Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| import torch.nn as nn | |
| import torchvision | |
| import cv2 | |
| import numpy as np | |
| import tempfile | |
| class MyModel(nn.Module): | |
| def __init__(self, num_classes=1): | |
| super(MyModel, self).__init__() # Initialize nn.Module | |
| self.model = torchvision.models.video.r3d_18(pretrained=True) | |
| self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) | |
| def preprocess_video(self, video_path, num_frames=40): | |
| """Preprocess video: sample frames, resize, normalize, and return tensor.""" | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frame_indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int) | |
| sampled_frames = [] | |
| for idx in frame_indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
| ret, frame = cap.read() | |
| if not ret: | |
| continue | |
| frame = cv2.resize(frame, (112, 112)) # Resize to 112x112 for R3D-18 | |
| frame = np.transpose(frame, (2, 0, 1)) # Channels-first | |
| sampled_frames.append(frame) | |
| cap.release() | |
| if len(sampled_frames) < num_frames: | |
| padding = np.zeros((num_frames - len(sampled_frames), 3, 112, 112)) | |
| sampled_frames = np.concatenate([sampled_frames, padding], axis=0) | |
| # Convert to tensor and rearrange dimensions to (3, num_frames, 112, 112) | |
| return torch.tensor(sampled_frames).float().permute(1, 0, 2, 3).unsqueeze(0) | |
| def forward(self, x): | |
| return self.model(x) | |
| def predict(self, video_path): | |
| """Test the model on the given videos and compute accuracy.""" | |
| self.model.eval() | |
| predictions = [] | |
| with torch.no_grad(): | |
| X = self.preprocess_video(video_path) | |
| output = self.model(X) | |
| pred = torch.sigmoid(output) # Apply sigmoid for binary classification | |
| # Track predictions | |
| predictions.append(pred.item()) | |
| return predictions | |
| def save_model(self, filepath): | |
| torch.save({ | |
| 'model_state_dict': self.state_dict(), | |
| }, filepath) | |
| def load_model(filepath, num_classes=1): | |
| model = MyModel(num_classes) | |
| checkpoint = torch.load(filepath, weights_only=True) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| return model | |
| model = MyModel().load_model('pre_3D_model.h5') | |
| def classify_video(video): | |
| prob = model.predict(video) | |
| label = "Non-violent" if prob[0] >= 0.5 else "Violent" | |
| violent_prob_percentage = f"{round((1 - prob[0]) * 100, 2)}% chance of being violent" | |
| return label, violent_prob_percentage | |
| # Set up the Gradio interface | |
| interface = gr.Interface( | |
| fn=classify_video, | |
| inputs=gr.Video(), # Allows video upload | |
| outputs=[ | |
| gr.Text(label="Classification"), # Label for classification output | |
| gr.Text(label="Violence Probability") # Label for probability output with text | |
| ], | |
| title="Violence Detection in Videos", | |
| description="Upload a video to classify it as violent or non-violent with a probability score." | |
| ) | |
| interface.launch(share=True, debug=True) | |