Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,7 +34,7 @@ def get_dataloader(start, end, batch_size=8):
|
|
| 34 |
subset = torch.utils.data.Subset(global_data, range(start, end))
|
| 35 |
return DataLoader(subset, batch_size=batch_size)
|
| 36 |
|
| 37 |
-
@spaces.GPU(duration=
|
| 38 |
def train_batch(dataloader):
|
| 39 |
model.train()
|
| 40 |
start_time = time.time()
|
|
@@ -88,8 +88,10 @@ def train_step(file=None, start_idx=0):
|
|
| 88 |
return start_idx # Trả về start_idx nếu lỗi xảy ra
|
| 89 |
|
| 90 |
except HTMLError as e:
|
| 91 |
-
print("Exceeded GPU quota
|
| 92 |
-
|
|
|
|
|
|
|
| 93 |
return start_idx # Trả về start_idx để lưu lại vị trí
|
| 94 |
|
| 95 |
start_idx = end_idx
|
|
|
|
| 34 |
subset = torch.utils.data.Subset(global_data, range(start, end))
|
| 35 |
return DataLoader(subset, batch_size=batch_size)
|
| 36 |
|
| 37 |
+
@spaces.GPU(duration=120)
|
| 38 |
def train_batch(dataloader):
|
| 39 |
model.train()
|
| 40 |
start_time = time.time()
|
|
|
|
| 88 |
return start_idx # Trả về start_idx nếu lỗi xảy ra
|
| 89 |
|
| 90 |
except HTMLError as e:
|
| 91 |
+
print("Exceeded GPU quota.")
|
| 92 |
+
if not os.path.exists('./checkpoint'):
|
| 93 |
+
os.makedirs('./checkpoint')
|
| 94 |
+
torch.save(model.state_dict(), "./checkpoint/model.pt")
|
| 95 |
return start_idx # Trả về start_idx để lưu lại vị trí
|
| 96 |
|
| 97 |
start_idx = end_idx
|