Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,6 +14,7 @@ import base64
|
|
| 14 |
from twilio.rest import Client
|
| 15 |
from collections import Counter
|
| 16 |
import uuid
|
|
|
|
| 17 |
|
| 18 |
# ======================
|
| 19 |
# 模型加载函数(缓存)
|
|
@@ -39,7 +40,6 @@ smoke_pipeline = load_smoke_pipeline()
|
|
| 39 |
gender_pipeline = load_gender_pipeline()
|
| 40 |
age_pipeline = load_age_pipeline()
|
| 41 |
|
| 42 |
-
|
| 43 |
# ======================
|
| 44 |
# remote settings
|
| 45 |
# ======================
|
|
@@ -52,7 +52,6 @@ client = Client(account_sid, auth_token)
|
|
| 52 |
|
| 53 |
token = client.tokens.create()
|
| 54 |
|
| 55 |
-
|
| 56 |
# ======================
|
| 57 |
# 音频加载函数(缓存)
|
| 58 |
# ======================
|
|
@@ -76,7 +75,6 @@ def load_all_audios():
|
|
| 76 |
# 应用启动时加载所有音频
|
| 77 |
audio_data = load_all_audios()
|
| 78 |
|
| 79 |
-
|
| 80 |
# ======================
|
| 81 |
# 照片檢測处理函数
|
| 82 |
# ======================
|
|
@@ -250,7 +248,6 @@ def cover_page():
|
|
| 250 |
- Set up Twilio environment variables (TWILIO_ACCOUNT_SID and TWILIO_AUTH_TOKEN) for WebRTC.
|
| 251 |
""")
|
| 252 |
|
| 253 |
-
|
| 254 |
# ======================
|
| 255 |
# 照片检测页面
|
| 256 |
# ======================
|
|
@@ -278,7 +275,6 @@ def photo_detection_page():
|
|
| 278 |
st.image(image, caption="拍攝的圖片", use_container_width=True)
|
| 279 |
|
| 280 |
if image is not None:
|
| 281 |
-
|
| 282 |
# 吸烟分类
|
| 283 |
with st.spinner("Wait for smoking detection"):
|
| 284 |
smoke_result = smoking_detection(image)
|
|
@@ -307,12 +303,16 @@ def photo_detection_page():
|
|
| 307 |
st.error(f"音频文件不存在: {audio_key}.wav")
|
| 308 |
|
| 309 |
# ======================
|
| 310 |
-
# 实时检测页面
|
| 311 |
# ======================
|
| 312 |
|
| 313 |
def real_time_detection_page():
|
| 314 |
st.title("实时视频检测")
|
| 315 |
-
st.write("程序在一分钟内捕获5张快照进行图片分类,首先判定是否吸烟。若检测到吸烟的快照超过2次,则
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
# 创建用于显示进度文字和进度条的占位容器
|
| 318 |
capture_text_placeholder = st.empty()
|
|
@@ -321,10 +321,12 @@ def real_time_detection_page():
|
|
| 321 |
classification_progress_placeholder = st.empty()
|
| 322 |
detection_info_placeholder = st.empty()
|
| 323 |
|
|
|
|
|
|
|
|
|
|
| 324 |
# 启动实时视频流
|
| 325 |
ctx = webrtc_streamer(key="unique_example", video_transformer_factory=VideoTransformer,
|
| 326 |
-
rtc_configuration={"iceServers": token.ice_servers}
|
| 327 |
-
)
|
| 328 |
image_placeholder = st.empty()
|
| 329 |
audio_placeholder = st.empty()
|
| 330 |
|
|
@@ -371,13 +373,7 @@ def real_time_detection_page():
|
|
| 371 |
most_common_gender = Counter(gender_results).most_common(1)[0][0]
|
| 372 |
most_common_age = Counter(age_results).most_common(1)[0][0]
|
| 373 |
|
| 374 |
-
|
| 375 |
-
f"**吸烟状态:** Smoking (检测到 {smoking_count} 次)\n\n"
|
| 376 |
-
f"**性别:** {most_common_gender}\n\n"
|
| 377 |
-
f"**年龄范围:** {most_common_age}"
|
| 378 |
-
)
|
| 379 |
-
classification_result_placeholder.markdown(result_text)
|
| 380 |
-
|
| 381 |
smoking_image = None
|
| 382 |
for idx, label in enumerate(smoke_results):
|
| 383 |
if label.lower() == "smoking":
|
|
@@ -385,8 +381,31 @@ def real_time_detection_page():
|
|
| 385 |
break
|
| 386 |
if smoking_image is None:
|
| 387 |
smoking_image = snapshots[0]
|
| 388 |
-
image_placeholder.image(smoking_image, caption="捕获的快照示例", use_container_width=True)
|
| 389 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
audio_placeholder.empty()
|
| 391 |
audio_key = f"{most_common_age} {most_common_gender.lower()}"
|
| 392 |
if audio_key in audio_data:
|
|
@@ -402,6 +421,18 @@ def real_time_detection_page():
|
|
| 402 |
classification_text_placeholder.text("分类进度: 分类完成!")
|
| 403 |
classification_progress.progress(100)
|
| 404 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
time.sleep(5)
|
| 406 |
classification_progress_placeholder.empty()
|
| 407 |
classification_text_placeholder.empty()
|
|
@@ -418,10 +449,10 @@ def real_time_detection_page():
|
|
| 418 |
|
| 419 |
def main():
|
| 420 |
st.sidebar.title("导航")
|
| 421 |
-
page = st.sidebar.selectbox("选择页面", ["coverpage","照片检测", "实时视频检测"])
|
| 422 |
|
| 423 |
if page == "coverpage":
|
| 424 |
-
|
| 425 |
if page == "照片检测":
|
| 426 |
photo_detection_page()
|
| 427 |
if page == "实时视频检测":
|
|
|
|
| 14 |
from twilio.rest import Client
|
| 15 |
from collections import Counter
|
| 16 |
import uuid
|
| 17 |
+
import pandas as pd
|
| 18 |
|
| 19 |
# ======================
|
| 20 |
# 模型加载函数(缓存)
|
|
|
|
| 40 |
gender_pipeline = load_gender_pipeline()
|
| 41 |
age_pipeline = load_age_pipeline()
|
| 42 |
|
|
|
|
| 43 |
# ======================
|
| 44 |
# remote settings
|
| 45 |
# ======================
|
|
|
|
| 52 |
|
| 53 |
token = client.tokens.create()
|
| 54 |
|
|
|
|
| 55 |
# ======================
|
| 56 |
# 音频加载函数(缓存)
|
| 57 |
# ======================
|
|
|
|
| 75 |
# 应用启动时加载所有音频
|
| 76 |
audio_data = load_all_audios()
|
| 77 |
|
|
|
|
| 78 |
# ======================
|
| 79 |
# 照片檢測处理函数
|
| 80 |
# ======================
|
|
|
|
| 248 |
- Set up Twilio environment variables (TWILIO_ACCOUNT_SID and TWILIO_AUTH_TOKEN) for WebRTC.
|
| 249 |
""")
|
| 250 |
|
|
|
|
| 251 |
# ======================
|
| 252 |
# 照片检测页面
|
| 253 |
# ======================
|
|
|
|
| 275 |
st.image(image, caption="拍攝的圖片", use_container_width=True)
|
| 276 |
|
| 277 |
if image is not None:
|
|
|
|
| 278 |
# 吸烟分类
|
| 279 |
with st.spinner("Wait for smoking detection"):
|
| 280 |
smoke_result = smoking_detection(image)
|
|
|
|
| 303 |
st.error(f"音频文件不存在: {audio_key}.wav")
|
| 304 |
|
| 305 |
# ======================
|
| 306 |
+
# 实时检测页面
|
| 307 |
# ======================
|
| 308 |
|
| 309 |
def real_time_detection_page():
|
| 310 |
st.title("实时视频检测")
|
| 311 |
+
st.write("程序在一分钟内捕获5张快照进行图片分类,首先判定是否吸烟。若检测到吸烟的快照超过2次,则将结果添加到表格中,包含快照、性别和年龄。")
|
| 312 |
+
|
| 313 |
+
# 初始化 session state 用于存储检测结果
|
| 314 |
+
if 'detection_results' not in st.session_state:
|
| 315 |
+
st.session_state.detection_results = []
|
| 316 |
|
| 317 |
# 创建用于显示进度文字和进度条的占位容器
|
| 318 |
capture_text_placeholder = st.empty()
|
|
|
|
| 321 |
classification_progress_placeholder = st.empty()
|
| 322 |
detection_info_placeholder = st.empty()
|
| 323 |
|
| 324 |
+
# 显示检测结果表格
|
| 325 |
+
table_placeholder = st.empty()
|
| 326 |
+
|
| 327 |
# 启动实时视频流
|
| 328 |
ctx = webrtc_streamer(key="unique_example", video_transformer_factory=VideoTransformer,
|
| 329 |
+
rtc_configuration={"iceServers": token.ice_servers})
|
|
|
|
| 330 |
image_placeholder = st.empty()
|
| 331 |
audio_placeholder = st.empty()
|
| 332 |
|
|
|
|
| 373 |
most_common_gender = Counter(gender_results).most_common(1)[0][0]
|
| 374 |
most_common_age = Counter(age_results).most_common(1)[0][0]
|
| 375 |
|
| 376 |
+
# 找到第一张吸烟快照
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
smoking_image = None
|
| 378 |
for idx, label in enumerate(smoke_results):
|
| 379 |
if label.lower() == "smoking":
|
|
|
|
| 381 |
break
|
| 382 |
if smoking_image is None:
|
| 383 |
smoking_image = snapshots[0]
|
|
|
|
| 384 |
|
| 385 |
+
# 添加结果到 session state
|
| 386 |
+
st.session_state.detection_results.append({
|
| 387 |
+
"Timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
| 388 |
+
"Snapshot": smoking_image,
|
| 389 |
+
"Gender": most_common_gender,
|
| 390 |
+
"Age Range": most_common_age,
|
| 391 |
+
"Smoking Count": smoking_count
|
| 392 |
+
})
|
| 393 |
+
|
| 394 |
+
# 更新表格显示
|
| 395 |
+
df = pd.DataFrame([
|
| 396 |
+
{
|
| 397 |
+
"Timestamp": result["Timestamp"],
|
| 398 |
+
"Gender": result["Gender"],
|
| 399 |
+
"Age Range": result["Age Range"],
|
| 400 |
+
"Smoking Count": result["Smoking Count"]
|
| 401 |
+
} for result in st.session_state.detection_results
|
| 402 |
+
])
|
| 403 |
+
table_placeholder.dataframe(df, use_container_width=True)
|
| 404 |
+
|
| 405 |
+
# 显示示例快照
|
| 406 |
+
image_placeholder.image(smoking_image, caption="捕获的吸烟快照", use_container_width=True)
|
| 407 |
+
|
| 408 |
+
# 播放音频
|
| 409 |
audio_placeholder.empty()
|
| 410 |
audio_key = f"{most_common_age} {most_common_gender.lower()}"
|
| 411 |
if audio_key in audio_data:
|
|
|
|
| 421 |
classification_text_placeholder.text("分类进度: 分类完成!")
|
| 422 |
classification_progress.progress(100)
|
| 423 |
|
| 424 |
+
# 更新表格显示,即使没有吸烟检测到
|
| 425 |
+
if st.session_state.detection_results:
|
| 426 |
+
df = pd.DataFrame([
|
| 427 |
+
{
|
| 428 |
+
"Timestamp": result["Timestamp"],
|
| 429 |
+
"Gender": result["Gender"],
|
| 430 |
+
"Age Range": result["Age Range"],
|
| 431 |
+
"Smoking Count": result["Smoking Count"]
|
| 432 |
+
} for result in st.session_state.detection_results
|
| 433 |
+
])
|
| 434 |
+
table_placeholder.dataframe(df, use_container_width=True)
|
| 435 |
+
|
| 436 |
time.sleep(5)
|
| 437 |
classification_progress_placeholder.empty()
|
| 438 |
classification_text_placeholder.empty()
|
|
|
|
| 449 |
|
| 450 |
def main():
|
| 451 |
st.sidebar.title("导航")
|
| 452 |
+
page = st.sidebar.selectbox("选择页面", ["coverpage", "照片检测", "实时视频检测"])
|
| 453 |
|
| 454 |
if page == "coverpage":
|
| 455 |
+
cover_page()
|
| 456 |
if page == "照片检测":
|
| 457 |
photo_detection_page()
|
| 458 |
if page == "实时视频检测":
|