manbeast3b commited on
Commit
3c526ec
·
verified ·
1 Parent(s): bd0bb1c

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +5 -22
src/pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- from diffusers import AutoencoderKL, AutoencoderTiny
2
  from diffusers.image_processor import VaeImageProcessor
3
  import torch
4
  import torch._dynamo
@@ -7,7 +7,9 @@ from PIL.Image import Image
7
  from pipelines.models import TextToImageRequest
8
  from torch import Generator
9
  from diffusers import FluxPipeline
10
- from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
 
 
11
 
12
  Pipeline = None
13
  MODEL_ID = "black-forest-labs/FLUX.1-schnell"
@@ -20,11 +22,7 @@ def clear():
20
 
21
  def load_pipeline() -> Pipeline:
22
  clear()
23
- # vae = AutoencoderKL.from_pretrained(
24
- # MODEL_ID, subfolder="vae", torch_dtype=torch.bfloat16
25
- # )
26
  vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=DTYPE)
27
- # quantize_(vae, fpx_weight_only(3, 2))
28
  quantize_(vae, int8_weight_only())
29
  pipeline = FluxPipeline.from_pretrained(MODEL_ID,vae=vae,
30
  torch_dtype=DTYPE)
@@ -43,18 +41,6 @@ def load_pipeline() -> Pipeline:
43
  pipeline(prompt="unpervaded, unencumber, froggish, groundneedle, transnatural, fatherhood, outjump, cinerator", width=1024, height=1024, guidance_scale=0.1, num_inference_steps=4, max_sequence_length=256)
44
  return pipeline
45
 
46
- # sample = True
47
- # @torch.inference_mode()
48
- # def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
49
- # global sample
50
- # if sample:
51
- # clear()
52
- # sample = None
53
- # # torch.cuda.reset_peak_memory_stats()
54
- # generator = Generator("cuda").manual_seed(request.seed)
55
- # image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
56
- # return(image)
57
-
58
  sample = True
59
  @torch.inference_mode()
60
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
@@ -62,9 +48,6 @@ def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
62
  if sample:
63
  clear()
64
  sample = None
65
- # torch.cuda.reset_peak_memory_stats()
66
  generator = Generator("cuda").manual_seed(request.seed)
67
- image = None
68
- with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
69
- image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
70
  return(image)
 
1
+ from diffusers import AutoencoderTiny
2
  from diffusers.image_processor import VaeImageProcessor
3
  import torch
4
  import torch._dynamo
 
7
  from pipelines.models import TextToImageRequest
8
  from torch import Generator
9
  from diffusers import FluxPipeline
10
+ from torchao.quantization import quantize_, int8_weight_only
11
+ import os
12
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.02"
13
 
14
  Pipeline = None
15
  MODEL_ID = "black-forest-labs/FLUX.1-schnell"
 
22
 
23
  def load_pipeline() -> Pipeline:
24
  clear()
 
 
 
25
  vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=DTYPE)
 
26
  quantize_(vae, int8_weight_only())
27
  pipeline = FluxPipeline.from_pretrained(MODEL_ID,vae=vae,
28
  torch_dtype=DTYPE)
 
41
  pipeline(prompt="unpervaded, unencumber, froggish, groundneedle, transnatural, fatherhood, outjump, cinerator", width=1024, height=1024, guidance_scale=0.1, num_inference_steps=4, max_sequence_length=256)
42
  return pipeline
43
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  sample = True
45
  @torch.inference_mode()
46
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
 
48
  if sample:
49
  clear()
50
  sample = None
 
51
  generator = Generator("cuda").manual_seed(request.seed)
52
+ image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
 
 
53
  return(image)