Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import pandas as pd | |
| import numpy as np | |
| import onnxruntime as ort | |
| import huggingface_hub | |
| from PIL import Image | |
| from PIL.Image import Resampling | |
| from modules.taggers.base import TaggerProcessor | |
| # --- PIXAI Tagger Global Durum --- | |
| PIXAI_MODEL = None | |
| PIXAI_MODEL_NAME = None | |
| PIXAI_TAGS_DF = None | |
| PIXAI_D_IPS = None | |
| PIXAI_PREPROCESS_FUNC = None | |
| PIXAI_THRESHOLDS = None | |
| PIXAI_CATEGORY_NAMES = None | |
| def _download_pixai_files(model_name: str): | |
| repo_id = model_name if '/' in model_name else f'deepghs/pixai-tagger-{model_name}-onnx' | |
| return ( | |
| huggingface_hub.hf_hub_download(repo_id=repo_id, filename='model.onnx', library_name="pixai-tagger"), | |
| huggingface_hub.hf_hub_download(repo_id=repo_id, filename='selected_tags.csv', library_name="pixai-tagger"), | |
| huggingface_hub.hf_hub_download(repo_id=repo_id, filename='preprocess.json', library_name="pixai-tagger"), | |
| huggingface_hub.hf_hub_download(repo_id=repo_id, filename='thresholds.csv', library_name="pixai-tagger") | |
| ) | |
| def _load_pixai_model_components(device_pref: str): | |
| global PIXAI_MODEL, PIXAI_MODEL_NAME, PIXAI_TAGS_DF, PIXAI_D_IPS, PIXAI_PREPROCESS_FUNC, PIXAI_THRESHOLDS, PIXAI_CATEGORY_NAMES | |
| model_name = 'deepghs/pixai-tagger-v0.9-onnx' | |
| if PIXAI_MODEL_NAME != model_name: | |
| try: | |
| m_path, t_path, p_path, th_path = _download_pixai_files(model_name) | |
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device_pref == "CUDA" else ['CPUExecutionProvider'] | |
| PIXAI_MODEL = ort.InferenceSession(m_path, providers=providers) | |
| PIXAI_TAGS_DF = pd.read_csv(t_path) | |
| PIXAI_D_IPS = {} | |
| if 'ips' in PIXAI_TAGS_DF.columns: | |
| PIXAI_TAGS_DF['ips'] = PIXAI_TAGS_DF['ips'].apply(lambda x: json.loads(x) if pd.notna(x) and x != '{}' else {}) | |
| def transform(img): | |
| if img.mode != 'RGB': img = img.convert('RGB') | |
| img = img.resize((448, 448), Resampling.LANCZOS) | |
| img_array = np.array(img).astype(np.float32) / 255.0 | |
| mean = np.array([0.48145466, 0.4578275, 0.40821073]).astype(np.float32) | |
| std = np.array([0.26862954, 0.26130258, 0.27577711]).astype(np.float32) | |
| img_array = (img_array - mean) / std | |
| return np.transpose(img_array, (2, 0, 1)) | |
| PIXAI_PREPROCESS_FUNC = transform | |
| if th_path and os.path.exists(th_path): | |
| df_th = pd.read_csv(th_path) | |
| PIXAI_THRESHOLDS = {row['category']: row['threshold'] for _, row in df_th.iterrows()} | |
| PIXAI_CATEGORY_NAMES = {row['category']: row['name'] for _, row in df_th.iterrows()} | |
| else: | |
| PIXAI_THRESHOLDS = {0: 0.3, 4: 0.85, 9: 0.85} | |
| PIXAI_CATEGORY_NAMES = {0: 'general', 4: 'character', 9: 'rating'} | |
| PIXAI_MODEL_NAME = model_name | |
| except Exception as e: print(f"PixAI yükleme hatası: {e}"); raise | |
| return PIXAI_MODEL, PIXAI_TAGS_DF, PIXAI_D_IPS, PIXAI_PREPROCESS_FUNC, PIXAI_THRESHOLDS, PIXAI_CATEGORY_NAMES | |
| def get_pixai_tags(image: Image, thresholds: dict, device_pref: str): | |
| model, df_tags, _, preprocess, default_thresh, cat_names = _load_pixai_model_components(device_pref) | |
| input_tensor = preprocess(image) | |
| if len(input_tensor.shape) == 3: input_tensor = np.expand_dims(input_tensor, axis=0) | |
| out = model.run(None, {'input': input_tensor.astype(np.float32)})[0][0] | |
| mapped_thresh = {} | |
| for cat_id, cat_name in cat_names.items(): | |
| if cat_name == 'general': mapped_thresh[cat_id] = thresholds.get('pixai_general_thresh', default_thresh.get(cat_id, 0.3)) | |
| elif cat_name in ['character', 'copyright', 'artist']: mapped_thresh[cat_id] = thresholds.get('pixai_char_thresh', default_thresh.get(cat_id, 0.85)) | |
| else: mapped_thresh[cat_id] = default_thresh.get(cat_id, 0.85) | |
| all_tags = [] | |
| for cat in sorted(set(df_tags['category'])): | |
| mask = df_tags['category'] == cat | |
| names = df_tags.loc[mask, 'name'] | |
| preds = out[mask] | |
| thresh = mapped_thresh.get(cat, 0.85) | |
| sel_mask = preds >= thresh | |
| for n, s in zip(names[sel_mask], preds[sel_mask]): all_tags.append((n, float(s))) | |
| return ", ".join([t[0] for t in all_tags]), [t[0].replace("_", " ") for t in sorted(all_tags, key=lambda x: x[1], reverse=True)] | |
| class PixaiTaggerProcessor(TaggerProcessor): | |
| def predict(self, image, pixai_general_thresh, pixai_char_thresh, replacement_file_path, synonym_file_path, addition_file_path, sort_order="Alfabetik", device_pref: str = "Auto"): | |
| self.replacement_file = replacement_file_path | |
| self.synonym_file = synonym_file_path | |
| self.addition_file = addition_file_path | |
| if PIXAI_MODEL is None: | |
| try: _load_pixai_model_components(device_pref) | |
| except Exception as e: return "", f"❌ PixAI Tagger modülü yüklenemedi: {e}", [] | |
| if image is None: return "", "⚠️ Resim yüklenmedi.", [] | |
| try: | |
| thresholds = {'pixai_general_thresh': pixai_general_thresh, 'pixai_char_thresh': pixai_char_thresh} | |
| ai_tags_string_raw, original_order_for_pixai = get_pixai_tags(image, thresholds, device_pref) | |
| final_tags = self.process_tags(ai_tags_string_raw, sort_order, original_order_for_pixai) | |
| return final_tags, "✅ PixAI işlemi tamamlandı!", original_order_for_pixai | |
| except Exception as e: | |
| return f"Hata: {e}", f"❌ PixAI hata: {e}", [] | |