srivatsavdamaraju commited on
Commit
692a2fe
Β·
verified Β·
1 Parent(s): a1347bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -47
app.py CHANGED
@@ -27,23 +27,72 @@ pose_model = mp_pose.Pose(static_image_mode=True, model_complexity=2)
27
  mp_drawing = mp.solutions.drawing_utils
28
  mp_styles = mp.solutions.drawing_styles
29
 
30
- # πŸ”§ Process function
31
- def process(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  if image is None:
33
- return "❌ Please upload an image."
34
-
35
  # Timestamp-based ID
36
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
37
  pose_id = f"pose_{ts}"
38
-
39
  # Convert to OpenCV
40
  img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
41
-
42
  # Detect pose
43
  results = pose_model.process(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
44
  if not results.pose_landmarks:
45
- return "❌ No pose detected."
46
-
47
  # Draw overlay
48
  overlay = img_bgr.copy()
49
  mp_drawing.draw_landmarks(
@@ -52,49 +101,151 @@ def process(image):
52
  mp_pose.POSE_CONNECTIONS,
53
  landmark_drawing_spec=mp_styles.get_default_pose_landmarks_style()
54
  )
55
-
 
 
 
56
  # Save overlay image
57
  overlay_path = f"pose_images/{pose_id}.png"
58
  cv2.imwrite(overlay_path, overlay)
59
-
60
- # Extract coordinates
61
- pose_coords = {}
62
- for idx, lm in enumerate(results.pose_landmarks.landmark):
63
- name = mp_pose.PoseLandmark(idx).name
64
- pose_coords[name] = {
65
- "x": round(lm.x, 4),
66
- "y": round(lm.y, 4),
67
- "z": round(lm.z, 4),
68
- "visibility": round(lm.visibility, 3)
69
- }
70
-
71
- # Save to JSON
 
 
72
  pose_dataset[pose_id] = {
73
- "correct pose name": pose_id,
74
- "Image with pose overlay": overlay_path,
75
- "Pose coordinates": pose_coords,
76
- "Pose description": "To be filled"
 
77
  }
78
-
79
  with open(json_path, "w") as f:
80
  json.dump(pose_dataset, f, indent=2)
81
-
82
- # Return basic preview
83
- preview = f"βœ… Pose saved as `{pose_id}` with {len(pose_coords)} joints.\n"
84
- for joint, v in list(pose_coords.items())[:5]: # show first 5
85
- preview += f"{joint}: x={v['x']} y={v['y']} z={v['z']}\n"
86
-
87
- preview += f"\nπŸ–Ό Overlay image saved at `{overlay_path}`"
88
-
89
- return preview
90
-
91
- # βœ… Interface-style Gradio UI
92
- interface = gr.Interface(
93
- fn=process,
94
- inputs=gr.Image(type="numpy", label="Upload Pose Image"),
95
- outputs="text",
96
- title="🧘 Pose Analysis with MediaPipe",
97
- description="Upload a yoga or archery pose image. This tool will extract pose keypoints using MediaPipe and save them in a JSON file."
98
- )
99
-
100
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  mp_drawing = mp.solutions.drawing_utils
28
  mp_styles = mp.solutions.drawing_styles
29
 
30
+ # Define pose connections (edges) for graph structure
31
+ POSE_CONNECTIONS = [
32
+ # Face
33
+ (0, 1), (1, 2), (2, 3), (3, 7), (0, 4), (4, 5), (5, 6), (6, 8),
34
+ # Torso
35
+ (9, 10), (11, 12), (11, 13), (13, 15), (15, 17), (15, 19), (15, 21),
36
+ (12, 14), (14, 16), (16, 18), (16, 20), (16, 22), (11, 23), (12, 24),
37
+ (23, 24),
38
+ # Left arm
39
+ (11, 13), (13, 15), (15, 17), (17, 19), (19, 21),
40
+ # Right arm
41
+ (12, 14), (14, 16), (16, 18), (18, 20), (20, 22),
42
+ # Left leg
43
+ (23, 25), (25, 27), (27, 29), (29, 31), (27, 31),
44
+ # Right leg
45
+ (24, 26), (26, 28), (28, 30), (30, 32), (28, 32)
46
+ ]
47
+
48
+ def create_pose_graph_data(pose_landmarks):
49
+ """Create nodes and edges data structure from pose landmarks"""
50
+ nodes = {}
51
+ edges = []
52
+
53
+ # Create nodes
54
+ for idx, lm in enumerate(pose_landmarks.landmark):
55
+ name = mp_pose.PoseLandmark(idx).name
56
+ nodes[idx] = {
57
+ "id": idx,
58
+ "name": name,
59
+ "x": round(lm.x, 4),
60
+ "y": round(lm.y, 4),
61
+ "z": round(lm.z, 4),
62
+ "visibility": round(lm.visibility, 3)
63
+ }
64
+
65
+ # Create edges based on MediaPipe connections
66
+ for connection in mp_pose.POSE_CONNECTIONS:
67
+ start_idx = connection[0]
68
+ end_idx = connection[1]
69
+ if start_idx < len(pose_landmarks.landmark) and end_idx < len(pose_landmarks.landmark):
70
+ edges.append({
71
+ "from": start_idx,
72
+ "to": end_idx,
73
+ "from_name": mp_pose.PoseLandmark(start_idx).name,
74
+ "to_name": mp_pose.PoseLandmark(end_idx).name
75
+ })
76
+
77
+ return nodes, edges
78
+
79
+ def process_pose(image, pose_description=""):
80
+ """Process pose image and return overlay with pose data"""
81
  if image is None:
82
+ return None, "❌ Please upload an image.", ""
83
+
84
  # Timestamp-based ID
85
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
86
  pose_id = f"pose_{ts}"
87
+
88
  # Convert to OpenCV
89
  img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
90
+
91
  # Detect pose
92
  results = pose_model.process(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
93
  if not results.pose_landmarks:
94
+ return None, "❌ No pose detected.", ""
95
+
96
  # Draw overlay
97
  overlay = img_bgr.copy()
98
  mp_drawing.draw_landmarks(
 
101
  mp_pose.POSE_CONNECTIONS,
102
  landmark_drawing_spec=mp_styles.get_default_pose_landmarks_style()
103
  )
104
+
105
+ # Convert back to RGB for display
106
+ overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
107
+
108
  # Save overlay image
109
  overlay_path = f"pose_images/{pose_id}.png"
110
  cv2.imwrite(overlay_path, overlay)
111
+
112
+ # Create graph data structure
113
+ nodes, edges = create_pose_graph_data(results.pose_landmarks)
114
+
115
+ # Create pose data summary
116
+ pose_data = {
117
+ "pose_id": pose_id,
118
+ "total_nodes": len(nodes),
119
+ "total_edges": len(edges),
120
+ "nodes": nodes,
121
+ "edges": edges,
122
+ "description": pose_description if pose_description else "No description provided"
123
+ }
124
+
125
+ # Save to JSON dataset
126
  pose_dataset[pose_id] = {
127
+ "pose_name": pose_id,
128
+ "image_path": overlay_path,
129
+ "pose_description": pose_description,
130
+ "pose_data": pose_data,
131
+ "timestamp": ts
132
  }
133
+
134
  with open(json_path, "w") as f:
135
  json.dump(pose_dataset, f, indent=2)
136
+
137
+ # Format pose data for display
138
+ data_display = f"""
139
+ 🎯 **Pose Analysis Results**
140
+
141
+ πŸ“Š **Graph Structure:**
142
+ - Total Nodes (Keypoints): {len(nodes)}
143
+ - Total Edges (Connections): {len(edges)}
144
+
145
+ πŸ“ **Pose Description:** {pose_description if pose_description else "No description provided"}
146
+
147
+ πŸ” **Key Nodes (First 10):**
148
+ """
149
+
150
+ for i, (idx, node) in enumerate(list(nodes.items())[:10]):
151
+ data_display += f"β€’ {node['name']}: ({node['x']:.3f}, {node['y']:.3f}, {node['z']:.3f}) [visibility: {node['visibility']}]\n"
152
+
153
+ if len(nodes) > 10:
154
+ data_display += f"... and {len(nodes) - 10} more nodes\n"
155
+
156
+ data_display += f"""
157
+ πŸ”— **Sample Connections:**
158
+ """
159
+
160
+ for i, edge in enumerate(edges[:5]):
161
+ data_display += f"β€’ {edge['from_name']} β†’ {edge['to_name']}\n"
162
+
163
+ if len(edges) > 5:
164
+ data_display += f"... and {len(edges) - 5} more connections\n"
165
+
166
+ data_display += f"""
167
+ πŸ’Ύ **Data Saved:**
168
+ - Image: {overlay_path}
169
+ - JSON: {json_path}
170
+ - Pose ID: {pose_id}
171
+ """
172
+
173
+ return overlay_rgb, data_display, f"βœ… Pose '{pose_id}' saved successfully!"
174
+
175
+ def save_with_description(image, description):
176
+ """Save pose with description"""
177
+ if image is None:
178
+ return None, "❌ Please upload an image first.", "❌ No image to process"
179
+
180
+ return process_pose(image, description)
181
+
182
+ # Create Gradio interface
183
+ with gr.Blocks(title="🧘 Advanced Pose Analysis Tool") as demo:
184
+ gr.Markdown("# 🧘 Advanced Pose Analysis with MediaPipe")
185
+ gr.Markdown("Upload a yoga, archery, or any pose image to extract keypoints and analyze the pose structure as nodes and edges.")
186
+
187
+ with gr.Row():
188
+ with gr.Column(scale=1):
189
+ # Input section
190
+ gr.Markdown("## πŸ“€ Input")
191
+ input_image = gr.Image(type="numpy", label="Upload Pose Image")
192
+ pose_description = gr.Textbox(
193
+ label="Pose Description",
194
+ placeholder="Enter a description of the pose (e.g., 'Warrior II pose with arms extended')",
195
+ lines=3
196
+ )
197
+
198
+ with gr.Row():
199
+ analyze_btn = gr.Button("πŸ” Analyze Pose", variant="primary")
200
+ save_btn = gr.Button("πŸ’Ύ Save with Description", variant="secondary")
201
+
202
+ with gr.Column(scale=1):
203
+ # Output section
204
+ gr.Markdown("## πŸ“Š Results")
205
+ output_image = gr.Image(label="Pose with MediaPipe Overlay")
206
+ status_text = gr.Textbox(label="Status", lines=1)
207
+
208
+ # Pose data display
209
+ gr.Markdown("## πŸ“‹ Pose Data (Nodes & Edges)")
210
+ pose_data_display = gr.Textbox(
211
+ label="Pose Analysis Data",
212
+ lines=15,
213
+ max_lines=20,
214
+ show_copy_button=True
215
+ )
216
+
217
+ # Button actions
218
+ analyze_btn.click(
219
+ fn=lambda img: process_pose(img, ""),
220
+ inputs=[input_image],
221
+ outputs=[output_image, pose_data_display, status_text]
222
+ )
223
+
224
+ save_btn.click(
225
+ fn=save_with_description,
226
+ inputs=[input_image, pose_description],
227
+ outputs=[output_image, pose_data_display, status_text]
228
+ )
229
+
230
+ # Auto-analyze on image upload
231
+ input_image.change(
232
+ fn=lambda img: process_pose(img, ""),
233
+ inputs=[input_image],
234
+ outputs=[output_image, pose_data_display, status_text]
235
+ )
236
+
237
+ gr.Markdown("""
238
+ ## πŸ“– How to Use:
239
+ 1. **Upload Image**: Upload a pose image using the image uploader
240
+ 2. **Auto Analysis**: The pose will be automatically analyzed showing keypoints
241
+ 3. **Add Description**: Enter a description of the pose in the text box
242
+ 4. **Save**: Click "Save with Description" to save the pose data with your description
243
+
244
+ ## πŸ“Š Data Structure:
245
+ - **Nodes**: 33 body keypoints (face, torso, arms, legs) with 3D coordinates
246
+ - **Edges**: Connections between keypoints following human body structure
247
+ - **Visibility**: Confidence score for each keypoint detection
248
+ """)
249
+
250
+ # Launch the interface
251
+ demo.launch(share=True)