Kyryll Kochkin commited on
Commit
ad9ba57
·
1 Parent(s): b9e551d

new frontend

Browse files
Files changed (9) hide show
  1. .DS_Store +0 -0
  2. .gitignore +1 -0
  3. app.py +79 -19
  4. dataset.py +1 -0
  5. templates/index.html +116 -176
  6. train_conv.py +15 -3
  7. train_diff.py +67 -0
  8. vq_transformer.py +1 -1
  9. vq_vae.py +8 -2
.DS_Store ADDED
Binary file (8.2 kB). View file
 
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
app.py CHANGED
@@ -17,8 +17,8 @@ from vq_vae import VQVAE
17
 
18
  app = Flask(__name__, template_folder="templates")
19
 
20
-
21
- device = "mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" # don't know why but mps runs slower than cpu on m2 mac
22
 
23
  # ------------------------------------------------------------------------------
24
  # Model Status Tracking
@@ -28,14 +28,13 @@ available_models = {
28
  "moe": False, # MoEPixelTransformer (mixture of experts)
29
  "conv": False, # ConvGenerator (direct generation)
30
  "vq": False, # VQTransformer (autoregressive by token)
31
- "vq-vae": False # VQ-VAE only (encode/decode)
 
32
  }
33
 
34
  # ------------------------------------------------------------------------------
35
  # Load models
36
  # ------------------------------------------------------------------------------
37
-
38
-
39
  # 1. Load ConvGenerator
40
  try:
41
  conv_config = ConvConfig.from_pretrained("my_conv")
@@ -93,6 +92,22 @@ try:
93
  except Exception as e:
94
  print(f"✗ Error loading VQ models: {str(e)}")
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # Select default model (use the first available one)
97
  for model_name, is_available in available_models.items():
98
  if is_available:
@@ -254,6 +269,52 @@ def generate_vq_vae_digit():
254
  print(f"Error in VQ-VAE reconstruction: {str(e)}")
255
  return Response(str(e), status=500)
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  # ------------------------------------------------------------------------------
258
  # STREAM DIGIT (Pixel-by-pixel generation or token-by-token for VQ)
259
  # ------------------------------------------------------------------------------
@@ -292,28 +353,27 @@ def stream_digit():
292
  time.sleep(0.005)
293
 
294
  elif model_name == "vq" and vq_transformer_model and vq_model:
295
- # VQ-Transformer (token by token, then decode)
296
  generator = vq_transformer_model.generate_token_stream(digit, device)
297
  tokens = []
298
-
299
- # Stream token generation progress
300
  for i, token in enumerate(generator):
301
  tokens.append(token)
302
- progress = int((i+1) * 100 / 49) # 49 tokens total
303
  yield f"data: token:{i+1}:{progress}\n\n"
304
  time.sleep(0.01)
305
-
306
- # Then decode tokens to image
307
- if len(tokens) == 49: # Make sure we have all tokens
308
- token_tensor = torch.tensor(tokens, dtype=torch.long, device=device).reshape(1, 7, 7)
309
  decoded_img = vq_model.decode(token_tensor)
310
  img_array = (decoded_img.cpu().squeeze().numpy() * 255).astype(np.uint8)
311
-
312
- # Stream the final image pixels
313
- flattened_pixels = img_array.flatten()
314
- for pixel in flattened_pixels:
315
- yield f"data: {int(pixel)}\n\n"
316
- time.sleep(0.001)
317
  else:
318
  yield "data: Error: Invalid model selected or model not available.\n\n"
319
 
 
17
 
18
  app = Flask(__name__, template_folder="templates")
19
 
20
+ # Detect device — adjust if you want to force 'mps' or 'cuda'
21
+ device = 'cpu' # don't know why but mps runs slower than cpu
22
 
23
  # ------------------------------------------------------------------------------
24
  # Model Status Tracking
 
28
  "moe": False, # MoEPixelTransformer (mixture of experts)
29
  "conv": False, # ConvGenerator (direct generation)
30
  "vq": False, # VQTransformer (autoregressive by token)
31
+ "vq-vae": False, # VQ-VAE only (encode/decode)
32
+ "diffusion": False # Diffusion model (DDPM)
33
  }
34
 
35
  # ------------------------------------------------------------------------------
36
  # Load models
37
  # ------------------------------------------------------------------------------
 
 
38
  # 1. Load ConvGenerator
39
  try:
40
  conv_config = ConvConfig.from_pretrained("my_conv")
 
92
  except Exception as e:
93
  print(f"✗ Error loading VQ models: {str(e)}")
94
 
95
+ # 5. Load Diffusion pipeline if available
96
+ diffusion_pipe = None
97
+ try:
98
+ from diffusers import DDPMPipeline
99
+ diffusion_model_dir = "my_diffusion_model"
100
+ if os.path.exists(diffusion_model_dir):
101
+ diffusion_pipe = DDPMPipeline.from_pretrained(
102
+ diffusion_model_dir, torch_dtype=torch.float32
103
+ ).to(device)
104
+ available_models["diffusion"] = True
105
+ print("✓ Diffusion pipeline loaded successfully")
106
+ else:
107
+ print(f"✗ Diffusion model directory '{diffusion_model_dir}' not found, skipping diffusion.")
108
+ except Exception as e:
109
+ diffusion_pipe = None
110
+ print(f"✗ Error loading Diffusion pipeline: {str(e)}")
111
  # Select default model (use the first available one)
112
  for model_name, is_available in available_models.items():
113
  if is_available:
 
269
  print(f"Error in VQ-VAE reconstruction: {str(e)}")
270
  return Response(str(e), status=500)
271
 
272
+
273
+ # ------------------------------------------------------------------------------
274
+ # DIFFUSION GENERATION (using DDPM pipeline)
275
+ # ------------------------------------------------------------------------------
276
+ @app.route("/generate_diffusion_digit", methods=["GET"])
277
+ def generate_diffusion_digit():
278
+ """Generate image using diffusion model (DDPM)."""
279
+ if diffusion_pipe is None:
280
+ return Response("Diffusion model not loaded", status=500)
281
+ try:
282
+ digit = int(request.args.get("digit", 0))
283
+ steps = int(request.args.get("steps", 50))
284
+ print(f"Generating diffusion image for digit {digit} with {steps} steps...")
285
+ num_steps = steps
286
+ scheduler = diffusion_pipe.scheduler
287
+ scheduler.set_timesteps(num_steps)
288
+
289
+ img = torch.randn(
290
+ (
291
+ 1,
292
+ diffusion_pipe.unet.config.in_channels,
293
+ diffusion_pipe.unet.config.sample_size,
294
+ diffusion_pipe.unet.config.sample_size,
295
+ ),
296
+ device=device,
297
+ dtype=torch.float32,
298
+ )
299
+ labels = torch.tensor([digit], device=device)
300
+
301
+ for t in scheduler.timesteps:
302
+ with torch.no_grad():
303
+ model_output = diffusion_pipe.unet(img, t, class_labels=labels).sample
304
+ img = scheduler.step(model_output, t, img).prev_sample
305
+
306
+ img = (img / 2 + 0.5).clamp(0, 1)
307
+ array = img.cpu().permute(0, 2, 3, 1).numpy()[0]
308
+ array = (array * 255).astype(np.uint8)
309
+ image = Image.fromarray(array.squeeze(), mode="L").resize((28, 28))
310
+ buf = BytesIO()
311
+ image.save(buf, format="PNG")
312
+ buf.seek(0)
313
+ return Response(buf.getvalue(), mimetype="image/png")
314
+ except Exception as e:
315
+ print(f"Error generating diffusion image: {str(e)}")
316
+ return Response(str(e), status=500)
317
+
318
  # ------------------------------------------------------------------------------
319
  # STREAM DIGIT (Pixel-by-pixel generation or token-by-token for VQ)
320
  # ------------------------------------------------------------------------------
 
353
  time.sleep(0.005)
354
 
355
  elif model_name == "vq" and vq_transformer_model and vq_model:
356
+ # VQ-Transformer (token by token, with streaming decode)
357
  generator = vq_transformer_model.generate_token_stream(digit, device)
358
  tokens = []
359
+
360
+ # Stream token generation progress and partial image patches
361
  for i, token in enumerate(generator):
362
  tokens.append(token)
363
+ progress = int((i + 1) * 100 / 49)
364
  yield f"data: token:{i+1}:{progress}\n\n"
365
  time.sleep(0.01)
366
+
367
+ # Partial decode: pad remaining tokens with zero index
368
+ pad_tokens = tokens + [0] * (49 - len(tokens))
369
+ token_tensor = torch.tensor(pad_tokens, dtype=torch.long, device=device).reshape(1, 7, 7)
370
  decoded_img = vq_model.decode(token_tensor)
371
  img_array = (decoded_img.cpu().squeeze().numpy() * 255).astype(np.uint8)
372
+
373
+ # Stream full frame as CSV
374
+ flat_pixels = img_array.flatten().tolist()
375
+ yield f"data: frame:{','.join(str(int(p)) for p in flat_pixels)}\n\n"
376
+ time.sleep(0.001)
 
377
  else:
378
  yield "data: Error: Invalid model selected or model not available.\n\n"
379
 
dataset.py CHANGED
@@ -12,6 +12,7 @@ class ConditionalMNISTDataset(Dataset):
12
  super().__init__()
13
  self.label_offset = label_offset
14
 
 
15
  transform = transforms.ToTensor()
16
  self.data = datasets.MNIST(root="./data", train=(split=="train"),
17
  download=True, transform=transform)
 
12
  super().__init__()
13
  self.label_offset = label_offset
14
 
15
+ # Load MNIST from torchvision
16
  transform = transforms.ToTensor()
17
  self.data = datasets.MNIST(root="./data", train=(split=="train"),
18
  download=True, transform=transform)
templates/index.html CHANGED
@@ -1,104 +1,68 @@
1
  <!DOCTYPE html>
2
- <html>
3
  <head>
