DamageLensAI / app.py
junaid17's picture
Upload 15 files
1ae016f verified
import os
import uuid
import shutil
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.staticfiles import StaticFiles
from PIL import Image
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from scripts.gradcam import get_resnet_gradcam, get_fusion_gradcam
from scripts.yolo import get_yolo_damage_boxes
from scripts.model_loader import initialize_models
load_dotenv()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
UPLOAD_DIR = "static/uploads"
RESULT_DIR = "static/results"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
app.mount("/static", StaticFiles(directory="static"), name="static")
class_map = {
0: "Front Breakage",
1: "Front Crushed",
2: "Front Normal",
3: "Rear Breakage",
4: "Rear Crushed",
5: "Rear Normal"
}
resnet_predictor, fusion_predictor = initialize_models(class_map)
@app.get("/")
def api_status():
return {"status": "API is running"}
@app.post("/predict")
async def predict_and_generate_cams(file: UploadFile = File(...), mode: str = "resnet"):
mode = mode.lower()
if mode not in {"resnet", "fusion"}:
raise HTTPException(status_code=400, detail="mode must be 'resnet' or 'fusion'")
unique_id = str(uuid.uuid4())
input_filename = f"{unique_id}_input.jpg"
input_path = os.path.join(UPLOAD_DIR, input_filename)
with open(input_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
if mode == "resnet":
output_name = f"{unique_id}_resnet.jpg"
output_path = os.path.join(RESULT_DIR, output_name)
get_resnet_gradcam(input_path, resnet_predictor, output_path)
selected_viz = f"/static/results/{output_name}"
resnet_viz = selected_viz
fusion_viz = None
else:
output_name = f"{unique_id}_fusion.jpg"
output_path = os.path.join(RESULT_DIR, output_name)
get_fusion_gradcam(input_path, fusion_predictor, output_path)
selected_viz = f"/static/results/{output_name}"
resnet_viz = None
fusion_viz = selected_viz
return {
"status": "success",
"original_image": f"/static/uploads/{input_filename}",
"selected_viz": selected_viz,
"resnet_viz": resnet_viz,
"fusion_viz": fusion_viz,
"mode": mode
}
@app.post("/predict/resnet")
async def resnet_prediction(image: UploadFile = File(...)):
try:
image = Image.open(image.file).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Invalid image file")
return resnet_predictor.resnet_predict(image_input=image)
@app.post("/predict/fusion")
async def fusion_prediction(image: UploadFile = File(...)):
try:
image = Image.open(image.file).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Invalid image file")
return fusion_predictor.predict(image_input=image)
@app.post("/predict/yolo")
async def yolo_detection(file: UploadFile = File(...)):
unique_id = str(uuid.uuid4())
input_filename = f"{unique_id}_input.jpg"
yolo_out_name = f"{unique_id}_yolo.jpg"
input_path = os.path.join(UPLOAD_DIR, input_filename)
yolo_path = os.path.join(RESULT_DIR, yolo_out_name)
with open(input_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
result = get_yolo_damage_boxes(input_path, yolo_path)
return {
"status": "success",
"original_image": f"/static/uploads/{input_filename}",
"yolo_image": f"/static/results/{yolo_out_name}",
"detections": result["detections"],
"total_detections": result["total_detections"],
"message": result["message"]
}