jaspermurphy1989 commited on
Commit
e0a52b4
·
verified ·
1 Parent(s): a6b006d

Create thumbnail_generator.py

Browse files
Files changed (1) hide show
  1. thumbnail_generator.py +15 -0
thumbnail_generator.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.utils import save_image
3
+ from PIL import Image
4
+ import io
5
+
6
+ class ThumbnailGenerator:
7
+ def __init__(self, model_path="pytorch_model.bin"):
8
+ self.model = torch.load(model_path, map_location=torch.device("cpu"))
9
+ self.model.eval()
10
+
11
+ def generate(self, seed=None):
12
+ z = torch.randn(1, 512) if seed is None else torch.tensor(seed).float().unsqueeze(0)
13
+ with torch.no_grad():
14
+ thumbnail = self.model(z)
15
+ return thumbnail.squeeze(0)