usman-khn's picture
Upload 5 files
4ffed49 verified
raw
history blame
5.76 kB
import os
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import gradio as gr
# --- 1. Model Configuration ---
SEQUENCE_LENGTH = 16
NUM_CLASSES = 4
MODEL_PATH = "best_model.pth" # Ensure this file is in the same directory
# Device setup for loading the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Class names mapping (must match the order used in your training code)
# Based on your classification report: 0=aggressive, 1=idle, 2=panic, 3=normal
CLASS_NAMES = ["aggressive", "idle", "panic", "normal"]
# --- 2. Model Definition (Copied from your notebook) ---
class CNNLSTM(nn.Module):
def __init__(self, num_classes):
super(CNNLSTM, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
# Input size calculation: 64 channels * (64/2/2) * (64/2/2) = 64 * 16 * 16
self.lstm = nn.LSTM(input_size=64*16*16, hidden_size=128, batch_first=True)
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
B, T, C, H, W = x.size() # Batch, Time (Sequence Length), Channel, Height, Width
# Apply CNN to each frame
x = x.view(B * T, C, H, W)
x = self.cnn(x)
# Flatten and reshape for LSTM
x = x.view(B, T, -1)
# Pass through LSTM
# We only need the output of the last time step
x, _ = self.lstm(x)
x = x[:, -1, :]
return self.fc(x)
# --- 3. Model Loading and Prediction Function ---
def load_model():
"""Loads the trained model weights."""
model = CNNLSTM(num_classes=NUM_CLASSES).to(device)
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please ensure your 'best_model.pth' is uploaded.")
# Load state_dict and map to CPU if necessary
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()
return model
# Global model instance
try:
model = load_model()
except FileNotFoundError as e:
print(e)
# This allows the app to start even if the model file is missing initially,
# but the prediction function will fail until it's fixed.
model = None
# Transformation pipeline for a single frame
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
def predict_crowd_behavior(input_images):
"""
Predicts the crowd behavior from a list of frames.
Args:
input_images (list[PIL.Image]): A list of images (frames).
Returns:
str: The predicted class name.
"""
if model is None:
return "ERROR: Model could not be loaded. Check logs for missing file."
if not input_images or len(input_images) != SEQUENCE_LENGTH:
return f"ERROR: Please upload exactly {SEQUENCE_LENGTH} frames."
try:
frames_tensor = []
for img in input_images:
# Ensure image is RGB (convert from any format PIL loads)
if img.mode != 'RGB':
img = img.convert('RGB')
frames_tensor.append(transform(img))
# Stack the frames and add batch dimension (1, T, C, H, W)
video_tensor = torch.stack(frames_tensor).unsqueeze(0).to(device)
with torch.no_grad():
output = model(video_tensor)
# Get the predicted class index
predicted_class_idx = torch.argmax(output, dim=1).item()
# Map index to class name
predicted_class_name = CLASS_NAMES[predicted_class_idx]
# Return the prediction and all class probabilities
probabilities = torch.softmax(output, dim=1)[0].cpu().numpy()
# Format the output as a dictionary for Gradio to display nicely
output_data = {
CLASS_NAMES[i]: probabilities[i] for i in range(len(CLASS_NAMES))
}
return output_data
except Exception as e:
return f"Prediction failed: {e}"
# --- 4. Gradio Interface ---
# Create an Image component for each frame in the sequence
image_components = [
gr.Image(
label=f"Frame {i+1}",
type="pil",
width=100,
height=100
)
for i in range(SEQUENCE_LENGTH)
]
description = f"""
# 🧠 CNN-LSTM Crowd Behavior Analysis from Aerial Video
This model analyzes a sequence of **{SEQUENCE_LENGTH} consecutive frames** extracted from an aerial video (e.g., drone footage) to classify the crowd's behavior.
## 🛠 Instructions
1. **Extract Frames:** Use the custom script you have (`extract_frames` from your notebook) or another tool to get **16 consecutive frames** from your video segment.
2. **Upload:** Upload each of the 16 frames to the image slots below.
3. **Predict:** Click the 'Predict Behavior' button to see the results.
The model classifies into one of these behaviors: **aggressive, idle, panic, or normal**.
"""
# ... (lines 148-154)
iface = gr.Interface(
fn=predict_crowd_behavior,
inputs=image_components,
outputs=gr.Label(num_top_classes=NUM_CLASSES),
title="Crowd Behavior Classifier (CNN-LSTM Hybrid)",
description=description,
live=False,
# FIX IS HERE: Change 'allow_flagging' to 'flagging_enabled'
#flagging_enabled=False,
)
if __name__ == "__main__":
# Gradio will run on localhost when run locally.
# Hugging Face Spaces will automatically use `iface.launch()` when deploying.
iface.launch()