image_generator / templates /index.html
Kyryll Kochkin
new frontend
ad9ba57
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Image Generator</title>
<script src="https://cdn.tailwindcss.com"></script>
</head>
<body class="bg-black text-white">
<div class="flex flex-col items-center justify-center min-h-screen space-y-6 p-4">
<h1 class="text-3xl font-semibold">Image Generator</h1>
<div class="flex flex-wrap items-center justify-center space-x-2">
<input id="digitInput" type="number" min="0" max="9" value="7"
class="w-16 px-2 py-1 bg-gray-800 border border-gray-600 rounded text-white focus:outline-none"/>
<input id="stepsInput" type="number" min="1" max="1000" value="50" placeholder="steps"
class="hidden w-20 px-2 py-1 bg-gray-800 border border-gray-600 rounded text-white focus:outline-none"/>
<button id="generateBtn" onclick="generateDigit()"
class="px-4 py-1 bg-gray-700 hover:bg-gray-600 rounded disabled:opacity-50">
Generate
</button>
<select id="modelSelector" onchange="selectModel()"
class="px-2 py-1 bg-gray-800 border border-gray-600 rounded text-white focus:outline-none">
{% for name, available in available_models.items() %}
{% if available %}
<option value="{{ name }}" {% if selected_model == name %}selected{% endif %}>
{{ name|capitalize }}
</option>
{% endif %}
{% endfor %}
</select>
</div>
<canvas id="canvas" width="28" height="28"
class="w-[280px] h-[280px] border border-gray-600 bg-black"
style="image-rendering: pixelated;"></canvas>
<div id="progress-container" class="w-[280px] bg-gray-800 rounded overflow-hidden"
style="display: none;">
<div id="progress-fill" class="bg-white h-1 w-0 transition-all"></div>
</div>
<div id="log" class="text-sm text-gray-400"></div>
</div>
<script>
let currentModel = '{{ selected_model }}';
let currentEventSource = null;
let isGenerating = false;
let pixelCounter = 0;
function selectModel() {
const modelSelector = document.getElementById('modelSelector');
currentModel = modelSelector.value;
fetch('/select_model', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({model_type: currentModel})
});
document.getElementById('progress-container').style.display =
(currentModel === 'vq' || currentModel === 'vq-vae') ? 'block' : 'none';
const stepsInput = document.getElementById('stepsInput');
if (currentModel === 'diffusion') {
stepsInput.classList.remove('hidden');
} else {
stepsInput.classList.add('hidden');
}
}
// initialize UI based on default selected model
selectModel();
function setGenerating(generating) {
isGenerating = generating;
document.getElementById('generateBtn').disabled = generating;
document.getElementById('modelSelector').disabled = generating;
}
function generateDigit() {
if (isGenerating) return;
setGenerating(true);
if (currentEventSource) {
currentEventSource.close();
currentEventSource = null;
}
const digit = document.getElementById('digitInput').value;
const canvas = document.getElementById('canvas');
const ctx = canvas.getContext('2d');
const log = document.getElementById('log');
const progressBar = document.getElementById('progress-fill');
pixelCounter = 0;
ctx.fillStyle = 'black';
ctx.fillRect(0, 0, canvas.width, canvas.height);
log.textContent = 'Generating...';
log.className = '';
progressBar.style.width = '0%';
if (currentModel === 'conv') {
fetch(`/generate_conv_digit?digit=${digit}`)
.then(response => {
if (!response.ok) {
return response.text().then(text => { throw new Error(text || `HTTP error! status: ${response.status}`); });
}
return response.blob();
})
.then(blob => {
const img = new Image();
img.onload = () => {
ctx.drawImage(img, 0, 0);
log.textContent = 'Generated!';
setGenerating(false);
};
img.onerror = () => { throw new Error('Failed to load generated image'); };
img.src = URL.createObjectURL(blob);
})
.catch(error => {
console.error('Error:', error);
log.textContent = `Error generating image: ${error.message}`;
log.className = 'text-red-500';
setGenerating(false);
});
} else if (currentModel === 'diffusion') {
const steps = document.getElementById('stepsInput').value;
fetch(`/generate_diffusion_digit?digit=${digit}&steps=${steps}`)
.then(response => {
if (!response.ok) {
return response.text().then(text => { throw new Error(text || `HTTP error! status: ${response.status}`); });
}
return response.blob();
})
.then(blob => {
const img = new Image();
img.onload = () => {
ctx.drawImage(img, 0, 0);
log.textContent = 'Generated!';
setGenerating(false);
};
img.onerror = () => { throw new Error('Failed to load generated image'); };
img.src = URL.createObjectURL(blob);
})
.catch(error => {
console.error('Error:', error);
log.textContent = `Error generating image: ${error.message}`;
log.className = 'text-red-500';
setGenerating(false);
});
} else if (currentModel === 'vq' || currentModel === 'vq-vae') {
const imageData = ctx.createImageData(28, 28);
const endpoint = currentModel === 'vq-vae'
? `/generate_vq_vae_digit?digit=${digit}`
: `/stream_digit?digit=${digit}`;
if (currentModel === 'vq-vae') {
fetch(endpoint)
.then(response => {
if (!response.ok) {
return response.text().then(text => { throw new Error(text || `HTTP error! status: ${response.status}`); });
}
return response.blob();
})
.then(blob => {
const img = new Image();
img.onload = () => {
ctx.drawImage(img, 0, 0);
log.textContent = 'Generated!';
setGenerating(false);
};
img.onerror = () => { throw new Error('Failed to load generated image'); };
img.src = URL.createObjectURL(blob);
})
.catch(error => {
console.error('Error:', error);
log.textContent = `Error generating image: ${error.message}`;
log.className = 'text-red-500';
setGenerating(false);
});
} else {
currentEventSource = new EventSource(endpoint);
currentEventSource.onmessage = function(event) {
const data = event.data;
if (data.startsWith('Error:')) {
log.textContent = data;
log.className = 'text-red-500';
currentEventSource.close();
setGenerating(false);
return;
}
if (data.startsWith('token:')) {
const [, tokenNum, progress] = data.split(':');
progressBar.style.width = `${progress}%`;
log.textContent = `Generating tokens: ${tokenNum}/49 (${progress}%)`;
return;
}
if (data.startsWith('frame:')) {
const pixels = data.slice(6).split(',').map(Number);
for (let idx = 0; idx < pixels.length; idx++) {
const x = idx % 28;
const y = Math.floor(idx / 28);
const i = (y * 28 + x) * 4;
imageData.data[i] = pixels[idx];
imageData.data[i + 1] = pixels[idx];
imageData.data[i + 2] = pixels[idx];
imageData.data[i + 3] = 255;
}
ctx.putImageData(imageData, 0, 0);
return;
}
const pixelValue = parseInt(data);
if (isNaN(pixelValue)) return;
const x = pixelCounter % 28;
const y = Math.floor(pixelCounter / 28);
const idx = (y * 28 + x) * 4;
imageData.data[idx] = pixelValue;
imageData.data[idx + 1] = pixelValue;
imageData.data[idx + 2] = pixelValue;
imageData.data[idx + 3] = 255;
pixelCounter++;
if (x === 27 || pixelCounter === 28 * 28) {
ctx.putImageData(imageData, 0, 0);
if (pixelCounter >= 28 * 28) {
currentEventSource.close();
log.textContent = 'Generation complete!';
setGenerating(false);
}
}
};
currentEventSource.onerror = function(e) {
currentEventSource.close();
setGenerating(false);
};
}
} else {
const imageData = ctx.createImageData(28, 28);
let index = 0;
currentEventSource = new EventSource(`/stream_digit?digit=${digit}`);
currentEventSource.onmessage = function(event) {
const data = event.data;
if (data.startsWith('Error:')) {
log.textContent = data;
log.className = 'text-red-500';
currentEventSource.close();
setGenerating(false);
return;
}
const pixelValue = parseInt(data);
if (isNaN(pixelValue)) return;
imageData.data[index] = pixelValue;
imageData.data[index + 1] = pixelValue;
imageData.data[index + 2] = pixelValue;
imageData.data[index + 3] = 255;
index += 4;
if (index % (28 * 4) === 0) {
ctx.putImageData(imageData, 0, 0);
}
if (index >= 28 * 28 * 4) {
currentEventSource.close();
log.textContent = 'Generation complete!';
setGenerating(false);
}
};
currentEventSource.onerror = function() {
currentEventSource.close();
setGenerating(false);
};
}
}
</script>
</body>
</html>