JBlitzar commited on
Commit ·
6b36151
1
Parent(s): 3802079
- pipeline.py +4 -4
pipeline.py
CHANGED
|
@@ -12,7 +12,7 @@ from bert_vectorize import vectorize_text_with_bert
|
|
| 12 |
from logger import save_grid_with_label
|
| 13 |
|
| 14 |
class TextToImagePipeline(Pipeline):
|
| 15 |
-
def __init__(self, model_dir: str, device: str = "cpu"):
|
| 16 |
# Initialize model, diffusion manager, and set up environment
|
| 17 |
self.device = device
|
| 18 |
self.model_dir = model_dir
|
|
@@ -29,9 +29,9 @@ class TextToImagePipeline(Pipeline):
|
|
| 29 |
self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000)
|
| 30 |
self.wrapper.set_schedule(Schedule.LINEAR)
|
| 31 |
|
| 32 |
-
def __call__(self,
|
| 33 |
-
|
| 34 |
-
return self.generate_sample_save_images(
|
| 35 |
|
| 36 |
def generate_sample_save_images(self, prompt: str, amt: int = 1):
|
| 37 |
# Prepare the output path
|
|
|
|
| 12 |
from logger import save_grid_with_label
|
| 13 |
|
| 14 |
class TextToImagePipeline(Pipeline):
|
| 15 |
+
def __init__(self, model_dir: str = "runs/run_3_jxa", device: str = "cpu"):
|
| 16 |
# Initialize model, diffusion manager, and set up environment
|
| 17 |
self.device = device
|
| 18 |
self.model_dir = model_dir
|
|
|
|
| 29 |
self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000)
|
| 30 |
self.wrapper.set_schedule(Schedule.LINEAR)
|
| 31 |
|
| 32 |
+
def __call__(self, inputs):
|
| 33 |
+
|
| 34 |
+
return self.generate_sample_save_images(inputs, 8)
|
| 35 |
|
| 36 |
def generate_sample_save_images(self, prompt: str, amt: int = 1):
|
| 37 |
# Prepare the output path
|