AbstractPhil commited on
Commit
cb977ae
·
verified ·
1 Parent(s): f512b47

Update trainer_v2.py

Browse files
Files changed (1) hide show
  1. trainer_v2.py +77 -44
trainer_v2.py CHANGED
@@ -189,7 +189,7 @@ class DanbooruTrainingConfig:
189
  use_gradient_checkpointing=self.use_gradient_checkpointing,
190
  share_scale_embeddings=self.share_scale_embeddings,
191
  geometric_init_method="hybrid",
192
- geometric_init_validate=False,
193
  geometric_init_seed=42
194
  )
195
 
@@ -199,7 +199,7 @@ class DanbooruTrainingConfig:
199
  # ============================================================================
200
 
201
  class CheckpointManager:
202
- """Manages checkpoints with proper naming (no step in directory name)."""
203
 
204
  def __init__(
205
  self,
@@ -210,11 +210,16 @@ class CheckpointManager:
210
  ):
211
  self.local_dir = Path(local_dir)
212
  self.hf_repo_id = hf_repo_id
213
- self.sub_name = sub_name
 
 
 
 
 
214
  self.hf_private = hf_private
215
 
216
- # Checkpoint directory structure: checkpoints/{sub_name}/{timestamp}/
217
- self.sub_checkpoint_dir = self.local_dir / sub_name
218
  self.sub_checkpoint_dir.mkdir(parents=True, exist_ok=True)
219
 
220
  self.checkpoints_file = self.sub_checkpoint_dir / "checkpoints.json"
@@ -242,6 +247,7 @@ class CheckpointManager:
242
  return json.load(f)
