ShynBui commited on
Commit
20127c2
·
verified ·
1 Parent(s): 7f9ed76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -90,7 +90,7 @@ def train_step(file=None, start_idx=0):
90
  try:
91
  success, message = train_batch(dataloader)
92
  if not success:
93
- return start_idx # Trả về start_idx nếu lỗi xảy ra
94
 
95
  except HTMLError as e:
96
  print(e)
@@ -98,20 +98,20 @@ def train_step(file=None, start_idx=0):
98
  os.makedirs('./checkpoint')
99
  print('Save checkpoint')
100
  torch.save(model.state_dict(), "./checkpoint/model.pt")
101
- return start_idx # Trả về start_idx để lưu lại vị trí
102
 
103
  start_idx = end_idx
104
 
105
  if not os.path.exists('./checkpoint'):
106
  os.makedirs('./checkpoint')
107
  torch.save(model.state_dict(), "./checkpoint/model.pt")
108
- return start_idx
109
 
110
 
111
  if __name__ == "__main__":
112
  iface = gr.Interface(
113
  fn=train_step,
114
  inputs=[gr.File(label="Upload CSV"), gr.Textbox()],
115
- outputs="text"
116
  )
117
  iface.launch()
 
90
  try:
91
  success, message = train_batch(dataloader)
92
  if not success:
93
+ return start_idx, "./checkpoint/model.pt # Trả về start_idx nếu lỗi xảy ra
94
 
95
  except HTMLError as e:
96
  print(e)
 
98
  os.makedirs('./checkpoint')
99
  print('Save checkpoint')
100
  torch.save(model.state_dict(), "./checkpoint/model.pt")
101
+ return start_idx, "./checkpoint/model.pt" # Trả về start_idx để lưu lại vị trí
102
 
103
  start_idx = end_idx
104
 
105
  if not os.path.exists('./checkpoint'):
106
  os.makedirs('./checkpoint')
107
  torch.save(model.state_dict(), "./checkpoint/model.pt")
108
+ return start_idx, "./checkpoint/model.pt
109
 
110
 
111
  if __name__ == "__main__":
112
  iface = gr.Interface(
113
  fn=train_step,
114
  inputs=[gr.File(label="Upload CSV"), gr.Textbox()],
115
+ outputs=["text", gradio.File()]
116
  )
117
  iface.launch()