Spaces:
Runtime error
Runtime error
| import os | |
| import timm | |
| import torch | |
| import huggingface_hub | |
| from safetensors import safe_open | |
| from PIL import Image | |
| from modules.hydra_layers import HydraPool | |
| from modules.taggers.image_utils import process_image_jtp, patchify_image | |
| from modules.taggers.base import TaggerProcessor | |
| # Global State | |
| INITIAL_TORCH_DEVICE = ["cpu", "cuda"][torch.cuda.is_available()] | |
| JOINT_MODEL = None | |
| JOINT_TAGS = [] | |
| PATCH_SIZE = 16 | |
| MAX_SEQ_LEN = 1024 | |
| def get_torch_device(device_pref: str) -> str: | |
| if device_pref == "CUDA" and torch.cuda.is_available(): return "cuda" | |
| elif device_pref == "Auto" and torch.cuda.is_available(): return "cuda" | |
| return "cpu" | |
| run_joint_classifier = None | |
| # Initialize Model Loading on Import (or lazily) | |
| # To preserve behavior, we'll try to load it immediately but wrap in try/except | |
| try: | |
| print("Joint Tagger (JTP-3 Hydra) Yükleniyor...") | |
| jtp3_path = huggingface_hub.hf_hub_download(repo_id="RedRocket/JTP-3", filename="models/jtp-3-hydra.safetensors") | |
| with safe_open(jtp3_path, framework="pt", device="cpu") as f: | |
| metadata = f.metadata() | |
| state_dict = {key: f.get_tensor(key) for key in f.keys()} | |
| tags = metadata["classifier.labels"].split("\n") | |
| JOINT_TAGS = [t.replace("_", " ").replace("vulva", "pussy") for t in tags] | |
| joint_model = timm.create_model( | |
| 'naflexvit_so400m_patch16_siglip', | |
| pretrained=False, num_classes=0, | |
| pos_embed_interp_mode="bilinear", | |
| weight_init="skip", fix_init=False, | |
| device="cpu", dtype=torch.bfloat16 | |
| ) | |
| joint_model.attn_pool = HydraPool.for_state(state_dict, "attn_pool.", device="cpu", dtype=torch.bfloat16) | |
| joint_model.head = joint_model.attn_pool.create_head() | |
| joint_model.num_classes = len(tags) | |
| joint_model.load_state_dict(state_dict, strict=False) | |
| joint_model.attn_pool._q_normed = True | |
| joint_model.eval().to(dtype=torch.bfloat16) | |
| joint_model.to(INITIAL_TORCH_DEVICE) | |
| JOINT_MODEL = joint_model | |
| def run_joint_classifier_func(image: Image, threshold, execution_device: str): | |
| device_for_tensor = get_torch_device(execution_device) | |
| processed_img = process_image_jtp(image, PATCH_SIZE, MAX_SEQ_LEN) | |
| patches, patch_coords, patch_valid = patchify_image(processed_img, PATCH_SIZE, MAX_SEQ_LEN) | |
| patches = patches.unsqueeze(0).to(device=device_for_tensor, non_blocking=True) | |
| patch_coords = patch_coords.unsqueeze(0).to(device=device_for_tensor, non_blocking=True) | |
| patch_valid = patch_valid.unsqueeze(0).to(device=device_for_tensor, non_blocking=True) | |
| patches = patches.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0) | |
| patch_coords = patch_coords.to(dtype=torch.int32) | |
| if next(JOINT_MODEL.parameters()).device.type != device_for_tensor: | |
| JOINT_MODEL.to(device_for_tensor) | |
| with torch.no_grad(): | |
| features = JOINT_MODEL.forward_intermediates( | |
| patches, | |
| patch_coord=patch_coords, | |
| patch_valid=patch_valid, | |
| output_dict=True, | |
| output_fmt='NLC' | |
| ) | |
| logits = JOINT_MODEL.forward_head(features["image_features"], patch_valid=patch_valid) | |
| probits = logits[0].float().sigmoid_().mul_(2.0).sub_(1.0) | |
| values, indices = probits.cpu().topk(len(JOINT_TAGS)) | |
| raw_results = [] | |
| for idx, val in zip(indices, values): | |
| score = val.item() | |
| if score >= threshold: | |
| raw_results.append((JOINT_TAGS[idx.item()], score)) | |
| text_no_impl = ", ".join([t[0] for t in raw_results]) | |
| sorted_tag_score = dict(raw_results) | |
| return text_no_impl, sorted_tag_score | |
| run_joint_classifier = run_joint_classifier_func | |
| print(f"JTP-3 Hydra Modeli Başarıyla Yüklendi ({INITIAL_TORCH_DEVICE})") | |
| except Exception as e: | |
| print(f"Joint Tagger (JTP-3) yüklenirken hata: {e}") | |
| run_joint_classifier = None | |
| class JointTaggerProcessor(TaggerProcessor): | |
| def predict(self, image, threshold, replacement_file_path, synonym_file_path, addition_file_path, sort_order="Alfabetik", device_pref: str = "Auto"): | |
| self.replacement_file = replacement_file_path | |
| self.synonym_file = synonym_file_path | |
| self.addition_file = addition_file_path | |
| if run_joint_classifier is None: return "", "❌ Joint Tagger (JTP-3) yüklenemedi.", [] | |
| if image is None: return "", "⚠️ Resim yüklenmedi.", [] | |
| try: | |
| ai_tags_string_raw, raw_tags_sorted_by_confidence = run_joint_classifier(image, threshold, device_pref) | |
| original_order_for_joint = list(raw_tags_sorted_by_confidence.keys()) | |
| final_tags = self.process_tags(ai_tags_string_raw, sort_order, original_order_for_joint) | |
| return final_tags, "✅ Joint (JTP-3) işlemi tamamlandı!", original_order_for_joint | |
| except Exception as e: | |
| return f"Hata: {e}", f"❌ Joint hata: {e}", [] | |