Alogotron commited on
Commit
a564872
·
verified ·
1 Parent(s): 066e1de

Upload sdxl/train_sdxl_adapter.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sdxl/train_sdxl_adapter.py +263 -0
sdxl/train_sdxl_adapter.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train SDXL adapter: Qwen3-4B activations -> SDXL prompt embeddings.
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import numpy as np
9
+ import json
10
+ import os
11
+ from pathlib import Path
12
+ from datetime import datetime
13
+
14
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
15
+
16
+ from sdxl_adapter import SDXLCrossAttentionAdapter, LayerWeightedInput
17
+
18
+
19
+ class ActivationEmbeddingDataset(Dataset):
20
+ """Dataset of (Qwen activation, SDXL embedding) pairs."""
21
+ def __init__(self, activation_dir, embedding_dir, metadata_path):
22
+ with open(metadata_path) as f:
23
+ self.metadata = json.load(f)
24
+ self.activation_dir = Path(activation_dir)
25
+ self.embedding_dir = Path(embedding_dir)
26
+
27
+ # Validate all files exist
28
+ valid = []
29
+ for item in self.metadata:
30
+ emotion = item['emotion']
31
+ idx = item['index']
32
+ act_file = self.activation_dir / f"{emotion}_{idx:02d}.npy"
33
+ emb_file = self.embedding_dir / f"{emotion}_{idx:02d}.npz"
34
+ if act_file.exists() and emb_file.exists():
35
+ valid.append(item)
36
+ self.metadata = valid
37
+ print(f"Dataset: {len(self.metadata)} valid pairs")
38
+
39
+ def __len__(self):
40
+ return len(self.metadata)
41
+
42
+ def __getitem__(self, idx):
43
+ item = self.metadata[idx]
44
+ emotion = item['emotion']
45
+ i = item['index']
46
+
47
+ # Load Qwen activation [7680]
48
+ act = np.load(self.activation_dir / f"{emotion}_{i:02d}.npy")
49
+ act = torch.from_numpy(act).float()
50
+
51
+ # Load SDXL embeddings
52
+ emb = np.load(self.embedding_dir / f"{emotion}_{i:02d}.npz")
53
+ prompt_embeds = torch.from_numpy(emb['prompt_embeds']).float().squeeze(0) # [77, 2048]
54
+ pooled_embeds = torch.from_numpy(emb['pooled_prompt_embeds']).float().squeeze(0) # [1280]
55
+
56
+ return act, prompt_embeds, pooled_embeds, emotion
57
+
58
+
59
+ def compute_normalization(dataset):
60
+ """Compute mean/std of activations and targets for normalization."""
61
+ all_acts = []
62
+ all_main = []
63
+ all_pooled = []
64
+ for i in range(len(dataset)):
65
+ act, main, pooled, _ = dataset[i]
66
+ all_acts.append(act)
67
+ all_main.append(main)
68
+ all_pooled.append(pooled)
69
+
70
+ acts = torch.stack(all_acts)
71
+ mains = torch.stack(all_main)
72
+ pooleds = torch.stack(all_pooled)
73
+
74
+ return {
75
+ 'act_mean': acts.mean(dim=0),
76
+ 'act_std': acts.std(dim=0).clamp(min=1e-6),
77
+ 'main_mean': mains.mean(dim=(0, 1)), # [2048] - mean across batch and tokens
78
+ 'main_std': mains.std(dim=(0, 1)).clamp(min=1e-6),
79
+ 'pooled_mean': pooleds.mean(dim=0), # [1280]
80
+ 'pooled_std': pooleds.std(dim=0).clamp(min=1e-6),
81
+ }
82
+
83
+
84
+ def train():
85
+ device = torch.device('cuda')
86
+ base_dir = Path('/home/beta1/milady-training')
87
+
88
+ print(f"[{datetime.now()}] Loading dataset...")
89
+ dataset = ActivationEmbeddingDataset(
90
+ activation_dir=base_dir / 'qwen_activations',
91
+ embedding_dir=base_dir / 'sdxl_embeddings',
92
+ metadata_path=base_dir / 'sdxl_emotions' / 'metadata.json',
93
+ )
94
+
95
+ # Compute normalization stats
96
+ print(f"[{datetime.now()}] Computing normalization statistics...")
97
+ norm = compute_normalization(dataset)
98
+ for k, v in norm.items():
99
+ print(f" {k}: shape={v.shape}, mean={v.mean():.4f}, std={v.std():.4f}")
100
+
101
+ # Train/val split (80/20)
102
+ n = len(dataset)
103
+ n_val = max(1, n // 5)
104
+ n_train = n - n_val
105
+ train_ds, val_ds = torch.utils.data.random_split(
106
+ dataset, [n_train, n_val],
107
+ generator=torch.Generator().manual_seed(42)
108
+ )
109
+ print(f"Train: {n_train}, Val: {n_val}")
110
+
111
+ train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, drop_last=True)
112
+ val_loader = DataLoader(val_ds, batch_size=16, shuffle=False)
113
+
114
+ # Models
115
+ layer_weight = LayerWeightedInput(n_layers=3, layer_dim=2560).to(device)
116
+ adapter = SDXLCrossAttentionAdapter(
117
+ in_dim=2560, rank=256, n_input_tokens=8,
118
+ n_heads=8, n_layers=3,
119
+ n_output_tokens=77, main_dim=2048, pooled_dim=1280,
120
+ ).to(device)
121
+
122
+ print(f"LayerWeight params: {sum(p.numel() for p in layer_weight.parameters()):,}")
123
+ print(f"Adapter params: {sum(p.numel() for p in adapter.parameters()):,}")
124
+
125
+ # Move norm stats to device
126
+ act_mean = norm['act_mean'].to(device)
127
+ act_std = norm['act_std'].to(device)
128
+
129
+ # Optimizer
130
+ params = list(adapter.parameters()) + list(layer_weight.parameters())
131
+ optimizer = torch.optim.AdamW(params, lr=1e-4, weight_decay=1e-2)
132
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500, eta_min=1e-6)
133
+
134
+ # Training loop
135
+ n_epochs = 500
136
+ best_val_loss = float('inf')
137
+ patience = 50
138
+ patience_counter = 0
139
+ save_dir = base_dir / 'sdxl_adapter_checkpoints'
140
+ save_dir.mkdir(exist_ok=True)
141
+
142
+ print(f"\n[{datetime.now()}] Starting training for {n_epochs} epochs...")
143
+
144
+ for epoch in range(n_epochs):
145
+ # Train
146
+ adapter.train()
147
+ layer_weight.train()
148
+ train_loss_sum = 0
149
+ train_main_loss_sum = 0
150
+ train_pooled_loss_sum = 0
151
+ n_batches = 0
152
+
153
+ for acts, target_main, target_pooled, _ in train_loader:
154
+ acts = acts.to(device)
155
+ target_main = target_main.to(device)
156
+ target_pooled = target_pooled.to(device)
157
+
158
+ # Normalize activations
159
+ acts_norm = (acts - act_mean) / act_std
160
+
161
+ # Layer weighting
162
+ x = layer_weight(acts_norm) # [B, 2560]
163
+
164
+ # Forward
165
+ pred_main, pred_pooled = adapter(x) # [B, 77, 2048], [B, 1280]
166
+
167
+ # Losses
168
+ main_loss = F.mse_loss(pred_main, target_main)
169
+ pooled_loss = F.mse_loss(pred_pooled, target_pooled)
170
+ loss = main_loss + 0.5 * pooled_loss
171
+
172
+ optimizer.zero_grad()
173
+ loss.backward()
174
+ torch.nn.utils.clip_grad_norm_(params, 1.0)
175
+ optimizer.step()
176
+
177
+ train_loss_sum += loss.item()
178
+ train_main_loss_sum += main_loss.item()
179
+ train_pooled_loss_sum += pooled_loss.item()
180
+ n_batches += 1
181
+
182
+ scheduler.step()
183
+
184
+ # Validate
185
+ adapter.eval()
186
+ layer_weight.eval()
187
+ val_loss_sum = 0
188
+ val_main_sum = 0
189
+ val_pooled_sum = 0
190
+ v_batches = 0
191
+
192
+ with torch.no_grad():
193
+ for acts, target_main, target_pooled, _ in val_loader:
194
+ acts = acts.to(device)
195
+ target_main = target_main.to(device)
196
+ target_pooled = target_pooled.to(device)
197
+
198
+ acts_norm = (acts - act_mean) / act_std
199
+ x = layer_weight(acts_norm)
200
+ pred_main, pred_pooled = adapter(x)
201
+
202
+ main_loss = F.mse_loss(pred_main, target_main)
203
+ pooled_loss = F.mse_loss(pred_pooled, target_pooled)
204
+ loss = main_loss + 0.5 * pooled_loss
205
+
206
+ val_loss_sum += loss.item()
207
+ val_main_sum += main_loss.item()
208
+ val_pooled_sum += pooled_loss.item()
209
+ v_batches += 1
210
+
211
+ avg_train = train_loss_sum / max(n_batches, 1)
212
+ avg_val = val_loss_sum / max(v_batches, 1)
213
+ avg_train_main = train_main_loss_sum / max(n_batches, 1)
214
+ avg_train_pooled = train_pooled_loss_sum / max(n_batches, 1)
215
+ avg_val_main = val_main_sum / max(v_batches, 1)
216
+ avg_val_pooled = val_pooled_sum / max(v_batches, 1)
217
+
218
+ if (epoch + 1) % 10 == 0 or epoch == 0:
219
+ lr = scheduler.get_last_lr()[0]
220
+ print(f" Epoch {epoch+1:3d} | Train: {avg_train:.6f} (main={avg_train_main:.6f} pool={avg_train_pooled:.6f}) | Val: {avg_val:.6f} (main={avg_val_main:.6f} pool={avg_val_pooled:.6f}) | LR: {lr:.2e}")
221
+
222
+ # Early stopping & best model
223
+ if avg_val < best_val_loss:
224
+ best_val_loss = avg_val
225
+ patience_counter = 0
226
+ checkpoint = {
227
+ 'epoch': epoch + 1,
228
+ 'model_type': 'sdxl_cross_attention',
229
+ 'adapter_state_dict': adapter.state_dict(),
230
+ 'layer_weight_state_dict': layer_weight.state_dict(),
231
+ 'in_dim': 2560,
232
+ 'out_dim_main': 2048,
233
+ 'out_dim_pooled': 1280,
234
+ 'n_tokens': 77,
235
+ 'rank': 256,
236
+ 'n_input_tokens': 8,
237
+ 'n_heads': 8,
238
+ 'n_layers': 3,
239
+ 'input_layers': 'learned_weight',
240
+ 'hook_layers': [9, 18, 27],
241
+ 'act_mean': act_mean.cpu(),
242
+ 'act_std': act_std.cpu(),
243
+ 'train_loss': avg_train,
244
+ 'val_loss': avg_val,
245
+ }
246
+ torch.save(checkpoint, save_dir / 'best_sdxl_adapter.pt')
247
+ else:
248
+ patience_counter += 1
249
+ if patience_counter >= patience:
250
+ print(f" Early stopping at epoch {epoch+1}")
251
+ break
252
+
253
+ # Periodic save
254
+ if (epoch + 1) % 100 == 0:
255
+ torch.save(checkpoint, save_dir / f'sdxl_adapter_epoch{epoch+1}.pt')
256
+
257
+ print(f"\n[{datetime.now()}] Training complete!")
258
+ print(f"Best val loss: {best_val_loss:.6f}")
259
+ print(f"Best model saved to: {save_dir / 'best_sdxl_adapter.pt'}")
260
+
261
+
262
+ if __name__ == '__main__':
263
+ train()