Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoImageProcessor, TimesformerForVideoClassification | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| import time | |
| from collections import deque | |
| import base64 | |
| import io | |
| # --- Configuration --- | |
| HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection" | |
| MODEL_INPUT_NUM_FRAMES = 8 | |
| TARGET_IMAGE_HEIGHT = 224 | |
| TARGET_IMAGE_WIDTH = 224 | |
| RAW_RECORDING_DURATION_SECONDS = 10.0 | |
| FRAMES_TO_SAMPLE_PER_CLIP = 20 | |
| DELAY_BETWEEN_PREDICTIONS_SECONDS = 120.0 # 2 minutes for CPU | |
| # --- Load Model and Processor --- | |
| print(f"Loading model and processor from {HF_MODEL_REPO_ID}...") | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID) | |
| model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| exit() | |
| model.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| print(f"Model loaded on {device}.") | |
| # --- Global State Variables for Live Demo --- | |
| raw_frames_buffer = deque() | |
| current_clip_start_time = time.time() | |
| last_prediction_completion_time = time.time() | |
| app_state = "recording" # States: "recording", "predicting", "processing_delay" | |
| # --- Helper function to sample frames --- | |
| def sample_frames(frames_list, target_count): | |
| if not frames_list: | |
| return [] | |
| if len(frames_list) <= target_count: | |
| return frames_list | |
| indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int) | |
| sampled = [frames_list[int(i)] for i in indices] | |
| return sampled | |
| # --- Main processing function for Live Demo Stream --- | |
| def live_predict_stream(image_np_array): | |
| global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state | |
| current_time = time.time() | |
| pil_image = Image.fromarray(image_np_array) | |
| if app_state == "recording": | |
| raw_frames_buffer.append(pil_image) | |
| elapsed_recording_time = current_time - current_clip_start_time | |
| yield f"Recording: {elapsed_recording_time:.1f}/{RAW_RECORDING_DURATION_SECONDS}s. Raw frames: {len(raw_frames_buffer)}", "Buffering..." | |
| if elapsed_recording_time >= RAW_RECORDING_DURATION_SECONDS: | |
| # Transition to predicting state | |
| app_state = "predicting" | |
| yield "Preparing to predict...", "Processing..." | |
| print("DEBUG: Transitioning to 'predicting' state.") | |
| elif app_state == "predicting": | |
| # Ensure this prediction block only runs once per cycle | |
| if raw_frames_buffer: # Only proceed if there are frames to process | |
| print("DEBUG: Starting prediction.") | |
| try: | |
| sampled_raw_frames = sample_frames(list(raw_frames_buffer), FRAMES_TO_SAMPLE_PER_CLIP) | |
| frames_for_model = sample_frames(sampled_raw_frames, MODEL_INPUT_NUM_FRAMES) | |
| if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES: | |
| yield "Error during frame sampling.", f"Error: Not enough frames ({len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}). Resetting." | |
| print(f"ERROR: Insufficient frames for model input: {len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}. Resetting state.") | |
| app_state = "recording" # Reset state to start a new recording | |
| raw_frames_buffer.clear() | |
| current_clip_start_time = time.time() | |
| last_prediction_completion_time = time.time() | |
| return # Exit this stream call to wait for next frame or reset | |
| processed_input = processor(images=frames_for_model, return_tensors="pt") | |
| pixel_values = processed_input.pixel_values.to(device) | |
| with torch.no_grad(): | |
| outputs = model(pixel_values) | |
| logits = outputs.logits | |
| predicted_class_id = logits.argmax(-1).item() | |
| predicted_label = model.config.id2label.get(predicted_class_id, "Unknown") | |
| confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item() | |
| prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})" | |
| status_message = "Prediction complete." | |
| print(f"DEBUG: Prediction Result: {prediction_result}") | |
| # Yield the prediction result immediately to ensure UI update | |
| yield status_message, prediction_result | |
| # Clear buffer and transition to delay AFTER yielding the prediction | |
| raw_frames_buffer.clear() | |
| last_prediction_completion_time = current_time | |
| app_state = "processing_delay" | |
| print("DEBUG: Transitioning to 'processing_delay' state.") | |
| except Exception as e: | |
| error_message = f"Error during prediction: {e}" | |
| print(f"ERROR during prediction: {e}") | |
| # Yield error to UI | |
| yield "Prediction error.", error_message | |
| app_state = "processing_delay" # Still go to delay state to prevent constant errors | |
| raw_frames_buffer.clear() # Clear buffer to prevent re-processing same problematic frames | |
| elif app_state == "processing_delay": | |
| elapsed_delay = current_time - last_prediction_completion_time | |
| if elapsed_delay < DELAY_BETWEEN_PREDICTIONS_SECONDS: | |
| # Continue yielding the delay message and the last prediction result | |
| # Assuming prediction_result from previous state is still held by UI | |
| yield f"Delaying next prediction: {int(elapsed_delay)}/{int(DELAY_BETWEEN_PREDICTIONS_SECONDS)}s", gr.NO_VALUE # NO_VALUE keeps previous prediction visible | |
| else: | |
| # Delay is over, reset for new recording cycle | |
| app_state = "recording" | |
| current_clip_start_time = current_time | |
| print("DEBUG: Transitioning back to 'recording' state.") | |
| yield "Starting new recording...", "Ready for new prediction." | |
| # If for some reason nothing is yielded, return the current state to prevent UI freeze. | |
| # This acts as a fallback if no state transition happens. | |
| # However, with the yield statements, this might be less critical. | |
| # For streaming, yielding is the preferred way to update. | |
| # If the function ends without yielding, Gradio will just keep the last state. | |
| # We always yield in every branch. | |
| pass # No explicit return needed at the end if all paths yield | |
| def reset_app_state_manual(): | |
| global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state | |
| raw_frames_buffer.clear() | |
| current_clip_start_time = time.time() | |
| last_prediction_completion_time = time.time() | |
| app_state = "recording" | |
| print("DEBUG: Manual reset triggered.") | |
| # Return initial values immediately upon reset | |
| return "Ready to record...", "Ready for new prediction." | |
| # --- Gradio UI Layout --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| f""" | |
| # TimesFormer Crime Detection - Hugging Face Space Host | |
| This Space hosts the `owinymarvin/timesformer-crime-detection` model. | |
| Live webcam demo with recording and prediction phases. | |
| """ | |
| ) | |
| with gr.Tab("Live Webcam Demo"): | |
| gr.Markdown( | |
| f""" | |
| Continuously captures live webcam feed for **{RAW_RECORDING_DURATION_SECONDS} seconds**, | |
| then makes a prediction. There is a **{DELAY_BETWEEN_PREDICTIONS_SECONDS/60:.0f} minute delay** afterwards. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| webcam_input = gr.Image( | |
| sources=["webcam"], | |
| streaming=True, | |
| label="Live Webcam Feed" | |
| ) | |
| status_output = gr.Textbox(label="Current Status", value="Initializing...") | |
| reset_button = gr.Button("Reset / Start New Cycle") | |
| with gr.Column(): | |
| prediction_output = gr.Textbox(label="Prediction Result", value="Waiting...") | |
| # IMPORTANT: Use webcam_input.stream() with a generator function (live_predict_stream) | |
| # to enable progressive updates via 'yield'. | |
| webcam_input.stream( | |
| live_predict_stream, | |
| inputs=[webcam_input], | |
| outputs=[status_output, prediction_output] | |
| ) | |
| # The reset button is a regular click event, not a stream | |
| reset_button.click( | |
| reset_app_state_manual, | |
| inputs=[], | |
| outputs=[status_output, prediction_output] | |
| ) | |
| with gr.Tab("API Endpoint for External Clients"): | |
| gr.Markdown( | |
| """ | |
| Use this API endpoint to send base64-encoded frames for prediction. | |
| """ | |
| ) | |
| # Placeholder for the API tab. The actual API calls target /run/predict_from_frames_api | |
| gr.Interface( | |
| fn=lambda frames_list: "API endpoint is active for programmatic calls. See documentation in app.py.", | |
| inputs=gr.Json(label="List of Base64-encoded image strings"), | |
| outputs=gr.Textbox(label="API Response"), | |
| live=False, | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |