Jiwoo77's picture
Upload app.py
d77592f verified
from read_bpm import bpm_value
import os
import time
import cv2
import numpy as np
import tensorflow as tf
import gradio as gr
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from fpdf import FPDF
from PIL import Image
MODEL_PATH = "fer_surprise_softmax.h5"
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
IMG_SIZE = (96, 96)
CLASS_NAMES = ["angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"]
SURPRISE_IDX = CLASS_NAMES.index("surprise")
face_cascade = cv2.CascadeClassifier(
cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
)
events = []
surprise_history = []
start_time = None
MIN_EVENT_GAP = 1.0
frames_with_face = 0
max_p_surprise = 0.0
def format_time(seconds: float) -> str:
minutes = int(seconds // 60)
sec = int(seconds % 60)
return f"{minutes:02d}:{sec:02d}"
def detect_surprise(frame, threshold):
global events, start_time, surprise_history
global frames_with_face, max_p_surprise
if frame is None:
stats_text = (
"### Session Stats\n"
"- Session duration: 00:00\n"
f"- Current threshold: {threshold:.2f}\n"
"- Frames with face detected: 0\n"
"- Surprise events detected: 0\n"
"- Max P(surprise): 0.00\n"
)
return None, {"Error": 1.0}, None, stats_text
if start_time is None:
start_time = time.time()
surprise_history = []
events = []
frames_with_face = 0
max_p_surprise = 0.0
current_time = time.time() - start_time
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
label = "NO FACE - Try brighter lighting or adjust angle"
color = (0, 255, 255)
probs_dict = {}
if len(faces) > 0:
frames_with_face += 1
x, y, w, h = sorted(faces, key=lambda r: r[2] * r[3], reverse=True)[0]
roi = frame_bgr[y:y+h, x:x+w]
rgb = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
resized = cv2.resize(rgb, IMG_SIZE)
inp = resized.astype("float32") / 255.0
inp = np.expand_dims(inp, axis=0)
probs = model.predict(inp, verbose=0)[0]
p_surprise = float(probs[SURPRISE_IDX])
if p_surprise > max_p_surprise:
max_p_surprise = p_surprise
probs_dict = {
cls: float(p) for cls, p in zip(CLASS_NAMES, probs)
}
surprise_history.append({
"time": current_time,
"score": p_surprise,
})
if p_surprise >= threshold:
if len(events) == 0:
events.append({
"time": current_time,
"score": p_surprise,
"frame": frame.copy()
})
else:
dt = current_time - events[-1]["time"]
if dt > MIN_EVENT_GAP:
events.append({
"time": current_time,
"score": p_surprise,
"frame": frame.copy()
})
else:
if p_surprise > events[-1]["score"]:
events[-1]["time"] = current_time
events[-1]["score"] = p_surprise
events[-1]["frame"] = frame.copy()
label = f"😲 SURPRISE (p={p_surprise:.2f})"
color = (0, 255, 0)
else:
label = f"πŸ™‚ Not Surprise (p={p_surprise:.2f})"
color = (0, 0, 255)
cv2.rectangle(frame_bgr, (x, y), (x + w, y + h), color, 3)
h_img, w_img = frame_bgr.shape[:2]
cv2.putText(
frame_bgr,
label,
(10, h_img - 10),
cv2.FONT_HERSHEY_SIMPLEX,
1.6,
color,
3
)
out_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
fig = go.Figure()
if len(probs_dict) > 0:
fig.add_trace(go.Bar(
x=list(probs_dict.keys()),
y=list(probs_dict.values()),
marker_color="lightskyblue"
))
fig.update_layout(
title="Emotion Probability Distribution",
yaxis=dict(range=[0, 1])
)
session_duration_str = format_time(current_time)
stats_text = (
"### Session Stats\n"
f"- Session duration: {session_duration_str}\n"
f"- Current threshold: {threshold:.2f}\n"
f"- Frames with face detected: {frames_with_face}\n"
f"- Surprise events detected: {len(events)}\n"
f"- Max P(surprise): {max_p_surprise:.2f}\n"
)
return out_rgb, probs_dict, fig, stats_text
def summarize_results():
global events, start_time, surprise_history
global frames_with_face, max_p_surprise
if len(surprise_history) == 0:
return "No data recorded.", None, None, None, None, None
times = [h["time"] for h in surprise_history]
scores = [h["score"] for h in surprise_history]
fig, ax = plt.subplots()
ax.plot(times, scores, marker="o", linewidth=1)
ax.set_title("Surprise Probability Timeline")
ax.set_xlabel("Time (s)")
ax.set_ylabel("P(surprise)")
ax.set_ylim(0, 1)
ax.grid(True)
top_images = [None, None, None]
if len(events) == 0:
summary_text = (
"No surprise events detected above the current threshold.\n\n"
"The timeline shows overall surprise probability over time."
)
img1 = img2 = img3 = None
else:
top3 = sorted(events, key=lambda x: x["score"], reverse=True)[:3]
captions = []
images = []
top_times = []
top_scores = []
for i, e in enumerate(top3):
formatted_time = format_time(e["time"])
score = e["score"]
captions.append(f"#{i+1} Time = {formatted_time} Score = {score:.2f}")
images.append(e["frame"])
top_times.append(e["time"])
top_scores.append(score)
summary_text = "Top 3 surprise moments:\n" + "\n".join(captions)
markers = ["*", "^", "s"]
colors = ["red", "darkorange", "gold"]
for i, (t, s) in enumerate(zip(top_times, top_scores)):
ax.scatter(t, s, color=colors[i], marker=markers[i], s=80, zorder=5)
for i in range(3):
if i < len(images):
top_images[i] = images[i]
img1, img2, img3 = top_images
# PDF 생성 μƒλž΅
return summary_text, img1, img2, img3, fig, None
# ===============================
# πŸ”₯ Gradio UI + BPM ν‘œμ‹œ
# ===============================
try:
custom_theme = gr.themes.Soft(primary_hue="indigo", neutral_hue="slate")
except:
custom_theme = "soft"
demo = gr.Blocks(theme=custom_theme)
with demo:
gr.Markdown("## 🎭 Real-Time Surprise Detector & Heart Rate Monitor")
webcam = gr.Image(sources=["webcam"], type="numpy", label="Webcam")
output_img = gr.Image(label="Detection")
threshold = gr.Slider(0.0, 1.0, value=0.1, step=0.01, label="Threshold")
output_label = gr.Label(label="Softmax")
plot = gr.Plot(label="Emotion Plot")
stats_md = gr.Markdown()
webcam.stream(
fn=detect_surprise,
inputs=[webcam, threshold],
outputs=[output_img, output_label, plot, stats_md],
stream_every=0.1
)
gr.HTML("""
<div style='font-size:24px; font-weight:bold; margin-top:20px;'>
❀️ Current BPM: <span id="bpm_display">--</span>
</div>
<script>
async function getBPM() {
try {
const res = await fetch("http://127.0.0.1:8000/get_bpm");
const data = await res.json();
return data.bpm;
} catch (err) {
console.log("BPM fetch error:", err);
return "--";
}
}
setInterval(async () => {
const bpm = await getBPM();
document.getElementById("bpm_display").innerText = bpm;
}, 1000);
</script>
""")
if __name__ == "__main__":
demo.launch()