Spaces:
Running
Running
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from gradio_client import Client, handle_file | |
| import os | |
| import shutil | |
| import uuid | |
| from dotenv import load_dotenv | |
| from typing import List | |
| load_dotenv() | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| base_dir = os.path.dirname(os.path.abspath(__file__)) | |
| UPLOAD_DIR = os.path.join(base_dir, "temp_uploads") | |
| OUTPUT_DIR = os.path.join(base_dir, "outputs") | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # Diccionario de clientes | |
| clients = {} | |
| def init_clients(): | |
| models = { | |
| "firered": "prithivMLmods/FireRed-Image-Edit-1.0-Fast", | |
| "qwen": "prithivMLmods/Qwen-Image-Edit-2511-LoRAs-Fast", | |
| "flux": "prithivMLmods/FLUX.2-Klein-LoRA-Studio", | |
| "turbo": "mrfakename/Z-Image-Turbo", | |
| "banana": "multimodalart/nano-banana", | |
| "3d_camera": "multimodalart/qwen-image-multiple-angles-3d-camera", | |
| "qwen_rapid": "IllyaS08/qwen-image-edit-rapid-aio-sfw-v23", | |
| "anypose": "linoyts/Qwen-Image-Edit-2511-anypose", | |
| "qwen3_vl": "prithivMLmods/Qwen3-VL-abliterated-MAX-Fast" | |
| } | |
| for key, space in models.items(): | |
| try: | |
| print(f"DEBUG: Conectando a {space}...") | |
| clients[key] = Client(space, token=HF_TOKEN) | |
| print(f"DEBUG: {key} conectado.") | |
| except Exception as e: | |
| print(f"Error connecting to {key}: {e}") | |
| clients[key] = None | |
| init_clients() | |
| async def read_index(): | |
| return FileResponse(os.path.join(base_dir, 'index.html')) | |
| async def get_status(): | |
| status = {} | |
| for key, client in clients.items(): | |
| status[key] = client is not None | |
| return status | |
| async def edit_image( | |
| images: List[UploadFile] = File(None), | |
| prompt: str = Form(...), | |
| model: str = Form("firered"), | |
| lora_adapter: str = Form("Photo-to-Anime"), | |
| style_name: str = Form("None"), | |
| seed: int = Form(0), | |
| randomize_seed: bool = Form(True), | |
| guidance_scale: float = Form(1.0), | |
| steps: int = Form(4), | |
| width: int = Form(1024), | |
| height: int = Form(1024), | |
| azimuth: float = Form(0), | |
| elevation: float = Form(0), | |
| distance: float = Form(1.0), | |
| rewrite_prompt: bool = Form(False) | |
| ): | |
| if model not in clients or not clients[model]: | |
| raise HTTPException(status_code=503, detail=f"Model {model} not connected") | |
| temp_paths = [] | |
| try: | |
| # Guardar todas las imágenes temporalmente si existen | |
| gradio_images = [] | |
| if images: | |
| for img in images: | |
| temp_filename = f"{uuid.uuid4()}_{img.filename}" | |
| temp_path = os.path.join(UPLOAD_DIR, temp_filename) | |
| with open(temp_path, "wb") as buffer: | |
| shutil.copyfileobj(img.file, buffer) | |
| temp_paths.append(temp_path) | |
| gradio_images.append({"image": handle_file(temp_path), "caption": None}) | |
| client = clients[model] | |
| if model == "firered": | |
| if not gradio_images: | |
| raise HTTPException(status_code=400, detail="Images required for FireRed") | |
| result = client.predict( | |
| images=gradio_images, | |
| prompt=prompt, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| guidance_scale=guidance_scale, | |
| steps=steps, | |
| api_name="/infer" | |
| ) | |
| elif model == "qwen": | |
| if not gradio_images: | |
| raise HTTPException(status_code=400, detail="Images required for Qwen") | |
| result = client.predict( | |
| images=gradio_images, | |
| prompt=prompt, | |
| lora_adapter=lora_adapter, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| guidance_scale=guidance_scale, | |
| steps=steps, | |
| api_name="/infer" | |
| ) | |
| elif model == "qwen_rapid": | |
| if not gradio_images: | |
| raise HTTPException(status_code=400, detail="Images required for Qwen Rapid") | |
| # Using the parameters from the documentation: | |
| # images, prompt, seed, randomize_seed, true_guidance_scale, num_inference_steps, height, width, rewrite_prompt | |
| result = client.predict( | |
| images=gradio_images, | |
| prompt=prompt, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| true_guidance_scale=guidance_scale, | |
| num_inference_steps=steps, | |
| height=height, | |
| width=width, | |
| rewrite_prompt=rewrite_prompt, | |
| api_name="/infer" | |
| ) | |
| elif model == "anypose": | |
| if len(gradio_images) < 2: | |
| raise HTTPException(status_code=400, detail="AnyPose requires two images: Reference and Pose") | |
| result = client.predict( | |
| reference_image=gradio_images[0]["image"], | |
| pose_image=gradio_images[1]["image"], | |
| prompt=prompt, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| true_guidance_scale=guidance_scale, | |
| num_inference_steps=steps, | |
| height=height, | |
| width=width, | |
| rewrite_prompt=rewrite_prompt, | |
| api_name="/infer" | |
| ) | |
| elif model == "qwen3_vl": | |
| if not gradio_images: | |
| raise HTTPException(status_code=400, detail="Image required for Qwen3-VL") | |
| # Using the parameters from the documentation: | |
| # text (prompt), image, max_new_tokens (mapped from steps), temperature (mapped from guidance_scale), etc. | |
| result = client.predict( | |
| text=prompt, | |
| image=gradio_images[0]["image"], | |
| max_new_tokens=steps * 100, # Adapting steps to tokens | |
| temperature=guidance_scale, | |
| top_p=0.9, | |
| top_k=50, | |
| repetition_penalty=1.1, | |
| gpu_timeout=60, | |
| api_name="/generate_image" | |
| ) | |
| elif model == "flux": | |
| if not gradio_images: | |
| raise HTTPException(status_code=400, detail="Images required for Flux") | |
| result = client.predict( | |
| input_images=gradio_images, | |
| prompt=prompt, | |
| style_name=style_name, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| guidance_scale=guidance_scale, | |
| steps=steps, | |
| api_name="/infer" | |
| ) | |
| elif model == "3d_camera": | |
| if not gradio_images: | |
| raise HTTPException(status_code=400, detail="Image required for 3D Camera") | |
| result = client.predict( | |
| image=gradio_images[0]["image"], | |
| azimuth=azimuth, | |
| elevation=elevation, | |
| distance=distance, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=steps, | |
| height=height, | |
| width=width, | |
| api_name="/infer_camera_edit" | |
| ) | |
| elif model == "turbo": | |
| result = client.predict( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=steps, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| api_name="/generate_image" | |
| ) | |
| elif model == "banana": | |
| # Nano Banana (Gemini 2.5 Flash Image) | |
| # uses fn_index 2 with prompt, image list and token | |
| result = client.predict( | |
| prompt=prompt, | |
| images=gradio_images if gradio_images else [], | |
| oauth_token=HF_TOKEN, | |
| fn_index=2 | |
| ) | |
| print(f"DEBUG: Result from {model}: {result}") | |
| output_image_data = result[0] | |
| # Some models return a list of images as the first element | |
| if isinstance(output_image_data, list) and len(output_image_data) > 0: | |
| output_image_data = output_image_data[0] | |
| gradio_temp_path = None | |
| if isinstance(output_image_data, dict): | |
| # Try common Gradio keys for image paths | |
| gradio_temp_path = output_image_data.get('path') or output_image_data.get('name') or output_image_data.get('url') | |
| # Special case: nested 'image' key (found in some models) | |
| if not gradio_temp_path and 'image' in output_image_data: | |
| img_val = output_image_data['image'] | |
| if isinstance(img_val, str): | |
| gradio_temp_path = img_val | |
| elif isinstance(img_val, dict): | |
| gradio_temp_path = img_val.get('path') or img_val.get('name') | |
| elif isinstance(output_image_data, str): | |
| gradio_temp_path = output_image_data | |
| if not gradio_temp_path: | |
| raise Exception(f"Could not extract image path from result: {output_image_data}") | |
| output_filename = f"{model}_edited_{uuid.uuid4()}.webp" | |
| final_output_path = os.path.join(OUTPUT_DIR, output_filename) | |
| shutil.copy(gradio_temp_path, final_output_path) | |
| return { | |
| "success": True, | |
| "images": [f"/outputs/{output_filename}"], | |
| "seed": str(result[1]) | |
| } | |
| except Exception as e: | |
| print(f"Inference error ({model}): {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| for path in temp_paths: | |
| if os.path.exists(path): | |
| os.remove(path) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import os | |
| # Hugging Face Spaces uses port 7860 by default | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |