aiqtech commited on
Commit
f15b9a3
·
verified ·
1 Parent(s): d9e7eed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -261
app.py CHANGED
@@ -1,11 +1,8 @@
1
  import gradio as gr
2
- import numpy as np
3
  import json
4
  import requests
5
  import os
6
- from typing import Dict, List, Tuple
7
  from PIL import Image, ImageDraw
8
- import io
9
 
10
  # Fireworks AI configuration
11
  FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY", "YOUR_API_KEY_HERE")
@@ -19,7 +16,7 @@ BODY_PARTS = {
19
  "LEye": 15, "REar": 16, "LEar": 17
20
  }
21
 
22
- # Skeleton connections for drawing
23
  POSE_CONNECTIONS = [
24
  ("Neck", "RShoulder"), ("RShoulder", "RElbow"), ("RElbow", "RWrist"),
25
  ("Neck", "LShoulder"), ("LShoulder", "LElbow"), ("LElbow", "LWrist"),
@@ -82,7 +79,7 @@ POSE_TEMPLATES = {
82
  }
83
  }
84
 
85
- def draw_pose(keypoints: Dict, width: int = 512, height: int = 512) -> Image.Image:
86
  """Draw pose skeleton on image"""
87
  img = Image.new('RGB', (width, height), color='white')
88
  draw = ImageDraw.Draw(img)
@@ -101,12 +98,10 @@ def draw_pose(keypoints: Dict, width: int = 512, height: int = 512) -> Image.Ima
101
  x, y = point
102
  radius = 5
103
  draw.ellipse([x-radius, y-radius, x+radius, y+radius], fill='red', outline='darkred')
104
- # Add label
105
- draw.text((x+8, y-8), part[:3], fill='black')
106
 
107
  return img
108
 
109
- def generate_pose_from_llm(prompt: str) -> Dict:
110
  """Generate pose using LLM"""
111
  system_prompt = """You are an expert in generating human pose keypoints.
112
  Given a description, generate 18 keypoint coordinates for OpenPose.
@@ -116,9 +111,7 @@ def generate_pose_from_llm(prompt: str) -> Dict:
116
  RHip, RKnee, RAnkle, LHip, LKnee, LAnkle, REye, LEye, REar, LEar
117
 
118
  Return ONLY a JSON object with keypoint names and [x, y] coordinates.
119
- Example: {"Nose": [256, 80], "Neck": [256, 120], ...}
120
-
121
- Ensure anatomically correct proportions and center the pose."""
122
 