243
  return {
244
  "sub_name": self.sub_name,
 
245
  "checkpoints": [],
246
  "latest": None,
247
  "best": None
@@ -252,9 +258,8 @@ class CheckpointManager:
252
  json.dump(self.checkpoint_history, f, indent=2)
253
 
254
  def get_checkpoint_dir(self, step: int, epoch: int) -> Path:
255
- """Generate checkpoint directory name (timestamp-based, step in metadata)."""
256
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
257
- dirname = f"epoch{epoch}_step{step}_{timestamp}"
258
  return self.sub_checkpoint_dir / dirname
259
 
260
  def _safe_state_dict(self, model: nn.Module) -> Dict[str, torch.Tensor]:
@@ -311,11 +316,12 @@ class CheckpointManager:
311
  ckpt_dir.mkdir(parents=True, exist_ok=True)
312
 
313
  print(f"\n💾 Saving checkpoint: {self.sub_name}/{ckpt_dir.name}")
 
314
 
315
  state_dict = self._safe_state_dict(model)
316
  weights_path = ckpt_dir / "model.safetensors"
317
  save_file(state_dict, weights_path)
318
- print(f" ✓ Model weights: {weights_path.name}")
319
 
320
  training_state = {
321
  'epoch': epoch,
@@ -323,7 +329,8 @@ class CheckpointManager:
323
  'optimizer_state_dict': optimizer.state_dict(),
324
  'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
325
  'val_loss': val_loss,
326
- 'sub_name': self.sub_name
 
327
  }
328
  torch.save(training_state, ckpt_dir / "training_state.pt")
329
  print(f" ✓ Training state: training_state.pt")
@@ -393,7 +400,7 @@ class CheckpointManager:
393
  traceback.print_exc()
394
 
395
  def find_latest_checkpoint(self) -> Optional[Dict]:
396
- """Find the latest checkpoint for this sub_name."""
397
  checkpoints = self.checkpoint_history.get('checkpoints', [])
398
  if checkpoints:
399
  return max(checkpoints, key=lambda x: x['step'])
@@ -409,7 +416,7 @@ class CheckpointManager:
409
  latest = self.find_latest_checkpoint()
410
 
411
  if not latest:
412
- print(f"ℹ️ No previous checkpoint found for sub_name='{self.sub_name}'")
413
  return 0, 0, float('inf')
414
 
415
  ckpt_dir = self.sub_checkpoint_dir / latest['dirname']
@@ -437,7 +444,7 @@ class CheckpointManager:
437
  print(f" ⚠️ Checkpoint directory not found: {ckpt_dir}")
438
  return 0, 0, float('inf')
439
 
440
- print(f"\n🔄 Resuming from checkpoint: {latest['dirname']}")
441
  print(f" Step: {latest['step']}, Epoch: {latest['epoch']}, Val Loss: {latest['val_loss']:.4f}")
442
 
443
  weights_path = ckpt_dir / "model.safetensors"
@@ -1117,18 +1124,30 @@ class DanbooruLiminalStaircaseTrainer:
1117
  return
1118
 
1119
  for key, value in metrics.items():
1120
- self.writer.add_scalar(f"{prefix}/{key}", value, self.global_step)
 
 
 
 
 
 
 
 
 
 
 
1121
 
1122
- current_lr = self.optimizer.param_groups[0]['lr']
1123
- self.writer.add_scalar("train/learning_rate", current_lr, self.global_step)
1124
 
1125
- # Log text modality stats
1126
- if self.global_step % self.config.log_every == 0:
1127
  total = sum(self.text_dropout_stats.values()) or 1
1128
  for mode, count in self.text_dropout_stats.items():
1129
  self.writer.add_scalar(f"text_modality/{mode}_pct", 100 * count / total, self.global_step)
1130
 
1131
- if self.global_step % (self.config.log_every * 10) == 0:
 
1132
  fusion_diag = self.get_fusion_diagnostics()
1133
 
1134
  for i, w in enumerate(fusion_diag.get('layer_weights', [])):
@@ -1142,6 +1161,8 @@ class DanbooruLiminalStaircaseTrainer:
1142
 
1143
  for i, b in enumerate(fusion_diag.get('beta_per_scale', [])):
1144
  self.writer.add_scalar(f"fusion/beta_scale_{i}", b, self.global_step)
 
 
1145
 
1146
  @torch.no_grad()
1147
  def validate(self, max_batches: int = 100) -> Dict[str, float]:
@@ -1197,14 +1218,20 @@ class DanbooruLiminalStaircaseTrainer:
1197
  continue
1198
 
1199
  if stats_with_text['count'] == 0 or stats_vision_only['count'] == 0:
1200
- return {'loss/val': float('inf'), 'acc/val': 0.0}
 
 
 
 
 
 
 
1201
 
1202
  return {
1203
- 'loss/val_with_text': stats_with_text['loss'] / stats_with_text['count'],
1204
- 'acc/val_with_text': stats_with_text['acc'] / stats_with_text['count'],
1205
- 'loss/val_vision_only': stats_vision_only['loss'] / stats_vision_only['count'],
1206
- 'acc/val_vision_only': stats_vision_only['acc'] / stats_vision_only['count'],
1207
- # Overall metric = vision-only (the real use case)
1208
  'loss/val': stats_vision_only['loss'] / stats_vision_only['count'],
1209
  'acc/val': stats_vision_only['acc'] / stats_vision_only['count'],
1210
  }
@@ -1212,7 +1239,14 @@ class DanbooruLiminalStaircaseTrainer:
1212
  except Exception as e:
1213
  print(f"\n⚠️ Validation completely failed: {e}")
1214
  traceback.print_exc()
1215
- return {'loss/val': float('inf'), 'acc/val': 0.0}
 
 
 
 
 
 
 
1216
 
1217
  def save_checkpoint_and_upload(self, epoch: int, val_loss: float = float('inf'), is_best: bool = False):
1218
  """Save checkpoint first, then optionally upload."""
@@ -1294,8 +1328,8 @@ class DanbooruLiminalStaircaseTrainer:
1294
  print("\n🔍 Running validation...")
1295
  val_metrics = self.validate(max_batches=50)
1296
  self.log_metrics(val_metrics, prefix="val")
1297
- print(f"✓ Val (with text) - Loss: {val_metrics.get('loss/val_with_text', 0):.4f}, Acc: {val_metrics.get('acc/val_with_text', 0):.4f}")
1298
- print(f"✓ Val (vision-only) - Loss: {val_metrics.get('loss/val_vision_only', 0):.4f}, Acc: {val_metrics.get('acc/val_vision_only', 0):.4f}")
1299
 
1300
  # HuggingFace upload
1301
  if (self.config.hf_repo_id and
@@ -1305,8 +1339,8 @@ class DanbooruLiminalStaircaseTrainer:
1305
  if self.accelerator.is_main_process:
1306
  print("\n🔍 Running validation for upload...")
1307
  val_metrics = self.validate(max_batches=50)
1308
- print(f"✓ Val (with text) - Loss: {val_metrics.get('loss/val_with_text', 0):.4f}, Acc: {val_metrics.get('acc/val_with_text', 0):.4f}")
1309
- print(f"✓ Val (vision-only) - Loss: {val_metrics.get('loss/val_vision_only', 0):.4f}, Acc: {val_metrics.get('acc/val_vision_only', 0):.4f}")
1310
 
1311
  if self._interrupt_received:
1312
  break
@@ -1320,11 +1354,11 @@ class DanbooruLiminalStaircaseTrainer:
1320
 
1321
  print(f"\n📊 Validation Results:")
1322
  print(f" With Text:")
1323
- print(f" Loss: {val_metrics.get('loss/val_with_text', 0):.4f}")
1324
- print(f" Acc: {val_metrics.get('acc/val_with_text', 0):.4f}")
1325
  print(f" Vision-Only (PRIMARY METRIC):")
1326
- print(f" Loss: {val_metrics.get('loss/val_vision_only', 0):.4f}")
1327
- print(f" Acc: {val_metrics.get('acc/val_vision_only', 0):.4f}")
1328
 
1329
  self.log_metrics(val_metrics, prefix="val")
1330
 
@@ -1369,7 +1403,6 @@ class DanbooruLiminalStaircaseTrainer:
1369
  if self.writer:
1370
  self.writer.close()
1371
 
1372
-
1373
  # ============================================================================
1374
  # MAIN
1375
  # ============================================================================
@@ -1377,16 +1410,16 @@ class DanbooruLiminalStaircaseTrainer:
1377
  if __name__ == "__main__":
1378
  config = DanbooruTrainingConfig(
1379
  # Run identifier
1380
- sub_name="danbooru-50k-v1-512",
1381
 
1382
  # Model architecture
1383
  num_opinion_anchors=225,
1384
- pentachoron_dim=256,
1385
  scales=[128, 256, 512, 1024],
1386
- scale_hidden_dims={128: 128, 256: 512, 512: 1024, 1024: 2048},
1387
 
1388
  # Fusion controller
1389
- alpha_init=0.1,
1390
  alpha_learnable=True,
1391
  beta_init=0.5,
1392
  beta_learnable=True,
@@ -1394,16 +1427,16 @@ if __name__ == "__main__":
1394
  learn_layer_weights=True,
1395
 
1396
  # Encoders
1397
- clip_skip=0,
1398
- siglip_layer_indices=[3, 6, 9, 12, 21, 23, 24, 25, 26],
1399
 
1400
  # Optimizations
1401
  use_gradient_checkpointing=False,
1402
  share_scale_embeddings=False,
1403
 
1404
  # Training
1405
- batch_size=32,
1406
- num_epochs=3,
1407
  learning_rate=1e-4,
1408
  save_every=500,
1409
 
@@ -1417,7 +1450,7 @@ if __name__ == "__main__":
1417
  text_dropout_end=0.5,
1418
 
1419
  # Resume
1420
- resume=True,
1421
 
1422
  # HuggingFace
1423
  hf_repo_id="AbstractPhil/liminal-staircase-v2",
 
189
  use_gradient_checkpointing=self.use_gradient_checkpointing,
190
  share_scale_embeddings=self.share_scale_embeddings,
191
  geometric_init_method="hybrid",
192
+ geometric_init_validate=True,
193
  geometric_init_seed=42
194
  )
195
 
 
199
  # ============================================================================
200
 
201
  class CheckpointManager:
202
+ """Manages checkpoints with run timestamp, simple step-based checkpoint names."""
203
 
204
  def __init__(
205
  self,
 
210
  ):
211
  self.local_dir = Path(local_dir)
212
  self.hf_repo_id = hf_repo_id
213
+ self.base_sub_name = sub_name
214
+
215
+ # ADD RUN TIMESTAMP TO SUB_NAME (once, when training starts)
216
+ run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
217
+ self.sub_name = f"{sub_name}-{run_timestamp}"
218
+
219
  self.hf_private = hf_private
220
 
221
+ # Checkpoint directory: checkpoints/{sub_name-timestamp}/
222
+ self.sub_checkpoint_dir = self.local_dir / self.sub_name
223
  self.sub_checkpoint_dir.mkdir(parents=True, exist_ok=True)
224
 
225
  self.checkpoints_file = self.sub_checkpoint_dir / "checkpoints.json"
 
247
  return json.load(f)
248
  return {
249
  "sub_name": self.sub_name,
250
+ "base_name": self.base_sub_name,
251
  "checkpoints": [],
252
  "latest": None,
253
  "best": None
 
258
  json.dump(self.checkpoint_history, f, indent=2)
259
 
260
  def get_checkpoint_dir(self, step: int, epoch: int) -> Path:
261
+ """Generate checkpoint directory name: just step{N}."""
262
+ dirname = f"step{step}"
 
263
  return self.sub_checkpoint_dir / dirname
264
 
265
  def _safe_state_dict(self, model: nn.Module) -> Dict[str, torch.Tensor]:
 
316
  ckpt_dir.mkdir(parents=True, exist_ok=True)
317
 
318
  print(f"\n💾 Saving checkpoint: {self.sub_name}/{ckpt_dir.name}")
319
+ print(f" Step: {step}, Epoch: {epoch}")
320
 
321
  state_dict = self._safe_state_dict(model)
322
  weights_path = ckpt_dir / "model.safetensors"
323
  save_file(state_dict, weights_path)
324
+ print(f" ✓ Model weights: model.safetensors")
325
 
326
  training_state = {
327
  'epoch': epoch,
 
329
  'optimizer_state_dict': optimizer.state_dict(),
330
  'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
331
  'val_loss': val_loss,
332
+ 'sub_name': self.sub_name,
333
+ 'base_name': self.base_sub_name
334
  }
335
  torch.save(training_state, ckpt_dir / "training_state.pt")
336
  print(f" ✓ Training state: training_state.pt")
 
400
  traceback.print_exc()
401
 
402
  def find_latest_checkpoint(self) -> Optional[Dict]:
403
+ """Find the latest checkpoint for this training run."""
404
  checkpoints = self.checkpoint_history.get('checkpoints', [])
405
  if checkpoints:
406
  return max(checkpoints, key=lambda x: x['step'])
 
416
  latest = self.find_latest_checkpoint()
417
 
418
  if not latest:
419
+ print(f"ℹ️ No previous checkpoint found for training run '{self.sub_name}'")
420
  return 0, 0, float('inf')
421
 
422
  ckpt_dir = self.sub_checkpoint_dir / latest['dirname']
 
444
  print(f" ⚠️ Checkpoint directory not found: {ckpt_dir}")
445
  return 0, 0, float('inf')
446
 
447
+ print(f"\n🔄 Resuming from checkpoint: {self.sub_name}/{latest['dirname']}")
448
  print(f" Step: {latest['step']}, Epoch: {latest['epoch']}, Val Loss: {latest['val_loss']:.4f}")
449
 
450
  weights_path = ckpt_dir / "model.safetensors"
 
1124
  return
1125
 
1126
  for key, value in metrics.items():
1127
+ # Handle validation metrics that already have prefixes
1128
+ if prefix == "val" and key.startswith(('loss/', 'acc/')):
1129
+ # Strip the redundant prefix
1130
+ clean_key = key.replace('loss/', '').replace('acc/', '')
1131
+ self.writer.add_scalar(f"val/{clean_key}", value, self.global_step)
1132
+ else:
1133
+ self.writer.add_scalar(f"{prefix}/{key}", value, self.global_step)
1134
+
1135
+ # Log learning rate
1136
+ if prefix == "train":
1137
+ current_lr = self.optimizer.param_groups[0]['lr']
1138
+ self.writer.add_scalar("train/learning_rate", current_lr, self.global_step)
1139
 
1140
+ # Flush to disk
1141
+ self.writer.flush()
1142
 
1143
+ # Log text modality stats periodically
1144
+ if prefix == "train" and self.global_step % self.config.log_every == 0:
1145
  total = sum(self.text_dropout_stats.values()) or 1
1146
  for mode, count in self.text_dropout_stats.items():
1147
  self.writer.add_scalar(f"text_modality/{mode}_pct", 100 * count / total, self.global_step)
1148
 
1149
+ # Log fusion diagnostics periodically
1150
+ if prefix == "train" and self.global_step % (self.config.log_every * 10) == 0:
1151
  fusion_diag = self.get_fusion_diagnostics()
1152
 
1153
  for i, w in enumerate(fusion_diag.get('layer_weights', [])):
 
1161
 
1162
  for i, b in enumerate(fusion_diag.get('beta_per_scale', [])):
1163
  self.writer.add_scalar(f"fusion/beta_scale_{i}", b, self.global_step)
1164
+
1165
+ self.writer.flush()
1166
 
1167
  @torch.no_grad()
1168
  def validate(self, max_batches: int = 100) -> Dict[str, float]:
 
1218
  continue
1219
 
1220
  if stats_with_text['count'] == 0 or stats_vision_only['count'] == 0:
1221
+ return {
1222
+ 'val_with_text_loss': float('inf'),
1223
+ 'val_with_text_acc': 0.0,
1224
+ 'val_vision_only_loss': float('inf'),
1225
+ 'val_vision_only_acc': 0.0,
1226
+ 'loss/val': float('inf'),
1227
+ 'acc/val': 0.0
1228
+ }
1229
 
1230
  return {
1231
+ 'val_with_text_loss': stats_with_text['loss'] / stats_with_text['count'],
1232
+ 'val_with_text_acc': stats_with_text['acc'] / stats_with_text['count'],
1233
+ 'val_vision_only_loss': stats_vision_only['loss'] / stats_vision_only['count'],
1234
+ 'val_vision_only_acc': stats_vision_only['acc'] / stats_vision_only['count'],
 
1235
  'loss/val': stats_vision_only['loss'] / stats_vision_only['count'],
1236
  'acc/val': stats_vision_only['acc'] / stats_vision_only['count'],
1237
  }
 
1239
  except Exception as e:
1240
  print(f"\n⚠️ Validation completely failed: {e}")
1241
  traceback.print_exc()
1242
+ return {
1243
+ 'val_with_text_loss': float('inf'),
1244
+ 'val_with_text_acc': 0.0,
1245
+ 'val_vision_only_loss': float('inf'),
1246
+ 'val_vision_only_acc': 0.0,
1247
+ 'loss/val': float('inf'),
1248
+ 'acc/val': 0.0
1249
+ }
1250
 
1251
  def save_checkpoint_and_upload(self, epoch: int, val_loss: float = float('inf'), is_best: bool = False):
1252
  """Save checkpoint first, then optionally upload."""
 
1328
  print("\n🔍 Running validation...")
1329
  val_metrics = self.validate(max_batches=50)
1330
  self.log_metrics(val_metrics, prefix="val")
1331
+ print(f"✓ Val (with text) - Loss: {val_metrics['val_with_text_loss']:.4f}, Acc: {val_metrics['val_with_text_acc']:.4f}")
1332
+ print(f"✓ Val (vision-only) - Loss: {val_metrics['val_vision_only_loss']:.4f}, Acc: {val_metrics['val_vision_only_acc']:.4f}")
1333
 
1334
  # HuggingFace upload
1335
  if (self.config.hf_repo_id and
 
1339
  if self.accelerator.is_main_process:
1340
  print("\n🔍 Running validation for upload...")
1341
  val_metrics = self.validate(max_batches=50)
1342
+ print(f"✓ Val (with text) - Loss: {val_metrics['val_with_text_loss']:.4f}, Acc: {val_metrics['val_with_text_acc']:.4f}")
1343
+ print(f"✓ Val (vision-only) - Loss: {val_metrics['val_vision_only_loss']:.4f}, Acc: {val_metrics['val_vision_only_acc']:.4f}")
1344
 
1345
  if self._interrupt_received:
1346
  break
 
1354
 
1355
  print(f"\n📊 Validation Results:")
1356
  print(f" With Text:")
1357
+ print(f" Loss: {val_metrics['val_with_text_loss']:.4f}")
1358
+ print(f" Acc: {val_metrics['val_with_text_acc']:.4f}")
1359
  print(f" Vision-Only (PRIMARY METRIC):")
1360
+ print(f" Loss: {val_metrics['val_vision_only_loss']:.4f}")
1361
+ print(f" Acc: {val_metrics['val_vision_only_acc']:.4f}")
1362
 
1363
  self.log_metrics(val_metrics, prefix="val")
1364
 
 
1403
  if self.writer:
1404
  self.writer.close()
1405
 
 
1406
  # ============================================================================
1407
  # MAIN
1408
  # ============================================================================
 
1410
  if __name__ == "__main__":
1411
  config = DanbooruTrainingConfig(
1412
  # Run identifier
1413
+ sub_name="danbooru-50k-v1-512-2",
1414
 
1415
  # Model architecture
1416
  num_opinion_anchors=225,
1417
+ pentachoron_dim=512,
1418
  scales=[128, 256, 512, 1024],
1419
+ scale_hidden_dims={128: 256, 256: 512, 512: 1024, 1024: 2048},
1420
 
1421
  # Fusion controller
1422
+ alpha_init=0.125,
1423
  alpha_learnable=True,
1424
  beta_init=0.5,
1425
  beta_learnable=True,
 
1427
  learn_layer_weights=True,
1428
 
1429
  # Encoders
1430
+ clip_skip=1,
1431
+ siglip_layer_indices=[1, 2, 3, 4, 5, 6, 9, 12, 18, 21, 23, 24, 25, 26],
1432
 
1433
  # Optimizations
1434
  use_gradient_checkpointing=False,
1435
  share_scale_embeddings=False,
1436
 
1437
  # Training
1438
+ batch_size=24,
1439
+ num_epochs=20,
1440
  learning_rate=1e-4,
1441
  save_every=500,
1442
 
 
1450
  text_dropout_end=0.5,
1451
 
1452
  # Resume
1453
+ resume=False,
1454
 
1455
  # HuggingFace
1456
  hf_repo_id="AbstractPhil/liminal-staircase-v2",