AbstractPhil commited on
Commit
9c5f7a4
·
verified ·
1 Parent(s): a5c7471

Update scripts/trainer_v4_testing.py

Browse files
Files changed (1) hide show
  1. scripts/trainer_v4_testing.py +54 -54
scripts/trainer_v4_testing.py CHANGED
@@ -45,8 +45,8 @@ warnings.filterwarnings('ignore', message='.*TF32.*')
45
  # ============================================================================
46
  # CONFIG
47
  # ============================================================================
48
- BATCH_SIZE = 16
49
- GRAD_ACCUM = 2
50
  LR = 3e-4
51
  EPOCHS = 10
52
  MAX_SEQ = 128
@@ -986,7 +986,7 @@ def build_object_relations_prompt(item):
986
 
987
  if ENABLE_OBJECT_RELATIONS:
988
  print(f"\n[6/6] Loading Object Relations from {OBJECT_RELATIONS_REPO}...")
989
- object_relations_ds = load_dataset(OBJECT_RELATIONS_REPO, split="train")
990
  print(f" Raw samples: {len(object_relations_ds)}")
991
 
992
  # Use columnar access - MUCH faster than row iteration
@@ -1431,7 +1431,7 @@ def get_sol_features_for_batch(
1431
 
1432
  B = local_indices.shape[0]
1433
  device = timesteps.device
1434
- stats = torch.zeros(B, 4, device=device, dtype=DTYPE)
1435
  spatial = torch.zeros(B, SOL_SPATIAL_SIZE, SOL_SPATIAL_SIZE, device=device, dtype=DTYPE)
1436
 
1437
  for ds_id, cache in enumerate(caches):
@@ -1443,7 +1443,7 @@ def get_sol_features_for_batch(
1443
  ds_local_indices = local_indices[mask]
1444
  ds_timesteps = timesteps[mask]
1445
  ds_stats, ds_spatial = cache.get_features(ds_local_indices, ds_timesteps)
1446
- stats[mask] = ds_stats
1447
  spatial[mask] = ds_spatial
1448
 
1449
  return stats, spatial
@@ -1758,30 +1758,39 @@ def upload_logs():
1758
  # ============================================================================
1759
  # WEIGHT UPGRADE LOADING (v3 -> v4.1)
1760
  # ============================================================================
1761
- V3_TO_V4_REMAP = {
1762
- # ExpertPredictor -> LunePredictor
1763
- 'expert_predictor.t_embed.0.weight': 'lune_predictor.t_embed.0.weight',
1764
- 'expert_predictor.t_embed.0.bias': 'lune_predictor.t_embed.0.bias',
1765
- 'expert_predictor.t_embed.2.weight': 'lune_predictor.t_embed.2.weight',
1766
- 'expert_predictor.t_embed.2.bias': 'lune_predictor.t_embed.2.bias',
1767
- 'expert_predictor.clip_proj.weight': 'lune_predictor.clip_proj.weight',
1768
- 'expert_predictor.clip_proj.bias': 'lune_predictor.clip_proj.bias',
1769
- 'expert_predictor.out_proj.0.weight': 'lune_predictor.out_proj.0.weight',
1770
- 'expert_predictor.out_proj.0.bias': 'lune_predictor.out_proj.0.bias',
1771
- 'expert_predictor.out_proj.2.weight': 'lune_predictor.out_proj.2.weight',
1772
- 'expert_predictor.out_proj.2.bias': 'lune_predictor.out_proj.2.bias',
1773
- 'expert_predictor.gate': 'lune_predictor.gate',
1774
- # expert_features -> lune_features
1775
- 'expert_features': 'lune_features',
1776
- }
1777
 
1778
 
1779
  def load_with_weight_upgrade(model, state_dict):
1780
- """Load state dict with v3 -> v4.1 remapping support."""
 
 
 
 
 
1781
  model_state = model.state_dict()
1782
-
1783
- # New modules in v4.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1784
  NEW_WEIGHT_PATTERNS = [
 
1785
  'lune_predictor.',
1786
  'sol_prior.',
1787
  't5_vec_proj.',
@@ -1791,12 +1800,10 @@ def load_with_weight_upgrade(model, state_dict):
1791
  '.norm_added_k.weight',
1792
  ]
1793
 
1794
- # Deprecated keys from v3
1795
  DEPRECATED_PATTERNS = [
1796
  'guidance_in.',
1797
  '.sin_basis',
1798
- 'expert_predictor.', # Renamed to lune_predictor
1799
- 'expert_features', # Renamed to lune_features
1800
  ]
1801
 
1802
  loaded_keys = []
@@ -1805,15 +1812,16 @@ def load_with_weight_upgrade(model, state_dict):
1805
  initialized_keys = []
1806
  remapped_keys = []
1807
 
1808
- # First pass: remap v3 keys to v4 keys
1809
  remapped_state = {}
1810
  for k, v in state_dict.items():
1811
- if k in V3_TO_V4_REMAP:
1812
- new_key = V3_TO_V4_REMAP[k]
1813
- remapped_state[new_key] = v
1814
- remapped_keys.append(f"{k} -> {new_key}")
1815
- else:
1816
- remapped_state[k] = v
 
1817
 
1818
  # Second pass: load matching weights
1819
  for key, v in remapped_state.items():
@@ -2014,13 +2022,6 @@ print("\nCreating TinyFlux v4.1 model with Lune + Sol...")
2014
  # If running as a script, uncomment the import below:
2015
  # from model_v4 import TinyFluxConfig, TinyFlux
2016
 
2017
- # Check that model classes exist
2018
- if 'TinyFluxConfig' not in dir() or 'TinyFlux' not in dir():
2019
- raise RuntimeError(
2020
- "TinyFluxConfig and TinyFlux not found! "
2021
- "Run model_v4.py cell first, or add: from model_v4 import TinyFluxConfig, TinyFlux"
2022
- )
2023
-
2024
  config = TinyFluxConfig(
2025
  hidden_size=512,
2026
  num_attention_heads=4,
@@ -2047,7 +2048,7 @@ config = TinyFluxConfig(
2047
  huber_delta=HUBER_DELTA,
2048
  guidance_embeds=False,
2049
  )
2050
- model = TinyFlux(config).to(device=DEVICE, dtype=DTYPE)
2051
 
2052
  total_params = sum(p.numel() for p in model.parameters())
2053
  print(f"Total parameters: {total_params:,}")
@@ -2102,9 +2103,9 @@ if ema_state is not None:
2102
  # Remap v3 EMA keys to v4
2103
  remapped_ema = {}
2104
  for k, v in ema_state.items():
2105
- if k in V3_TO_V4_REMAP:
2106
- remapped_ema[V3_TO_V4_REMAP[k]] = v
2107
- else:
2108
  remapped_ema[k] = v
2109
  ema.load_shadow(remapped_ema, model=model)
2110
 
@@ -2198,7 +2199,7 @@ for ep in range(start_epoch, EPOCHS):
2198
  x_t = (1 - t_expanded) * noise + t_expanded * data
2199
  v_target = data - noise
2200
 
2201
- img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
2202
 
2203
  # Get expert features from CACHE
2204
  lune_features = None
@@ -2248,19 +2249,18 @@ for ep in range(start_epoch, EPOCHS):
2248
 
2249
  # Lune distillation loss
2250
  lune_loss = torch.tensor(0.0, device=DEVICE)
2251
- if lune_features is not None and expert_info.get('lune_pred') is not None:
2252
  lune_loss = compute_lune_loss(
2253
- expert_info['lune_pred'], lune_features, mode=LUNE_DISTILL_MODE
2254
  )
2255
 
2256
  # Sol distillation loss
2257
  sol_loss = torch.tensor(0.0, device=DEVICE)
2258
- if sol_stats is not None and expert_info.get('sol_stats_pred') is not None:
2259
- sol_loss = compute_sol_loss(
2260
- expert_info['sol_stats_pred'], expert_info.get('sol_spatial_pred'),
2261
- sol_stats, sol_spatial
2262
- )
2263
-
2264
  # Total loss with warmup weights
2265
  total_loss = main_loss
2266
  total_loss = total_loss + get_lune_weight(step) * lune_loss
 
45
  # ============================================================================
46
  # CONFIG
47
  # ============================================================================
48
+ BATCH_SIZE = 8
49
+ GRAD_ACCUM = 4
50
  LR = 3e-4
51
  EPOCHS = 10
52
  MAX_SEQ = 128
 
986
 
987
  if ENABLE_OBJECT_RELATIONS:
988
  print(f"\n[6/6] Loading Object Relations from {OBJECT_RELATIONS_REPO}...")
989
+ object_relations_ds = load_dataset(OBJECT_RELATIONS_REPO, "schnell_512_1", split="train")
990
  print(f" Raw samples: {len(object_relations_ds)}")
991
 
992
  # Use columnar access - MUCH faster than row iteration
 
1431
 
1432
  B = local_indices.shape[0]
1433
  device = timesteps.device
1434
+ stats = torch.zeros(B, 3, device=device, dtype=DTYPE) # 3 stats: locality, entropy, clustering
1435
  spatial = torch.zeros(B, SOL_SPATIAL_SIZE, SOL_SPATIAL_SIZE, device=device, dtype=DTYPE)
1436
 
1437
  for ds_id, cache in enumerate(caches):
 
1443
  ds_local_indices = local_indices[mask]
1444
  ds_timesteps = timesteps[mask]
1445
  ds_stats, ds_spatial = cache.get_features(ds_local_indices, ds_timesteps)
1446
+ stats[mask] = ds_stats[:, :3] # Drop redundant sparsity (was copy of locality)
1447
  spatial[mask] = ds_spatial
1448
 
1449
  return stats, spatial
 
1758
  # ============================================================================
1759
  # WEIGHT UPGRADE LOADING (v3 -> v4.1)
1760
  # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1761
 
1762
 
1763
  def load_with_weight_upgrade(model, state_dict):
1764
+ """Load state dict with bidirectional remapping support.
1765
+
1766
+ Handles:
1767
+ - v3 checkpoint (expert_predictor) -> v4 model (lune_predictor)
1768
+ - v4 checkpoint (lune_predictor) -> model with (expert_predictor)
1769
+ """
1770
  model_state = model.state_dict()
1771
+
1772
+ # Detect which naming the MODEL uses
1773
+ model_has_expert = any('expert_predictor' in k for k in model_state.keys())
1774
+ model_has_lune = any('lune_predictor' in k for k in model_state.keys())
1775
+
1776
+ # Detect which naming the CHECKPOINT uses
1777
+ ckpt_has_expert = any('expert_predictor' in k for k in state_dict.keys())
1778
+ ckpt_has_lune = any('lune_predictor' in k for k in state_dict.keys())
1779
+
1780
+ # Build remap based on mismatch
1781
+ REMAP = {}
1782
+ if model_has_expert and ckpt_has_lune:
1783
+ # Checkpoint has lune_predictor, model expects expert_predictor
1784
+ print(" Remapping: lune_predictor -> expert_predictor")
1785
+ REMAP = {'lune_predictor.': 'expert_predictor.'}
1786
+ elif model_has_lune and ckpt_has_expert:
1787
+ # Checkpoint has expert_predictor, model expects lune_predictor
1788
+ print(" Remapping: expert_predictor -> lune_predictor")
1789
+ REMAP = {'expert_predictor.': 'lune_predictor.'}
1790
+
1791
+ # New modules that may not exist in checkpoint
1792
  NEW_WEIGHT_PATTERNS = [
1793
+ 'expert_predictor.',
1794
  'lune_predictor.',
1795
  'sol_prior.',
1796
  't5_vec_proj.',
 
1800
  '.norm_added_k.weight',
1801
  ]
1802
 
1803
+ # Deprecated keys
1804
  DEPRECATED_PATTERNS = [
1805
  'guidance_in.',
1806
  '.sin_basis',
 
 
1807
  ]
1808
 
1809
  loaded_keys = []
 
1812
  initialized_keys = []
1813
  remapped_keys = []
1814
 
1815
+ # First pass: remap checkpoint keys to match model
1816
  remapped_state = {}
1817
  for k, v in state_dict.items():
1818
+ new_k = k
1819
+ for old_pat, new_pat in REMAP.items():
1820
+ if old_pat in k:
1821
+ new_k = k.replace(old_pat, new_pat)
1822
+ remapped_keys.append(f"{k} -> {new_k}")
1823
+ break
1824
+ remapped_state[new_k] = v
1825
 
1826
  # Second pass: load matching weights
1827
  for key, v in remapped_state.items():
 
2022
  # If running as a script, uncomment the import below:
2023
  # from model_v4 import TinyFluxConfig, TinyFlux
2024
 
 
 
 
 
 
 
 
2025
  config = TinyFluxConfig(
2026
  hidden_size=512,
2027
  num_attention_heads=4,
 
2048
  huber_delta=HUBER_DELTA,
2049
  guidance_embeds=False,
2050
  )
2051
+ model = TinyFluxDeep(config).to(device=DEVICE, dtype=DTYPE)
2052
 
2053
  total_params = sum(p.numel() for p in model.parameters())
2054
  print(f"Total parameters: {total_params:,}")
 
2103
  # Remap v3 EMA keys to v4
2104
  remapped_ema = {}
2105
  for k, v in ema_state.items():
2106
+ #if k in V3_TO_V4_REMAP:
2107
+ # remapped_ema[V3_TO_V4_REMAP[k]] = v
2108
+ #else:
2109
  remapped_ema[k] = v
2110
  ema.load_shadow(remapped_ema, model=model)
2111
 
 
2199
  x_t = (1 - t_expanded) * noise + t_expanded * data
2200
  v_target = data - noise
2201
 
2202
+ img_ids = TinyFluxDeep.create_img_ids(B, H, W, DEVICE)
2203
 
2204
  # Get expert features from CACHE
2205
  lune_features = None
 
2249
 
2250
  # Lune distillation loss
2251
  lune_loss = torch.tensor(0.0, device=DEVICE)
2252
+ if lune_features is not None and expert_info.get('lune') is not None:
2253
  lune_loss = compute_lune_loss(
2254
+ expert_info['lune']['expert_pred'], lune_features, mode=LUNE_DISTILL_MODE
2255
  )
2256
 
2257
  # Sol distillation loss
2258
  sol_loss = torch.tensor(0.0, device=DEVICE)
2259
+ if sol_stats is not None and expert_info.get('sol') is not None:
2260
+ sol_loss = compute_sol_loss(
2261
+ expert_info['sol']['pred_stats'], expert_info['sol']['pred_spatial'],
2262
+ sol_stats, sol_spatial
2263
+ )
 
2264
  # Total loss with warmup weights
2265
  total_loss = main_loss
2266
  total_loss = total_loss + get_lune_weight(step) * lune_loss