43 / Meissonic /inference_fp8.py
BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
import os
import sys
sys.path.append("./")
import torch
from src.transformer import Transformer2DModel
from src.pipeline import Pipeline
from src.scheduler import Scheduler
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
)
from diffusers import VQModel
import time
import argparse
from torchao.quantization.quant_api import (
quantize_,
float8_weight_only, # A8W8 FP8
)
device = 'cuda'
def get_quantization_method(method):
quantization_methods = {
'fp8': lambda: float8_weight_only(),
}
return quantization_methods.get(method, None)
def load_models(quantization_method=None):
model_path = "MeissonFlow/Meissonic"
dtype = torch.float16
model = Transformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
vq_model = VQModel.from_pretrained(model_path, subfolder="vqvae", torch_dtype=dtype)
text_encoder = CLIPTextModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
torch_dtype=dtype
)
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
scheduler = Scheduler.from_pretrained(model_path, subfolder="scheduler")
if quantization_method:
quant_method = get_quantization_method(quantization_method)
if quant_method:
quantize_(model, quant_method())
else:
print(f"Unsupported quantization method: {quantization_method}")
pipe = Pipeline(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler)
return pipe.to(device)
def run_inference(pipe, prompt, negative_prompt, resolution, cfg, steps):
return pipe(prompt=prompt, negative_prompt=negative_prompt, height=resolution, width=resolution, guidance_scale=cfg, num_inference_steps=steps).images[0]
def main(quantization_method):
steps = 64
CFG = 9
resolution = 1024
negative_prompts = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark"
prompts = [
"Two actors are posing for a pictur with one wearing a black and white face paint.",
"A large body of water with a rock in the middle and mountains in the background.",
"A white and blue coffee mug with a picture of a man on it.",
"The sun is setting over a city skyline with a river in the foreground.",
"A black and white cat with blue eyes.",
"Three boats in the ocean with a rainbow in the sky.",
"A robot playing the piano.",
"A cat wearing a hat.",
"A dog in a jungle.",
]
output_dir = "./output"
os.makedirs(output_dir, exist_ok=True)
pipe = load_models(quantization_method)
start_time = time.time()
total_memory_used = 0
for i, prompt in enumerate(prompts):
torch.cuda.reset_peak_memory_stats()
image_start_time = time.time()
image = run_inference(pipe, prompt, negative_prompts, resolution, CFG, steps)
image_end_time = time.time()
image.save(os.path.join(output_dir, f"{prompt[:10]}_{resolution}_{steps}_{CFG}_{quantization_method}.png"))
memory_used = torch.cuda.max_memory_reserved() / (1024 ** 3) # Convert to GB
total_memory_used += memory_used
print(f"Image {i+1} time: {image_end_time - image_start_time:.2f} seconds")
print(f"Image {i+1} max memory used: {memory_used:.2f} GB")
total_time = time.time() - start_time
avg_memory_used = total_memory_used / len(prompts)
print(f"Total inference time ({quantization_method}): {total_time:.2f} seconds")
print(f"Average memory used per image: {avg_memory_used:.2f} GB")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run inference with specified quantization method.")
parser.add_argument("--quantization", type=str, choices=['fp8'],
help="Quantization method to use")
args = parser.parse_args()
main(args.quantization)