File size: 10,764 Bytes
9d59b1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99648d6
9d59b1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99648d6
9d59b1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99648d6
9d59b1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99648d6
9d59b1e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import tensorflow as tf
import cv2
import numpy as np
import gradio as gr
import math
import logging
import time
import os
import tempfile
from urllib.parse import urlparse

# Configure logging
logging.basicConfig(level=logging.INFO)

class ShopliftingPrediction:
    def __init__(self, model_path, frame_width, frame_height, sequence_length):
        self.frame_width = frame_width
        self.frame_height = frame_height
        self.sequence_length = sequence_length
        self.model_path = model_path
        self.message = ''
        self.model = None

    def load_model(self):
        if self.model is not None:
            return

        # Define custom objects for loading the model
        custom_objects = {
            'Conv2D': tf.keras.layers.Conv2D,
            'MaxPooling2D': tf.keras.layers.MaxPooling2D,
            'TimeDistributed': tf.keras.layers.TimeDistributed,
            'LSTM': tf.keras.layers.LSTM,
            'Dense': tf.keras.layers.Dense,
            'Flatten': tf.keras.layers.Flatten,
            'Dropout': tf.keras.layers.Dropout,
            'Orthogonal': tf.keras.initializers.Orthogonal,
        }

        # Load the model with custom objects
        self.model = tf.keras.models.load_model(self.model_path, custom_objects=custom_objects)
        logging.info("Model loaded successfully.")

    def generate_message_content(self, probability, label):
        if label == 0:
            if probability <= 50:
                self.message = "No theft"
            elif probability <= 75:
                self.message = "There is little chance of theft"
            elif probability <= 85:
                self.message = "High probability of theft"
            else:
                self.message = "Very high probability of theft"
        elif label == 1:
            if probability <= 50:
                self.message = "No theft"
            elif probability <= 75:
                self.message = "The movement is confusing, watch"
            elif probability <= 85:
                self.message = "I think it's normal, but it's better to watch"
            else:
                self.message = "Movement is normal"

    def Pre_Process_Video(self, current_frame, previous_frame):
        diff = cv2.absdiff(current_frame, previous_frame)
        diff = cv2.GaussianBlur(diff, (3, 3), 0)
        resized_frame = cv2.resize(diff, (self.frame_height, self.frame_width))
        gray_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2GRAY)
        normalized_frame = gray_frame / 255
        return normalized_frame

    def Open_Video_Stream(self, stream_url):
        """Opens a video stream from a URL or local file path"""
        self.video_reader = cv2.VideoCapture(stream_url)
        
        # Check if the stream is opened successfully
        if not self.video_reader.isOpened():
            raise ValueError(f"Could not open video stream: {stream_url}")
            
        self.original_video_width = int(self.video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.original_video_height = int(self.video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
        self.fps = self.video_reader.get(cv2.CAP_PROP_FPS)
        
        # For streams without a defined FPS, use a default value
        if self.fps == 0 or math.isnan(self.fps):
            self.fps = 25  # Default FPS for streaming
            logging.info(f"Using default FPS of {self.fps} for stream")
        
        logging.info(f"Stream opened: {self.original_video_width}x{self.original_video_height} at {self.fps} FPS")

    def Single_Frame_Predict(self, frames_queue):
        probabilities = self.model.predict(np.expand_dims(frames_queue, axis=0), verbose=0)[0]
        predicted_label = np.argmax(probabilities)
        probability = math.floor(max(probabilities[0], probabilities[1]) * 100)
        return [probability, predicted_label]

    def Process_Stream(self, stream_url, output_file_path=None, buffer_size=None):
        """
        Process a live video stream for shoplifting detection
        
        Args:
            stream_url: URL to the HTTP live stream or path to local video file
            output_file_path: Where to save the processed video (if None, a temp file is created)
            buffer_size: Size of frames to buffer before processing (if None, use sequence_length)
            
        Returns:
            Path to the processed video file
        """
        self.load_model()
        
        # Create temporary file if output path not specified
        if output_file_path is None:
            with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file:
                output_file_path = temp_file.name
                logging.info(f"Creating temporary output file: {output_file_path}")
        
        # Set buffer size to sequence length if not specified
        if buffer_size is None:
            buffer_size = self.sequence_length
        
        # Check if input is a URL or local file
        is_url = bool(urlparse(stream_url).scheme)
        if is_url:
            logging.info(f"Opening HTTP stream: {stream_url}")
        else:
            logging.info(f"Opening local video file: {stream_url}")
        
        self.Open_Video_Stream(stream_url)
        
        # Setup video writer with the same parameters as the input stream
        video_writer = cv2.VideoWriter(
            output_file_path, 
            cv2.VideoWriter_fourcc('M', 'P', '4', 'V'),
            self.fps, 
            (self.original_video_width, self.original_video_height)
        )
        
        # Read first frame
        success, frame = self.video_reader.read()
        if not success:
            logging.error("Failed to read first frame from stream")
            self.video_reader.release()
            return None
            
        previous = frame.copy()
        frames_queue = []
        start_time = time.time()
        frame_count = 0
        
        while self.video_reader.isOpened():
            # Read the next frame
            ok, frame = self.video_reader.read()
            if not ok:
                if is_url:
                    # For streams, we might have temporary connection issues, wait and retry
                    logging.warning("Stream frame read failed, waiting...")
                    time.sleep(0.5)
                    continue
                else:
                    # For local files, end of file means we're done
                    logging.info("End of video file reached")
                    break
                    
            # Process the frame
            frame_count += 1
            normalized_frame = self.Pre_Process_Video(frame, previous)
            previous = frame.copy()
            frames_queue.append(normalized_frame)
            
            # When we have enough frames in our queue, make a prediction
            if len(frames_queue) >= buffer_size:
                # Use only the most recent sequence_length frames for prediction
                prediction_frames = frames_queue[-self.sequence_length:]
                if len(prediction_frames) == self.sequence_length:
                    [probability, predicted_label] = self.Single_Frame_Predict(prediction_frames)
                    self.generate_message_content(probability, predicted_label)
                    message = f"{self.message}:{probability}%"
                    logging.info(message)
                
                # Keep only the most recent frame in the queue for HTTP streams to avoid lag
                if is_url:
                    frames_queue = frames_queue[-1:]
                else:
                    # For video files, we can slide the window
                    frames_queue = frames_queue[-(self.sequence_length//2):]
            
            # Add detection information to the frame
            cv2.rectangle(frame, (0, 0), (640, 40), (255, 255, 255), -1)
            cv2.putText(frame, self.message, (1, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
            
            # Write the processed frame
            video_writer.write(frame)
            
            # For streams, periodically log progress
            if is_url and frame_count % 100 == 0:
                logging.info(f"Processed {frame_count} frames, elapsed time: {time.time() - start_time:.2f} seconds")
        
        # Clean up resources
        self.video_reader.release()
        video_writer.release()
        logging.info(f"Processing complete. Output saved to: {output_file_path}")
        return output_file_path

def inference(model_path):
    shoplifting_prediction = ShopliftingPrediction(model_path, 90, 90, sequence_length=160)
    
    def process_input(input_source):
        """
        Process either a video file upload or a streaming URL
        
        Args:
            input_source: Either a URL string or a path to an uploaded video file
            
        Returns:
            Path to the processed video file
        """
        output_file_path = os.path.join(tempfile.gettempdir(), 'output.mp4')
        
        # Check if input is a string (URL) or a file path from upload
        if isinstance(input_source, str):
            # Input is likely a URL
            logging.info(f"Processing input as URL: {input_source}")
            return shoplifting_prediction.Process_Stream(input_source, output_file_path)
        else:
            # Input is likely an uploaded file
            logging.info(f"Processing input as uploaded file: {input_source}")
            return shoplifting_prediction.Process_Stream(input_source, output_file_path)
    
    return process_input

model_path = 'lrcn_160S_90_90Q.h5'
process_input = inference(model_path)

# Create Gradio interface with both file upload and URL input options
with gr.Blocks(title="Shoplifting Detection System") as iface:
    gr.Markdown("# Shoplifting Detection with HTTP Stream Support")
    
    with gr.Tabs():
        with gr.TabItem("Video File"):
            video_input = gr.Video()
            video_submit = gr.Button("Process Video")
            video_output = gr.Video()
            video_submit.click(
                fn=process_input, 
                inputs=[video_input], 
                outputs=video_output
            )
            
        with gr.TabItem("HTTP Stream URL"):
            stream_url = gr.Textbox(
                label="Enter HTTP Live Stream URL", 
                placeholder="https://example.com/stream.m3u8"
            )
            stream_submit = gr.Button("Process Stream")
            stream_output = gr.Video()
            stream_submit.click(
                fn=process_input, 
                inputs=[stream_url], 
                outputs=stream_output
            )

if __name__ == "__main__":
    iface.launch()