Arsooo commited on
Commit
9d59b1e
·
verified ·
1 Parent(s): f81a1ee

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +277 -0
app.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import cv2
3
+ import numpy as np
4
+ import gradio as gr
5
+ import math
6
+ import logging
7
+ import time
8
+ import os
9
+ import tempfile
10
+ from urllib.parse import urlparse
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+
15
+ class ShopliftingPrediction:
16
+ def __init__(self, model_path, frame_width, frame_height, sequence_length):
17
+ self.frame_width = frame_width
18
+ self.frame_height = frame_height
19
+ self.sequence_length = sequence_length
20
+ self.model_path = model_path
21
+ self.message = ''
22
+ self.model = None
23
+
24
+ def load_model(self):
25
+ if self.model is not None:
26
+ return
27
+
28
+ # Define custom objects for loading the model
29
+ custom_objects = {
30
+ 'Conv2D': tf.keras.layers.Conv2D,
31
+ 'MaxPooling2D': tf.keras.layers.MaxPooling2D,
32
+ 'TimeDistributed': tf.keras.layers.TimeDistributed,
33
+ 'LSTM': tf.keras.layers.LSTM,
34
+ 'Dense': tf.keras.layers.Dense,
35
+ 'Flatten': tf.keras.layers.Flatten,
36
+ 'Dropout': tf.keras.layers.Dropout,
37
+ 'Orthogonal': tf.keras.initializers.Orthogonal,
38
+ }
39
+
40
+ # Load the model with custom objects
41
+ self.model = tf.keras.models.load_model(self.model_path, custom_objects=custom_objects)
42
+ logging.info("Model loaded successfully.")
43
+
44
+ def generate_message_content(self, probability, label):
45
+ if label == 0:
46
+ if probability <= 50:
47
+ self.message = "No theft"
48
+ elif probability <= 75:
49
+ self.message = "There is little chance of theft"
50
+ elif probability <= 85:
51
+ self.message = "High probability of theft"
52
+ else:
53
+ self.message = "Very high probability of theft"
54
+ elif label == 1:
55
+ if probability <= 50:
56
+ self.message = "No theft"
57
+ elif probability <= 75:
58
+ self.message = "The movement is confusing, watch"
59
+ elif probability <= 85:
60
+ self.message = "I think it's normal, but it's better to watch"
61
+ else:
62
+ self.message = "Movement is normal"
63
+
64
+ def Pre_Process_Video(self, current_frame, previous_frame):
65
+ diff = cv2.absdiff(current_frame, previous_frame)
66
+ diff = cv2.GaussianBlur(diff, (3, 3), 0)
67
+ resized_frame = cv2.resize(diff, (self.frame_height, self.frame_width))
68
+ gray_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2GRAY)
69
+ normalized_frame = gray_frame / 255
70
+ return normalized_frame
71
+
72
+ def Open_Video_Stream(self, stream_url):
73
+ """Opens a video stream from a URL or local file path"""
74
+ self.video_reader = cv2.VideoCapture(stream_url)
75
+
76
+ # Check if the stream is opened successfully
77
+ if not self.video_reader.isOpened():
78
+ raise ValueError(f"Could not open video stream: {stream_url}")
79
+
80
+ self.original_video_width = int(self.video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
81
+ self.original_video_height = int(self.video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
82
+ self.fps = self.video_reader.get(cv2.CAP_PROP_FPS)
83
+
84
+ # For streams without a defined FPS, use a default value
85
+ if self.fps == 0 or math.isnan(self.fps):
86
+ self.fps = 25 # Default FPS for streaming
87
+ logging.info(f"Using default FPS of {self.fps} for stream")
88
+
89
+ logging.info(f"Stream opened: {self.original_video_width}x{self.original_video_height} at {self.fps} FPS")
90
+
91
+ def Single_Frame_Predict(self, frames_queue):
92
+ probabilities = self.model.predict(np.expand_dims(frames_queue, axis=0), verbose=0)[0]
93
+ predicted_label = np.argmax(probabilities)
94
+ probability = math.floor(max(probabilities[0], probabilities[1]) * 100)
95
+ return [probability, predicted_label]
96
+
97
+ def Process_Stream(self, stream_url, output_file_path=None, max_duration=30, buffer_size=None):
98
+ """
99
+ Process a live video stream for shoplifting detection
100
+
101
+ Args:
102
+ stream_url: URL to the HTTP live stream or path to local video file
103
+ output_file_path: Where to save the processed video (if None, a temp file is created)
104
+ max_duration: Maximum duration to process in seconds (for streams)
105
+ buffer_size: Size of frames to buffer before processing (if None, use sequence_length)
106
+
107
+ Returns:
108
+ Path to the processed video file
109
+ """
110
+ self.load_model()
111
+
112
+ # Create temporary file if output path not specified
113
+ if output_file_path is None:
114
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file:
115
+ output_file_path = temp_file.name
116
+ logging.info(f"Creating temporary output file: {output_file_path}")
117
+
118
+ # Set buffer size to sequence length if not specified
119
+ if buffer_size is None:
120
+ buffer_size = self.sequence_length
121
+
122
+ # Check if input is a URL or local file
123
+ is_url = bool(urlparse(stream_url).scheme)
124
+ if is_url:
125
+ logging.info(f"Opening HTTP stream: {stream_url}")
126
+ else:
127
+ logging.info(f"Opening local video file: {stream_url}")
128
+
129
+ self.Open_Video_Stream(stream_url)
130
+
131
+ # Setup video writer with the same parameters as the input stream
132
+ video_writer = cv2.VideoWriter(
133
+ output_file_path,
134
+ cv2.VideoWriter_fourcc('M', 'P', '4', 'V'),
135
+ self.fps,
136
+ (self.original_video_width, self.original_video_height)
137
+ )
138
+
139
+ # Read first frame
140
+ success, frame = self.video_reader.read()
141
+ if not success:
142
+ logging.error("Failed to read first frame from stream")
143
+ self.video_reader.release()
144
+ return None
145
+
146
+ previous = frame.copy()
147
+ frames_queue = []
148
+ start_time = time.time()
149
+ frame_count = 0
150
+
151
+ while self.video_reader.isOpened():
152
+ # Check if we've exceeded the max duration for streams
153
+ if is_url and (time.time() - start_time) > max_duration:
154
+ logging.info(f"Reached maximum stream capture duration of {max_duration} seconds")
155
+ break
156
+
157
+ # Read the next frame
158
+ ok, frame = self.video_reader.read()
159
+ if not ok:
160
+ if is_url:
161
+ # For streams, we might have temporary connection issues, wait and retry
162
+ logging.warning("Stream frame read failed, waiting...")
163
+ time.sleep(0.5)
164
+ continue
165
+ else:
166
+ # For local files, end of file means we're done
167
+ logging.info("End of video file reached")
168
+ break
169
+
170
+ # Process the frame
171
+ frame_count += 1
172
+ normalized_frame = self.Pre_Process_Video(frame, previous)
173
+ previous = frame.copy()
174
+ frames_queue.append(normalized_frame)
175
+
176
+ # When we have enough frames in our queue, make a prediction
177
+ if len(frames_queue) >= buffer_size:
178
+ # Use only the most recent sequence_length frames for prediction
179
+ prediction_frames = frames_queue[-self.sequence_length:]
180
+ if len(prediction_frames) == self.sequence_length:
181
+ [probability, predicted_label] = self.Single_Frame_Predict(prediction_frames)
182
+ self.generate_message_content(probability, predicted_label)
183
+ message = f"{self.message}:{probability}%"
184
+ logging.info(message)
185
+
186
+ # Keep only the most recent frame in the queue for HTTP streams to avoid lag
187
+ if is_url:
188
+ frames_queue = frames_queue[-1:]
189
+ else:
190
+ # For video files, we can slide the window
191
+ frames_queue = frames_queue[-(self.sequence_length//2):]
192
+
193
+ # Add detection information to the frame
194
+ cv2.rectangle(frame, (0, 0), (640, 40), (255, 255, 255), -1)
195
+ cv2.putText(frame, self.message, (1, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
196
+
197
+ # Write the processed frame
198
+ video_writer.write(frame)
199
+
200
+ # For streams, periodically log progress
201
+ if is_url and frame_count % 100 == 0:
202
+ logging.info(f"Processed {frame_count} frames, elapsed time: {time.time() - start_time:.2f} seconds")
203
+
204
+ # Clean up resources
205
+ self.video_reader.release()
206
+ video_writer.release()
207
+ logging.info(f"Processing complete. Output saved to: {output_file_path}")
208
+ return output_file_path
209
+
210
+ def inference(model_path):
211
+ shoplifting_prediction = ShopliftingPrediction(model_path, 90, 90, sequence_length=160)
212
+
213
+ def process_input(input_source, max_duration=30):
214
+ """
215
+ Process either a video file upload or a streaming URL
216
+
217
+ Args:
218
+ input_source: Either a URL string or a path to an uploaded video file
219
+ max_duration: Maximum duration to process for streams (in seconds)
220
+
221
+ Returns:
222
+ Path to the processed video file
223
+ """
224
+ output_file_path = os.path.join(tempfile.gettempdir(), 'output.mp4')
225
+
226
+ # Check if input is a string (URL) or a file path from upload
227
+ if isinstance(input_source, str):
228
+ # Input is likely a URL
229
+ logging.info(f"Processing input as URL: {input_source}")
230
+ return shoplifting_prediction.Process_Stream(input_source, output_file_path, max_duration)
231
+ else:
232
+ # Input is likely an uploaded file
233
+ logging.info(f"Processing input as uploaded file: {input_source}")
234
+ return shoplifting_prediction.Process_Stream(input_source, output_file_path)
235
+
236
+ return process_input
237
+
238
+ model_path = 'lrcn_160S_90_90Q.h5'
239
+ process_input = inference(model_path)
240
+
241
+ # Create Gradio interface with both file upload and URL input options
242
+ with gr.Blocks(title="Shoplifting Detection System") as iface:
243
+ gr.Markdown("# Shoplifting Detection with HTTP Stream Support")
244
+
245
+ with gr.Tabs():
246
+ with gr.TabItem("Video File"):
247
+ video_input = gr.Video()
248
+ video_submit = gr.Button("Process Video")
249
+ video_output = gr.Video()
250
+ video_submit.click(
251
+ fn=process_input,
252
+ inputs=[video_input],
253
+ outputs=video_output
254
+ )
255
+
256
+ with gr.TabItem("HTTP Stream URL"):
257
+ stream_url = gr.Textbox(
258
+ label="Enter HTTP Live Stream URL",
259
+ placeholder="https://example.com/stream.m3u8"
260
+ )
261
+ max_duration = gr.Slider(
262
+ minimum=5,
263
+ maximum=120,
264
+ value=30,
265
+ step=5,
266
+ label="Max Stream Duration (seconds)"
267
+ )
268
+ stream_submit = gr.Button("Process Stream")
269
+ stream_output = gr.Video()
270
+ stream_submit.click(
271
+ fn=process_input,
272
+ inputs=[stream_url, max_duration],
273
+ outputs=stream_output
274
+ )
275
+
276
+ if __name__ == "__main__":
277
+ iface.launch()