File size: 4,631 Bytes
80b7188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""

FireRed-Image-Edit HTTP service matching GenSearcher Qwen /generate contract.



Request/response aligned with qwen_image_api_server and gen_image_deepresearch_reward.call_qwen_edit_to_generate_image.

"""
from __future__ import annotations

import argparse
import base64
import io
import os
import re
from typing import List, Optional

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from PIL import Image

app = FastAPI(title="FireRed-Image-Edit GenSearcher adapter")

_pipe = None


def _load_image_from_url_or_data(url_or_data: str) -> Image.Image:
    if url_or_data.startswith("data:image/"):
        m = re.match(r"data:image/[^;]+;base64,(.*)", url_or_data, re.DOTALL)
        if not m:
            raise ValueError("Invalid data URL")
        raw = base64.b64decode(m.group(1))
        return Image.open(io.BytesIO(raw)).convert("RGB")
    raise ValueError("Only data:image/...;base64,... URLs are supported in Space adapter")


class GenerateRequest(BaseModel):
    image_urls: Optional[List[str]] = None
    prompt: str
    seed: int = 0
    true_cfg_scale: float = 4.0
    negative_prompt: str = " "
    num_inference_steps: int = 40
    guidance_scale: float = 1.0
    num_images_per_prompt: int = 1


def get_pipeline():
    global _pipe
    if _pipe is None:
        import torch
        from diffusers import QwenImageEditPlusPipeline

        model_path = os.environ.get(
            "FIRERED_MODEL_ID", "FireRedTeam/FireRed-Image-Edit-1.1"
        )
        dtype = torch.bfloat16
        _pipe = QwenImageEditPlusPipeline.from_pretrained(
            model_path,
            torch_dtype=dtype,
        )
        _pipe.to("cuda")
        _pipe.set_progress_bar_config(disable=True)
    return _pipe


@app.get("/health")
def health():
    return {"status": "ok", "model_loaded": _pipe is not None}


@app.post("/generate")
def generate(request: GenerateRequest):
    try:
        pipe = get_pipeline()
    except Exception as e:
        raise HTTPException(status_code=503, detail=f"Model not ready: {e}")

    import torch

    images: List[Image.Image] = []
    if request.image_urls:
        for u in request.image_urls[:3]:
            if u:
                try:
                    images.append(_load_image_from_url_or_data(u))
                except Exception as ex:
                    raise HTTPException(
                        status_code=400, detail=f"Bad image_urls entry: {ex}"
                    )

    gen = torch.Generator(device="cuda").manual_seed(int(request.seed))

    try:
        with torch.inference_mode():
            if not images:
                # Text-only: FireRed is edit-focused; synthesize a neutral canvas for conditioning-free edit
                blank = Image.new("RGB", (1024, 1024), (240, 240, 240))
                out = pipe(
                    image=[blank],
                    prompt=request.prompt,
                    generator=gen,
                    true_cfg_scale=float(request.true_cfg_scale),
                    negative_prompt=request.negative_prompt or " ",
                    num_inference_steps=int(request.num_inference_steps),
                    guidance_scale=float(request.guidance_scale),
                    num_images_per_prompt=int(request.num_images_per_prompt),
                )
            else:
                out = pipe(
                    image=images,
                    prompt=request.prompt,
                    generator=gen,
                    true_cfg_scale=float(request.true_cfg_scale),
                    negative_prompt=request.negative_prompt or " ",
                    num_inference_steps=int(request.num_inference_steps),
                    guidance_scale=float(request.guidance_scale),
                    num_images_per_prompt=int(request.num_images_per_prompt),
                )
        pil = out.images[0]
    except Exception as e:
        import traceback

        return {
            "success": False,
            "message": f"{e}\n{traceback.format_exc()}",
        }

    buf = io.BytesIO()
    pil.save(buf, format="PNG")
    b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
    return {"success": True, "image": b64}


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", default="0.0.0.0")
    parser.add_argument("--port", type=int, default=8765)
    args = parser.parse_args()
    import uvicorn

    uvicorn.run(app, host=args.host, port=args.port)


if __name__ == "__main__":
    main()