3v324v23 commited on
Commit
0b95948
Β·
1 Parent(s): 46fe1ee
Files changed (1) hide show
  1. app.py +111 -278
app.py CHANGED
@@ -1,333 +1,166 @@
1
  """
2
  OpenPose Preprocessor for ControlNet
3
- A Gradio application for pose detection with multiple models and customization options.
4
  """
5
 
6
  import gradio as gr
7
  import numpy as np
8
  from PIL import Image
9
  import torch
10
- import json
11
- from typing import Tuple, Optional, Dict, Any
12
 
13
  # Global device detection
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
  print(f"Using device: {DEVICE}")
16
 
17
- # Model cache to avoid reloading
18
- _model_cache: Dict[str, Any] = {}
 
19
 
20
 
21
  def get_openpose_detector():
22
  """Get or create OpenPose detector."""
23
- if "openpose" not in _model_cache:
 
24
  from controlnet_aux import OpenposeDetector
25
- _model_cache["openpose"] = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
26
- return _model_cache["openpose"]
27
 
28
 
29
  def get_dwpose_detector():
30
  """Get or create DWPose detector."""
31
- if "dwpose" not in _model_cache:
 
32
  from controlnet_aux import DWposeDetector
33
- _model_cache["dwpose"] = DWposeDetector.from_pretrained("yolox_l.onnx", "dw-ll_ucoco_384.onnx")
34
- return _model_cache["dwpose"]
35
 
36
 
37
- def process_with_openpose(
38
- image: Image.Image,
39
- mode: str,
40
- detect_hand: bool,
41
- detect_face: bool,
42
- detect_resolution: int,
43
- ) -> Tuple[Image.Image, Optional[dict]]:
44
- """Process image using OpenPose detector."""
45
- detector = get_openpose_detector()
46
-
47
- # Determine hand_and_face parameter based on mode and toggles
48
- if mode == "OpenPose (Full)":
49
- hand_and_face = True
50
- elif mode == "OpenPose (Hand)":
51
- hand_and_face = detect_hand
52
- elif mode == "OpenPose (Face)":
53
- hand_and_face = detect_face
54
- elif mode == "OpenPose (Face Only)":
55
- # Face only mode
56
- result = detector(
57
- image,
58
- detect_resolution=detect_resolution,
59
- include_body=False,
60
- include_hand=False,
61
- include_face=True,
62
- output_type="pil"
63
- )
64
- return result, None
65
- else:
66
- # Basic OpenPose
67
- hand_and_face = detect_hand and detect_face
68
-
69
- result = detector(
70
- image,
71
- detect_resolution=detect_resolution,
72
- hand_and_face=hand_and_face,
73
- output_type="pil"
74
- )
75
-
76
- return result, None
77
-
78
-
79
- def process_with_dwpose(
80
- image: Image.Image,
81
- detect_hand: bool,
82
- detect_face: bool,
83
- detect_resolution: int,
84
- ) -> Tuple[Image.Image, Optional[dict]]:
85
- """Process image using DWPose detector."""
86
- detector = get_dwpose_detector()
87
-
88
- # controlnet-aux DWposeDetector API
89
- result = detector(
90
- image,
91
- detect_resolution=detect_resolution,
92
- image_resolution=detect_resolution,
93
- include_hand=detect_hand,
94
- include_face=detect_face,
95
- include_body=True,
96
- output_type="pil"
97
- )
98
-
99
- return result, None
100
-
101
-
102
- def detect_pose(
103
- image: Image.Image,
104
- model_type: str,
105
- detect_hand: bool,
106
- detect_face: bool,
107
- detect_resolution: int,
108
- output_resolution: int,
109
- output_format: str,
110
- ) -> Tuple[Optional[Image.Image], str]:
111
- """
112
- Main pose detection function.
113
-
114
- Args:
115
- image: Input PIL Image
116
- model_type: Selected model type
117
- detect_hand: Whether to detect hands
118
- detect_face: Whether to detect face
119
- detect_resolution: Resolution for detection
120
- output_resolution: Resolution for output image
121
- output_format: "Image", "JSON", or "Both"
122
-
123
- Returns:
124
- Tuple of (output_image, json_string)
125
- """
126
  if image is None:
