File size: 1,842 Bytes
4a45da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03c88c6
4a45da4
03c88c6
4a45da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler

from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only
#from torchao.quantization import autoquant
Pipeline = None

ckpt_id = "black-forest-labs/FLUX.1-schnell"

def clear():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

def load_pipeline() -> Pipeline:    
    clear()

    dtype, device = torch.bfloat16, "cuda"

    clear()
    pipeline = DiffusionPipeline.from_pretrained(
        ckpt_id, 
        torch_dtype=dtype,
        )
    
    pipeline.enable_sequential_cpu_offload()
    torch.jit.enable_onednn_fusion(True)
    for _ in range(2):
        clear()
        pipeline(prompt="testing testing 123", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    
    return pipeline

sample = True
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    global sample
    if sample:
        clear()
        sample = None
    torch.cuda.reset_peak_memory_stats()
    generator = Generator("cuda").manual_seed(request.seed)
    image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
    return(image)