Spaces:
Paused
Paused
| import torch | |
| import torch.optim as optim | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| import numpy as np | |
| from PIL import Image | |
| import requests | |
| import io | |
| from unet import Unet, ConditionalUnet | |
| from diffusion import GaussianDiffusion, DiffusionImageAPI | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def inference1(): | |
| # new image from web page | |
| image = requests.get("https://picsum.photos/120/80").content | |
| return Image.open(io.BytesIO(image)) | |
| def inference(cond, x0=None, gif=False, callback=None): | |
| model = Unet( | |
| image_channels=3, | |
| dropout=0.1, | |
| ) | |
| model = ConditionalUnet( | |
| unet=model, | |
| num_classes=13, | |
| ) | |
| model.load_state_dict(torch.load("./model_final2.pt", map_location=device)) | |
| diffusion = GaussianDiffusion( | |
| model=model, | |
| noise_steps=1000, | |
| beta_0=1e-4, | |
| beta_T=0.02, | |
| image_size=(192, 128), | |
| ) | |
| if x0 is not None: | |
| x0 = diffusion.normalize_image(x0) | |
| x0 = x0.permute(2, 0, 1) | |
| x0 = x0.unsqueeze(0) | |
| model.to(device) | |
| diffusion.to(device) | |
| imageAPI = DiffusionImageAPI(diffusion) | |
| new_images, versions = diffusion.sample(1,cond=cond,x0=x0, cb=callback) | |
| if gif: | |
| images = [] | |
| for image in versions: | |
| images.append(imageAPI.tensor_to_image(image.squeeze(0))) | |
| print(len(images)) | |
| print(images[0]) | |
| # make gif out of pillow images | |
| images[0].save('./gif_output/versions.gif', | |
| save_all=True, | |
| append_images=images[1:], | |
| duration=100, | |
| loop=0) | |
| return imageAPI.tensor_to_image(new_images.squeeze(0)) | |
| if __name__ == "__main__": | |
| inference().show() | |