Spaces:
Sleeping
Sleeping
| import tensorflow as tf | |
| import cv2 | |
| import numpy as np | |
| import gradio as gr | |
| import math | |
| import logging | |
| import time | |
| import os | |
| import tempfile | |
| from urllib.parse import urlparse | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| class ShopliftingPrediction: | |
| def __init__(self, model_path, frame_width, frame_height, sequence_length): | |
| self.frame_width = frame_width | |
| self.frame_height = frame_height | |
| self.sequence_length = sequence_length | |
| self.model_path = model_path | |
| self.message = '' | |
| self.model = None | |
| def load_model(self): | |
| if self.model is not None: | |
| return | |
| # Define custom objects for loading the model | |
| custom_objects = { | |
| 'Conv2D': tf.keras.layers.Conv2D, | |
| 'MaxPooling2D': tf.keras.layers.MaxPooling2D, | |
| 'TimeDistributed': tf.keras.layers.TimeDistributed, | |
| 'LSTM': tf.keras.layers.LSTM, | |
| 'Dense': tf.keras.layers.Dense, | |
| 'Flatten': tf.keras.layers.Flatten, | |
| 'Dropout': tf.keras.layers.Dropout, | |
| 'Orthogonal': tf.keras.initializers.Orthogonal, | |
| } | |
| # Load the model with custom objects | |
| self.model = tf.keras.models.load_model(self.model_path, custom_objects=custom_objects) | |
| logging.info("Model loaded successfully.") | |
| def generate_message_content(self, probability, label): | |
| if label == 0: | |
| if probability <= 50: | |
| self.message = "No theft" | |
| elif probability <= 75: | |
| self.message = "There is little chance of theft" | |
| elif probability <= 85: | |
| self.message = "High probability of theft" | |
| else: | |
| self.message = "Very high probability of theft" | |
| elif label == 1: | |
| if probability <= 50: | |
| self.message = "No theft" | |
| elif probability <= 75: | |
| self.message = "The movement is confusing, watch" | |
| elif probability <= 85: | |
| self.message = "I think it's normal, but it's better to watch" | |
| else: | |
| self.message = "Movement is normal" | |
| def Pre_Process_Video(self, current_frame, previous_frame): | |
| diff = cv2.absdiff(current_frame, previous_frame) | |
| diff = cv2.GaussianBlur(diff, (3, 3), 0) | |
| resized_frame = cv2.resize(diff, (self.frame_height, self.frame_width)) | |
| gray_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2GRAY) | |
| normalized_frame = gray_frame / 255 | |
| return normalized_frame | |
| def Open_Video_Stream(self, stream_url): | |
| """Opens a video stream from a URL or local file path""" | |
| self.video_reader = cv2.VideoCapture(stream_url) | |
| # Check if the stream is opened successfully | |
| if not self.video_reader.isOpened(): | |
| raise ValueError(f"Could not open video stream: {stream_url}") | |
| self.original_video_width = int(self.video_reader.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| self.original_video_height = int(self.video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| self.fps = self.video_reader.get(cv2.CAP_PROP_FPS) | |
| # For streams without a defined FPS, use a default value | |
| if self.fps == 0 or math.isnan(self.fps): | |
| self.fps = 25 # Default FPS for streaming | |
| logging.info(f"Using default FPS of {self.fps} for stream") | |
| logging.info(f"Stream opened: {self.original_video_width}x{self.original_video_height} at {self.fps} FPS") | |
| def Single_Frame_Predict(self, frames_queue): | |
| probabilities = self.model.predict(np.expand_dims(frames_queue, axis=0), verbose=0)[0] | |
| predicted_label = np.argmax(probabilities) | |
| probability = math.floor(max(probabilities[0], probabilities[1]) * 100) | |
| return [probability, predicted_label] | |
| def Process_Stream(self, stream_url, output_file_path=None, buffer_size=None): | |
| """ | |
| Process a live video stream for shoplifting detection | |
| Args: | |
| stream_url: URL to the HTTP live stream or path to local video file | |
| output_file_path: Where to save the processed video (if None, a temp file is created) | |
| buffer_size: Size of frames to buffer before processing (if None, use sequence_length) | |
| Returns: | |
| Path to the processed video file | |
| """ | |
| self.load_model() | |
| # Create temporary file if output path not specified | |
| if output_file_path is None: | |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file: | |
| output_file_path = temp_file.name | |
| logging.info(f"Creating temporary output file: {output_file_path}") | |
| # Set buffer size to sequence length if not specified | |
| if buffer_size is None: | |
| buffer_size = self.sequence_length | |
| # Check if input is a URL or local file | |
| is_url = bool(urlparse(stream_url).scheme) | |
| if is_url: | |
| logging.info(f"Opening HTTP stream: {stream_url}") | |
| else: | |
| logging.info(f"Opening local video file: {stream_url}") | |
| self.Open_Video_Stream(stream_url) | |
| # Setup video writer with the same parameters as the input stream | |
| video_writer = cv2.VideoWriter( | |
| output_file_path, | |
| cv2.VideoWriter_fourcc('M', 'P', '4', 'V'), | |
| self.fps, | |
| (self.original_video_width, self.original_video_height) | |
| ) | |
| # Read first frame | |
| success, frame = self.video_reader.read() | |
| if not success: | |
| logging.error("Failed to read first frame from stream") | |
| self.video_reader.release() | |
| return None | |
| previous = frame.copy() | |
| frames_queue = [] | |
| start_time = time.time() | |
| frame_count = 0 | |
| while self.video_reader.isOpened(): | |
| # Read the next frame | |
| ok, frame = self.video_reader.read() | |
| if not ok: | |
| if is_url: | |
| # For streams, we might have temporary connection issues, wait and retry | |
| logging.warning("Stream frame read failed, waiting...") | |
| time.sleep(0.5) | |
| continue | |
| else: | |
| # For local files, end of file means we're done | |
| logging.info("End of video file reached") | |
| break | |
| # Process the frame | |
| frame_count += 1 | |
| normalized_frame = self.Pre_Process_Video(frame, previous) | |
| previous = frame.copy() | |
| frames_queue.append(normalized_frame) | |
| # When we have enough frames in our queue, make a prediction | |
| if len(frames_queue) >= buffer_size: | |
| # Use only the most recent sequence_length frames for prediction | |
| prediction_frames = frames_queue[-self.sequence_length:] | |
| if len(prediction_frames) == self.sequence_length: | |
| [probability, predicted_label] = self.Single_Frame_Predict(prediction_frames) | |
| self.generate_message_content(probability, predicted_label) | |
| message = f"{self.message}:{probability}%" | |
| logging.info(message) | |
| # Keep only the most recent frame in the queue for HTTP streams to avoid lag | |
| if is_url: | |
| frames_queue = frames_queue[-1:] | |
| else: | |
| # For video files, we can slide the window | |
| frames_queue = frames_queue[-(self.sequence_length//2):] | |
| # Add detection information to the frame | |
| cv2.rectangle(frame, (0, 0), (640, 40), (255, 255, 255), -1) | |
| cv2.putText(frame, self.message, (1, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) | |
| # Write the processed frame | |
| video_writer.write(frame) | |
| # For streams, periodically log progress | |
| if is_url and frame_count % 100 == 0: | |
| logging.info(f"Processed {frame_count} frames, elapsed time: {time.time() - start_time:.2f} seconds") | |
| # Clean up resources | |
| self.video_reader.release() | |
| video_writer.release() | |
| logging.info(f"Processing complete. Output saved to: {output_file_path}") | |
| return output_file_path | |
| def inference(model_path): | |
| shoplifting_prediction = ShopliftingPrediction(model_path, 90, 90, sequence_length=160) | |
| def process_input(input_source): | |
| """ | |
| Process either a video file upload or a streaming URL | |
| Args: | |
| input_source: Either a URL string or a path to an uploaded video file | |
| Returns: | |
| Path to the processed video file | |
| """ | |
| output_file_path = os.path.join(tempfile.gettempdir(), 'output.mp4') | |
| # Check if input is a string (URL) or a file path from upload | |
| if isinstance(input_source, str): | |
| # Input is likely a URL | |
| logging.info(f"Processing input as URL: {input_source}") | |
| return shoplifting_prediction.Process_Stream(input_source, output_file_path) | |
| else: | |
| # Input is likely an uploaded file | |
| logging.info(f"Processing input as uploaded file: {input_source}") | |
| return shoplifting_prediction.Process_Stream(input_source, output_file_path) | |
| return process_input | |
| model_path = 'lrcn_160S_90_90Q.h5' | |
| process_input = inference(model_path) | |
| # Create Gradio interface with both file upload and URL input options | |
| with gr.Blocks(title="Shoplifting Detection System") as iface: | |
| gr.Markdown("# Shoplifting Detection with HTTP Stream Support") | |
| with gr.Tabs(): | |
| with gr.TabItem("Video File"): | |
| video_input = gr.Video() | |
| video_submit = gr.Button("Process Video") | |
| video_output = gr.Video() | |
| video_submit.click( | |
| fn=process_input, | |
| inputs=[video_input], | |
| outputs=video_output | |
| ) | |
| with gr.TabItem("HTTP Stream URL"): | |
| stream_url = gr.Textbox( | |
| label="Enter HTTP Live Stream URL", | |
| placeholder="https://example.com/stream.m3u8" | |
| ) | |
| stream_submit = gr.Button("Process Stream") | |
| stream_output = gr.Video() | |
| stream_submit.click( | |
| fn=process_input, | |
| inputs=[stream_url], | |
| outputs=stream_output | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |