Update src/pipeline.py
Browse files- src/pipeline.py +5 -13
src/pipeline.py
CHANGED
|
@@ -45,16 +45,8 @@ def load_pipeline():
|
|
| 45 |
|
| 46 |
|
| 47 |
@torch.inference_mode()
|
| 48 |
-
def infer(request: TextToImageRequest, pipeline):
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
guidance_scale=0.0, # Match original repos except -25, test higher values
|
| 54 |
-
num_inference_steps=4,
|
| 55 |
-
max_sequence_length=256,
|
| 56 |
-
height=request.height,
|
| 57 |
-
width=request.width,
|
| 58 |
-
output_type="pil"
|
| 59 |
-
).images[0]
|
| 60 |
-
return image
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
@torch.inference_mode()
|
| 48 |
+
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
|
| 49 |
+
torch.cuda.reset_peak_memory_stats()
|
| 50 |
+
generator = Generator("cuda").manual_seed(request.seed)
|
| 51 |
+
image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
|
| 52 |
+
return(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|