4
- <meta charset="utf-8" />
5
- <title>Conditional MNIST Generation (Pixel-by-Pixel)</title>
6
- <style>
7
- body {
8
- font-family: sans-serif;
9
- margin: 20px;
10
- }
11
- #canvas {
12
- width: 280px; /* 10x zoom for 28px images */
13
- height: 280px;
14
- border: 1px solid #ccc;
15
- image-rendering: pixelated; /* keep blocky pixels */
16
- display: block;
17
- margin-top: 10px;
18
- background: #fff;
19
- }
20
- #log {
21
- margin-top: 10px;
22
- white-space: pre-wrap;
23
- font-size: 14px;
24
- color: #666;
25
- }
26
- .error {
27
- color: #ff0000;
28
- }
29
- button:disabled {
30
- opacity: 0.5;
31
- cursor: not-allowed;
32
- }
33
- .progress-bar {
34
- height: 20px;
35
- background-color: #f0f0f0;
36
- border-radius: 5px;
37
- margin-top: 10px;
38
- display: none;
39
- }
40
- .progress-fill {
41
- height: 100%;
42
- background-color: #4CAF50;
43
- border-radius: 5px;
44
- width: 0%;
45
- transition: width 0.1s;
46
- }
47
- </style>
48
  </head>
49
- <body>
50
- <h2>Conditional MNIST Generator</h2>
51
- <p>Enter a digit (0-9) to generate:</p>
52
- <input type="number" id="digitInput" value="7" min="0" max="9" style="width: 60px;">
53
- <button id="generateBtn" onclick="generateDigit()">Generate</button>
54
-
55
- <select id="modelSelector" onchange="selectModel()">
56
- {% if available_models.get('pixel', False) %}
57
- <option value="pixel" {% if selected_model == 'pixel' %}selected{% endif %}>PixelTransformer</option>
58
- {% endif %}
59
- {% if available_models.get('moe', False) %}
60
- <option value="moe" {% if selected_model == 'moe' %}selected{% endif %}>MoEPixelTransformer</option>
61
- {% endif %}
62
- {% if available_models.get('conv', False) %}
63
- <option value="conv" {% if selected_model == 'conv' %}selected{% endif %}>ConvGenerator</option>
64
- {% endif %}
65
- {% if available_models.get('vq', False) %}
66
- <option value="vq" {% if selected_model == 'vq' %}selected{% endif %}>VQ-Transformer</option>
67
- {% endif %}
68
- {% if available_models.get('vq-vae', False) %}
69
- <option value="vq-vae" {% if selected_model == 'vq-vae' %}selected{% endif %}>VQ-VAE Only</option>
70
- {% endif %}
71
- </select>
72
-
73
- <canvas id="canvas" width="28" height="28"></canvas>
74
- <div id="progress-container" class="progress-bar">
75
- <div id="progress-fill" class="progress-fill"></div>
 
 
 
 
76
  </div>
77
- <div id="log"></div>
78
-
79
  <script>
80
  let currentModel = '{{ selected_model }}';
81
  let currentEventSource = null;
82
  let isGenerating = false;
83
- let pixelCounter = 0; // Track pixels for VQ model rendering
84
 
85
  function selectModel() {
86
  const modelSelector = document.getElementById('modelSelector');
87
  currentModel = modelSelector.value;
88
-
89
- // Update model selection on server
90
  fetch('/select_model', {
91
  method: 'POST',
92
- headers: {
93
- 'Content-Type': 'application/json',
94
- },
95
- body: JSON.stringify({ model_type: currentModel })
96
  });
97
-
98
- // Show/hide progress bar if VQ model
99
- document.getElementById('progress-container').style.display =
100
  (currentModel === 'vq' || currentModel === 'vq-vae') ? 'block' : 'none';
 
 
 
 
 
 
101
  }
 
 
102
 
103
  function setGenerating(generating) {
104
  isGenerating = generating;
@@ -109,35 +73,27 @@
109
  function generateDigit() {
110
  if (isGenerating) return;
111
  setGenerating(true);
112
-
113
- // Clean up any existing EventSource
114
  if (currentEventSource) {
115
  currentEventSource.close();
116
  currentEventSource = null;
117
  }
118
-
119
  const digit = document.getElementById('digitInput').value;
120
  const canvas = document.getElementById('canvas');
121
  const ctx = canvas.getContext('2d');
122
  const log = document.getElementById('log');
123
  const progressBar = document.getElementById('progress-fill');
124
- pixelCounter = 0; // Reset pixel counter
125
-
126
- // Clear previous content
127
- ctx.fillStyle = 'white';
128
  ctx.fillRect(0, 0, canvas.width, canvas.height);
129
  log.textContent = 'Generating...';
130
  log.className = '';
131
  progressBar.style.width = '0%';
132
 
133
  if (currentModel === 'conv') {
134
- // For ConvGenerator (instant generation)
135
  fetch(`/generate_conv_digit?digit=${digit}`)
136
  .then(response => {
137
  if (!response.ok) {
138
- return response.text().then(text => {
139
- throw new Error(text || `HTTP error! status: ${response.status}`);
140
- });
141
  }
142
  return response.blob();
143
  })
@@ -148,34 +104,51 @@
148
  log.textContent = 'Generated!';
149
  setGenerating(false);
150
  };
151
- img.onerror = () => {
152
- throw new Error('Failed to load generated image');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  };
 
154
  img.src = URL.createObjectURL(blob);
155
  })
