Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoImageProcessor, TimesformerForVideoClassification | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| import time | |
| from collections import deque | |
| # --- Configuration --- | |
| # Your Hugging Face model repository ID | |
| HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection" | |
| # These must match the values used during your training | |
| NUM_FRAMES = 16 # Still expecting 16 frames for a batch | |
| TARGET_IMAGE_HEIGHT = 224 | |
| TARGET_IMAGE_WIDTH = 224 | |
| # --- Load Model and Processor --- | |
| print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...") | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID) | |
| model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID) | |
| except Exception as e: | |
| print(f"Error loading model from Hugging Face Hub: {e}") | |
| # Handle error - exit or raise exception for Space to fail gracefully | |
| exit() | |
| model.eval() # Set model to evaluation mode | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| print(f"Model loaded successfully on {device}.") | |
| print(f"Model's class labels: {model.config.id2label}") | |
| # Initialize a global buffer for frames that the webcam continuously captures | |
| # This buffer will hold the *latest* NUM_FRAMES. | |
| # We use a global variable to persist state across Gradio calls. | |
| captured_frames_buffer = deque(maxlen=NUM_FRAMES) | |
| # This flag will control the 5-minute wait (if still needed for testing) | |
| wait_duration_seconds = 300 # 5 minutes | |
| # --- Function to continuously capture frames (without immediate processing) --- | |
| def capture_frame_into_buffer(image_np_array): | |
| global captured_frames_buffer | |
| # Convert Gradio's numpy array (RGB) to PIL Image | |
| pil_image = Image.fromarray(image_np_array) | |
| captured_frames_buffer.append(pil_image) | |
| # Return a message showing how many frames are buffered | |
| return f"Frames buffered: {len(captured_frames_buffer)}/{NUM_FRAMES}" | |
| # --- Function to trigger prediction with the buffered frames --- | |
| def make_prediction_from_buffer(): | |
| global captured_frames_buffer | |
| if len(captured_frames_buffer) < NUM_FRAMES: | |
| return "Not enough frames buffered yet. Please capture more frames." | |
| # Take a snapshot of the current frames in the buffer for prediction | |
| # Convert deque to a list for the processor | |
| frames_for_prediction = list(captured_frames_buffer) | |
| # --- Perform Inference --- | |
| print(f"Triggered inference on {len(frames_for_prediction)} frames...") | |
| processed_input = processor(images=frames_for_prediction, return_tensors="pt") | |
| pixel_values = processed_input.pixel_values.to(device) | |
| with torch.no_grad(): | |
| outputs = model(pixel_values) | |
| logits = outputs.logits | |
| predicted_class_id = logits.argmax(-1).item() | |
| predicted_label = model.config.id2label[predicted_class_id] | |
| confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item() | |
| prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})" | |
| print(prediction_text) # Print to Space logs | |
| # Clear the buffer after prediction if you want to capture a *new* set of frames for the next click | |
| # captured_frames_buffer.clear() | |
| # If you *don't* clear, the next click will re-predict on the same last 16 frames. | |
| # --- Introduce the artificial 5-minute wait (if still desired) --- | |
| # This will pause the *return* from this function, effectively blocking the UI update | |
| # If you remove this, the prediction will show immediately. | |
| # print(f"Initiating {wait_duration_seconds} second wait...") | |
| # time.sleep(wait_duration_seconds) | |
| # print("Wait finished.") | |
| return prediction_text | |
| # --- Gradio Interface --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| f""" | |
| # TimesFormer Crime Detection Live Demo (Manual Trigger) | |
| This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. | |
| It continuously buffers frames, but **only makes a prediction when you click the 'Predict' button**. | |
| The model requires **{NUM_FRAMES} frames** for a prediction. | |
| Please allow webcam access. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| webcam_input = gr.Image( | |
| sources=["webcam"], | |
| streaming=True, | |
| label="Live Webcam Feed" | |
| ) | |
| # This textbox will show the buffering status dynamically | |
| buffer_status = gr.Textbox(label="Frame Buffer Status", value=f"Frames buffered: 0/{NUM_FRAMES}") | |
| # Button to trigger prediction | |
| predict_button = gr.Button("Predict Latest Frames") | |
| with gr.Column(): | |
| prediction_output = gr.Textbox(label="Prediction Result", value="Click 'Predict Latest Frames' to start.") | |
| # Define actions | |
| # This continuously updates the buffer_status as frames come in | |
| webcam_input.stream(capture_frame_into_buffer, inputs=[webcam_input], outputs=[buffer_status]) | |
| # This triggers the prediction when the button is clicked | |
| predict_button.click(make_prediction_from_buffer, inputs=[], outputs=[prediction_output]) | |
| if __name__ == "__main__": | |
| demo.launch() |