srivatsavdamaraju commited on
Commit
8554fb0
·
verified ·
1 Parent(s): 4ccef7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -228
app.py CHANGED
@@ -1,230 +1,93 @@
1
  import gradio as gr
2
- import cv2
3
- import os
4
- import uuid
5
- import threading
6
- import time
7
  import mediapipe as mp
8
- import pandas as pd
9
- from concurrent.futures import ThreadPoolExecutor
10
- import queue
11
-
12
- # === Setup ===
13
- OUTPUT_DIR = "captured_frames"
14
- os.makedirs(OUTPUT_DIR, exist_ok=True)
15
- df = pd.DataFrame(columns=["filename", "caption", "pose_coords"])
16
-
17
- pose = mp.solutions.pose.Pose()
18
-
19
- state = {
20
- "cap": None,
21
- "frame": None,
22
- "frame_rgb": None, # Pre-converted RGB frame
23
- "play": False,
24
- "video_path": None,
25
- "capture_queue": queue.Queue(),
26
- "processing_thread": None
27
- }
28
-
29
- # Thread pool for background processing
30
- executor = ThreadPoolExecutor(max_workers=2)
31
-
32
- # === Background pose processing ===
33
- def process_pose_async(frame_bgr, filename, caption):
34
- """Process pose estimation in background thread"""
35
- try:
36
- # Convert to RGB for MediaPipe
37
- frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
38
- results = pose.process(frame_rgb)
39
-
40
- coords = []
41
- if results.pose_landmarks:
42
- for lm in results.pose_landmarks.landmark:
43
- coords.append((round(lm.x, 5), round(lm.y, 5), round(lm.z, 5)))
44
-
45
- # Add to dataframe
46
- global df
47
- new_row = pd.DataFrame([{
48
- "filename": filename,
49
- "caption": caption,
50
- "pose_coords": coords
51
- }])
52
- df = pd.concat([df, new_row], ignore_index=True)
53
-
54
- except Exception as e:
55
- print(f"Error processing pose: {e}")
56
-
57
- # === Load Video ===
58
- def load_video(video_file):
59
- try:
60
- if hasattr(video_file, "name"):
61
- video_path = video_file.name
62
- else:
63
- video_path = video_file
64
- state["video_path"] = video_path
65
-
66
- if state["cap"]:
67
- state["cap"].release()
68
-
69
- state["cap"] = cv2.VideoCapture(video_path)
70
-
71
- # Set buffer size to reduce lag
72
- state["cap"].set(cv2.CAP_PROP_BUFFERSIZE, 1)
73
-
74
- state["frame"] = None
75
- state["frame_rgb"] = None
76
- state["play"] = False
77
- return "✅ Video loaded successfully!"
78
- except Exception as e:
79
- return f"❌ Error loading video: {e}"
80
-
81
- # === Play video in background ===
82
- def play_video():
83
- if not state["cap"]:
84
- return "⚠️ Load a video first."
85
- state["play"] = True
86
-
87
- def stream():
88
- while state["cap"] and state["cap"].isOpened() and state["play"]:
89
- ret, frame = state["cap"].read()
90
- if not ret:
91
- state["play"] = False
92
- break
93
-
94
- # Store both BGR and RGB versions
95
- state["frame"] = frame
96
- state["frame_rgb"] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
97
-
98
- # Adaptive delay based on video FPS
99
- fps = state["cap"].get(cv2.CAP_PROP_FPS)
100
- if fps > 0:
101
- time.sleep(1.0 / fps)
102
- else:
103
- time.sleep(0.033) # ~30 FPS fallback
104
-
105
- threading.Thread(target=stream, daemon=True).start()
106
- return "▶️ Playing..."
107
-
108
- # === Pause playback ===
109
- def pause_video():
110
- state["play"] = False
111
- return "⏸️ Paused."
112
-
113
- # === Show current frame ===
114
- def show_frame():
115
- if state["frame_rgb"] is not None:
116
- return state["frame_rgb"] # Already in RGB
117
- return None
118
-
119
- # === Fast capture frame (immediate pause + async processing) ===
120
- def capture_frame(caption):
121
- if state["frame"] is None:
122
- return "⚠️ No frame to capture.", None
123
-
124
- # IMMEDIATE pause - this is the key optimization
125
- state["play"] = False
126
-
127
- # Capture current frame immediately
128
- frame_bgr = state["frame"].copy() # Copy to avoid race conditions
129
- frame_rgb = state["frame_rgb"].copy() if state["frame_rgb"] is not None else cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
130
-
131
- # Generate filename and save immediately
132
- filename = f"{uuid.uuid4().hex[:8]}.jpg"
133
- path = os.path.join(OUTPUT_DIR, filename)
134
- cv2.imwrite(path, frame_bgr)
135
-
136
- # Process pose estimation in background (non-blocking)
137
- executor.submit(process_pose_async, frame_bgr, filename, caption)
138
-
139
- return f"✅ Captured & paused: {filename} (processing pose...)", frame_rgb
140
-
141
- # === Show dataset info ===
142
- def show_dataset_info():
143
- return f"📊 Dataset contains {len(df)} samples"
144
-
145
- # === Download CSV ===
146
- def download_csv():
147
- path = os.path.join(OUTPUT_DIR, "pose_dataset.csv")
148
- df.to_csv(path, index=False)
149
- return path
150
-
151
- # === Reset all ===
152
- def reset_all():
153
- global df
154
- df = pd.DataFrame(columns=["filename", "caption", "pose_coords"])
155
-
156
- # Clean up files
157
- try:
158
- for f in os.listdir(OUTPUT_DIR):
159
- file_path = os.path.join(OUTPUT_DIR, f)
160
- if os.path.isfile(file_path):
161
- os.remove(file_path)
162
- except Exception as e:
163
- print(f"Error cleaning files: {e}")
164
-
165
- # Reset video state
166
- if state["cap"]:
167
- state["cap"].release()
168
- state.update({
169
- "video_path": None,
170
- "cap": None,
171
- "frame": None,
172
- "frame_rgb": None,
173
- "play": False
174
- })
175
- return "🔁 Reset done.", None
176
-
177
- # === UI ===
178
- with gr.Blocks(title="Fast Archery Pose Capture") as app:
179
- gr.Markdown("## 🏹 Archery Pose Dataset Tool (Optimized for Speed)")
180
- gr.Markdown("⚡ **Optimized**: Instant capture with background pose processing")
181
-
182
- # Top section - Video loading
183
- video_input = gr.Video(label="🎞️ Upload Video")
184
- load_btn = gr.Button("📂 Load Video", variant="primary")
185
- status = gr.Textbox(label="Status", interactive=False)
186
-
187
- # Main section - Side by side layout
188
- with gr.Row():
189
- # Left column - Video display and controls
190
- with gr.Column(scale=1):
191
- gr.Markdown("### 🎥 Video Player")
192
- with gr.Row():
193
- play_btn = gr.Button("▶️ Play", variant="secondary")
194
- pause_btn = gr.Button("⏸️ Pause", variant="secondary")
195
- show_btn = gr.Button("🖼️ Show Frame", variant="secondary")
196
-
197
- image_output = gr.Image(label="Current Frame", height=400)
198
-
199
- # Right column - Capture controls
200
- with gr.Column(scale=1):
201
- gr.Markdown("### 📸 Capture Controls")
202
- caption_input = gr.Textbox(label="Caption", placeholder="Describe the pose...", lines=2)
203
- capture_btn = gr.Button("📸 Capture & Pause", variant="primary", size="lg")
204
-
205
- gr.Markdown("### 📊 Dataset Management")
206
- with gr.Row():
207
- info_btn = gr.Button("📊 Dataset Info")
208
- download_btn = gr.Button("📥 Download CSV")
209
-
210
- reset_btn = gr.Button("🔄 Reset All", variant="stop")
211
- dataset_info = gr.Textbox(label="Dataset Info", interactive=False, lines=2)
212
-
213
- # Bottom section - File download
214
- csv_file = gr.File(label="📄 Dataset CSV")
215
-
216
- # Bind actions
217
- load_btn.click(load_video, inputs=video_input, outputs=status)
218
- play_btn.click(play_video, outputs=status)
219
- pause_btn.click(pause_video, outputs=status)
220
- show_btn.click(show_frame, outputs=image_output)
221
- capture_btn.click(capture_frame, inputs=caption_input, outputs=[status, image_output])
222
- info_btn.click(show_dataset_info, outputs=dataset_info)
223
- download_btn.click(download_csv, outputs=csv_file)
224
- reset_btn.click(reset_all, outputs=[status, image_output])
225
-
226
- # Auto-refresh frame display while playing
227
- app.load(lambda: None) # Initialize
228
-
229
- if __name__ == "__main__":
230
- app.launch(share=False, server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
 
 
 
 
 
2
  import mediapipe as mp
3
+ import cv2
4
+ import numpy as np
5
+ from openai import OpenAI
6
+ import base64
7
+ import tempfile
8
+ import requests
9
+
10
+ # Initialize MediaPipe Pose
11
+ mp_pose = mp.solutions.pose
12
+ pose = mp_pose.Pose(static_image_mode=True)
13
+
14
+ # Function to extract pose landmarks
15
+ def extract_pose(image):
16
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
17
+ results = pose.process(image_rgb)
18
+
19
+ if results.pose_landmarks:
20
+ pose_data = [
21
+ {
22
+ "id": i,
23
+ "x": lm.x,
24
+ "y": lm.y,
25
+ "z": lm.z,
26
+ "visibility": lm.visibility
27
+ }
28
+ for i, lm in enumerate(results.pose_landmarks.landmark)
29
+ ]
30
+ return pose_data, image
31
+ else:
32
+ return "No pose landmarks found.", image
33
+
34
+ # Function to convert image to base64
35
+ def image_to_base64(img_np):
36
+ _, buffer = cv2.imencode('.jpg', img_np)
37
+ return base64.b64encode(buffer).decode('utf-8')
38
+
39
+ # Call Vision LLM
40
+ def call_llama_vlm(image, pose_data):
41
+ # Save image to temp and upload to imgbb or similar if needed
42
+ img_base64 = image_to_base64(image)
43
+
44
+ # Construct data for OpenRouter API
45
+ client = OpenAI(
46
+ base_url="https://openrouter.ai/api/v1",
47
+ api_key="<OPENROUTER_API_KEY>",
48
+ )
49
+
50
+ completion = client.chat.completions.create(
51
+ extra_headers={
52
+ "HTTP-Referer": "<YOUR_SITE_URL>",
53
+ "X-Title": "<YOUR_SITE_NAME>",
54
+ },
55
+ model="meta-llama/llama-3.2-11b-vision-instruct:free",
56
+ messages=[
57
+ {
58
+ "role": "user",
59
+ "content": [
60
+ {
61
+ "type": "text",
62
+ "text": f"What is this pose doing? Pose data: {pose_data}"
63
+ },
64
+ {
65
+ "type": "image_url",
66
+ "image_url": {
67
+ "url": f"data:image/jpeg;base64,{img_base64}"
68
+ }
69
+ }
70
+ ]
71
+ }
72
+ ]
73
+ )
74
+
75
+ return completion.choices[0].message.content
76
+
77
+ # Gradio Interface
78
+ def process(image):
79
+ pose_data, img = extract_pose(image)
80
+ if isinstance(pose_data, str):
81
+ return pose_data
82
+ else:
83
+ description = call_llama_vlm(img, pose_data)
84
+ return description
85
+
86
+ interface = gr.Interface(
87
+ fn=process,
88
+ inputs=gr.Image(type="numpy", label="Upload Pose Image"),
89
+ outputs="text",
90
+ title="Pose Analysis with MediaPipe and Vision LLM"
91
+ )
92
+
93
+ interface.launch()