156
  .catch(error => {
157
  console.error('Error:', error);
158
  log.textContent = `Error generating image: ${error.message}`;
159
- log.className = 'error';
160
  setGenerating(false);
161
  });
162
  } else if (currentModel === 'vq' || currentModel === 'vq-vae') {
163
- // Special handling for VQ models
164
  const imageData = ctx.createImageData(28, 28);
165
-
166
- // Use a specific endpoint for vq-vae direct reconstruction
167
- const endpoint = currentModel === 'vq-vae' ?
168
- `/generate_vq_vae_digit?digit=${digit}` :
169
- `/stream_digit?digit=${digit}`;
170
-
171
  if (currentModel === 'vq-vae') {
172
- // For VQ-VAE direct reconstruction (non-streamed)
173
  fetch(endpoint)
174
  .then(response => {
175
  if (!response.ok) {
176
- return response.text().then(text => {
177
- throw new Error(text || `HTTP error! status: ${response.status}`);
178
- });
179
  }
180
  return response.blob();
181
  })
@@ -186,135 +159,102 @@
186
  log.textContent = 'Generated!';
187
  setGenerating(false);
188
  };
189
- img.onerror = () => {
190
- throw new Error('Failed to load generated image');
191
- };
192
  img.src = URL.createObjectURL(blob);
193
  })
194
  .catch(error => {
195
  console.error('Error:', error);
196
  log.textContent = `Error generating image: ${error.message}`;
197
- log.className = 'error';
198
  setGenerating(false);
199
  });
200
  } else {
201
- // For VQ-Transformer (streamed)
202
  currentEventSource = new EventSource(endpoint);
203
-
204
  currentEventSource.onmessage = function(event) {
205
  const data = event.data;
206
-
207
  if (data.startsWith('Error:')) {
208
  log.textContent = data;
209
- log.className = 'error';
210
  currentEventSource.close();
211
  setGenerating(false);
212
  return;
213
  }
214
-
215
- // Check if it's a token progress update
216
  if (data.startsWith('token:')) {
217
- const parts = data.split(':');
218
- const tokenNum = parseInt(parts[1]);
219
- const progress = parseInt(parts[2]);
220
-
221
- // Update progress bar
222
  progressBar.style.width = `${progress}%`;
223
  log.textContent = `Generating tokens: ${tokenNum}/49 (${progress}%)`;
224
  return;
225
  }
226
-
227
- // Otherwise it's a pixel value
228
- const pixelValue = parseInt(data);
229
- if (isNaN(pixelValue)) {
230
- console.error('Invalid pixel value:', data);
 
 
 
 
 
 
 
231
  return;
232
  }
233
-
234
- // Calculate pixel position
235
  const x = pixelCounter % 28;
236
  const y = Math.floor(pixelCounter / 28);
237
-
238
- // Set RGB values for this pixel
239
  const idx = (y * 28 + x) * 4;
240
- imageData.data[idx] = pixelValue; // R
241
- imageData.data[idx + 1] = pixelValue; // G
242
- imageData.data[idx + 2] = pixelValue; // B
243
- imageData.data[idx + 3] = 255; // A (opacity)
244
-
245
  pixelCounter++;
246
-
247
- // Update canvas every 28 pixels (full row)
248
- if (x === 27 || pixelCounter === 28*28) {
249
  ctx.putImageData(imageData, 0, 0);
250
-
251
- if (pixelCounter >= 28*28) {
252
  currentEventSource.close();
253
  log.textContent = 'Generation complete!';
254
  setGenerating(false);
255
  }
256
  }
257
  };
258
-
259
- currentEventSource.onerror = function(error) {
260
- console.error('EventSource error:', error);
261
  currentEventSource.close();
262
- log.textContent = 'Error in streaming!';
263
- log.className = 'error';
264
  setGenerating(false);
265
  };
266
  }
267
  } else {
268
- // For PixelTransformer & MoEPixelTransformer (pixel streaming)
269
  const imageData = ctx.createImageData(28, 28);
270
  let index = 0;
271
-
272
  currentEventSource = new EventSource(`/stream_digit?digit=${digit}`);
273
-
274
  currentEventSource.onmessage = function(event) {
275
  const data = event.data;
276
  if (data.startsWith('Error:')) {
277
  log.textContent = data;
278
- log.className = 'error';
279
  currentEventSource.close();
280
- currentEventSource = null;
281
  setGenerating(false);
282
  return;
283
  }
284
-
285
  const pixelValue = parseInt(data);
286
- if (isNaN(pixelValue)) {
287
- console.error('Invalid pixel value:', data);
288
- return;
289
- }
290
-
291
- // Set RGB values to the same value for grayscale
292
- imageData.data[index] = pixelValue; // R
293
- imageData.data[index + 1] = pixelValue; // G
294
- imageData.data[index + 2] = pixelValue; // B
295
- imageData.data[index + 3] = 255; // A (opacity)
296
  index += 4;
297
-
298
- // Update canvas every row (28 pixels)
299
  if (index % (28 * 4) === 0) {
300
  ctx.putImageData(imageData, 0, 0);
301
  }
302
-
303
  if (index >= 28 * 28 * 4) {
304
  currentEventSource.close();
305
- currentEventSource = null;
306
  log.textContent = 'Generation complete!';
307
  setGenerating(false);
308
  }
309
  };
310
-
311
- currentEventSource.onerror = function(error) {
312
- console.error('EventSource error:', error);
313
- currentEventSource.close();
314
- currentEventSource = null;
315
- log.textContent = 'Error in streaming!';
316
- log.className = 'error';
317
- setGenerating(false);
318
  };
319
  }
