File size: 9,143 Bytes
935d1f8
336b59e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c3faeb
336b59e
9c3faeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336b59e
9c3faeb
 
 
 
 
336b59e
9c3faeb
 
 
 
 
336b59e
 
9c3faeb
 
 
 
336b59e
9c3faeb
 
 
 
 
 
 
 
 
336b59e
 
9c3faeb
336b59e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935d1f8
 
 
336b59e
 
935d1f8
 
336b59e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935d1f8
 
 
 
9c3faeb
935d1f8
336b59e
 
 
 
 
 
 
 
935d1f8
 
 
 
336b59e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
import os
import shutil
import uuid
import requests
from dotenv import load_dotenv

load_dotenv()

app = FastAPI(title="Experience Eats 2.5D Processing API")

# Configure CORS for local development
app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:3000"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Setup directories
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
UPLOAD_DIR = os.path.join(BASE_DIR, "storage", "uploads")
PROCESSED_DIR = os.path.join(BASE_DIR, "storage", "processed")

os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(PROCESSED_DIR, exist_ok=True)

# Mount static files to serve images
app.mount("/storage", StaticFiles(directory=os.path.join(BASE_DIR, "storage")), name="storage")

# Initialize Depth Estimator
depth_estimator = None
try:
    from transformers import pipeline
    print("Loading Depth Anything model... (this may take a minute on first run)")
    # Using the V1 model which has native Hugging Face transformers pipeline support
    depth_estimator = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf")
    print("Depth model loaded successfully!")
except Exception as e:
    print(f"Warning: Failed to load depth estimator. {e}")

def generate_depth_map(input_path: str, output_path: str):
    """Generates a depth map from an image using Depth Anything V2."""
    if not depth_estimator:
        print("Depth estimator not loaded, simulating depth map.")
        shutil.copy(input_path, output_path)
        return False
        
    try:
        from PIL import Image
        image = Image.open(input_path)
        # Handle transparency by converting to RGB for depth estimation
        if image.mode == 'RGBA':
            background = Image.new('RGB', image.size, (255, 255, 255))
            background.paste(image, mask=image.split()[3]) # 3 is the alpha channel
            image = background
        elif image.mode != 'RGB':
             image = image.convert('RGB')
             
        result = depth_estimator(image)
        depth_img = result["depth"]
        depth_img.save(output_path)
        return True
    except Exception as e:
        print(f"Depth generation failed: {e}")
        return False

# remove_bg_exhausted boolean no longer needed since we use local AI

# Initialize local AI session once globally to avoid reloading the model on every image
rmbg_session = None

def get_rmbg_session():
    global rmbg_session
    if rmbg_session is None:
        try:
            from rembg import new_session
            # Using the default u2net model which offers exceptional quality
            # equivalent to RMBG-1.4 but strictly compatible with this environment
            print("Loading local AI background removal model... (this may take a minute on first run)")
            rmbg_session = new_session('u2net')
            print("Local AI Background removal model loaded successfully!")
        except Exception as e:
            print(f"Failed to load local AI background remover: {e}")
    return rmbg_session

def remove_background(input_path: str, output_path: str):
    """Uses local AI (rembg/u2net) to remove background from an image. No API keys needed!"""
    try:
        from PIL import Image
        from rembg import remove
        
        img = Image.open(input_path)
        session = get_rmbg_session()
        
        if session:
            # Remove background locally
            result = remove(img, session=session)
            # Save as PNG to keep transparency
            result.save(output_path, format="PNG")
            return True
        else:
            # Fallback if session couldn't be loaded
            img.save(output_path, format="PNG")
            return True
            
    except Exception as e:
        print(f"Local AI Background removal failed: {e}")
        # Graceful fallback to copy
        try:
            from PIL import Image
            img = Image.open(input_path)
            img.save(output_path, format="PNG")
        except:
            shutil.copy(input_path, output_path)
        return True

from fastapi import BackgroundTasks
from typing import Dict, Any, List

# Simple in-memory storage for job status
# In production, this would be a database (Redis/Postgres)
jobs_db: Dict[str, Any] = {}

