gnikhilchand's picture
Update app.py
0132850 verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
import os
# --- 1. SETUP & CONFIG ---
DEVICE = torch.device("cpu") # HF Spaces (Free) uses CPU
# Preprocessing transforms (Must match training!)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
print("⏳ Loading Models... (This might take a minute)")
# --- 2. LOAD DETECTOR MODEL ---
detector = models.resnet50(weights=None)
detector.fc = nn.Linear(detector.fc.in_features, 2) # Binary: Accident vs Non-Accident
try:
detector.load_state_dict(torch.load("accident_detector.pth", map_location=DEVICE))
print("✅ Accident Detector Loaded!")
except FileNotFoundError:
print("❌ ERROR: 'accident_detector.pth' not found. Please upload it.")
detector.to(DEVICE).eval()
# --- 3. LOAD SEVERITY MODEL (Optional) ---
severity_model = None
try:
if os.path.exists("severity_classifier.pth"):
severity_net = models.resnet50(weights=None)
severity_net.fc = nn.Linear(severity_net.fc.in_features, 3) # Minor, Substantial, Critical
severity_net.load_state_dict(torch.load("severity_classifier.pth", map_location=DEVICE))
severity_net.to(DEVICE).eval()
severity_model = severity_net
print("✅ Severity Model Loaded!")
else:
print("ℹ️ Severity model not found. Skipping severity check.")
except Exception as e:
print(f"⚠️ Could not load severity model: {e}")
# --- 4. LOAD SUMMARIZER (BLIP) ---
# This downloads from HuggingFace Hub automatically
print("⏳ Loading BLIP Model...")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(DEVICE)
print("✅ BLIP Summarizer Loaded!")
# --- 5. INFERENCE FUNCTION ---
def analyze_frame(image):
if image is None:
return "Please upload an image."
# A. Preprocess
img_pil = Image.fromarray(image).convert('RGB')
input_tensor = transform(img_pil).unsqueeze(0).to(DEVICE)
# B. Detect Accident
with torch.no_grad():
out = detector(input_tensor)
probs = torch.nn.functional.softmax(out, dim=1)
# Class 0 = Accident (Standard alphabetical sorting by ImageFolder)
accident_conf = probs[0][0].item()
is_accident = accident_conf > 0.5 # Threshold
if not is_accident:
return f"✅ Status: Normal Traffic\nConfidence: {1-accident_conf:.2%}"
# C. If Accident -> Assess Severity (if model exists)
severity_status = "Unknown (Model not loaded)"
if severity_model:
with torch.no_grad():
sev_out = severity_model(input_tensor)
sev_idx = torch.argmax(sev_out).item()
# Mapping based on folder names: 1, 2, 3
classes = ["Minor Impact", "Substantial Impact", "Critical Impact"]
severity_status = classes[sev_idx]
# D. If Accident -> Generate Summary
inputs = processor(img_pil, "a cctv footage of a car accident showing", return_tensors="pt").to(DEVICE)
out_ids = blip_model.generate(**inputs, max_new_tokens=50)
summary = processor.decode(out_ids[0], skip_special_tokens=True)
# E. Format Output
return f"""🚨 ACCIDENT DETECTED 🚨
--------------------------
Confidence: {accident_conf:.2%}
Severity: {severity_status}
📝 AI Summary:
"{summary}"
"""
# --- 6. DEFINE UI ---
# Removed the 'examples' list to fix the InvalidPathError
interface = gr.Interface(
fn=analyze_frame,
inputs=gr.Image(type="numpy", label="Upload CCTV Frame"),
outputs=gr.Textbox(label="Analysis Report"),
title="🛡️ AI Accident Detection System",
description="Upload a traffic image. The system will detect if an accident occurred, estimate severity, and describe the scene."
)
# --- 7. LAUNCH ---
if __name__ == "__main__":
interface.launch()