import os import uuid import shutil from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.staticfiles import StaticFiles from PIL import Image from fastapi.middleware.cors import CORSMiddleware from dotenv import load_dotenv from scripts.gradcam import get_resnet_gradcam, get_fusion_gradcam from scripts.yolo import get_yolo_damage_boxes from scripts.model_loader import initialize_models load_dotenv() app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) UPLOAD_DIR = "static/uploads" RESULT_DIR = "static/results" os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(RESULT_DIR, exist_ok=True) app.mount("/static", StaticFiles(directory="static"), name="static") class_map = { 0: "Front Breakage", 1: "Front Crushed", 2: "Front Normal", 3: "Rear Breakage", 4: "Rear Crushed", 5: "Rear Normal" } resnet_predictor, fusion_predictor = initialize_models(class_map) @app.get("/") def api_status(): return {"status": "API is running"} @app.post("/predict") async def predict_and_generate_cams(file: UploadFile = File(...), mode: str = "resnet"): mode = mode.lower() if mode not in {"resnet", "fusion"}: raise HTTPException(status_code=400, detail="mode must be 'resnet' or 'fusion'") unique_id = str(uuid.uuid4()) input_filename = f"{unique_id}_input.jpg" input_path = os.path.join(UPLOAD_DIR, input_filename) with open(input_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) if mode == "resnet": output_name = f"{unique_id}_resnet.jpg" output_path = os.path.join(RESULT_DIR, output_name) get_resnet_gradcam(input_path, resnet_predictor, output_path) selected_viz = f"/static/results/{output_name}" resnet_viz = selected_viz fusion_viz = None else: output_name = f"{unique_id}_fusion.jpg" output_path = os.path.join(RESULT_DIR, output_name) get_fusion_gradcam(input_path, fusion_predictor, output_path) selected_viz = f"/static/results/{output_name}" resnet_viz = None fusion_viz = selected_viz return { "status": "success", "original_image": f"/static/uploads/{input_filename}", "selected_viz": selected_viz, "resnet_viz": resnet_viz, "fusion_viz": fusion_viz, "mode": mode } @app.post("/predict/resnet") async def resnet_prediction(image: UploadFile = File(...)): try: image = Image.open(image.file).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="Invalid image file") return resnet_predictor.resnet_predict(image_input=image) @app.post("/predict/fusion") async def fusion_prediction(image: UploadFile = File(...)): try: image = Image.open(image.file).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="Invalid image file") return fusion_predictor.predict(image_input=image) @app.post("/predict/yolo") async def yolo_detection(file: UploadFile = File(...)): unique_id = str(uuid.uuid4()) input_filename = f"{unique_id}_input.jpg" yolo_out_name = f"{unique_id}_yolo.jpg" input_path = os.path.join(UPLOAD_DIR, input_filename) yolo_path = os.path.join(RESULT_DIR, yolo_out_name) with open(input_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) result = get_yolo_damage_boxes(input_path, yolo_path) return { "status": "success", "original_image": f"/static/uploads/{input_filename}", "yolo_image": f"/static/results/{yolo_out_name}", "detections": result["detections"], "total_detections": result["total_detections"], "message": result["message"] }