Spaces:
Paused
Paused
| import os | |
| import traceback | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import spaces | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| # -------------------- | |
| # Global config | |
| # -------------------- | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 2048 | |
| MIN_IMAGE_SIZE = 64 | |
| # Optional: environment override for model name | |
| MODEL_ID = os.getenv("QWEN_IMAGE_MODEL_ID", "Qwen/Qwen-Image") | |
| # -------------------- | |
| # Pipeline load with guard | |
| # -------------------- | |
| pipe = None | |
| pipe_load_error = None | |
| def _load_pipeline(): | |
| global pipe, pipe_load_error | |
| if pipe is not None: | |
| return pipe | |
| try: | |
| pipe = DiffusionPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=dtype | |
| ) | |
| pipe = pipe.to(device) | |
| torch.cuda.empty_cache() | |
| except Exception as e: | |
| pipe_load_error = f"Failed to load model '{MODEL_ID}': {repr(e)}" | |
| traceback.print_exc() | |
| return pipe | |
| _load_pipeline() # eager load on startup | |
| def _safe_clamp_size(width: int, height: int): | |
| """ | |
| Clamp image dimensions to safe boundaries and keep them multiples of 8/16. | |
| """ | |
| def _round_to_16(x): | |
| return int(max(MIN_IMAGE_SIZE, min(MAX_IMAGE_SIZE, x)) // 16 * 16) | |
| w = _round_to_16(width) | |
| h = _round_to_16(height) | |
| return w, h | |
| def _normalize_seed(seed, randomize_seed: bool): | |
| """ | |
| Normalize seed: if -1 or None, or randomize_seed=True, draw a fresh seed. | |
| """ | |
| if randomize_seed or seed is None or int(seed) < 0: | |
| return random.randint(0, MAX_SEED) | |
| return int(seed) % (MAX_SEED + 1) | |
| def _maybe_load_lora(lora_id: str, lora_scale: float): | |
| """ | |
| Load LoRA if provided. Returns (loaded: bool, message: str | None). | |
| """ | |
| if not lora_id or lora_id.strip() == "": | |
| return False, None | |
| lora_id = lora_id.strip() | |
| try: | |
| # Best-effort unload previous LoRA if supported | |
| if hasattr(pipe, "unload_lora_weights"): | |
| pipe.unload_lora_weights() | |
| if hasattr(pipe, "load_lora_weights"): | |
| pipe.load_lora_weights(lora_id, adapter_name="default", weight_name=None) | |
| else: | |
| return False, f"LoRA support not available in this pipeline. (Tried: {lora_id})" | |
| # Some pipelines support setting a scale attribute or passing scale in call. | |
| # Here we just report scale; the actual use depends on the underlying pipeline. | |
| return True, None | |
| except Exception as e: | |
| traceback.print_exc() | |
| return False, f"Failed to load LoRA '{lora_id}': {repr(e)}" | |
| def _maybe_unload_lora(): | |
| try: | |
| if hasattr(pipe, "unload_lora_weights"): | |
| pipe.unload_lora_weights() | |
| except Exception: | |
| traceback.print_exc() | |
| # -------------------- | |
| # Inference function with robust error handling | |
| # -------------------- | |
| def infer( | |
| prompt: str, | |
| seed: int = 42, | |
| randomize_seed: bool = False, | |
| width: int = 1024, | |
| height: int = 1024, | |
| guidance_scale: float = 4.0, | |
| num_inference_steps: int = 28, | |
| lora_id: str = None, | |
| lora_scale: float = 0.95, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """ | |
| Main inference entrypoint for Gradio. | |
| Returns: | |
| - on success: (PIL.Image, seed) | |
| - on failure: (None, seed or -1) with a user-friendly error via gr.Error | |
| """ | |
| # Basic validation | |
| if not prompt or prompt.strip() == "": | |
| raise gr.Error("Prompt is empty. Please provide a text prompt.") | |
| # If model failed to load at startup, fail fast | |
| if pipe_load_error is not None: | |
| raise gr.Error( | |
| f"Model failed to load on startup: {pipe_load_error} | |
| " | |
| "Try restarting the Space or check the logs." | |
| ) | |
| # Clamp dimensions | |
| width, height = _safe_clamp_size(width, height) | |
| # Normalize seed | |
| seed = _normalize_seed(seed, randomize_seed) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| lora_loaded = False | |
| lora_warning = None | |
| try: | |
| # LoRA loading | |
| if lora_id and lora_id.strip() != "": | |
| lora_loaded, lora_warning = _maybe_load_lora(lora_id, lora_scale) | |
| progress(0.1, desc="Running generation...") | |
| # Core pipeline call | |
| # true_cfg_scale enables Qwen-style CFG; keep guidance_scale fixed / unused. | |
| try: | |
| result = pipe( | |
| prompt=prompt, | |
| negative_prompt="", # required even if empty for true_cfg_scale CFG | |
| width=width, | |
| height=height, | |
| num_inference_steps=int(num_inference_steps), | |
| generator=generator, | |
| true_cfg_scale=float(guidance_scale), | |
| guidance_scale=None, # unused for this pipeline | |
| ) | |
| except torch.cuda.OutOfMemoryError: | |
| torch.cuda.empty_cache() | |
| raise gr.Error( | |
| "CUDA out-of-memory during generation. Try reducing image size or steps." | |
| ) | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise gr.Error( | |
| f"Inference failed with an internal error: {repr(e)} | |
| " | |
| "Please try again with smaller dimensions or fewer steps." | |
| ) | |
| if not hasattr(result, "images") or not result.images: | |
| raise gr.Error( | |
| "Pipeline returned no images. This may indicate a model or configuration issue." | |
| ) | |
| image = result.images[0] | |
| # If there was a LoRA warning, surface it as a non-fatal message | |
| if lora_warning: | |
| # Use print for logs; Gradio will show the main output, not this text. | |
| print(lora_warning) | |
| progress(1.0, desc="Done") | |
| return image, seed | |
| finally: | |
| # Ensure we always try to clean up LoRA & memory even on errors | |
| if lora_loaded: | |
| _maybe_unload_lora() | |
| if device == "cuda": | |
| try: | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| # -------------------- | |
| # UI | |
| # -------------------- | |
| examples = [ | |
| "a tiny astronaut hatching from an egg on the moon", | |
| "a cat holding a sign that says hello world", | |
| "an anime illustration of a wiener schnitzel", | |
| ] | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 960px; | |
| } | |
| .generate-btn { | |
| background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important; | |
| border: none !important; | |
| color: white !important; | |
| } | |
| .generate-btn:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 5px 15px rgba(0,0,0,0.2); | |
| } | |
| """ | |
| with gr.Blocks(css=css) as app: | |
| gr.HTML("<center><h1>Qwen Image with LoRA support</h1></center>") | |
| with gr.Column(elem_id="col-container"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| text_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter a prompt here", | |
| lines=3, | |
| elem_id="prompt-text-input", | |
| ) | |
| with gr.Row(): | |
| custom_lora = gr.Textbox( | |
| label="Custom LoRA (optional)", | |
| info="LoRA Hugging Face path (e.g. flymy-ai/qwen-image-realism-lora)", | |
| placeholder="flymy-ai/qwen-image-realism-lora", | |
| ) | |
| with gr.Row(): | |
| with gr.Accordion("Advanced Settings", open=False): | |
| lora_scale = gr.Slider( | |
| label="LoRA Scale", | |
| minimum=0, | |
| maximum=2, | |
| step=0.01, | |
| value=1, | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| value=1024, | |
| minimum=MIN_IMAGE_SIZE, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=16, | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| value=1024, | |
| minimum=MIN_IMAGE_SIZE, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=16, | |
| ) | |
| seed = gr.Slider( | |
| label="Seed (-1 = random)", | |
| value=-1, | |
| minimum=-1, | |
| maximum=MAX_SEED, | |
| step=1, | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize seed each run", | |
| value=True, | |
| ) | |
| with gr.Row(): | |
| steps = gr.Slider( | |
| label="Inference steps", | |
| value=28, | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| ) | |
| cfg = gr.Slider( | |
| label="Guidance Scale (true_cfg_scale)", | |
| value=4, | |
| minimum=1, | |
| maximum=20, | |
| step=0.5, | |
| ) | |
| with gr.Row(): | |
| text_button = gr.Button( | |
| "✨ Generate Image", | |
| variant="primary", | |
| elem_classes=["generate-btn"], | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| image_output = gr.Image( | |
| type="pil", | |
| label="Image Output", | |
| elem_id="gallery", | |
| ) | |
| with gr.Column(): | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[text_prompt], | |
| ) | |
| # Shared handler for button click and prompt submit | |
| gr.on( | |
| triggers=[text_button.click, text_prompt.submit], | |
| fn=infer, | |
| inputs=[ | |
| text_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| cfg, | |
| steps, | |
| custom_lora, | |
| lora_scale, | |
| ], | |
| outputs=[image_output, seed], | |
| ) | |
| if __name__ == "__main__": | |
| # In Spaces, HF will call app.launch() implicitly, but keeping this for local dev. | |
| app.launch(share=False)899492 |