from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from gradio_client import Client, handle_file import os import shutil import uuid from dotenv import load_dotenv from typing import List load_dotenv() app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) base_dir = os.path.dirname(os.path.abspath(__file__)) UPLOAD_DIR = os.path.join(base_dir, "temp_uploads") OUTPUT_DIR = os.path.join(base_dir, "outputs") os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs") HF_TOKEN = os.environ.get("HF_TOKEN") # Diccionario de clientes clients = {} def init_clients(): models = { "firered": "prithivMLmods/FireRed-Image-Edit-1.0-Fast", "qwen": "prithivMLmods/Qwen-Image-Edit-2511-LoRAs-Fast", "flux": "prithivMLmods/FLUX.2-Klein-LoRA-Studio", "turbo": "mrfakename/Z-Image-Turbo", "banana": "multimodalart/nano-banana", "3d_camera": "multimodalart/qwen-image-multiple-angles-3d-camera", "qwen_rapid": "IllyaS08/qwen-image-edit-rapid-aio-sfw-v23", "anypose": "linoyts/Qwen-Image-Edit-2511-anypose", "qwen3_vl": "prithivMLmods/Qwen3-VL-abliterated-MAX-Fast" } for key, space in models.items(): try: print(f"DEBUG: Conectando a {space}...") clients[key] = Client(space, token=HF_TOKEN) print(f"DEBUG: {key} conectado.") except Exception as e: print(f"Error connecting to {key}: {e}") clients[key] = None init_clients() @app.get("/") async def read_index(): return FileResponse(os.path.join(base_dir, 'index.html')) @app.get("/status") async def get_status(): status = {} for key, client in clients.items(): status[key] = client is not None return status @app.post("/edit-image") async def edit_image( images: List[UploadFile] = File(None), prompt: str = Form(...), model: str = Form("firered"), lora_adapter: str = Form("Photo-to-Anime"), style_name: str = Form("None"), seed: int = Form(0), randomize_seed: bool = Form(True), guidance_scale: float = Form(1.0), steps: int = Form(4), width: int = Form(1024), height: int = Form(1024), azimuth: float = Form(0), elevation: float = Form(0), distance: float = Form(1.0), rewrite_prompt: bool = Form(False) ): if model not in clients or not clients[model]: raise HTTPException(status_code=503, detail=f"Model {model} not connected") temp_paths = [] try: # Guardar todas las imágenes temporalmente si existen gradio_images = [] if images: for img in images: temp_filename = f"{uuid.uuid4()}_{img.filename}" temp_path = os.path.join(UPLOAD_DIR, temp_filename) with open(temp_path, "wb") as buffer: shutil.copyfileobj(img.file, buffer) temp_paths.append(temp_path) gradio_images.append({"image": handle_file(temp_path), "caption": None}) client = clients[model] if model == "firered": if not gradio_images: raise HTTPException(status_code=400, detail="Images required for FireRed") result = client.predict( images=gradio_images, prompt=prompt, seed=seed, randomize_seed=randomize_seed, guidance_scale=guidance_scale, steps=steps, api_name="/infer" ) elif model == "qwen": if not gradio_images: raise HTTPException(status_code=400, detail="Images required for Qwen") result = client.predict( images=gradio_images, prompt=prompt, lora_adapter=lora_adapter, seed=seed, randomize_seed=randomize_seed, guidance_scale=guidance_scale, steps=steps, api_name="/infer" ) elif model == "qwen_rapid": if not gradio_images: raise HTTPException(status_code=400, detail="Images required for Qwen Rapid") # Using the parameters from the documentation: # images, prompt, seed, randomize_seed, true_guidance_scale, num_inference_steps, height, width, rewrite_prompt result = client.predict( images=gradio_images, prompt=prompt, seed=seed, randomize_seed=randomize_seed, true_guidance_scale=guidance_scale, num_inference_steps=steps, height=height, width=width, rewrite_prompt=rewrite_prompt, api_name="/infer" ) elif model == "anypose": if len(gradio_images) < 2: raise HTTPException(status_code=400, detail="AnyPose requires two images: Reference and Pose") result = client.predict( reference_image=gradio_images[0]["image"], pose_image=gradio_images[1]["image"], prompt=prompt, seed=seed, randomize_seed=randomize_seed, true_guidance_scale=guidance_scale, num_inference_steps=steps, height=height, width=width, rewrite_prompt=rewrite_prompt, api_name="/infer" ) elif model == "qwen3_vl": if not gradio_images: raise HTTPException(status_code=400, detail="Image required for Qwen3-VL") # Using the parameters from the documentation: # text (prompt), image, max_new_tokens (mapped from steps), temperature (mapped from guidance_scale), etc. result = client.predict( text=prompt, image=gradio_images[0]["image"], max_new_tokens=steps * 100, # Adapting steps to tokens temperature=guidance_scale, top_p=0.9, top_k=50, repetition_penalty=1.1, gpu_timeout=60, api_name="/generate_image" ) elif model == "flux": if not gradio_images: raise HTTPException(status_code=400, detail="Images required for Flux") result = client.predict( input_images=gradio_images, prompt=prompt, style_name=style_name, seed=seed, randomize_seed=randomize_seed, guidance_scale=guidance_scale, steps=steps, api_name="/infer" ) elif model == "3d_camera": if not gradio_images: raise HTTPException(status_code=400, detail="Image required for 3D Camera") result = client.predict( image=gradio_images[0]["image"], azimuth=azimuth, elevation=elevation, distance=distance, seed=seed, randomize_seed=randomize_seed, guidance_scale=guidance_scale, num_inference_steps=steps, height=height, width=width, api_name="/infer_camera_edit" ) elif model == "turbo": result = client.predict( prompt=prompt, height=height, width=width, num_inference_steps=steps, seed=seed, randomize_seed=randomize_seed, api_name="/generate_image" ) elif model == "banana": # Nano Banana (Gemini 2.5 Flash Image) # uses fn_index 2 with prompt, image list and token result = client.predict( prompt=prompt, images=gradio_images if gradio_images else [], oauth_token=HF_TOKEN, fn_index=2 ) print(f"DEBUG: Result from {model}: {result}") output_image_data = result[0] # Some models return a list of images as the first element if isinstance(output_image_data, list) and len(output_image_data) > 0: output_image_data = output_image_data[0] gradio_temp_path = None if isinstance(output_image_data, dict): # Try common Gradio keys for image paths gradio_temp_path = output_image_data.get('path') or output_image_data.get('name') or output_image_data.get('url') # Special case: nested 'image' key (found in some models) if not gradio_temp_path and 'image' in output_image_data: img_val = output_image_data['image'] if isinstance(img_val, str): gradio_temp_path = img_val elif isinstance(img_val, dict): gradio_temp_path = img_val.get('path') or img_val.get('name') elif isinstance(output_image_data, str): gradio_temp_path = output_image_data if not gradio_temp_path: raise Exception(f"Could not extract image path from result: {output_image_data}") output_filename = f"{model}_edited_{uuid.uuid4()}.webp" final_output_path = os.path.join(OUTPUT_DIR, output_filename) shutil.copy(gradio_temp_path, final_output_path) return { "success": True, "images": [f"/outputs/{output_filename}"], "seed": str(result[1]) } except Exception as e: print(f"Inference error ({model}): {e}") raise HTTPException(status_code=500, detail=str(e)) finally: for path in temp_paths: if os.path.exists(path): os.remove(path) if __name__ == "__main__": import uvicorn import os # Hugging Face Spaces uses port 7860 by default port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)