Update trainer_v2.py
Browse files- trainer_v2.py +77 -44
trainer_v2.py
CHANGED
|
@@ -189,7 +189,7 @@ class DanbooruTrainingConfig:
|
|
| 189 |
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
| 190 |
share_scale_embeddings=self.share_scale_embeddings,
|
| 191 |
geometric_init_method="hybrid",
|
| 192 |
-
geometric_init_validate=
|
| 193 |
geometric_init_seed=42
|
| 194 |
)
|
| 195 |
|
|
@@ -199,7 +199,7 @@ class DanbooruTrainingConfig:
|
|
| 199 |
# ============================================================================
|
| 200 |
|
| 201 |
class CheckpointManager:
|
| 202 |
-
"""Manages checkpoints with
|
| 203 |
|
| 204 |
def __init__(
|
| 205 |
self,
|
|
@@ -210,11 +210,16 @@ class CheckpointManager:
|
|
| 210 |
):
|
| 211 |
self.local_dir = Path(local_dir)
|
| 212 |
self.hf_repo_id = hf_repo_id
|
| 213 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
self.hf_private = hf_private
|
| 215 |
|
| 216 |
-
# Checkpoint directory
|
| 217 |
-
self.sub_checkpoint_dir = self.local_dir / sub_name
|
| 218 |
self.sub_checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 219 |
|
| 220 |
self.checkpoints_file = self.sub_checkpoint_dir / "checkpoints.json"
|
|
@@ -242,6 +247,7 @@ class CheckpointManager:
|
|
| 242 |
return json.load(f)
|
| 243 |
return {
|
| 244 |
"sub_name": self.sub_name,
|
|
|
|
| 245 |
"checkpoints": [],
|
| 246 |
"latest": None,
|
| 247 |
"best": None
|
|
@@ -252,9 +258,8 @@ class CheckpointManager:
|
|
| 252 |
json.dump(self.checkpoint_history, f, indent=2)
|
| 253 |
|
| 254 |
def get_checkpoint_dir(self, step: int, epoch: int) -> Path:
|
| 255 |
-
"""Generate checkpoint directory name
|
| 256 |
-
|
| 257 |
-
dirname = f"epoch{epoch}_step{step}_{timestamp}"
|
| 258 |
return self.sub_checkpoint_dir / dirname
|
| 259 |
|
| 260 |
def _safe_state_dict(self, model: nn.Module) -> Dict[str, torch.Tensor]:
|
|
@@ -311,11 +316,12 @@ class CheckpointManager:
|
|
| 311 |
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 312 |
|
| 313 |
print(f"\n💾 Saving checkpoint: {self.sub_name}/{ckpt_dir.name}")
|
|
|
|
| 314 |
|
| 315 |
state_dict = self._safe_state_dict(model)
|
| 316 |
weights_path = ckpt_dir / "model.safetensors"
|
| 317 |
save_file(state_dict, weights_path)
|
| 318 |
-
print(f" ✓ Model weights:
|
| 319 |
|
| 320 |
training_state = {
|
| 321 |
'epoch': epoch,
|
|
@@ -323,7 +329,8 @@ class CheckpointManager:
|
|
| 323 |
'optimizer_state_dict': optimizer.state_dict(),
|
| 324 |
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
|
| 325 |
'val_loss': val_loss,
|
| 326 |
-
'sub_name': self.sub_name
|
|
|
|
| 327 |
}
|
| 328 |
torch.save(training_state, ckpt_dir / "training_state.pt")
|
| 329 |
print(f" ✓ Training state: training_state.pt")
|
|
@@ -393,7 +400,7 @@ class CheckpointManager:
|
|
| 393 |
traceback.print_exc()
|
| 394 |
|
| 395 |
def find_latest_checkpoint(self) -> Optional[Dict]:
|
| 396 |
-
"""Find the latest checkpoint for this
|
| 397 |
checkpoints = self.checkpoint_history.get('checkpoints', [])
|
| 398 |
if checkpoints:
|
| 399 |
return max(checkpoints, key=lambda x: x['step'])
|
|
@@ -409,7 +416,7 @@ class CheckpointManager:
|
|
| 409 |
latest = self.find_latest_checkpoint()
|
| 410 |
|
| 411 |
if not latest:
|
| 412 |
-
print(f"ℹ️ No previous checkpoint found for
|
| 413 |
return 0, 0, float('inf')
|
| 414 |
|
| 415 |
ckpt_dir = self.sub_checkpoint_dir / latest['dirname']
|
|
@@ -437,7 +444,7 @@ class CheckpointManager:
|
|
| 437 |
print(f" ⚠️ Checkpoint directory not found: {ckpt_dir}")
|
| 438 |
return 0, 0, float('inf')
|
| 439 |
|
| 440 |
-
print(f"\n🔄 Resuming from checkpoint: {latest['dirname']}")
|
| 441 |
print(f" Step: {latest['step']}, Epoch: {latest['epoch']}, Val Loss: {latest['val_loss']:.4f}")
|
| 442 |
|
| 443 |
weights_path = ckpt_dir / "model.safetensors"
|
|
@@ -1117,18 +1124,30 @@ class DanbooruLiminalStaircaseTrainer:
|
|
| 1117 |
return
|
| 1118 |
|
| 1119 |
for key, value in metrics.items():
|
| 1120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1121 |
|
| 1122 |
-
|
| 1123 |
-
self.writer.
|
| 1124 |
|
| 1125 |
-
# Log text modality stats
|
| 1126 |
-
if self.global_step % self.config.log_every == 0:
|
| 1127 |
total = sum(self.text_dropout_stats.values()) or 1
|
| 1128 |
for mode, count in self.text_dropout_stats.items():
|
| 1129 |
self.writer.add_scalar(f"text_modality/{mode}_pct", 100 * count / total, self.global_step)
|
| 1130 |
|
| 1131 |
-
|
|
|
|
| 1132 |
fusion_diag = self.get_fusion_diagnostics()
|
| 1133 |
|
| 1134 |
for i, w in enumerate(fusion_diag.get('layer_weights', [])):
|
|
@@ -1142,6 +1161,8 @@ class DanbooruLiminalStaircaseTrainer:
|
|
| 1142 |
|
| 1143 |
for i, b in enumerate(fusion_diag.get('beta_per_scale', [])):
|
| 1144 |
self.writer.add_scalar(f"fusion/beta_scale_{i}", b, self.global_step)
|
|
|
|
|
|
|
| 1145 |
|
| 1146 |
@torch.no_grad()
|
| 1147 |
def validate(self, max_batches: int = 100) -> Dict[str, float]:
|
|
@@ -1197,14 +1218,20 @@ class DanbooruLiminalStaircaseTrainer:
|
|
| 1197 |
continue
|
| 1198 |
|
| 1199 |
if stats_with_text['count'] == 0 or stats_vision_only['count'] == 0:
|
| 1200 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1201 |
|
| 1202 |
return {
|
| 1203 |
-
'
|
| 1204 |
-
'
|
| 1205 |
-
'
|
| 1206 |
-
'
|
| 1207 |
-
# Overall metric = vision-only (the real use case)
|
| 1208 |
'loss/val': stats_vision_only['loss'] / stats_vision_only['count'],
|
| 1209 |
'acc/val': stats_vision_only['acc'] / stats_vision_only['count'],
|
| 1210 |
}
|
|
@@ -1212,7 +1239,14 @@ class DanbooruLiminalStaircaseTrainer:
|
|
| 1212 |
except Exception as e:
|
| 1213 |
print(f"\n⚠️ Validation completely failed: {e}")
|
| 1214 |
traceback.print_exc()
|
| 1215 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1216 |
|
| 1217 |
def save_checkpoint_and_upload(self, epoch: int, val_loss: float = float('inf'), is_best: bool = False):
|
| 1218 |
"""Save checkpoint first, then optionally upload."""
|
|
@@ -1294,8 +1328,8 @@ class DanbooruLiminalStaircaseTrainer:
|
|
| 1294 |
print("\n🔍 Running validation...")
|
| 1295 |
val_metrics = self.validate(max_batches=50)
|
| 1296 |
self.log_metrics(val_metrics, prefix="val")
|
| 1297 |
-
print(f"✓ Val (with text) - Loss: {val_metrics
|
| 1298 |
-
print(f"✓ Val (vision-only) - Loss: {val_metrics
|
| 1299 |
|
| 1300 |
# HuggingFace upload
|
| 1301 |
if (self.config.hf_repo_id and
|
|
@@ -1305,8 +1339,8 @@ class DanbooruLiminalStaircaseTrainer:
|
|
| 1305 |
if self.accelerator.is_main_process:
|
| 1306 |
print("\n🔍 Running validation for upload...")
|
| 1307 |
val_metrics = self.validate(max_batches=50)
|
| 1308 |
-
print(f"✓ Val (with text) - Loss: {val_metrics
|
| 1309 |
-
print(f"✓ Val (vision-only) - Loss: {val_metrics
|
| 1310 |
|
| 1311 |
if self._interrupt_received:
|
| 1312 |
break
|
|
@@ -1320,11 +1354,11 @@ class DanbooruLiminalStaircaseTrainer:
|
|
| 1320 |
|
| 1321 |
print(f"\n📊 Validation Results:")
|
| 1322 |
print(f" With Text:")
|
| 1323 |
-
print(f" Loss: {val_metrics
|
| 1324 |
-
print(f" Acc: {val_metrics
|
| 1325 |
print(f" Vision-Only (PRIMARY METRIC):")
|
| 1326 |
-
print(f" Loss: {val_metrics
|
| 1327 |
-
print(f" Acc: {val_metrics
|
| 1328 |
|
| 1329 |
self.log_metrics(val_metrics, prefix="val")
|
| 1330 |
|
|
@@ -1369,7 +1403,6 @@ class DanbooruLiminalStaircaseTrainer:
|
|
| 1369 |
if self.writer:
|
| 1370 |
self.writer.close()
|
| 1371 |
|
| 1372 |
-
|
| 1373 |
# ============================================================================
|
| 1374 |
# MAIN
|
| 1375 |
# ============================================================================
|
|
@@ -1377,16 +1410,16 @@ class DanbooruLiminalStaircaseTrainer:
|
|
| 1377 |
if __name__ == "__main__":
|
| 1378 |
config = DanbooruTrainingConfig(
|
| 1379 |
# Run identifier
|
| 1380 |
-
sub_name="danbooru-50k-v1-512",
|
| 1381 |
|
| 1382 |
# Model architecture
|
| 1383 |
num_opinion_anchors=225,
|
| 1384 |
-
pentachoron_dim=
|
| 1385 |
scales=[128, 256, 512, 1024],
|
| 1386 |
-
scale_hidden_dims={128:
|
| 1387 |
|
| 1388 |
# Fusion controller
|
| 1389 |
-
alpha_init=0.
|
| 1390 |
alpha_learnable=True,
|
| 1391 |
beta_init=0.5,
|
| 1392 |
beta_learnable=True,
|
|
@@ -1394,16 +1427,16 @@ if __name__ == "__main__":
|
|
| 1394 |
learn_layer_weights=True,
|
| 1395 |
|
| 1396 |
# Encoders
|
| 1397 |
-
clip_skip=
|
| 1398 |
-
siglip_layer_indices=[3, 6, 9, 12, 21, 23, 24, 25, 26],
|
| 1399 |
|
| 1400 |
# Optimizations
|
| 1401 |
use_gradient_checkpointing=False,
|
| 1402 |
share_scale_embeddings=False,
|
| 1403 |
|
| 1404 |
# Training
|
| 1405 |
-
batch_size=
|
| 1406 |
-
num_epochs=
|
| 1407 |
learning_rate=1e-4,
|
| 1408 |
save_every=500,
|
| 1409 |
|
|
@@ -1417,7 +1450,7 @@ if __name__ == "__main__":
|
|
| 1417 |
text_dropout_end=0.5,
|
| 1418 |
|
| 1419 |
# Resume
|
| 1420 |
-
resume=
|
| 1421 |
|
| 1422 |
# HuggingFace
|
| 1423 |
hf_repo_id="AbstractPhil/liminal-staircase-v2",
|
|
|
|
| 189 |
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
| 190 |
share_scale_embeddings=self.share_scale_embeddings,
|
| 191 |
geometric_init_method="hybrid",
|
| 192 |
+
geometric_init_validate=True,
|
| 193 |
geometric_init_seed=42
|
| 194 |
)
|
| 195 |
|
|
|
|
| 199 |
# ============================================================================
|
| 200 |
|
| 201 |
class CheckpointManager:
|
| 202 |
+
"""Manages checkpoints with run timestamp, simple step-based checkpoint names."""
|
| 203 |
|
| 204 |
def __init__(
|
| 205 |
self,
|
|
|
|
| 210 |
):
|
| 211 |
self.local_dir = Path(local_dir)
|
| 212 |
self.hf_repo_id = hf_repo_id
|
| 213 |
+
self.base_sub_name = sub_name
|
| 214 |
+
|
| 215 |
+
# ADD RUN TIMESTAMP TO SUB_NAME (once, when training starts)
|
| 216 |
+
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 217 |
+
self.sub_name = f"{sub_name}-{run_timestamp}"
|
| 218 |
+
|
| 219 |
self.hf_private = hf_private
|
| 220 |
|
| 221 |
+
# Checkpoint directory: checkpoints/{sub_name-timestamp}/
|
| 222 |
+
self.sub_checkpoint_dir = self.local_dir / self.sub_name
|
| 223 |
self.sub_checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 224 |
|
| 225 |
self.checkpoints_file = self.sub_checkpoint_dir / "checkpoints.json"
|
|
|
|
| 247 |
return json.load(f)
|
| 248 |
return {
|
| 249 |
"sub_name": self.sub_name,
|
| 250 |
+
"base_name": self.base_sub_name,
|
| 251 |
"checkpoints": [],
|
| 252 |
"latest": None,
|
| 253 |
"best": None
|
|
|
|
| 258 |
json.dump(self.checkpoint_history, f, indent=2)
|
| 259 |
|
| 260 |
def get_checkpoint_dir(self, step: int, epoch: int) -> Path:
|
| 261 |
+
"""Generate checkpoint directory name: just step{N}."""
|
| 262 |
+
dirname = f"step{step}"
|
|
|
|
| 263 |
return self.sub_checkpoint_dir / dirname
|
| 264 |
|
| 265 |
def _safe_state_dict(self, model: nn.Module) -> Dict[str, torch.Tensor]:
|
|
|
|
| 316 |
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 317 |
|
| 318 |
print(f"\n💾 Saving checkpoint: {self.sub_name}/{ckpt_dir.name}")
|
| 319 |
+
print(f" Step: {step}, Epoch: {epoch}")
|
| 320 |
|
| 321 |
state_dict = self._safe_state_dict(model)
|
| 322 |
weights_path = ckpt_dir / "model.safetensors"
|
| 323 |
save_file(state_dict, weights_path)
|
| 324 |
+
print(f" ✓ Model weights: model.safetensors")
|
| 325 |
|
| 326 |
training_state = {
|
| 327 |
'epoch': epoch,
|
|
|
|
| 329 |
'optimizer_state_dict': optimizer.state_dict(),
|
| 330 |
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
|
| 331 |
'val_loss': val_loss,
|
| 332 |
+
'sub_name': self.sub_name,
|
| 333 |
+
'base_name': self.base_sub_name
|
| 334 |
}
|
| 335 |
torch.save(training_state, ckpt_dir / "training_state.pt")
|
| 336 |
print(f" ✓ Training state: training_state.pt")
|
|
|
|
| 400 |
traceback.print_exc()
|
| 401 |
|
| 402 |
def find_latest_checkpoint(self) -> Optional[Dict]:
|
| 403 |
+
"""Find the latest checkpoint for this training run."""
|
| 404 |
checkpoints = self.checkpoint_history.get('checkpoints', [])
|
| 405 |
if checkpoints:
|
| 406 |
return max(checkpoints, key=lambda x: x['step'])
|
|
|
|
| 416 |
latest = self.find_latest_checkpoint()
|
| 417 |
|
| 418 |
if not latest:
|
| 419 |
+
print(f"ℹ️ No previous checkpoint found for training run '{self.sub_name}'")
|
| 420 |
return 0, 0, float('inf')
|
| 421 |
|
| 422 |
ckpt_dir = self.sub_checkpoint_dir / latest['dirname']
|
|
|
|
| 444 |
print(f" ⚠️ Checkpoint directory not found: {ckpt_dir}")
|
| 445 |
return 0, 0, float('inf')
|
| 446 |
|
| 447 |
+
print(f"\n🔄 Resuming from checkpoint: {self.sub_name}/{latest['dirname']}")
|
| 448 |
print(f" Step: {latest['step']}, Epoch: {latest['epoch']}, Val Loss: {latest['val_loss']:.4f}")
|
| 449 |
|
| 450 |
weights_path = ckpt_dir / "model.safetensors"
|
|
|
|
| 1124 |
return
|
| 1125 |
|
| 1126 |
for key, value in metrics.items():
|
| 1127 |
+
# Handle validation metrics that already have prefixes
|
| 1128 |
+
if prefix == "val" and key.startswith(('loss/', 'acc/')):
|
| 1129 |
+
# Strip the redundant prefix
|
| 1130 |
+
clean_key = key.replace('loss/', '').replace('acc/', '')
|
| 1131 |
+
self.writer.add_scalar(f"val/{clean_key}", value, self.global_step)
|
| 1132 |
+
else:
|
| 1133 |
+
self.writer.add_scalar(f"{prefix}/{key}", value, self.global_step)
|
| 1134 |
+
|
| 1135 |
+
# Log learning rate
|
| 1136 |
+
if prefix == "train":
|
| 1137 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
| 1138 |
+
self.writer.add_scalar("train/learning_rate", current_lr, self.global_step)
|
| 1139 |
|
| 1140 |
+
# Flush to disk
|
| 1141 |
+
self.writer.flush()
|
| 1142 |
|
| 1143 |
+
# Log text modality stats periodically
|
| 1144 |
+
if prefix == "train" and self.global_step % self.config.log_every == 0:
|
| 1145 |
total = sum(self.text_dropout_stats.values()) or 1
|
| 1146 |
for mode, count in self.text_dropout_stats.items():
|
| 1147 |
self.writer.add_scalar(f"text_modality/{mode}_pct", 100 * count / total, self.global_step)
|
| 1148 |
|
| 1149 |
+
# Log fusion diagnostics periodically
|
| 1150 |
+
if prefix == "train" and self.global_step % (self.config.log_every * 10) == 0:
|
| 1151 |
fusion_diag = self.get_fusion_diagnostics()
|
| 1152 |
|
| 1153 |
for i, w in enumerate(fusion_diag.get('layer_weights', [])):
|
|
|
|
| 1161 |
|
| 1162 |
for i, b in enumerate(fusion_diag.get('beta_per_scale', [])):
|
| 1163 |
self.writer.add_scalar(f"fusion/beta_scale_{i}", b, self.global_step)
|
| 1164 |
+
|
| 1165 |
+
self.writer.flush()
|
| 1166 |
|
| 1167 |
@torch.no_grad()
|
| 1168 |
def validate(self, max_batches: int = 100) -> Dict[str, float]:
|
|
|
|
| 1218 |
continue
|
| 1219 |
|
| 1220 |
if stats_with_text['count'] == 0 or stats_vision_only['count'] == 0:
|
| 1221 |
+
return {
|
| 1222 |
+
'val_with_text_loss': float('inf'),
|
| 1223 |
+
'val_with_text_acc': 0.0,
|
| 1224 |
+
'val_vision_only_loss': float('inf'),
|
| 1225 |
+
'val_vision_only_acc': 0.0,
|
| 1226 |
+
'loss/val': float('inf'),
|
| 1227 |
+
'acc/val': 0.0
|
| 1228 |
+
}
|
| 1229 |
|
| 1230 |
return {
|
| 1231 |
+
'val_with_text_loss': stats_with_text['loss'] / stats_with_text['count'],
|
| 1232 |
+
'val_with_text_acc': stats_with_text['acc'] / stats_with_text['count'],
|
| 1233 |
+
'val_vision_only_loss': stats_vision_only['loss'] / stats_vision_only['count'],
|
| 1234 |
+
'val_vision_only_acc': stats_vision_only['acc'] / stats_vision_only['count'],
|
|
|
|
| 1235 |
'loss/val': stats_vision_only['loss'] / stats_vision_only['count'],
|
| 1236 |
'acc/val': stats_vision_only['acc'] / stats_vision_only['count'],
|
| 1237 |
}
|
|
|
|
| 1239 |
except Exception as e:
|
| 1240 |
print(f"\n⚠️ Validation completely failed: {e}")
|
| 1241 |
traceback.print_exc()
|
| 1242 |
+
return {
|
| 1243 |
+
'val_with_text_loss': float('inf'),
|
| 1244 |
+
'val_with_text_acc': 0.0,
|
| 1245 |
+
'val_vision_only_loss': float('inf'),
|
| 1246 |
+
'val_vision_only_acc': 0.0,
|
| 1247 |
+
'loss/val': float('inf'),
|
| 1248 |
+
'acc/val': 0.0
|
| 1249 |
+
}
|
| 1250 |
|
| 1251 |
def save_checkpoint_and_upload(self, epoch: int, val_loss: float = float('inf'), is_best: bool = False):
|
| 1252 |
"""Save checkpoint first, then optionally upload."""
|
|
|
|
| 1328 |
print("\n🔍 Running validation...")
|
| 1329 |
val_metrics = self.validate(max_batches=50)
|
| 1330 |
self.log_metrics(val_metrics, prefix="val")
|
| 1331 |
+
print(f"✓ Val (with text) - Loss: {val_metrics['val_with_text_loss']:.4f}, Acc: {val_metrics['val_with_text_acc']:.4f}")
|
| 1332 |
+
print(f"✓ Val (vision-only) - Loss: {val_metrics['val_vision_only_loss']:.4f}, Acc: {val_metrics['val_vision_only_acc']:.4f}")
|
| 1333 |
|
| 1334 |
# HuggingFace upload
|
| 1335 |
if (self.config.hf_repo_id and
|
|
|
|
| 1339 |
if self.accelerator.is_main_process:
|
| 1340 |
print("\n🔍 Running validation for upload...")
|
| 1341 |
val_metrics = self.validate(max_batches=50)
|
| 1342 |
+
print(f"✓ Val (with text) - Loss: {val_metrics['val_with_text_loss']:.4f}, Acc: {val_metrics['val_with_text_acc']:.4f}")
|
| 1343 |
+
print(f"✓ Val (vision-only) - Loss: {val_metrics['val_vision_only_loss']:.4f}, Acc: {val_metrics['val_vision_only_acc']:.4f}")
|
| 1344 |
|
| 1345 |
if self._interrupt_received:
|
| 1346 |
break
|
|
|
|
| 1354 |
|
| 1355 |
print(f"\n📊 Validation Results:")
|
| 1356 |
print(f" With Text:")
|
| 1357 |
+
print(f" Loss: {val_metrics['val_with_text_loss']:.4f}")
|
| 1358 |
+
print(f" Acc: {val_metrics['val_with_text_acc']:.4f}")
|
| 1359 |
print(f" Vision-Only (PRIMARY METRIC):")
|
| 1360 |
+
print(f" Loss: {val_metrics['val_vision_only_loss']:.4f}")
|
| 1361 |
+
print(f" Acc: {val_metrics['val_vision_only_acc']:.4f}")
|
| 1362 |
|
| 1363 |
self.log_metrics(val_metrics, prefix="val")
|
| 1364 |
|
|
|
|
| 1403 |
if self.writer:
|
| 1404 |
self.writer.close()
|
| 1405 |
|
|
|
|
| 1406 |
# ============================================================================
|
| 1407 |
# MAIN
|
| 1408 |
# ============================================================================
|
|
|
|
| 1410 |
if __name__ == "__main__":
|
| 1411 |
config = DanbooruTrainingConfig(
|
| 1412 |
# Run identifier
|
| 1413 |
+
sub_name="danbooru-50k-v1-512-2",
|
| 1414 |
|
| 1415 |
# Model architecture
|
| 1416 |
num_opinion_anchors=225,
|
| 1417 |
+
pentachoron_dim=512,
|
| 1418 |
scales=[128, 256, 512, 1024],
|
| 1419 |
+
scale_hidden_dims={128: 256, 256: 512, 512: 1024, 1024: 2048},
|
| 1420 |
|
| 1421 |
# Fusion controller
|
| 1422 |
+
alpha_init=0.125,
|
| 1423 |
alpha_learnable=True,
|
| 1424 |
beta_init=0.5,
|
| 1425 |
beta_learnable=True,
|
|
|
|
| 1427 |
learn_layer_weights=True,
|
| 1428 |
|
| 1429 |
# Encoders
|
| 1430 |
+
clip_skip=1,
|
| 1431 |
+
siglip_layer_indices=[1, 2, 3, 4, 5, 6, 9, 12, 18, 21, 23, 24, 25, 26],
|
| 1432 |
|
| 1433 |
# Optimizations
|
| 1434 |
use_gradient_checkpointing=False,
|
| 1435 |
share_scale_embeddings=False,
|
| 1436 |
|
| 1437 |
# Training
|
| 1438 |
+
batch_size=24,
|
| 1439 |
+
num_epochs=20,
|
| 1440 |
learning_rate=1e-4,
|
| 1441 |
save_every=500,
|
| 1442 |
|
|
|
|
| 1450 |
text_dropout_end=0.5,
|
| 1451 |
|
| 1452 |
# Resume
|
| 1453 |
+
resume=False,
|
| 1454 |
|
| 1455 |
# HuggingFace
|
| 1456 |
hf_repo_id="AbstractPhil/liminal-staircase-v2",
|