pictogram / app.py
FishKIM's picture
Update app.py
711ee3f verified
import joblib
import cv2
import numpy as np
import mediapipe as mp
import gradio as gr
import matplotlib.pyplot as plt
import io
from PIL import Image
model = joblib.load("image_model.pkl")
action_labels = {0: "Weightlifting", 1: "Soccer", 2: "Handball"}
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5)
mp_drawing = mp.solutions.drawing_utils
def process_frame(frame):
image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = pose.process(image_rgb)
if results.pose_landmarks:
mp_drawing.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
resized = cv2.resize(image_rgb, (128, 128))
input_data = resized.flatten().reshape(1, -1)
prediction_proba = model.predict_proba(input_data)[0]
predicted_class = np.argmax(prediction_proba)
prediction = action_labels.get(predicted_class, "Unknown")
confidence = prediction_proba[predicted_class] * 100
label = f'{prediction} ({confidence:.2f}%)'
# ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
fig, ax = plt.subplots()
ax.bar(action_labels.values(), prediction_proba, color='skyblue')
ax.set_ylim([0, 1])
ax.set_ylabel("Confidence")
ax.set_title("Prediction Probabilities")
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
graph_img = Image.open(buf).convert("RGB")
plt.close()
cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2)
return frame, label, graph_img
demo = gr.Interface(
fn=process_frame,
inputs=gr.Image(source="webcam", streaming=True),
outputs=[
gr.Image(label="Pose + Prediction"),
gr.Textbox(label="Predicted Action"),
gr.Image(label="Confidence Graph")
],
live=True,
title="Pose Action Classifier",
description="Real-time action prediction using MediaPipe and a trained model."
)
demo.launch()