def process_photos_background(job_id: str, files_data: list, job_upload_dir: str, job_processed_dir: str):
    """Background task to process images so we don't block the API and trigger proxy timeouts."""
    try:
        jobs_db[job_id]["status"] = "processing"
        processed_files = []
        
        for i, (safe_filename, input_file_path) in enumerate(files_data):
            output_file_path = os.path.join(job_processed_dir, f"angle_{i:02d}_nobg.png")
            depth_file_path = os.path.join(job_processed_dir, f"angle_{i:02d}_depth.png")
            
            # 2. Try to Remove background
            bg_success = remove_background(input_file_path, output_file_path)
            
            # 3. Generate depth map
            # Use the output if bg removal succeeded and the file exists, otherwise fallback to original
            source_for_depth = output_file_path if bg_success and os.path.exists(output_file_path) else input_file_path
            
            generate_depth_map(source_for_depth, depth_file_path)
            
            # Determine correct folder prefix for URL since source might be in 'uploads' instead of 'processed'
            source_folder = "processed" if bg_success and os.path.exists(output_file_path) else "uploads"
            
            # Reconstruct the URL path relative to the storage dir
            rel_path_to_job = os.path.relpath(job_upload_dir, UPLOAD_DIR)
            
            processed_files.append({
                "angle": i,
                "image_url": f"/storage/{source_folder}/{rel_path_to_job}/{os.path.basename(source_for_depth)}",
                "depth_url": f"/storage/processed/{rel_path_to_job}/{os.path.basename(depth_file_path)}"
            })
            
        # Update job as complete
        jobs_db[job_id] = {
            "status": "success",
            "layers": processed_files
        }
    except Exception as e:
        import traceback
        traceback.print_exc()
        jobs_db[job_id] = {
            "status": "error",
            "message": str(e)
        }

@app.get("/")
def read_root():
    return {"status": "ok", "message": "Experience Eats Backend is running"}

@app.post("/api/process-dish")
async def process_dish_photos(
    background_tasks: BackgroundTasks, 
    shop_slug: str = Form(...),
    category: str = Form("uncategorized"),
    files: List[UploadFile] = File(...)
):
    """
    Receives 12 photos of a dish, saves them, and starts the 2.5D processing pipeline in the background.
    """
    if len(files) != 12:
        raise HTTPException(status_code=400, detail="Exactly 12 photos are required")
    
    # Generate common job ID
    job_id = str(uuid.uuid4())
    
    # Ensure nested shop directory structure
    job_upload_dir = os.path.join(UPLOAD_DIR, shop_slug, category, job_id)
    job_processed_dir = os.path.join(PROCESSED_DIR, shop_slug, category, job_id)
    os.makedirs(job_upload_dir, exist_ok=True)
    os.makedirs(job_processed_dir, exist_ok=True)
    
    files_data = []
    
    # Save uploaded files synchronously before passing to background task
    for i, file in enumerate(files):
        # Validate format
        if not file.content_type.startswith("image/"):
             raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
            
        file_extension = os.path.splitext(file.filename)[1]
        if not file_extension:
            file_extension = ".jpg" # fallback
            
        safe_filename = f"angle_{i:02d}{file_extension}"
        input_file_path = os.path.join(job_upload_dir, safe_filename)
        
        with open(input_file_path, "wb") as buffer:
            shutil.copyfileobj(file.file, buffer)
            
        files_data.append((safe_filename, input_file_path))
            
    # Give initial status
    jobs_db[job_id] = {"status": "pending"}
    
    # Send to background task
    background_tasks.add_task(process_photos_background, job_id, files_data, job_upload_dir, job_processed_dir)
            
    return {
        "status": "accepted", 
        "job_id": job_id,
        "message": "Processing started in the background. Poll /api/job-status/{job_id} for completion."
    }

@app.get("/api/job-status/{job_id}")
def get_job_status(job_id: str):
    """
    Endpoint for the frontend to poll the status of a long-running 2.5D crop/depth job.
    """
    if job_id not in jobs_db:
         raise HTTPException(status_code=404, detail="Job not found")
    return jobs_db[job_id]

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)