Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -70,6 +70,11 @@ def train_step(file=None, start_idx=0):
|
|
| 70 |
if os.path.exists("./checkpoint/model.pt"):
|
| 71 |
print("Loading checkpoint...")
|
| 72 |
model.load_state_dict(torch.load("./checkpoint/model.pt"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
batch_size = 8
|
| 75 |
total_samples = len(global_data)
|
|
|
|
| 70 |
if os.path.exists("./checkpoint/model.pt"):
|
| 71 |
print("Loading checkpoint...")
|
| 72 |
model.load_state_dict(torch.load("./checkpoint/model.pt"))
|
| 73 |
+
else:
|
| 74 |
+
print("Checkpoint not found, starting fresh...")
|
| 75 |
+
if not os.path.exists('./checkpoint'):
|
| 76 |
+
os.makedirs('./checkpoint')
|
| 77 |
+
torch.save(model.state_dict(), "./checkpoint/model.pt")
|
| 78 |
|
| 79 |
batch_size = 8
|
| 80 |
total_samples = len(global_data)
|