Spaces:
Runtime error
Runtime error
Update watchdog.py
Browse files- watchdog.py +13 -2
watchdog.py
CHANGED
|
@@ -85,10 +85,16 @@ def retrain_model():
|
|
| 85 |
if input_ids is None:
|
| 86 |
return "⚠️ Not enough data to retrain.", None, "Please log more feedback first."
|
| 87 |
|
|
|
|
|
|
|
| 88 |
config = mutate_config()
|
| 89 |
-
model = EvoTransformerForClassification(config)
|
| 90 |
model.train()
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
|
| 93 |
loss_fn = torch.nn.CrossEntropyLoss()
|
| 94 |
|
|
@@ -100,6 +106,10 @@ def retrain_model():
|
|
| 100 |
optimizer.step()
|
| 101 |
print(f"🔁 Epoch {epoch+1}: Loss = {loss.item():.4f}")
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
# Accuracy
|
| 104 |
model.eval()
|
| 105 |
with torch.no_grad():
|
|
@@ -126,8 +136,9 @@ def retrain_model():
|
|
| 126 |
with open(log_path, "w") as f:
|
| 127 |
json.dump(history, f, indent=2)
|
| 128 |
|
| 129 |
-
# Save model
|
| 130 |
model.save_pretrained("trained_model")
|
|
|
|
| 131 |
print("✅ EvoTransformer retrained and saved.")
|
| 132 |
|
| 133 |
# Load updated summary + plot
|
|
|
|
| 85 |
if input_ids is None:
|
| 86 |
return "⚠️ Not enough data to retrain.", None, "Please log more feedback first."
|
| 87 |
|
| 88 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 89 |
+
|
| 90 |
config = mutate_config()
|
| 91 |
+
model = EvoTransformerForClassification(config).to(device)
|
| 92 |
model.train()
|
| 93 |
|
| 94 |
+
input_ids = input_ids.to(device)
|
| 95 |
+
attention_masks = attention_masks.to(device)
|
| 96 |
+
labels = labels.to(device)
|
| 97 |
+
|
| 98 |
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
|
| 99 |
loss_fn = torch.nn.CrossEntropyLoss()
|
| 100 |
|
|
|
|
| 106 |
optimizer.step()
|
| 107 |
print(f"🔁 Epoch {epoch+1}: Loss = {loss.item():.4f}")
|
| 108 |
|
| 109 |
+
# Sanity check logits
|
| 110 |
+
if logits.shape[-1] < 2:
|
| 111 |
+
raise ValueError("Logits shape invalid. Retrained model did not output 2 classes.")
|
| 112 |
+
|
| 113 |
# Accuracy
|
| 114 |
model.eval()
|
| 115 |
with torch.no_grad():
|
|
|
|
| 136 |
with open(log_path, "w") as f:
|
| 137 |
json.dump(history, f, indent=2)
|
| 138 |
|
| 139 |
+
# Save model + tokenizer
|
| 140 |
model.save_pretrained("trained_model")
|
| 141 |
+
tokenizer.save_pretrained("trained_model")
|
| 142 |
print("✅ EvoTransformer retrained and saved.")
|
| 143 |
|
| 144 |
# Load updated summary + plot
|