experience-eats / main.py
Hasthika's picture
New U2net Rem Bg Model
9c3faeb verified
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
import os
import shutil
import uuid
import requests
from dotenv import load_dotenv
load_dotenv()
app = FastAPI(title="Experience Eats 2.5D Processing API")
# Configure CORS for local development
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Setup directories
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
UPLOAD_DIR = os.path.join(BASE_DIR, "storage", "uploads")
PROCESSED_DIR = os.path.join(BASE_DIR, "storage", "processed")
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(PROCESSED_DIR, exist_ok=True)
# Mount static files to serve images
app.mount("/storage", StaticFiles(directory=os.path.join(BASE_DIR, "storage")), name="storage")
# Initialize Depth Estimator
depth_estimator = None
try:
from transformers import pipeline
print("Loading Depth Anything model... (this may take a minute on first run)")
# Using the V1 model which has native Hugging Face transformers pipeline support
depth_estimator = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf")
print("Depth model loaded successfully!")
except Exception as e:
print(f"Warning: Failed to load depth estimator. {e}")
def generate_depth_map(input_path: str, output_path: str):
"""Generates a depth map from an image using Depth Anything V2."""
if not depth_estimator:
print("Depth estimator not loaded, simulating depth map.")
shutil.copy(input_path, output_path)
return False
try:
from PIL import Image
image = Image.open(input_path)
# Handle transparency by converting to RGB for depth estimation
if image.mode == 'RGBA':
background = Image.new('RGB', image.size, (255, 255, 255))
background.paste(image, mask=image.split()[3]) # 3 is the alpha channel
image = background
elif image.mode != 'RGB':
image = image.convert('RGB')
result = depth_estimator(image)
depth_img = result["depth"]
depth_img.save(output_path)
return True
except Exception as e:
print(f"Depth generation failed: {e}")
return False
# remove_bg_exhausted boolean no longer needed since we use local AI
# Initialize local AI session once globally to avoid reloading the model on every image
rmbg_session = None
def get_rmbg_session():
global rmbg_session
if rmbg_session is None:
try:
from rembg import new_session
# Using the default u2net model which offers exceptional quality
# equivalent to RMBG-1.4 but strictly compatible with this environment
print("Loading local AI background removal model... (this may take a minute on first run)")
rmbg_session = new_session('u2net')
print("Local AI Background removal model loaded successfully!")
except Exception as e:
print(f"Failed to load local AI background remover: {e}")
return rmbg_session
def remove_background(input_path: str, output_path: str):
"""Uses local AI (rembg/u2net) to remove background from an image. No API keys needed!"""
try:
from PIL import Image
from rembg import remove
img = Image.open(input_path)
session = get_rmbg_session()
if session:
# Remove background locally
result = remove(img, session=session)
# Save as PNG to keep transparency
result.save(output_path, format="PNG")
return True
else:
# Fallback if session couldn't be loaded
img.save(output_path, format="PNG")
return True
except Exception as e:
print(f"Local AI Background removal failed: {e}")
# Graceful fallback to copy
try:
from PIL import Image
img = Image.open(input_path)
img.save(output_path, format="PNG")
except:
shutil.copy(input_path, output_path)
return True
from fastapi import BackgroundTasks
from typing import Dict, Any, List
# Simple in-memory storage for job status
# In production, this would be a database (Redis/Postgres)
jobs_db: Dict[str, Any] = {}
def process_photos_background(job_id: str, files_data: list, job_upload_dir: str, job_processed_dir: str):
"""Background task to process images so we don't block the API and trigger proxy timeouts."""
try:
jobs_db[job_id]["status"] = "processing"
processed_files = []
for i, (safe_filename, input_file_path) in enumerate(files_data):
output_file_path = os.path.join(job_processed_dir, f"angle_{i:02d}_nobg.png")
depth_file_path = os.path.join(job_processed_dir, f"angle_{i:02d}_depth.png")
# 2. Try to Remove background
bg_success = remove_background(input_file_path, output_file_path)
# 3. Generate depth map
# Use the output if bg removal succeeded and the file exists, otherwise fallback to original
source_for_depth = output_file_path if bg_success and os.path.exists(output_file_path) else input_file_path
generate_depth_map(source_for_depth, depth_file_path)
# Determine correct folder prefix for URL since source might be in 'uploads' instead of 'processed'
source_folder = "processed" if bg_success and os.path.exists(output_file_path) else "uploads"
# Reconstruct the URL path relative to the storage dir
rel_path_to_job = os.path.relpath(job_upload_dir, UPLOAD_DIR)
processed_files.append({
"angle": i,
"image_url": f"/storage/{source_folder}/{rel_path_to_job}/{os.path.basename(source_for_depth)}",
"depth_url": f"/storage/processed/{rel_path_to_job}/{os.path.basename(depth_file_path)}"
})
# Update job as complete
jobs_db[job_id] = {
"status": "success",
"layers": processed_files
}
except Exception as e:
import traceback
traceback.print_exc()
jobs_db[job_id] = {
"status": "error",
"message": str(e)
}
@app.get("/")
def read_root():
return {"status": "ok", "message": "Experience Eats Backend is running"}
@app.post("/api/process-dish")
async def process_dish_photos(
background_tasks: BackgroundTasks,
shop_slug: str = Form(...),
category: str = Form("uncategorized"),
files: List[UploadFile] = File(...)
):
"""
Receives 12 photos of a dish, saves them, and starts the 2.5D processing pipeline in the background.
"""
if len(files) != 12:
raise HTTPException(status_code=400, detail="Exactly 12 photos are required")
# Generate common job ID
job_id = str(uuid.uuid4())
# Ensure nested shop directory structure
job_upload_dir = os.path.join(UPLOAD_DIR, shop_slug, category, job_id)
job_processed_dir = os.path.join(PROCESSED_DIR, shop_slug, category, job_id)
os.makedirs(job_upload_dir, exist_ok=True)
os.makedirs(job_processed_dir, exist_ok=True)
files_data = []
# Save uploaded files synchronously before passing to background task
for i, file in enumerate(files):
# Validate format
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
file_extension = os.path.splitext(file.filename)[1]
if not file_extension:
file_extension = ".jpg" # fallback
safe_filename = f"angle_{i:02d}{file_extension}"
input_file_path = os.path.join(job_upload_dir, safe_filename)
with open(input_file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
files_data.append((safe_filename, input_file_path))
# Give initial status
jobs_db[job_id] = {"status": "pending"}
# Send to background task
background_tasks.add_task(process_photos_background, job_id, files_data, job_upload_dir, job_processed_dir)
return {
"status": "accepted",
"job_id": job_id,
"message": "Processing started in the background. Poll /api/job-status/{job_id} for completion."
}
@app.get("/api/job-status/{job_id}")
def get_job_status(job_id: str):
"""
Endpoint for the frontend to poll the status of a long-running 2.5D crop/depth job.
"""
if job_id not in jobs_db:
raise HTTPException(status_code=404, detail="Job not found")
return jobs_db[job_id]
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)