127
- return None, "Please upload an image first."
128
 
129
  try:
 
 
 
 
130
  # Convert to RGB if necessary
131
  if image.mode != "RGB":
132
  image = image.convert("RGB")
133
 
134
  # Process based on model type
135
  if model_type == "DWPose":
136
- result_image, keypoints = process_with_dwpose(
137
- image, detect_hand, detect_face, detect_resolution
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  )
139
  else:
140
- result_image, keypoints = process_with_openpose(
141
- image, model_type, detect_hand, detect_face, detect_resolution
 
 
 
 
 
142
  )
143
 
144
- # Resize output if needed
145
- if output_resolution > 0:
146
- orig_w, orig_h = result_image.size
147
- scale = output_resolution / max(orig_w, orig_h)
148
- new_w, new_h = int(orig_w * scale), int(orig_h * scale)
149
- result_image = result_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
150
-
151
- # Prepare outputs based on format
152
- json_output = ""
153
- if output_format == "JSON" or output_format == "Both":
154
- json_output = json.dumps({
155
- "model": model_type,
156
- "detect_hand": detect_hand,
157
- "detect_face": detect_face,
158
- "detect_resolution": detect_resolution,
159
- "output_resolution": output_resolution,
160
- "device": DEVICE,
161
- "status": "success",
162
- "note": "Keypoint extraction requires additional processing. Use the output image for ControlNet."
163
- }, indent=2, ensure_ascii=False)
164
-
165
- if output_format == "JSON":
166
- return None, json_output
167
- elif output_format == "Image":
168
- return result_image, "Processing complete. Image ready for ControlNet."
169
- else: # Both
170
- return result_image, json_output
171
 
172
  except Exception as e:
173
- error_msg = f"Error during processing: {str(e)}"
174
- return None, error_msg
175
 
176
 
177
- def create_ui() -> gr.Blocks:
178
- """Create the Gradio UI."""
179
-
180
- css = """
181
- .main-title {
182
- text-align: center;
183
- margin-bottom: 1rem;
184
- }
185
- .settings-panel {
186
- background: var(--background-fill-secondary);
187
- padding: 1rem;
188
- border-radius: 8px;
189
- }
190
- """
191
 
192
- with gr.Blocks(
193
- title="🦴 OpenPose Preprocessor",
194
- css=css,
195
- theme=gr.themes.Soft()
196
- ) as demo:
197
 
198
- # Header
199
- gr.Markdown(
200
- """
201
- # 🦴 OpenPose Preprocessor for ControlNet
 
 
 
 
 
202
 
203
- High-quality pose detection with multiple models and customization options.
204
- Upload an image and get pose skeleton for ControlNet.
205
- """
206
- )
207
-
208
- # Device info
209
- gr.Markdown(f"**Device**: `{DEVICE}` {'πŸš€' if DEVICE == 'cuda' else '🐒'}")
210
-
211
- with gr.Row():
212
- # Left column - Input
213
- with gr.Column(scale=1):
214
- input_image = gr.Image(
215
- label="πŸ“· Input Image",
216
- type="pil",
217
- height=400
218
- )
219
-
220
- # Settings
221
- with gr.Accordion("βš™οΈ Settings", open=True):
222
- model_type = gr.Dropdown(
223
- label="πŸ€– Model",
224
- choices=[
225
- "DWPose",
226
- "OpenPose",
227
- "OpenPose (Face)",
228
- "OpenPose (Hand)",
229
- "OpenPose (Full)",
230
- "OpenPose (Face Only)"
231
- ],
232
- value="DWPose",
233
- info="DWPose is recommended for better accuracy"
234
- )
235
-
236
- with gr.Row():
237
- detect_hand = gr.Checkbox(
238
- label="πŸ‘† Detect Hands",
239
- value=True
240
- )
241
- detect_face = gr.Checkbox(
242
- label="😊 Detect Face",
243
- value=True
244
- )
245
-
246
- detect_resolution = gr.Slider(
247
- label="πŸ“ Detection Resolution",
248
- minimum=256,
249
- maximum=2048,
250
- value=512,
251
- step=64,
252
- info="Higher = more accurate but slower"
253
- )
254
-
255
- output_resolution = gr.Slider(
256
- label="πŸ–ΌοΈ Output Resolution",
257
- minimum=256,
258
- maximum=2048,
259
- value=512,
260
- step=64,
261
- info="Final output image resolution"
262
- )
263
-
264
- output_format = gr.Radio(
265
- label="πŸ“Š Output Format",
266
- choices=["Image", "JSON", "Both"],
267
- value="Both"
268
- )
269
-
270
- # Process button
271
- process_btn = gr.Button(
272
- "πŸš€ Detect Pose",
273
- variant="primary",
274
- size="lg"
275
- )
276
 
