Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from helper import generate_img, DDPM | |
| import torch.nn.functional as F | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # diffusion setup | |
| timesteps = 500 | |
| beta1, beta2 = 1e-4, 0.02 | |
| betas = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1) + beta1 | |
| betas = betas.to(device) | |
| alpha = 1.0 - betas | |
| alpha_bar = torch.cumprod(alpha, dim=0).to(device) | |
| # model and sampler | |
| model = torch.load("model.pt", map_location=device) | |
| sampler = DDPM(betas) | |
| label_to_index = { | |
| l: i | |
| for i, l in enumerate([ | |
| 'hero', | |
| 'non-hero -not recommended-', | |
| 'food', | |
| 'spells & weapons', | |
| 'side-facing' | |
| ]) | |
| } | |
| sampling_count = 500 | |
| batch_size = 1 | |
| def generate(context_label): | |
| index = [label_to_index[context_label]] | |
| img = generate_img(model, sampler, betas, alpha, alpha_bar, batch_size, sampling_count, context=index) | |
| img = F.interpolate(img, size=(320, 320), mode="nearest")[0] | |
| img = torch.clamp(img, 0, 1) | |
| img = img.cpu().detach().permute(1, 2, 0).numpy() | |
| return img | |
| interface = gr.Interface( | |
| fn=generate, | |
| inputs=gr.Radio(list(label_to_index.keys()), label="Pick one:"), | |
| outputs=gr.Image(label="Generated Image"), | |
| title="DDPM Image Generator", | |
| description="Select a category to generate an image" | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() | |