Spaces:
Sleeping
Sleeping
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Image Generator</title> | |
| <script src="https://cdn.tailwindcss.com"></script> | |
| </head> | |
| <body class="bg-black text-white"> | |
| <div class="flex flex-col items-center justify-center min-h-screen space-y-6 p-4"> | |
| <h1 class="text-3xl font-semibold">Image Generator</h1> | |
| <div class="flex flex-wrap items-center justify-center space-x-2"> | |
| <input id="digitInput" type="number" min="0" max="9" value="7" | |
| class="w-16 px-2 py-1 bg-gray-800 border border-gray-600 rounded text-white focus:outline-none"/> | |
| <input id="stepsInput" type="number" min="1" max="1000" value="50" placeholder="steps" | |
| class="hidden w-20 px-2 py-1 bg-gray-800 border border-gray-600 rounded text-white focus:outline-none"/> | |
| <button id="generateBtn" onclick="generateDigit()" | |
| class="px-4 py-1 bg-gray-700 hover:bg-gray-600 rounded disabled:opacity-50"> | |
| Generate | |
| </button> | |
| <select id="modelSelector" onchange="selectModel()" | |
| class="px-2 py-1 bg-gray-800 border border-gray-600 rounded text-white focus:outline-none"> | |
| {% for name, available in available_models.items() %} | |
| {% if available %} | |
| <option value="{{ name }}" {% if selected_model == name %}selected{% endif %}> | |
| {{ name|capitalize }} | |
| </option> | |
| {% endif %} | |
| {% endfor %} | |
| </select> | |
| </div> | |
| <canvas id="canvas" width="28" height="28" | |
| class="w-[280px] h-[280px] border border-gray-600 bg-black" | |
| style="image-rendering: pixelated;"></canvas> | |
| <div id="progress-container" class="w-[280px] bg-gray-800 rounded overflow-hidden" | |
| style="display: none;"> | |
| <div id="progress-fill" class="bg-white h-1 w-0 transition-all"></div> | |
| </div> | |
| <div id="log" class="text-sm text-gray-400"></div> | |
| </div> | |
| <script> | |
| let currentModel = '{{ selected_model }}'; | |
| let currentEventSource = null; | |
| let isGenerating = false; | |
| let pixelCounter = 0; | |
| function selectModel() { | |
| const modelSelector = document.getElementById('modelSelector'); | |
| currentModel = modelSelector.value; | |
| fetch('/select_model', { | |
| method: 'POST', | |
| headers: {'Content-Type': 'application/json'}, | |
| body: JSON.stringify({model_type: currentModel}) | |
| }); | |
| document.getElementById('progress-container').style.display = | |
| (currentModel === 'vq' || currentModel === 'vq-vae') ? 'block' : 'none'; | |
| const stepsInput = document.getElementById('stepsInput'); | |
| if (currentModel === 'diffusion') { | |
| stepsInput.classList.remove('hidden'); | |
| } else { | |
| stepsInput.classList.add('hidden'); | |
| } | |
| } | |
| // initialize UI based on default selected model | |
| selectModel(); | |
| function setGenerating(generating) { | |
| isGenerating = generating; | |
| document.getElementById('generateBtn').disabled = generating; | |
| document.getElementById('modelSelector').disabled = generating; | |
| } | |
| function generateDigit() { | |
| if (isGenerating) return; | |
| setGenerating(true); | |
| if (currentEventSource) { | |
| currentEventSource.close(); | |
| currentEventSource = null; | |
| } | |
| const digit = document.getElementById('digitInput').value; | |
| const canvas = document.getElementById('canvas'); | |
| const ctx = canvas.getContext('2d'); | |
| const log = document.getElementById('log'); | |
| const progressBar = document.getElementById('progress-fill'); | |
| pixelCounter = 0; | |
| ctx.fillStyle = 'black'; | |
| ctx.fillRect(0, 0, canvas.width, canvas.height); | |
| log.textContent = 'Generating...'; | |
| log.className = ''; | |
| progressBar.style.width = '0%'; | |
| if (currentModel === 'conv') { | |
| fetch(`/generate_conv_digit?digit=${digit}`) | |
| .then(response => { | |
| if (!response.ok) { | |
| return response.text().then(text => { throw new Error(text || `HTTP error! status: ${response.status}`); }); | |
| } | |
| return response.blob(); | |
| }) | |
| .then(blob => { | |
| const img = new Image(); | |
| img.onload = () => { | |
| ctx.drawImage(img, 0, 0); | |
| log.textContent = 'Generated!'; | |
| setGenerating(false); | |
| }; | |
| img.onerror = () => { throw new Error('Failed to load generated image'); }; | |
| img.src = URL.createObjectURL(blob); | |
| }) | |
| .catch(error => { | |
| console.error('Error:', error); | |
| log.textContent = `Error generating image: ${error.message}`; | |
| log.className = 'text-red-500'; | |
| setGenerating(false); | |
| }); | |
| } else if (currentModel === 'diffusion') { | |
| const steps = document.getElementById('stepsInput').value; | |
| fetch(`/generate_diffusion_digit?digit=${digit}&steps=${steps}`) | |
| .then(response => { | |
| if (!response.ok) { | |
| return response.text().then(text => { throw new Error(text || `HTTP error! status: ${response.status}`); }); | |
| } | |
| return response.blob(); | |
| }) | |
| .then(blob => { | |
| const img = new Image(); | |
| img.onload = () => { | |
| ctx.drawImage(img, 0, 0); | |
| log.textContent = 'Generated!'; | |
| setGenerating(false); | |
| }; | |
| img.onerror = () => { throw new Error('Failed to load generated image'); }; | |
| img.src = URL.createObjectURL(blob); | |
| }) | |
| .catch(error => { | |
| console.error('Error:', error); | |
| log.textContent = `Error generating image: ${error.message}`; | |
| log.className = 'text-red-500'; | |
| setGenerating(false); | |
| }); | |
| } else if (currentModel === 'vq' || currentModel === 'vq-vae') { | |
| const imageData = ctx.createImageData(28, 28); | |
| const endpoint = currentModel === 'vq-vae' | |
| ? `/generate_vq_vae_digit?digit=${digit}` | |
| : `/stream_digit?digit=${digit}`; | |
| if (currentModel === 'vq-vae') { | |
| fetch(endpoint) | |
| .then(response => { | |
| if (!response.ok) { | |
| return response.text().then(text => { throw new Error(text || `HTTP error! status: ${response.status}`); }); | |
| } | |
| return response.blob(); | |
| }) | |
| .then(blob => { | |
| const img = new Image(); | |
| img.onload = () => { | |
| ctx.drawImage(img, 0, 0); | |
| log.textContent = 'Generated!'; | |
| setGenerating(false); | |
| }; | |
| img.onerror = () => { throw new Error('Failed to load generated image'); }; | |
| img.src = URL.createObjectURL(blob); | |
| }) | |
| .catch(error => { | |
| console.error('Error:', error); | |
| log.textContent = `Error generating image: ${error.message}`; | |
| log.className = 'text-red-500'; | |
| setGenerating(false); | |
| }); | |
| } else { | |
| currentEventSource = new EventSource(endpoint); | |
| currentEventSource.onmessage = function(event) { | |
| const data = event.data; | |
| if (data.startsWith('Error:')) { | |
| log.textContent = data; | |
| log.className = 'text-red-500'; | |
| currentEventSource.close(); | |
| setGenerating(false); | |
| return; | |
| } | |
| if (data.startsWith('token:')) { | |
| const [, tokenNum, progress] = data.split(':'); | |
| progressBar.style.width = `${progress}%`; | |
| log.textContent = `Generating tokens: ${tokenNum}/49 (${progress}%)`; | |
| return; | |
| } | |
| if (data.startsWith('frame:')) { | |
| const pixels = data.slice(6).split(',').map(Number); | |
| for (let idx = 0; idx < pixels.length; idx++) { | |
| const x = idx % 28; | |
| const y = Math.floor(idx / 28); | |
| const i = (y * 28 + x) * 4; | |
| imageData.data[i] = pixels[idx]; | |
| imageData.data[i + 1] = pixels[idx]; | |
| imageData.data[i + 2] = pixels[idx]; | |
| imageData.data[i + 3] = 255; | |
| } | |
| ctx.putImageData(imageData, 0, 0); | |
| return; | |
| } | |
| const pixelValue = parseInt(data); | |
| if (isNaN(pixelValue)) return; | |
| const x = pixelCounter % 28; | |
| const y = Math.floor(pixelCounter / 28); | |
| const idx = (y * 28 + x) * 4; | |
| imageData.data[idx] = pixelValue; | |
| imageData.data[idx + 1] = pixelValue; | |
| imageData.data[idx + 2] = pixelValue; | |
| imageData.data[idx + 3] = 255; | |
| pixelCounter++; | |
| if (x === 27 || pixelCounter === 28 * 28) { | |
| ctx.putImageData(imageData, 0, 0); | |
| if (pixelCounter >= 28 * 28) { | |
| currentEventSource.close(); | |
| log.textContent = 'Generation complete!'; | |
| setGenerating(false); | |
| } | |
| } | |
| }; | |
| currentEventSource.onerror = function(e) { | |
| currentEventSource.close(); | |
| setGenerating(false); | |
| }; | |
| } | |
| } else { | |
| const imageData = ctx.createImageData(28, 28); | |
| let index = 0; | |
| currentEventSource = new EventSource(`/stream_digit?digit=${digit}`); | |
| currentEventSource.onmessage = function(event) { | |
| const data = event.data; | |
| if (data.startsWith('Error:')) { | |
| log.textContent = data; | |
| log.className = 'text-red-500'; | |
| currentEventSource.close(); | |
| setGenerating(false); | |
| return; | |
| } | |
| const pixelValue = parseInt(data); | |
| if (isNaN(pixelValue)) return; | |
| imageData.data[index] = pixelValue; | |
| imageData.data[index + 1] = pixelValue; | |
| imageData.data[index + 2] = pixelValue; | |
| imageData.data[index + 3] = 255; | |
| index += 4; | |
| if (index % (28 * 4) === 0) { | |
| ctx.putImageData(imageData, 0, 0); | |
| } | |
| if (index >= 28 * 28 * 4) { | |
| currentEventSource.close(); | |
| log.textContent = 'Generation complete!'; | |
| setGenerating(false); | |
| } | |
| }; | |
| currentEventSource.onerror = function() { | |
| currentEventSource.close(); | |
| setGenerating(false); | |
| }; | |
| } | |
| } | |
| </script> | |
| </body> | |
| </html> |