Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoImageProcessor, TimesformerForVideoClassification | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| import time | |
| from collections import deque | |
| # --- Configuration --- | |
| # Your Hugging Face model repository ID | |
| HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection" | |
| # These must match the values used during your training | |
| NUM_FRAMES = 16 | |
| TARGET_IMAGE_HEIGHT = 224 | |
| TARGET_IMAGE_WIDTH = 224 | |
| # --- Load Model and Processor --- | |
| print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...") | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID) | |
| model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID) | |
| except Exception as e: | |
| print(f"Error loading model from Hugging Face Hub: {e}") | |
| # Handle error - exit or raise exception for Space to fail gracefully | |
| exit() | |
| model.eval() # Set model to evaluation mode | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| print(f"Model loaded successfully on {device}.") | |
| print(f"Model's class labels: {model.config.id2label}") | |
| # Initialize a global buffer for frames for the session | |
| # Use a deque for efficient appending/popping from both ends | |
| frame_buffer = deque(maxlen=NUM_FRAMES) | |
| last_inference_time = time.time() | |
| inference_interval = 1.0 # Predict every 1 second (1.0 / INFERENCE_FPS) | |
| current_prediction_text = "Buffering frames..." | |
| def predict_video_frame(image_np_array): | |
| global frame_buffer, last_inference_time, current_prediction_text | |
| # Gradio sends frames as numpy arrays (RGB) | |
| pil_image = Image.fromarray(image_np_array) | |
| frame_buffer.append(pil_image) | |
| current_time = time.time() | |
| # Only perform inference if we have enough frames and it's time for a prediction | |
| if len(frame_buffer) == NUM_FRAMES and (current_time - last_inference_time) >= inference_interval: | |
| last_inference_time = current_time | |
| # Preprocess the frames. processor expects a list of PIL Images or numpy arrays | |
| processed_input = processor(images=list(frame_buffer), return_tensors="pt") | |
| pixel_values = processed_input.pixel_values.to(device) | |
| with torch.no_grad(): | |
| outputs = model(pixel_values) | |
| logits = outputs.logits | |
| predicted_class_id = logits.argmax(-1).item() | |
| predicted_label = model.config.id2label[predicted_class_id] | |
| confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item() | |
| current_prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})" | |
| print(current_prediction_text) # Print to Space logs | |
| # Return the current prediction text for display in the UI | |
| return current_prediction_text | |
| # --- Gradio Interface --- | |
| # Create a streaming input for webcam | |
| webcam_input = gr.Image( | |
| sources=["webcam"], # Allows webcam input | |
| streaming=True, # Enables continuous streaming of frames | |
| shape=(TARGET_IMAGE_WIDTH, TARGET_IMAGE_HEIGHT), # Set expected input resolution | |
| label="Live Webcam Feed" | |
| ) | |
| # Output text box for predictions | |
| prediction_output = gr.Textbox(label="Real-time Prediction") | |
| # Define the Gradio Interface | |
| # We use Blocks for more control over layout if needed, but Interface works too. | |
| # For simplicity, we'll stick to a basic Interface | |
| # For streaming, gr.Interface.load() is more common, but let's define from scratch. | |
| demo = gr.Interface( | |
| fn=predict_video_frame, | |
| inputs=webcam_input, | |
| outputs=prediction_output, | |
| live=True, # Enable live updates | |
| allow_flagging="never", # Disable flagging on public demo | |
| title="TimesFormer Crime Detection Live Demo", | |
| description=f"This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. The model processes {NUM_FRAMES} frames at a time and makes a prediction every {inference_interval} seconds. Please allow webcam access.", | |
| # You might want to add examples for file uploads if you also want to support video files. | |
| # examples=["path/to/your/test_video.mp4"] # If you add video upload input | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |