FireEdit-Deployment / server.py
toiram's picture
Upload server.py
7e5e3f8 verified
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from gradio_client import Client, handle_file
import os
import shutil
import uuid
from dotenv import load_dotenv
from typing import List
load_dotenv()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
base_dir = os.path.dirname(os.path.abspath(__file__))
UPLOAD_DIR = os.path.join(base_dir, "temp_uploads")
OUTPUT_DIR = os.path.join(base_dir, "outputs")
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")
HF_TOKEN = os.environ.get("HF_TOKEN")
# Diccionario de clientes
clients = {}
def init_clients():
models = {
"firered": "prithivMLmods/FireRed-Image-Edit-1.0-Fast",
"qwen": "prithivMLmods/Qwen-Image-Edit-2511-LoRAs-Fast",
"flux": "prithivMLmods/FLUX.2-Klein-LoRA-Studio",
"turbo": "mrfakename/Z-Image-Turbo",
"banana": "multimodalart/nano-banana",
"3d_camera": "multimodalart/qwen-image-multiple-angles-3d-camera",
"qwen_rapid": "IllyaS08/qwen-image-edit-rapid-aio-sfw-v23",
"anypose": "linoyts/Qwen-Image-Edit-2511-anypose",
"qwen3_vl": "prithivMLmods/Qwen3-VL-abliterated-MAX-Fast"
}
for key, space in models.items():
try:
print(f"DEBUG: Conectando a {space}...")
clients[key] = Client(space, token=HF_TOKEN)
print(f"DEBUG: {key} conectado.")
except Exception as e:
print(f"Error connecting to {key}: {e}")
clients[key] = None
init_clients()
@app.get("/")
async def read_index():
return FileResponse(os.path.join(base_dir, 'index.html'))
@app.get("/status")
async def get_status():
status = {}
for key, client in clients.items():
status[key] = client is not None
return status
@app.post("/edit-image")
async def edit_image(
images: List[UploadFile] = File(None),
prompt: str = Form(...),
model: str = Form("firered"),
lora_adapter: str = Form("Photo-to-Anime"),
style_name: str = Form("None"),
seed: int = Form(0),
randomize_seed: bool = Form(True),
guidance_scale: float = Form(1.0),
steps: int = Form(4),
width: int = Form(1024),
height: int = Form(1024),
azimuth: float = Form(0),
elevation: float = Form(0),
distance: float = Form(1.0),
rewrite_prompt: bool = Form(False)
):
if model not in clients or not clients[model]:
raise HTTPException(status_code=503, detail=f"Model {model} not connected")
temp_paths = []
try:
# Guardar todas las imágenes temporalmente si existen
gradio_images = []
if images:
for img in images:
temp_filename = f"{uuid.uuid4()}_{img.filename}"
temp_path = os.path.join(UPLOAD_DIR, temp_filename)
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(img.file, buffer)
temp_paths.append(temp_path)
gradio_images.append({"image": handle_file(temp_path), "caption": None})
client = clients[model]
if model == "firered":
if not gradio_images:
raise HTTPException(status_code=400, detail="Images required for FireRed")
result = client.predict(
images=gradio_images,
prompt=prompt,
seed=seed,
randomize_seed=randomize_seed,
guidance_scale=guidance_scale,
steps=steps,
api_name="/infer"
)
elif model == "qwen":
if not gradio_images:
raise HTTPException(status_code=400, detail="Images required for Qwen")
result = client.predict(
images=gradio_images,
prompt=prompt,
lora_adapter=lora_adapter,
seed=seed,
randomize_seed=randomize_seed,
guidance_scale=guidance_scale,
steps=steps,
api_name="/infer"
)
elif model == "qwen_rapid":
if not gradio_images:
raise HTTPException(status_code=400, detail="Images required for Qwen Rapid")
# Using the parameters from the documentation:
# images, prompt, seed, randomize_seed, true_guidance_scale, num_inference_steps, height, width, rewrite_prompt
result = client.predict(
images=gradio_images,
prompt=prompt,
seed=seed,
randomize_seed=randomize_seed,
true_guidance_scale=guidance_scale,
num_inference_steps=steps,
height=height,
width=width,
rewrite_prompt=rewrite_prompt,
api_name="/infer"
)
elif model == "anypose":
if len(gradio_images) < 2:
raise HTTPException(status_code=400, detail="AnyPose requires two images: Reference and Pose")
result = client.predict(
reference_image=gradio_images[0]["image"],
pose_image=gradio_images[1]["image"],
prompt=prompt,
seed=seed,
randomize_seed=randomize_seed,
true_guidance_scale=guidance_scale,
num_inference_steps=steps,
height=height,
width=width,
rewrite_prompt=rewrite_prompt,
api_name="/infer"
)
elif model == "qwen3_vl":
if not gradio_images:
raise HTTPException(status_code=400, detail="Image required for Qwen3-VL")
# Using the parameters from the documentation:
# text (prompt), image, max_new_tokens (mapped from steps), temperature (mapped from guidance_scale), etc.
result = client.predict(
text=prompt,
image=gradio_images[0]["image"],
max_new_tokens=steps * 100, # Adapting steps to tokens
temperature=guidance_scale,
top_p=0.9,
top_k=50,
repetition_penalty=1.1,
gpu_timeout=60,
api_name="/generate_image"
)
elif model == "flux":
if not gradio_images:
raise HTTPException(status_code=400, detail="Images required for Flux")
result = client.predict(
input_images=gradio_images,
prompt=prompt,
style_name=style_name,
seed=seed,
randomize_seed=randomize_seed,
guidance_scale=guidance_scale,
steps=steps,
api_name="/infer"
)
elif model == "3d_camera":
if not gradio_images:
raise HTTPException(status_code=400, detail="Image required for 3D Camera")
result = client.predict(
image=gradio_images[0]["image"],
azimuth=azimuth,
elevation=elevation,
distance=distance,
seed=seed,
randomize_seed=randomize_seed,
guidance_scale=guidance_scale,
num_inference_steps=steps,
height=height,
width=width,
api_name="/infer_camera_edit"
)
elif model == "turbo":
result = client.predict(
prompt=prompt,
height=height,
width=width,
num_inference_steps=steps,
seed=seed,
randomize_seed=randomize_seed,
api_name="/generate_image"
)
elif model == "banana":
# Nano Banana (Gemini 2.5 Flash Image)
# uses fn_index 2 with prompt, image list and token
result = client.predict(
prompt=prompt,
images=gradio_images if gradio_images else [],
oauth_token=HF_TOKEN,
fn_index=2
)
print(f"DEBUG: Result from {model}: {result}")
output_image_data = result[0]
# Some models return a list of images as the first element
if isinstance(output_image_data, list) and len(output_image_data) > 0:
output_image_data = output_image_data[0]
gradio_temp_path = None
if isinstance(output_image_data, dict):
# Try common Gradio keys for image paths
gradio_temp_path = output_image_data.get('path') or output_image_data.get('name') or output_image_data.get('url')
# Special case: nested 'image' key (found in some models)
if not gradio_temp_path and 'image' in output_image_data:
img_val = output_image_data['image']
if isinstance(img_val, str):
gradio_temp_path = img_val
elif isinstance(img_val, dict):
gradio_temp_path = img_val.get('path') or img_val.get('name')
elif isinstance(output_image_data, str):
gradio_temp_path = output_image_data
if not gradio_temp_path:
raise Exception(f"Could not extract image path from result: {output_image_data}")
output_filename = f"{model}_edited_{uuid.uuid4()}.webp"
final_output_path = os.path.join(OUTPUT_DIR, output_filename)
shutil.copy(gradio_temp_path, final_output_path)
return {
"success": True,
"images": [f"/outputs/{output_filename}"],
"seed": str(result[1])
}
except Exception as e:
print(f"Inference error ({model}): {e}")
raise HTTPException(status_code=500, detail=str(e))
finally:
for path in temp_paths:
if os.path.exists(path):
os.remove(path)
if __name__ == "__main__":
import uvicorn
import os
# Hugging Face Spaces uses port 7860 by default
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)