File size: 4,104 Bytes
a951b90
 
 
 
 
 
 
 
 
0132850
a951b90
 
 
 
 
 
 
 
 
 
 
0132850
a951b90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0132850
a951b90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0132850
a951b90
 
 
 
 
0132850
a951b90
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
import os

# --- 1. SETUP & CONFIG ---
DEVICE = torch.device("cpu") # HF Spaces (Free) uses CPU

# Preprocessing transforms (Must match training!)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

print("⏳ Loading Models... (This might take a minute)")

# --- 2. LOAD DETECTOR MODEL ---
detector = models.resnet50(weights=None) 
detector.fc = nn.Linear(detector.fc.in_features, 2) # Binary: Accident vs Non-Accident

try:
    detector.load_state_dict(torch.load("accident_detector.pth", map_location=DEVICE))
    print("✅ Accident Detector Loaded!")
except FileNotFoundError:
    print("❌ ERROR: 'accident_detector.pth' not found. Please upload it.")
detector.to(DEVICE).eval()

# --- 3. LOAD SEVERITY MODEL (Optional) ---
severity_model = None
try:
    if os.path.exists("severity_classifier.pth"):
        severity_net = models.resnet50(weights=None)
        severity_net.fc = nn.Linear(severity_net.fc.in_features, 3) # Minor, Substantial, Critical
        severity_net.load_state_dict(torch.load("severity_classifier.pth", map_location=DEVICE))
        severity_net.to(DEVICE).eval()
        severity_model = severity_net
        print("✅ Severity Model Loaded!")
    else:
        print("ℹ️ Severity model not found. Skipping severity check.")
except Exception as e:
    print(f"⚠️ Could not load severity model: {e}")

# --- 4. LOAD SUMMARIZER (BLIP) ---
# This downloads from HuggingFace Hub automatically
print("⏳ Loading BLIP Model...")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(DEVICE)
print("✅ BLIP Summarizer Loaded!")

# --- 5. INFERENCE FUNCTION ---
def analyze_frame(image):
    if image is None:
        return "Please upload an image."
    
    # A. Preprocess
    img_pil = Image.fromarray(image).convert('RGB')
    input_tensor = transform(img_pil).unsqueeze(0).to(DEVICE)
    
    # B. Detect Accident
    with torch.no_grad():
        out = detector(input_tensor)
        probs = torch.nn.functional.softmax(out, dim=1)
        # Class 0 = Accident (Standard alphabetical sorting by ImageFolder)
        accident_conf = probs[0][0].item()
        is_accident = accident_conf > 0.5 # Threshold

    if not is_accident:
        return f"✅ Status: Normal Traffic\nConfidence: {1-accident_conf:.2%}"

    # C. If Accident -> Assess Severity (if model exists)
    severity_status = "Unknown (Model not loaded)"
    if severity_model:
        with torch.no_grad():
            sev_out = severity_model(input_tensor)
            sev_idx = torch.argmax(sev_out).item()
            # Mapping based on folder names: 1, 2, 3
            classes = ["Minor Impact", "Substantial Impact", "Critical Impact"]
            severity_status = classes[sev_idx]

    # D. If Accident -> Generate Summary
    inputs = processor(img_pil, "a cctv footage of a car accident showing", return_tensors="pt").to(DEVICE)
    out_ids = blip_model.generate(**inputs, max_new_tokens=50)
    summary = processor.decode(out_ids[0], skip_special_tokens=True)

    # E. Format Output
    return f"""🚨 ACCIDENT DETECTED 🚨
--------------------------
Confidence: {accident_conf:.2%}
Severity: {severity_status}

📝 AI Summary:
"{summary}"
"""

# --- 6. DEFINE UI ---
# Removed the 'examples' list to fix the InvalidPathError
interface = gr.Interface(
    fn=analyze_frame,
    inputs=gr.Image(type="numpy", label="Upload CCTV Frame"),
    outputs=gr.Textbox(label="Analysis Report"),
    title="🛡️ AI Accident Detection System",
    description="Upload a traffic image. The system will detect if an accident occurred, estimate severity, and describe the scene."
)

# --- 7. LAUNCH ---
if __name__ == "__main__":
    interface.launch()