manbeast3b commited on
Commit
813fdcc
·
verified ·
1 Parent(s): 2ffbe90

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +2 -2
src/pipeline.py CHANGED
@@ -37,7 +37,7 @@ def load_pipeline() -> Pipeline:
37
  pipeline.enable_sequential_cpu_offload()
38
  for _ in range(1):
39
  clear()
40
- with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
41
  pipeline(prompt="unpervaded, unencumber, froggish, groundneedle, transnatural, fatherhood, outjump, cinerator", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
42
  return pipeline
43
 
@@ -52,6 +52,6 @@ def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
52
  # torch.cuda.reset_peak_memory_stats()
53
  generator = Generator("cuda").manual_seed(request.seed)
54
  image = None
55
- with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
56
  image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
57
  return(image)
 
37
  pipeline.enable_sequential_cpu_offload()
38
  for _ in range(1):
39
  clear()
40
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
41
  pipeline(prompt="unpervaded, unencumber, froggish, groundneedle, transnatural, fatherhood, outjump, cinerator", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
42
  return pipeline
43
 
 
52
  # torch.cuda.reset_peak_memory_stats()
53
  generator = Generator("cuda").manual_seed(request.seed)
54
  image = None
55
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
56
  image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
57
  return(image)