123
  headers = {
124
  "Accept": "application/json",
@@ -142,7 +135,6 @@ def generate_pose_from_llm(prompt: str) -> Dict:
142
  data = response.json()
143
  content = data['choices'][0]['message']['content']
144
 
145
- # Extract JSON from response
146
  import re
147
  json_match = re.search(r'\{.*\}', content, re.DOTALL)
148
  if json_match:
@@ -151,11 +143,13 @@ def generate_pose_from_llm(prompt: str) -> Dict:
151
  except Exception as e:
152
  print(f"LLM Error: {e}")
153
 
154
- # Fallback to template
155
  return get_template_from_prompt(prompt)
156
 
157
- def get_template_from_prompt(prompt: str) -> Dict:
158
- """Select appropriate template based on prompt"""
 
 
 
159
  prompt_lower = prompt.lower()
160
 
161
  if any(word in prompt_lower for word in ["sit", "chair", "seated"]):
@@ -171,56 +165,41 @@ def get_template_from_prompt(prompt: str) -> Dict:
171
  else:
172
  return POSE_TEMPLATES["Standing"]
173
 
174
- def refine_pose(current_keypoints: Dict, instruction: str) -> Dict:
175
- """Refine existing pose based on instruction"""
 
 
 
176
  keypoints = current_keypoints.copy()
177
  instruction_lower = instruction.lower()
178
 
179
- # Simple rule-based refinement
180
  if "raise" in instruction_lower or "lift" in instruction_lower:
181
- if "arm" in instruction_lower or "hand" in instruction_lower:
182
  if "left" in instruction_lower:
183
  if "LWrist" in keypoints:
184
  keypoints["LWrist"][1] -= 50
185
- if "LElbow" in keypoints:
186
- keypoints["LElbow"][1] -= 30
187
  elif "right" in instruction_lower:
188
  if "RWrist" in keypoints:
189
  keypoints["RWrist"][1] -= 50
190
- if "RElbow" in keypoints:
191
- keypoints["RElbow"][1] -= 30
192
- else: # Both arms
193
  for part in ["LWrist", "RWrist"]:
194
  if part in keypoints:
195
  keypoints[part][1] -= 50
196
- for part in ["LElbow", "RElbow"]:
197
- if part in keypoints:
198
- keypoints[part][1] -= 30
199
 
200
  elif "lower" in instruction_lower:
201
- if "arm" in instruction_lower or "hand" in instruction_lower:
202
  for part in ["LWrist", "RWrist"]:
203
  if part in keypoints:
204
  keypoints[part][1] += 50
205
 
206
- elif "spread" in instruction_lower or "wide" in instruction_lower:
207
- if "leg" in instruction_lower:
208
- if "LAnkle" in keypoints:
209
- keypoints["LAnkle"][0] -= 30
210
- if "RAnkle" in keypoints:
211
- keypoints["RAnkle"][0] += 30
212
-
213
- elif "bend" in instruction_lower:
214
- if "knee" in instruction_lower:
215
- for part in ["LKnee", "RKnee"]:
216
- if part in keypoints:
217
- keypoints[part][1] += 20
218
- keypoints[part][0] += 10 if "L" in part else -10
219
-
220
  return keypoints
221
 
222
- def keypoints_to_openpose_format(keypoints: Dict) -> str:
223
- """Convert keypoints to OpenPose JSON format"""
 
 
 
224
  candidate = []
225
  for i in range(18):
226
  part_name = None
@@ -240,231 +219,115 @@ def keypoints_to_openpose_format(keypoints: Dict) -> str:
240
 
241
  return json.dumps({"candidate": candidate, "subset": subset}, indent=2)
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  # Create Gradio interface
244
- def create_demo():
245
- with gr.Blocks(title="AI Pose Generator", theme=gr.themes.Soft()) as demo:
246
- current_keypoints = gr.State({})
247
-
248
- gr.Markdown("""
249
- # 🎨 AI Line Art Pose Generator
250
- ### Generate precise poses from text descriptions using AI
251
- """)
252
-
253
- with gr.Tabs():
254
- with gr.TabItem("🤖 Generate Pose"):
255
- with gr.Row():
256
- with gr.Column(scale=1):
257
- # Input section
258
- use_llm = gr.Checkbox(
259
- label="🚀 Use Advanced AI Model (Fireworks API)",
260
- value=False,
261
- info="Enable for more accurate pose generation (requires API key)"
262
- )
263
-
264
- api_status = gr.Markdown("⚠️ API key not set - Template mode active")
265
-
266
- prompt = gr.Textbox(
267
- label="Pose Description",
268
- placeholder="e.g., A person sitting cross-legged reading a book",
269
- lines=3
270
- )
271
-
272
- gr.Examples(
273
- examples=[
274
- "A person standing with arms raised in victory",
275
- "Someone sitting at a desk typing on a laptop",
276
- "A dancer in arabesque position",
277
- "A person doing yoga warrior pose",
278
- "Someone waving hello",
279
- "A person running with arms pumping"
280
- ],
281
- inputs=prompt
282
- )
283
-
284
- generate_btn = gr.Button("🎯 Generate Pose", variant="primary", size="lg")
285
-
286
- # Template selection
287
- with gr.Accordion("📚 Quick Templates", open=False):
288
- template_select = gr.Dropdown(
289
- choices=list(POSE_TEMPLATES.keys()),
290
- label="Select Template",
291
- value="Standing"
292
- )
293
- use_template_btn = gr.Button("Apply Template")
294
-
295
- with gr.Column(scale=1):
296
- # Output section
297
- pose_image = gr.Image(
298
- label="Generated Pose",
299
- type="pil",
300
- height=512
301
- )
302
-
303
- with gr.Accordion("📋 OpenPose JSON", open=False):
304
- json_output = gr.Code(
305
- label="JSON Data",
306
- language="json",
307
- lines=10
308
- )
309
-
310
- with gr.TabItem("✏️ Refine Pose"):
311
- with gr.Row():
312
- with gr.Column():
313
- refinement_instruction = gr.Textbox(
314
- label="Refinement Instructions",
315
- placeholder="e.g., Raise the left arm higher",
316
- lines=2
317
- )
318
-
319
- refine_btn = gr.Button("✨ Refine Pose", variant="secondary")
320
-
321
- gr.Markdown("""
322
- **Quick commands:**
323
- - "Raise left/right arm"
324
- - "Lower arms"
325
- - "Spread legs wider"
326
- - "Bend knees"
327
- """)
328
-
329
- # Manual adjustment
330
- with gr.Accordion("🎛️ Manual Adjustment", open=False):
331
- selected_part = gr.Dropdown(
332
- choices=list(BODY_PARTS.keys()),
333
- label="Select Body Part",
334
- value="RWrist"
335
- )
336
- x_adjust = gr.Slider(-100, 100, 0, label="X Adjustment")
337
- y_adjust = gr.Slider(-100, 100, 0, label="Y Adjustment")
338
- apply_adjust_btn = gr.Button("Apply Adjustment")
339
-
340
- with gr.Column():
341
- refined_image = gr.Image(
342
- label="Refined Pose",
343
- type="pil",
344
- height=512
345
- )
346
-
347
- with gr.Accordion("📋 Updated JSON", open=False):
348
- refined_json = gr.Code(
349
- label="JSON Data",
350
- language="json",
351
- lines=10
352
- )
353
 
354
- with gr.TabItem("ℹ️ Help"):
355
- gr.Markdown("""
356
- ## How to Use
357
-
358
- ### 1. Generate Pose
359
- - Enter a natural language description of the pose
360
- - Click "Generate Pose" to create the pose
361
- - Or select a template for quick start
362
-
363
- ### 2. Refine Pose (Optional)
364
- - Use natural language commands to adjust the pose
365
- - Or manually adjust individual body parts
366
-
367
- ### 3. Export
368
- - Copy the OpenPose JSON format for use in other applications
369
- - Compatible with ControlNet and other pose-based tools
370
-
371
- ### API Setup (Optional)
372
- For better results, set up Fireworks API:
373
- ```bash
374
- export FIREWORKS_API_KEY="your_api_key"
375
- ```
376
-
377
- ### Features
378
- - 🚀 No GPU required - runs on CPU
379
- - 🎨 Clean line art style
380
- - 📊 OpenPose compatible format
381
- - 🔧 Easy refinement tools
382
- - 💾 JSON export for integration
383
- """)
384
-
385
- # Event handlers
386
- def check_api_status():
387
- if FIREWORKS_API_KEY != "YOUR_API_KEY_HERE":
388
- return "✅ API key configured - Advanced AI ready"
389
- return "⚠️ API key not set - Template mode active"
390
-
391
- def generate_pose(prompt_text, use_llm_flag):
392
- if not prompt_text:
393
- keypoints = POSE_TEMPLATES["Standing"]
394
- elif use_llm_flag and FIREWORKS_API_KEY != "YOUR_API_KEY_HERE":
395
- keypoints = generate_pose_from_llm(prompt_text)
396
- else:
397
- keypoints = get_template_from_prompt(prompt_text)
398
 
399
- pose_img = draw_pose(keypoints)
400
- json_str = keypoints_to_openpose_format(keypoints)
 
 
 
401
 
402
- return pose_img, json_str, keypoints
403
-
404
- def use_template(template_name):
405
- keypoints = POSE_TEMPLATES[template_name]
406
- pose_img = draw_pose(keypoints)
407
- json_str = keypoints_to_openpose_format(keypoints)
408
- return pose_img, json_str, keypoints
409
-
410
- def refine_existing_pose(instruction, keypoints_state):
411
- if not keypoints_state:
412
- gr.Warning("Please generate a pose first")
413
- return None, None, keypoints_state
414
 
415
- refined_keypoints = refine_pose(keypoints_state, instruction)
416
- pose_img = draw_pose(refined_keypoints)
417
- json_str = keypoints_to_openpose_format(refined_keypoints)
418
- return pose_img, json_str, refined_keypoints
419
-
420
- def manual_adjust(part, x_adj, y_adj, keypoints_state):
421
- if not keypoints_state:
422
- gr.Warning("Please generate a pose first")
423
- return None, None, keypoints_state
424
 
425
- if part not in keypoints_state:
426
- gr.Warning(f"Part {part} not found in current pose")
427
- return None, None, keypoints_state
 
 
 
 
428
 
429
- adjusted_keypoints = keypoints_state.copy()
430
- adjusted_keypoints[part][0] += x_adj
431
- adjusted_keypoints[part][1] += y_adj
432
 
433
- pose_img = draw_pose(adjusted_keypoints)
434
- json_str = keypoints_to_openpose_format(adjusted_keypoints)
435
- return pose_img, json_str, adjusted_keypoints
 
 
 
 
 
 
 
 
436
 
437
- # Connect events
438
- demo.load(check_api_status, outputs=api_status)
439
-
440
- generate_btn.click(
441
- generate_pose,
442
- inputs=[prompt, use_llm],
443
- outputs=[pose_image, json_output, current_keypoints]
444
- )
445
-
446
- use_template_btn.click(
447
- use_template,
448
- inputs=[template_select],
449
- outputs=[pose_image, json_output, current_keypoints]
450
- )
451
-
452
- refine_btn.click(
453
- refine_existing_pose,
454
- inputs=[refinement_instruction, current_keypoints],
455
- outputs=[refined_image, refined_json, current_keypoints]
456
- )
457
-
458
- apply_adjust_btn.click(
459
- manual_adjust,
460
- inputs=[selected_part, x_adjust, y_adjust, current_keypoints],
461
- outputs=[refined_image, refined_json, current_keypoints]
462
- )
463
-
464
- return demo
465
-
466
- # Create and launch the app
467
- app = create_demo()
468
 
 
469
  if __name__ == "__main__":
470
- app.launch()
 
1
  import gradio as gr
 
2
  import json
3
  import requests
4
  import os
 
5
  from PIL import Image, ImageDraw
 
6
 
7
  # Fireworks AI configuration
8
  FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY", "YOUR_API_KEY_HERE")
 
16
  "LEye": 15, "REar": 16, "LEar": 17
17
  }
