ccclllwww commited on
Commit
8aa70c8
·
verified ·
1 Parent(s): 333c990

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -34
app.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  import streamlit as st
2
  import cv2
3
  import time
@@ -5,10 +11,10 @@ from streamlit_webrtc import VideoTransformerBase, webrtc_streamer
5
  from PIL import Image
6
  from transformers import pipeline
7
  import os
8
- from twilio.rest import Client
9
- from collections import Counter
10
  import base64
11
-
 
 
12
 
13
  # ======================
14
  # 模型加载函数(缓存)
@@ -30,15 +36,17 @@ def load_age_pipeline():
30
  return pipeline("image-classification", model="akashmaggon/vit-base-age-classification", use_fast=True)
31
 
32
  # 预先加载所有模型
33
- load_smoke_pipeline()
34
- load_gender_pipeline()
35
- load_age_pipeline()
 
36
 
37
  # ======================
38
  # remote settings
39
  # ======================
40
  # Find your Account SID and Auth Token at twilio.com/console
41
  # and set the environment variables. See http://twil.io/secure
 
42
  account_sid = os.environ['TWILIO_ACCOUNT_SID']
43
  auth_token = os.environ['TWILIO_AUTH_TOKEN']
44
  client = Client(account_sid, auth_token)
@@ -69,15 +77,46 @@ def load_all_audios():
69
  # 应用启动时加载所有音频
70
  audio_data = load_all_audios()
71
 
 
72
  # ======================
73
- # 核心处理函数
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  # ======================
75
 
76
  @st.cache_data(show_spinner=False, max_entries=3)
77
  def smoking_classification(image: Image.Image) -> str:
78
  """接受 PIL 图片并利用吸烟分类 pipeline 进行判定,返回标签(如 "smoking")。"""
79
  try:
80
- smoke_pipeline = load_smoke_pipeline()
81
  output = smoke_pipeline(image)
82
  status = max(output, key=lambda x: x["score"])['label']
83
  return status
@@ -89,7 +128,6 @@ def smoking_classification(image: Image.Image) -> str:
89
  def gender_classification(image: Image.Image) -> str:
90
  """进行性别分类,返回模型输出的性别(依模型输出)。"""
91
  try:
92
- gender_pipeline = load_gender_pipeline()
93
  output = gender_pipeline(image)
94
  status = max(output, key=lambda x: x["score"])['label']
95
  return status
@@ -101,7 +139,6 @@ def gender_classification(image: Image.Image) -> str:
101
  def age_classification(image: Image.Image) -> str:
102
  """进行年龄分类,返回年龄范围,例如 "10-19" 等。"""
103
  try:
104
- age_pipeline = load_age_pipeline()
105
  output = age_pipeline(image)
106
  age_range = max(output, key=lambda x: x["score"])['label']
107
  return age_range
@@ -122,15 +159,14 @@ def play_audio_via_js(audio_bytes):
122
  """
123
  audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
124
  html_content = f"""
125
- <audio id="audio_player" controls style="width: 100%;">
126
  <source src="data:audio/wav;base64,{audio_base64}" type="audio/wav">
127
  Your browser does not support the audio element.
128
  </audio>
129
  <script type="text/javascript">
