Spaces:
Runtime error
Runtime error
| from denoising_diffusion_pytorch import Unet, GaussianDiffusion | |
| import streamlit as st | |
| import torch | |
| def get_model(): | |
| unet = Unet( | |
| dim = 64, | |
| dim_mults = (1, 2, 4, 8) | |
| ) | |
| model = GaussianDiffusion( | |
| unet, | |
| image_size = 64, | |
| timesteps = 1000, # number of steps | |
| sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) | |
| loss_type = 'l1' , # L1 or L2 | |
| p2_loss_weight_gamma = 1. | |
| ) | |
| model.load_state_dict(torch.load("./model-final.pt", map_location="cpu")) | |
| model.eval() | |
| return model | |
| def scale_to_255(x): | |
| return ((x+1)/2*255).astype('uint8') | |
| if __name__ == "__main__": | |
| st.title("Sushi Diffusion") | |
| st.text("The generation process takes about 10 mins.") | |
| st.text("If you don't want to wait, please visit: https://thissushidoesnotexist.com/") | |
| model = get_model() | |
| st.text('Press the button below to generate sushi!') | |
| if st.button('🍣'): | |
| bar = st.progress(0) | |
| img = torch.randn((1,3,64,64), device="cpu") | |
| for t in reversed(range(0, model.num_timesteps)): | |
| img, _ = model.p_sample(img, t, None) | |
| bar.progress((model.num_timesteps-t) / model.num_timesteps) | |
| img = scale_to_255(img.squeeze().numpy().transpose(1,2,0)) | |
| st.image(img, caption='This sushi does not exist.') | |