Hydragee's picture
Upload folder using huggingface_hub
772b344 verified
import os
import json
import pandas as pd
import numpy as np
import onnxruntime as ort
import huggingface_hub
from PIL import Image
from PIL.Image import Resampling
from modules.taggers.base import TaggerProcessor
# --- PIXAI Tagger Global Durum ---
PIXAI_MODEL = None
PIXAI_MODEL_NAME = None
PIXAI_TAGS_DF = None
PIXAI_D_IPS = None
PIXAI_PREPROCESS_FUNC = None
PIXAI_THRESHOLDS = None
PIXAI_CATEGORY_NAMES = None
def _download_pixai_files(model_name: str):
repo_id = model_name if '/' in model_name else f'deepghs/pixai-tagger-{model_name}-onnx'
return (
huggingface_hub.hf_hub_download(repo_id=repo_id, filename='model.onnx', library_name="pixai-tagger"),
huggingface_hub.hf_hub_download(repo_id=repo_id, filename='selected_tags.csv', library_name="pixai-tagger"),
huggingface_hub.hf_hub_download(repo_id=repo_id, filename='preprocess.json', library_name="pixai-tagger"),
huggingface_hub.hf_hub_download(repo_id=repo_id, filename='thresholds.csv', library_name="pixai-tagger")
)
def _load_pixai_model_components(device_pref: str):
global PIXAI_MODEL, PIXAI_MODEL_NAME, PIXAI_TAGS_DF, PIXAI_D_IPS, PIXAI_PREPROCESS_FUNC, PIXAI_THRESHOLDS, PIXAI_CATEGORY_NAMES
model_name = 'deepghs/pixai-tagger-v0.9-onnx'
if PIXAI_MODEL_NAME != model_name:
try:
m_path, t_path, p_path, th_path = _download_pixai_files(model_name)
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device_pref == "CUDA" else ['CPUExecutionProvider']
PIXAI_MODEL = ort.InferenceSession(m_path, providers=providers)
PIXAI_TAGS_DF = pd.read_csv(t_path)
PIXAI_D_IPS = {}
if 'ips' in PIXAI_TAGS_DF.columns:
PIXAI_TAGS_DF['ips'] = PIXAI_TAGS_DF['ips'].apply(lambda x: json.loads(x) if pd.notna(x) and x != '{}' else {})
def transform(img):
if img.mode != 'RGB': img = img.convert('RGB')
img = img.resize((448, 448), Resampling.LANCZOS)
img_array = np.array(img).astype(np.float32) / 255.0
mean = np.array([0.48145466, 0.4578275, 0.40821073]).astype(np.float32)
std = np.array([0.26862954, 0.26130258, 0.27577711]).astype(np.float32)
img_array = (img_array - mean) / std
return np.transpose(img_array, (2, 0, 1))
PIXAI_PREPROCESS_FUNC = transform
if th_path and os.path.exists(th_path):
df_th = pd.read_csv(th_path)
PIXAI_THRESHOLDS = {row['category']: row['threshold'] for _, row in df_th.iterrows()}
PIXAI_CATEGORY_NAMES = {row['category']: row['name'] for _, row in df_th.iterrows()}
else:
PIXAI_THRESHOLDS = {0: 0.3, 4: 0.85, 9: 0.85}
PIXAI_CATEGORY_NAMES = {0: 'general', 4: 'character', 9: 'rating'}
PIXAI_MODEL_NAME = model_name
except Exception as e: print(f"PixAI yükleme hatası: {e}"); raise
return PIXAI_MODEL, PIXAI_TAGS_DF, PIXAI_D_IPS, PIXAI_PREPROCESS_FUNC, PIXAI_THRESHOLDS, PIXAI_CATEGORY_NAMES
def get_pixai_tags(image: Image, thresholds: dict, device_pref: str):
model, df_tags, _, preprocess, default_thresh, cat_names = _load_pixai_model_components(device_pref)
input_tensor = preprocess(image)
if len(input_tensor.shape) == 3: input_tensor = np.expand_dims(input_tensor, axis=0)
out = model.run(None, {'input': input_tensor.astype(np.float32)})[0][0]
mapped_thresh = {}
for cat_id, cat_name in cat_names.items():
if cat_name == 'general': mapped_thresh[cat_id] = thresholds.get('pixai_general_thresh', default_thresh.get(cat_id, 0.3))
elif cat_name in ['character', 'copyright', 'artist']: mapped_thresh[cat_id] = thresholds.get('pixai_char_thresh', default_thresh.get(cat_id, 0.85))
else: mapped_thresh[cat_id] = default_thresh.get(cat_id, 0.85)
all_tags = []
for cat in sorted(set(df_tags['category'])):
mask = df_tags['category'] == cat
names = df_tags.loc[mask, 'name']
preds = out[mask]
thresh = mapped_thresh.get(cat, 0.85)
sel_mask = preds >= thresh
for n, s in zip(names[sel_mask], preds[sel_mask]): all_tags.append((n, float(s)))
return ", ".join([t[0] for t in all_tags]), [t[0].replace("_", " ") for t in sorted(all_tags, key=lambda x: x[1], reverse=True)]
class PixaiTaggerProcessor(TaggerProcessor):
def predict(self, image, pixai_general_thresh, pixai_char_thresh, replacement_file_path, synonym_file_path, addition_file_path, sort_order="Alfabetik", device_pref: str = "Auto"):
self.replacement_file = replacement_file_path
self.synonym_file = synonym_file_path
self.addition_file = addition_file_path
if PIXAI_MODEL is None:
try: _load_pixai_model_components(device_pref)
except Exception as e: return "", f"❌ PixAI Tagger modülü yüklenemedi: {e}", []
if image is None: return "", "⚠️ Resim yüklenmedi.", []
try:
thresholds = {'pixai_general_thresh': pixai_general_thresh, 'pixai_char_thresh': pixai_char_thresh}
ai_tags_string_raw, original_order_for_pixai = get_pixai_tags(image, thresholds, device_pref)
final_tags = self.process_tags(ai_tags_string_raw, sort_order, original_order_for_pixai)
return final_tags, "✅ PixAI işlemi tamamlandı!", original_order_for_pixai
except Exception as e:
return f"Hata: {e}", f"❌ PixAI hata: {e}", []