| import gradio as gr |
| import gc |
| import os |
| import random |
| import tempfile |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from fastapi.responses import HTMLResponse |
| from gradio.data_classes import FileData |
|
|
| |
| try: |
| import spaces |
| _HAS_SPACES = True |
| except ImportError: |
| _HAS_SPACES = False |
|
|
| |
| |
| |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| DTYPE = torch.bfloat16 |
|
|
| if gr.NO_RELOAD: |
| from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline |
| from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel |
| from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3 |
|
|
| PIPE = QwenImageEditPlusPipeline.from_pretrained( |
| "FireRedTeam/FireRed-Image-Edit-1.1", |
| transformer=QwenImageTransformer2DModel.from_pretrained( |
| "prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V19", |
| torch_dtype=DTYPE, |
| device_map="cuda", |
| ), |
| torch_dtype=DTYPE, |
| ).to(DEVICE) |
|
|
| try: |
| PIPE.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3()) |
| print("Flash Attention 3 processor set.") |
| except Exception as e: |
| print(f"FA3 processor not set: {e}") |
|
|
| NEGATIVE_PROMPT = ( |
| "worst quality, low quality, bad anatomy, bad hands, text, error, " |
| "missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, " |
| "signature, watermark, username, blurry" |
| ) |
| MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
| def _round_dims(image: Image.Image) -> tuple[int, int]: |
| w, h = image.size |
| if w > h: |
| new_w, new_h = 1024, int(1024 * h / w) |
| else: |
| new_h, new_w = 1024, int(1024 * w / h) |
| return (new_w // 8) * 8, (new_h // 8) * 8 |
|
|
|
|
| |
| |
| |
| if _HAS_SPACES: |
| @spaces.GPU |
| def _edit(image: Image.Image, prompt: str, seed: int, steps: int) -> Image.Image: |
| return _run_pipe(image, prompt, seed, steps) |
| else: |
| def _edit(image, prompt, seed, steps): |
| return _run_pipe(image, prompt, seed, steps) |
|
|
|
|
| def _run_pipe(image, prompt, seed, steps): |
| print(f"[_run_pipe] start cuda_avail={torch.cuda.is_available()}", flush=True) |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| width, height = _round_dims(image) |
| print(f"[_run_pipe] dims w={width} h={height} steps={steps}", flush=True) |
| generator = torch.Generator(device=DEVICE).manual_seed(seed) |
| out = PIPE( |
| image=[image], |
| prompt=prompt, |
| negative_prompt=NEGATIVE_PROMPT, |
| width=width, |
| height=height, |
| num_inference_steps=steps, |
| true_cfg_scale=1.0, |
| generator=generator, |
| ).images[0] |
| print(f"[_run_pipe] done size={out.size}", flush=True) |
| return out |
|
|
|
|
| |
| server = gr.Server() |
| HOME = Path(__file__).parent |
|
|
|
|
| @server.api(name="edit_image") |
| def edit_image(image: FileData, prompt: str) -> dict: |
| """Edit an image guided by a text prompt using FireRed-Image-Edit 1.1.""" |
| print(f"[edit_image] received prompt={prompt!r} path={image.get('path')}", flush=True) |
| if not prompt or not prompt.strip(): |
| return {"error": "Please enter an edit prompt."} |
| src = Image.open(image["path"]).convert("RGB") |
| print(f"[edit_image] image opened size={src.size}", flush=True) |
| seed = random.randint(0, MAX_SEED) |
| print(f"[edit_image] calling _edit seed={seed}", flush=True) |
| result = _edit(src, prompt.strip(), seed, steps=4) |
| print(f"[edit_image] _edit returned size={result.size}", flush=True) |
|
|
| fd, out_path = tempfile.mkstemp(suffix=".png") |
| os.close(fd) |
| result.save(out_path) |
| print(f"[edit_image] saved to {out_path} exists={os.path.exists(out_path)} size={os.path.getsize(out_path)}", flush=True) |
| payload = {"image": FileData(path=out_path), "seed": seed} |
| print(f"[edit_image] returning payload keys={list(payload.keys())} image={payload['image']}", flush=True) |
| return payload |
|
|
|
|
| @server.get("/", response_class=HTMLResponse) |
| async def homepage(): |
| return (HOME / "index.html").read_text(encoding="utf-8") |
|
|
|
|
| if __name__ == "__main__": |
| server.launch(show_error=True, allowed_paths=[tempfile.gettempdir()]) |
|
|