Spaces:
Running
Running
| from fastapi import APIRouter, Request, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import HTMLResponse, FileResponse, JSONResponse | |
| from fastapi.templating import Jinja2Templates | |
| from starlette.background import BackgroundTask | |
| import shutil | |
| import os | |
| import uuid | |
| from pathlib import Path | |
| from typing import Optional | |
| import json | |
| import base64 | |
| from ultralytics import YOLO | |
| import cv2 | |
| import numpy as np | |
| from ..utils.llm_client import GroqAnalyzer | |
| # Templates directory | |
| TEMPLATES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "templates") | |
| templates = Jinja2Templates(directory=TEMPLATES_DIR) | |
| router = APIRouter() | |
| UPLOAD_DIR = os.path.join("/tmp", "uploads") | |
| RESULTS_DIR = os.path.join("/tmp", "results") | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| os.makedirs(RESULTS_DIR, exist_ok=True) | |
| ALLOWED_EXTENSIONS = {"jpg", "jpeg", "png", "tiff", "tif"} | |
| # Model paths | |
| # DAMAGE_MODEL_PATH = os.path.join("/tmp", "models", "damage", "weights", "weights", "best.pt") # Commented for now | |
| PARTS_MODEL_PATH = os.path.join("/tmp", "models", "parts", "weights", "weights", "best.pt") | |
| # Class names for parts | |
| PARTS_CLASS_NAMES = ['headlamp', 'front_bumper', 'hood', 'door', 'rear_bumper'] | |
| # Initialize GroqAnalyzer | |
| groq_analyzer = GroqAnalyzer() | |
| # Helper: Run YOLO inference and return results | |
| def run_yolo_inference(model_path, image_path, task='segment'): | |
| model = YOLO(model_path) | |
| results = model.predict(source=image_path, imgsz=640, conf=0.25, save=False, task=task) | |
| return results[0] | |
| # Helper: Draw masks and confidence on image | |
| def draw_masks_and_conf(image_path, yolo_result, class_names=None): | |
| img = cv2.imread(image_path) | |
| overlay = img.copy() | |
| out_img = img.copy() | |
| colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255), (0,255,255)] | |
| for i, box in enumerate(yolo_result.boxes): | |
| conf = float(box.conf[0]) | |
| cls = int(box.cls[0]) | |
| color = colors[cls % len(colors)] | |
| # Draw bbox | |
| x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
| cv2.rectangle(overlay, (x1, y1), (x2, y2), color, 2) | |
| label = f"{class_names[cls] if class_names else 'damage'}: {conf:.2f}" | |
| cv2.putText(overlay, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) | |
| # Draw mask if available | |
| if hasattr(yolo_result, 'masks') and yolo_result.masks is not None: | |
| mask = yolo_result.masks.data[i].cpu().numpy() | |
| mask = (mask * 255).astype(np.uint8) | |
| mask = cv2.resize(mask, (x2-x1, y2-y1)) | |
| roi = overlay[y1:y2, x1:x2] | |
| colored_mask = np.zeros_like(roi) | |
| colored_mask[mask > 127] = color | |
| overlay[y1:y2, x1:x2] = cv2.addWeighted(roi, 0.5, colored_mask, 0.5, 0) | |
| out_img = cv2.addWeighted(overlay, 0.7, img, 0.3, 0) | |
| return out_img | |
| # Helper: Generate JSON output | |
| def generate_json_output(filename, damage_result, parts_result): | |
| # Damage severity: use max confidence | |
| if damage_result is not None and hasattr(damage_result, 'boxes'): | |
| severity_score = float(max([float(box.conf[0]) for box in damage_result.boxes], default=0)) | |
| damage_regions = [] | |
| for box in damage_result.boxes: | |
| x1, y1, x2, y2 = map(float, box.xyxy[0]) | |
| conf = float(box.conf[0]) | |
| damage_regions.append({"bbox": [x1, y1, x2, y2], "confidence": conf}) | |
| else: | |
| severity_score = 0 | |
| damage_regions = [] | |
| # Parts | |
| parts = [] | |
| for i, box in enumerate(parts_result.boxes): | |
| x1, y1, x2, y2 = map(float, box.xyxy[0]) | |
| conf = float(box.conf[0]) | |
| cls = int(box.cls[0]) | |
| # Damage %: use mask area / bbox area if available | |
| damage_percentage = None | |
| if hasattr(parts_result, 'masks') and parts_result.masks is not None: | |
| mask = parts_result.masks.data[i].cpu().numpy() | |
| mask_area = np.sum(mask > 0.5) | |
| bbox_area = (x2-x1)*(y2-y1) | |
| damage_percentage = float(mask_area / bbox_area) if bbox_area > 0 else None | |
| parts.append({ | |
| "part": PARTS_CLASS_NAMES[cls] if cls < len(PARTS_CLASS_NAMES) else str(cls), | |
| "damaged": True, | |
| "confidence": conf, | |
| "damage_percentage": damage_percentage, | |
| "bbox": [x1, y1, x2, y2] | |
| }) | |
| # Optionally, add base64 masks | |
| # (not implemented here for brevity) | |
| return { | |
| "filename": filename, | |
| "damage": { | |
| "severity_score": severity_score, | |
| "regions": damage_regions | |
| }, | |
| "parts": parts, | |
| "cost_estimate": None | |
| } | |
| # Dummy login credentials | |
| def check_login(username: str, password: str) -> bool: | |
| return username == "demo" and password == "demo123" | |
| def home(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request, "result": None}) | |
| def login(request: Request, username: str = Form(...), password: str = Form(...)): | |
| if check_login(username, password): | |
| return templates.TemplateResponse("index.html", {"request": request, "result": None, "user": username}) | |
| return templates.TemplateResponse("login.html", {"request": request, "error": "Invalid credentials"}) | |
| def login_page(request: Request): | |
| return templates.TemplateResponse("login.html", {"request": request}) | |
| async def upload_image(request: Request, file: UploadFile = File(...)): | |
| try: | |
| ext = file.filename.split(".")[-1].lower() | |
| print(f"[DEBUG] Uploaded file extension: {ext}") | |
| if ext not in ALLOWED_EXTENSIONS: | |
| print(f"[DEBUG] Unsupported file type: {ext}") | |
| return templates.TemplateResponse("index.html", {"request": request, "error": "Unsupported file type."}) | |
| # Save uploaded file | |
| session_id = str(uuid.uuid4()) | |
| upload_filename = f"{session_id}_{file.filename}" | |
| upload_path = os.path.join(UPLOAD_DIR, upload_filename) | |
| print(f"[DEBUG] Saving uploaded file to: {upload_path}") | |
| with open(upload_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| print(f"[DEBUG] File saved. Running inference...") | |
| warning = None | |
| try: | |
| damage_result = None # Not used | |
| parts_result = run_yolo_inference(PARTS_MODEL_PATH, upload_path) | |
| print(f"[DEBUG] YOLO inference result: {parts_result}") | |
| parts_img = None | |
| json_output = None | |
| parts_img_url = None | |
| json_url = None | |
| if hasattr(parts_result, 'boxes') and len(parts_result.boxes) > 0: | |
| print(f"[DEBUG] Detected {len(parts_result.boxes)} parts.") | |
| parts_img = draw_masks_and_conf(upload_path, parts_result, class_names=PARTS_CLASS_NAMES) | |
| parts_img_filename = f"{session_id}_parts.png" | |
| parts_img_path = os.path.join(RESULTS_DIR, parts_img_filename) | |
| cv2.imwrite(parts_img_path, parts_img) | |
| print(f"[DEBUG] Parts image saved to: {parts_img_path}") | |
| parts_img_url = f"/download/result/{parts_img_filename}" | |
| json_output = generate_json_output(file.filename, damage_result, parts_result) | |
| json_filename = f"{session_id}_result.json" | |
| json_path = os.path.join(RESULTS_DIR, json_filename) | |
| with open(json_path, "w") as jf: | |
| json.dump(json_output, jf, indent=2) | |
| print(f"[DEBUG] JSON output saved to: {json_path}") | |
| json_url = f"/download/result/{json_filename}" | |
| else: | |
| warning = "No parts detected in the image." | |
| print("[DEBUG] No parts detected.") | |
| llm_analysis = groq_analyzer.analyze_damage(upload_path) | |
| print(f"[DEBUG] LLM analysis output: {llm_analysis}") | |
| result = { | |
| "filename": file.filename, | |
| "parts_image": parts_img_url, | |
| "json": json_output, | |
| "json_download": json_url, | |
| "llm_analysis": llm_analysis, | |
| "warning": warning | |
| } | |
| print("[DEBUG] Result dict:", result) | |
| except Exception as e: | |
| result = { | |
| "filename": file.filename, | |
| "error": f"Inference failed: {str(e)}", | |
| "parts_image": None, | |
| "json": None, | |
| "json_download": None, | |
| "llm_analysis": None, | |
| "warning": None | |
| } | |
| print("[ERROR] Inference failed:", e) | |
| import threading | |
| import time | |
| def delayed_cleanup(): | |
| time.sleep(300) # 5 minutes | |
| try: | |
| os.remove(upload_path) | |
| print(f"[DEBUG] Cleaned up upload: {upload_path}") | |
| except Exception as ce: | |
| print(f"[DEBUG] Cleanup error (upload): {ce}") | |
| for suffix in ["_parts.png", "_result.json"]: | |
| try: | |
| os.remove(os.path.join(RESULTS_DIR, f"{session_id}{suffix}")) | |
| print(f"[DEBUG] Cleaned up result: {os.path.join(RESULTS_DIR, f'{session_id}{suffix}')}" ) | |
| except Exception as ce: | |
| print(f"[DEBUG] Cleanup error (result): {ce}") | |
| threading.Thread(target=delayed_cleanup, daemon=True).start() | |
| return templates.TemplateResponse( | |
| "index.html", | |
| { | |
| "request": request, | |
| "result": result, | |
| "original_image": f"/download/upload/{upload_filename}" | |
| } | |
| ) | |
| except Exception as e: | |
| print(f"[ERROR] Inference failed: {str(e)}") | |
| return templates.TemplateResponse( | |
| "index.html", | |
| {"request": request, "error": f"Error processing image: {str(e)}"} | |
| ) | |
| # --- Serve files from /tmp/uploads and /tmp/results --- | |
| def download_uploaded_file(filename: str): | |
| file_path = os.path.join(UPLOAD_DIR, filename) | |
| if not os.path.exists(file_path): | |
| return JSONResponse(status_code=404, content={"error": "File not found"}) | |
| return FileResponse(file_path, filename=filename) | |
| def download_result_file(filename: str): | |
| file_path = os.path.join(RESULTS_DIR, filename) | |
| if not os.path.exists(file_path): | |
| return JSONResponse(status_code=404, content={"error": "File not found"}) | |
| return FileResponse(file_path, filename=filename) | |