ccclllwww commited on
Commit
af773bf
·
verified ·
1 Parent(s): b99112d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -233
app.py CHANGED
@@ -17,156 +17,151 @@ import uuid
17
  import pandas as pd
18
 
19
  # ======================
20
- # 模型加载函数(缓存)
21
  # ======================
22
 
23
  @st.cache_resource
24
  def load_smoke_pipeline():
25
- """初始化并缓存吸烟图片分类 pipeline"""
26
  return pipeline("image-classification", model="ccclllwww/smoker_cls_base_V9", use_fast=True)
27
 
28
  @st.cache_resource
29
  def load_gender_pipeline():
30
- """初始化并缓存性别图片分类 pipeline"""
31
  return pipeline("image-classification", model="rizvandwiki/gender-classification-2", use_fast=True)
32
 
33
  @st.cache_resource
34
  def load_age_pipeline():
35
- """初始化并缓存年龄图片分类 pipeline"""
36
  return pipeline("image-classification", model="akashmaggon/vit-base-age-classification", use_fast=True)
37
 
38
- # 预先加载所有模型
39
  smoke_pipeline = load_smoke_pipeline()
40
  gender_pipeline = load_gender_pipeline()
41
  age_pipeline = load_age_pipeline()
42
 
43
  # ======================
44
- # remote settings
45
  # ======================
46
- # Find your Account SID and Auth Token at twilio.com/console
47
- # and set the environment variables. See http://twil.io/secure
48
 
49
- account_sid = os.environ['TWILIO_ACCOUNT_SID']
50
- auth_token = os.environ['TWILIO_AUTH_TOKEN']
51
- client = Client(account_sid, auth_token)
 
 
 
 
 
 
52
 
53
- token = client.tokens.create()
54
 
55
  # ======================
56
- # 音频加载函数(缓存)
57
  # ======================
58
 
59
  @st.cache_resource
60
- def load_all_audios():
61
- """加载 audio 目录中的所有 .wav 文件,并返回一个字典,
62
- 键为文件名(不带扩展名),值为音频字节数据。"""
63
  audio_dir = "audio"
 
 
 
64
  audio_files = [f for f in os.listdir(audio_dir) if f.endswith(".wav")]
65
  audio_dict = {}
66
  for audio_file in audio_files:
67
- file_path = os.path.join(audio_dir, audio_file)
68
- with open(file_path, "rb") as af:
69
- audio_bytes = af.read()
70
- # 去掉扩展名作为键
71
- key = os.path.splitext(audio_file)[0]
72
- audio_dict[key] = audio_bytes
73
  return audio_dict
74
 
75
- # 应用启动时加载所有音频
76
- audio_data = load_all_audios()
77
 
78
  # ======================
79
- # 照片檢測处理函数
80
  # ======================
81
 
82
- def smoking_detection(image: Image.Image) -> str:
 
83
  try:
84
  output = smoke_pipeline(image)
85
- status = output[0]["label"]
86
- return status
87
  except Exception as e:
88
- st.error(f"🔍 图像处理错误: {str(e)}")
89
  st.stop()
90
-
91
- def gender_detection(image: Image.Image) -> str:
 
92
  try:
93
  output = gender_pipeline(image)
94
- status = output[0]["label"]
95
- return status
96
  except Exception as e:
97
- st.error(f"🔍 图像处理错误: {str(e)}")
98
  st.stop()
99
-
100
- def age_detection(image: Image.Image) -> str:
 
101
  try:
102
  output = age_pipeline(image)
103
- status = output[0]["label"]
104
- return status
105
  except Exception as e:
106
- st.error(f"🔍 图像处理错误: {str(e)}")
107
  st.stop()
108
-
109
  # ======================
110
- # 實時檢測核心处理函数
111
  # ======================
112
 
113
  @st.cache_data(show_spinner=False, max_entries=3)
114
- def smoking_classification(image: Image.Image) -> str:
115
- """接受 PIL 图片并利用吸烟分类 pipeline 进行判定,返回标签(如 "smoking")。"""
116
  try:
117
  output = smoke_pipeline(image)
118
- status = max(output, key=lambda x: x["score"])['label']
119
- return status
120
  except Exception as e:
121
- st.error(f"🔍 图像处理错误: {str(e)}")
122
  st.stop()
123
 
124
  @st.cache_data(show_spinner=False, max_entries=3)
125
- def gender_classification(image: Image.Image) -> str:
126
- """进行性别分类,返回模型输出的性别(依模型输出)。"""
127
  try:
