aiqtech commited on
Commit
94ded05
·
verified ·
1 Parent(s): 8d84d68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +533 -205
app.py CHANGED
@@ -1,205 +1,533 @@
1
- import gradio as gr
2
- import numpy as np
3
- import cv2
4
- from fastapi import FastAPI, Request, Response
5
- from src.body import Body
6
- import json as js
7
-
8
- body_estimation = Body('model/body_pose_model.pth')
9
-
10
- def pil2cv(image):
11
- ''' PIL型 -> OpenCV型 '''
12
- new_image = np.array(image, dtype=np.uint8)
13
- if new_image.ndim == 2: # モノクロ
14
- pass
15
- elif new_image.shape[2] == 3: # カラー
16
- new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
17
- elif new_image.shape[2] == 4: # 透過
18
- new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
19
- return new_image
20
-
21
- with open("static/poseEditor.js", "r") as f:
22
- file_contents = f.read()
23
-
24
- app = FastAPI()
25
-
26
- @app.middleware("http")
27
- async def some_fastapi_middleware(request: Request, call_next):
28
- path = request.scope['path'] # get the request route
29
- response = await call_next(request)
30
-
31
- if path == "/":
32
- response_body = ""
33
- async for chunk in response.body_iterator:
34
- response_body += chunk.decode()
35
-
36
- some_javascript = f"""
37
- <script type="text/javascript" defer>
38
- {file_contents}
39
- </script>
40
- """
41
-
42
- response_body = response_body.replace("</body>", some_javascript + "</body>")
43
-
44
- del response.headers["content-length"]
45
-
46
- return Response(
47
- content=response_body,
48
- status_code=response.status_code,
49
- headers=dict(response.headers),
50
- media_type=response.media_type
51
- )
52
-
53
- return response
54
-
55
- # make cndidate to json
56
- def candidate_to_json_string(arr):
57
- a = [f'[{x:.2f}, {y:.2f}]' for x, y, *_ in arr]
58
- return '[' + ', '.join(a) + ']'
59
-
60
- # make subset to json
61
- def subset_to_json_string(arr):
62
- arr_str = ','.join(['[' + ','.join([f'{num:.2f}' for num in row]) + ']' for row in arr])
63
- return '[' + arr_str + ']'
64
-
65
- def estimate_body(source):
66
- if source == None:
67
- return None
68
-
69
- candidate, subset = body_estimation(pil2cv(source))
70
- return "{ \"candidate\": " + candidate_to_json_string(candidate) + ", \"subset\": " + subset_to_json_string(subset) + " }"
71
-
72
- def image_changed(image):
73
- if image == None:
74
- return "estimation", {}
75
-
76
- if 'openpose' in image.info:
77
- print("pose found")
78
- jsonText = image.info['openpose']
79
- jsonObj = js.loads(jsonText)
80
- subset = jsonObj['subset']
81
- return f"""{image.width}px x {image.height}px, {len(subset)} indivisual(s)""", jsonText
82
- else:
83
- print("pose not found")
84
- candidate, subset = body_estimation(pil2cv(image))
85
- jsonText = "{ \"candidate\": " + candidate_to_json_string(candidate) + ", \"subset\": " + subset_to_json_string(subset) + " }"
86
- return f"""{image.width}px x {image.height}px, {subset.shape[0]} indivisual(s)""", jsonText
87
-
88
- html_text = f"""
89
- <canvas id="canvas" width="512" height="512"></canvas>
90
- <script type="text/javascript" defer>{file_contents}</script>
91
- """
92
-
93
- with gr.Blocks(css="""button { min-width: 80px; }""") as demo:
94
- gr.Markdown(f"""
95
- ## This project is no longer being updated. Please use [PoseMaker2](https://huggingface.co/spaces/jonigata/PoseMaker2) instead.
96
- ### (That project uses MMPose for pose estimation.)
97
- """)
98
- with gr.Row():
99
- with gr.Column(scale=1):
100
- width = gr.Slider(label="Width", minimum=512, maximum=1024, step=64, value=512, interactive=True)
101
- height = gr.Slider(label="Height", minimum=512, maximum=1024, step=64, value=512, interactive=True)
102
- with gr.Accordion(label="Pose estimation", open=False):
103
- source = gr.Image(type="pil")
104
- estimationResult = gr.Markdown("""estimation""")
105
- with gr.Row():
106
- with gr.Column(min_width=80):
107
- applySizeBtn = gr.Button(value="Apply size")
108
- with gr.Column(min_width=80):
109
- replaceBtn = gr.Button(value="Replace")
110
- with gr.Column(min_width=80):
111
- importBtn = gr.Button(value="Import")
112
- with gr.Accordion(label="Json", open=False):
113
- with gr.Row():
114
- with gr.Column(min_width=80):
115
- replaceWithJsonBtn = gr.Button(value="Replace")
116
- with gr.Column(min_width=80):
117
- importJsonBtn = gr.Button(value="Import")
118
- gr.Markdown("""
119
- | inout | how to |
120
- | -----------------| ----------------------------------------------------------------------------------------- |
121
- | Import | Paste json to "Json source" and click "Read", edit the width/height, then click "Replace" or "Import". |
122
- | Export | click "Save" and "Copy to clipboard" of "Json" section. |
123
- """)
124
- json = gr.JSON(label="Json")
125
- jsonSource = gr.Textbox(label="Json source", lines=10)
126
- with gr.Accordion(label="Notes", open=False):
127
- gr.Markdown("""
128
- #### How to bring pose to ControlNet
129
- 1. Press **Save** button
130
- 2. **Drag** the file placed at the bottom left corder of browser
131
- 3. **Drop** the file into ControlNet
132
-
133
- #### Points to note for pseudo-3D rotation
134
- When performing pseudo-3D rotation on the X and Y axes, the projection is converted to 2D and Z-axis information is lost when the mouse button is released. This means that if you finish dragging while the shape is collapsed, you may not be able to restore it to its original state. In such a case, please use the "undo" function.
135
-
136
- #### Reuse pose image
137
- Pose image generated by this tool has pose data in the image itself. You can reuse pose information by loading it as the image source instead of a regular image.
138
- """)
139
- with gr.Column(scale=2):
140
- html = gr.HTML(html_text)
141
- with gr.Row():
142
- with gr.Column(scale=1, min_width=60):
143
- saveBtn = gr.Button(value="Save")
144
- with gr.Column(scale=7):
145
- gr.Markdown("""
146
- - "ctrl + drag" to **scale**
147
- - "alt + drag" to **move**
148
- - "shift + drag" to **rotate** (move right first, release shift, then up or down)
149
- - "space + drag" to **range-move**
150
- - "[", "]" or "Alt + wheel" or "Space + wheel" to shrink or expand **range**
151
- - "ctrl + Z", "shift + ctrl + Z" to **undo**, **redo**
152
- - "ctrl + E" **add** new person
153
- - "D + click" to **delete** person
154
- - "Q + click" to **cut off** limb
155
- - "X + drag" to **x-axis** pseudo-3D rotation
156
- - "C + drag" to **y-axis** pseudo-3D rotation
157
- - "R + click" to **repair**
158
-
159
- When using Q, X, C, R, pressing and dont release until the operation is complete.
160
-
161
- [Contact us for feature requests or bug reports (anonymous)](https://t.co/UC3jJOJJtS)
162
- """)
163
-
164
- width.change(fn=None, inputs=[width], _js="(w) => { resizeCanvas(w,null); }")
165
- height.change(fn=None, inputs=[height], _js="(h) => { resizeCanvas(null,h); }")
166
-
167
- source.change(
168
- fn = image_changed,
169
- inputs = [source],
170
- outputs = [estimationResult, json])
171
- applySizeBtn.click(
172
- fn = lambda x: (x.width, x.height),
173
- inputs = [source],
174
- outputs = [width, height])
175
- replaceBtn.click(
176
- fn = None,
177
- inputs = [json],
178
- outputs = [],
179
- _js="(json) => { initializeEditor(); importPose(json); return []; }")
180
- importBtn.click(
181
- fn = None,
182
- inputs = [json],
183
- outputs = [],
184
- _js="(json) => { importPose(json); return []; }")
185
-
186
- saveBtn.click(
187
- fn = None,
188
- inputs = [], outputs = [json],
189
- _js="() => { return [savePose()]; }")
190
- jsonSource.change(
191
- fn = lambda x: x,
192
- inputs = [jsonSource], outputs = [json])
193
- replaceWithJsonBtn.click(
194
- fn = None,
195
- inputs = [json],
196
- outputs = [],
197
- _js="(json) => { initializeEditor(); importPose(json); return []; }")
198
- importJsonBtn.click(
199
- fn = None,
200
- inputs = [json],
201
- outputs = [],
202
- _js="(json) => { importPose(json); return []; }")
203
- demo.load(fn=None, inputs=[], outputs=[], _js="() => { initializeEditor(); importPose(); return []; }")
204
-
205
- gr.mount_gradio_app(app, demo, path="/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ from fastapi import FastAPI, Request, Response
5
+ from src.body import Body
6
+ import json as js
7
+ import requests
8
+ import os
9
+ from typing import Dict, List, Tuple
10
+ import asyncio
11
+ import aiohttp
12
+
13
+ # Initialize body estimation model
14
+ body_estimation = Body('model/body_pose_model.pth')
15
+
16
+ # Fireworks AI configuration
17
+ FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY", "YOUR_API_KEY_HERE")
18
+ FIREWORKS_API_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
19
+
20
+ # OpenPose keypoint definitions
21
+ BODY_PARTS = {
22
+ "Nose": 0, "Neck": 1, "RShoulder": 2, "RElbow": 3, "RWrist": 4,
23
+ "LShoulder": 5, "LElbow": 6, "LWrist": 7, "RHip": 8, "RKnee": 9,
24
+ "RAnkle": 10, "LHip": 11, "LKnee": 12, "LAnkle": 13, "REye": 14,
25
+ "LEye": 15, "REar": 16, "LEar": 17
26
+ }
27
+
28
+ # Pose templates for common positions
29
+ POSE_TEMPLATES = {
30
+ "standing": {
31
+ "keypoints": {
32
+ "Neck": [256, 120],
33
+ "RShoulder": [220, 140], "RElbow": [200, 200], "RWrist": [190, 260],
34
+ "LShoulder": [292, 140], "LElbow": [312, 200], "LWrist": [322, 260],
35
+ "RHip": [230, 280], "RKnee": [225, 380], "RAnkle": [220, 480],
36
+ "LHip": [282, 280], "LKnee": [287, 380], "LAnkle": [292, 480]
37
+ }
38
+ },
39
+ "sitting": {
40
+ "keypoints": {
41
+ "Neck": [256, 180],
42
+ "RShoulder": [220, 200], "RElbow": [200, 260], "RWrist": [190, 320],
43
+ "LShoulder": [292, 200], "LElbow": [312, 260], "LWrist": [322, 320],
44
+ "RHip": [230, 340], "RKnee": [225, 400], "RAnkle": [280, 420],
45
+ "LHip": [282, 340], "LKnee": [287, 400], "LAnkle": [232, 420]
46
+ }
47
+ },
48
+ "running": {
49
+ "keypoints": {
50
+ "Neck": [256, 120],
51
+ "RShoulder": [220, 140], "RElbow": [180, 180], "RWrist": [150, 220],
52
+ "LShoulder": [292, 140], "LElbow": [332, 180], "LWrist": [362, 140],
53
+ "RHip": [230, 280], "RKnee": [260, 380], "RAnkle": [290, 470],
54
+ "LHip": [282, 280], "LKnee": [252, 360], "LAnkle": [222, 440]
55
+ }
56
+ }
57
+ }
58
+
59
+ def pil2cv(image):
60
+ '''PIL型 -> OpenCV型'''
61
+ new_image = np.array(image, dtype=np.uint8)
62
+ if new_image.ndim == 2: # モノクロ
63
+ pass
64
+ elif new_image.shape[2] == 3: # カラー
65
+ new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
66
+ elif new_image.shape[2] == 4: # 透過
67
+ new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
68
+ return new_image
69
+
70
+ async def generate_pose_from_llm(prompt: str) -> Dict:
71
+ """
72
+ LLM을 사용하여 텍스트 프롬프트로부터 포즈 데이터를 생성
73
+ """
74
+ system_prompt = """You are an expert in human pose generation. Given a description, generate precise OpenPose keypoint coordinates.
75
+
76
+ Rules:
77
+ 1. Canvas size is 512x512 pixels
78
+ 2. Return JSON with 18 keypoints (0-17)
79
+ 3. Each keypoint has [x, y, confidence] where confidence is always 1.0
80
+ 4. Maintain anatomically correct proportions
81
+ 5. Center the pose in the canvas
82
+
83
+ Keypoint indices:
84
+ 0: Nose, 1: Neck, 2: Right Shoulder, 3: Right Elbow, 4: Right Wrist,
85
+ 5: Left Shoulder, 6: Left Elbow, 7: Left Wrist, 8: Right Hip, 9: Right Knee,
86
+ 10: Right Ankle, 11: Left Hip, 12: Left Knee, 13: Left Ankle, 14: Right Eye,
87
+ 15: Left Eye, 16: Right Ear, 17: Left Ear
88
+
89
+ Return ONLY valid JSON in this format:
90
+ {
91
+ "candidate": [[x, y, confidence], ...],
92
+ "subset": [[indices of connected keypoints, score, number of keypoints]]
93
+ }"""
94
+
95
+ headers = {
96
+ "Accept": "application/json",
97
+ "Content-Type": "application/json",
98
+ "Authorization": f"Bearer {FIREWORKS_API_KEY}"
99
+ }
100
+
101
+ payload = {
102
+ "model": "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507",
103
+ "max_tokens": 2048,
104
+ "temperature": 0.3,
105
+ "messages": [
106
+ {"role": "system", "content": system_prompt},
107
+ {"role": "user", "content": f"Generate OpenPose keypoints for: {prompt}"}
108
+ ]
109
+ }
110
+
111
+ try:
112
+ async with aiohttp.ClientSession() as session:
113
+ async with session.post(FIREWORKS_API_URL, headers=headers, json=payload) as response:
114
+ if response.status == 200:
115
+ data = await response.json()
116
+ content = data['choices'][0]['message']['content']
117
+
118
+ # Extract JSON from response
119
+ import re
120
+ json_match = re.search(r'\{.*\}', content, re.DOTALL)
121
+ if json_match:
122
+ pose_data = js.loads(json_match.group())
123
+ return pose_data
124
+ else:
125
+ # Fallback to template
126
+ return generate_template_pose(prompt)
127
+ else:
128
+ return generate_template_pose(prompt)
129
+ except Exception as e:
130
+ print(f"LLM Error: {e}")
131
+ return generate_template_pose(prompt)
132
+
133
+ def generate_template_pose(prompt: str) -> Dict:
134
+ """
135
+ 템플릿 기반 포즈 생성 (LLM 실패 시 폴백)
136
+ """
137
+ prompt_lower = prompt.lower()
138
+
139
+ # Detect pose type from prompt
140
+ if any(word in prompt_lower for word in ["sit", "sitting", "seated", "chair"]):
141
+ template = POSE_TEMPLATES["sitting"]
142
+ elif any(word in prompt_lower for word in ["run", "running", "jog", "sprint"]):
143
+ template = POSE_TEMPLATES["running"]
144
+ else:
145
+ template = POSE_TEMPLATES["standing"]
146
+
147
+ # Convert template to OpenPose format
148
+ candidate = []
149
+ for i in range(18):
150
+ if i == 0: # Nose
151
+ candidate.append([256, 100, 1.0])
152
+ elif part_name := next((k for k, v in BODY_PARTS.items() if v == i), None):
153
+ if part_name in template["keypoints"]:
154
+ x, y = template["keypoints"][part_name]
155
+ candidate.append([x, y, 1.0])
156
+ else:
157
+ # Estimate position based on nearby keypoints
158
+ candidate.append([256, 256, 0.0])
159
+ else:
160
+ candidate.append([0, 0, 0.0])
161
+
162
+ # Create subset (connection information)
163
+ subset = [[i for i in range(18) if candidate[i][2] > 0] + [18.0, 18]]
164
+
165
+ return {"candidate": candidate, "subset": subset}
166
+
167
+ def refine_pose_with_llm(current_pose: Dict, refinement_prompt: str) -> Dict:
168
+ """
169
+ LLM을 사용하여 기존 포즈를 세밀하게 조정
170
+ """
171
+ system_prompt = """You are an expert in pose refinement. Given current pose data and adjustment instructions,
172
+ modify the keypoints precisely while maintaining anatomical correctness.
173
+
174
+ Return the modified pose in the same JSON format."""
175
+
176
+ headers = {
177
+ "Accept": "application/json",
178
+ "Content-Type": "application/json",
179
+ "Authorization": f"Bearer {FIREWORKS_API_KEY}"
180
+ }
181
+
182
+ payload = {
183
+ "model": "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507",
184
+ "max_tokens": 2048,
185
+ "temperature": 0.2,
186
+ "messages": [
187
+ {"role": "system", "content": system_prompt},
188
+ {"role": "user", "content": f"Current pose: {js.dumps(current_pose)}\nAdjustment needed: {refinement_prompt}"}
189
+ ]
190
+ }
191
+
192
+ try:
193
+ response = requests.post(FIREWORKS_API_URL, headers=headers, json=payload)
194
+ if response.status_code == 200:
195
+ data = response.json()
196
+ content = data['choices'][0]['message']['content']
197
+
198
+ import re
199
+ json_match = re.search(r'\{.*\}', content, re.DOTALL)
200
+ if json_match:
201
+ return js.loads(json_match.group())
202
+ except Exception as e:
203
+ print(f"Refinement error: {e}")
204
+
205
+ return current_pose
206
+
207
+ # FastAPI setup
208
+ with open("static/poseEditor.js", "r") as f:
209
+ file_contents = f.read()
210
+
211
+ app = FastAPI()
212
+
213
+ @app.middleware("http")
214
+ async def some_fastapi_middleware(request: Request, call_next):
215
+ path = request.scope['path']
216
+ response = await call_next(request)
217
+
218
+ if path == "/":
219
+ response_body = ""
220
+ async for chunk in response.body_iterator:
221
+ response_body += chunk.decode()
222
+
223
+ some_javascript = f"""
224
+ <script type="text/javascript" defer>
225
+ {file_contents}
226
+ </script>
227
+ """
228
+
229
+ response_body = response_body.replace("</body>", some_javascript + "</body>")
230
+ del response.headers["content-length"]
231
+
232
+ return Response(
233
+ content=response_body,
234
+ status_code=response.status_code,
235
+ headers=dict(response.headers),
236
+ media_type=response.media_type
237
+ )
238
+
239
+ return response
240
+
241
+ def candidate_to_json_string(arr):
242
+ a = [f'[{x:.2f}, {y:.2f}]' for x, y, *_ in arr]
243
+ return '[' + ', '.join(a) + ']'
244
+
245
+ def subset_to_json_string(arr):
246
+ arr_str = ','.join(['[' + ','.join([f'{num:.2f}' for num in row]) + ']' for row in arr])
247
+ return '[' + arr_str + ']'
248
+
249
+ def estimate_body(source):
250
+ if source == None:
251
+ return None
252
+
253
+ candidate, subset = body_estimation(pil2cv(source))
254
+ return "{ \"candidate\": " + candidate_to_json_string(candidate) + ", \"subset\": " + subset_to_json_string(subset) + " }"
255
+
256
+ def image_changed(image):
257
+ if image == None:
258
+ return "estimation", {}
259
+
260
+ if 'openpose' in image.info:
261
+ print("pose found")
262
+ jsonText = image.info['openpose']
263
+ jsonObj = js.loads(jsonText)
264
+ subset = jsonObj['subset']
265
+ return f"""{image.width}px x {image.height}px, {len(subset)} individual(s)""", jsonText
266
+ else:
267
+ print("pose not found")
268
+ candidate, subset = body_estimation(pil2cv(image))
269
+ jsonText = "{ \"candidate\": " + candidate_to_json_string(candidate) + ", \"subset\": " + subset_to_json_string(subset) + " }"
270
+ return f"""{image.width}px x {image.height}px, {subset.shape[0]} individual(s)""", jsonText
271
+
272
+ async def generate_pose_from_text(prompt: str, use_llm: bool = True):
273
+ """
274
+ 텍스트 프롬프트로부터 포즈 생성
275
+ """
276
+ if use_llm and FIREWORKS_API_KEY != "YOUR_API_KEY_HERE":
277
+ pose_data = await generate_pose_from_llm(prompt)
278
+ else:
279
+ pose_data = generate_template_pose(prompt)
280
+
281
+ # Format for the pose editor
282
+ if isinstance(pose_data['candidate'], list):
283
+ candidate_str = candidate_to_json_string(pose_data['candidate'])
284
+ else:
285
+ candidate_str = js.dumps(pose_data['candidate'])
286
+
287
+ if isinstance(pose_data['subset'], list):
288
+ subset_str = subset_to_json_string(pose_data['subset'])
289
+ else:
290
+ subset_str = js.dumps(pose_data['subset'])
291
+
292
+ return "{ \"candidate\": " + candidate_str + ", \"subset\": " + subset_str + " }"
293
+
294
+ html_text = f"""
295
+ <canvas id="canvas" width="512" height="512"></canvas>
296
+ <script type="text/javascript" defer>{file_contents}</script>
297
+ """
298
+
299
+ # Gradio interface
300
+ with gr.Blocks(css="""
301
+ button { min-width: 80px; }
302
+ .prompt-box { border: 2px solid #667eea; border-radius: 8px; padding: 10px; }
303
+ .llm-status { color: #667eea; font-weight: bold; }
304
+ """) as demo:
305
+
306
+ gr.Markdown("""
307
+ # 🎨 AI-Powered Pose Generator with LLM
308
+ ### Generate precise line art poses from text descriptions using advanced AI
309
+ """)
310
+
311
+ with gr.Row():
312
+ with gr.Column(scale=1):
313
+ width = gr.Slider(label="Width", minimum=512, maximum=1024, step=64, value=512, interactive=True)
314
+ height = gr.Slider(label="Height", minimum=512, maximum=1024, step=64, value=512, interactive=True)
315
+
316
+ # LLM Pose Generation Section
317
+ with gr.Accordion(label="🤖 AI Pose Generation", open=True):
318
+ prompt_input = gr.Textbox(
319
+ label="Describe the pose",
320
+ placeholder="e.g., 'A person sitting cross-legged in meditation pose' or 'Someone running with arms pumping'",
321
+ lines=3,
322
+ elem_classes=["prompt-box"]
323
+ )
324
+
325
+ with gr.Row():
326
+ use_llm_checkbox = gr.Checkbox(label="Use Advanced LLM", value=True)
327
+ llm_status = gr.Markdown("", elem_classes=["llm-status"])
328
+
329
+ with gr.Row():
330
+ generate_btn = gr.Button("🎯 Generate Pose", variant="primary")
331
+ refine_btn = gr.Button("✨ Refine Current", variant="secondary")
332
+
333
+ refinement_prompt = gr.Textbox(
334
+ label="Refinement instructions",
335
+ placeholder="e.g., 'Raise the left arm higher' or 'Bend the knees more'",
336
+ lines=2,
337
+ visible=False
338
+ )
339
+
340
+ gr.Examples(
341
+ examples=[
342
+ "A person standing with arms raised in victory",
343
+ "Someone sitting at a desk typing on a keyboard",
344
+ "A dancer in arabesque position with one leg extended",
345
+ "A person doing a yoga warrior pose",
346
+ "Someone crouching in a ready position",
347
+ "A person walking casually with relaxed posture"
348
+ ],
349
+ inputs=prompt_input
350
+ )
351
+
352
+ with gr.Accordion(label="📸 Pose Estimation from Image", open=False):
353
+ source = gr.Image(type="pil")
354
+ estimationResult = gr.Markdown("""estimation""")
355
+ with gr.Row():
356
+ with gr.Column(min_width=80):
357
+ applySizeBtn = gr.Button(value="Apply size")
358
+ with gr.Column(min_width=80):
359
+ replaceBtn = gr.Button(value="Replace")
360
+ with gr.Column(min_width=80):
361
+ importBtn = gr.Button(value="Import")
362
+
363
+ with gr.Accordion(label="📋 Json Data", open=False):
364
+ with gr.Row():
365
+ with gr.Column(min_width=80):
366
+ replaceWithJsonBtn = gr.Button(value="Replace")
367
+ with gr.Column(min_width=80):
368
+ importJsonBtn = gr.Button(value="Import")
369
+ gr.Markdown("""
370
+ | Action | Instructions |
371
+ |----------|-------------|
372
+ | Import | Paste JSON and click "Replace" or "Import" |
373
+ | Export | Click "Save" to get pose data |
374
+ """)
375
+ json = gr.JSON(label="Json")
376
+ jsonSource = gr.Textbox(label="Json source", lines=10)
377
+
378
+ with gr.Accordion(label="📝 Notes & Controls", open=False):
379
+ gr.Markdown("""
380
+ #### Keyboard Controls
381
+ - **Ctrl + Drag**: Scale
382
+ - **Alt + Drag**: Move
383
+ - **Shift + Drag**: Rotate
384
+ - **Space + Drag**: Range move
385
+ - **Ctrl + Z/Shift + Ctrl + Z**: Undo/Redo
386
+ - **Ctrl + E**: Add person
387
+ - **D + Click**: Delete person
388
+ - **Q + Click**: Cut off limb
389
+ - **X/C + Drag**: 3D rotation
390
+ - **R + Click**: Repair
391
+
392
+ #### LLM Features
393
+ - Generate complex poses from natural language
394
+ - Refine existing poses with specific instructions
395
+ - Anatomically accurate keypoint generation
396
+ """)
397
+
398
+ with gr.Column(scale=2):
399
+ html = gr.HTML(html_text)
400
+ with gr.Row():
401
+ with gr.Column(scale=1, min_width=60):
402
+ saveBtn = gr.Button(value="💾 Save")
403
+ with gr.Column(scale=7):
404
+ generation_status = gr.Markdown("Ready to generate poses...")
405
+
406
+ # Event handlers
407
+ width.change(fn=None, inputs=[width], _js="(w) => { resizeCanvas(w,null); }")
408
+ height.change(fn=None, inputs=[height], _js="(h) => { resizeCanvas(null,h); }")
409
+
410
+ source.change(
411
+ fn=image_changed,
412
+ inputs=[source],
413
+ outputs=[estimationResult, json]
414
+ )
415
+
416
+ applySizeBtn.click(
417
+ fn=lambda x: (x.width, x.height),
418
+ inputs=[source],
419
+ outputs=[width, height]
420
+ )
421
+
422
+ replaceBtn.click(
423
+ fn=None,
424
+ inputs=[json],
425
+ outputs=[],
426
+ _js="(json) => { initializeEditor(); importPose(json); return []; }"
427
+ )
428
+
429
+ importBtn.click(
430
+ fn=None,
431
+ inputs=[json],
432
+ outputs=[],
433
+ _js="(json) => { importPose(json); return []; }"
434
+ )
435
+
436
+ # LLM generation events
437
+ async def handle_generate(prompt, use_llm):
438
+ if not prompt:
439
+ return None, "⚠️ Please enter a pose description"
440
+
441
+ try:
442
+ status = "🔄 Generating pose with AI..." if use_llm else "🔄 Using template..."
443
+ yield None, status
444
+
445
+ pose_json = await generate_pose_from_text(prompt, use_llm)
446
+ yield pose_json, "✅ Pose generated successfully!"
447
+
448
+ except Exception as e:
449
+ yield None, f"❌ Error: {str(e)}"
450
+
451
+ generate_btn.click(
452
+ fn=handle_generate,
453
+ inputs=[prompt_input, use_llm_checkbox],
454
+ outputs=[json, generation_status]
455
+ ).then(
456
+ fn=None,
457
+ inputs=[json],
458
+ outputs=[],
459
+ _js="(json) => { if(json) { initializeEditor(); importPose(json); } return []; }"
460
+ )
461
+
462
+ def toggle_refinement():
463
+ return gr.update(visible=True)
464
+
465
+ refine_btn.click(
466
+ fn=toggle_refinement,
467
+ outputs=[refinement_prompt]
468
+ )
469
+
470
+ async def handle_refine(current_json, refinement):
471
+ if not current_json or not refinement:
472
+ return None, "⚠️ Need current pose and refinement instructions"
473
+
474
+ try:
475
+ refined = refine_pose_with_llm(current_json, refinement)
476
+ return refined, "✅ Pose refined!"
477
+ except Exception as e:
478
+ return current_json, f"❌ Refinement error: {str(e)}"
479
+
480
+ refinement_prompt.submit(
481
+ fn=handle_refine,
482
+ inputs=[json, refinement_prompt],
483
+ outputs=[json, generation_status]
484
+ ).then(
485
+ fn=None,
486
+ inputs=[json],
487
+ outputs=[],
488
+ _js="(json) => { if(json) { importPose(json); } return []; }"
489
+ )
490
+
491
+ saveBtn.click(
492
+ fn=None,
493
+ inputs=[],
494
+ outputs=[json],
495
+ _js="() => { return [savePose()]; }"
496
+ )
497
+
498
+ jsonSource.change(
499
+ fn=lambda x: x,
500
+ inputs=[jsonSource],
501
+ outputs=[json]
502
+ )
503
+
504
+ replaceWithJsonBtn.click(
505
+ fn=None,
506
+ inputs=[json],
507
+ outputs=[],
508
+ _js="(json) => { initializeEditor(); importPose(json); return []; }"
509
+ )
510
+
511
+ importJsonBtn.click(
512
+ fn=None,
513
+ inputs=[json],
514
+ outputs=[],
515
+ _js="(json) => { importPose(json); return []; }"
516
+ )
517
+
518
+ demo.load(
519
+ fn=None,
520
+ inputs=[],
521
+ outputs=[],
522
+ _js="() => { initializeEditor(); importPose(); return []; }"
523
+ )
524
+
525
+ # Check API key status on load
526
+ def check_api_status():
527
+ if FIREWORKS_API_KEY == "YOUR_API_KEY_HERE":
528
+ return "⚠️ LLM API key not configured - using templates"
529
+ return "✅ LLM ready"
530
+
531
+ demo.load(fn=check_api_status, outputs=[llm_status])
532
+
533
+ gr.mount_gradio_app(app, demo, path="/")