seesaw112233 commited on
Commit
391a38b
·
verified ·
1 Parent(s): 17d3dc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -85
app.py CHANGED
@@ -1,92 +1,130 @@
1
- import gradio as gr
2
  import cv2
 
 
 
 
3
  import mediapipe as mp
4
- import tempfile
5
- import os
6
 
7
- def process_video(input_video):
8
- """处理视频"""
9
-
10
- if input_video is None:
 
 
 
 
11
  return None
12
-
13
- print("📹 开始处理...")
14
-
15
- try:
16
- # 初始化 MediaPipe
17
- mp_pose = mp.solutions.pose
18
- mp_drawing = mp.solutions.drawing_utils
19
-
20
- pose = mp_pose.Pose(
21
- static_image_mode=False,
22
- model_complexity=1,
23
- min_detection_confidence=0.5,
24
- min_tracking_confidence=0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  )
26
-
27
- # 打开视频
28
- cap = cv2.VideoCapture(input_video)
29
-
30
- if not cap.isOpened():
31
- return None
32
-
33
- # 获取视频信息
34
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
- fps = cap.get(cv2.CAP_PROP_FPS)
37
-
38
- print(f"视频: {width}x{height} @ {fps}fps")
39
-
40
- # 创建输出
41
- output_path = tempfile.mktemp(suffix='.mp4')
42
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
43
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
44
-
45
- frame_count = 0
46
-
47
- # 处理每一帧
48
- while cap.isOpened():
49
- ret, frame = cap.read()
50
- if not ret:
51
- break
52
-
53
- # 转 RGB
54
- rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
55
-
56
- # 检测姿态
57
- results = pose.process(rgb)
58
-
59
- # 绘制骨架
60
- if results.pose_landmarks:
61
- mp_drawing.draw_landmarks(
62
- frame,
63
- results.pose_landmarks,
64
- mp_pose.POSE_CONNECTIONS,
65
- mp_drawing.DrawingSpec(color=(0,255,0), thickness=2, circle_radius=2),
66
- mp_drawing.DrawingSpec(color=(0,0,255), thickness=2)
67
- )
68
-
69
- out.write(frame)
70
- frame_count += 1
71
-
72
- cap.release()
73
- out.release()
74
- pose.close()
75
-
76
- print(f"✅ 完成! {frame_count} 帧")
77
- return output_path
78
-
79
- except Exception as e:
80
- print(f"❌ 错误: {e}")
81
- return None
 
 
 
 
 
 
 
82
 
83
- # 创建界面
84
- demo = gr.Interface(
85
- fn=process_video,
86
- inputs=gr.Video(),
87
- outputs=gr.Video(),
88
- title="👶 Baby Pose Estimation",
89
- description="上传视频,自动识别姿态"
90
- )
91
 
92
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
1
+ import os
2
  import cv2
3
+ import numpy as np
4
+ import pandas as pd
5
+ import gradio as gr
6
+
7
  import mediapipe as mp
 
 
8
 
9
+
10
+ mp_pose = mp.solutions.pose
11
+ mp_drawing = mp.solutions.drawing_utils
12
+
13
+
14
+ def _ensure_rgb(img: np.ndarray) -> np.ndarray:
15
+ # Gradio Image returns RGB np.uint8
16
+ if img is None:
17
  return None
