peacock / thumbnail_generator.py
jaspermurphy1989's picture
Create thumbnail_generator.py
e0a52b4 verified
raw
history blame contribute delete
506 Bytes
import torch
from torchvision.utils import save_image
from PIL import Image
import io
class ThumbnailGenerator:
def __init__(self, model_path="pytorch_model.bin"):
self.model = torch.load(model_path, map_location=torch.device("cpu"))
self.model.eval()
def generate(self, seed=None):
z = torch.randn(1, 512) if seed is None else torch.tensor(seed).float().unsqueeze(0)
with torch.no_grad():
thumbnail = self.model(z)
return thumbnail.squeeze(0)