import gradio as gr import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image from transformers import BlipProcessor, BlipForConditionalGeneration import os # --- 1. SETUP & CONFIG --- DEVICE = torch.device("cpu") # HF Spaces (Free) uses CPU # Preprocessing transforms (Must match training!) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) print("⏳ Loading Models... (This might take a minute)") # --- 2. LOAD DETECTOR MODEL --- detector = models.resnet50(weights=None) detector.fc = nn.Linear(detector.fc.in_features, 2) # Binary: Accident vs Non-Accident try: detector.load_state_dict(torch.load("accident_detector.pth", map_location=DEVICE)) print("✅ Accident Detector Loaded!") except FileNotFoundError: print("❌ ERROR: 'accident_detector.pth' not found. Please upload it.") detector.to(DEVICE).eval() # --- 3. LOAD SEVERITY MODEL (Optional) --- severity_model = None try: if os.path.exists("severity_classifier.pth"): severity_net = models.resnet50(weights=None) severity_net.fc = nn.Linear(severity_net.fc.in_features, 3) # Minor, Substantial, Critical severity_net.load_state_dict(torch.load("severity_classifier.pth", map_location=DEVICE)) severity_net.to(DEVICE).eval() severity_model = severity_net print("✅ Severity Model Loaded!") else: print("ℹ️ Severity model not found. Skipping severity check.") except Exception as e: print(f"⚠️ Could not load severity model: {e}") # --- 4. LOAD SUMMARIZER (BLIP) --- # This downloads from HuggingFace Hub automatically print("⏳ Loading BLIP Model...") processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(DEVICE) print("✅ BLIP Summarizer Loaded!") # --- 5. INFERENCE FUNCTION --- def analyze_frame(image): if image is None: return "Please upload an image." # A. Preprocess img_pil = Image.fromarray(image).convert('RGB') input_tensor = transform(img_pil).unsqueeze(0).to(DEVICE) # B. Detect Accident with torch.no_grad(): out = detector(input_tensor) probs = torch.nn.functional.softmax(out, dim=1) # Class 0 = Accident (Standard alphabetical sorting by ImageFolder) accident_conf = probs[0][0].item() is_accident = accident_conf > 0.5 # Threshold if not is_accident: return f"✅ Status: Normal Traffic\nConfidence: {1-accident_conf:.2%}" # C. If Accident -> Assess Severity (if model exists) severity_status = "Unknown (Model not loaded)" if severity_model: with torch.no_grad(): sev_out = severity_model(input_tensor) sev_idx = torch.argmax(sev_out).item() # Mapping based on folder names: 1, 2, 3 classes = ["Minor Impact", "Substantial Impact", "Critical Impact"] severity_status = classes[sev_idx] # D. If Accident -> Generate Summary inputs = processor(img_pil, "a cctv footage of a car accident showing", return_tensors="pt").to(DEVICE) out_ids = blip_model.generate(**inputs, max_new_tokens=50) summary = processor.decode(out_ids[0], skip_special_tokens=True) # E. Format Output return f"""🚨 ACCIDENT DETECTED 🚨 -------------------------- Confidence: {accident_conf:.2%} Severity: {severity_status} 📝 AI Summary: "{summary}" """ # --- 6. DEFINE UI --- # Removed the 'examples' list to fix the InvalidPathError interface = gr.Interface( fn=analyze_frame, inputs=gr.Image(type="numpy", label="Upload CCTV Frame"), outputs=gr.Textbox(label="Analysis Report"), title="🛡️ AI Accident Detection System", description="Upload a traffic image. The system will detect if an accident occurred, estimate severity, and describe the scene." ) # --- 7. LAUNCH --- if __name__ == "__main__": interface.launch()