File size: 4,104 Bytes
a951b90 0132850 a951b90 0132850 a951b90 0132850 a951b90 0132850 a951b90 0132850 a951b90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
import os
# --- 1. SETUP & CONFIG ---
DEVICE = torch.device("cpu") # HF Spaces (Free) uses CPU
# Preprocessing transforms (Must match training!)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
print("⏳ Loading Models... (This might take a minute)")
# --- 2. LOAD DETECTOR MODEL ---
detector = models.resnet50(weights=None)
detector.fc = nn.Linear(detector.fc.in_features, 2) # Binary: Accident vs Non-Accident
try:
detector.load_state_dict(torch.load("accident_detector.pth", map_location=DEVICE))
print("✅ Accident Detector Loaded!")
except FileNotFoundError:
print("❌ ERROR: 'accident_detector.pth' not found. Please upload it.")
detector.to(DEVICE).eval()
# --- 3. LOAD SEVERITY MODEL (Optional) ---
severity_model = None
try:
if os.path.exists("severity_classifier.pth"):
severity_net = models.resnet50(weights=None)
severity_net.fc = nn.Linear(severity_net.fc.in_features, 3) # Minor, Substantial, Critical
severity_net.load_state_dict(torch.load("severity_classifier.pth", map_location=DEVICE))
severity_net.to(DEVICE).eval()
severity_model = severity_net
print("✅ Severity Model Loaded!")
else:
print("ℹ️ Severity model not found. Skipping severity check.")
except Exception as e:
print(f"⚠️ Could not load severity model: {e}")
# --- 4. LOAD SUMMARIZER (BLIP) ---
# This downloads from HuggingFace Hub automatically
print("⏳ Loading BLIP Model...")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(DEVICE)
print("✅ BLIP Summarizer Loaded!")
# --- 5. INFERENCE FUNCTION ---
def analyze_frame(image):
if image is None:
return "Please upload an image."
# A. Preprocess
img_pil = Image.fromarray(image).convert('RGB')
input_tensor = transform(img_pil).unsqueeze(0).to(DEVICE)
# B. Detect Accident
with torch.no_grad():
out = detector(input_tensor)
probs = torch.nn.functional.softmax(out, dim=1)
# Class 0 = Accident (Standard alphabetical sorting by ImageFolder)
accident_conf = probs[0][0].item()
is_accident = accident_conf > 0.5 # Threshold
if not is_accident:
return f"✅ Status: Normal Traffic\nConfidence: {1-accident_conf:.2%}"
# C. If Accident -> Assess Severity (if model exists)
severity_status = "Unknown (Model not loaded)"
if severity_model:
with torch.no_grad():
sev_out = severity_model(input_tensor)
sev_idx = torch.argmax(sev_out).item()
# Mapping based on folder names: 1, 2, 3
classes = ["Minor Impact", "Substantial Impact", "Critical Impact"]
severity_status = classes[sev_idx]
# D. If Accident -> Generate Summary
inputs = processor(img_pil, "a cctv footage of a car accident showing", return_tensors="pt").to(DEVICE)
out_ids = blip_model.generate(**inputs, max_new_tokens=50)
summary = processor.decode(out_ids[0], skip_special_tokens=True)
# E. Format Output
return f"""🚨 ACCIDENT DETECTED 🚨
--------------------------
Confidence: {accident_conf:.2%}
Severity: {severity_status}
📝 AI Summary:
"{summary}"
"""
# --- 6. DEFINE UI ---
# Removed the 'examples' list to fix the InvalidPathError
interface = gr.Interface(
fn=analyze_frame,
inputs=gr.Image(type="numpy", label="Upload CCTV Frame"),
outputs=gr.Textbox(label="Analysis Report"),
title="🛡️ AI Accident Detection System",
description="Upload a traffic image. The system will detect if an accident occurred, estimate severity, and describe the scene."
)
# --- 7. LAUNCH ---
if __name__ == "__main__":
interface.launch() |