128
  output = gender_pipeline(image)
129
- status = max(output, key=lambda x: x["score"])['label']
130
- return status
131
  except Exception as e:
132
- st.error(f"🔍 图像处理错误: {str(e)}")
133
  st.stop()
134
 
135
  @st.cache_data(show_spinner=False, max_entries=3)
136
- def age_classification(image: Image.Image) -> str:
137
- """进行年龄分类,返回年龄范围,例如 "10-19" 等。"""
138
  try:
139
  output = age_pipeline(image)
140
- age_range = max(output, key=lambda x: x["score"])['label']
141
- return age_range
142
  except Exception as e:
143
- st.error(f"🔍 图像处理错误: {str(e)}")
144
  st.stop()
145
 
146
  # ======================
147
- # 自定义JS播放音频函数
148
  # ======================
149
 
150
- @st.cache_resource
151
- def play_audio_via_js(audio_bytes):
152
- """
153
- 利用自定义 HTML 和 JavaScript 播放音频。
154
- 将二进制音频数据转换为 Base64 后嵌入 audio 标签,
155
- 并用 JS 在页面加载后模拟点击进行播放。
156
- """
157
  audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
 
158
  html_content = f"""
159
- <audio id="audio_player_{uuid.uuid4()}" controls style="width: 100%;">
160
  <source src="data:audio/wav;base64,{audio_base64}" type="audio/wav">
161
  Your browser does not support the audio element.
162
  </audio>
163
  <script type="text/javascript">
164
  window.addEventListener('DOMContentLoaded', function() {{
165
  setTimeout(function() {{
166
- var audioElement = document.getElementById("audio_player_{uuid.uuid4()}");
167
  if (audioElement) {{
168
  audioElement.play().catch(function(e) {{
169
- console.log("播放被浏览器阻止:", e);
170
  }});
171
  }}
172
  }}, 1000);
@@ -176,213 +171,187 @@ def play_audio_via_js(audio_bytes):
176
  st.components.v1.html(html_content, height=150)
177
 
178
  # ======================
179
- # VideoTransformer 定义:处理摄像头帧与快照捕获
180
  # ======================
181
 
182
  class VideoTransformer(VideoTransformerBase):
183
  def __init__(self):
184
- self.snapshots = [] # 存储捕获的快照
185
- self.last_capture_time = time.time() # 上次捕获时间
186
- self.capture_interval = 1 # 每0.5秒捕获一张快照
 
187
 
188
  def transform(self, frame):
189
- """从摄像头流捕获单帧图像,并转换为 PIL Image。"""
190
  img = frame.to_ndarray(format="bgr24")
191
  current_time = time.time()
192
- # 每隔 capture_interval 秒捕获一张快照,直到捕获5张
193
- if current_time - self.last_capture_time >= self.capture_interval and len(self.snapshots) < 5:
194
  img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
195
  self.snapshots.append(Image.fromarray(img_rgb))
196
  self.last_capture_time = current_time
197
- st.write(f"已捕获快照 {len(self.snapshots)}/20")
198
- return img # 返回原始帧以供前端显示
199
 
200
  # ======================
201
  # Cover Page
202
  # ======================
203
 
204
  def cover_page():
205
- """Display the cover page with project overview and usage instructions."""
206
- st.title("Smoking Detection System")
207
-
208
- st.header("Project Overview")
209
- st.write("""
210
- The Smoking Detection System is a Streamlit-based web application designed to detect smoking behavior
211
- in images or real-time video streams. It leverages advanced machine learning models to classify images
212
- for smoking activity, gender, and age range. The system is structured to provide both static image analysis
213
- and real-time video processing, with audio feedback for detected smoking incidents.
214
 
215
- **Purpose**: The primary goal is to identify smoking behavior in public or controlled environments,
216
- providing insights into the demographics (gender and age) of individuals engaged in smoking. This can
217
- be used for monitoring compliance with no-smoking policies or conducting behavioral studies.
218
-
219
- **Significance**: The application promotes public health by enabling automated monitoring of smoking
220
- activities, potentially aiding in the enforcement of smoking regulations and raising awareness about
221
- smoking prevalence across different demographics.
222
 
223
- **Structure**:
224
- - **Cover Page**: Provides an overview and usage instructions.
225
- - **Photo Detection**: Analyzes a single uploaded or captured image for smoking, gender, and age.
226
- - **Real-Time Video Detection**: Processes video streams, capturing snapshots to detect smoking and
227
- analyze demographics if smoking is detected.
 
 
 