277
- # Right column - Output
278
- with gr.Column(scale=1):
279
- output_image = gr.Image(
280
- label="🎨 Output Pose",
281
- type="pil",
282
- height=400
283
- )
284
-
285
- output_json = gr.Textbox(
286
- label="πŸ“‹ Output Info",
287
- lines=8,
288
- max_lines=15
289
- )
290
-
291
- # Examples
292
- gr.Markdown("### πŸ“Œ Tips")
293
- gr.Markdown(
294
- """
295
- - **DWPose** is recommended for best accuracy, especially for hands
296
- - **OpenPose (Full)** detects body, face, and hands together
297
- - Higher **Detection Resolution** improves accuracy but increases processing time
298
- - The output image can be directly used with ControlNet OpenPose models
299
- """
300
- )
301
-
302
- # Connect events
303
- process_btn.click(
304
- fn=detect_pose,
305
- inputs=[
306
- input_image,
307
- model_type,
308
- detect_hand,
309
- detect_face,
310
- detect_resolution,
311
- output_resolution,
312
- output_format,
313
- ],
314
- outputs=[output_image, output_json]
315
- )
316
 
317
- # Also clear output on image upload for convenience
318
- input_image.change(
319
- fn=lambda x: (None, ""),
320
- inputs=[input_image],
321
- outputs=[output_image, output_json]
322
- )
 
 
 
 
 
 
323
 
324
- return demo
 
 
 
 
325
 
326
 
327
  if __name__ == "__main__":
328
- demo = create_ui()
329
  demo.launch(
330
  server_name="0.0.0.0",
331
  server_port=7860,
332
- share=False
333
  )
 
1
  """
2
  OpenPose Preprocessor for ControlNet
3
+ A simple Gradio application for pose detection.
4
  """
5
 
6
  import gradio as gr
7
  import numpy as np
8
  from PIL import Image
9
  import torch
 
 
10
 
11
  # Global device detection
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
  print(f"Using device: {DEVICE}")
14
 
15
+ # Model cache
16
+ _openpose_detector = None
17
+ _dwpose_detector = None
18
 
19
 
20
  def get_openpose_detector():
21
  """Get or create OpenPose detector."""
22
+ global _openpose_detector
23
+ if _openpose_detector is None:
24
  from controlnet_aux import OpenposeDetector
25
+ _openpose_detector = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
26
+ return _openpose_detector
27
 
28
 
29
  def get_dwpose_detector():
30
  """Get or create DWPose detector."""
31
+ global _dwpose_detector
32
+ if _dwpose_detector is None:
33
  from controlnet_aux import DWposeDetector
34
+ _dwpose_detector = DWposeDetector.from_pretrained("yolox_l.onnx", "dw-ll_ucoco_384.onnx")
35
+ return _dwpose_detector
36
 
37
 
