File size: 1,901 Bytes
a00e120
 
 
 
 
 
 
711ee3f
a00e120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711ee3f
a00e120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import joblib
import cv2
import numpy as np
import mediapipe as mp
import gradio as gr
import matplotlib.pyplot as plt
import io
from PIL import Image

model = joblib.load("image_model.pkl")
action_labels = {0: "Weightlifting", 1: "Soccer", 2: "Handball"}

mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5)
mp_drawing = mp.solutions.drawing_utils

def process_frame(frame):
    image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    results = pose.process(image_rgb)

    if results.pose_landmarks:
        mp_drawing.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)

    resized = cv2.resize(image_rgb, (128, 128))
    input_data = resized.flatten().reshape(1, -1)
    prediction_proba = model.predict_proba(input_data)[0]
    predicted_class = np.argmax(prediction_proba)
    prediction = action_labels.get(predicted_class, "Unknown")
    confidence = prediction_proba[predicted_class] * 100
    label = f'{prediction} ({confidence:.2f}%)'

    # 그래프 생성
    fig, ax = plt.subplots()
    ax.bar(action_labels.values(), prediction_proba, color='skyblue')
    ax.set_ylim([0, 1])
    ax.set_ylabel("Confidence")
    ax.set_title("Prediction Probabilities")
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)
    graph_img = Image.open(buf).convert("RGB")
    plt.close()

    cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2)
    return frame, label, graph_img

demo = gr.Interface(
    fn=process_frame,
    inputs=gr.Image(source="webcam", streaming=True),
    outputs=[
        gr.Image(label="Pose + Prediction"),
        gr.Textbox(label="Predicted Action"),
        gr.Image(label="Confidence Graph")
    ],
    live=True,
    title="Pose Action Classifier",
    description="Real-time action prediction using MediaPipe and a trained model."
)

demo.launch()