import torch from diffusers import DiffusionPipeline import spaces # Configuration MODEL_ID = 'black-forest-labs/FLUX.1-dev' DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Set dtype based on device for compatibility dtype = torch.bfloat16 if DEVICE == "cuda" else torch.float32 # Load pipeline with appropriate dtype for device pipe = DiffusionPipeline.from_pretrained(MODEL_ID, dtype=dtype) pipe.to(DEVICE) # AoT Compilation for faster inference (requires GPU) @spaces.GPU(duration=1500) def compile_transformer(): with spaces.aoti_capture(pipe.transformer) as call: pipe("test prompt") exported = torch.export.export( pipe.transformer, args=call.args, kwargs=call.kwargs, ) return spaces.aoti_compile(exported) # Apply compiled model compiled_transformer = compile_transformer() spaces.aoti_apply(compiled_transformer, pipe.transformer) @spaces.GPU def generate_image(prompt, negative_prompt="", num_inference_steps=20, guidance_scale=7.5): """ Generate an image from text prompt using FLUX. Args: prompt (str): The text prompt for image generation. negative_prompt (str): Negative prompt (not used in FLUX). num_inference_steps (int): Number of denoising steps. guidance_scale (float): Scale for classifier-free guidance. Returns: PIL.Image: Generated image. """ try: result = pipe( prompt=prompt, num_inference_steps=int(num_inference_steps), guidance_scale=float(guidance_scale), height=1024, width=1024 ) return result.images[0] except Exception as e: raise gr.Error(f"Generation failed: {str(e)}")