diffusion-model / app.py
Dulamaa's picture
intital commit
88ba839
Raw
History Blame Contribute Delete
3.55 kB
# -*- coding: utf-8 -*-
"""app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/104LGaVaxm2qBgfuuNORSOHuzKCmgOVKm
"""
### 1. Imports and setup ###
import gradio as gr
import torch
import numpy as np
from model import create_diffusion_model
from timeit import default_timer as timer
from typing import Tuple
# ------------------------------
# 2. Model preparation
# ------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# Digit vocabulary
digit_words = ["zero","one","two","three","four","five","six","seven","eight","nine"]
word_to_idx = {w:i for i,w in enumerate(digit_words)}
# Create diffusion model
model = create_diffusion_model(
emb_dim=64,
timesteps=100,
seed=42,
device=device
)
# Load trained weights
model.load_state_dict(
torch.load(
"text_to_digit_diffusion_mnist.pth",
map_location=device
)
)
model.eval()
# ------------------------------
# 3. Diffusion scheduler utils
# ------------------------------
timesteps = 100
betas = torch.linspace(1e-4, 0.02, timesteps).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
def extract(a, t, x_shape):
return a.gather(-1, t).reshape(-1,1,1,1).expand(x_shape)
# ------------------------------
# 4. Predict / Generate function
# ------------------------------
@torch.inference_mode()
def predict(text: str) -> Tuple[np.ndarray, float]:
start_time = timer()
text = text.strip().lower()
# Normalize input
if text.isdigit():
label = int(text)
else:
label = word_to_idx.get(text, None)
if label is None or not (0 <= label <= 9):
raise ValueError("Please enter a digit (0–9) or its word form.")
labels = torch.tensor([label], device=device)
uncond_labels = torch.tensor([0], device=device)
# Start from noise
x = torch.randn(1, 1, 28, 28, device=device)
guidance_scale = 3.0
for i in reversed(range(1, timesteps)):
t = torch.full((1,), i, device=device)
pred_cond = model(x, t, labels)
pred_uncond = model(x, t, uncond_labels)
pred_noise = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
beta_t = extract(betas, t, x.shape)
alpha_t = extract(alphas, t, x.shape)
alpha_bar_t = extract(alphas_cumprod, t, x.shape)
pred_x0 = (x - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t)
x = torch.sqrt(alpha_t) * pred_x0 + torch.sqrt(beta_t) * pred_noise
img = (x.clamp(-1,1) + 1) / 2
img = img[0,0].cpu().numpy()
end_time = timer()
gen_time = round(end_time - start_time, 4)
return img, gen_time
# ------------------------------
# 5. Gradio app
# ------------------------------
title = "Text-to-Digit Diffusion (MNIST)"
description = (
"A **conditional diffusion model** trained on MNIST. "
"Type a digit (e.g. `seven` or `7`) to generate a handwritten number."
)
article = "Created by [Programming Ocean Academy](https://www.programming-ocean.com/)"
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(
label="Enter digit (0–9 or word)",
placeholder="seven or 7"
),
outputs=[
gr.Image(
label="Generated Digit",
type="numpy",
width=256,
height=256
),
gr.Number(label="generation time (s)")
],
title=title,
description=description,
article=article
)
# Launch demo
demo.launch(debug=False)