Werli commited on
Commit
eaf0408
·
verified ·
1 Parent(s): f78d889

Update modules/pixai.py

Browse files
Files changed (1) hide show
  1. modules/pixai.py +803 -801
modules/pixai.py CHANGED
@@ -1,801 +1,803 @@
1
- import os, json, zipfile, tempfile, time, traceback
2
- import gradio as gr
3
- import pandas as pd
4
- import numpy as np
5
- import onnxruntime as ort
6
- from collections import defaultdict
7
- from typing import Union, Dict, Any, Tuple, List
8
- from PIL import Image
9
- from huggingface_hub import hf_hub_download
10
- from huggingface_hub.errors import EntryNotFoundError
11
- from datetime import datetime
12
- from modules.media_handler import handle_single_media_upload, handle_multiple_media_uploads
13
-
14
- # Global variables for model components (for memory management)
15
- CURRENT_MODEL = None
16
- CURRENT_MODEL_NAME = None
17
- CURRENT_TAGS_DF = None
18
- CURRENT_D_IPS = None
19
- CURRENT_PREPROCESS_FUNC = None
20
- CURRENT_THRESHOLDS = None
21
- CURRENT_CATEGORY_NAMES = None
22
-
23
- css = """
24
- #custom-gallery {--row-height: 180px;display: grid;grid-auto-rows: min-content;gap: 10px;}
25
- #custom-gallery .thumbnail-item {height: var(--row-height);width: 100%;position: relative;overflow: hidden;border-radius: 8px;box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);transition: transform 0.2s ease, box-shadow 0.2s ease;}
26
- #custom-gallery .thumbnail-item:hover {transform: translateY(-3px);box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);}
27
- #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: contain;margin: 0 auto;display: block;}
28
- #custom-gallery .thumbnail-item img.portrait {max-width: 100%;}
29
- #custom-gallery .thumbnail-item img.landscape {max-height: 100%;}
30
- .gallery-container {max-height: 500px;overflow-y: auto;padding-right: 0px;--size-80: 500px;}
31
- .thumbnails {display: flex;position: absolute;bottom: 0;width: 120px;overflow-x: scroll;padding-top: 320px;padding-bottom: 280px;padding-left: 4px;flex-wrap: wrap;}
32
- #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: initial;width: fit-content;margin: 0px auto;display: block;}
33
- """
34
-
35
- def preprocess_on_gpu(img, device='cuda'):
36
- """Preprocess image on GPU using PyTorch"""
37
- import torch
38
- import torchvision.transforms as transforms
39
- # Convert PIL to tensor and move to GPU
40
- transform = transforms.Compose([transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
41
- # Move to GPU if available
42
- tensor_img = transform(img).unsqueeze(0)
43
- if torch.cuda.is_available():
44
- tensor_img = tensor_img.to(device)
45
- return tensor_img.cpu().numpy()
46
-
47
- class Timer: # Report the execution time & process
48
- def __init__(self):
49
- self.start_time = time.perf_counter()
50
- self.checkpoints = [('Start', self.start_time)]
51
-
52
- def checkpoint(self, label='Checkpoint'):
53
- now = time.perf_counter()
54
- self.checkpoints.append((label, now))
55
-
56
- def report(self, is_clear_checkpoints=True):
57
- max_label_length = max(len(label) for (label, _) in self.checkpoints) if self.checkpoints else 0
58
- prev_time = self.checkpoints[0][1] if self.checkpoints else self.start_time
59
-
60
- for (label, curr_time) in self.checkpoints[1:]:
61
- elapsed = curr_time - prev_time
62
- print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
63
- prev_time = curr_time
64
-
65
- if is_clear_checkpoints:
66
- self.checkpoints.clear()
67
- self.checkpoint()
68
-
69
- def report_all(self):
70
- print('\n> Execution Time Report:')
71
- max_label_length = max(len(label) for (label, _) in self.checkpoints) if len(self.checkpoints) > 0 else 0
72
- prev_time = self.start_time
73
-
74
- for (label, curr_time) in self.checkpoints[1:]:
75
- elapsed = curr_time - prev_time
76
- print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
77
- prev_time = curr_time
78
-
79
- total_time = self.checkpoints[-1][1] - self.start_time if self.checkpoints else 0
80
- print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n") # Performance tests
81
- self.checkpoints.clear()
82
-
83
- def restart(self):
84
- self.start_time = time.perf_counter()
85
- self.checkpoints = [('Start', self.start_time)]
86
-
87
- def _get_repo_id(model_name: str) -> str:
88
- """Get the repository ID for the specified model name."""
89
- if '/' in model_name:
90
- return model_name
91
- else:
92
- return f'deepghs/pixai-tagger-{model_name}-onnx'
93
-
94
- def _download_model_files(model_name: str):
95
- """Download all required model files."""
96
- repo_id = _get_repo_id(model_name)
97
-
98
- # Download the necessary files using hf_hub_download instead of local cache...
99
- model_path = hf_hub_download(
100
- repo_id=repo_id,
101
- filename='model.onnx',
102
- library_name="pixai-tagger"
103
- )
104
- tags_path = hf_hub_download(
105
- repo_id=repo_id,
106
- filename='selected_tags.csv',
107
- library_name="pixai-tagger"
108
- )
109
- preprocess_path = hf_hub_download(
110
- repo_id=repo_id,
111
- filename='preprocess.json',
112
- library_name="pixai-tagger"
113
- )
114
- try:
115
- thresholds_path = hf_hub_download(
116
- repo_id=repo_id,
117
- filename='thresholds.csv',
118
- library_name="pixai-tagger"
119
- )
120
- except EntryNotFoundError:
121
- thresholds_path = None
122
-
123
- return model_path, tags_path, preprocess_path, thresholds_path
124
-
125
- def create_optimized_ort_session(model_path):
126
- """Create an optimized ONNX Runtime session with GPU support"""
127
- # Test: Session options for better performance
128
- sess_options = ort.SessionOptions()
129
- sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
130
- sess_options.intra_op_num_threads = 0 # Use all available cores
131
- sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
132
- sess_options.enable_mem_pattern = True
133
- sess_options.enable_cpu_mem_arena = True
134
-
135
- # Check available providers
136
- available_providers = ort.get_available_providers()
137
- print(f"Available ONNX Runtime providers: {available_providers}")
138
-
139
- # Use appropriate execution providers (in order of preference)
140
- providers = []
141
-
142
- # Use CUDA if available
143
- if 'CUDAExecutionProvider' in available_providers:
144
- cuda_provider = ('CUDAExecutionProvider', {
145
- 'device_id': 0,
146
- 'arena_extend_strategy': 'kNextPowerOfTwo',
147
- 'gpu_mem_limit': 4 * 1024 * 1024 * 1024, # 4GB VRAM
148
- 'cudnn_conv_algo_search': 'EXHAUSTIVE',
149
- 'do_copy_in_default_stream': True,
150
- })
151
- providers.append(cuda_provider)
152
- print("Using CUDA provider for ONNX inference")
153
- else:
154
- print("CUDA provider not available, falling back to CPU")
155
-
156
- # Always include CPU as fallback (FOR HF)
157
- providers.append('CPUExecutionProvider')
158
-
159
- try:
160
- session = ort.InferenceSession(model_path, sess_options, providers=providers)
161
- print(f"Model loaded with providers: {session.get_providers()}")
162
- return session
163
- except Exception as e:
164
- print(f"Failed to create ONNX session: {e}")
165
- raise
166
-
167
- def _load_model_components_optimized(model_name: str):
168
- global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_D_IPS
169
- global CURRENT_PREPROCESS_FUNC, CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES
170
-
171
- # Only reload if model changed
172
- if CURRENT_MODEL_NAME != model_name:
173
- # Download files
174
- model_path, tags_path, preprocess_path, thresholds_path = _download_model_files(model_name)
175
-
176
- # Load optimized ONNX model
177
- CURRENT_MODEL = create_optimized_ort_session(model_path)
178
-
179
- # Load tags
180
- CURRENT_TAGS_DF = pd.read_csv(tags_path)
181
- CURRENT_D_IPS = {}
182
-
183
- if 'ips' in CURRENT_TAGS_DF.columns:
184
- CURRENT_TAGS_DF['ips'] = CURRENT_TAGS_DF['ips'].fillna('{}').map(json.loads)
185
- for name, ips in zip(CURRENT_TAGS_DF['name'], CURRENT_TAGS_DF['ips']):
186
- if ips:
187
- CURRENT_D_IPS[name] = ips
188
-
189
- # Load preprocessing
190
- with open(preprocess_path, 'r') as f:
191
- data_ = json.load(f)
192
- # Simple preprocessing function
193
- def transform(img):
194
- # Ensure image is in RGB mode
195
- if img.mode != 'RGB':
196
- img = img.convert('RGB')
197
-
198
- # Resize to 448x448 <- Very important.
199
- img = img.resize((448, 448), Image.Resampling.LANCZOS)
200
-
201
- # Convert to numpy array and normalize
202
- img_array = np.array(img).astype(np.float32)
203
-
204
- # Normalize pixel values to [0, 1]
205
- img_array = img_array / 255.0
206
-
207
- # Normalize with ImageNet mean and std
208
- mean = np.array([0.48145466, 0.4578275, 0.40821073]).astype(np.float32)
209
- std = np.array([0.26862954, 0.26130258, 0.27577711]).astype(np.float32)
210
- img_array = (img_array - mean) / std
211
-
212
- # Transpose to (C, H, W)
213
- img_array = np.transpose(img_array, (2, 0, 1))
214
- return img_array
215
-
216
- CURRENT_PREPROCESS_FUNC = transform
217
-
218
- # Load thresholds
219
- CURRENT_THRESHOLDS = {}
220
- CURRENT_CATEGORY_NAMES = {}
221
-
222
- if thresholds_path and os.path.exists(thresholds_path):
223
- df_category_thresholds = pd.read_csv(thresholds_path, keep_default_na=False)
224
- for item in df_category_thresholds.to_dict('records'):
225
- if item['category'] not in CURRENT_THRESHOLDS:
226
- CURRENT_THRESHOLDS[item['category']] = item['threshold']
227
- CURRENT_CATEGORY_NAMES[item['category']] = item['name']
228
- else:
229
- # Default thresholds if file doesn't exist
230
- CURRENT_THRESHOLDS = {0: 0.3, 4: 0.85, 9: 0.85}
231
- CURRENT_CATEGORY_NAMES = {0: 'general', 4: 'character', 9: 'rating'}
232
-
233
- CURRENT_MODEL_NAME = model_name
234
-
235
- return (CURRENT_MODEL, CURRENT_TAGS_DF, CURRENT_D_IPS, CURRENT_PREPROCESS_FUNC,
236
- CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES)
237
-
238
- def _raw_predict(image: Image.Image, model_name: str):
239
- """Make a raw prediction with the PixAI tagger model."""
240
- try:
241
- # Ensure we have a PIL Image
242
- if not isinstance(image, Image.Image):
243
- raise ValueError("Input must be a PIL Image") # <-
244
-
245
- # Load model components
246
- model, _, _, preprocess_func, _, _ = _load_model_components_optimized(model_name)
247
-
248
- # Preprocess image
249
- input_tensor = preprocess_func(image)
250
-
251
- # Add batch dimension
252
- if len(input_tensor.shape) == 3:
253
- input_tensor = np.expand_dims(input_tensor, axis=0)
254
-
255
- # Run inference
256
- output_names = [output.name for output in model.get_outputs()]
257
- output_values = model.run(output_names, {'input': input_tensor.astype(np.float32)})
258
-
259
- return {name: value[0] for name, value in zip(output_names, output_values)}
260
-
261
- except Exception as e:
262
- raise RuntimeError(f"Error processing image: {str(e)}")
263
-
264
- def get_pixai_tags(
265
- image: Union[str, Image.Image],
266
- model_name: str = 'deepghs/pixai-tagger-v0.9-onnx',
267
- thresholds: Union[float, Dict[Any, float]] = None,
268
- fmt='all'
269
- ):
270
- try:
271
- # Load image if it's a path
272
- if isinstance(image, str):
273
- pil_image = Image.open(image)
274
- elif isinstance(image, Image.Image):
275
- pil_image = image
276
- else:
277
- raise ValueError("Image must be a file path or PIL Image")
278
-
279
- # Load model components
280
- _, df_tags, d_ips, _, default_thresholds, category_names = _load_model_components_optimized(model_name)
281
-
282
- values = _raw_predict(pil_image, model_name)
283
- prediction = values.get('prediction', np.array([]))
284
-
285
- if prediction.size == 0:
286
- raise RuntimeError("Model did not return valid predictions")
287
-
288
- tags = {}
289
-
290
- # Process tags by category
291
- for category in sorted(set(df_tags['category'].tolist())):
292
- mask = df_tags['category'] == category
293
- tag_names = df_tags.loc[mask, 'name']
294
- category_pred = prediction[mask]
295
-
296
- # Determine threshold for this category
297
- if isinstance(thresholds, float):
298
- category_threshold = thresholds
299
- elif isinstance(thresholds, dict) and \
300
- (category in thresholds or category_names.get(category, '') in thresholds):
301
- if category in thresholds:
302
- category_threshold = thresholds[category]
303
- elif category_names.get(category, '') in thresholds:
304
- category_threshold = thresholds[category_names[category]]
305
- else:
306
- category_threshold = 0.85
307
- else:
308
- category_threshold = default_thresholds.get(category, 0.85)
309
-
310
- # Apply threshold
311
- pred_mask = category_pred >= category_threshold
312
- filtered_tag_names = tag_names[pred_mask].tolist()
313
- filtered_predictions = category_pred[pred_mask].tolist()
314
-
315
- # Sort by confidence
316
- cate_tags = dict(sorted(
317
- zip(filtered_tag_names, filtered_predictions),
318
- key=lambda x: (-x[1], x[0])
319
- ))
320
-
321
- category_name = category_names.get(category, f"category_{category}")
322
- values[category_name] = cate_tags
323
- tags.update(cate_tags)
324
-
325
- values['tag'] = tags
326
-
327
- # Handle IPs if available
328
- if 'ips' in df_tags.columns:
329
- ips_mapping, ips_counts = {}, defaultdict(int)
330
- for tag, _ in tags.items():
331
- if tag in d_ips:
332
- ips_mapping[tag] = d_ips[tag]
333
- for ip_name in d_ips[tag]:
334
- ips_counts[ip_name] += 1
335
- values['ips_mapping'] = ips_mapping
336
- values['ips_count'] = dict(ips_counts)
337
- values['ips'] = [x for x, _ in sorted(ips_counts.items(), key=lambda x: (-x[1], x[0]))]
338
-
339
- # Return based on format
340
- if fmt == 'all':
341
- # Return all available categories
342
- available_categories = [category_names.get(cat, f"category_{cat}")
343
- for cat in sorted(set(df_tags['category'].tolist()))]
344
- return tuple(values.get(cat, {}) for cat in available_categories)
345
- elif fmt in values:
346
- return values[fmt]
347
- else:
348
- return values
349
-
350
- except Exception as e:
351
- raise RuntimeError(f"Error processing image: {str(e)}")
352
-
353
- def format_ips_output(ips_result, ips_mapping):
354
- """Format IP detection output as a single string with proper escaping."""
355
- if not ips_result and not ips_mapping:
356
- return ""
357
-
358
- # Format detected IPs
359
- ips_list = []
360
- if ips_result:
361
- ips_list = [ip.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
362
- for ip in ips_result]
363
-
364
- # Format character-to-IP mapping
365
- mapping_list = []
366
- if ips_mapping:
367
- for char, ips in ips_mapping.items():
368
- formatted_char = char.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
369
- formatted_ips = [ip.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
370
- for ip in ips]
371
- mapping_list.append(f"{formatted_char}: {', '.join(formatted_ips)}")
372
-
373
- # Combine all into a single string
374
- result_parts = []
375
- if ips_list:
376
- result_parts.append(", ".join(ips_list))
377
- if mapping_list:
378
- result_parts.extend(mapping_list)
379
-
380
- return ", ".join(result_parts)
381
-
382
- def process_single_image(
383
- image_path,
384
- model_name="deepghs/pixai-tagger-v0.9-onnx", ###
385
- general_threshold=0.3,
386
- character_threshold=0.85,
387
- progress=None,
388
- idx=0,
389
- total_images=1
390
- ):
391
- """Process a single image and return all formatted outputs."""
392
- try:
393
- if image_path is None:
394
- return "", "", "", "", {}, {}
395
-
396
- if progress:
397
- progress((idx)/total_images, desc=f"Processing image {idx+1}/{total_images}")
398
-
399
- # Load image from path
400
- pil_image = Image.open(image_path)
401
-
402
- # Set thresholds
403
- thresholds = {
404
- 'general': general_threshold,
405
- 'character': character_threshold
406
- }
407
-
408
- # Get all tag categories
409
- all_categories = get_pixai_tags(
410
- pil_image, model_name, thresholds, fmt='all'
411
- )
412
-
413
- # Ensure we have at least 3 categories (general, character, rating)
414
- while len(all_categories) < 3:
415
- all_categories += ({},)
416
-
417
- general_tags = all_categories[0] if len(all_categories) > 0 else {}
418
- character_tags = all_categories[1] if len(all_categories) > 1 else {}
419
- rating_tags = all_categories[2] if len(all_categories) > 2 else {}
420
-
421
- # Get IP detection data
422
- ips_result = get_pixai_tags(pil_image, model_name, thresholds, fmt='ips') or []
423
- ips_mapping = get_pixai_tags(pil_image, model_name, thresholds, fmt='ips_mapping') or {}
424
-
425
- # Format character tags (names only)
426
- character_names = [name.replace("(", "\\(").replace(")", "\\)").replace("_", " ") # Replacement shouldn't be necessary here, but I'll do anyway
427
- for name in character_tags.keys()]
428
- character_output = ", ".join(character_names)
429
-
430
- # Format general tags (names only)
431
- general_names = [name.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
432
- for name in general_tags.keys()]
433
- general_output = ", ".join(general_names)
434
-
435
- # Format IP detection output
436
- ips_output = format_ips_output(ips_result, ips_mapping)
437
-
438
- # Format combined tags (Character tags first, then General tags, then IP tags)
439
- combined_parts = []
440
- if character_names:
441
- combined_parts.append(", ".join(character_names))
442
- if general_names:
443
- combined_parts.append(", ".join(general_names))
444
- if ips_output:
445
- combined_parts.append(ips_output)
446
-
447
- combined_output = ", ".join(combined_parts)
448
-
449
- # Get detailed JSON data
450
- json_data = {
451
- "character_tags": character_tags,
452
- "general_tags": general_tags,
453
- "rating_tags": rating_tags,
454
- "ips_result": ips_result,
455
- "ips_mapping": ips_mapping
456
- }
457
-
458
- # Format rating as label-compatible dict
459
- rating_output = {k.replace("(", "\\(").replace(")", "\\)").replace("_", " "): v
460
- for k, v in rating_tags.items()}
461
-
462
- return (
463
- character_output, # Character tags
464
- general_output, # General tags
465
- ips_output, # IP Detection
466
- combined_output, # Combined tags
467
- json_data, # Detailed JSON
468
- rating_output # Rating <- Not working atm
469
- )
470
- except Exception as e:
471
- error_msg = f"Error: {str(e)}"
472
- # Return error message for all 6 outputs
473
- return error_msg, error_msg, error_msg, error_msg, {}, {} # 6
474
-
475
- """GPU"""
476
- def unload_model():
477
- """Explicitly unload the current model from memory"""
478
- global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_D_IPS
479
- global CURRENT_PREPROCESS_FUNC, CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES
480
- # Delete the model session
481
- if CURRENT_MODEL is not None:
482
- del CURRENT_MODEL
483
- CURRENT_MODEL = None
484
- # Clear other large objects
485
- CURRENT_TAGS_DF = None
486
- CURRENT_D_IPS = None
487
- CURRENT_PREPROCESS_FUNC = None
488
- CURRENT_THRESHOLDS = None
489
- CURRENT_CATEGORY_NAMES = None
490
- CURRENT_MODEL_NAME = None
491
- # Force garbage collection
492
- import gc
493
- gc.collect()
494
- # Clear CUDA cache if using GPU
495
- try:
496
- import torch
497
- if torch.cuda.is_available():
498
- torch.cuda.empty_cache()
499
- except ImportError:
500
- pass
501
- # print("Model unloaded and memory cleared")
502
- def cleanup_after_processing():
503
- unload_model()
504
-
505
- def process_gallery_images(
506
- gallery,
507
- model_name,
508
- general_threshold,
509
- character_threshold,
510
- progress=gr.Progress()
511
- ):
512
- """Process all images in the gallery and return results with download file."""
513
- if not gallery:
514
- return [], "", "", "", {}, {}, {}, None
515
-
516
- tag_results = {}
517
- txt_infos = []
518
- output_dir = tempfile.mkdtemp()
519
-
520
- if not os.path.exists(output_dir):
521
- os.makedirs(output_dir)
522
-
523
- total_images = len(gallery)
524
- timer = Timer()
525
-
526
- try:
527
- for idx, image_data in enumerate(gallery):
528
- try:
529
- image_path = image_data[0] if isinstance(image_data, (list, tuple)) else image_data
530
-
531
- # Process image
532
- results = process_single_image(
533
- image_path, model_name, general_threshold, character_threshold,
534
- progress, idx, total_images
535
- )
536
-
537
- # Store results
538
- tag_results[image_path] = {
539
- 'character_tags': results[0],
540
- 'general_tags': results[1],
541
- 'ips_detection': results[2],
542
- 'combined_tags': results[3],
543
- 'json_data': results[4],
544
- 'rating': results[5]
545
- }
546
-
547
- # Create output files with descriptive names
548
- image_name = os.path.splitext(os.path.basename(image_path))[0]
549
-
550
- # Save all output files with descriptive prefixes
551
- files_to_create = [
552
- (f"character_tags-{image_name}.txt", results[0]),
553
- (f"general_tags-{image_name}.txt", results[1]),
554
- (f"ips_detection-{image_name}.txt", results[2]),
555
- (f"combined_tags-{image_name}.txt", results[3]),
556
- (f"detailed_json-{image_name}.json", json.dumps(results[4], indent=4, ensure_ascii=False))
557
- ]
558
-
559
- for file_name, content in files_to_create:
560
- file_path = os.path.join(output_dir, file_name)
561
- with open(file_path, 'w', encoding='utf-8') as f:
562
- f.write(content if isinstance(content, str) else content)
563
- txt_infos.append({'path': file_path, 'name': file_name})
564
-
565
- # Copy original image
566
- original_image = Image.open(image_path)
567
- image_copy_path = os.path.join(output_dir, f"{image_name}{os.path.splitext(image_path)[1]}")
568
- original_image.save(image_copy_path)
569
- txt_infos.append({'path': image_copy_path, 'name': f"{image_name}{os.path.splitext(image_path)[1]}"})
570
-
571
- timer.checkpoint(f"image{idx:02d}, processed")
572
-
573
- except Exception as e:
574
- print(f"Error processing image {image_path}: {str(e)}")
575
- print(traceback.format_exc())
576
- continue
577
-
578
- # Create zip file
579
- download_zip_path = os.path.join(output_dir, f"Multi-Tagger-{datetime.now().strftime('%Y%m%d-%H%M%S')}.zip")
580
- with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
581
- for info in txt_infos:
582
- zipf.write(info['path'], arcname=info['name'])
583
- # If using GPU, model will auto unload after zip file creation
584
- cleanup_after_processing() # Comment here to turn off this behavior
585
-
586
- progress(1.0, desc="Processing complete")
587
- timer.report_all()
588
- print('Processing is complete.')
589
-
590
- # Return first image results as default if available even if we are tagging 1000+ images.
591
- first_image_results = ("", "", "", {}, {}, "") # 6
592
- if gallery and len(gallery) > 0:
593
- first_image_path = gallery[0][0] if isinstance(gallery[0], (list, tuple)) else gallery[0]
594
- if first_image_path in tag_results:
595
- result = tag_results[first_image_path]
596
- first_image_results = (
597
- result['character_tags'],
598
- result['general_tags'],
599
- result['combined_tags'],
600
- result['json_data'],
601
- result['rating'],
602
- result['ips_detection']
603
- )
604
-
605
- return tag_results, first_image_results[0], first_image_results[1], first_image_results[2], first_image_results[3], first_image_results[4], first_image_results[5], download_zip_path
606
-
607
- except Exception as e:
608
- print(f"Error in process_gallery_images: {str(e)}")
609
- print(traceback.format_exc())
610
- progress(1.0, desc="Processing failed")
611
- return {}, "", "", "", {}, {}, "", None
612
-
613
- def get_selection_from_gallery(gallery, tag_results, selected_state: gr.SelectData):
614
- """Handle gallery image selection and update UI with stored results."""
615
- if not selected_state or not tag_results:
616
- return "", "", "", {}, {}, ""
617
-
618
- # Get selected image path
619
- selected_value = selected_state.value
620
- if isinstance(selected_value, dict) and 'image' in selected_value:
621
- image_path = selected_value['image']['path']
622
- elif isinstance(selected_value, (list, tuple)) and len(selected_value) > 0:
623
- image_path = selected_value[0]
624
- else:
625
- image_path = str(selected_value)
626
-
627
- # Retrieve stored results
628
- if image_path in tag_results:
629
- result = tag_results[image_path]
630
- return (
631
- result['character_tags'],
632
- result['general_tags'],
633
- result['combined_tags'],
634
- result['json_data'],
635
- result['rating'],
636
- result['ips_detection']
637
- )
638
-
639
- # Return empty if not found
640
- return "", "", "", {}, {}, ""
641
-
642
- def append_gallery(gallery, image):
643
- """Add a single media file (image or video) to the gallery."""
644
- return handle_single_media_upload(image, gallery)
645
-
646
- def extend_gallery(gallery, images):
647
- """Add multiple media files (images or videos) to the gallery."""
648
- return handle_multiple_media_uploads(images, gallery)
649
-
650
- def create_pixai_interface():
651
- """Create the PixAI Gradio interface"""
652
- with gr.Blocks(css=css, fill_width=True) as demo:
653
- # gr.Markdown("Upload anime-style images to extract tags using PixAI")
654
- # State to store results
655
- tag_results = gr.State({})
656
- selected_image = gr.Textbox(label='Selected Image', visible=False)
657
-
658
- with gr.Row():
659
- with gr.Column():
660
- # Image upload section
661
- with gr.Column(variant='panel'):
662
- image_input = gr.Image(
663
- label='Upload an Image (or paste from clipboard)',
664
- type='filepath',
665
- sources=['upload', 'clipboard'],
666
- height=150
667
- )
668
- with gr.Row():
669
- upload_button = gr.UploadButton(
670
- 'Upload multiple images or videos',
671
- file_types=['image', 'video'],
672
- file_count='multiple',
673
- size='sm'
674
- )
675
- gallery = gr.Gallery(
676
- columns=2,
677
- show_share_button=False,
678
- interactive=True,
679
- height='auto',
680
- label='Grid of images',
681
- preview=False,
682
- elem_id='custom-gallery'
683
- )
684
- run_button = gr.Button("Analyze Images", variant="primary", size='lg')
685
- model_dropdown = gr.Dropdown(
686
- choices=["deepghs/pixai-tagger-v0.9-onnx"],
687
- value="deepghs/pixai-tagger-v0.9-onnx",
688
- label="Model"
689
- )
690
- # Threshold controls
691
- with gr.Row():
692
- general_threshold = gr.Slider(
693
- minimum=0.0, maximum=1.0, value=0.30, step=0.05,
694
- label="General Tags Threshold", scale=3
695
- )
696
- character_threshold = gr.Slider(
697
- minimum=0.0, maximum=1.0, value=0.85, step=0.05,
698
- label="Character Tags Threshold", scale=3
699
- )
700
-
701
- with gr.Row():
702
- clear = gr.ClearButton(
703
- components=[gallery, model_dropdown, general_threshold, character_threshold],
704
- variant='secondary',
705
- size='lg'
706
- )
707
- clear.add([tag_results])
708
- detailed_json_output = gr.JSON(label="Detailed JSON")
709
-
710
- with gr.Column(variant='panel'):
711
-
712
- download_file = gr.File(label="Download")
713
-
714
- # Output blocks
715
- character_tags_output = gr.Textbox(
716
- label="Character tags",
717
- show_copy_button=True,
718
- lines=3
719
- )
720
- general_tags_output = gr.Textbox(
721
- label="General tags",
722
- show_copy_button=True,
723
- lines=3
724
- )
725
- ips_detection_output = gr.Textbox(
726
- label="IPs Detection",
727
- show_copy_button=True,
728
- lines=5
729
- )
730
- combined_tags_output = gr.Textbox(
731
- label="Combined tags",
732
- show_copy_button=True,
733
- lines=6
734
- )
735
- rating_output = gr.Label(label="Rating")
736
-
737
- # Clear button targets
738
- clear.add([
739
- download_file,
740
- character_tags_output,
741
- general_tags_output,
742
- ips_detection_output,
743
- combined_tags_output,
744
- rating_output,
745
- detailed_json_output
746
- ])
747
-
748
- # Event handlers
749
- image_input.change(
750
- append_gallery,
751
- inputs=[gallery, image_input],
752
- outputs=[gallery, image_input]
753
- )
754
-
755
- upload_button.upload(
756
- extend_gallery,
757
- inputs=[gallery, upload_button],
758
- outputs=gallery
759
- )
760
-
761
- gallery.select(
762
- get_selection_from_gallery,
763
- inputs=[gallery, tag_results],
764
- outputs=[
765
- character_tags_output,
766
- general_tags_output,
767
- combined_tags_output,
768
- detailed_json_output,
769
- rating_output,
770
- ips_detection_output
771
- ]
772
- )
773
-
774
- run_button.click(
775
- process_gallery_images,
776
- inputs=[gallery, model_dropdown, general_threshold, character_threshold],
777
- outputs=[
778
- tag_results,
779
- character_tags_output,
780
- general_tags_output,
781
- combined_tags_output,
782
- detailed_json_output,
783
- rating_output,
784
- ips_detection_output,
785
- download_file
786
- ]
787
- )
788
-
789
- gr.Markdown('[Based on Source code for imgutils.tagging.pixai](https://dghs-imgutils.deepghs.org/main/_modules/imgutils/tagging/pixai.html) & [pixai-labs/pixai-tagger-demo](https://huggingface.co/spaces/pixai-labs/pixai-tagger-demo)')
790
-
791
- return demo
792
-
793
- # Export public API
794
- __all__ = [
795
- 'get_pixai_tags',
796
- 'process_single_image',
797
- 'process_gallery_images',
798
- 'create_pixai_interface',
799
- 'unload_model',
800
- 'cleanup_after_processing'
801
- ]
 
 
 
1
+ import os, json, zipfile, tempfile, time, traceback
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from collections import defaultdict
7
+ from typing import Union, Dict, Any, Tuple, List
8
+ from PIL import Image
9
+ from huggingface_hub import hf_hub_download
10
+ from huggingface_hub.errors import EntryNotFoundError
11
+ from datetime import datetime
12
+ from modules.media_handler import handle_single_media_upload, handle_multiple_media_uploads
13
+
14
+ # Global variables for model components (for memory management)
15
+ CURRENT_MODEL = None
16
+ CURRENT_MODEL_NAME = None
17
+ CURRENT_TAGS_DF = None
18
+ CURRENT_D_IPS = None
19
+ CURRENT_PREPROCESS_FUNC = None
20
+ CURRENT_THRESHOLDS = None
21
+ CURRENT_CATEGORY_NAMES = None
22
+
23
+ css = """
24
+ #custom-gallery {--row-height: 180px;display: grid;grid-auto-rows: min-content;gap: 10px;}
25
+ #custom-gallery .thumbnail-item {height: var(--row-height);width: 100%;position: relative;overflow: hidden;border-radius: 8px;box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);transition: transform 0.2s ease, box-shadow 0.2s ease;}
26
+ #custom-gallery .thumbnail-item:hover {transform: translateY(-3px);box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);}
27
+ #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: contain;margin: 0 auto;display: block;}
28
+ #custom-gallery .thumbnail-item img.portrait {max-width: 100%;}
29
+ #custom-gallery .thumbnail-item img.landscape {max-height: 100%;}
30
+ .gallery-container {max-height: 500px;overflow-y: auto;padding-right: 0px;--size-80: 500px;}
31
+ .thumbnails {display: flex;position: absolute;bottom: 0;width: 120px;overflow-x: scroll;padding-top: 320px;padding-bottom: 280px;padding-left: 4px;flex-wrap: wrap;}
32
+ #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: initial;width: fit-content;margin: 0px auto;display: block;}
33
+ """
34
+
35
+ def preprocess_on_gpu(img, device='cuda'):
36
+ """Preprocess image on GPU using PyTorch"""
37
+ import torch
38
+ import torchvision.transforms as transforms
39
+ # Convert PIL to tensor and move to GPU
40
+ transform = transforms.Compose([transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
41
+ # Move to GPU if available
42
+ tensor_img = transform(img).unsqueeze(0)
43
+ if torch.cuda.is_available():
44
+ tensor_img = tensor_img.to(device)
45
+ return tensor_img.cpu().numpy()
46
+
47
+ class Timer: # Report the execution time & process
48
+ def __init__(self):
49
+ self.start_time = time.perf_counter()
50
+ self.checkpoints = [('Start', self.start_time)]
51
+
52
+ def checkpoint(self, label='Checkpoint'):
53
+ now = time.perf_counter()
54
+ self.checkpoints.append((label, now))
55
+
56
+ def report(self, is_clear_checkpoints=True):
57
+ max_label_length = max(len(label) for (label, _) in self.checkpoints) if self.checkpoints else 0
58
+ prev_time = self.checkpoints[0][1] if self.checkpoints else self.start_time
59
+
60
+ for (label, curr_time) in self.checkpoints[1:]:
61
+ elapsed = curr_time - prev_time
62
+ print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
63
+ prev_time = curr_time
64
+
65
+ if is_clear_checkpoints:
66
+ self.checkpoints.clear()
67
+ self.checkpoint()
68
+
69
+ def report_all(self):
70
+ print('\n> Execution Time Report:')
71
+ max_label_length = max(len(label) for (label, _) in self.checkpoints) if len(self.checkpoints) > 0 else 0
72
+ prev_time = self.start_time
73
+
74
+ for (label, curr_time) in self.checkpoints[1:]:
75
+ elapsed = curr_time - prev_time
76
+ print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
77
+ prev_time = curr_time
78
+
79
+ total_time = self.checkpoints[-1][1] - self.start_time if self.checkpoints else 0
80
+ print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n") # Performance tests
81
+ self.checkpoints.clear()
82
+
83
+ def restart(self):
84
+ self.start_time = time.perf_counter()
85
+ self.checkpoints = [('Start', self.start_time)]
86
+
87
+ def _get_repo_id(model_name: str) -> str:
88
+ """Get the repository ID for the specified model name."""
89
+ if '/' in model_name:
90
+ return model_name
91
+ else:
92
+ return f'deepghs/pixai-tagger-{model_name}-onnx'
93
+
94
+ def _download_model_files(model_name: str):
95
+ """Download all required model files."""
96
+ repo_id = _get_repo_id(model_name)
97
+
98
+ # Download the necessary files using hf_hub_download instead of local cache...
99
+ model_path = hf_hub_download(
100
+ repo_id=repo_id,
101
+ filename='model.onnx',
102
+ library_name="pixai-tagger"
103
+ )
104
+ tags_path = hf_hub_download(
105
+ repo_id=repo_id,
106
+ filename='selected_tags.csv',
107
+ library_name="pixai-tagger"
108
+ )
109
+ preprocess_path = hf_hub_download(
110
+ repo_id=repo_id,
111
+ filename='preprocess.json',
112
+ library_name="pixai-tagger"
113
+ )
114
+ try:
115
+ thresholds_path = hf_hub_download(
116
+ repo_id=repo_id,
117
+ filename='thresholds.csv',
118
+ library_name="pixai-tagger"
119
+ )
120
+ except EntryNotFoundError:
121
+ thresholds_path = None
122
+
123
+ return model_path, tags_path, preprocess_path, thresholds_path
124
+
125
+ def create_optimized_ort_session(model_path):
126
+ """Create an optimized ONNX Runtime session with GPU support"""
127
+ # Test: Session options for better performance
128
+ sess_options = ort.SessionOptions()
129
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
130
+ sess_options.intra_op_num_threads = 0 # Use all available cores
131
+ sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
132
+ sess_options.enable_mem_pattern = True
133
+ sess_options.enable_cpu_mem_arena = True
134
+
135
+ # Check available providers
136
+ available_providers = ort.get_available_providers()
137
+ print(f"Available ONNX Runtime providers: {available_providers}")
138
+
139
+ # Use appropriate execution providers (in order of preference)
140
+ providers = []
141
+
142
+ # Use CUDA if available
143
+ if 'CUDAExecutionProvider' in available_providers:
144
+ cuda_provider = ('CUDAExecutionProvider', {
145
+ 'device_id': 0,
146
+ 'arena_extend_strategy': 'kNextPowerOfTwo',
147
+ 'gpu_mem_limit': 4 * 1024 * 1024 * 1024, # 4GB VRAM
148
+ 'cudnn_conv_algo_search': 'EXHAUSTIVE',
149
+ 'do_copy_in_default_stream': True,
150
+ })
151
+ providers.append(cuda_provider)
152
+ print("Using CUDA provider for ONNX inference")
153
+ else:
154
+ print("CUDA provider not available, falling back to CPU")
155
+
156
+ # Always include CPU as fallback (FOR HF)
157
+ providers.append('CPUExecutionProvider')
158
+
159
+ try:
160
+ session = ort.InferenceSession(model_path, sess_options, providers=providers)
161
+ print(f"Model loaded with providers: {session.get_providers()}")
162
+ return session
163
+ except Exception as e:
164
+ print(f"Failed to create ONNX session: {e}")
165
+ raise
166
+
167
+ def _load_model_components_optimized(model_name: str):
168
+ global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_D_IPS
169
+ global CURRENT_PREPROCESS_FUNC, CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES
170
+
171
+ # Only reload if model changed
172
+ if CURRENT_MODEL_NAME != model_name:
173
+ # Download files
174
+ model_path, tags_path, preprocess_path, thresholds_path = _download_model_files(model_name)
175
+
176
+ # Load optimized ONNX model
177
+ CURRENT_MODEL = create_optimized_ort_session(model_path)
178
+
179
+ # Load tags
180
+ CURRENT_TAGS_DF = pd.read_csv(tags_path)
181
+ CURRENT_D_IPS = {}
182
+
183
+ if 'ips' in CURRENT_TAGS_DF.columns:
184
+ CURRENT_TAGS_DF['ips'] = CURRENT_TAGS_DF['ips'].fillna('{}').map(json.loads)
185
+ for name, ips in zip(CURRENT_TAGS_DF['name'], CURRENT_TAGS_DF['ips']):
186
+ if ips:
187
+ CURRENT_D_IPS[name] = ips
188
+
189
+ # Load preprocessing
190
+ with open(preprocess_path, 'r') as f:
191
+ data_ = json.load(f)
192
+ # Simple preprocessing function
193
+ def transform(img):
194
+ # Ensure image is in RGB mode
195
+ if img.mode != 'RGB':
196
+ img = img.convert('RGB')
197
+
198
+ # Resize to 448x448 <- Very important.
199
+ img = img.resize((448, 448), Image.Resampling.LANCZOS)
200
+
201
+ # Convert to numpy array and normalize
202
+ img_array = np.array(img).astype(np.float32)
203
+
204
+ # Normalize pixel values to [0, 1]
205
+ img_array = img_array / 255.0
206
+
207
+ # Normalize with ImageNet mean and std
208
+ mean = np.array([0.48145466, 0.4578275, 0.40821073]).astype(np.float32)
209
+ std = np.array([0.26862954, 0.26130258, 0.27577711]).astype(np.float32)
210
+ img_array = (img_array - mean) / std
211
+
212
+ # Transpose to (C, H, W)
213
+ img_array = np.transpose(img_array, (2, 0, 1))
214
+ return img_array
215
+
216
+ CURRENT_PREPROCESS_FUNC = transform
217
+
218
+ # Load thresholds
219
+ CURRENT_THRESHOLDS = {}
220
+ CURRENT_CATEGORY_NAMES = {}
221
+
222
+ if thresholds_path and os.path.exists(thresholds_path):
223
+ df_category_thresholds = pd.read_csv(thresholds_path, keep_default_na=False)
224
+ for item in df_category_thresholds.to_dict('records'):
225
+ if item['category'] not in CURRENT_THRESHOLDS:
226
+ CURRENT_THRESHOLDS[item['category']] = item['threshold']
227
+ CURRENT_CATEGORY_NAMES[item['category']] = item['name']
228
+ else:
229
+ # Default thresholds if file doesn't exist
230
+ CURRENT_THRESHOLDS = {0: 0.3, 4: 0.85, 9: 0.85}
231
+ CURRENT_CATEGORY_NAMES = {0: 'general', 4: 'character', 9: 'rating'}
232
+
233
+ CURRENT_MODEL_NAME = model_name
234
+
235
+ return (CURRENT_MODEL, CURRENT_TAGS_DF, CURRENT_D_IPS, CURRENT_PREPROCESS_FUNC,
236
+ CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES)
237
+
238
+ def _raw_predict(image: Image.Image, model_name: str):
239
+ """Make a raw prediction with the PixAI tagger model."""
240
+ try:
241
+ # Ensure we have a PIL Image
242
+ if not isinstance(image, Image.Image):
243
+ raise ValueError("Input must be a PIL Image") # <-
244
+
245
+ # Load model components
246
+ model, _, _, preprocess_func, _, _ = _load_model_components_optimized(model_name)
247
+
248
+ # Preprocess image
249
+ input_tensor = preprocess_func(image)
250
+
251
+ # Add batch dimension
252
+ if len(input_tensor.shape) == 3:
253
+ input_tensor = np.expand_dims(input_tensor, axis=0)
254
+
255
+ # Run inference
256
+ output_names = [output.name for output in model.get_outputs()]
257
+ output_values = model.run(output_names, {'input': input_tensor.astype(np.float32)})
258
+
259
+ return {name: value[0] for name, value in zip(output_names, output_values)}
260
+
261
+ except Exception as e:
262
+ raise RuntimeError(f"Error processing image: {str(e)}")
263
+
264
+ def get_pixai_tags(
265
+ image: Union[str, Image.Image],
266
+ model_name: str = 'deepghs/pixai-tagger-v0.9-onnx',
267
+ thresholds: Union[float, Dict[Any, float]] = None,
268
+ fmt='all'
269
+ ):
270
+ try:
271
+ # Load image if it's a path
272
+ if isinstance(image, str):
273
+ pil_image = Image.open(image)
274
+ elif isinstance(image, Image.Image):
275
+ pil_image = image
276
+ else:
277
+ raise ValueError("Image must be a file path or PIL Image")
278
+
279
+ # Load model components
280
+ _, df_tags, d_ips, _, default_thresholds, category_names = _load_model_components_optimized(model_name)
281
+
282
+ values = _raw_predict(pil_image, model_name)
283
+ prediction = values.get('prediction', np.array([]))
284
+
285
+ if prediction.size == 0:
286
+ raise RuntimeError("Model did not return valid predictions")
287
+
288
+ tags = {}
289
+
290
+ # Process tags by category
291
+ for category in sorted(set(df_tags['category'].tolist())):
292
+ mask = df_tags['category'] == category
293
+ tag_names = df_tags.loc[mask, 'name']
294
+ category_pred = prediction[mask]
295
+
296
+ # Determine threshold for this category
297
+ if isinstance(thresholds, float):
298
+ category_threshold = thresholds
299
+ elif isinstance(thresholds, dict) and \
300
+ (category in thresholds or category_names.get(category, '') in thresholds):
301
+ if category in thresholds:
302
+ category_threshold = thresholds[category]
303
+ elif category_names.get(category, '') in thresholds:
304
+ category_threshold = thresholds[category_names[category]]
305
+ else:
306
+ category_threshold = 0.85
307
+ else:
308
+ category_threshold = default_thresholds.get(category, 0.85)
309
+
310
+ # Apply threshold
311
+ pred_mask = category_pred >= category_threshold
312
+ filtered_tag_names = tag_names[pred_mask].tolist()
313
+ filtered_predictions = category_pred[pred_mask].tolist()
314
+
315
+ # Sort by confidence
316
+ cate_tags = dict(sorted(
317
+ zip(filtered_tag_names, filtered_predictions),
318
+ key=lambda x: (-x[1], x[0])
319
+ ))
320
+
321
+ category_name = category_names.get(category, f"category_{category}")
322
+ values[category_name] = cate_tags
323
+ tags.update(cate_tags)
324
+
325
+ values['tag'] = tags
326
+
327
+ # Handle IPs if available
328
+ if 'ips' in df_tags.columns:
329
+ ips_mapping, ips_counts = {}, defaultdict(int)
330
+ for tag, _ in tags.items():
331
+ if tag in d_ips:
332
+ ips_mapping[tag] = d_ips[tag]
333
+ for ip_name in d_ips[tag]:
334
+ ips_counts[ip_name] += 1
335
+ values['ips_mapping'] = ips_mapping
336
+ values['ips_count'] = dict(ips_counts)
337
+ values['ips'] = [x for x, _ in sorted(ips_counts.items(), key=lambda x: (-x[1], x[0]))]
338
+
339
+ # Return based on format
340
+ if fmt == 'all':
341
+ # Return all available categories
342
+ available_categories = [category_names.get(cat, f"category_{cat}")
343
+ for cat in sorted(set(df_tags['category'].tolist()))]
344
+ return tuple(values.get(cat, {}) for cat in available_categories)
345
+ elif fmt in values:
346
+ return values[fmt]
347
+ else:
348
+ return values
349
+
350
+ except Exception as e:
351
+ raise RuntimeError(f"Error processing image: {str(e)}")
352
+
353
+ def format_ips_output(ips_result, ips_mapping):
354
+ """Format IP detection output as a single string with proper escaping."""
355
+ if not ips_result and not ips_mapping:
356
+ return ""
357
+
358
+ # Format detected IPs
359
+ ips_list = []
360
+ if ips_result:
361
+ ips_list = [ip.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
362
+ for ip in ips_result]
363
+
364
+ # Format character-to-IP mapping
365
+ mapping_list = []
366
+ if ips_mapping:
367
+ for char, ips in ips_mapping.items():
368
+ formatted_char = char.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
369
+ formatted_ips = [ip.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
370
+ for ip in ips]
371
+ mapping_list.append(f"{formatted_char}: {', '.join(formatted_ips)}")
372
+
373
+ # Combine all into a single string
374
+ result_parts = []
375
+ if ips_list:
376
+ result_parts.append(", ".join(ips_list))
377
+ if mapping_list:
378
+ result_parts.extend(mapping_list)
379
+
380
+ return ", ".join(result_parts)
381
+
382
+ def process_single_image(
383
+ image_path,
384
+ model_name="deepghs/pixai-tagger-v0.9-onnx", ###
385
+ general_threshold=0.3,
386
+ character_threshold=0.85,
387
+ progress=None,
388
+ idx=0,
389
+ total_images=1
390
+ ):
391
+ """Process a single image and return all formatted outputs."""
392
+ try:
393
+ if image_path is None:
394
+ return "", "", "", "", {}, {}
395
+
396
+ if progress:
397
+ progress((idx)/total_images, desc=f"Processing image {idx+1}/{total_images}")
398
+
399
+ # Load image from path
400
+ pil_image = Image.open(image_path)
401
+
402
+ # Set thresholds
403
+ thresholds = {
404
+ 'general': general_threshold,
405
+ 'character': character_threshold
406
+ }
407
+
408
+ # Get all tag categories
409
+ all_categories = get_pixai_tags(
410
+ pil_image, model_name, thresholds, fmt='all'
411
+ )
412
+
413
+ # Ensure we have at least 3 categories (general, character, rating)
414
+ while len(all_categories) < 3:
415
+ all_categories += ({},)
416
+
417
+ general_tags = all_categories[0] if len(all_categories) > 0 else {}
418
+ character_tags = all_categories[1] if len(all_categories) > 1 else {}
419
+ rating_tags = all_categories[2] if len(all_categories) > 2 else {}
420
+
421
+ # Get IP detection data
422
+ ips_result = get_pixai_tags(pil_image, model_name, thresholds, fmt='ips') or []
423
+ ips_mapping = get_pixai_tags(pil_image, model_name, thresholds, fmt='ips_mapping') or {}
424
+
425
+ # Format character tags (names only)
426
+ character_names = [name.replace("(", "\\(").replace(")", "\\)").replace("_", " ") # Replacement shouldn't be necessary here, but I'll do anyway
427
+ for name in character_tags.keys()]
428
+ character_output = ", ".join(character_names)
429
+
430
+ # Format general tags (names only)
431
+ general_names = [name.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
432
+ for name in general_tags.keys()]
433
+ general_output = ", ".join(general_names)
434
+
435
+ # Format IP detection output
436
+ ips_output = format_ips_output(ips_result, ips_mapping)
437
+
438
+ # Format combined tags (Character tags first, then General tags, then IP tags)
439
+ combined_parts = []
440
+ if character_names:
441
+ combined_parts.append(", ".join(character_names))
442
+ if general_names:
443
+ combined_parts.append(", ".join(general_names))
444
+ if ips_output:
445
+ combined_parts.append(ips_output)
446
+
447
+ combined_output = ", ".join(combined_parts)
448
+
449
+ # Get detailed JSON data
450
+ json_data = {
451
+ "character_tags": character_tags,
452
+ "general_tags": general_tags,
453
+ "rating_tags": rating_tags,
454
+ "ips_result": ips_result,
455
+ "ips_mapping": ips_mapping
456
+ }
457
+
458
+ # Format rating as label-compatible dict
459
+ rating_output = {k.replace("(", "\\(").replace(")", "\\)").replace("_", " "): v
460
+ for k, v in rating_tags.items()}
461
+
462
+ return (
463
+ character_output, # Character tags
464
+ general_output, # General tags
465
+ ips_output, # IP Detection
466
+ combined_output, # Combined tags
467
+ json_data, # Detailed JSON
468
+ rating_output # Rating <- Not working atm
469
+ )
470
+ except Exception as e:
471
+ error_msg = f"Error: {str(e)}"
472
+ # Return error message for all 6 outputs
473
+ return error_msg, error_msg, error_msg, error_msg, {}, {} # 6
474
+
475
+ """GPU"""
476
+ def unload_model():
477
+ """Explicitly unload the current model from memory"""
478
+ global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_D_IPS
479
+ global CURRENT_PREPROCESS_FUNC, CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES
480
+ # Delete the model session
481
+ if CURRENT_MODEL is not None:
482
+ del CURRENT_MODEL
483
+ CURRENT_MODEL = None
484
+ # Clear other large objects
485
+ CURRENT_TAGS_DF = None
486
+ CURRENT_D_IPS = None
487
+ CURRENT_PREPROCESS_FUNC = None
488
+ CURRENT_THRESHOLDS = None
489
+ CURRENT_CATEGORY_NAMES = None
490
+ CURRENT_MODEL_NAME = None
491
+ # Force garbage collection
492
+ import gc
493
+ gc.collect()
494
+ # Clear CUDA cache if using GPU
495
+ try:
496
+ import torch
497
+ if torch.cuda.is_available():
498
+ torch.cuda.empty_cache()
499
+ except ImportError:
500
+ pass
501
+ # print("Model unloaded and memory cleared")
502
+ def cleanup_after_processing():
503
+ unload_model()
504
+
505
+ def process_gallery_images(
506
+ gallery,
507
+ model_name,
508
+ general_threshold,
509
+ character_threshold,
510
+ progress=gr.Progress()
511
+ ):
512
+ """Process all images in the gallery and return results with download file."""
513
+ if not gallery:
514
+ return [], "", "", "", {}, {}, {}, None
515
+
516
+ tag_results = {}
517
+ txt_infos = []
518
+ output_dir = tempfile.mkdtemp()
519
+
520
+ if not os.path.exists(output_dir):
521
+ os.makedirs(output_dir)
522
+
523
+ total_images = len(gallery)
524
+ timer = Timer()
525
+
526
+ try:
527
+ for idx, image_data in enumerate(gallery):
528
+ try:
529
+ image_path = image_data[0] if isinstance(image_data, (list, tuple)) else image_data
530
+
531
+ # Process image
532
+ results = process_single_image(
533
+ image_path, model_name, general_threshold, character_threshold,
534
+ progress, idx, total_images
535
+ )
536
+
537
+ # Store results
538
+ tag_results[image_path] = {
539
+ 'character_tags': results[0],
540
+ 'general_tags': results[1],
541
+ 'ips_detection': results[2],
542
+ 'combined_tags': results[3],
543
+ 'json_data': results[4],
544
+ 'rating': results[5]
545
+ }
546
+
547
+ # Create output files with descriptive names
548
+ image_name = os.path.splitext(os.path.basename(image_path))[0]
549
+
550
+ # Save all output files with descriptive prefixes
551
+ files_to_create = [
552
+ (f"character_tags-{image_name}.txt", results[0]),
553
+ (f"general_tags-{image_name}.txt", results[1]),
554
+ (f"ips_detection-{image_name}.txt", results[2]),
555
+ (f"combined_tags-{image_name}.txt", results[3]),
556
+ (f"detailed_json-{image_name}.json", json.dumps(results[4], indent=4, ensure_ascii=False))
557
+ ]
558
+
559
+ for file_name, content in files_to_create:
560
+ file_path = os.path.join(output_dir, file_name)
561
+ with open(file_path, 'w', encoding='utf-8') as f:
562
+ f.write(content if isinstance(content, str) else content)
563
+ txt_infos.append({'path': file_path, 'name': file_name})
564
+
565
+ # Copy original image
566
+ original_image = Image.open(image_path)
567
+ image_copy_path = os.path.join(output_dir, f"{image_name}{os.path.splitext(image_path)[1]}")
568
+ original_image.save(image_copy_path)
569
+ txt_infos.append({'path': image_copy_path, 'name': f"{image_name}{os.path.splitext(image_path)[1]}"})
570
+
571
+ timer.checkpoint(f"image{idx:02d}, processed")
572
+
573
+ except Exception as e:
574
+ print(f"Error processing image {image_path}: {str(e)}")
575
+ print(traceback.format_exc())
576
+ continue
577
+
578
+ # Create zip file
579
+ download_zip_path = os.path.join(output_dir, f"Multi-Tagger-{datetime.now().strftime('%Y%m%d-%H%M%S')}.zip")
580
+ with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
581
+ for info in txt_infos:
582
+ zipf.write(info['path'], arcname=info['name'])
583
+ # If using GPU, model will auto unload after zip file creation
584
+ cleanup_after_processing() # Comment here to turn off this behavior
585
+
586
+ progress(1.0, desc="Processing complete")
587
+ timer.report_all()
588
+ print('Processing is complete.')
589
+
590
+ # Return first image results as default if available even if we are tagging 1000+ images.
591
+ first_image_results = ("", "", "", {}, {}, "") # 6
592
+ if gallery and len(gallery) > 0:
593
+ first_image_path = gallery[0][0] if isinstance(gallery[0], (list, tuple)) else gallery[0]
594
+ if first_image_path in tag_results:
595
+ result = tag_results[first_image_path]
596
+ first_image_results = (
597
+ result['character_tags'],
598
+ result['general_tags'],
599
+ result['combined_tags'],
600
+ result['json_data'],
601
+ result['rating'],
602
+ result['ips_detection']
603
+ )
604
+
605
+ return tag_results, first_image_results[0], first_image_results[1], first_image_results[2], first_image_results[3], first_image_results[4], first_image_results[5], download_zip_path
606
+
607
+ except Exception as e:
608
+ print(f"Error in process_gallery_images: {str(e)}")
609
+ print(traceback.format_exc())
610
+ progress(1.0, desc="Processing failed")
611
+ return {}, "", "", "", {}, {}, "", None
612
+
613
+ def get_selection_from_gallery(gallery, tag_results, selected_state: gr.SelectData):
614
+ """Handle gallery image selection and update UI with stored results."""
615
+ if not selected_state or not tag_results:
616
+ return "", "", "", {}, {}, ""
617
+
618
+ # Get selected image path
619
+ selected_value = selected_state.value
620
+ if isinstance(selected_value, dict) and 'image' in selected_value:
621
+ image_path = selected_value['image']['path']
622
+ elif isinstance(selected_value, (list, tuple)) and len(selected_value) > 0:
623
+ image_path = selected_value[0]
624
+ else:
625
+ image_path = str(selected_value)
626
+
627
+ # Retrieve stored results
628
+ if image_path in tag_results:
629
+ result = tag_results[image_path]
630
+ return (
631
+ result['character_tags'],
632
+ result['general_tags'],
633
+ result['combined_tags'],
634
+ result['json_data'],
635
+ result['rating'],
636
+ result['ips_detection']
637
+ )
638
+
639
+ # Return empty if not found
640
+ return "", "", "", {}, {}, ""
641
+
642
+ def append_gallery(gallery, image):
643
+ """Add a single media file (image or video) to the gallery."""
644
+ return handle_single_media_upload(image, gallery)
645
+
646
+ def extend_gallery(gallery, images):
647
+ """Add multiple media files (images or videos) to the gallery."""
648
+ return handle_multiple_media_uploads(images, gallery)
649
+
650
+ def create_pixai_interface():
651
+ """Create the PixAI Gradio interface"""
652
+ with gr.Blocks(css=css, fill_width=True) as demo:
653
+ # gr.Markdown("Upload anime-style images to extract tags using PixAI")
654
+ # State to store results
655
+ tag_results = gr.State({})
656
+ selected_image = gr.Textbox(label='Selected Image', visible=False)
657
+
658
+ with gr.Row():
659
+ with gr.Column():
660
+ # Image upload section
661
+ with gr.Column(variant='panel'):
662
+ image_input = gr.Image(
663
+ label='Upload an Image (or paste from clipboard)',
664
+ type='filepath',
665
+ sources=['upload', 'clipboard'],
666
+ height=150
667
+ )
668
+ with gr.Row():
669
+ upload_button = gr.UploadButton(
670
+ 'Upload multiple images or videos',
671
+ file_types=['image', 'video'],
672
+ file_count='multiple',
673
+ size='sm'
674
+ )
675
+ gallery = gr.Gallery(
676
+ columns=2,
677
+ show_share_button=False,
678
+ interactive=True,
679
+ height='auto',
680
+ label='Grid of images',
681
+ preview=False,
682
+ elem_id='custom-gallery'
683
+ )
684
+ run_button = gr.Button("Analyze Images", variant="primary", size='lg')
685
+ clear = gr.ClearButton(components=[gallery], value='Clear Gallery', variant='secondary', size='sm')
686
+ model_dropdown = gr.Dropdown(
687
+ choices=["deepghs/pixai-tagger-v0.9-onnx"],
688
+ value="deepghs/pixai-tagger-v0.9-onnx",
689
+ label="Model"
690
+ )
691
+ # Threshold controls
692
+ with gr.Row():
693
+ general_threshold = gr.Slider(
694
+ minimum=0.0, maximum=1.0, value=0.30, step=0.05,
695
+ label="General Tags Threshold", scale=3
696
+ )
697
+ character_threshold = gr.Slider(
698
+ minimum=0.0, maximum=1.0, value=0.85, step=0.05,
699
+ label="Character Tags Threshold", scale=3
700
+ )
701
+
702
+ with gr.Row():
703
+ clear = gr.ClearButton(
704
+ components=[gallery, model_dropdown, general_threshold, character_threshold],
705
+ value="Clear Everything",
706
+ variant='secondary',
707
+ size='lg'
708
+ )
709
+ clear.add([tag_results])
710
+ detailed_json_output = gr.JSON(label="Detailed JSON")
711
+
712
+ with gr.Column(variant='panel'):
713
+
714
+ download_file = gr.File(label="Download")
715
+
716
+ # Output blocks
717
+ character_tags_output = gr.Textbox(
718
+ label="Character tags",
719
+ show_copy_button=True,
720
+ lines=3
721
+ )
722
+ general_tags_output = gr.Textbox(
723
+ label="General tags",
724
+ show_copy_button=True,
725
+ lines=3
726
+ )
727
+ ips_detection_output = gr.Textbox(
728
+ label="IPs Detection",
729
+ show_copy_button=True,
730
+ lines=5
731
+ )
732
+ combined_tags_output = gr.Textbox(
733
+ label="Combined tags",
734
+ show_copy_button=True,
735
+ lines=6
736
+ )
737
+ rating_output = gr.Label(label="Rating")
738
+
739
+ # Clear button targets
740
+ clear.add([
741
+ download_file,
742
+ character_tags_output,
743
+ general_tags_output,
744
+ ips_detection_output,
745
+ combined_tags_output,
746
+ rating_output,
747
+ detailed_json_output
748
+ ])
749
+
750
+ # Event handlers
751
+ image_input.change(
752
+ append_gallery,
753
+ inputs=[gallery, image_input],
754
+ outputs=[gallery, image_input]
755
+ )
756
+
757
+ upload_button.upload(
758
+ extend_gallery,
759
+ inputs=[gallery, upload_button],
760
+ outputs=gallery
761
+ )
762
+
763
+ gallery.select(
764
+ get_selection_from_gallery,
765
+ inputs=[gallery, tag_results],
766
+ outputs=[
767
+ character_tags_output,
768
+ general_tags_output,
769
+ combined_tags_output,
770
+ detailed_json_output,
771
+ rating_output,
772
+ ips_detection_output
773
+ ]
774
+ )
775
+
776
+ run_button.click(
777
+ process_gallery_images,
778
+ inputs=[gallery, model_dropdown, general_threshold, character_threshold],
779
+ outputs=[
780
+ tag_results,
781
+ character_tags_output,
782
+ general_tags_output,
783
+ combined_tags_output,
784
+ detailed_json_output,
785
+ rating_output,
786
+ ips_detection_output,
787
+ download_file
788
+ ]
789
+ )
790
+
791
+ gr.Markdown('[Based on Source code for imgutils.tagging.pixai](https://dghs-imgutils.deepghs.org/main/_modules/imgutils/tagging/pixai.html) & [pixai-labs/pixai-tagger-demo](https://huggingface.co/spaces/pixai-labs/pixai-tagger-demo)')
792
+
793
+ return demo
794
+
795
+ # Export public API
796
+ __all__ = [
797
+ 'get_pixai_tags',
798
+ 'process_single_image',
799
+ 'process_gallery_images',
800
+ 'create_pixai_interface',
801
+ 'unload_model',
802
+ 'cleanup_after_processing'
803
+ ]