Jasmeet Singh commited on
Commit
ac62aed
·
verified ·
1 Parent(s): 0f78d89

Update generationPipeline.py

Browse files
Files changed (1) hide show
  1. generationPipeline.py +3 -1
generationPipeline.py CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
3
  import numpy as np
4
  from sampler import DDPMSampler
5
  from tqdm import tqdm
 
6
 
7
 
8
  WIDTH = 512
@@ -82,7 +83,8 @@ def generate(
82
 
83
  latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
84
 
85
- if input_image.any():
 
86
  encoder = models["encoder"]
87
  encoder.to(device)
88
 
 
3
  import numpy as np
4
  from sampler import DDPMSampler
5
  from tqdm import tqdm
6
+ from PIL import Image
7
 
8
 
9
  WIDTH = 512
 
83
 
84
  latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
85
 
86
+ if input_image:
87
+ input_image = Image.open(input_image)
88
  encoder = models["encoder"]
89
  encoder.to(device)
90