import gradio as gr import torch from helper import generate_img, DDPM import torch.nn.functional as F device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # diffusion setup timesteps = 500 beta1, beta2 = 1e-4, 0.02 betas = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1) + beta1 betas = betas.to(device) alpha = 1.0 - betas alpha_bar = torch.cumprod(alpha, dim=0).to(device) # model and sampler model = torch.load("model.pt", map_location=device) sampler = DDPM(betas) label_to_index = { l: i for i, l in enumerate([ 'hero', 'non-hero -not recommended-', 'food', 'spells & weapons', 'side-facing' ]) } sampling_count = 500 batch_size = 1 def generate(context_label): index = [label_to_index[context_label]] img = generate_img(model, sampler, betas, alpha, alpha_bar, batch_size, sampling_count, context=index) img = F.interpolate(img, size=(320, 320), mode="nearest")[0] img = torch.clamp(img, 0, 1) img = img.cpu().detach().permute(1, 2, 0).numpy() return img interface = gr.Interface( fn=generate, inputs=gr.Radio(list(label_to_index.keys()), label="Pick one:"), outputs=gr.Image(label="Generated Image"), title="DDPM Image Generator", description="Select a category to generate an image" ) if __name__ == "__main__": interface.launch()