130
- // 等待 DOMContentLoaded 事件,并在1秒后自动调用 play() 方法
131
  window.addEventListener('DOMContentLoaded', function() {{
132
  setTimeout(function() {{
133
- var audioElement = document.getElementById("audio_player");
134
  if (audioElement) {{
135
  audioElement.play().catch(function(e) {{
136
  console.log("播放被浏览器阻止:", e);
@@ -165,11 +201,57 @@ class VideoTransformer(VideoTransformerBase):
165
  return img # 返回原始帧以供前端显示
166
 
167
  # ======================
168
- # 主函数:整合视频流、自动图分类并展示结果
169
  # ======================
170
 
171
- def main():
172
- st.title("Streamlit-WebRTC 自动图分类示例")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  st.write("程序在一分钟内捕获20张快照进行图片分类,首先判定是否吸烟。若检测到吸烟的快照超过2次,则展示年龄与性别分类结果。")
174
 
175
  # 创建用于显示进度文字和进度条的占位容器
@@ -177,48 +259,43 @@ def main():
177
  capture_progress_placeholder = st.empty()
178
  classification_text_placeholder = st.empty()
179
  classification_progress_placeholder = st.empty()
180
- detection_info_placeholder = st.empty() # 用于显示“开始侦测”
181
 
182
  # 启动实时视频流
183
- # Then, pass the ICE server information to webrtc_streamer().
184
-
185
- ctx = webrtc_streamer(key="unique_example", video_transformer_factory=VideoTransformer,rtc_configuration={"iceServers": token.ice_servers})
186
  image_placeholder = st.empty()
187
  audio_placeholder = st.empty()
188
 
189
- capture_target = 10 # 本轮捕获目标张数
190
 
191
  if ctx.video_transformer is not None:
192
- classification_result_placeholder = st.empty() # 用于显示分类结果
193
  detection_info_placeholder.info("开始侦测")
194
 
195
  while True:
196
  snapshots = ctx.video_transformer.snapshots
197
 
198
- # 更新捕获阶段进度:同时显示文字和进度条
199
  if len(snapshots) < capture_target:
200
  capture_text_placeholder.text(f"捕获进度: {len(snapshots)}/{capture_target} 张快照")
201
  progress_value = int(len(snapshots) / capture_target * 100)
202
  capture_progress_placeholder.progress(progress_value)
203
  else:
204
- # 捕获完成,清空捕获进度条,并显示完成提示
205
  capture_text_placeholder.text("捕获进度: 捕获完成!")
206
  capture_progress_placeholder.empty()
207
- detection_info_placeholder.empty() # 清除“开始侦测”提示
208
 
209
- # ---------- 分类阶段进度 ----------
210
  total = len(snapshots)
211
  classification_text_placeholder.text("分类进度: 正在分类...")
212
  classification_progress = classification_progress_placeholder.progress(0)
213
 
214
- # 1. 吸烟分类 (0 ~ 33%)
215
  smoke_results = []
216
  for idx, img in enumerate(snapshots):
217
  smoke_results.append(smoking_classification(img))
218
  smoking_count = sum(1 for result in smoke_results if result.lower() == "smoking")
219
  classification_progress.progress(33)
220
 
221
- # 2. 若吸烟次数超过2,再进行性别和年龄分类 (33% ~ 100%)
222
  if smoking_count > 2:
223
  gender_results = []
224
  for idx, img in enumerate(snapshots):
@@ -241,7 +318,6 @@ def main():
241
  )
242
  classification_result_placeholder.markdown(result_text)
243
 
244
- # 选择第一张分类结果为 "smoking" 的快照,如未检测到,则显示第一张
245
  smoking_image = None
246
  for idx, label in enumerate(smoke_results):
247
  if label.lower() == "smoking":
@@ -251,7 +327,6 @@ def main():
251
  smoking_image = snapshots[0]
252
  image_placeholder.image(smoking_image, caption="捕获的快照示例", use_container_width=True)
253
 
254
- # 清空播放区域后再播放对应音频
255
  audio_placeholder.empty()
256
  audio_key = f"{most_common_age} {most_common_gender.lower()}"
257
  if audio_key in audio_data:
@@ -267,18 +342,28 @@ def main():
267
  classification_text_placeholder.text("分类进度: 分类完成!")
268
  classification_progress.progress(100)
269
 
270
- # 分类阶段结束后清空分类进度占位区
271
  time.sleep(1)
272
  classification_progress_placeholder.empty()
273
  classification_text_placeholder.empty()
274
  capture_text_placeholder.empty()
275
 
276
-
277
- # 重置快照列表,准备下一轮捕获
278
  detection_info_placeholder.info("开始侦测")
279
  ctx.video_transformer.snapshots = []
280
  ctx.video_transformer.last_capture_time = time.time()
281
  time.sleep(0.1)
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  if __name__ == "__main__":
284
- main()
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Tue May 20 11:00:14 2025
4
+
5
+ @author: ColinWang
6
+ """
7
  import streamlit as st
8
  import cv2
9
  import time
 
11
  from PIL import Image
12
  from transformers import pipeline
13
  import os
 
 
14
  import base64
15
+ #from twilio.rest import Client
16
+ from collections import Counter
17
+ import uuid
18
 
19
  # ======================
20
  # 模型加载函数(缓存)
 
36
  return pipeline("image-classification", model="akashmaggon/vit-base-age-classification", use_fast=True)
37
 
38
  # 预先加载所有模型
39
+ smoke_pipeline = load_smoke_pipeline()
40
+ gender_pipeline = load_gender_pipeline()
41
+ age_pipeline = load_age_pipeline()
42
+
43
 
44
  # ======================
45
  # remote settings
46
  # ======================
47
  # Find your Account SID and Auth Token at twilio.com/console
48
  # and set the environment variables. See http://twil.io/secure
49
+
50
  account_sid = os.environ['TWILIO_ACCOUNT_SID']
51
  auth_token = os.environ['TWILIO_AUTH_TOKEN']
52
  client = Client(account_sid, auth_token)
 
77
  # 应用启动时加载所有音频
78
  audio_data = load_all_audios()
79
 
80
+
81
  # ======================
82
+ # 照片檢測处理函数
83
+ # ======================
84
+
85
+ def smoking_detection(image: Image.Image) -> str:
86
+ try:
87
+ output = smoke_pipeline(image)
88
+ status = output[0]["label"]
89
+ return status
90
+ except Exception as e:
91
+ st.error(f"🔍 图像处理错误: {str(e)}")
92
+ st.stop()
93
+
94
+ def gender_detection(image: Image.Image) -> str:
95
+ try:
96
+ output = gender_pipeline(image)
97
+ status = output[0]["label"]
98
+ return status
99
+ except Exception as e:
100
+ st.error(f"🔍 图像处理错误: {str(e)}")
101
+ st.stop()
102
+
103
+ def age_detection(image: Image.Image) -> str:
104
+ try:
105
+ output = age_pipeline(image)
106
+ status = output[0]["label"]
107
+ return status
108
+ except Exception as e:
109
+ st.error(f"🔍 图像处理错误: {str(e)}")
110
+ st.stop()
111
+
112
+ # ======================
113
+ # 實時檢測核心处理函数
114
  # ======================
115
 
116
  @st.cache_data(show_spinner=False, max_entries=3)
117
  def smoking_classification(image: Image.Image) -> str:
118
  """接受 PIL 图片并利用吸烟分类 pipeline 进行判定,返回标签(如 "smoking")。"""
119
  try:
 
120
  output = smoke_pipeline(image)
121
  status = max(output, key=lambda x: x["score"])['label']
122
  return status
 
128
  def gender_classification(image: Image.Image) -> str:
129
  """进行性别分类,返回模型输出的性别(依模型输出)。"""
130
  try:
 
131
  output = gender_pipeline(image)
132
  status = max(output, key=lambda x: x["score"])['label']
133
  return status
 
139
  def age_classification(image: Image.Image) -> str:
140
  """进行年龄分类,返回年龄范围,例如 "10-19" 等。"""
141
  try:
 
142
  output = age_pipeline(image)
143
  age_range = max(output, key=lambda x: x["score"])['label']
144
  return age_range
 
159
  """
160
  audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
161
  html_content = f"""
162
+ <audio id="audio_player_{uuid.uuid4()}" controls style="width: 100%;">
163
  <source src="data:audio/wav;base64,{audio_base64}" type="audio/wav">
164
  Your browser does not support the audio element.
165
  </audio>
166
  <script type="text/javascript">
 
167
  window.addEventListener('DOMContentLoaded', function() {{
168
  setTimeout(function() {{
169
+ var audioElement = document.getElementById("audio_player_{uuid.uuid4()}");
170
  if (audioElement) {{
171
  audioElement.play().catch(function(e) {{
172
  console.log("播放被浏览器阻止:", e);
 
201
  return img # 返回原始帧以供前端显示
202
 
203
  # ======================
204
+ # 检测页面
205
  # ======================
206
 
207
+ def photo_detection_page():
208
+ st.title("检测")
209
+ st.write("上传一张图片或使用摄像头拍摄,检测是否吸烟,若检测到吸烟则进一步分析性别和年龄。")
210
+
211
+ # 提供上传和摄像头选项
212
+ option = st.radio("选择输入方式", ["上传图片", "使用摄像头拍摄"])
213
+
214
+ image = None
215
+ if option == "上传图片":
216
+ uploaded_file = st.file_uploader("选择一张图片", type=["jpg", "jpeg", "png"])
217
+ if uploaded_file is not None:
218
+ image = Image.open(uploaded_file)
219
+ st.image(image, caption="上传的图片", use_container_width=True)
220
+ else:
221
+ # 摄像头拍摄
222
+ enable = st.checkbox("启用摄像头")
223
+ camera_file = st.camera_input("拍摄照片", disabled=not enable)
224
+ if camera_file is not None:
225
+ image = Image.open(camera_file)
226
+ st.image(image, caption="拍攝的圖片", use_container_width=True)
227
+
228
+ if image is not None:
229
+
230
+ # 吸烟分类
231
+ with st.spinner("Wait for smoking detection"):
232
+ smoke_result = smoking_detection(image)
233
+ st.success("The smoke result is:")
234
+ st.write(smoke_result)
235
+
236
+ if smoke_result.lower() == "smoking":
237
+ # 性别分类
238
+ with st.spinner("Wait for gender detection"):
239
+ gender_result = gender_detection(image)
240
+ st.success("The gender result is:")
241
+ st.write(gender_result)
242
+
243
+ # 年龄分类
244
+ with st.spinner("Wait for age detection"):
245
+ age_result = age_detection(image)
246
+ st.success("The age result is:")
247
+ st.write(age_result)
248
+
249
+ # ======================
250
+ # 实时检测页面(原主函数)
251
+ # ======================
252
+
253
+ def real_time_detection_page():
254
+ st.title("实时视频检测")
255
  st.write("程序在一分钟内捕获20张快照进行图片分类,首先判定是否吸烟。若检测到吸烟的快照超过2次,则展示年龄与性别分类结果。")
256
 
257
  # 创建用于显示进度文字和进度条的占位容器
 
259
  capture_progress_placeholder = st.empty()
260
  classification_text_placeholder = st.empty()
261
  classification_progress_placeholder = st.empty()
262
+ detection_info_placeholder = st.empty()
263
 
264
  # 启动实时视频流
265
+ ctx = webrtc_streamer(key="unique_example", video_transformer_factory=VideoTransformer,
266
+ rtc_configuration={"iceServers": token.ice_servers}
267
+ )
268
  image_placeholder = st.empty()
269
  audio_placeholder = st.empty()
270
 
271
+ capture_target = 20
272
 
273
  if ctx.video_transformer is not None:
274
+ classification_result_placeholder = st.empty()
275
  detection_info_placeholder.info("开始侦测")
276
 
277
  while True:
278
  snapshots = ctx.video_transformer.snapshots
279
 
 
280
  if len(snapshots) < capture_target:
281
  capture_text_placeholder.text(f"捕获进度: {len(snapshots)}/{capture_target} 张快照")
282
  progress_value = int(len(snapshots) / capture_target * 100)
283
  capture_progress_placeholder.progress(progress_value)
284
  else:
 
285
  capture_text_placeholder.text("捕获进度: 捕获完成!")
286
  capture_progress_placeholder.empty()
287
+ detection_info_placeholder.empty()
288
 
 
289
  total = len(snapshots)
290
  classification_text_placeholder.text("分类进度: 正在分类...")
291
  classification_progress = classification_progress_placeholder.progress(0)
292
 
 
293
  smoke_results = []
294
  for idx, img in enumerate(snapshots):
295
  smoke_results.append(smoking_classification(img))
296
  smoking_count = sum(1 for result in smoke_results if result.lower() == "smoking")
297
  classification_progress.progress(33)
298
 
 
299
  if smoking_count > 2:
300
  gender_results = []
301
  for idx, img in enumerate(snapshots):
 
318
  )
319
  classification_result_placeholder.markdown(result_text)
320
 
 
321
  smoking_image = None
322
  for idx, label in enumerate(smoke_results):
323
  if label.lower() == "smoking":
 
327
  smoking_image = snapshots[0]
328
  image_placeholder.image(smoking_image, caption="捕获的快照示例", use_container_width=True)
329
 
 
330
  audio_placeholder.empty()
331
  audio_key = f"{most_common_age} {most_common_gender.lower()}"
332
  if audio_key in audio_data:
 
342
  classification_text_placeholder.text("分类进度: 分类完成!")
343
  classification_progress.progress(100)
344
 
 
345
  time.sleep(1)
346
  classification_progress_placeholder.empty()
347
  classification_text_placeholder.empty()
348
  capture_text_placeholder.empty()
349
 
 
 
350
  detection_info_placeholder.info("开始侦测")
351
  ctx.video_transformer.snapshots = []
352
  ctx.video_transformer.last_capture_time = time.time()
353
  time.sleep(0.1)
354
 
355
+ # ======================
356
+ # 主函数:多页面导航
357
+ # ======================
358
+
359
+ def main():
360
+ st.sidebar.title("导航")
361
+ page = st.sidebar.selectbox("选择页面", ["照片检测", "实时视频检测"])
362
+
363
+ if page == "照片检测":
364
+ photo_detection_page()
365
+ else:
366
+ real_time_detection_page()
367
+
368
  if __name__ == "__main__":
369
+ main()