Spaces:
Sleeping
Sleeping
Kyryll Kochkin commited on
Commit ·
ad9ba57
1
Parent(s): b9e551d
new frontend
Browse files- .DS_Store +0 -0
- .gitignore +1 -0
- app.py +79 -19
- dataset.py +1 -0
- templates/index.html +116 -176
- train_conv.py +15 -3
- train_diff.py +67 -0
- vq_transformer.py +1 -1
- vq_vae.py +8 -2
.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
app.py
CHANGED
|
@@ -17,8 +17,8 @@ from vq_vae import VQVAE
|
|
| 17 |
|
| 18 |
app = Flask(__name__, template_folder="templates")
|
| 19 |
|
| 20 |
-
|
| 21 |
-
device =
|
| 22 |
|
| 23 |
# ------------------------------------------------------------------------------
|
| 24 |
# Model Status Tracking
|
|
@@ -28,14 +28,13 @@ available_models = {
|
|
| 28 |
"moe": False, # MoEPixelTransformer (mixture of experts)
|
| 29 |
"conv": False, # ConvGenerator (direct generation)
|
| 30 |
"vq": False, # VQTransformer (autoregressive by token)
|
| 31 |
-
"vq-vae": False
|
|
|
|
| 32 |
}
|
| 33 |
|
| 34 |
# ------------------------------------------------------------------------------
|
| 35 |
# Load models
|
| 36 |
# ------------------------------------------------------------------------------
|
| 37 |
-
|
| 38 |
-
|
| 39 |
# 1. Load ConvGenerator
|
| 40 |
try:
|
| 41 |
conv_config = ConvConfig.from_pretrained("my_conv")
|
|
@@ -93,6 +92,22 @@ try:
|
|
| 93 |
except Exception as e:
|
| 94 |
print(f"✗ Error loading VQ models: {str(e)}")
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
# Select default model (use the first available one)
|
| 97 |
for model_name, is_available in available_models.items():
|
| 98 |
if is_available:
|
|
@@ -254,6 +269,52 @@ def generate_vq_vae_digit():
|
|
| 254 |
print(f"Error in VQ-VAE reconstruction: {str(e)}")
|
| 255 |
return Response(str(e), status=500)
|
| 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
# ------------------------------------------------------------------------------
|
| 258 |
# STREAM DIGIT (Pixel-by-pixel generation or token-by-token for VQ)
|
| 259 |
# ------------------------------------------------------------------------------
|
|
@@ -292,28 +353,27 @@ def stream_digit():
|
|
| 292 |
time.sleep(0.005)
|
| 293 |
|
| 294 |
elif model_name == "vq" and vq_transformer_model and vq_model:
|
| 295 |
-
# VQ-Transformer (token by token,
|
| 296 |
generator = vq_transformer_model.generate_token_stream(digit, device)
|
| 297 |
tokens = []
|
| 298 |
-
|
| 299 |
-
# Stream token generation progress
|
| 300 |
for i, token in enumerate(generator):
|
| 301 |
tokens.append(token)
|
| 302 |
-
progress = int((i+1) * 100 / 49)
|
| 303 |
yield f"data: token:{i+1}:{progress}\n\n"
|
| 304 |
time.sleep(0.01)
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
token_tensor = torch.tensor(
|
| 309 |
decoded_img = vq_model.decode(token_tensor)
|
| 310 |
img_array = (decoded_img.cpu().squeeze().numpy() * 255).astype(np.uint8)
|
| 311 |
-
|
| 312 |
-
# Stream
|
| 313 |
-
|
| 314 |
-
for
|
| 315 |
-
|
| 316 |
-
time.sleep(0.001)
|
| 317 |
else:
|
| 318 |
yield "data: Error: Invalid model selected or model not available.\n\n"
|
| 319 |
|
|
|
|
| 17 |
|
| 18 |
app = Flask(__name__, template_folder="templates")
|
| 19 |
|
| 20 |
+
# Detect device — adjust if you want to force 'mps' or 'cuda'
|
| 21 |
+
device = 'cpu' # don't know why but mps runs slower than cpu
|
| 22 |
|
| 23 |
# ------------------------------------------------------------------------------
|
| 24 |
# Model Status Tracking
|
|
|
|
| 28 |
"moe": False, # MoEPixelTransformer (mixture of experts)
|
| 29 |
"conv": False, # ConvGenerator (direct generation)
|
| 30 |
"vq": False, # VQTransformer (autoregressive by token)
|
| 31 |
+
"vq-vae": False, # VQ-VAE only (encode/decode)
|
| 32 |
+
"diffusion": False # Diffusion model (DDPM)
|
| 33 |
}
|
| 34 |
|
| 35 |
# ------------------------------------------------------------------------------
|
| 36 |
# Load models
|
| 37 |
# ------------------------------------------------------------------------------
|
|
|
|
|
|
|
| 38 |
# 1. Load ConvGenerator
|
| 39 |
try:
|
| 40 |
conv_config = ConvConfig.from_pretrained("my_conv")
|
|
|
|
| 92 |
except Exception as e:
|
| 93 |
print(f"✗ Error loading VQ models: {str(e)}")
|
| 94 |
|
| 95 |
+
# 5. Load Diffusion pipeline if available
|
| 96 |
+
diffusion_pipe = None
|
| 97 |
+
try:
|
| 98 |
+
from diffusers import DDPMPipeline
|
| 99 |
+
diffusion_model_dir = "my_diffusion_model"
|
| 100 |
+
if os.path.exists(diffusion_model_dir):
|
| 101 |
+
diffusion_pipe = DDPMPipeline.from_pretrained(
|
| 102 |
+
diffusion_model_dir, torch_dtype=torch.float32
|
| 103 |
+
).to(device)
|
| 104 |
+
available_models["diffusion"] = True
|
| 105 |
+
print("✓ Diffusion pipeline loaded successfully")
|
| 106 |
+
else:
|
| 107 |
+
print(f"✗ Diffusion model directory '{diffusion_model_dir}' not found, skipping diffusion.")
|
| 108 |
+
except Exception as e:
|
| 109 |
+
diffusion_pipe = None
|
| 110 |
+
print(f"✗ Error loading Diffusion pipeline: {str(e)}")
|
| 111 |
# Select default model (use the first available one)
|
| 112 |
for model_name, is_available in available_models.items():
|
| 113 |
if is_available:
|
|
|
|
| 269 |
print(f"Error in VQ-VAE reconstruction: {str(e)}")
|
| 270 |
return Response(str(e), status=500)
|
| 271 |
|
| 272 |
+
|
| 273 |
+
# ------------------------------------------------------------------------------
|
| 274 |
+
# DIFFUSION GENERATION (using DDPM pipeline)
|
| 275 |
+
# ------------------------------------------------------------------------------
|
| 276 |
+
@app.route("/generate_diffusion_digit", methods=["GET"])
|
| 277 |
+
def generate_diffusion_digit():
|
| 278 |
+
"""Generate image using diffusion model (DDPM)."""
|
| 279 |
+
if diffusion_pipe is None:
|
| 280 |
+
return Response("Diffusion model not loaded", status=500)
|
| 281 |
+
try:
|
| 282 |
+
digit = int(request.args.get("digit", 0))
|
| 283 |
+
steps = int(request.args.get("steps", 50))
|
| 284 |
+
print(f"Generating diffusion image for digit {digit} with {steps} steps...")
|
| 285 |
+
num_steps = steps
|
| 286 |
+
scheduler = diffusion_pipe.scheduler
|
| 287 |
+
scheduler.set_timesteps(num_steps)
|
| 288 |
+
|
| 289 |
+
img = torch.randn(
|
| 290 |
+
(
|
| 291 |
+
1,
|
| 292 |
+
diffusion_pipe.unet.config.in_channels,
|
| 293 |
+
diffusion_pipe.unet.config.sample_size,
|
| 294 |
+
diffusion_pipe.unet.config.sample_size,
|
| 295 |
+
),
|
| 296 |
+
device=device,
|
| 297 |
+
dtype=torch.float32,
|
| 298 |
+
)
|
| 299 |
+
labels = torch.tensor([digit], device=device)
|
| 300 |
+
|
| 301 |
+
for t in scheduler.timesteps:
|
| 302 |
+
with torch.no_grad():
|
| 303 |
+
model_output = diffusion_pipe.unet(img, t, class_labels=labels).sample
|
| 304 |
+
img = scheduler.step(model_output, t, img).prev_sample
|
| 305 |
+
|
| 306 |
+
img = (img / 2 + 0.5).clamp(0, 1)
|
| 307 |
+
array = img.cpu().permute(0, 2, 3, 1).numpy()[0]
|
| 308 |
+
array = (array * 255).astype(np.uint8)
|
| 309 |
+
image = Image.fromarray(array.squeeze(), mode="L").resize((28, 28))
|
| 310 |
+
buf = BytesIO()
|
| 311 |
+
image.save(buf, format="PNG")
|
| 312 |
+
buf.seek(0)
|
| 313 |
+
return Response(buf.getvalue(), mimetype="image/png")
|
| 314 |
+
except Exception as e:
|
| 315 |
+
print(f"Error generating diffusion image: {str(e)}")
|
| 316 |
+
return Response(str(e), status=500)
|
| 317 |
+
|
| 318 |
# ------------------------------------------------------------------------------
|
| 319 |
# STREAM DIGIT (Pixel-by-pixel generation or token-by-token for VQ)
|
| 320 |
# ------------------------------------------------------------------------------
|
|
|
|
| 353 |
time.sleep(0.005)
|
| 354 |
|
| 355 |
elif model_name == "vq" and vq_transformer_model and vq_model:
|
| 356 |
+
# VQ-Transformer (token by token, with streaming decode)
|
| 357 |
generator = vq_transformer_model.generate_token_stream(digit, device)
|
| 358 |
tokens = []
|
| 359 |
+
|
| 360 |
+
# Stream token generation progress and partial image patches
|
| 361 |
for i, token in enumerate(generator):
|
| 362 |
tokens.append(token)
|
| 363 |
+
progress = int((i + 1) * 100 / 49)
|
| 364 |
yield f"data: token:{i+1}:{progress}\n\n"
|
| 365 |
time.sleep(0.01)
|
| 366 |
+
|
| 367 |
+
# Partial decode: pad remaining tokens with zero index
|
| 368 |
+
pad_tokens = tokens + [0] * (49 - len(tokens))
|
| 369 |
+
token_tensor = torch.tensor(pad_tokens, dtype=torch.long, device=device).reshape(1, 7, 7)
|
| 370 |
decoded_img = vq_model.decode(token_tensor)
|
| 371 |
img_array = (decoded_img.cpu().squeeze().numpy() * 255).astype(np.uint8)
|
| 372 |
+
|
| 373 |
+
# Stream full frame as CSV
|
| 374 |
+
flat_pixels = img_array.flatten().tolist()
|
| 375 |
+
yield f"data: frame:{','.join(str(int(p)) for p in flat_pixels)}\n\n"
|
| 376 |
+
time.sleep(0.001)
|
|
|
|
| 377 |
else:
|
| 378 |
yield "data: Error: Invalid model selected or model not available.\n\n"
|
| 379 |
|
dataset.py
CHANGED
|
@@ -12,6 +12,7 @@ class ConditionalMNISTDataset(Dataset):
|
|
| 12 |
super().__init__()
|
| 13 |
self.label_offset = label_offset
|
| 14 |
|
|
|
|
| 15 |
transform = transforms.ToTensor()
|
| 16 |
self.data = datasets.MNIST(root="./data", train=(split=="train"),
|
| 17 |
download=True, transform=transform)
|
|
|
|
| 12 |
super().__init__()
|
| 13 |
self.label_offset = label_offset
|
| 14 |
|
| 15 |
+
# Load MNIST from torchvision
|
| 16 |
transform = transforms.ToTensor()
|
| 17 |
self.data = datasets.MNIST(root="./data", train=(split=="train"),
|
| 18 |
download=True, transform=transform)
|
templates/index.html
CHANGED
|
@@ -1,104 +1,68 @@
|
|
| 1 |
<!DOCTYPE html>
|
| 2 |
-
<html>
|
| 3 |
<head>
|
| 4 |
-
<meta charset="
|
| 5 |
-
<
|
| 6 |
-
<
|
| 7 |
-
|
| 8 |
-
font-family: sans-serif;
|
| 9 |
-
margin: 20px;
|
| 10 |
-
}
|
| 11 |
-
#canvas {
|
| 12 |
-
width: 280px; /* 10x zoom for 28px images */
|
| 13 |
-
height: 280px;
|
| 14 |
-
border: 1px solid #ccc;
|
| 15 |
-
image-rendering: pixelated; /* keep blocky pixels */
|
| 16 |
-
display: block;
|
| 17 |
-
margin-top: 10px;
|
| 18 |
-
background: #fff;
|
| 19 |
-
}
|
| 20 |
-
#log {
|
| 21 |
-
margin-top: 10px;
|
| 22 |
-
white-space: pre-wrap;
|
| 23 |
-
font-size: 14px;
|
| 24 |
-
color: #666;
|
| 25 |
-
}
|
| 26 |
-
.error {
|
| 27 |
-
color: #ff0000;
|
| 28 |
-
}
|
| 29 |
-
button:disabled {
|
| 30 |
-
opacity: 0.5;
|
| 31 |
-
cursor: not-allowed;
|
| 32 |
-
}
|
| 33 |
-
.progress-bar {
|
| 34 |
-
height: 20px;
|
| 35 |
-
background-color: #f0f0f0;
|
| 36 |
-
border-radius: 5px;
|
| 37 |
-
margin-top: 10px;
|
| 38 |
-
display: none;
|
| 39 |
-
}
|
| 40 |
-
.progress-fill {
|
| 41 |
-
height: 100%;
|
| 42 |
-
background-color: #4CAF50;
|
| 43 |
-
border-radius: 5px;
|
| 44 |
-
width: 0%;
|
| 45 |
-
transition: width 0.1s;
|
| 46 |
-
}
|
| 47 |
-
</style>
|
| 48 |
</head>
|
| 49 |
-
<body>
|
| 50 |
-
<
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
<div id="progress-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
</div>
|
| 77 |
-
<div id="log"></div>
|
| 78 |
-
|
| 79 |
<script>
|
| 80 |
let currentModel = '{{ selected_model }}';
|
| 81 |
let currentEventSource = null;
|
| 82 |
let isGenerating = false;
|
| 83 |
-
let pixelCounter = 0;
|
| 84 |
|
| 85 |
function selectModel() {
|
| 86 |
const modelSelector = document.getElementById('modelSelector');
|
| 87 |
currentModel = modelSelector.value;
|
| 88 |
-
|
| 89 |
-
// Update model selection on server
|
| 90 |
fetch('/select_model', {
|
| 91 |
method: 'POST',
|
| 92 |
-
headers: {
|
| 93 |
-
|
| 94 |
-
},
|
| 95 |
-
body: JSON.stringify({ model_type: currentModel })
|
| 96 |
});
|
| 97 |
-
|
| 98 |
-
// Show/hide progress bar if VQ model
|
| 99 |
-
document.getElementById('progress-container').style.display =
|
| 100 |
(currentModel === 'vq' || currentModel === 'vq-vae') ? 'block' : 'none';
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
}
|
|
|
|
|
|
|
| 102 |
|
| 103 |
function setGenerating(generating) {
|
| 104 |
isGenerating = generating;
|
|
@@ -109,35 +73,27 @@
|
|
| 109 |
function generateDigit() {
|
| 110 |
if (isGenerating) return;
|
| 111 |
setGenerating(true);
|
| 112 |
-
|
| 113 |
-
// Clean up any existing EventSource
|
| 114 |
if (currentEventSource) {
|
| 115 |
currentEventSource.close();
|
| 116 |
currentEventSource = null;
|
| 117 |
}
|
| 118 |
-
|
| 119 |
const digit = document.getElementById('digitInput').value;
|
| 120 |
const canvas = document.getElementById('canvas');
|
| 121 |
const ctx = canvas.getContext('2d');
|
| 122 |
const log = document.getElementById('log');
|
| 123 |
const progressBar = document.getElementById('progress-fill');
|
| 124 |
-
pixelCounter = 0;
|
| 125 |
-
|
| 126 |
-
// Clear previous content
|
| 127 |
-
ctx.fillStyle = 'white';
|
| 128 |
ctx.fillRect(0, 0, canvas.width, canvas.height);
|
| 129 |
log.textContent = 'Generating...';
|
| 130 |
log.className = '';
|
| 131 |
progressBar.style.width = '0%';
|
| 132 |
|
| 133 |
if (currentModel === 'conv') {
|
| 134 |
-
// For ConvGenerator (instant generation)
|
| 135 |
fetch(`/generate_conv_digit?digit=${digit}`)
|
| 136 |
.then(response => {
|
| 137 |
if (!response.ok) {
|
| 138 |
-
return response.text().then(text => {
|
| 139 |
-
throw new Error(text || `HTTP error! status: ${response.status}`);
|
| 140 |
-
});
|
| 141 |
}
|
| 142 |
return response.blob();
|
| 143 |
})
|
|
@@ -148,34 +104,51 @@
|
|
| 148 |
log.textContent = 'Generated!';
|
| 149 |
setGenerating(false);
|
| 150 |
};
|
| 151 |
-
img.onerror = () => {
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
};
|
|
|
|
| 154 |
img.src = URL.createObjectURL(blob);
|
| 155 |
})
|
| 156 |
.catch(error => {
|
| 157 |
console.error('Error:', error);
|
| 158 |
log.textContent = `Error generating image: ${error.message}`;
|
| 159 |
-
log.className = '
|
| 160 |
setGenerating(false);
|
| 161 |
});
|
| 162 |
} else if (currentModel === 'vq' || currentModel === 'vq-vae') {
|
| 163 |
-
// Special handling for VQ models
|
| 164 |
const imageData = ctx.createImageData(28, 28);
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
`/stream_digit?digit=${digit}`;
|
| 170 |
-
|
| 171 |
if (currentModel === 'vq-vae') {
|
| 172 |
-
// For VQ-VAE direct reconstruction (non-streamed)
|
| 173 |
fetch(endpoint)
|
| 174 |
.then(response => {
|
| 175 |
if (!response.ok) {
|
| 176 |
-
return response.text().then(text => {
|
| 177 |
-
throw new Error(text || `HTTP error! status: ${response.status}`);
|
| 178 |
-
});
|
| 179 |
}
|
| 180 |
return response.blob();
|
| 181 |
})
|
|
@@ -186,135 +159,102 @@
|
|
| 186 |
log.textContent = 'Generated!';
|
| 187 |
setGenerating(false);
|
| 188 |
};
|
| 189 |
-
img.onerror = () => {
|
| 190 |
-
throw new Error('Failed to load generated image');
|
| 191 |
-
};
|
| 192 |
img.src = URL.createObjectURL(blob);
|
| 193 |
})
|
| 194 |
.catch(error => {
|
| 195 |
console.error('Error:', error);
|
| 196 |
log.textContent = `Error generating image: ${error.message}`;
|
| 197 |
-
log.className = '
|
| 198 |
setGenerating(false);
|
| 199 |
});
|
| 200 |
} else {
|
| 201 |
-
// For VQ-Transformer (streamed)
|
| 202 |
currentEventSource = new EventSource(endpoint);
|
| 203 |
-
|
| 204 |
currentEventSource.onmessage = function(event) {
|
| 205 |
const data = event.data;
|
| 206 |
-
|
| 207 |
if (data.startsWith('Error:')) {
|
| 208 |
log.textContent = data;
|
| 209 |
-
log.className = '
|
| 210 |
currentEventSource.close();
|
| 211 |
setGenerating(false);
|
| 212 |
return;
|
| 213 |
}
|
| 214 |
-
|
| 215 |
-
// Check if it's a token progress update
|
| 216 |
if (data.startsWith('token:')) {
|
| 217 |
-
const
|
| 218 |
-
const tokenNum = parseInt(parts[1]);
|
| 219 |
-
const progress = parseInt(parts[2]);
|
| 220 |
-
|
| 221 |
-
// Update progress bar
|
| 222 |
progressBar.style.width = `${progress}%`;
|
| 223 |
log.textContent = `Generating tokens: ${tokenNum}/49 (${progress}%)`;
|
| 224 |
return;
|
| 225 |
}
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
return;
|
| 232 |
}
|
| 233 |
-
|
| 234 |
-
|
| 235 |
const x = pixelCounter % 28;
|
| 236 |
const y = Math.floor(pixelCounter / 28);
|
| 237 |
-
|
| 238 |
-
// Set RGB values for this pixel
|
| 239 |
const idx = (y * 28 + x) * 4;
|
| 240 |
-
imageData.data[idx] = pixelValue;
|
| 241 |
-
imageData.data[idx + 1] = pixelValue;
|
| 242 |
-
imageData.data[idx + 2] = pixelValue;
|
| 243 |
-
imageData.data[idx + 3] = 255;
|
| 244 |
-
|
| 245 |
pixelCounter++;
|
| 246 |
-
|
| 247 |
-
// Update canvas every 28 pixels (full row)
|
| 248 |
-
if (x === 27 || pixelCounter === 28*28) {
|
| 249 |
ctx.putImageData(imageData, 0, 0);
|
| 250 |
-
|
| 251 |
-
if (pixelCounter >= 28*28) {
|
| 252 |
currentEventSource.close();
|
| 253 |
log.textContent = 'Generation complete!';
|
| 254 |
setGenerating(false);
|
| 255 |
}
|
| 256 |
}
|
| 257 |
};
|
| 258 |
-
|
| 259 |
-
currentEventSource.onerror = function(error) {
|
| 260 |
-
console.error('EventSource error:', error);
|
| 261 |
currentEventSource.close();
|
| 262 |
-
log.textContent = 'Error in streaming!';
|
| 263 |
-
log.className = 'error';
|
| 264 |
setGenerating(false);
|
| 265 |
};
|
| 266 |
}
|
| 267 |
} else {
|
| 268 |
-
// For PixelTransformer & MoEPixelTransformer (pixel streaming)
|
| 269 |
const imageData = ctx.createImageData(28, 28);
|
| 270 |
let index = 0;
|
| 271 |
-
|
| 272 |
currentEventSource = new EventSource(`/stream_digit?digit=${digit}`);
|
| 273 |
-
|
| 274 |
currentEventSource.onmessage = function(event) {
|
| 275 |
const data = event.data;
|
| 276 |
if (data.startsWith('Error:')) {
|
| 277 |
log.textContent = data;
|
| 278 |
-
log.className = '
|
| 279 |
currentEventSource.close();
|
| 280 |
-
currentEventSource = null;
|
| 281 |
setGenerating(false);
|
| 282 |
return;
|
| 283 |
}
|
| 284 |
-
|
| 285 |
const pixelValue = parseInt(data);
|
| 286 |
-
if (isNaN(pixelValue))
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
// Set RGB values to the same value for grayscale
|
| 292 |
-
imageData.data[index] = pixelValue; // R
|
| 293 |
-
imageData.data[index + 1] = pixelValue; // G
|
| 294 |
-
imageData.data[index + 2] = pixelValue; // B
|
| 295 |
-
imageData.data[index + 3] = 255; // A (opacity)
|
| 296 |
index += 4;
|
| 297 |
-
|
| 298 |
-
// Update canvas every row (28 pixels)
|
| 299 |
if (index % (28 * 4) === 0) {
|
| 300 |
ctx.putImageData(imageData, 0, 0);
|
| 301 |
}
|
| 302 |
-
|
| 303 |
if (index >= 28 * 28 * 4) {
|
| 304 |
currentEventSource.close();
|
| 305 |
-
currentEventSource = null;
|
| 306 |
log.textContent = 'Generation complete!';
|
| 307 |
setGenerating(false);
|
| 308 |
}
|
| 309 |
};
|
| 310 |
-
|
| 311 |
-
currentEventSource.
|
| 312 |
-
|
| 313 |
-
currentEventSource.close();
|
| 314 |
-
currentEventSource = null;
|
| 315 |
-
log.textContent = 'Error in streaming!';
|
| 316 |
-
log.className = 'error';
|
| 317 |
-
setGenerating(false);
|
| 318 |
};
|
| 319 |
}
|
| 320 |
}
|
|
|
|
| 1 |
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Image Generator</title>
|
| 7 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
</head>
|
| 9 |
+
<body class="bg-black text-white">
|
| 10 |
+
<div class="flex flex-col items-center justify-center min-h-screen space-y-6 p-4">
|
| 11 |
+
<h1 class="text-3xl font-semibold">Image Generator</h1>
|
| 12 |
+
<div class="flex flex-wrap items-center justify-center space-x-2">
|
| 13 |
+
<input id="digitInput" type="number" min="0" max="9" value="7"
|
| 14 |
+
class="w-16 px-2 py-1 bg-gray-800 border border-gray-600 rounded text-white focus:outline-none"/>
|
| 15 |
+
<input id="stepsInput" type="number" min="1" max="1000" value="50" placeholder="steps"
|
| 16 |
+
class="hidden w-20 px-2 py-1 bg-gray-800 border border-gray-600 rounded text-white focus:outline-none"/>
|
| 17 |
+
<button id="generateBtn" onclick="generateDigit()"
|
| 18 |
+
class="px-4 py-1 bg-gray-700 hover:bg-gray-600 rounded disabled:opacity-50">
|
| 19 |
+
Generate
|
| 20 |
+
</button>
|
| 21 |
+
<select id="modelSelector" onchange="selectModel()"
|
| 22 |
+
class="px-2 py-1 bg-gray-800 border border-gray-600 rounded text-white focus:outline-none">
|
| 23 |
+
{% for name, available in available_models.items() %}
|
| 24 |
+
{% if available %}
|
| 25 |
+
<option value="{{ name }}" {% if selected_model == name %}selected{% endif %}>
|
| 26 |
+
{{ name|capitalize }}
|
| 27 |
+
</option>
|
| 28 |
+
{% endif %}
|
| 29 |
+
{% endfor %}
|
| 30 |
+
</select>
|
| 31 |
+
</div>
|
| 32 |
+
<canvas id="canvas" width="28" height="28"
|
| 33 |
+
class="w-[280px] h-[280px] border border-gray-600 bg-black"
|
| 34 |
+
style="image-rendering: pixelated;"></canvas>
|
| 35 |
+
<div id="progress-container" class="w-[280px] bg-gray-800 rounded overflow-hidden"
|
| 36 |
+
style="display: none;">
|
| 37 |
+
<div id="progress-fill" class="bg-white h-1 w-0 transition-all"></div>
|
| 38 |
+
</div>
|
| 39 |
+
<div id="log" class="text-sm text-gray-400"></div>
|
| 40 |
</div>
|
|
|
|
|
|
|
| 41 |
<script>
|
| 42 |
let currentModel = '{{ selected_model }}';
|
| 43 |
let currentEventSource = null;
|
| 44 |
let isGenerating = false;
|
| 45 |
+
let pixelCounter = 0;
|
| 46 |
|
| 47 |
function selectModel() {
|
| 48 |
const modelSelector = document.getElementById('modelSelector');
|
| 49 |
currentModel = modelSelector.value;
|
|
|
|
|
|
|
| 50 |
fetch('/select_model', {
|
| 51 |
method: 'POST',
|
| 52 |
+
headers: {'Content-Type': 'application/json'},
|
| 53 |
+
body: JSON.stringify({model_type: currentModel})
|
|
|
|
|
|
|
| 54 |
});
|
| 55 |
+
document.getElementById('progress-container').style.display =
|
|
|
|
|
|
|
| 56 |
(currentModel === 'vq' || currentModel === 'vq-vae') ? 'block' : 'none';
|
| 57 |
+
const stepsInput = document.getElementById('stepsInput');
|
| 58 |
+
if (currentModel === 'diffusion') {
|
| 59 |
+
stepsInput.classList.remove('hidden');
|
| 60 |
+
} else {
|
| 61 |
+
stepsInput.classList.add('hidden');
|
| 62 |
+
}
|
| 63 |
}
|
| 64 |
+
// initialize UI based on default selected model
|
| 65 |
+
selectModel();
|
| 66 |
|
| 67 |
function setGenerating(generating) {
|
| 68 |
isGenerating = generating;
|
|
|
|
| 73 |
function generateDigit() {
|
| 74 |
if (isGenerating) return;
|
| 75 |
setGenerating(true);
|
|
|
|
|
|
|
| 76 |
if (currentEventSource) {
|
| 77 |
currentEventSource.close();
|
| 78 |
currentEventSource = null;
|
| 79 |
}
|
|
|
|
| 80 |
const digit = document.getElementById('digitInput').value;
|
| 81 |
const canvas = document.getElementById('canvas');
|
| 82 |
const ctx = canvas.getContext('2d');
|
| 83 |
const log = document.getElementById('log');
|
| 84 |
const progressBar = document.getElementById('progress-fill');
|
| 85 |
+
pixelCounter = 0;
|
| 86 |
+
ctx.fillStyle = 'black';
|
|
|
|
|
|
|
| 87 |
ctx.fillRect(0, 0, canvas.width, canvas.height);
|
| 88 |
log.textContent = 'Generating...';
|
| 89 |
log.className = '';
|
| 90 |
progressBar.style.width = '0%';
|
| 91 |
|
| 92 |
if (currentModel === 'conv') {
|
|
|
|
| 93 |
fetch(`/generate_conv_digit?digit=${digit}`)
|
| 94 |
.then(response => {
|
| 95 |
if (!response.ok) {
|
| 96 |
+
return response.text().then(text => { throw new Error(text || `HTTP error! status: ${response.status}`); });
|
|
|
|
|
|
|
| 97 |
}
|
| 98 |
return response.blob();
|
| 99 |
})
|
|
|
|
| 104 |
log.textContent = 'Generated!';
|
| 105 |
setGenerating(false);
|
| 106 |
};
|
| 107 |
+
img.onerror = () => { throw new Error('Failed to load generated image'); };
|
| 108 |
+
img.src = URL.createObjectURL(blob);
|
| 109 |
+
})
|
| 110 |
+
.catch(error => {
|
| 111 |
+
console.error('Error:', error);
|
| 112 |
+
log.textContent = `Error generating image: ${error.message}`;
|
| 113 |
+
log.className = 'text-red-500';
|
| 114 |
+
setGenerating(false);
|
| 115 |
+
});
|
| 116 |
+
} else if (currentModel === 'diffusion') {
|
| 117 |
+
const steps = document.getElementById('stepsInput').value;
|
| 118 |
+
fetch(`/generate_diffusion_digit?digit=${digit}&steps=${steps}`)
|
| 119 |
+
.then(response => {
|
| 120 |
+
if (!response.ok) {
|
| 121 |
+
return response.text().then(text => { throw new Error(text || `HTTP error! status: ${response.status}`); });
|
| 122 |
+
}
|
| 123 |
+
return response.blob();
|
| 124 |
+
})
|
| 125 |
+
.then(blob => {
|
| 126 |
+
const img = new Image();
|
| 127 |
+
img.onload = () => {
|
| 128 |
+
ctx.drawImage(img, 0, 0);
|
| 129 |
+
log.textContent = 'Generated!';
|
| 130 |
+
setGenerating(false);
|
| 131 |
};
|
| 132 |
+
img.onerror = () => { throw new Error('Failed to load generated image'); };
|
| 133 |
img.src = URL.createObjectURL(blob);
|
| 134 |
})
|
| 135 |
.catch(error => {
|
| 136 |
console.error('Error:', error);
|
| 137 |
log.textContent = `Error generating image: ${error.message}`;
|
| 138 |
+
log.className = 'text-red-500';
|
| 139 |
setGenerating(false);
|
| 140 |
});
|
| 141 |
} else if (currentModel === 'vq' || currentModel === 'vq-vae') {
|
|
|
|
| 142 |
const imageData = ctx.createImageData(28, 28);
|
| 143 |
+
const endpoint = currentModel === 'vq-vae'
|
| 144 |
+
? `/generate_vq_vae_digit?digit=${digit}`
|
| 145 |
+
: `/stream_digit?digit=${digit}`;
|
| 146 |
+
|
|
|
|
|
|
|
| 147 |
if (currentModel === 'vq-vae') {
|
|
|
|
| 148 |
fetch(endpoint)
|
| 149 |
.then(response => {
|
| 150 |
if (!response.ok) {
|
| 151 |
+
return response.text().then(text => { throw new Error(text || `HTTP error! status: ${response.status}`); });
|
|
|
|
|
|
|
| 152 |
}
|
| 153 |
return response.blob();
|
| 154 |
})
|
|
|
|
| 159 |
log.textContent = 'Generated!';
|
| 160 |
setGenerating(false);
|
| 161 |
};
|
| 162 |
+
img.onerror = () => { throw new Error('Failed to load generated image'); };
|
|
|
|
|
|
|
| 163 |
img.src = URL.createObjectURL(blob);
|
| 164 |
})
|
| 165 |
.catch(error => {
|
| 166 |
console.error('Error:', error);
|
| 167 |
log.textContent = `Error generating image: ${error.message}`;
|
| 168 |
+
log.className = 'text-red-500';
|
| 169 |
setGenerating(false);
|
| 170 |
});
|
| 171 |
} else {
|
|
|
|
| 172 |
currentEventSource = new EventSource(endpoint);
|
|
|
|
| 173 |
currentEventSource.onmessage = function(event) {
|
| 174 |
const data = event.data;
|
|
|
|
| 175 |
if (data.startsWith('Error:')) {
|
| 176 |
log.textContent = data;
|
| 177 |
+
log.className = 'text-red-500';
|
| 178 |
currentEventSource.close();
|
| 179 |
setGenerating(false);
|
| 180 |
return;
|
| 181 |
}
|
|
|
|
|
|
|
| 182 |
if (data.startsWith('token:')) {
|
| 183 |
+
const [, tokenNum, progress] = data.split(':');
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
progressBar.style.width = `${progress}%`;
|
| 185 |
log.textContent = `Generating tokens: ${tokenNum}/49 (${progress}%)`;
|
| 186 |
return;
|
| 187 |
}
|
| 188 |
+
if (data.startsWith('frame:')) {
|
| 189 |
+
const pixels = data.slice(6).split(',').map(Number);
|
| 190 |
+
for (let idx = 0; idx < pixels.length; idx++) {
|
| 191 |
+
const x = idx % 28;
|
| 192 |
+
const y = Math.floor(idx / 28);
|
| 193 |
+
const i = (y * 28 + x) * 4;
|
| 194 |
+
imageData.data[i] = pixels[idx];
|
| 195 |
+
imageData.data[i + 1] = pixels[idx];
|
| 196 |
+
imageData.data[i + 2] = pixels[idx];
|
| 197 |
+
imageData.data[i + 3] = 255;
|
| 198 |
+
}
|
| 199 |
+
ctx.putImageData(imageData, 0, 0);
|
| 200 |
return;
|
| 201 |
}
|
| 202 |
+
const pixelValue = parseInt(data);
|
| 203 |
+
if (isNaN(pixelValue)) return;
|
| 204 |
const x = pixelCounter % 28;
|
| 205 |
const y = Math.floor(pixelCounter / 28);
|
|
|
|
|
|
|
| 206 |
const idx = (y * 28 + x) * 4;
|
| 207 |
+
imageData.data[idx] = pixelValue;
|
| 208 |
+
imageData.data[idx + 1] = pixelValue;
|
| 209 |
+
imageData.data[idx + 2] = pixelValue;
|
| 210 |
+
imageData.data[idx + 3] = 255;
|
|
|
|
| 211 |
pixelCounter++;
|
| 212 |
+
if (x === 27 || pixelCounter === 28 * 28) {
|
|
|
|
|
|
|
| 213 |
ctx.putImageData(imageData, 0, 0);
|
| 214 |
+
if (pixelCounter >= 28 * 28) {
|
|
|
|
| 215 |
currentEventSource.close();
|
| 216 |
log.textContent = 'Generation complete!';
|
| 217 |
setGenerating(false);
|
| 218 |
}
|
| 219 |
}
|
| 220 |
};
|
| 221 |
+
currentEventSource.onerror = function(e) {
|
|
|
|
|
|
|
| 222 |
currentEventSource.close();
|
|
|
|
|
|
|
| 223 |
setGenerating(false);
|
| 224 |
};
|
| 225 |
}
|
| 226 |
} else {
|
|
|
|
| 227 |
const imageData = ctx.createImageData(28, 28);
|
| 228 |
let index = 0;
|
|
|
|
| 229 |
currentEventSource = new EventSource(`/stream_digit?digit=${digit}`);
|
|
|
|
| 230 |
currentEventSource.onmessage = function(event) {
|
| 231 |
const data = event.data;
|
| 232 |
if (data.startsWith('Error:')) {
|
| 233 |
log.textContent = data;
|
| 234 |
+
log.className = 'text-red-500';
|
| 235 |
currentEventSource.close();
|
|
|
|
| 236 |
setGenerating(false);
|
| 237 |
return;
|
| 238 |
}
|
|
|
|
| 239 |
const pixelValue = parseInt(data);
|
| 240 |
+
if (isNaN(pixelValue)) return;
|
| 241 |
+
imageData.data[index] = pixelValue;
|
| 242 |
+
imageData.data[index + 1] = pixelValue;
|
| 243 |
+
imageData.data[index + 2] = pixelValue;
|
| 244 |
+
imageData.data[index + 3] = 255;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
index += 4;
|
|
|
|
|
|
|
| 246 |
if (index % (28 * 4) === 0) {
|
| 247 |
ctx.putImageData(imageData, 0, 0);
|
| 248 |
}
|
|
|
|
| 249 |
if (index >= 28 * 28 * 4) {
|
| 250 |
currentEventSource.close();
|
|
|
|
| 251 |
log.textContent = 'Generation complete!';
|
| 252 |
setGenerating(false);
|
| 253 |
}
|
| 254 |
};
|
| 255 |
+
currentEventSource.onerror = function() {
|
| 256 |
+
currentEventSource.close();
|
| 257 |
+
setGenerating(false);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
};
|
| 259 |
}
|
| 260 |
}
|
train_conv.py
CHANGED
|
@@ -5,12 +5,18 @@ from torch.utils.data import DataLoader
|
|
| 5 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 6 |
from dataset import ConditionalMNISTDataset
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
class ConvConfig(PretrainedConfig):
|
| 9 |
model_type = "conv_generator"
|
| 10 |
def __init__(self, latent_dim=100, **kwargs):
|
| 11 |
super().__init__(**kwargs)
|
| 12 |
self.latent_dim = latent_dim
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
class ConvGeneratorModel(PreTrainedModel):
|
| 15 |
config_class = ConvConfig
|
| 16 |
def __init__(self, config):
|
|
@@ -51,11 +57,16 @@ class ConvGeneratorModel(PreTrainedModel):
|
|
| 51 |
return out
|
| 52 |
|
| 53 |
def main():
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
print(f"Using device: {device}")
|
| 56 |
|
| 57 |
dataset = ConditionalMNISTDataset("train")
|
| 58 |
-
loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0)
|
| 59 |
|
| 60 |
config = ConvConfig(latent_dim=100)
|
| 61 |
model = ConvGeneratorModel(config).to(device)
|
|
@@ -67,6 +78,7 @@ def main():
|
|
| 67 |
model.train()
|
| 68 |
total_loss = 0
|
| 69 |
for step, (x, y) in enumerate(loader, 1):
|
|
|
|
| 70 |
x = x.to(device)
|
| 71 |
y = y.to(device)
|
| 72 |
|
|
@@ -76,7 +88,7 @@ def main():
|
|
| 76 |
optimizer.zero_grad()
|
| 77 |
generated_images = model(labels) # (bsz, 1, 28, 28)
|
| 78 |
|
| 79 |
-
# Mean Squared Error loss
|
| 80 |
loss = F.mse_loss(generated_images, real_images)
|
| 81 |
loss.backward()
|
| 82 |
optimizer.step()
|
|
|
|
| 5 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 6 |
from dataset import ConditionalMNISTDataset
|
| 7 |
|
| 8 |
+
############################
|
| 9 |
+
# Config Class #
|
| 10 |
+
############################
|
| 11 |
class ConvConfig(PretrainedConfig):
|
| 12 |
model_type = "conv_generator"
|
| 13 |
def __init__(self, latent_dim=100, **kwargs):
|
| 14 |
super().__init__(**kwargs)
|
| 15 |
self.latent_dim = latent_dim
|
| 16 |
|
| 17 |
+
############################
|
| 18 |
+
# Model Class #
|
| 19 |
+
############################
|
| 20 |
class ConvGeneratorModel(PreTrainedModel):
|
| 21 |
config_class = ConvConfig
|
| 22 |
def __init__(self, config):
|
|
|
|
| 57 |
return out
|
| 58 |
|
| 59 |
def main():
|
| 60 |
+
# Ensure MPS is available
|
| 61 |
+
if not torch.backends.mps.is_available():
|
| 62 |
+
print("MPS not available. Falling back to CPU.")
|
| 63 |
+
device = "cpu"
|
| 64 |
+
else:
|
| 65 |
+
device = "mps"
|
| 66 |
print(f"Using device: {device}")
|
| 67 |
|
| 68 |
dataset = ConditionalMNISTDataset("train")
|
| 69 |
+
loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0) # Reduced workers for MPS
|
| 70 |
|
| 71 |
config = ConvConfig(latent_dim=100)
|
| 72 |
model = ConvGeneratorModel(config).to(device)
|
|
|
|
| 78 |
model.train()
|
| 79 |
total_loss = 0
|
| 80 |
for step, (x, y) in enumerate(loader, 1):
|
| 81 |
+
# Move both inputs to device
|
| 82 |
x = x.to(device)
|
| 83 |
y = y.to(device)
|
| 84 |
|
|
|
|
| 88 |
optimizer.zero_grad()
|
| 89 |
generated_images = model(labels) # (bsz, 1, 28, 28)
|
| 90 |
|
| 91 |
+
# Use Mean Squared Error loss
|
| 92 |
loss = F.mse_loss(generated_images, real_images)
|
| 93 |
loss.backward()
|
| 94 |
optimizer.step()
|
train_diff.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torchvision import transforms, datasets
|
| 7 |
+
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
def train_diffusion():
|
| 11 |
+
# Train and save a DDPM diffusion model on MNIST.
|
| 12 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 13 |
+
print(f"Using device: {device}")
|
| 14 |
+
|
| 15 |
+
transform = transforms.Compose([transforms.ToTensor()])
|
| 16 |
+
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
|
| 17 |
+
loader = DataLoader(train_ds, batch_size=128, shuffle=True)
|
| 18 |
+
|
| 19 |
+
# Conditional DDPM UNet for MNIST digits
|
| 20 |
+
unet = UNet2DModel(
|
| 21 |
+
sample_size=28,
|
| 22 |
+
in_channels=1,
|
| 23 |
+
out_channels=1,
|
| 24 |
+
block_out_channels=(32, 64, 128),
|
| 25 |
+
down_block_types=("DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
|
| 26 |
+
up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
| 27 |
+
num_class_embeds=10,
|
| 28 |
+
).to(device)
|
| 29 |
+
scheduler = DDPMScheduler(num_train_timesteps=1000)
|
| 30 |
+
pipeline = DDPMPipeline(unet=unet, scheduler=scheduler).to(device)
|
| 31 |
+
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4, weight_decay=1e-4) # changed from Adam
|
| 32 |
+
|
| 33 |
+
epochs = 5
|
| 34 |
+
print(f"Training DDPM for {epochs} epochs...")
|
| 35 |
+
try:
|
| 36 |
+
for epoch in range(1, epochs + 1):
|
| 37 |
+
pbar = tqdm(loader, desc=f"Epoch {epoch}/{epochs}")
|
| 38 |
+
for images, labels in pbar:
|
| 39 |
+
images = images.to(device)
|
| 40 |
+
labels = labels.to(device)
|
| 41 |
+
noise = torch.randn_like(images)
|
| 42 |
+
timesteps = torch.randint(
|
| 43 |
+
0, scheduler.num_train_timesteps, (images.shape[0],), device=device
|
| 44 |
+
).long()
|
| 45 |
+
noisy = scheduler.add_noise(images, noise, timesteps)
|
| 46 |
+
|
| 47 |
+
# Conditional noise prediction
|
| 48 |
+
model_pred = unet(noisy, timesteps, class_labels=labels, return_dict=False)[0]
|
| 49 |
+
loss = F.mse_loss(model_pred, noise)
|
| 50 |
+
optimizer.zero_grad()
|
| 51 |
+
loss.backward()
|
| 52 |
+
optimizer.step()
|
| 53 |
+
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
| 54 |
+
except KeyboardInterrupt:
|
| 55 |
+
print("\nKeyboard interrupt, saving model...")
|
| 56 |
+
output_dir = "my_diffusion_model"
|
| 57 |
+
pipeline.save_pretrained(output_dir)
|
| 58 |
+
print(f"Model saved to {output_dir}/")
|
| 59 |
+
return pipeline
|
| 60 |
+
|
| 61 |
+
output_dir = "my_diffusion_model"
|
| 62 |
+
pipeline.save_pretrained(output_dir)
|
| 63 |
+
print(f"Training complete. Model saved to {output_dir}/")
|
| 64 |
+
return pipeline
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
train_diffusion()
|
vq_transformer.py
CHANGED
|
@@ -19,7 +19,7 @@ class VQTransformerConfig:
|
|
| 19 |
epochs: int = 10
|
| 20 |
warmup_steps: int = 500
|
| 21 |
label_offset: int = 512 # Labels are tokens 512-521 (for digits 0-9)
|
| 22 |
-
device: str = field(default="mps" if torch.backends.mps.is_available() else "
|
| 23 |
|
| 24 |
@classmethod
|
| 25 |
def from_pretrained(cls, path: str):
|
|
|
|
| 19 |
epochs: int = 10
|
| 20 |
warmup_steps: int = 500
|
| 21 |
label_offset: int = 512 # Labels are tokens 512-521 (for digits 0-9)
|
| 22 |
+
device: str = field(default="mps" if torch.backends.mps.is_available() else "cpu")
|
| 23 |
|
| 24 |
@classmethod
|
| 25 |
def from_pretrained(cls, path: str):
|
vq_vae.py
CHANGED
|
@@ -64,6 +64,7 @@ class VQVAE(nn.Module):
|
|
| 64 |
nn.Conv2d(32, embedding_dim, 1, stride=1) # 7x7 -> 7x7xembedding_dim
|
| 65 |
)
|
| 66 |
|
|
|
|
| 67 |
self.vq = VectorQuantizer(num_embeddings, embedding_dim)
|
| 68 |
|
| 69 |
# Decoder for MNIST (7x7 -> 28x28)
|
|
@@ -94,6 +95,7 @@ class VQVAE(nn.Module):
|
|
| 94 |
quantized = torch.matmul(one_hot.permute(0, 2, 3, 1), self.vq.embedding.weight)
|
| 95 |
quantized = quantized.permute(0, 3, 1, 2)
|
| 96 |
|
|
|
|
| 97 |
reconstructed = self.decoder(quantized)
|
| 98 |
return reconstructed
|
| 99 |
|
|
@@ -137,14 +139,18 @@ class VQVAE(nn.Module):
|
|
| 137 |
|
| 138 |
@staticmethod
|
| 139 |
def train_and_save(output_path="vq_vae_model.pt", device='cpu', batch_size=128, epochs=10):
|
|
|
|
| 140 |
transform = transforms.Compose([transforms.ToTensor()])
|
| 141 |
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
|
| 142 |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 143 |
-
|
|
|
|
| 144 |
model = VQVAE().to(device)
|
| 145 |
-
|
|
|
|
| 146 |
model.train_model(train_loader, epochs=epochs, device=device)
|
| 147 |
|
|
|
|
| 148 |
torch.save(model.state_dict(), output_path)
|
| 149 |
print(f"Model saved to {output_path}")
|
| 150 |
|
|
|
|
| 64 |
nn.Conv2d(32, embedding_dim, 1, stride=1) # 7x7 -> 7x7xembedding_dim
|
| 65 |
)
|
| 66 |
|
| 67 |
+
# Vector Quantizer
|
| 68 |
self.vq = VectorQuantizer(num_embeddings, embedding_dim)
|
| 69 |
|
| 70 |
# Decoder for MNIST (7x7 -> 28x28)
|
|
|
|
| 95 |
quantized = torch.matmul(one_hot.permute(0, 2, 3, 1), self.vq.embedding.weight)
|
| 96 |
quantized = quantized.permute(0, 3, 1, 2)
|
| 97 |
|
| 98 |
+
# Decode
|
| 99 |
reconstructed = self.decoder(quantized)
|
| 100 |
return reconstructed
|
| 101 |
|
|
|
|
| 139 |
|
| 140 |
@staticmethod
|
| 141 |
def train_and_save(output_path="vq_vae_model.pt", device='cpu', batch_size=128, epochs=10):
|
| 142 |
+
# Setup data
|
| 143 |
transform = transforms.Compose([transforms.ToTensor()])
|
| 144 |
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
|
| 145 |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 146 |
+
|
| 147 |
+
# Create model
|
| 148 |
model = VQVAE().to(device)
|
| 149 |
+
|
| 150 |
+
# Train
|
| 151 |
model.train_model(train_loader, epochs=epochs, device=device)
|
| 152 |
|
| 153 |
+
# Save model
|
| 154 |
torch.save(model.state_dict(), output_path)
|
| 155 |
print(f"Model saved to {output_path}")
|
| 156 |
|