basil-ahmad's picture
Update app.py
bc1b8ff verified
import gradio as gr
import torch
from helper import generate_img, DDPM
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# diffusion setup
timesteps = 500
beta1, beta2 = 1e-4, 0.02
betas = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1) + beta1
betas = betas.to(device)
alpha = 1.0 - betas
alpha_bar = torch.cumprod(alpha, dim=0).to(device)
# model and sampler
model = torch.load("model.pt", map_location=device)
sampler = DDPM(betas)
label_to_index = {
l: i
for i, l in enumerate([
'hero',
'non-hero -not recommended-',
'food',
'spells & weapons',
'side-facing'
])
}
sampling_count = 500
batch_size = 1
def generate(context_label):
index = [label_to_index[context_label]]
img = generate_img(model, sampler, betas, alpha, alpha_bar, batch_size, sampling_count, context=index)
img = F.interpolate(img, size=(320, 320), mode="nearest")[0]
img = torch.clamp(img, 0, 1)
img = img.cpu().detach().permute(1, 2, 0).numpy()
return img
interface = gr.Interface(
fn=generate,
inputs=gr.Radio(list(label_to_index.keys()), label="Pick one:"),
outputs=gr.Image(label="Generated Image"),
title="DDPM Image Generator",
description="Select a category to generate an image"
)
if __name__ == "__main__":
interface.launch()