manbeast3b commited on
Commit
b842cf8
·
verified ·
1 Parent(s): d7a1c7a

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +5 -13
src/pipeline.py CHANGED
@@ -45,16 +45,8 @@ def load_pipeline():
45
 
46
 
47
  @torch.inference_mode()
48
- def infer(request: TextToImageRequest, pipeline):
49
- generator = torch.Generator("cuda").manual_seed(request.seed)
50
- image = pipeline(
51
- request.prompt,
52
- generator=generator,
53
- guidance_scale=0.0, # Match original repos except -25, test higher values
54
- num_inference_steps=4,
55
- max_sequence_length=256,
56
- height=request.height,
57
- width=request.width,
58
- output_type="pil"
59
- ).images[0]
60
- return image
 
45
 
46
 
47
  @torch.inference_mode()
48
+ def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
49
+ torch.cuda.reset_peak_memory_stats()
50
+ generator = Generator("cuda").manual_seed(request.seed)
51
+ image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
52
+ return(image)