228
  """)
229
 
230
- st.header("Usage Instructions")
231
- st.write("""
232
- 1. **Navigation**: Use the sidebar to select a page:
233
- - **Cover Page**: View this project overview.
234
- - **Photo Detection**: Upload an image or use the camera to capture a photo for analysis.
235
- - **Real-Time Video Detection**: Enable the webcam for continuous monitoring.
236
  2. **Photo Detection**:
237
- - Choose to upload an image or capture one using the camera.
238
- - The system will classify the image for smoking. If smoking is detected, it will further analyze
239
- gender and age, and play an audio alert based on the results.
240
  3. **Real-Time Video Detection**:
241
- - Start the webcam to capture 20 snapshots over one minute.
242
- - The system analyzes each snapshot for smoking. If smoking is detected in more than two snapshots,
243
- it performs gender and age classification and displays the results.
244
- - An audio alert is played if smoking is confirmed, based on the most common gender and age range.
245
- 4. **Requirements**:
246
- - Ensure the 'audio' directory contains .wav files named in the format '<age_range> <gender>.wav'
247
- (e.g., '10-19 male.wav') for audio feedback.
248
- - Set up Twilio environment variables (TWILIO_ACCOUNT_SID and TWILIO_AUTH_TOKEN) for WebRTC.
249
  """)
 
 
 
250
 
251
  # ======================
252
- # 照片检测页面
253
  # ======================
254
 
255
  def photo_detection_page():
 
256
  audio_placeholder = st.empty()
257
- st.title("照片检测")
258
- st.write("上传一张图片或使用摄像头拍摄,检测是否吸烟,若检测到吸烟则进一步分析性别和年龄。")
259
-
260
- # 提供上传和摄像头选项
261
- option = st.radio("选择输入方式", ["上传图片", "使用摄像头拍摄"])
262
 
 
 
263
  image = None
264
- if option == "上传图片":
265
- uploaded_file = st.file_uploader("选择一张图片", type=["jpg", "jpeg", "png"])
266
- if uploaded_file is not None:
 
267
  image = Image.open(uploaded_file)
268
- st.image(image, caption="上传的图片", use_container_width=True)
269
  else:
270
- # 摄像头拍摄
271
- enable = st.checkbox("启用摄像头")
272
- camera_file = st.camera_input("拍摄照片", disabled=not enable)
273
- if camera_file is not None:
274
  image = Image.open(camera_file)
275
- st.image(image, caption="拍攝的圖片", use_container_width=True)
276
-
277
- if image is not None:
278
- # 吸烟分类
279
- with st.spinner("Wait for smoking detection"):
280
- smoke_result = smoking_detection(image)
281
- st.success("The smoke result is:")
282
- st.write(smoke_result)
283
-
284
  if smoke_result.lower() == "smoking":
285
- # 性别分类
286
- with st.spinner("Wait for gender detection"):
287
- gender_result = gender_detection(image)
288
- st.success("The gender result is:")
289
- st.write(gender_result)
290
-
291
- # 年龄分类
292
- with st.spinner("Wait for age detection"):
293
- age_result = age_detection(image)
294
- st.success("The age result is:")
295
- st.write(age_result)
296
-
297
  audio_placeholder.empty()
298
  audio_key = f"{age_result} {gender_result.lower()}"
299
  if audio_key in audio_data:
300
- audio_bytes = audio_data[audio_key]
301
- play_audio_via_js(audio_bytes)
302
  else:
303
- st.error(f"音频文件不存在: {audio_key}.wav")
304
 
305
  # ======================
306
- # 实时检测页面
307
  # ======================
308
 
309
  def real_time_detection_page():
310
- st.title("实时视频检测")
311
- st.write("程序在一分钟内捕获5张快照进行图片分类,首先判定是否吸烟。若检测到吸烟的快照超过2次,则将结果添加到表格中,包含快照、性别和年龄。")
 
312
 
313
- # 初始化 session state 用于存储检测结果
314
  if 'detection_results' not in st.session_state:
315
  st.session_state.detection_results = []
316
 
317
- # 创建用于显示进度文字和进度条的占位容器
318
- capture_text_placeholder = st.empty()
319
- capture_progress_placeholder = st.empty()
320
- classification_text_placeholder = st.empty()
321
- classification_progress_placeholder = st.empty()
322
- detection_info_placeholder = st.empty()
323
-
324
- # 显示检测结果表格
325
- table_placeholder = st.empty()
326
-
327
- # 启���实时视频流
328
- ctx = webrtc_streamer(key="unique_example", video_transformer_factory=VideoTransformer,
329
- rtc_configuration={"iceServers": token.ice_servers})
330
- image_placeholder = st.empty()
331
- audio_placeholder = st.empty()
 
332
 
333
  capture_target = 5
334
 
335
- if ctx.video_transformer is not None:
336
- classification_result_placeholder = st.empty()
337
- detection_info_placeholder.info("开始侦测")
338
 
339
  while True:
340
  snapshots = ctx.video_transformer.snapshots
341
 
342
  if len(snapshots) < capture_target:
343
- capture_text_placeholder.text(f"捕获进度: {len(snapshots)}/{capture_target} 张快照")
344
- progress_value = int(len(snapshots) / capture_target * 100)
345
- capture_progress_placeholder.progress(progress_value)
346
  else:
347
- capture_text_placeholder.text("捕获进度: 捕获完成!")
348
- capture_progress_placeholder.empty()
349
- detection_info_placeholder.empty()
350
 
351
- total = len(snapshots)
352
- classification_text_placeholder.text("分类进度: 正在分类...")
353
- classification_progress = classification_progress_placeholder.progress(0)
354
 
355
- smoke_results = []
356
- for idx, img in enumerate(snapshots):
357
- smoke_results.append(smoking_classification(img))
358
  smoking_count = sum(1 for result in smoke_results if result.lower() == "smoking")
359
- classification_progress.progress(33)
360
 
361
  if smoking_count > 2:
362
- gender_results = []
363
- for idx, img in enumerate(snapshots):
364
- gender_results.append(gender_classification(img))
365
- classification_progress.progress(66)
366
-
367
- age_results = []
368
- for idx, img in enumerate(snapshots):
369
- age_results.append(age_classification(img))
370
- classification_progress.progress(100)
371
- classification_text_placeholder.text("分类进度: 分类完成!")
372
 
 
373
  most_common_gender = Counter(gender_results).most_common(1)[0][0]
374
  most_common_age = Counter(age_results).most_common(1)[0][0]
375
 
376
- # 找到第一张吸烟快照
377
- smoking_image = None
378
- for idx, label in enumerate(smoke_results):
379
- if label.lower() == "smoking":
380
- smoking_image = snapshots[idx]
381
- break
382
- if smoking_image is None:
383
- smoking_image = snapshots[0]
384
 
385
- # 添加结果到 session state
386
  st.session_state.detection_results.append({
387
  "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
388
  "Snapshot": smoking_image,
@@ -391,7 +360,7 @@ def real_time_detection_page():
391
  "Smoking Count": smoking_count
392
  })
393
 
394
- # 更新表格显示
395
  df = pd.DataFrame([
396
  {
397
  "Timestamp": result["Timestamp"],
@@ -400,28 +369,26 @@ def real_time_detection_page():
400
  "Smoking Count": result["Smoking Count"]
401
  } for result in st.session_state.detection_results
402
  ])
403
- table_placeholder.dataframe(df, use_container_width=True)
404
 
405
- # 显示示例快照
406
- image_placeholder.image(smoking_image, caption="捕获的吸烟快照", use_container_width=True)
407
 
408
- # 播放音频
409
- audio_placeholder.empty()
410
  audio_key = f"{most_common_age} {most_common_gender.lower()}"
411
  if audio_key in audio_data:
412
- audio_bytes = audio_data[audio_key]
413
- play_audio_via_js(audio_bytes)
414
  else:
415
- st.error(f"音频文件不存在: {audio_key}.wav")
416
  else:
417
- result_text = "**吸烟状态:** Not Smoking"
418
- classification_result_placeholder.markdown(result_text)
419
- image_placeholder.empty()
420
- audio_placeholder.empty()
421
- classification_text_placeholder.text("分类进度: 分类完成!")
422
  classification_progress.progress(100)
423
 
424
- # 更新表格显示,即使没有吸烟检测到
425
  if st.session_state.detection_results:
426
  df = pd.DataFrame([
427
  {
@@ -431,31 +398,33 @@ def real_time_detection_page():
431
  "Smoking Count": result["Smoking Count"]
432
  } for result in st.session_state.detection_results
433
  ])
434
- table_placeholder.dataframe(df, use_container_width=True)
435
 
 
436
  time.sleep(5)
437
- classification_progress_placeholder.empty()
438
- classification_text_placeholder.empty()
439
- capture_text_placeholder.empty()
440
-
441
- detection_info_placeholder.info("开始侦测")
442
  ctx.video_transformer.snapshots = []
443
  ctx.video_transformer.last_capture_time = time.time()
 
444
  time.sleep(0.1)
445
 
446
  # ======================
447
- # 主函数:多页面导航
448
  # ======================
449
 
450
  def main():
451
- st.sidebar.title("导航")
452
- page = st.sidebar.selectbox("选择页面", ["coverpage", "照片检测", "实时视频检测"])
 
453
 
454
- if page == "coverpage":
455
  cover_page()
456
- if page == "照片检测":
457
  photo_detection_page()
458
- if page == "实时视频检测":
459
  real_time_detection_page()
460
 
461
  if __name__ == "__main__":
 
17
  import pandas as pd
18
 
19
  # ======================
20
+ # Model Loading Functions
21
  # ======================
22
 
23
  @st.cache_resource
24
  def load_smoke_pipeline():
25
+ """Initialize and cache the smoking image classification pipeline."""
26
  return pipeline("image-classification", model="ccclllwww/smoker_cls_base_V9", use_fast=True)
27
 
28
  @st.cache_resource
29
  def load_gender_pipeline():
30
+ """Initialize and cache the gender image classification pipeline."""
31
  return pipeline("image-classification", model="rizvandwiki/gender-classification-2", use_fast=True)
32
 
33
  @st.cache_resource
34
  def load_age_pipeline():
35
+ """Initialize and cache the age image classification pipeline."""
36
  return pipeline("image-classification", model="akashmaggon/vit-base-age-classification", use_fast=True)
37
 
38
+ # Preload all models
39
  smoke_pipeline = load_smoke_pipeline()
40
  gender_pipeline = load_gender_pipeline()
41
  age_pipeline = load_age_pipeline()
42
 
43
  # ======================
44
+ # Twilio Configuration
45
  # ======================
 
 
46
 
47
+ def initialize_twilio_client():
48
+ """Initialize Twilio client using environment variables."""
49
+ account_sid = os.environ.get('TWILIO_ACCOUNT_SID')
50
+ auth_token = os.environ.get('TWILIO_AUTH_TOKEN')
51
+ if not account_sid or not auth_token:
52
+ st.error("Twilio credentials not found in environment variables.")
53
+ st.stop()
54
+ client = Client(account_sid, auth_token)
55
+ return client.tokens.create()
56
 
57
+ token = initialize_twilio_client()
58
 
59
  # ======================
60
+ # Audio Loading Function
61
  # ======================
62
 
63
  @st.cache_resource
64
+ def load_audio_files():
65
+ """Load all .wav files from the audio directory into a dictionary."""
 
66
  audio_dir = "audio"
67
+ if not os.path.exists(audio_dir):
68
+ st.error(f"Audio directory '{audio_dir}' not found.")
69
+ st.stop()
70
  audio_files = [f for f in os.listdir(audio_dir) if f.endswith(".wav")]
71
  audio_dict = {}
72
  for audio_file in audio_files:
73
+ with open(os.path.join(audio_dir, audio_file), "rb") as file:
74
+ audio_dict[os.path.splitext(audio_file)[0]] = file.read()
 
 
 
 
75
  return audio_dict
76
 
77
+ # Load audio files at startup
78
+ audio_data = load_audio_files()
79
 
80
  # ======================
81
+ # Image Processing Functions
82
  # ======================
83
 
84
+ def detect_smoking(image: Image.Image) -> str:
85
+ """Classify an image for smoking activity."""
86
  try:
87
  output = smoke_pipeline(image)
88
+ return output[0]["label"]
 
89
  except Exception as e:
90
+ st.error(f"Image processing error: {str(e)}")
91
  st.stop()
92
+
93
+ def detect_gender(image: Image.Image) -> str:
94
+ """Classify an image for gender."""
95
  try:
96
  output = gender_pipeline(image)
97
+ return output[0]["label"]
 
98
  except Exception as e:
99
+ st.error(f"Image processing error: {str(e)}")
100
  st.stop()
101
+
102
+ def detect_age(image: Image.Image) -> str:
103
+ """Classify an image for age range."""
104
  try:
105
  output = age_pipeline(image)
106
+ return output[0]["label"]
 
107
  except Exception as e:
108
+ st.error(f"Image processing error: {str(e)}")
109
  st.stop()
110
+
111
  # ======================
112
+ # Real-Time Classification Functions
113
  # ======================
114
 
115
  @st.cache_data(show_spinner=False, max_entries=3)
116
+ def classify_smoking(image: Image.Image) -> str:
117
+ """Classify an image for smoking and return the label with highest confidence."""
118
  try:
119
  output = smoke_pipeline(image)
120
+ return max(output, key=lambda x: x["score"])["label"]
 
121
  except Exception as e:
122
+ st.error(f"Image processing error: {str(e)}")
123
  st.stop()
124
 
125
  @st.cache_data(show_spinner=False, max_entries=3)
126
+ def classify_gender(image: Image.Image) -> str:
127
+ """Classify an image for gender and return the label with highest confidence."""
128
  try:
129
  output = gender_pipeline(image)
130
+ return max(output, key=lambda x: x["score"])["label"]
 
131
  except Exception as e:
132
+ st.error(f"Image processing error: {str(e)}")
133
  st.stop()
134
 
135
  @st.cache_data(show_spinner=False, max_entries=3)
136
+ def classify_age(image: Image.Image) -> str:
137
+ """Classify an image for age range and return the label with highest confidence."""
138
  try:
139
  output = age_pipeline(image)
140
+ return max(output, key=lambda x: x["score"])["label"]
 
141
  except Exception as e:
142
+ st.error(f"Image processing error: {str(e)}")
143
  st.stop()
144
 
145
  # ======================
146
+ # Audio Playback Function
147
  # ======================
148
 
149
+ def play_audio(audio_bytes: bytes):
150
+ """Play audio using HTML and JavaScript with Base64-encoded audio data."""
 
 
 
 
 
151
  audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
152
+ audio_id = f"audio_player_{uuid.uuid4()}"
153
  html_content = f"""
