manbeast3b commited on
Commit
cf17da0
·
verified ·
1 Parent(s): 604d61f

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +9 -7
src/pipeline.py CHANGED
@@ -19,11 +19,10 @@ os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
19
  torch.backends.cuda.matmul.allow_tf32 = True
20
  torch.backends.cudnn.enabled = True
21
  torch.backends.cudnn.benchmark = True
22
-
23
  ckpt_id = "black-forest-labs/FLUX.1-schnell"
24
  ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
25
-
26
-
27
  Pipeline = None
28
  def empty_cache():
29
  gc.collect()
@@ -52,15 +51,18 @@ def load_pipeline() -> Pipeline:
52
  ).to(device)
53
  quantize_(pipeline.vae, int8_weight_only())
54
 
55
- pipeline(prompt="imprisonable, forechamber, demagogic, monotropic, blandiloquious, blechnoid", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
56
-
57
  empty_cache()
 
58
  return pipeline
59
 
60
-
61
  @torch.no_grad()
62
  def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
63
-
 
 
 
 
64
  image=pipeline(request.prompt,
65
  generator=generator,
66
  guidance_scale=0.0,
 
19
  torch.backends.cuda.matmul.allow_tf32 = True
20
  torch.backends.cudnn.enabled = True
21
  torch.backends.cudnn.benchmark = True
22
+ torch.backends.cudnn.deterministic = False
23
  ckpt_id = "black-forest-labs/FLUX.1-schnell"
24
  ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
25
+ sample = 0
 
26
  Pipeline = None
27
  def empty_cache():
28
  gc.collect()
 
51
  ).to(device)
52
  quantize_(pipeline.vae, int8_weight_only())
53
 
54
+ pipeline(prompt="imprisonable, forechamber, demagogic, monotropic, blandiloquious, blechnoid, blechnoid, blechnoid, blechnoid", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
 
55
  empty_cache()
56
+ pipeline(prompt="", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
57
  return pipeline
58
 
 
59
  @torch.no_grad()
60
  def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
61
+ global sample
62
+ if not sample:
63
+ sample = 1
64
+ empty_cache()
65
+
66
  image=pipeline(request.prompt,
67
  generator=generator,
68
  guidance_scale=0.0,