v2: Updated notebook with real dataset training
Browse files- ArtFlow_Training.ipynb +47 -73
ArtFlow_Training.ipynb
CHANGED
|
@@ -4,10 +4,10 @@
|
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
-
"# π¨ ArtFlow Complete Training Notebook\n",
|
| 8 |
-
"**
|
| 9 |
"\n",
|
| 10 |
-
"Downloads model + training scripts from
|
| 11 |
]
|
| 12 |
},
|
| 13 |
{
|
|
@@ -18,7 +18,7 @@
|
|
| 18 |
"source": [
|
| 19 |
"# ===== 0. Setup =====\n",
|
| 20 |
"!pip install -q torch torchvision torchaudio\n",
|
| 21 |
-
"!pip install -q huggingface_hub matplotlib numpy tqdm\n",
|
| 22 |
"\n",
|
| 23 |
"from huggingface_hub import hf_hub_download\n",
|
| 24 |
"import shutil\n",
|
|
@@ -41,7 +41,7 @@
|
|
| 41 |
"# ===== 1. Create Model =====\n",
|
| 42 |
"from artflow_model import ArtFlow, ArtFlowConfig\n",
|
| 43 |
"from artflow_train import (\n",
|
| 44 |
-
" TrainConfig, SyntheticDataset, freeze_for_stage, train\n",
|
| 45 |
")\n",
|
| 46 |
"\n",
|
| 47 |
"DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
@@ -61,7 +61,7 @@
|
|
| 61 |
"\n",
|
| 62 |
"model = ArtFlow(config).to(DEVICE)\n",
|
| 63 |
"p = sum(x.numel() for x in model.parameters())\n",
|
| 64 |
-
"print(f'Model: {p:,} params ({p/1e6:.1f}M)')"
|
| 65 |
]
|
| 66 |
},
|
| 67 |
{
|
|
@@ -70,9 +70,22 @@
|
|
| 70 |
"metadata": {},
|
| 71 |
"outputs": [],
|
| 72 |
"source": [
|
| 73 |
-
"# ===== 2. Dataset =====\n",
|
| 74 |
-
"#
|
| 75 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
]
|
| 77 |
},
|
| 78 |
{
|
|
@@ -82,7 +95,7 @@
|
|
| 82 |
"---\n",
|
| 83 |
"## Stage 1: Base Generation\n",
|
| 84 |
"**Frozen:** style, mood, concept \n",
|
| 85 |
-
"**Trains:** WaveMamba backbone + cross-attention \n",
|
| 86 |
"**Goal:** Learn denoising dynamics"
|
| 87 |
]
|
| 88 |
},
|
|
@@ -114,7 +127,7 @@
|
|
| 114 |
" print(f'Loss: {np.mean(L[:10]):.4f} β {np.mean(L[-10:]):.4f}')\n",
|
| 115 |
" sm = np.convolve(L, np.ones(min(20, len(L)//4))/min(20, len(L)//4), 'valid')\n",
|
| 116 |
" plt.figure(figsize=(10,3))\n",
|
| 117 |
-
" plt.plot(sm); plt.title('Stage 1 Loss'); plt.xlabel('Step'); plt.show()"
|
| 118 |
]
|
| 119 |
},
|
| 120 |
{
|
|
@@ -122,10 +135,8 @@
|
|
| 122 |
"metadata": {},
|
| 123 |
"source": [
|
| 124 |
"---\n",
|
| 125 |
-
"## Stage 2: Style
|
| 126 |
-
"
|
| 127 |
-
"**Trains:** style matrix + backbone (joint fine-tune) \n",
|
| 128 |
-
"**Goal:** Disentangled art style vectors"
|
| 129 |
]
|
| 130 |
},
|
| 131 |
{
|
|
@@ -134,24 +145,12 @@
|
|
| 134 |
"metadata": {},
|
| 135 |
"outputs": [],
|
| 136 |
"source": [
|
|
|
|
| 137 |
"model = freeze_for_stage(model, 2)\n",
|
| 138 |
-
"
|
| 139 |
-
" lr=5e-5, batch_size=2, grad_accum=32,\n",
|
| 140 |
-
"
|
| 141 |
-
"
|
| 142 |
-
")\n",
|
| 143 |
-
"engine2 = train(model, config, tcfg2, dataset, DEVICE)"
|
| 144 |
-
]
|
| 145 |
-
},
|
| 146 |
-
{
|
| 147 |
-
"cell_type": "markdown",
|
| 148 |
-
"metadata": {},
|
| 149 |
-
"source": [
|
| 150 |
-
"---\n",
|
| 151 |
-
"## Stage 3: Resolution Scaling + Reasoning\n",
|
| 152 |
-
"**Frozen:** mood, concept \n",
|
| 153 |
-
"**Trains:** backbone + style + reasoning \n",
|
| 154 |
-
"**Goal:** Higher res, enable recursive latent reasoning"
|
| 155 |
]
|
| 156 |
},
|
| 157 |
{
|
|
@@ -160,24 +159,12 @@
|
|
| 160 |
"metadata": {},
|
| 161 |
"outputs": [],
|
| 162 |
"source": [
|
|
|
|
| 163 |
"model = freeze_for_stage(model, 3)\n",
|
| 164 |
-
"
|
| 165 |
-
" lr=3e-5, batch_size=2, grad_accum=32,\n",
|
| 166 |
-
"
|
| 167 |
-
"
|
| 168 |
-
")\n",
|
| 169 |
-
"engine3 = train(model, config, tcfg3, dataset, DEVICE)"
|
| 170 |
-
]
|
| 171 |
-
},
|
| 172 |
-
{
|
| 173 |
-
"cell_type": "markdown",
|
| 174 |
-
"metadata": {},
|
| 175 |
-
"source": [
|
| 176 |
-
"---\n",
|
| 177 |
-
"## Stage 4: Concept & Mood Training\n",
|
| 178 |
-
"**Frozen:** backbone + style \n",
|
| 179 |
-
"**Trains:** concept engine + mood controller only \n",
|
| 180 |
-
"**Goal:** Scene understanding, emotional atmosphere"
|
| 181 |
]
|
| 182 |
},
|
| 183 |
{
|
|
@@ -186,24 +173,12 @@
|
|
| 186 |
"metadata": {},
|
| 187 |
"outputs": [],
|
| 188 |
"source": [
|
|
|
|
| 189 |
"model = freeze_for_stage(model, 4)\n",
|
| 190 |
-
"
|
| 191 |
-
" lr=2e-5, batch_size=2, grad_accum=32,\n",
|
| 192 |
-
"
|
| 193 |
-
"
|
| 194 |
-
")\n",
|
| 195 |
-
"engine4 = train(model, config, tcfg4, dataset, DEVICE)"
|
| 196 |
-
]
|
| 197 |
-
},
|
| 198 |
-
{
|
| 199 |
-
"cell_type": "markdown",
|
| 200 |
-
"metadata": {},
|
| 201 |
-
"source": [
|
| 202 |
-
"---\n",
|
| 203 |
-
"## Stage 5: Quality Post-Training\n",
|
| 204 |
-
"**Frozen:** nothing (all trainable) \n",
|
| 205 |
-
"**Trains:** everything at low LR \n",
|
| 206 |
-
"**Goal:** Final quality alignment"
|
| 207 |
]
|
| 208 |
},
|
| 209 |
{
|
|
@@ -212,13 +187,12 @@
|
|
| 212 |
"metadata": {},
|
| 213 |
"outputs": [],
|
| 214 |
"source": [
|
|
|
|
| 215 |
"model = freeze_for_stage(model, 5)\n",
|
| 216 |
-
"
|
| 217 |
-
" lr=1e-5, batch_size=2, grad_accum=32,\n",
|
| 218 |
-
"
|
| 219 |
-
"
|
| 220 |
-
")\n",
|
| 221 |
-
"engine5 = train(model, config, tcfg5, dataset, DEVICE)"
|
| 222 |
]
|
| 223 |
},
|
| 224 |
{
|
|
@@ -231,7 +205,7 @@
|
|
| 231 |
"# from huggingface_hub import HfApi\n",
|
| 232 |
"# engine5.save('./artflow_final.pt')\n",
|
| 233 |
"# HfApi().upload_file('./artflow_final.pt', 'artflow_final.pt', 'krystv/ArtFlow')\n",
|
| 234 |
-
"print('π All 5 stages complete!')"
|
| 235 |
]
|
| 236 |
}
|
| 237 |
],
|
|
|
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
+
"# π¨ ArtFlow v2 Complete Training Notebook\n",
|
| 8 |
+
"**Real Mamba SSM backbone β No CUDA extensions needed!**\n",
|
| 9 |
"\n",
|
| 10 |
+
"Downloads model + training scripts from HF repo, then trains with real art datasets."
|
| 11 |
]
|
| 12 |
},
|
| 13 |
{
|
|
|
|
| 18 |
"source": [
|
| 19 |
"# ===== 0. Setup =====\n",
|
| 20 |
"!pip install -q torch torchvision torchaudio\n",
|
| 21 |
+
"!pip install -q huggingface_hub matplotlib numpy tqdm datasets\n",
|
| 22 |
"\n",
|
| 23 |
"from huggingface_hub import hf_hub_download\n",
|
| 24 |
"import shutil\n",
|
|
|
|
| 41 |
"# ===== 1. Create Model =====\n",
|
| 42 |
"from artflow_model import ArtFlow, ArtFlowConfig\n",
|
| 43 |
"from artflow_train import (\n",
|
| 44 |
+
" TrainConfig, SyntheticDataset, RealArtDataset, freeze_for_stage, train\n",
|
| 45 |
")\n",
|
| 46 |
"\n",
|
| 47 |
"DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
|
|
| 61 |
"\n",
|
| 62 |
"model = ArtFlow(config).to(DEVICE)\n",
|
| 63 |
"p = sum(x.numel() for x in model.parameters())\n",
|
| 64 |
+
"print(f'Model: {p:,} params ({p/1e6:.1f}M) β Real Mamba SSM!')"
|
| 65 |
]
|
| 66 |
},
|
| 67 |
{
|
|
|
|
| 70 |
"metadata": {},
|
| 71 |
"outputs": [],
|
| 72 |
"source": [
|
| 73 |
+
"# ===== 2. Dataset (REAL β not synthetic!) =====\n",
|
| 74 |
+
"# Choose one:\n",
|
| 75 |
+
"# - 'huggan/wikiart' (80K art paintings, 27 styles)\n",
|
| 76 |
+
"# - 'Fazzie/Teyvat' (anime illustrations)\n",
|
| 77 |
+
"# - 'diffusers/pokemon-gpt4-captions' (800 pokemon, good captions)\n",
|
| 78 |
+
"# - 'lambdalabs/naruto-blip-captions' (anime faces)\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"dataset = RealArtDataset(\n",
|
| 81 |
+
" 'diffusers/pokemon-gpt4-captions', # Small but high quality\n",
|
| 82 |
+
" config=config,\n",
|
| 83 |
+
" max_samples=None, # Use all samples\n",
|
| 84 |
+
")\n",
|
| 85 |
+
"print(f'Dataset: {len(dataset)} samples')\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"# For smoke test, use synthetic:\n",
|
| 88 |
+
"# dataset = SyntheticDataset(n=10000, config=config)"
|
| 89 |
]
|
| 90 |
},
|
| 91 |
{
|
|
|
|
| 95 |
"---\n",
|
| 96 |
"## Stage 1: Base Generation\n",
|
| 97 |
"**Frozen:** style, mood, concept \n",
|
| 98 |
+
"**Trains:** WaveMamba backbone (Real Mamba SSM!) + cross-attention \n",
|
| 99 |
"**Goal:** Learn denoising dynamics"
|
| 100 |
]
|
| 101 |
},
|
|
|
|
| 127 |
" print(f'Loss: {np.mean(L[:10]):.4f} β {np.mean(L[-10:]):.4f}')\n",
|
| 128 |
" sm = np.convolve(L, np.ones(min(20, len(L)//4))/min(20, len(L)//4), 'valid')\n",
|
| 129 |
" plt.figure(figsize=(10,3))\n",
|
| 130 |
+
" plt.plot(sm); plt.title('Stage 1 Loss (Real Mamba SSM)'); plt.xlabel('Step'); plt.show()"
|
| 131 |
]
|
| 132 |
},
|
| 133 |
{
|
|
|
|
| 135 |
"metadata": {},
|
| 136 |
"source": [
|
| 137 |
"---\n",
|
| 138 |
+
"## Stage 2-5: Style, Resolution, Concept, Quality\n",
|
| 139 |
+
"Same as before but now with Real Mamba SSM backbone."
|
|
|
|
|
|
|
| 140 |
]
|
| 141 |
},
|
| 142 |
{
|
|
|
|
| 145 |
"metadata": {},
|
| 146 |
"outputs": [],
|
| 147 |
"source": [
|
| 148 |
+
"# Stage 2: Style Matrix\n",
|
| 149 |
"model = freeze_for_stage(model, 2)\n",
|
| 150 |
+
"engine2 = train(model, config, TrainConfig(\n",
|
| 151 |
+
" lr=5e-5, batch_size=2, grad_accum=32, num_steps=25000,\n",
|
| 152 |
+
" warmup_steps=500, log_every=100, save_every=5000, stage=2,\n",
|
| 153 |
+
"), dataset, DEVICE)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
]
|
| 155 |
},
|
| 156 |
{
|
|
|
|
| 159 |
"metadata": {},
|
| 160 |
"outputs": [],
|
| 161 |
"source": [
|
| 162 |
+
"# Stage 3: Resolution + Reasoning\n",
|
| 163 |
"model = freeze_for_stage(model, 3)\n",
|
| 164 |
+
"engine3 = train(model, config, TrainConfig(\n",
|
| 165 |
+
" lr=3e-5, batch_size=2, grad_accum=32, num_steps=25000,\n",
|
| 166 |
+
" warmup_steps=500, log_every=100, save_every=5000, stage=3,\n",
|
| 167 |
+
"), dataset, DEVICE)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
]
|
| 169 |
},
|
| 170 |
{
|
|
|
|
| 173 |
"metadata": {},
|
| 174 |
"outputs": [],
|
| 175 |
"source": [
|
| 176 |
+
"# Stage 4: Concept & Mood\n",
|
| 177 |
"model = freeze_for_stage(model, 4)\n",
|
| 178 |
+
"engine4 = train(model, config, TrainConfig(\n",
|
| 179 |
+
" lr=2e-5, batch_size=2, grad_accum=32, num_steps=15000,\n",
|
| 180 |
+
" warmup_steps=300, log_every=100, save_every=5000, stage=4,\n",
|
| 181 |
+
"), dataset, DEVICE)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
]
|
| 183 |
},
|
| 184 |
{
|
|
|
|
| 187 |
"metadata": {},
|
| 188 |
"outputs": [],
|
| 189 |
"source": [
|
| 190 |
+
"# Stage 5: Quality Alignment\n",
|
| 191 |
"model = freeze_for_stage(model, 5)\n",
|
| 192 |
+
"engine5 = train(model, config, TrainConfig(\n",
|
| 193 |
+
" lr=1e-5, batch_size=2, grad_accum=32, num_steps=5000,\n",
|
| 194 |
+
" warmup_steps=200, log_every=50, save_every=2500, stage=5,\n",
|
| 195 |
+
"), dataset, DEVICE)"
|
|
|
|
|
|
|
| 196 |
]
|
| 197 |
},
|
| 198 |
{
|
|
|
|
| 205 |
"# from huggingface_hub import HfApi\n",
|
| 206 |
"# engine5.save('./artflow_final.pt')\n",
|
| 207 |
"# HfApi().upload_file('./artflow_final.pt', 'artflow_final.pt', 'krystv/ArtFlow')\n",
|
| 208 |
+
"print('π All 5 stages complete β Real Mamba SSM, real datasets!')"
|
| 209 |
]
|
| 210 |
}
|
| 211 |
],
|