18
+ if img.dtype != np.uint8:
19
+ img = np.clip(img, 0, 255).astype(np.uint8)
20
+ if img.ndim == 2:
21
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
22
+ return img
23
+
24
+
25
+ def estimate_pose(image: np.ndarray, model_complexity: int, min_det: float, min_track: float):
26
+ """
27
+ Returns:
28
+ - annotated_image (RGB)
29
+ - keypoints dataframe
30
+ """
31
+ image = _ensure_rgb(image)
32
+ if image is None:
33
+ return None, pd.DataFrame()
34
+
35
+ # MediaPipe expects RGB, but drawing is easier in BGR sometimes; we'll keep RGB and convert when needed.
36
+ rgb = image.copy()
37
+
38
+ with mp_pose.Pose(
39
+ static_image_mode=True,
40
+ model_complexity=model_complexity,
41
+ enable_segmentation=False,
42
+ min_detection_confidence=float(min_det),
43
+ min_tracking_confidence=float(min_track),
44
+ ) as pose:
45
+ results = pose.process(rgb)
46
+
47
+ annotated = rgb.copy()
48
+
49
+ rows = []
50
+ if results.pose_landmarks:
51
+ # Draw landmarks
52
+ annotated_bgr = cv2.cvtColor(annotated, cv2.COLOR_RGB2BGR)
53
+ mp_drawing.draw_landmarks(
54
+ annotated_bgr,
55
+ results.pose_landmarks,
56
+ mp_pose.POSE_CONNECTIONS,
57
+ landmark_drawing_spec=mp_drawing.DrawingSpec(thickness=2, circle_radius=2),
58
+ connection_drawing_spec=mp_drawing.DrawingSpec(thickness=2),
59
  )
60
+ annotated = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB)
61
+
62
+ # Collect keypoints
63
+ for i, lm in enumerate(results.pose_landmarks.landmark):
64
+ rows.append(
65
+ {
66
+ "id": i,
67
+ "name": mp_pose.PoseLandmark(i).name,
68
+ "x": float(lm.x),
69
+ "y": float(lm.y),
70
+ "z": float(lm.z),
71
+ "visibility": float(lm.visibility),
72
+ }
73
+ )
74
+
75
+ df = pd.DataFrame(rows)
76
+ return annotated, df
77
+
78
+
79
+ def build_demo():
80
+ with gr.Blocks(title="Pose Estimation") as demo:
81
+ gr.Markdown(
82
+ "## 🕺 Pose Estimation (MediaPipe)\n"
83
+ "上传一张图片 输出骨架标注图 + 关键点表格。\n\n"
84
+ "如果你之前遇到 `TypeError: argument of type 'bool' is not iterable`,这是 Gradio 4.x 的一个坑,"
85
+ "本 Space 已升级到 Gradio 5.x 来避免。"
86
+ )
87
+
88
+ with gr.Row():
89
+ inp = gr.Image(label="Input Image", type="numpy")
90
+ out_img = gr.Image(label="Annotated Output", type="numpy")
91
+
92
+ with gr.Row():
93
+ model_complexity = gr.Radio(
94
+ choices=[0, 1, 2],
95
+ value=1,
96
+ label="Model Complexity (0=light, 2=accurate)",
97
+ )
98
+ min_det = gr.Slider(0.1, 0.99, value=0.5, step=0.01, label="Min Detection Confidence")
99
+ min_track = gr.Slider(0.1, 0.99, value=0.5, step=0.01, label="Min Tracking Confidence")
100
+
101
+ out_df = gr.Dataframe(
102
+ label="Keypoints (normalized coords)",
103
+ headers=["id", "name", "x", "y", "z", "visibility"],
104
+ interactive=False,
105
+ wrap=True,
106
+ )
107
+
108
+ run_btn = gr.Button("Run Pose Estimation", variant="primary")
109
+ run_btn.click(
110
+ fn=estimate_pose,
111
+ inputs=[inp, model_complexity, min_det, min_track],
112
+ outputs=[out_img, out_df],
113
+ )
114
+
115
+ gr.Markdown(
116
+ "### Notes\n"
117
+ "- `x/y/z` 是相对坐标(0~1),相对于输入图像宽高。\n"
118
+ "- 这是 CPU 友好版本,适合 Hugging Face Spaces。"
119
+ )
120
+
121
+ return demo
122
+
123
 
124
+ demo = build_demo()
 
 
 
 
 
 
 
125
 
126
+ if __name__ == "__main__":
127
+ # Hugging Face Spaces 通常不需要 share=True
128
+ # 如果你环境仍然报 localhost 不可访问,可把 share=True 打开兜底
129
+ share = os.getenv("GRADIO_SHARE", "0") == "1"
130
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=share)