Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """app.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/104LGaVaxm2qBgfuuNORSOHuzKCmgOVKm | |
| """ | |
| ### 1. Imports and setup ### | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from model import create_diffusion_model | |
| from timeit import default_timer as timer | |
| from typing import Tuple | |
| # ------------------------------ | |
| # 2. Model preparation | |
| # ------------------------------ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Digit vocabulary | |
| digit_words = ["zero","one","two","three","four","five","six","seven","eight","nine"] | |
| word_to_idx = {w:i for i,w in enumerate(digit_words)} | |
| # Create diffusion model | |
| model = create_diffusion_model( | |
| emb_dim=64, | |
| timesteps=100, | |
| seed=42, | |
| device=device | |
| ) | |
| # Load trained weights | |
| model.load_state_dict( | |
| torch.load( | |
| "text_to_digit_diffusion_mnist.pth", | |
| map_location=device | |
| ) | |
| ) | |
| model.eval() | |
| # ------------------------------ | |
| # 3. Diffusion scheduler utils | |
| # ------------------------------ | |
| timesteps = 100 | |
| betas = torch.linspace(1e-4, 0.02, timesteps).to(device) | |
| alphas = 1.0 - betas | |
| alphas_cumprod = torch.cumprod(alphas, dim=0) | |
| def extract(a, t, x_shape): | |
| return a.gather(-1, t).reshape(-1,1,1,1).expand(x_shape) | |
| # ------------------------------ | |
| # 4. Predict / Generate function | |
| # ------------------------------ | |
| def predict(text: str) -> Tuple[np.ndarray, float]: | |
| start_time = timer() | |
| text = text.strip().lower() | |
| # Normalize input | |
| if text.isdigit(): | |
| label = int(text) | |
| else: | |
| label = word_to_idx.get(text, None) | |
| if label is None or not (0 <= label <= 9): | |
| raise ValueError("Please enter a digit (0–9) or its word form.") | |
| labels = torch.tensor([label], device=device) | |
| uncond_labels = torch.tensor([0], device=device) | |
| # Start from noise | |
| x = torch.randn(1, 1, 28, 28, device=device) | |
| guidance_scale = 3.0 | |
| for i in reversed(range(1, timesteps)): | |
| t = torch.full((1,), i, device=device) | |
| pred_cond = model(x, t, labels) | |
| pred_uncond = model(x, t, uncond_labels) | |
| pred_noise = pred_uncond + guidance_scale * (pred_cond - pred_uncond) | |
| beta_t = extract(betas, t, x.shape) | |
| alpha_t = extract(alphas, t, x.shape) | |
| alpha_bar_t = extract(alphas_cumprod, t, x.shape) | |
| pred_x0 = (x - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t) | |
| x = torch.sqrt(alpha_t) * pred_x0 + torch.sqrt(beta_t) * pred_noise | |
| img = (x.clamp(-1,1) + 1) / 2 | |
| img = img[0,0].cpu().numpy() | |
| end_time = timer() | |
| gen_time = round(end_time - start_time, 4) | |
| return img, gen_time | |
| # ------------------------------ | |
| # 5. Gradio app | |
| # ------------------------------ | |
| title = "Text-to-Digit Diffusion (MNIST)" | |
| description = ( | |
| "A **conditional diffusion model** trained on MNIST. " | |
| "Type a digit (e.g. `seven` or `7`) to generate a handwritten number." | |
| ) | |
| article = "Created by [Programming Ocean Academy](https://www.programming-ocean.com/)" | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox( | |
| label="Enter digit (0–9 or word)", | |
| placeholder="seven or 7" | |
| ), | |
| outputs=[ | |
| gr.Image( | |
| label="Generated Digit", | |
| type="numpy", | |
| width=256, | |
| height=256 | |
| ), | |
| gr.Number(label="generation time (s)") | |
| ], | |
| title=title, | |
| description=description, | |
| article=article | |
| ) | |
| # Launch demo | |
| demo.launch(debug=False) | |