Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| import gradio as gr | |
| # --- 1. Model Configuration --- | |
| SEQUENCE_LENGTH = 16 | |
| NUM_CLASSES = 4 | |
| MODEL_PATH = "best_model.pth" # Ensure this file is in the same directory | |
| # Device setup for loading the model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Class names mapping (must match the order used in your training code) | |
| # Based on your classification report: 0=aggressive, 1=idle, 2=panic, 3=normal | |
| CLASS_NAMES = ["aggressive", "idle", "panic", "normal"] | |
| # --- 2. Model Definition (Copied from your notebook) --- | |
| class CNNLSTM(nn.Module): | |
| def __init__(self, num_classes): | |
| super(CNNLSTM, self).__init__() | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(3, 32, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2) | |
| ) | |
| # Input size calculation: 64 channels * (64/2/2) * (64/2/2) = 64 * 16 * 16 | |
| self.lstm = nn.LSTM(input_size=64*16*16, hidden_size=128, batch_first=True) | |
| self.fc = nn.Linear(128, num_classes) | |
| def forward(self, x): | |
| B, T, C, H, W = x.size() # Batch, Time (Sequence Length), Channel, Height, Width | |
| # Apply CNN to each frame | |
| x = x.view(B * T, C, H, W) | |
| x = self.cnn(x) | |
| # Flatten and reshape for LSTM | |
| x = x.view(B, T, -1) | |
| # Pass through LSTM | |
| # We only need the output of the last time step | |
| x, _ = self.lstm(x) | |
| x = x[:, -1, :] | |
| return self.fc(x) | |
| # --- 3. Model Loading and Prediction Function --- | |
| def load_model(): | |
| """Loads the trained model weights.""" | |
| model = CNNLSTM(num_classes=NUM_CLASSES).to(device) | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please ensure your 'best_model.pth' is uploaded.") | |
| # Load state_dict and map to CPU if necessary | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) | |
| model.eval() | |
| return model | |
| # Global model instance | |
| try: | |
| model = load_model() | |
| except FileNotFoundError as e: | |
| print(e) | |
| # This allows the app to start even if the model file is missing initially, | |
| # but the prediction function will fail until it's fixed. | |
| model = None | |
| # Transformation pipeline for a single frame | |
| transform = transforms.Compose([ | |
| transforms.Resize((64, 64)), | |
| transforms.ToTensor(), | |
| ]) | |
| def predict_crowd_behavior(input_images): | |
| """ | |
| Predicts the crowd behavior from a list of frames. | |
| Args: | |
| input_images (list[PIL.Image]): A list of images (frames). | |
| Returns: | |
| str: The predicted class name. | |
| """ | |
| if model is None: | |
| return "ERROR: Model could not be loaded. Check logs for missing file." | |
| if not input_images or len(input_images) != SEQUENCE_LENGTH: | |
| return f"ERROR: Please upload exactly {SEQUENCE_LENGTH} frames." | |
| try: | |
| frames_tensor = [] | |
| for img in input_images: | |
| # Ensure image is RGB (convert from any format PIL loads) | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| frames_tensor.append(transform(img)) | |
| # Stack the frames and add batch dimension (1, T, C, H, W) | |
| video_tensor = torch.stack(frames_tensor).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(video_tensor) | |
| # Get the predicted class index | |
| predicted_class_idx = torch.argmax(output, dim=1).item() | |
| # Map index to class name | |
| predicted_class_name = CLASS_NAMES[predicted_class_idx] | |
| # Return the prediction and all class probabilities | |
| probabilities = torch.softmax(output, dim=1)[0].cpu().numpy() | |
| # Format the output as a dictionary for Gradio to display nicely | |
| output_data = { | |
| CLASS_NAMES[i]: probabilities[i] for i in range(len(CLASS_NAMES)) | |
| } | |
| return output_data | |
| except Exception as e: | |
| return f"Prediction failed: {e}" | |
| # --- 4. Gradio Interface --- | |
| # Create an Image component for each frame in the sequence | |
| image_components = [ | |
| gr.Image( | |
| label=f"Frame {i+1}", | |
| type="pil", | |
| width=100, | |
| height=100 | |
| ) | |
| for i in range(SEQUENCE_LENGTH) | |
| ] | |
| description = f""" | |
| # 🧠 CNN-LSTM Crowd Behavior Analysis from Aerial Video | |
| This model analyzes a sequence of **{SEQUENCE_LENGTH} consecutive frames** extracted from an aerial video (e.g., drone footage) to classify the crowd's behavior. | |
| ## 🛠 Instructions | |
| 1. **Extract Frames:** Use the custom script you have (`extract_frames` from your notebook) or another tool to get **16 consecutive frames** from your video segment. | |
| 2. **Upload:** Upload each of the 16 frames to the image slots below. | |
| 3. **Predict:** Click the 'Predict Behavior' button to see the results. | |
| The model classifies into one of these behaviors: **aggressive, idle, panic, or normal**. | |
| """ | |
| # ... (lines 148-154) | |
| iface = gr.Interface( | |
| fn=predict_crowd_behavior, | |
| inputs=image_components, | |
| outputs=gr.Label(num_top_classes=NUM_CLASSES), | |
| title="Crowd Behavior Classifier (CNN-LSTM Hybrid)", | |
| description=description, | |
| live=False, | |
| # FIX IS HERE: Change 'allow_flagging' to 'flagging_enabled' | |
| #flagging_enabled=False, | |
| ) | |
| if __name__ == "__main__": | |
| # Gradio will run on localhost when run locally. | |
| # Hugging Face Spaces will automatically use `iface.launch()` when deploying. | |
| iface.launch() |