krystv commited on
Commit
3ab3cbd
Β·
verified Β·
1 Parent(s): 4c58a98

v2: Updated notebook with real dataset training

Browse files
Files changed (1) hide show
  1. ArtFlow_Training.ipynb +47 -73
ArtFlow_Training.ipynb CHANGED
@@ -4,10 +4,10 @@
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
- "# 🎨 ArtFlow Complete Training Notebook\n",
8
- "**Reasoning-Native Artistic Image Generation for Mobile Devices**\n",
9
  "\n",
10
- "Downloads model + training scripts from the HF repo, then trains all 5 stages."
11
  ]
12
  },
13
  {
@@ -18,7 +18,7 @@
18
  "source": [
19
  "# ===== 0. Setup =====\n",
20
  "!pip install -q torch torchvision torchaudio\n",
21
- "!pip install -q huggingface_hub matplotlib numpy tqdm\n",
22
  "\n",
23
  "from huggingface_hub import hf_hub_download\n",
24
  "import shutil\n",
@@ -41,7 +41,7 @@
41
  "# ===== 1. Create Model =====\n",
42
  "from artflow_model import ArtFlow, ArtFlowConfig\n",
43
  "from artflow_train import (\n",
44
- " TrainConfig, SyntheticDataset, freeze_for_stage, train\n",
45
  ")\n",
46
  "\n",
47
  "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
@@ -61,7 +61,7 @@
61
  "\n",
62
  "model = ArtFlow(config).to(DEVICE)\n",
63
  "p = sum(x.numel() for x in model.parameters())\n",
64
- "print(f'Model: {p:,} params ({p/1e6:.1f}M)')"
65
  ]
66
  },
67
  {
@@ -70,9 +70,22 @@
70
  "metadata": {},
71
  "outputs": [],
72
  "source": [
73
- "# ===== 2. Dataset =====\n",
74
- "# Synthetic for smoke-testing. Replace with real latents for real training.\n",
75
- "dataset = SyntheticDataset(n=10000, config=config)"
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  ]
77
  },
78
  {
@@ -82,7 +95,7 @@
82
  "---\n",
83
  "## Stage 1: Base Generation\n",
84
  "**Frozen:** style, mood, concept \n",
85
- "**Trains:** WaveMamba backbone + cross-attention \n",
86
  "**Goal:** Learn denoising dynamics"
87
  ]
88
  },
@@ -114,7 +127,7 @@
114
  " print(f'Loss: {np.mean(L[:10]):.4f} β†’ {np.mean(L[-10:]):.4f}')\n",
115
  " sm = np.convolve(L, np.ones(min(20, len(L)//4))/min(20, len(L)//4), 'valid')\n",
116
  " plt.figure(figsize=(10,3))\n",
117
- " plt.plot(sm); plt.title('Stage 1 Loss'); plt.xlabel('Step'); plt.show()"
118
  ]
119
  },
120
  {
@@ -122,10 +135,8 @@
122
  "metadata": {},
123
  "source": [
124
  "---\n",
125
- "## Stage 2: Style Matrix Training\n",
126
- "**Frozen:** mood, concept \n",
127
- "**Trains:** style matrix + backbone (joint fine-tune) \n",
128
- "**Goal:** Disentangled art style vectors"
129
  ]
130
  },
131
  {
@@ -134,24 +145,12 @@
134
  "metadata": {},
135
  "outputs": [],
136
  "source": [
 
137
  "model = freeze_for_stage(model, 2)\n",
138
- "tcfg2 = TrainConfig(\n",
139
- " lr=5e-5, batch_size=2, grad_accum=32,\n",
140
- " num_steps=25000, warmup_steps=500,\n",
141
- " log_every=100, save_every=5000, stage=2,\n",
142
- ")\n",
143
- "engine2 = train(model, config, tcfg2, dataset, DEVICE)"
144
- ]
145
- },
146
- {
147
- "cell_type": "markdown",
148
- "metadata": {},
149
- "source": [
150
- "---\n",
151
- "## Stage 3: Resolution Scaling + Reasoning\n",
152
- "**Frozen:** mood, concept \n",
153
- "**Trains:** backbone + style + reasoning \n",
154
- "**Goal:** Higher res, enable recursive latent reasoning"
155
  ]
156
  },
