| """
|
| FireRed-Image-Edit HTTP service matching GenSearcher Qwen /generate contract.
|
|
|
| Request/response aligned with qwen_image_api_server and gen_image_deepresearch_reward.call_qwen_edit_to_generate_image.
|
| """
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import base64
|
| import io
|
| import os
|
| import re
|
| from typing import List, Optional
|
|
|
| from fastapi import FastAPI, HTTPException
|
| from pydantic import BaseModel
|
| from PIL import Image
|
|
|
| app = FastAPI(title="FireRed-Image-Edit GenSearcher adapter")
|
|
|
| _pipe = None
|
|
|
|
|
| def _load_image_from_url_or_data(url_or_data: str) -> Image.Image:
|
| if url_or_data.startswith("data:image/"):
|
| m = re.match(r"data:image/[^;]+;base64,(.*)", url_or_data, re.DOTALL)
|
| if not m:
|
| raise ValueError("Invalid data URL")
|
| raw = base64.b64decode(m.group(1))
|
| return Image.open(io.BytesIO(raw)).convert("RGB")
|
| raise ValueError("Only data:image/...;base64,... URLs are supported in Space adapter")
|
|
|
|
|
| class GenerateRequest(BaseModel):
|
| image_urls: Optional[List[str]] = None
|
| prompt: str
|
| seed: int = 0
|
| true_cfg_scale: float = 4.0
|
| negative_prompt: str = " "
|
| num_inference_steps: int = 40
|
| guidance_scale: float = 1.0
|
| num_images_per_prompt: int = 1
|
|
|
|
|
| def get_pipeline():
|
| global _pipe
|
| if _pipe is None:
|
| import torch
|
| from diffusers import QwenImageEditPlusPipeline
|
|
|
| model_path = os.environ.get(
|
| "FIRERED_MODEL_ID", "FireRedTeam/FireRed-Image-Edit-1.1"
|
| )
|
| dtype = torch.bfloat16
|
| _pipe = QwenImageEditPlusPipeline.from_pretrained(
|
| model_path,
|
| torch_dtype=dtype,
|
| )
|
| _pipe.to("cuda")
|
| _pipe.set_progress_bar_config(disable=True)
|
| return _pipe
|
|
|
|
|
| @app.get("/health")
|
| def health():
|
| return {"status": "ok", "model_loaded": _pipe is not None}
|
|
|
|
|
| @app.post("/generate")
|
| def generate(request: GenerateRequest):
|
| try:
|
| pipe = get_pipeline()
|
| except Exception as e:
|
| raise HTTPException(status_code=503, detail=f"Model not ready: {e}")
|
|
|
| import torch
|
|
|
| images: List[Image.Image] = []
|
| if request.image_urls:
|
| for u in request.image_urls[:3]:
|
| if u:
|
| try:
|
| images.append(_load_image_from_url_or_data(u))
|
| except Exception as ex:
|
| raise HTTPException(
|
| status_code=400, detail=f"Bad image_urls entry: {ex}"
|
| )
|
|
|
| gen = torch.Generator(device="cuda").manual_seed(int(request.seed))
|
|
|
| try:
|
| with torch.inference_mode():
|
| if not images:
|
|
|
| blank = Image.new("RGB", (1024, 1024), (240, 240, 240))
|
| out = pipe(
|
| image=[blank],
|
| prompt=request.prompt,
|
| generator=gen,
|
| true_cfg_scale=float(request.true_cfg_scale),
|
| negative_prompt=request.negative_prompt or " ",
|
| num_inference_steps=int(request.num_inference_steps),
|
| guidance_scale=float(request.guidance_scale),
|
| num_images_per_prompt=int(request.num_images_per_prompt),
|
| )
|
| else:
|
| out = pipe(
|
| image=images,
|
| prompt=request.prompt,
|
| generator=gen,
|
| true_cfg_scale=float(request.true_cfg_scale),
|
| negative_prompt=request.negative_prompt or " ",
|
| num_inference_steps=int(request.num_inference_steps),
|
| guidance_scale=float(request.guidance_scale),
|
| num_images_per_prompt=int(request.num_images_per_prompt),
|
| )
|
| pil = out.images[0]
|
| except Exception as e:
|
| import traceback
|
|
|
| return {
|
| "success": False,
|
| "message": f"{e}\n{traceback.format_exc()}",
|
| }
|
|
|
| buf = io.BytesIO()
|
| pil.save(buf, format="PNG")
|
| b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
| return {"success": True, "image": b64}
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--host", default="0.0.0.0")
|
| parser.add_argument("--port", type=int, default=8765)
|
| args = parser.parse_args()
|
| import uvicorn
|
|
|
| uvicorn.run(app, host=args.host, port=args.port)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|