38
+ def detect_pose(image, model_type, detect_hand, detect_face, detect_resolution):
39
+ """Main pose detection function."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  if image is None:
41
+ return None
42
 
43
  try:
44
+ # Convert to PIL if needed
45
+ if isinstance(image, np.ndarray):
46
+ image = Image.fromarray(image)
47
+
48
  # Convert to RGB if necessary
49
  if image.mode != "RGB":
50
  image = image.convert("RGB")
51
 
52
  # Process based on model type
53
  if model_type == "DWPose":
54
+ detector = get_dwpose_detector()
55
+ result = detector(
56
+ image,
57
+ detect_resolution=detect_resolution,
58
+ image_resolution=detect_resolution,
59
+ include_hand=detect_hand,
60
+ include_face=detect_face,
61
+ include_body=True,
62
+ output_type="pil"
63
+ )
64
+ elif model_type == "OpenPose (Full)":
65
+ detector = get_openpose_detector()
66
+ result = detector(
67
+ image,
68
+ detect_resolution=detect_resolution,
69
+ hand_and_face=True,
70
+ output_type="pil"
71
+ )
72
+ elif model_type == "OpenPose (Face Only)":
73
+ detector = get_openpose_detector()
74
+ result = detector(
75
+ image,
76
+ detect_resolution=detect_resolution,
77
+ include_body=False,
78
+ include_hand=False,
79
+ include_face=True,
80
+ output_type="pil"
81
  )
82
  else:
83
+ # Basic OpenPose
84
+ detector = get_openpose_detector()
85
+ result = detector(
86
+ image,
87
+ detect_resolution=detect_resolution,
88
+ hand_and_face=detect_hand and detect_face,
89
+ output_type="pil"
90
  )
91
 
92
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  except Exception as e:
95
+ print(f"Error during processing: {str(e)}")
96
+ return None
97
 
98
 
99
+ # Create Gradio interface
100
+ with gr.Blocks(
101
+ title="🦴 OpenPose Preprocessor",
102
+ theme=gr.themes.Soft()
103
+ ) as demo:
 
 
 
 
 
 
 
 
 
104
 
105
+ gr.Markdown(
106
+ """
107
+ # 🦴 OpenPose Preprocessor for ControlNet
 
 
108
 
109
+ High-quality pose detection with multiple models. Upload an image and get pose skeleton for ControlNet.
110
+ """
111
+ )
112
+
113
+ gr.Markdown(f"**Device**: `{DEVICE}` {'πŸš€' if DEVICE == 'cuda' else '🐒'}")
114
+
115
+ with gr.Row():
116
+ with gr.Column(scale=1):
117
+ input_image = gr.Image(label="πŸ“· Input Image", type="pil", height=400)
118
 
119
+ model_type = gr.Dropdown(
120
+ label="πŸ€– Model",
121
+ choices=["DWPose", "OpenPose", "OpenPose (Full)", "OpenPose (Face Only)"],
122
+ value="DWPose",
123
+ info="DWPose is recommended for better accuracy"
124
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ with gr.Row():
127
+ detect_hand = gr.Checkbox(label="πŸ‘† Detect Hands", value=True)
128
+ detect_face = gr.Checkbox(label="😊 Detect Face", value=True)
129
+
130
+ detect_resolution = gr.Slider(
131
+ label="πŸ“ Detection Resolution",
132
+ minimum=256,
133
+ maximum=2048,
134
+ value=512,
135
+ step=64,
136
+ info="Higher = more accurate but slower"
137
+ )
138
+
139
+ process_btn = gr.Button("πŸš€ Detect Pose", variant="primary", size="lg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ with gr.Column(scale=1):
142
+ output_image = gr.Image(label="🎨 Output Pose", type="pil", height=400)
143
+
144
+ gr.Markdown(
145
+ """
146
+ ### πŸ“Œ Tips
147
+ - **DWPose** is recommended for best accuracy, especially for hands
148
+ - **OpenPose (Full)** detects body, face, and hands together
149
+ - Higher **Detection Resolution** improves accuracy but increases processing time
150
+ - The output image can be directly used with ControlNet OpenPose models
151
+ """
152
+ )
153
 
154
+ process_btn.click(
155
+ fn=detect_pose,
156
+ inputs=[input_image, model_type, detect_hand, detect_face, detect_resolution],
157
+ outputs=[output_image]
158
+ )
159
 
160
 
161
  if __name__ == "__main__":
 
162
  demo.launch(
163
  server_name="0.0.0.0",
164
  server_port=7860,
165
+ ssr_mode=False
166
  )