Spaces:
Running
Running
File size: 3,937 Bytes
1ae016f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | import os
import uuid
import shutil
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.staticfiles import StaticFiles
from PIL import Image
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from scripts.gradcam import get_resnet_gradcam, get_fusion_gradcam
from scripts.yolo import get_yolo_damage_boxes
from scripts.model_loader import initialize_models
load_dotenv()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
UPLOAD_DIR = "static/uploads"
RESULT_DIR = "static/results"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
app.mount("/static", StaticFiles(directory="static"), name="static")
class_map = {
0: "Front Breakage",
1: "Front Crushed",
2: "Front Normal",
3: "Rear Breakage",
4: "Rear Crushed",
5: "Rear Normal"
}
resnet_predictor, fusion_predictor = initialize_models(class_map)
@app.get("/")
def api_status():
return {"status": "API is running"}
@app.post("/predict")
async def predict_and_generate_cams(file: UploadFile = File(...), mode: str = "resnet"):
mode = mode.lower()
if mode not in {"resnet", "fusion"}:
raise HTTPException(status_code=400, detail="mode must be 'resnet' or 'fusion'")
unique_id = str(uuid.uuid4())
input_filename = f"{unique_id}_input.jpg"
input_path = os.path.join(UPLOAD_DIR, input_filename)
with open(input_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
if mode == "resnet":
output_name = f"{unique_id}_resnet.jpg"
output_path = os.path.join(RESULT_DIR, output_name)
get_resnet_gradcam(input_path, resnet_predictor, output_path)
selected_viz = f"/static/results/{output_name}"
resnet_viz = selected_viz
fusion_viz = None
else:
output_name = f"{unique_id}_fusion.jpg"
output_path = os.path.join(RESULT_DIR, output_name)
get_fusion_gradcam(input_path, fusion_predictor, output_path)
selected_viz = f"/static/results/{output_name}"
resnet_viz = None
fusion_viz = selected_viz
return {
"status": "success",
"original_image": f"/static/uploads/{input_filename}",
"selected_viz": selected_viz,
"resnet_viz": resnet_viz,
"fusion_viz": fusion_viz,
"mode": mode
}
@app.post("/predict/resnet")
async def resnet_prediction(image: UploadFile = File(...)):
try:
image = Image.open(image.file).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Invalid image file")
return resnet_predictor.resnet_predict(image_input=image)
@app.post("/predict/fusion")
async def fusion_prediction(image: UploadFile = File(...)):
try:
image = Image.open(image.file).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Invalid image file")
return fusion_predictor.predict(image_input=image)
@app.post("/predict/yolo")
async def yolo_detection(file: UploadFile = File(...)):
unique_id = str(uuid.uuid4())
input_filename = f"{unique_id}_input.jpg"
yolo_out_name = f"{unique_id}_yolo.jpg"
input_path = os.path.join(UPLOAD_DIR, input_filename)
yolo_path = os.path.join(RESULT_DIR, yolo_out_name)
with open(input_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
result = get_yolo_damage_boxes(input_path, yolo_path)
return {
"status": "success",
"original_image": f"/static/uploads/{input_filename}",
"yolo_image": f"/static/results/{yolo_out_name}",
"detections": result["detections"],
"total_detections": result["total_detections"],
"message": result["message"]
} |