gensearcher-firered / services /firered_generate.py
JSCPPProgrammer's picture
Initial: GenSearcher workflow + FireRed /generate adapter + Gradio
80b7188 verified
"""
FireRed-Image-Edit HTTP service matching GenSearcher Qwen /generate contract.
Request/response aligned with qwen_image_api_server and gen_image_deepresearch_reward.call_qwen_edit_to_generate_image.
"""
from __future__ import annotations
import argparse
import base64
import io
import os
import re
from typing import List, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from PIL import Image
app = FastAPI(title="FireRed-Image-Edit GenSearcher adapter")
_pipe = None
def _load_image_from_url_or_data(url_or_data: str) -> Image.Image:
if url_or_data.startswith("data:image/"):
m = re.match(r"data:image/[^;]+;base64,(.*)", url_or_data, re.DOTALL)
if not m:
raise ValueError("Invalid data URL")
raw = base64.b64decode(m.group(1))
return Image.open(io.BytesIO(raw)).convert("RGB")
raise ValueError("Only data:image/...;base64,... URLs are supported in Space adapter")
class GenerateRequest(BaseModel):
image_urls: Optional[List[str]] = None
prompt: str
seed: int = 0
true_cfg_scale: float = 4.0
negative_prompt: str = " "
num_inference_steps: int = 40
guidance_scale: float = 1.0
num_images_per_prompt: int = 1
def get_pipeline():
global _pipe
if _pipe is None:
import torch
from diffusers import QwenImageEditPlusPipeline
model_path = os.environ.get(
"FIRERED_MODEL_ID", "FireRedTeam/FireRed-Image-Edit-1.1"
)
dtype = torch.bfloat16
_pipe = QwenImageEditPlusPipeline.from_pretrained(
model_path,
torch_dtype=dtype,
)
_pipe.to("cuda")
_pipe.set_progress_bar_config(disable=True)
return _pipe
@app.get("/health")
def health():
return {"status": "ok", "model_loaded": _pipe is not None}
@app.post("/generate")
def generate(request: GenerateRequest):
try:
pipe = get_pipeline()
except Exception as e:
raise HTTPException(status_code=503, detail=f"Model not ready: {e}")
import torch
images: List[Image.Image] = []
if request.image_urls:
for u in request.image_urls[:3]:
if u:
try:
images.append(_load_image_from_url_or_data(u))
except Exception as ex:
raise HTTPException(
status_code=400, detail=f"Bad image_urls entry: {ex}"
)
gen = torch.Generator(device="cuda").manual_seed(int(request.seed))
try:
with torch.inference_mode():
if not images:
# Text-only: FireRed is edit-focused; synthesize a neutral canvas for conditioning-free edit
blank = Image.new("RGB", (1024, 1024), (240, 240, 240))
out = pipe(
image=[blank],
prompt=request.prompt,
generator=gen,
true_cfg_scale=float(request.true_cfg_scale),
negative_prompt=request.negative_prompt or " ",
num_inference_steps=int(request.num_inference_steps),
guidance_scale=float(request.guidance_scale),
num_images_per_prompt=int(request.num_images_per_prompt),
)
else:
out = pipe(
image=images,
prompt=request.prompt,
generator=gen,
true_cfg_scale=float(request.true_cfg_scale),
negative_prompt=request.negative_prompt or " ",
num_inference_steps=int(request.num_inference_steps),
guidance_scale=float(request.guidance_scale),
num_images_per_prompt=int(request.num_images_per_prompt),
)
pil = out.images[0]
except Exception as e:
import traceback
return {
"success": False,
"message": f"{e}\n{traceback.format_exc()}",
}
buf = io.BytesIO()
pil.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return {"success": True, "image": b64}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8765)
args = parser.parse_args()
import uvicorn
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()