Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import os | |
| import shutil | |
| import uuid | |
| import requests | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| app = FastAPI(title="Experience Eats 2.5D Processing API") | |
| # Configure CORS for local development | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["http://localhost:3000"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Setup directories | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| UPLOAD_DIR = os.path.join(BASE_DIR, "storage", "uploads") | |
| PROCESSED_DIR = os.path.join(BASE_DIR, "storage", "processed") | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| os.makedirs(PROCESSED_DIR, exist_ok=True) | |
| # Mount static files to serve images | |
| app.mount("/storage", StaticFiles(directory=os.path.join(BASE_DIR, "storage")), name="storage") | |
| # Initialize Depth Estimator | |
| depth_estimator = None | |
| try: | |
| from transformers import pipeline | |
| print("Loading Depth Anything model... (this may take a minute on first run)") | |
| # Using the V1 model which has native Hugging Face transformers pipeline support | |
| depth_estimator = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf") | |
| print("Depth model loaded successfully!") | |
| except Exception as e: | |
| print(f"Warning: Failed to load depth estimator. {e}") | |
| def generate_depth_map(input_path: str, output_path: str): | |
| """Generates a depth map from an image using Depth Anything V2.""" | |
| if not depth_estimator: | |
| print("Depth estimator not loaded, simulating depth map.") | |
| shutil.copy(input_path, output_path) | |
| return False | |
| try: | |
| from PIL import Image | |
| image = Image.open(input_path) | |
| # Handle transparency by converting to RGB for depth estimation | |
| if image.mode == 'RGBA': | |
| background = Image.new('RGB', image.size, (255, 255, 255)) | |
| background.paste(image, mask=image.split()[3]) # 3 is the alpha channel | |
| image = background | |
| elif image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| result = depth_estimator(image) | |
| depth_img = result["depth"] | |
| depth_img.save(output_path) | |
| return True | |
| except Exception as e: | |
| print(f"Depth generation failed: {e}") | |
| return False | |
| # remove_bg_exhausted boolean no longer needed since we use local AI | |
| # Initialize local AI session once globally to avoid reloading the model on every image | |
| rmbg_session = None | |
| def get_rmbg_session(): | |
| global rmbg_session | |
| if rmbg_session is None: | |
| try: | |
| from rembg import new_session | |
| # Using the default u2net model which offers exceptional quality | |
| # equivalent to RMBG-1.4 but strictly compatible with this environment | |
| print("Loading local AI background removal model... (this may take a minute on first run)") | |
| rmbg_session = new_session('u2net') | |
| print("Local AI Background removal model loaded successfully!") | |
| except Exception as e: | |
| print(f"Failed to load local AI background remover: {e}") | |
| return rmbg_session | |
| def remove_background(input_path: str, output_path: str): | |
| """Uses local AI (rembg/u2net) to remove background from an image. No API keys needed!""" | |
| try: | |
| from PIL import Image | |
| from rembg import remove | |
| img = Image.open(input_path) | |
| session = get_rmbg_session() | |
| if session: | |
| # Remove background locally | |
| result = remove(img, session=session) | |
| # Save as PNG to keep transparency | |
| result.save(output_path, format="PNG") | |
| return True | |
| else: | |
| # Fallback if session couldn't be loaded | |
| img.save(output_path, format="PNG") | |
| return True | |
| except Exception as e: | |
| print(f"Local AI Background removal failed: {e}") | |
| # Graceful fallback to copy | |
| try: | |
| from PIL import Image | |
| img = Image.open(input_path) | |
| img.save(output_path, format="PNG") | |
| except: | |
| shutil.copy(input_path, output_path) | |
| return True | |
| from fastapi import BackgroundTasks | |
| from typing import Dict, Any, List | |
| # Simple in-memory storage for job status | |
| # In production, this would be a database (Redis/Postgres) | |
| jobs_db: Dict[str, Any] = {} | |
| def process_photos_background(job_id: str, files_data: list, job_upload_dir: str, job_processed_dir: str): | |
| """Background task to process images so we don't block the API and trigger proxy timeouts.""" | |
| try: | |
| jobs_db[job_id]["status"] = "processing" | |
| processed_files = [] | |
| for i, (safe_filename, input_file_path) in enumerate(files_data): | |
| output_file_path = os.path.join(job_processed_dir, f"angle_{i:02d}_nobg.png") | |
| depth_file_path = os.path.join(job_processed_dir, f"angle_{i:02d}_depth.png") | |
| # 2. Try to Remove background | |
| bg_success = remove_background(input_file_path, output_file_path) | |
| # 3. Generate depth map | |
| # Use the output if bg removal succeeded and the file exists, otherwise fallback to original | |
| source_for_depth = output_file_path if bg_success and os.path.exists(output_file_path) else input_file_path | |
| generate_depth_map(source_for_depth, depth_file_path) | |
| # Determine correct folder prefix for URL since source might be in 'uploads' instead of 'processed' | |
| source_folder = "processed" if bg_success and os.path.exists(output_file_path) else "uploads" | |
| # Reconstruct the URL path relative to the storage dir | |
| rel_path_to_job = os.path.relpath(job_upload_dir, UPLOAD_DIR) | |
| processed_files.append({ | |
| "angle": i, | |
| "image_url": f"/storage/{source_folder}/{rel_path_to_job}/{os.path.basename(source_for_depth)}", | |
| "depth_url": f"/storage/processed/{rel_path_to_job}/{os.path.basename(depth_file_path)}" | |
| }) | |
| # Update job as complete | |
| jobs_db[job_id] = { | |
| "status": "success", | |
| "layers": processed_files | |
| } | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| jobs_db[job_id] = { | |
| "status": "error", | |
| "message": str(e) | |
| } | |
| def read_root(): | |
| return {"status": "ok", "message": "Experience Eats Backend is running"} | |
| async def process_dish_photos( | |
| background_tasks: BackgroundTasks, | |
| shop_slug: str = Form(...), | |
| category: str = Form("uncategorized"), | |
| files: List[UploadFile] = File(...) | |
| ): | |
| """ | |
| Receives 12 photos of a dish, saves them, and starts the 2.5D processing pipeline in the background. | |
| """ | |
| if len(files) != 12: | |
| raise HTTPException(status_code=400, detail="Exactly 12 photos are required") | |
| # Generate common job ID | |
| job_id = str(uuid.uuid4()) | |
| # Ensure nested shop directory structure | |
| job_upload_dir = os.path.join(UPLOAD_DIR, shop_slug, category, job_id) | |
| job_processed_dir = os.path.join(PROCESSED_DIR, shop_slug, category, job_id) | |
| os.makedirs(job_upload_dir, exist_ok=True) | |
| os.makedirs(job_processed_dir, exist_ok=True) | |
| files_data = [] | |
| # Save uploaded files synchronously before passing to background task | |
| for i, file in enumerate(files): | |
| # Validate format | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image") | |
| file_extension = os.path.splitext(file.filename)[1] | |
| if not file_extension: | |
| file_extension = ".jpg" # fallback | |
| safe_filename = f"angle_{i:02d}{file_extension}" | |
| input_file_path = os.path.join(job_upload_dir, safe_filename) | |
| with open(input_file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| files_data.append((safe_filename, input_file_path)) | |
| # Give initial status | |
| jobs_db[job_id] = {"status": "pending"} | |
| # Send to background task | |
| background_tasks.add_task(process_photos_background, job_id, files_data, job_upload_dir, job_processed_dir) | |
| return { | |
| "status": "accepted", | |
| "job_id": job_id, | |
| "message": "Processing started in the background. Poll /api/job-status/{job_id} for completion." | |
| } | |
| def get_job_status(job_id: str): | |
| """ | |
| Endpoint for the frontend to poll the status of a long-running 2.5D crop/depth job. | |
| """ | |
| if job_id not in jobs_db: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| return jobs_db[job_id] | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) | |