Update scripts/trainer_v4_testing.py
Browse files- scripts/trainer_v4_testing.py +54 -54
scripts/trainer_v4_testing.py
CHANGED
|
@@ -45,8 +45,8 @@ warnings.filterwarnings('ignore', message='.*TF32.*')
|
|
| 45 |
# ============================================================================
|
| 46 |
# CONFIG
|
| 47 |
# ============================================================================
|
| 48 |
-
BATCH_SIZE =
|
| 49 |
-
GRAD_ACCUM =
|
| 50 |
LR = 3e-4
|
| 51 |
EPOCHS = 10
|
| 52 |
MAX_SEQ = 128
|
|
@@ -986,7 +986,7 @@ def build_object_relations_prompt(item):
|
|
| 986 |
|
| 987 |
if ENABLE_OBJECT_RELATIONS:
|
| 988 |
print(f"\n[6/6] Loading Object Relations from {OBJECT_RELATIONS_REPO}...")
|
| 989 |
-
object_relations_ds = load_dataset(OBJECT_RELATIONS_REPO, split="train")
|
| 990 |
print(f" Raw samples: {len(object_relations_ds)}")
|
| 991 |
|
| 992 |
# Use columnar access - MUCH faster than row iteration
|
|
@@ -1431,7 +1431,7 @@ def get_sol_features_for_batch(
|
|
| 1431 |
|
| 1432 |
B = local_indices.shape[0]
|
| 1433 |
device = timesteps.device
|
| 1434 |
-
stats = torch.zeros(B,
|
| 1435 |
spatial = torch.zeros(B, SOL_SPATIAL_SIZE, SOL_SPATIAL_SIZE, device=device, dtype=DTYPE)
|
| 1436 |
|
| 1437 |
for ds_id, cache in enumerate(caches):
|
|
@@ -1443,7 +1443,7 @@ def get_sol_features_for_batch(
|
|
| 1443 |
ds_local_indices = local_indices[mask]
|
| 1444 |
ds_timesteps = timesteps[mask]
|
| 1445 |
ds_stats, ds_spatial = cache.get_features(ds_local_indices, ds_timesteps)
|
| 1446 |
-
stats[mask] = ds_stats
|
| 1447 |
spatial[mask] = ds_spatial
|
| 1448 |
|
| 1449 |
return stats, spatial
|
|
@@ -1758,30 +1758,39 @@ def upload_logs():
|
|
| 1758 |
# ============================================================================
|
| 1759 |
# WEIGHT UPGRADE LOADING (v3 -> v4.1)
|
| 1760 |
# ============================================================================
|
| 1761 |
-
V3_TO_V4_REMAP = {
|
| 1762 |
-
# ExpertPredictor -> LunePredictor
|
| 1763 |
-
'expert_predictor.t_embed.0.weight': 'lune_predictor.t_embed.0.weight',
|
| 1764 |
-
'expert_predictor.t_embed.0.bias': 'lune_predictor.t_embed.0.bias',
|
| 1765 |
-
'expert_predictor.t_embed.2.weight': 'lune_predictor.t_embed.2.weight',
|
| 1766 |
-
'expert_predictor.t_embed.2.bias': 'lune_predictor.t_embed.2.bias',
|
| 1767 |
-
'expert_predictor.clip_proj.weight': 'lune_predictor.clip_proj.weight',
|
| 1768 |
-
'expert_predictor.clip_proj.bias': 'lune_predictor.clip_proj.bias',
|
| 1769 |
-
'expert_predictor.out_proj.0.weight': 'lune_predictor.out_proj.0.weight',
|
| 1770 |
-
'expert_predictor.out_proj.0.bias': 'lune_predictor.out_proj.0.bias',
|
| 1771 |
-
'expert_predictor.out_proj.2.weight': 'lune_predictor.out_proj.2.weight',
|
| 1772 |
-
'expert_predictor.out_proj.2.bias': 'lune_predictor.out_proj.2.bias',
|
| 1773 |
-
'expert_predictor.gate': 'lune_predictor.gate',
|
| 1774 |
-
# expert_features -> lune_features
|
| 1775 |
-
'expert_features': 'lune_features',
|
| 1776 |
-
}
|
| 1777 |
|
| 1778 |
|
| 1779 |
def load_with_weight_upgrade(model, state_dict):
|
| 1780 |
-
"""Load state dict with
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1781 |
model_state = model.state_dict()
|
| 1782 |
-
|
| 1783 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1784 |
NEW_WEIGHT_PATTERNS = [
|
|
|
|
| 1785 |
'lune_predictor.',
|
| 1786 |
'sol_prior.',
|
| 1787 |
't5_vec_proj.',
|
|
@@ -1791,12 +1800,10 @@ def load_with_weight_upgrade(model, state_dict):
|
|
| 1791 |
'.norm_added_k.weight',
|
| 1792 |
]
|
| 1793 |
|
| 1794 |
-
# Deprecated keys
|
| 1795 |
DEPRECATED_PATTERNS = [
|
| 1796 |
'guidance_in.',
|
| 1797 |
'.sin_basis',
|
| 1798 |
-
'expert_predictor.', # Renamed to lune_predictor
|
| 1799 |
-
'expert_features', # Renamed to lune_features
|
| 1800 |
]
|
| 1801 |
|
| 1802 |
loaded_keys = []
|
|
@@ -1805,15 +1812,16 @@ def load_with_weight_upgrade(model, state_dict):
|
|
| 1805 |
initialized_keys = []
|
| 1806 |
remapped_keys = []
|
| 1807 |
|
| 1808 |
-
# First pass: remap
|
| 1809 |
remapped_state = {}
|
| 1810 |
for k, v in state_dict.items():
|
| 1811 |
-
|
| 1812 |
-
|
| 1813 |
-
|
| 1814 |
-
|
| 1815 |
-
|
| 1816 |
-
|
|
|
|
| 1817 |
|
| 1818 |
# Second pass: load matching weights
|
| 1819 |
for key, v in remapped_state.items():
|
|
@@ -2014,13 +2022,6 @@ print("\nCreating TinyFlux v4.1 model with Lune + Sol...")
|
|
| 2014 |
# If running as a script, uncomment the import below:
|
| 2015 |
# from model_v4 import TinyFluxConfig, TinyFlux
|
| 2016 |
|
| 2017 |
-
# Check that model classes exist
|
| 2018 |
-
if 'TinyFluxConfig' not in dir() or 'TinyFlux' not in dir():
|
| 2019 |
-
raise RuntimeError(
|
| 2020 |
-
"TinyFluxConfig and TinyFlux not found! "
|
| 2021 |
-
"Run model_v4.py cell first, or add: from model_v4 import TinyFluxConfig, TinyFlux"
|
| 2022 |
-
)
|
| 2023 |
-
|
| 2024 |
config = TinyFluxConfig(
|
| 2025 |
hidden_size=512,
|
| 2026 |
num_attention_heads=4,
|
|
@@ -2047,7 +2048,7 @@ config = TinyFluxConfig(
|
|
| 2047 |
huber_delta=HUBER_DELTA,
|
| 2048 |
guidance_embeds=False,
|
| 2049 |
)
|
| 2050 |
-
model =
|
| 2051 |
|
| 2052 |
total_params = sum(p.numel() for p in model.parameters())
|
| 2053 |
print(f"Total parameters: {total_params:,}")
|
|
@@ -2102,9 +2103,9 @@ if ema_state is not None:
|
|
| 2102 |
# Remap v3 EMA keys to v4
|
| 2103 |
remapped_ema = {}
|
| 2104 |
for k, v in ema_state.items():
|
| 2105 |
-
if k in V3_TO_V4_REMAP:
|
| 2106 |
-
|
| 2107 |
-
else:
|
| 2108 |
remapped_ema[k] = v
|
| 2109 |
ema.load_shadow(remapped_ema, model=model)
|
| 2110 |
|
|
@@ -2198,7 +2199,7 @@ for ep in range(start_epoch, EPOCHS):
|
|
| 2198 |
x_t = (1 - t_expanded) * noise + t_expanded * data
|
| 2199 |
v_target = data - noise
|
| 2200 |
|
| 2201 |
-
img_ids =
|
| 2202 |
|
| 2203 |
# Get expert features from CACHE
|
| 2204 |
lune_features = None
|
|
@@ -2248,19 +2249,18 @@ for ep in range(start_epoch, EPOCHS):
|
|
| 2248 |
|
| 2249 |
# Lune distillation loss
|
| 2250 |
lune_loss = torch.tensor(0.0, device=DEVICE)
|
| 2251 |
-
if lune_features is not None and expert_info.get('
|
| 2252 |
lune_loss = compute_lune_loss(
|
| 2253 |
-
expert_info['
|
| 2254 |
)
|
| 2255 |
|
| 2256 |
# Sol distillation loss
|
| 2257 |
sol_loss = torch.tensor(0.0, device=DEVICE)
|
| 2258 |
-
if sol_stats is not None and expert_info.get('
|
| 2259 |
-
|
| 2260 |
-
|
| 2261 |
-
|
| 2262 |
-
|
| 2263 |
-
|
| 2264 |
# Total loss with warmup weights
|
| 2265 |
total_loss = main_loss
|
| 2266 |
total_loss = total_loss + get_lune_weight(step) * lune_loss
|
|
|
|
| 45 |
# ============================================================================
|
| 46 |
# CONFIG
|
| 47 |
# ============================================================================
|
| 48 |
+
BATCH_SIZE = 8
|
| 49 |
+
GRAD_ACCUM = 4
|
| 50 |
LR = 3e-4
|
| 51 |
EPOCHS = 10
|
| 52 |
MAX_SEQ = 128
|
|
|
|
| 986 |
|
| 987 |
if ENABLE_OBJECT_RELATIONS:
|
| 988 |
print(f"\n[6/6] Loading Object Relations from {OBJECT_RELATIONS_REPO}...")
|
| 989 |
+
object_relations_ds = load_dataset(OBJECT_RELATIONS_REPO, "schnell_512_1", split="train")
|
| 990 |
print(f" Raw samples: {len(object_relations_ds)}")
|
| 991 |
|
| 992 |
# Use columnar access - MUCH faster than row iteration
|
|
|
|
| 1431 |
|
| 1432 |
B = local_indices.shape[0]
|
| 1433 |
device = timesteps.device
|
| 1434 |
+
stats = torch.zeros(B, 3, device=device, dtype=DTYPE) # 3 stats: locality, entropy, clustering
|
| 1435 |
spatial = torch.zeros(B, SOL_SPATIAL_SIZE, SOL_SPATIAL_SIZE, device=device, dtype=DTYPE)
|
| 1436 |
|
| 1437 |
for ds_id, cache in enumerate(caches):
|
|
|
|
| 1443 |
ds_local_indices = local_indices[mask]
|
| 1444 |
ds_timesteps = timesteps[mask]
|
| 1445 |
ds_stats, ds_spatial = cache.get_features(ds_local_indices, ds_timesteps)
|
| 1446 |
+
stats[mask] = ds_stats[:, :3] # Drop redundant sparsity (was copy of locality)
|
| 1447 |
spatial[mask] = ds_spatial
|
| 1448 |
|
| 1449 |
return stats, spatial
|
|
|
|
| 1758 |
# ============================================================================
|
| 1759 |
# WEIGHT UPGRADE LOADING (v3 -> v4.1)
|
| 1760 |
# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1761 |
|
| 1762 |
|
| 1763 |
def load_with_weight_upgrade(model, state_dict):
|
| 1764 |
+
"""Load state dict with bidirectional remapping support.
|
| 1765 |
+
|
| 1766 |
+
Handles:
|
| 1767 |
+
- v3 checkpoint (expert_predictor) -> v4 model (lune_predictor)
|
| 1768 |
+
- v4 checkpoint (lune_predictor) -> model with (expert_predictor)
|
| 1769 |
+
"""
|
| 1770 |
model_state = model.state_dict()
|
| 1771 |
+
|
| 1772 |
+
# Detect which naming the MODEL uses
|
| 1773 |
+
model_has_expert = any('expert_predictor' in k for k in model_state.keys())
|
| 1774 |
+
model_has_lune = any('lune_predictor' in k for k in model_state.keys())
|
| 1775 |
+
|
| 1776 |
+
# Detect which naming the CHECKPOINT uses
|
| 1777 |
+
ckpt_has_expert = any('expert_predictor' in k for k in state_dict.keys())
|
| 1778 |
+
ckpt_has_lune = any('lune_predictor' in k for k in state_dict.keys())
|
| 1779 |
+
|
| 1780 |
+
# Build remap based on mismatch
|
| 1781 |
+
REMAP = {}
|
| 1782 |
+
if model_has_expert and ckpt_has_lune:
|
| 1783 |
+
# Checkpoint has lune_predictor, model expects expert_predictor
|
| 1784 |
+
print(" Remapping: lune_predictor -> expert_predictor")
|
| 1785 |
+
REMAP = {'lune_predictor.': 'expert_predictor.'}
|
| 1786 |
+
elif model_has_lune and ckpt_has_expert:
|
| 1787 |
+
# Checkpoint has expert_predictor, model expects lune_predictor
|
| 1788 |
+
print(" Remapping: expert_predictor -> lune_predictor")
|
| 1789 |
+
REMAP = {'expert_predictor.': 'lune_predictor.'}
|
| 1790 |
+
|
| 1791 |
+
# New modules that may not exist in checkpoint
|
| 1792 |
NEW_WEIGHT_PATTERNS = [
|
| 1793 |
+
'expert_predictor.',
|
| 1794 |
'lune_predictor.',
|
| 1795 |
'sol_prior.',
|
| 1796 |
't5_vec_proj.',
|
|
|
|
| 1800 |
'.norm_added_k.weight',
|
| 1801 |
]
|
| 1802 |
|
| 1803 |
+
# Deprecated keys
|
| 1804 |
DEPRECATED_PATTERNS = [
|
| 1805 |
'guidance_in.',
|
| 1806 |
'.sin_basis',
|
|
|
|
|
|
|
| 1807 |
]
|
| 1808 |
|
| 1809 |
loaded_keys = []
|
|
|
|
| 1812 |
initialized_keys = []
|
| 1813 |
remapped_keys = []
|
| 1814 |
|
| 1815 |
+
# First pass: remap checkpoint keys to match model
|
| 1816 |
remapped_state = {}
|
| 1817 |
for k, v in state_dict.items():
|
| 1818 |
+
new_k = k
|
| 1819 |
+
for old_pat, new_pat in REMAP.items():
|
| 1820 |
+
if old_pat in k:
|
| 1821 |
+
new_k = k.replace(old_pat, new_pat)
|
| 1822 |
+
remapped_keys.append(f"{k} -> {new_k}")
|
| 1823 |
+
break
|
| 1824 |
+
remapped_state[new_k] = v
|
| 1825 |
|
| 1826 |
# Second pass: load matching weights
|
| 1827 |
for key, v in remapped_state.items():
|
|
|
|
| 2022 |
# If running as a script, uncomment the import below:
|
| 2023 |
# from model_v4 import TinyFluxConfig, TinyFlux
|
| 2024 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2025 |
config = TinyFluxConfig(
|
| 2026 |
hidden_size=512,
|
| 2027 |
num_attention_heads=4,
|
|
|
|
| 2048 |
huber_delta=HUBER_DELTA,
|
| 2049 |
guidance_embeds=False,
|
| 2050 |
)
|
| 2051 |
+
model = TinyFluxDeep(config).to(device=DEVICE, dtype=DTYPE)
|
| 2052 |
|
| 2053 |
total_params = sum(p.numel() for p in model.parameters())
|
| 2054 |
print(f"Total parameters: {total_params:,}")
|
|
|
|
| 2103 |
# Remap v3 EMA keys to v4
|
| 2104 |
remapped_ema = {}
|
| 2105 |
for k, v in ema_state.items():
|
| 2106 |
+
#if k in V3_TO_V4_REMAP:
|
| 2107 |
+
# remapped_ema[V3_TO_V4_REMAP[k]] = v
|
| 2108 |
+
#else:
|
| 2109 |
remapped_ema[k] = v
|
| 2110 |
ema.load_shadow(remapped_ema, model=model)
|
| 2111 |
|
|
|
|
| 2199 |
x_t = (1 - t_expanded) * noise + t_expanded * data
|
| 2200 |
v_target = data - noise
|
| 2201 |
|
| 2202 |
+
img_ids = TinyFluxDeep.create_img_ids(B, H, W, DEVICE)
|
| 2203 |
|
| 2204 |
# Get expert features from CACHE
|
| 2205 |
lune_features = None
|
|
|
|
| 2249 |
|
| 2250 |
# Lune distillation loss
|
| 2251 |
lune_loss = torch.tensor(0.0, device=DEVICE)
|
| 2252 |
+
if lune_features is not None and expert_info.get('lune') is not None:
|
| 2253 |
lune_loss = compute_lune_loss(
|
| 2254 |
+
expert_info['lune']['expert_pred'], lune_features, mode=LUNE_DISTILL_MODE
|
| 2255 |
)
|
| 2256 |
|
| 2257 |
# Sol distillation loss
|
| 2258 |
sol_loss = torch.tensor(0.0, device=DEVICE)
|
| 2259 |
+
if sol_stats is not None and expert_info.get('sol') is not None:
|
| 2260 |
+
sol_loss = compute_sol_loss(
|
| 2261 |
+
expert_info['sol']['pred_stats'], expert_info['sol']['pred_spatial'],
|
| 2262 |
+
sol_stats, sol_spatial
|
| 2263 |
+
)
|
|
|
|
| 2264 |
# Total loss with warmup weights
|
| 2265 |
total_loss = main_loss
|
| 2266 |
total_loss = total_loss + get_lune_weight(step) * lune_loss
|