HemanM commited on
Commit
bf6e0ca
·
verified ·
1 Parent(s): e22bade

Update watchdog.py

Browse files
Files changed (1) hide show
  1. watchdog.py +13 -2
watchdog.py CHANGED
@@ -85,10 +85,16 @@ def retrain_model():
85
  if input_ids is None:
86
  return "⚠️ Not enough data to retrain.", None, "Please log more feedback first."
87
 
 
 
88
  config = mutate_config()
89
- model = EvoTransformerForClassification(config)
90
  model.train()
91
 
 
 
 
 
92
  optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
93
  loss_fn = torch.nn.CrossEntropyLoss()
94
 
@@ -100,6 +106,10 @@ def retrain_model():
100
  optimizer.step()
101
  print(f"🔁 Epoch {epoch+1}: Loss = {loss.item():.4f}")
102
 
 
 
 
 
103
  # Accuracy
104
  model.eval()
105
  with torch.no_grad():
@@ -126,8 +136,9 @@ def retrain_model():
126
  with open(log_path, "w") as f:
127
  json.dump(history, f, indent=2)
128
 
129
- # Save model
130
  model.save_pretrained("trained_model")
 
131
  print("✅ EvoTransformer retrained and saved.")
132
 
133
  # Load updated summary + plot
 
85
  if input_ids is None:
86
  return "⚠️ Not enough data to retrain.", None, "Please log more feedback first."
87
 
88
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
+
90
  config = mutate_config()
91
+ model = EvoTransformerForClassification(config).to(device)
92
  model.train()
93
 
94
+ input_ids = input_ids.to(device)
95
+ attention_masks = attention_masks.to(device)
96
+ labels = labels.to(device)
97
+
98
  optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
99
  loss_fn = torch.nn.CrossEntropyLoss()
100
 
 
106
  optimizer.step()
107
  print(f"🔁 Epoch {epoch+1}: Loss = {loss.item():.4f}")
108
 
109
+ # Sanity check logits
110
+ if logits.shape[-1] < 2:
111
+ raise ValueError("Logits shape invalid. Retrained model did not output 2 classes.")
112
+
113
  # Accuracy
114
  model.eval()
115
  with torch.no_grad():
 
136
  with open(log_path, "w") as f:
137
  json.dump(history, f, indent=2)
138
 
139
+ # Save model + tokenizer
140
  model.save_pretrained("trained_model")
141
+ tokenizer.save_pretrained("trained_model")
142
  print("✅ EvoTransformer retrained and saved.")
143
 
144
  # Load updated summary + plot