JBlitzar commited on
Commit
6b36151
·
1 Parent(s): 3802079
Files changed (1) hide show
  1. pipeline.py +4 -4
pipeline.py CHANGED
@@ -12,7 +12,7 @@ from bert_vectorize import vectorize_text_with_bert
12
  from logger import save_grid_with_label
13
 
14
  class TextToImagePipeline(Pipeline):
15
- def __init__(self, model_dir: str, device: str = "cpu"):
16
  # Initialize model, diffusion manager, and set up environment
17
  self.device = device
18
  self.model_dir = model_dir
@@ -29,9 +29,9 @@ class TextToImagePipeline(Pipeline):
29
  self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000)
30
  self.wrapper.set_schedule(Schedule.LINEAR)
31
 
32
- def __call__(self, prompt: str, amt: int = 1):
33
- # Generate images based on the prompt
34
- return self.generate_sample_save_images(prompt, amt)
35
 
36
  def generate_sample_save_images(self, prompt: str, amt: int = 1):
37
  # Prepare the output path
 
12
  from logger import save_grid_with_label
13
 
14
  class TextToImagePipeline(Pipeline):
15
+ def __init__(self, model_dir: str = "runs/run_3_jxa", device: str = "cpu"):
16
  # Initialize model, diffusion manager, and set up environment
17
  self.device = device
18
  self.model_dir = model_dir
 
29
  self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000)
30
  self.wrapper.set_schedule(Schedule.LINEAR)
31
 
32
+ def __call__(self, inputs):
33
+
34
+ return self.generate_sample_save_images(inputs, 8)
35
 
36
  def generate_sample_save_images(self, prompt: str, amt: int = 1):
37
  # Prepare the output path