18
 
19
+ # Skeleton connections
20
  POSE_CONNECTIONS = [
21
  ("Neck", "RShoulder"), ("RShoulder", "RElbow"), ("RElbow", "RWrist"),
22
  ("Neck", "LShoulder"), ("LShoulder", "LElbow"), ("LElbow", "LWrist"),
 
79
  }
80
  }
81
 
82
+ def draw_pose(keypoints, width=512, height=512):
83
  """Draw pose skeleton on image"""
84
  img = Image.new('RGB', (width, height), color='white')
85
  draw = ImageDraw.Draw(img)
 
98
  x, y = point
99
  radius = 5
100
  draw.ellipse([x-radius, y-radius, x+radius, y+radius], fill='red', outline='darkred')
 
 
101
 
102
  return img
103
 
104
+ def generate_pose_from_llm(prompt):
105
  """Generate pose using LLM"""
106
  system_prompt = """You are an expert in generating human pose keypoints.
107
  Given a description, generate 18 keypoint coordinates for OpenPose.
 
111
  RHip, RKnee, RAnkle, LHip, LKnee, LAnkle, REye, LEye, REar, LEar
112
 
113
  Return ONLY a JSON object with keypoint names and [x, y] coordinates.
114
+ Example: {"Nose": [256, 80], "Neck": [256, 120], ...}"""
 
 
115
 