157
  {
@@ -160,24 +159,12 @@
160
  "metadata": {},
161
  "outputs": [],
162
  "source": [
 
163
  "model = freeze_for_stage(model, 3)\n",
164
- "tcfg3 = TrainConfig(\n",
165
- " lr=3e-5, batch_size=2, grad_accum=32,\n",
166
- " num_steps=25000, warmup_steps=500,\n",
167
- " log_every=100, save_every=5000, stage=3,\n",
168
- ")\n",
169
- "engine3 = train(model, config, tcfg3, dataset, DEVICE)"
170
- ]
171
- },
172
- {
173
- "cell_type": "markdown",
174
- "metadata": {},
175
- "source": [
176
- "---\n",
177
- "## Stage 4: Concept & Mood Training\n",
178
- "**Frozen:** backbone + style \n",
179
- "**Trains:** concept engine + mood controller only \n",
180
- "**Goal:** Scene understanding, emotional atmosphere"
181
  ]
182
  },
183
  {
@@ -186,24 +173,12 @@
186
  "metadata": {},
187
  "outputs": [],
188
  "source": [
 
189
  "model = freeze_for_stage(model, 4)\n",
190
- "tcfg4 = TrainConfig(\n",
191
- " lr=2e-5, batch_size=2, grad_accum=32,\n",
192
- " num_steps=15000, warmup_steps=300,\n",
193
- " log_every=100, save_every=5000, stage=4,\n",
194
- ")\n",
195
- "engine4 = train(model, config, tcfg4, dataset, DEVICE)"
196
- ]
197
- },
198
- {
199
- "cell_type": "markdown",
200
- "metadata": {},
201
- "source": [
202
- "---\n",
203
- "## Stage 5: Quality Post-Training\n",
204
- "**Frozen:** nothing (all trainable) \n",
205
- "**Trains:** everything at low LR \n",
206
- "**Goal:** Final quality alignment"
207
  ]
208
  },
209
  {
@@ -212,13 +187,12 @@
212
  "metadata": {},
213
  "outputs": [],
214
  "source": [
 
215
  "model = freeze_for_stage(model, 5)\n",
216
- "tcfg5 = TrainConfig(\n",
217
- " lr=1e-5, batch_size=2, grad_accum=32,\n",
218
- " num_steps=5000, warmup_steps=200,\n",
219
- " log_every=50, save_every=2500, stage=5,\n",
220
- ")\n",
221
- "engine5 = train(model, config, tcfg5, dataset, DEVICE)"
222
  ]
223
  },
224
  {
@@ -231,7 +205,7 @@
231
  "# from huggingface_hub import HfApi\n",
232
  "# engine5.save('./artflow_final.pt')\n",
233
  "# HfApi().upload_file('./artflow_final.pt', 'artflow_final.pt', 'krystv/ArtFlow')\n",
234
- "print('πŸŽ‰ All 5 stages complete!')"
235
  ]
236
  }
237
  ],
 
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
+ "# 🎨 ArtFlow v2 Complete Training Notebook\n",
8
+ "**Real Mamba SSM backbone β€” No CUDA extensions needed!**\n",
9
  "\n",
10
+ "Downloads model + training scripts from HF repo, then trains with real art datasets."
11
  ]
12
  },
13
  {
 
18
  "source": [
19
  "# ===== 0. Setup =====\n",
20
  "!pip install -q torch torchvision torchaudio\n",
21
+ "!pip install -q huggingface_hub matplotlib numpy tqdm datasets\n",
22
  "\n",
23
  "from huggingface_hub import hf_hub_download\n",
24
  "import shutil\n",
 
41
  "# ===== 1. Create Model =====\n",
42
  "from artflow_model import ArtFlow, ArtFlowConfig\n",
43
  "from artflow_train import (\n",
44
+ " TrainConfig, SyntheticDataset, RealArtDataset, freeze_for_stage, train\n",
45
  ")\n",
46
  "\n",
47
  "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
 
61
  "\n",
62
  "model = ArtFlow(config).to(DEVICE)\n",
63
  "p = sum(x.numel() for x in model.parameters())\n",
64
+ "print(f'Model: {p:,} params ({p/1e6:.1f}M) β€” Real Mamba SSM!')"
65
  ]
66
  },