320
  }
 
1
  <!DOCTYPE html>
2
+ <html lang="en">
3
  <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Image Generator</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  </head>
9
+ <body class="bg-black text-white">
10
+ <div class="flex flex-col items-center justify-center min-h-screen space-y-6 p-4">
11
+ <h1 class="text-3xl font-semibold">Image Generator</h1>
12
+ <div class="flex flex-wrap items-center justify-center space-x-2">
13
+ <input id="digitInput" type="number" min="0" max="9" value="7"
14
+ class="w-16 px-2 py-1 bg-gray-800 border border-gray-600 rounded text-white focus:outline-none"/>
15
+ <input id="stepsInput" type="number" min="1" max="1000" value="50" placeholder="steps"
16
+ class="hidden w-20 px-2 py-1 bg-gray-800 border border-gray-600 rounded text-white focus:outline-none"/>
17
+ <button id="generateBtn" onclick="generateDigit()"
18
+ class="px-4 py-1 bg-gray-700 hover:bg-gray-600 rounded disabled:opacity-50">
19
+ Generate
20
+ </button>
21
+ <select id="modelSelector" onchange="selectModel()"
22
+ class="px-2 py-1 bg-gray-800 border border-gray-600 rounded text-white focus:outline-none">
23
+ {% for name, available in available_models.items() %}
24
+ {% if available %}
25
+ <option value="{{ name }}" {% if selected_model == name %}selected{% endif %}>
26
+ {{ name|capitalize }}
27
+ </option>
28
+ {% endif %}
29
+ {% endfor %}
30
+ </select>
31
+ </div>
32
+ <canvas id="canvas" width="28" height="28"
33
+ class="w-[280px] h-[280px] border border-gray-600 bg-black"
34
+ style="image-rendering: pixelated;"></canvas>
35
+ <div id="progress-container" class="w-[280px] bg-gray-800 rounded overflow-hidden"
36
+ style="display: none;">
37
+ <div id="progress-fill" class="bg-white h-1 w-0 transition-all"></div>
38
+ </div>
39
+ <div id="log" class="text-sm text-gray-400"></div>
40
  </div>
 
 
41
  <script>
42
  let currentModel = '{{ selected_model }}';
43
  let currentEventSource = null;
44
  let isGenerating = false;
45
+ let pixelCounter = 0;
46
 
47
  function selectModel() {
48
  const modelSelector = document.getElementById('modelSelector');
49
  currentModel = modelSelector.value;
 
 
50
  fetch('/select_model', {
51
  method: 'POST',
52
+ headers: {'Content-Type': 'application/json'},
53
+ body: JSON.stringify({model_type: currentModel})
 
 
54
  });
55
+ document.getElementById('progress-container').style.display =
 
 
56
  (currentModel === 'vq' || currentModel === 'vq-vae') ? 'block' : 'none';
57
+ const stepsInput = document.getElementById('stepsInput');
58
+ if (currentModel === 'diffusion') {
59
+ stepsInput.classList.remove('hidden');
60
+ } else {
61
+ stepsInput.classList.add('hidden');
62
+ }
63
  }
64
+ // initialize UI based on default selected model
65
+ selectModel();
66
 