116
  headers = {
117
  "Accept": "application/json",
 
135
  data = response.json()
136
  content = data['choices'][0]['message']['content']
137
 
 
138
  import re
139
  json_match = re.search(r'\{.*\}', content, re.DOTALL)
140
  if json_match:
 
143
  except Exception as e:
144
  print(f"LLM Error: {e}")
145
 
 
146
  return get_template_from_prompt(prompt)
147
 
148
+ def get_template_from_prompt(prompt):
149
+ """Select template based on prompt"""
150
+ if not prompt:
151
+ return POSE_TEMPLATES["Standing"]
152
+
153
  prompt_lower = prompt.lower()
154
 
155
  if any(word in prompt_lower for word in ["sit", "chair", "seated"]):
 
165
  else:
166
  return POSE_TEMPLATES["Standing"]
167
 
168
+ def refine_pose(current_keypoints, instruction):
169
+ """Refine existing pose"""
170
+ if not current_keypoints or not instruction:
171
+ return current_keypoints
172
+
173
  keypoints = current_keypoints.copy()
174
  instruction_lower = instruction.lower()
175
 
176
+ # Simple refinements
177
  if "raise" in instruction_lower or "lift" in instruction_lower:
178
+ if "arm" in instruction_lower:
179
  if "left" in instruction_lower:
180
  if "LWrist" in keypoints:
181
  keypoints["LWrist"][1] -= 50
 
 
182
  elif "right" in instruction_lower:
183
  if "RWrist" in keypoints:
184
  keypoints["RWrist"][1] -= 50
185
+ else:
 
 
186
  for part in ["LWrist", "RWrist"]:
187
  if part in keypoints:
188
  keypoints[part][1] -= 50
 
 
 
189
 
190
  elif "lower" in instruction_lower:
191
+ if "arm" in instruction_lower:
192
  for part in ["LWrist", "RWrist"]:
193
  if part in keypoints:
194
  keypoints[part][1] += 50
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  return keypoints
197
 
198
+ def keypoints_to_openpose_format(keypoints):
199
+ """Convert to OpenPose JSON format"""
200
+ if not keypoints:
201
+ return "{}"
202
+
203
  candidate = []
204
  for i in range(18):
205
  part_name = None
 
219
 
220
  return json.dumps({"candidate": candidate, "subset": subset}, indent=2)
221
 
222
+ # Main generation function
223
+ def generate_pose(prompt, use_llm, template):
224
+ """Main function to generate pose"""
225
+ if template and template != "None":
226
+ keypoints = POSE_TEMPLATES[template]
227
+ elif use_llm and FIREWORKS_API_KEY != "YOUR_API_KEY_HERE" and prompt:
228
+ keypoints = generate_pose_from_llm(prompt)
229
+ elif prompt:
230
+ keypoints = get_template_from_prompt(prompt)
231
+ else:
232
+ keypoints = POSE_TEMPLATES["Standing"]
233
+
234
+ pose_img = draw_pose(keypoints)
235
+ json_str = keypoints_to_openpose_format(keypoints)
236
+
237
+ return pose_img, json_str, keypoints
238
+
239
+ def refine_existing_pose(instruction, keypoints_json):
240
+ """Refine pose with instruction"""
241
+ if not keypoints_json:
242
+ return None, "{}", {}
243
+
244
+ refined_keypoints = refine_pose(keypoints_json, instruction)
245
+ pose_img = draw_pose(refined_keypoints)
246
+ json_str = keypoints_to_openpose_format(refined_keypoints)
247
+
248
+ return pose_img, json_str, refined_keypoints
249
+
250
+ def check_api_status():
251
+ """Check if API key is configured"""
252
+ if FIREWORKS_API_KEY != "YOUR_API_KEY_HERE":
253
+ return "✅ API key configured - Advanced AI ready"
254
+ return "⚠️ API key not set - Template mode active"
255
+
256
  # Create Gradio interface
257
+ with gr.Blocks(title="AI Pose Generator") as demo:
258
+ keypoints_state = gr.State({})
259
+
260
+ gr.Markdown("""
261
+ # 🎨 AI Line Art Pose Generator
262
+ ### Generate precise poses from text descriptions
263
+ """)
264
+
265
+ with gr.Row():
266
+ with gr.Column(scale=1):
267
+ # Input controls
268
+ api_status = gr.Markdown(check_api_status())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
+ use_llm = gr.Checkbox(
271
+ label="Use Advanced AI (Fireworks API)",
272
+ value=False
273
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
+ prompt = gr.Textbox(
276
+ label="Describe the pose",
277
+ placeholder="e.g., A person sitting and reading a book",
278
+ lines=2
279
+ )
280
 
281
+ template = gr.Dropdown(
282
+ choices=["None"] + list(POSE_TEMPLATES.keys()),
283
+ label="Or select a template",
284
+ value="None"
285
+ )
 
 
 
 
 
 
 
286
 
287
+ generate_btn = gr.Button("🎯 Generate Pose", variant="primary")
 
 
 
 
 
 
 
 
288
 
289
+ # Refinement
290
+ gr.Markdown("### Refine Pose")
291
+ refinement = gr.Textbox(
292
+ label="Refinement instruction",
293
+ placeholder="e.g., Raise the left arm",
294
+ lines=1
295
+ )
296
 
297
+ refine_btn = gr.Button("✨ Refine", variant="secondary")
 
 
298
 
299
+ # Examples
300
+ gr.Examples(
301
+ examples=[
302
+ "A person standing with arms raised",
303
+ "Someone sitting at a desk",
304
+ "A person doing yoga",
305
+ "Someone waving hello",
306
+ "A person running"
307
+ ],
308
+ inputs=prompt
309
+ )
310
 
311
+ with gr.Column(scale=1):
312
+ # Output
313
+ pose_image = gr.Image(label="Generated Pose", type="pil")
314
+
315
+ with gr.Accordion("OpenPose JSON", open=False):
316
+ json_output = gr.Code(language="json", lines=10)
317
+
318
+ # Event handlers
319
+ generate_btn.click(
320
+ fn=generate_pose,
321
+ inputs=[prompt, use_llm, template],
322
+ outputs=[pose_image, json_output, keypoints_state]
323
+ )
324
+
325
+ refine_btn.click(
326
+ fn=refine_existing_pose,
327
+ inputs=[refinement, keypoints_state],
328
+ outputs=[pose_image, json_output, keypoints_state]
329
+ )
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
+ # Launch the app
332
  if __name__ == "__main__":
333
+ demo.launch()