67
  {
 
70
  "metadata": {},
71
  "outputs": [],
72
  "source": [
73
+ "# ===== 2. Dataset (REAL β€” not synthetic!) =====\n",
74
+ "# Choose one:\n",
75
+ "# - 'huggan/wikiart' (80K art paintings, 27 styles)\n",
76
+ "# - 'Fazzie/Teyvat' (anime illustrations)\n",
77
+ "# - 'diffusers/pokemon-gpt4-captions' (800 pokemon, good captions)\n",
78
+ "# - 'lambdalabs/naruto-blip-captions' (anime faces)\n",
79
+ "\n",
80
+ "dataset = RealArtDataset(\n",
81
+ " 'diffusers/pokemon-gpt4-captions', # Small but high quality\n",
82
+ " config=config,\n",
83
+ " max_samples=None, # Use all samples\n",
84
+ ")\n",
85
+ "print(f'Dataset: {len(dataset)} samples')\n",
86
+ "\n",
87
+ "# For smoke test, use synthetic:\n",
88
+ "# dataset = SyntheticDataset(n=10000, config=config)"
89
  ]
90
  },
91
  {
 
95
  "---\n",
96
  "## Stage 1: Base Generation\n",
97
  "**Frozen:** style, mood, concept \n",
98
+ "**Trains:** WaveMamba backbone (Real Mamba SSM!) + cross-attention \n",
99
  "**Goal:** Learn denoising dynamics"
100
  ]
101
  },
 
127
  " print(f'Loss: {np.mean(L[:10]):.4f} β†’ {np.mean(L[-10:]):.4f}')\n",
128
  " sm = np.convolve(L, np.ones(min(20, len(L)//4))/min(20, len(L)//4), 'valid')\n",
129
  " plt.figure(figsize=(10,3))\n",
130
+ " plt.plot(sm); plt.title('Stage 1 Loss (Real Mamba SSM)'); plt.xlabel('Step'); plt.show()"
131
  ]
132
  },
133
  {
 
135
  "metadata": {},
136
  "source": [
137
  "---\n",
138
+ "## Stage 2-5: Style, Resolution, Concept, Quality\n",
139
+ "Same as before but now with Real Mamba SSM backbone."
 
 
140
  ]
141
  },
142
  {
 
145
  "metadata": {},
146
  "outputs": [],
147
  "source": [
148
+ "# Stage 2: Style Matrix\n",
149
  "model = freeze_for_stage(model, 2)\n",
150
+ "engine2 = train(model, config, TrainConfig(\n",
151
+ " lr=5e-5, batch_size=2, grad_accum=32, num_steps=25000,\n",
152
+ " warmup_steps=500, log_every=100, save_every=5000, stage=2,\n",
153
+ "), dataset, DEVICE)"
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  ]
155
  },
156
  {
 
159
  "metadata": {},
160
  "outputs": [],
161
  "source": [
162
+ "# Stage 3: Resolution + Reasoning\n",
163
  "model = freeze_for_stage(model, 3)\n",
164
+ "engine3 = train(model, config, TrainConfig(\n",
165
+ " lr=3e-5, batch_size=2, grad_accum=32, num_steps=25000,\n",
166
+ " warmup_steps=500, log_every=100, save_every=5000, stage=3,\n",
167
+ "), dataset, DEVICE)"
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  ]
169
  },
170
  {
 
173
  "metadata": {},
174
  "outputs": [],
175
  "source": [
176
+ "# Stage 4: Concept & Mood\n",
177
  "model = freeze_for_stage(model, 4)\n",
178
+ "engine4 = train(model, config, TrainConfig(\n",
179
+ " lr=2e-5, batch_size=2, grad_accum=32, num_steps=15000,\n",
180
+ " warmup_steps=300, log_every=100, save_every=5000, stage=4,\n",
181
+ "), dataset, DEVICE)"
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  ]
183
  },
184
  {
 
187
  "metadata": {},
188
  "outputs": [],
189
  "source": [
190
+ "# Stage 5: Quality Alignment\n",
191
  "model = freeze_for_stage(model, 5)\n",
192
+ "engine5 = train(model, config, TrainConfig(\n",
193
+ " lr=1e-5, batch_size=2, grad_accum=32, num_steps=5000,\n",
194
+ " warmup_steps=200, log_every=50, save_every=2500, stage=5,\n",
195
+ "), dataset, DEVICE)"
 
 
196
  ]
197
  },
198
  {
 
205
  "# from huggingface_hub import HfApi\n",
206
  "# engine5.save('./artflow_final.pt')\n",
207
  "# HfApi().upload_file('./artflow_final.pt', 'artflow_final.pt', 'krystv/ArtFlow')\n",
208
+ "print('πŸŽ‰ All 5 stages complete β€” Real Mamba SSM, real datasets!')"
209
  ]
210
  }
211
  ],