67
  function setGenerating(generating) {
68
  isGenerating = generating;
 
73
  function generateDigit() {
74
  if (isGenerating) return;
75
  setGenerating(true);
 
 
76
  if (currentEventSource) {
77
  currentEventSource.close();
78
  currentEventSource = null;
79
  }
 
80
  const digit = document.getElementById('digitInput').value;
81
  const canvas = document.getElementById('canvas');
82
  const ctx = canvas.getContext('2d');
83
  const log = document.getElementById('log');
84
  const progressBar = document.getElementById('progress-fill');
85
+ pixelCounter = 0;
86
+ ctx.fillStyle = 'black';
 
 
87
  ctx.fillRect(0, 0, canvas.width, canvas.height);
88
  log.textContent = 'Generating...';
89
  log.className = '';
90
  progressBar.style.width = '0%';
91
 
92
  if (currentModel === 'conv') {
 
93
  fetch(`/generate_conv_digit?digit=${digit}`)
94
  .then(response => {
95
  if (!response.ok) {
96
+ return response.text().then(text => { throw new Error(text || `HTTP error! status: ${response.status}`); });
 
 
97
  }
98
  return response.blob();
99
  })
 
104
  log.textContent = 'Generated!';
105
  setGenerating(false);
106
  };
107
+ img.onerror = () => { throw new Error('Failed to load generated image'); };
108
+ img.src = URL.createObjectURL(blob);
109
+ })
110
+ .catch(error => {
111
+ console.error('Error:', error);
112
+ log.textContent = `Error generating image: ${error.message}`;
113
+ log.className = 'text-red-500';
114
+ setGenerating(false);
115
+ });
116
+ } else if (currentModel === 'diffusion') {
117
+ const steps = document.getElementById('stepsInput').value;
118
+ fetch(`/generate_diffusion_digit?digit=${digit}&steps=${steps}`)
119
+ .then(response => {
120
+ if (!response.ok) {
121
+ return response.text().then(text => { throw new Error(text || `HTTP error! status: ${response.status}`); });
122
+ }
123
+ return response.blob();
124
+ })
125
+ .then(blob => {
126
+ const img = new Image();
127
+ img.onload = () => {
128
+ ctx.drawImage(img, 0, 0);
129
+ log.textContent = 'Generated!';
130
+ setGenerating(false);
131
  };
132
+ img.onerror = () => { throw new Error('Failed to load generated image'); };
133
  img.src = URL.createObjectURL(blob);
134
  })
135
  .catch(error => {
136
  console.error('Error:', error);
137
  log.textContent = `Error generating image: ${error.message}`;
138
+ log.className = 'text-red-500';
139
  setGenerating(false);
140
  });
141
  } else if (currentModel === 'vq' || currentModel === 'vq-vae') {
 
142
  const imageData = ctx.createImageData(28, 28);
143
+ const endpoint = currentModel === 'vq-vae'
144
+ ? `/generate_vq_vae_digit?digit=${digit}`
145
+ : `/stream_digit?digit=${digit}`;
146
+
 
 
147
  if (currentModel === 'vq-vae') {
 
148
  fetch(endpoint)
149
  .then(response => {
150
  if (!response.ok) {
151
+ return response.text().then(text => { throw new Error(text || `HTTP error! status: ${response.status}`); });
 
 
152
  }
153
  return response.blob();
154
  })
 
159
  log.textContent = 'Generated!';
160
  setGenerating(false);
161
  };
162
+ img.onerror = () => { throw new Error('Failed to load generated image'); };
 
 
163
  img.src = URL.createObjectURL(blob);
164
  })
165
  .catch(error => {
166
  console.error('Error:', error);
167
  log.textContent = `Error generating image: ${error.message}`;
168
+ log.className = 'text-red-500';
169
  setGenerating(false);
170
  });
171
  } else {
 
172
  currentEventSource = new EventSource(endpoint);
 
173
  currentEventSource.onmessage = function(event) {
174
  const data = event.data;
 
175
  if (data.startsWith('Error:')) {
176
  log.textContent = data;
177
+ log.className = 'text-red-500';
178
  currentEventSource.close();
179
  setGenerating(false);
180
  return;
181
  }
 
 
182
  if (data.startsWith('token:')) {
183
+ const [, tokenNum, progress] = data.split(':');
 
 
 
 
184
  progressBar.style.width = `${progress}%`;
185
  log.textContent = `Generating tokens: ${tokenNum}/49 (${progress}%)`;
186
  return;
187
  }
188
+ if (data.startsWith('frame:')) {
189
+ const pixels = data.slice(6).split(',').map(Number);
190
+ for (let idx = 0; idx < pixels.length; idx++) {
191
+ const x = idx % 28;
192
+ const y = Math.floor(idx / 28);
193
+ const i = (y * 28 + x) * 4;
194
+ imageData.data[i] = pixels[idx];
195
+ imageData.data[i + 1] = pixels[idx];
196
+ imageData.data[i + 2] = pixels[idx];
197
+ imageData.data[i + 3] = 255;
198
+ }
199
+ ctx.putImageData(imageData, 0, 0);
200
  return;
201
  }
202
+ const pixelValue = parseInt(data);
203
+ if (isNaN(pixelValue)) return;
204
  const x = pixelCounter % 28;
205
  const y = Math.floor(pixelCounter / 28);
 
 
206
  const idx = (y * 28 + x) * 4;
207
+ imageData.data[idx] = pixelValue;
208
+ imageData.data[idx + 1] = pixelValue;
209
+ imageData.data[idx + 2] = pixelValue;
210
+ imageData.data[idx + 3] = 255;
 
211
  pixelCounter++;
212
+ if (x === 27 || pixelCounter === 28 * 28) {
 
 
213
  ctx.putImageData(imageData, 0, 0);
214
+ if (pixelCounter >= 28 * 28) {
 
215
  currentEventSource.close();
216
  log.textContent = 'Generation complete!';
217
  setGenerating(false);
218
  }
219
  }
220
  };
221
+ currentEventSource.onerror = function(e) {
 
 
222
  currentEventSource.close();
 
 
223
  setGenerating(false);
224
  };
225
  }