154
+ <audio id="{audio_id}" controls style="width: 100%;">
155
  <source src="data:audio/wav;base64,{audio_base64}" type="audio/wav">
156
  Your browser does not support the audio element.
157
  </audio>
158
  <script type="text/javascript">
159
  window.addEventListener('DOMContentLoaded', function() {{
160
  setTimeout(function() {{
161
+ var audioElement = document.getElementById("{audio_id}");
162
  if (audioElement) {{
163
  audioElement.play().catch(function(e) {{
164
+ console.log("Playback prevented by browser:", e);
165
  }});
166
  }}
167
  }}, 1000);
 
171
  st.components.v1.html(html_content, height=150)
172
 
173
  # ======================
174
+ # Video Transformer Class
175
  # ======================
176
 
177
  class VideoTransformer(VideoTransformerBase):
178
  def __init__(self):
179
+ self.snapshots = []
180
+ self.last_capture_time = time.time()
181
+ self.capture_interval = 1 # Capture every 1 second
182
+ self.max_snapshots = 5
183
 
184
  def transform(self, frame):
185
+ """Process video frame and capture snapshots."""
186
  img = frame.to_ndarray(format="bgr24")
187
  current_time = time.time()
188
+ if (current_time - self.last_capture_time >= self.capture_interval and
189
+ len(self.snapshots) < self.max_snapshots):
190
  img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
191
  self.snapshots.append(Image.fromarray(img_rgb))
192
  self.last_capture_time = current_time
193
+ st.write(f"Captured snapshot {len(self.snapshots)}/{self.max_snapshots}")
194
+ return img
195
 
196
  # ======================
197
  # Cover Page
198
  # ======================
199
 
200
  def cover_page():
201
+ """Display an enhanced cover page with project overview and instructions."""
202
+ st.title("Smoking Detection System", anchor=False)
 
 
 
 
 
 
 
203
 
204
+ st.markdown("### Welcome to the Smoking Detection System")
205
+ st.markdown("""
206
+ This Streamlit-based application harnesses cutting-edge machine learning to detect smoking behavior in images and real-time video streams. By analyzing smoking activity, gender, and age demographics, it provides valuable insights for public health monitoring and policy enforcement.
207
+ """)
 
 
 
208
 
209
+ st.markdown("#### Project Overview")
210
+ st.markdown("""
211
+ - **Purpose**: Automatically identify smoking behavior in public or controlled environments to support compliance with no-smoking policies and facilitate behavioral studies.
212
+ - **Significance**: Enhances public health initiatives by enabling real-time monitoring and demographic analysis of smoking activities.
213
+ - **Features**:
214
+ - **Photo Detection**: Analyze a single image (uploaded or captured) for smoking, gender, and age.
215
+ - **Real-Time Video Detection**: Process webcam streams, capturing snapshots to detect smoking and demographics.
216
+ - **Audio Feedback**: Play alerts based on detected gender and age when smoking is confirmed.
217
  """)
218
 
219
+ st.markdown("#### How to Use")
220
+ st.markdown("""
221
+ 1. **Navigate**: Use the sidebar to select a page:
222
+ - **Cover Page**: View this overview.
223
+ - **Photo Detection**: Upload or capture an image for analysis.
224
+ - **Real-Time Video Detection**: Monitor live webcam feed.
225
  2. **Photo Detection**:
226
+ - Upload an image or capture one via webcam.
227
+ - The system detects smoking; if detected, it analyzes gender and age, playing a corresponding audio alert.
 
228
  3. **Real-Time Video Detection**:
229
+ - Captures 5 snapshots over one minute.
230
+ - If smoking is detected in more than 2 snapshots, it analyzes gender and age, displays results in a table, and plays an audio alert.
231
+ 4. **Setup Requirements**:
232
+ - Ensure the 'audio' directory contains .wav files named as '<age_range> <gender>.wav' (e.g., '10-19 male.wav').
233
+ - Configure Twilio environment variables (`TWILIO_ACCOUNT_SID` and `TWILIO_AUTH_TOKEN`) for WebRTC functionality.
 
 
 
234
  """)
235
+
236
+ st.markdown("#### Get Started")
237
+ st.markdown("Select a page from the sidebar to begin analyzing images or video streams.")
238
 
239
  # ======================
240
+ # Photo Detection Page
241
  # ======================
242
 
243
  def photo_detection_page():
244
+ """Handle photo detection page for smoking, gender, and age classification."""
245
  audio_placeholder = st.empty()
246
+ st.title("Photo Detection", anchor=False)
247
+ st.markdown("Upload an image or capture a photo to detect smoking behavior. If smoking is detected, gender and age will be analyzed.")
 
 
 
248
 
249
+ # Image input selection
250
+ option = st.radio("Choose input method", ["Upload Image", "Capture with Camera"], horizontal=True)
251
  image = None
252
+
253
+ if option == "Upload Image":
254
+ uploaded_file = st.file_uploader("Select an image", type=["jpg", "jpeg", "png"])
255
+ if uploaded_file:
256
  image = Image.open(uploaded_file)
257
+ st.image(image, caption="Uploaded Image", use_container_width=True)
258
  else:
259
+ enable = st.checkbox("Enable Camera")
260
+ camera_file = st.camera_input("Capture Photo", disabled=not enable)
261
+ if camera_file:
 
262
  image = Image.open(camera_file)
263
+ st.image(image, caption="Captured Photo", use_container_width=True)
264
+
265
+ if image:
266
+ with st.spinner("Detecting smoking..."):
267
+ smoke_result = detect_smoking(image)
268
+ st.success(f"Smoking Status: {smoke_result}")
269
+
 
 
270
  if smoke_result.lower() == "smoking":
271
+ with st.spinner("Detecting gender..."):
272
+ gender_result = detect_gender(image)
273
+ st.success(f"Gender: {gender_result}")
274
+
275
+ with st.spinner("Detecting age..."):
276
+ age_result = detect_age(image)
277
+ st.success(f"Age Range: {age_result}")
278
+
 
 
 
 
279
  audio_placeholder.empty()
280
  audio_key = f"{age_result} {gender_result.lower()}"
281
  if audio_key in audio_data:
282
+ play_audio(audio_data[audio_key])
 
283
  else:
284
+ st.error(f"Audio file not found: {audio_key}.wav")
285
 
286
  # ======================
287
+ # Real-Time Detection Page
288
  # ======================
289
 
290
  def real_time_detection_page():
291
+ """Handle real-time video detection with snapshot capture and analysis."""
292
+ st.title("Real-Time Video Detection", anchor=False)
293
+ st.markdown("Captures 5 snapshots over one minute to detect smoking. If smoking is detected in more than 2 snapshots, results include gender, age, and a snapshot in a table.")
294
 
295
+ # Initialize session state for detection results
296
  if 'detection_results' not in st.session_state:
297
  st.session_state.detection_results = []
298
 
299
+ # Placeholders for UI elements
300
+ capture_text = st.empty()
301
+ capture_progress = st.empty()
302
+ classification_text = st.empty()
303
+ classification_progress = st.empty()
304
+ detection_info = st.empty()
305
+ table = st.empty()
306
+ image_display = st.empty()
307
+ audio = st.empty()
308
+
309
+ # Start video stream
310
+ ctx = webrtc_streamer(
311
+ key="unique_example",
312
+ video_transformer_factory=VideoTransformer,
313
+ rtc_configuration={"iceServers": token.ice_servers}
314
+ )
315
 
316
  capture_target = 5
317
 
318
+ if ctx.video_transformer:
319
+ detection_info.info("Starting detection...")
 
320
 
321
  while True:
322
  snapshots = ctx.video_transformer.snapshots
323
 
324
  if len(snapshots) < capture_target:
325
+ capture_text.text(f"Capture Progress: {len(snapshots)}/{capture_target} snapshots")
326
+ capture_progress.progress(int(len(snapshots) / capture_target * 100))
 
327
  else:
328
+ capture_text.text("Capture Progress: Completed!")
329
+ capture_progress.empty()
330
+ detection_info.empty()
331
 
332
+ classification_text.text("Classification Progress: Analyzing...")
333
+ classification = classification_progress.progress(0)
 
334
 
335
+ # Classify snapshots
336
+ smoke_results = [classify_smoking(img) for img in snapshots]
 
337
  smoking_count = sum(1 for result in smoke_results if result.lower() == "smoking")
338
+ classification.progress(33)
339
 
340
  if smoking_count > 2:
341
+ gender_results = [classify_gender(img) for img in snapshots]
342
+ classification.progress(66)
343
+ age_results = [classify_age(img) for img in snapshots]
344
+ classification.progress(100)
345
+ classification_text.text("Classification Progress: Completed!")
 
 
 
 
 
346
 
347
+ # Determine most common gender and age
348
  most_common_gender = Counter(gender_results).most_common(1)[0][0]
349
  most_common_age = Counter(age_results).most_common(1)[0][0]
350
 
351
+ # Select first smoking snapshot
352
+ smoking_image = next((snapshots[i] for i, label in enumerate(smoke_results) if label.lower() == "smoking"), snapshots[0])
 
 
 
 
 
 
353
 
354
+ # Store results
355
  st.session_state.detection_results.append({
356
  "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
357
  "Snapshot": smoking_image,
 
360
  "Smoking Count": smoking_count
361
  })
362
 
363
+ # Update table
364
  df = pd.DataFrame([
365
  {
366
  "Timestamp": result["Timestamp"],
 
369
  "Smoking Count": result["Smoking Count"]
370
  } for result in st.session_state.detection_results
371
  ])
372
+ table.dataframe(df, use_container_width=True)
373
 
374
+ # Display snapshot
375
+ image_display.image(smoking_image, caption="Detected Smoking Snapshot", use_container_width=True)
376
 
377
+ # Play audio
378
+ audio.empty()
379
  audio_key = f"{most_common_age} {most_common_gender.lower()}"
380
  if audio_key in audio_data:
381
+ play_audio(audio_data[audio_key])
 
382
  else:
383
+ st.error(f"Audio file not found: {audio_key}.wav")
384
  else:
385
+ st.markdown("**Smoking Status:** Not Smoking")
386
+ image_display.empty()
387
+ audio.empty()
388
+ classification_text.text("Classification Progress: Completed!")
 
389
  classification_progress.progress(100)
390
 
391
+ # Update table if results exist
392
  if st.session_state.detection_results:
393
  df = pd.DataFrame([
394
  {
 
398
  "Smoking Count": result["Smoking Count"]
399
  } for result in st.session_state.detection_results
400
  ])
401
+ table.dataframe(df, use_container_width=True)
402
 
403
+ # Reset for next cycle
404
  time.sleep(5)
405
+ classification_progress.empty()
406
+ classification_text.empty()
407
+ capture_text.empty()
408
+ detection_info.info("Starting detection...")
 
409
  ctx.video_transformer.snapshots = []
410
  ctx.video_transformer.last_capture_time = time.time()
411
+
412
  time.sleep(0.1)
413
 
414
  # ======================
415
+ # Main Application
416
  # ======================
417
 
418
  def main():
419
+ """Main function to handle page navigation."""
420
+ st.sidebar.title("Navigation")
421
+ page = st.sidebar.selectbox("Select Page", ["Cover Page", "Photo Detection", "Real-Time Video Detection"])
422
 
423
+ if page == "Cover Page":
424
  cover_page()
425
+ elif page == "Photo Detection":
426
  photo_detection_page()
427
+ elif page == "Real-Time Video Detection":
428
  real_time_detection_page()
429
 
430
  if __name__ == "__main__":