import os import timm import torch import huggingface_hub from safetensors import safe_open from PIL import Image from modules.hydra_layers import HydraPool from modules.taggers.image_utils import process_image_jtp, patchify_image from modules.taggers.base import TaggerProcessor # Global State INITIAL_TORCH_DEVICE = ["cpu", "cuda"][torch.cuda.is_available()] JOINT_MODEL = None JOINT_TAGS = [] PATCH_SIZE = 16 MAX_SEQ_LEN = 1024 def get_torch_device(device_pref: str) -> str: if device_pref == "CUDA" and torch.cuda.is_available(): return "cuda" elif device_pref == "Auto" and torch.cuda.is_available(): return "cuda" return "cpu" run_joint_classifier = None # Initialize Model Loading on Import (or lazily) # To preserve behavior, we'll try to load it immediately but wrap in try/except try: print("Joint Tagger (JTP-3 Hydra) Yükleniyor...") jtp3_path = huggingface_hub.hf_hub_download(repo_id="RedRocket/JTP-3", filename="models/jtp-3-hydra.safetensors") with safe_open(jtp3_path, framework="pt", device="cpu") as f: metadata = f.metadata() state_dict = {key: f.get_tensor(key) for key in f.keys()} tags = metadata["classifier.labels"].split("\n") JOINT_TAGS = [t.replace("_", " ").replace("vulva", "pussy") for t in tags] joint_model = timm.create_model( 'naflexvit_so400m_patch16_siglip', pretrained=False, num_classes=0, pos_embed_interp_mode="bilinear", weight_init="skip", fix_init=False, device="cpu", dtype=torch.bfloat16 ) joint_model.attn_pool = HydraPool.for_state(state_dict, "attn_pool.", device="cpu", dtype=torch.bfloat16) joint_model.head = joint_model.attn_pool.create_head() joint_model.num_classes = len(tags) joint_model.load_state_dict(state_dict, strict=False) joint_model.attn_pool._q_normed = True joint_model.eval().to(dtype=torch.bfloat16) joint_model.to(INITIAL_TORCH_DEVICE) JOINT_MODEL = joint_model def run_joint_classifier_func(image: Image, threshold, execution_device: str): device_for_tensor = get_torch_device(execution_device) processed_img = process_image_jtp(image, PATCH_SIZE, MAX_SEQ_LEN) patches, patch_coords, patch_valid = patchify_image(processed_img, PATCH_SIZE, MAX_SEQ_LEN) patches = patches.unsqueeze(0).to(device=device_for_tensor, non_blocking=True) patch_coords = patch_coords.unsqueeze(0).to(device=device_for_tensor, non_blocking=True) patch_valid = patch_valid.unsqueeze(0).to(device=device_for_tensor, non_blocking=True) patches = patches.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0) patch_coords = patch_coords.to(dtype=torch.int32) if next(JOINT_MODEL.parameters()).device.type != device_for_tensor: JOINT_MODEL.to(device_for_tensor) with torch.no_grad(): features = JOINT_MODEL.forward_intermediates( patches, patch_coord=patch_coords, patch_valid=patch_valid, output_dict=True, output_fmt='NLC' ) logits = JOINT_MODEL.forward_head(features["image_features"], patch_valid=patch_valid) probits = logits[0].float().sigmoid_().mul_(2.0).sub_(1.0) values, indices = probits.cpu().topk(len(JOINT_TAGS)) raw_results = [] for idx, val in zip(indices, values): score = val.item() if score >= threshold: raw_results.append((JOINT_TAGS[idx.item()], score)) text_no_impl = ", ".join([t[0] for t in raw_results]) sorted_tag_score = dict(raw_results) return text_no_impl, sorted_tag_score run_joint_classifier = run_joint_classifier_func print(f"JTP-3 Hydra Modeli Başarıyla Yüklendi ({INITIAL_TORCH_DEVICE})") except Exception as e: print(f"Joint Tagger (JTP-3) yüklenirken hata: {e}") run_joint_classifier = None class JointTaggerProcessor(TaggerProcessor): def predict(self, image, threshold, replacement_file_path, synonym_file_path, addition_file_path, sort_order="Alfabetik", device_pref: str = "Auto"): self.replacement_file = replacement_file_path self.synonym_file = synonym_file_path self.addition_file = addition_file_path if run_joint_classifier is None: return "", "❌ Joint Tagger (JTP-3) yüklenemedi.", [] if image is None: return "", "⚠️ Resim yüklenmedi.", [] try: ai_tags_string_raw, raw_tags_sorted_by_confidence = run_joint_classifier(image, threshold, device_pref) original_order_for_joint = list(raw_tags_sorted_by_confidence.keys()) final_tags = self.process_tags(ai_tags_string_raw, sort_order, original_order_for_joint) return final_tags, "✅ Joint (JTP-3) işlemi tamamlandı!", original_order_for_joint except Exception as e: return f"Hata: {e}", f"❌ Joint hata: {e}", []