226
  } else {
 
227
  const imageData = ctx.createImageData(28, 28);
228
  let index = 0;
 
229
  currentEventSource = new EventSource(`/stream_digit?digit=${digit}`);
 
230
  currentEventSource.onmessage = function(event) {
231
  const data = event.data;
232
  if (data.startsWith('Error:')) {
233
  log.textContent = data;
234
+ log.className = 'text-red-500';
235
  currentEventSource.close();
 
236
  setGenerating(false);
237
  return;
238
  }
 
239
  const pixelValue = parseInt(data);
240
+ if (isNaN(pixelValue)) return;
241
+ imageData.data[index] = pixelValue;
242
+ imageData.data[index + 1] = pixelValue;
243
+ imageData.data[index + 2] = pixelValue;
244
+ imageData.data[index + 3] = 255;
 
 
 
 
 
245
  index += 4;
 
 
246
  if (index % (28 * 4) === 0) {
247
  ctx.putImageData(imageData, 0, 0);
248
  }
 
249
  if (index >= 28 * 28 * 4) {
250
  currentEventSource.close();
 
251
  log.textContent = 'Generation complete!';
252
  setGenerating(false);
253
  }
254
  };
255
+ currentEventSource.onerror = function() {
256
+ currentEventSource.close();
257
+ setGenerating(false);
 
 
 
 
 
258
  };
259
  }
260
  }
train_conv.py CHANGED
@@ -5,12 +5,18 @@ from torch.utils.data import DataLoader
5
  from transformers import PreTrainedModel, PretrainedConfig
6
  from dataset import ConditionalMNISTDataset
7
 
 
 
 
8
  class ConvConfig(PretrainedConfig):
9
  model_type = "conv_generator"
10
  def __init__(self, latent_dim=100, **kwargs):
11
  super().__init__(**kwargs)
12
  self.latent_dim = latent_dim
13
 
 
 
 
14
  class ConvGeneratorModel(PreTrainedModel):
15
  config_class = ConvConfig
16
  def __init__(self, config):
@@ -51,11 +57,16 @@ class ConvGeneratorModel(PreTrainedModel):
51
  return out
52
 
53
  def main():
54
- device = torch.device("mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
55
  print(f"Using device: {device}")
56
 
57
  dataset = ConditionalMNISTDataset("train")
58
- loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0)
59
 
60
  config = ConvConfig(latent_dim=100)
61
  model = ConvGeneratorModel(config).to(device)
@@ -67,6 +78,7 @@ def main():
67
  model.train()
68
  total_loss = 0
69
  for step, (x, y) in enumerate(loader, 1):
 
70
  x = x.to(device)
71
  y = y.to(device)
72
 
@@ -76,7 +88,7 @@ def main():
76
  optimizer.zero_grad()
77
  generated_images = model(labels) # (bsz, 1, 28, 28)
78
 
79
- # Mean Squared Error loss is typical for image generation
80
  loss = F.mse_loss(generated_images, real_images)
81
  loss.backward()
82
  optimizer.step()
 
5
  from transformers import PreTrainedModel, PretrainedConfig
6
  from dataset import ConditionalMNISTDataset
7
 
8
+ ############################
9
+ # Config Class #
10
+ ############################
11
  class ConvConfig(PretrainedConfig):
12
  model_type = "conv_generator"
13
  def __init__(self, latent_dim=100, **kwargs):
14
  super().__init__(**kwargs)
15
  self.latent_dim = latent_dim
16
 
17
+ ############################
18
+ # Model Class #
19
+ ############################
20
  class ConvGeneratorModel(PreTrainedModel):
21
  config_class = ConvConfig
22
  def __init__(self, config):
 
57
  return out
58
 
59
  def main():
60
+ # Ensure MPS is available
61
+ if not torch.backends.mps.is_available():
62
+ print("MPS not available. Falling back to CPU.")
63
+ device = "cpu"
64
+ else:
65
+ device = "mps"
66
  print(f"Using device: {device}")
67
 
68
  dataset = ConditionalMNISTDataset("train")
69
+ loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0) # Reduced workers for MPS
70
 
71
  config = ConvConfig(latent_dim=100)
72
  model = ConvGeneratorModel(config).to(device)
 
78
  model.train()
79
  total_loss = 0
80
  for step, (x, y) in enumerate(loader, 1):
81
+ # Move both inputs to device
82
  x = x.to(device)
83
  y = y.to(device)
84
 
 
88
  optimizer.zero_grad()
89
  generated_images = model(labels) # (bsz, 1, 28, 28)
90
 
91
+ # Use Mean Squared Error loss
92
  loss = F.mse_loss(generated_images, real_images)
93
  loss.backward()
94
  optimizer.step()
