bright-app-28 / models.py
AiCoderv2's picture
Update Gradio app with multiple files
75a470f verified
import torch
from diffusers import DiffusionPipeline
import spaces
# Configuration
MODEL_ID = 'black-forest-labs/FLUX.1-dev'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Set dtype based on device for compatibility
dtype = torch.bfloat16 if DEVICE == "cuda" else torch.float32
# Load pipeline with appropriate dtype for device
pipe = DiffusionPipeline.from_pretrained(MODEL_ID, dtype=dtype)
pipe.to(DEVICE)
# AoT Compilation for faster inference (requires GPU)
@spaces.GPU(duration=1500)
def compile_transformer():
with spaces.aoti_capture(pipe.transformer) as call:
pipe("test prompt")
exported = torch.export.export(
pipe.transformer,
args=call.args,
kwargs=call.kwargs,
)
return spaces.aoti_compile(exported)
# Apply compiled model
compiled_transformer = compile_transformer()
spaces.aoti_apply(compiled_transformer, pipe.transformer)
@spaces.GPU
def generate_image(prompt, negative_prompt="", num_inference_steps=20, guidance_scale=7.5):
"""
Generate an image from text prompt using FLUX.
Args:
prompt (str): The text prompt for image generation.
negative_prompt (str): Negative prompt (not used in FLUX).
num_inference_steps (int): Number of denoising steps.
guidance_scale (float): Scale for classifier-free guidance.
Returns:
PIL.Image: Generated image.
"""
try:
result = pipe(
prompt=prompt,
num_inference_steps=int(num_inference_steps),
guidance_scale=float(guidance_scale),
height=1024,
width=1024
)
return result.images[0]
except Exception as e:
raise gr.Error(f"Generation failed: {str(e)}")