# -*- coding: utf-8 -*- """app.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/104LGaVaxm2qBgfuuNORSOHuzKCmgOVKm """ ### 1. Imports and setup ### import gradio as gr import torch import numpy as np from model import create_diffusion_model from timeit import default_timer as timer from typing import Tuple # ------------------------------ # 2. Model preparation # ------------------------------ device = "cuda" if torch.cuda.is_available() else "cpu" # Digit vocabulary digit_words = ["zero","one","two","three","four","five","six","seven","eight","nine"] word_to_idx = {w:i for i,w in enumerate(digit_words)} # Create diffusion model model = create_diffusion_model( emb_dim=64, timesteps=100, seed=42, device=device ) # Load trained weights model.load_state_dict( torch.load( "text_to_digit_diffusion_mnist.pth", map_location=device ) ) model.eval() # ------------------------------ # 3. Diffusion scheduler utils # ------------------------------ timesteps = 100 betas = torch.linspace(1e-4, 0.02, timesteps).to(device) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) def extract(a, t, x_shape): return a.gather(-1, t).reshape(-1,1,1,1).expand(x_shape) # ------------------------------ # 4. Predict / Generate function # ------------------------------ @torch.inference_mode() def predict(text: str) -> Tuple[np.ndarray, float]: start_time = timer() text = text.strip().lower() # Normalize input if text.isdigit(): label = int(text) else: label = word_to_idx.get(text, None) if label is None or not (0 <= label <= 9): raise ValueError("Please enter a digit (0–9) or its word form.") labels = torch.tensor([label], device=device) uncond_labels = torch.tensor([0], device=device) # Start from noise x = torch.randn(1, 1, 28, 28, device=device) guidance_scale = 3.0 for i in reversed(range(1, timesteps)): t = torch.full((1,), i, device=device) pred_cond = model(x, t, labels) pred_uncond = model(x, t, uncond_labels) pred_noise = pred_uncond + guidance_scale * (pred_cond - pred_uncond) beta_t = extract(betas, t, x.shape) alpha_t = extract(alphas, t, x.shape) alpha_bar_t = extract(alphas_cumprod, t, x.shape) pred_x0 = (x - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t) x = torch.sqrt(alpha_t) * pred_x0 + torch.sqrt(beta_t) * pred_noise img = (x.clamp(-1,1) + 1) / 2 img = img[0,0].cpu().numpy() end_time = timer() gen_time = round(end_time - start_time, 4) return img, gen_time # ------------------------------ # 5. Gradio app # ------------------------------ title = "Text-to-Digit Diffusion (MNIST)" description = ( "A **conditional diffusion model** trained on MNIST. " "Type a digit (e.g. `seven` or `7`) to generate a handwritten number." ) article = "Created by [Programming Ocean Academy](https://www.programming-ocean.com/)" demo = gr.Interface( fn=predict, inputs=gr.Textbox( label="Enter digit (0–9 or word)", placeholder="seven or 7" ), outputs=[ gr.Image( label="Generated Digit", type="numpy", width=256, height=256 ), gr.Number(label="generation time (s)") ], title=title, description=description, article=article ) # Launch demo demo.launch(debug=False)