train_diff.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader
6
+ from torchvision import transforms, datasets
7
+ from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
8
+ from tqdm import tqdm
9
+
10
+ def train_diffusion():
11
+ # Train and save a DDPM diffusion model on MNIST.
12
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
13
+ print(f"Using device: {device}")
14
+
15
+ transform = transforms.Compose([transforms.ToTensor()])
16
+ train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
17
+ loader = DataLoader(train_ds, batch_size=128, shuffle=True)
18
+
19
+ # Conditional DDPM UNet for MNIST digits
20
+ unet = UNet2DModel(
21
+ sample_size=28,
22
+ in_channels=1,
23
+ out_channels=1,
24
+ block_out_channels=(32, 64, 128),
25
+ down_block_types=("DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
26
+ up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
27
+ num_class_embeds=10,
28
+ ).to(device)
29
+ scheduler = DDPMScheduler(num_train_timesteps=1000)
30
+ pipeline = DDPMPipeline(unet=unet, scheduler=scheduler).to(device)
31
+ optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4, weight_decay=1e-4) # changed from Adam
32
+
33
+ epochs = 5
34
+ print(f"Training DDPM for {epochs} epochs...")
35
+ try:
36
+ for epoch in range(1, epochs + 1):
37
+ pbar = tqdm(loader, desc=f"Epoch {epoch}/{epochs}")
38
+ for images, labels in pbar:
39
+ images = images.to(device)
40
+ labels = labels.to(device)
41
+ noise = torch.randn_like(images)
42
+ timesteps = torch.randint(
43
+ 0, scheduler.num_train_timesteps, (images.shape[0],), device=device
44
+ ).long()
45
+ noisy = scheduler.add_noise(images, noise, timesteps)
46
+
47
+ # Conditional noise prediction
48
+ model_pred = unet(noisy, timesteps, class_labels=labels, return_dict=False)[0]
49
+ loss = F.mse_loss(model_pred, noise)
50
+ optimizer.zero_grad()
51
+ loss.backward()
52
+ optimizer.step()
53
+ pbar.set_postfix(loss=f"{loss.item():.4f}")
54
+ except KeyboardInterrupt:
55
+ print("\nKeyboard interrupt, saving model...")
56
+ output_dir = "my_diffusion_model"
57
+ pipeline.save_pretrained(output_dir)
58
+ print(f"Model saved to {output_dir}/")
59
+ return pipeline
60
+
61
+ output_dir = "my_diffusion_model"
62
+ pipeline.save_pretrained(output_dir)
63
+ print(f"Training complete. Model saved to {output_dir}/")
64
+ return pipeline
65
+
66
+ if __name__ == "__main__":
67
+ train_diffusion()
vq_transformer.py CHANGED
@@ -19,7 +19,7 @@ class VQTransformerConfig:
19
  epochs: int = 10
20
  warmup_steps: int = 500
21
  label_offset: int = 512 # Labels are tokens 512-521 (for digits 0-9)
22
- device: str = field(default="mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
23
 
24
  @classmethod
25
  def from_pretrained(cls, path: str):
 
19
  epochs: int = 10
20
  warmup_steps: int = 500
21
  label_offset: int = 512 # Labels are tokens 512-521 (for digits 0-9)
22
+ device: str = field(default="mps" if torch.backends.mps.is_available() else "cpu")
23
 
24
  @classmethod
25
  def from_pretrained(cls, path: str):
vq_vae.py CHANGED
@@ -64,6 +64,7 @@ class VQVAE(nn.Module):
64
  nn.Conv2d(32, embedding_dim, 1, stride=1) # 7x7 -> 7x7xembedding_dim
65
  )
66
 
 
67
  self.vq = VectorQuantizer(num_embeddings, embedding_dim)
68
 
69
  # Decoder for MNIST (7x7 -> 28x28)
@@ -94,6 +95,7 @@ class VQVAE(nn.Module):
94
  quantized = torch.matmul(one_hot.permute(0, 2, 3, 1), self.vq.embedding.weight)
95
  quantized = quantized.permute(0, 3, 1, 2)
96
 
 
97
  reconstructed = self.decoder(quantized)
98
  return reconstructed
99
 
@@ -137,14 +139,18 @@ class VQVAE(nn.Module):
137
 
138
  @staticmethod
139
  def train_and_save(output_path="vq_vae_model.pt", device='cpu', batch_size=128, epochs=10):
 
140
  transform = transforms.Compose([transforms.ToTensor()])
141
  train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
142
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
143
-
 
144
  model = VQVAE().to(device)
145
-
 
146
  model.train_model(train_loader, epochs=epochs, device=device)
147
 
 
148
  torch.save(model.state_dict(), output_path)
149
  print(f"Model saved to {output_path}")
150
 
 
64
  nn.Conv2d(32, embedding_dim, 1, stride=1) # 7x7 -> 7x7xembedding_dim
65
  )
66
 
67
+ # Vector Quantizer
68
  self.vq = VectorQuantizer(num_embeddings, embedding_dim)
69
 
70
  # Decoder for MNIST (7x7 -> 28x28)
 
95
  quantized = torch.matmul(one_hot.permute(0, 2, 3, 1), self.vq.embedding.weight)
96
  quantized = quantized.permute(0, 3, 1, 2)
97
 
98
+ # Decode
99
  reconstructed = self.decoder(quantized)
100
  return reconstructed
101
 
 
139
 
140
  @staticmethod
141
  def train_and_save(output_path="vq_vae_model.pt", device='cpu', batch_size=128, epochs=10):
142
+ # Setup data
143
  transform = transforms.Compose([transforms.ToTensor()])
144
  train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
145
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
146
+
147
+ # Create model
148
  model = VQVAE().to(device)
149
+
150
+ # Train
151
  model.train_model(train_loader, epochs=epochs, device=device)
152
 
153
+ # Save model
154
  torch.save(model.state_dict(), output_path)
155
  print(f"Model saved to {output_path}")
156