Hydragee's picture
Upload folder using huggingface_hub
772b344 verified
import os
import timm
import torch
import huggingface_hub
from safetensors import safe_open
from PIL import Image
from modules.hydra_layers import HydraPool
from modules.taggers.image_utils import process_image_jtp, patchify_image
from modules.taggers.base import TaggerProcessor
# Global State
INITIAL_TORCH_DEVICE = ["cpu", "cuda"][torch.cuda.is_available()]
JOINT_MODEL = None
JOINT_TAGS = []
PATCH_SIZE = 16
MAX_SEQ_LEN = 1024
def get_torch_device(device_pref: str) -> str:
if device_pref == "CUDA" and torch.cuda.is_available(): return "cuda"
elif device_pref == "Auto" and torch.cuda.is_available(): return "cuda"
return "cpu"
run_joint_classifier = None
# Initialize Model Loading on Import (or lazily)
# To preserve behavior, we'll try to load it immediately but wrap in try/except
try:
print("Joint Tagger (JTP-3 Hydra) Yükleniyor...")
jtp3_path = huggingface_hub.hf_hub_download(repo_id="RedRocket/JTP-3", filename="models/jtp-3-hydra.safetensors")
with safe_open(jtp3_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
state_dict = {key: f.get_tensor(key) for key in f.keys()}
tags = metadata["classifier.labels"].split("\n")
JOINT_TAGS = [t.replace("_", " ").replace("vulva", "pussy") for t in tags]
joint_model = timm.create_model(
'naflexvit_so400m_patch16_siglip',
pretrained=False, num_classes=0,
pos_embed_interp_mode="bilinear",
weight_init="skip", fix_init=False,
device="cpu", dtype=torch.bfloat16
)
joint_model.attn_pool = HydraPool.for_state(state_dict, "attn_pool.", device="cpu", dtype=torch.bfloat16)
joint_model.head = joint_model.attn_pool.create_head()
joint_model.num_classes = len(tags)
joint_model.load_state_dict(state_dict, strict=False)
joint_model.attn_pool._q_normed = True
joint_model.eval().to(dtype=torch.bfloat16)
joint_model.to(INITIAL_TORCH_DEVICE)
JOINT_MODEL = joint_model
def run_joint_classifier_func(image: Image, threshold, execution_device: str):
device_for_tensor = get_torch_device(execution_device)
processed_img = process_image_jtp(image, PATCH_SIZE, MAX_SEQ_LEN)
patches, patch_coords, patch_valid = patchify_image(processed_img, PATCH_SIZE, MAX_SEQ_LEN)
patches = patches.unsqueeze(0).to(device=device_for_tensor, non_blocking=True)
patch_coords = patch_coords.unsqueeze(0).to(device=device_for_tensor, non_blocking=True)
patch_valid = patch_valid.unsqueeze(0).to(device=device_for_tensor, non_blocking=True)
patches = patches.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0)
patch_coords = patch_coords.to(dtype=torch.int32)
if next(JOINT_MODEL.parameters()).device.type != device_for_tensor:
JOINT_MODEL.to(device_for_tensor)
with torch.no_grad():
features = JOINT_MODEL.forward_intermediates(
patches,
patch_coord=patch_coords,
patch_valid=patch_valid,
output_dict=True,
output_fmt='NLC'
)
logits = JOINT_MODEL.forward_head(features["image_features"], patch_valid=patch_valid)
probits = logits[0].float().sigmoid_().mul_(2.0).sub_(1.0)
values, indices = probits.cpu().topk(len(JOINT_TAGS))
raw_results = []
for idx, val in zip(indices, values):
score = val.item()
if score >= threshold:
raw_results.append((JOINT_TAGS[idx.item()], score))
text_no_impl = ", ".join([t[0] for t in raw_results])
sorted_tag_score = dict(raw_results)
return text_no_impl, sorted_tag_score
run_joint_classifier = run_joint_classifier_func
print(f"JTP-3 Hydra Modeli Başarıyla Yüklendi ({INITIAL_TORCH_DEVICE})")
except Exception as e:
print(f"Joint Tagger (JTP-3) yüklenirken hata: {e}")
run_joint_classifier = None
class JointTaggerProcessor(TaggerProcessor):
def predict(self, image, threshold, replacement_file_path, synonym_file_path, addition_file_path, sort_order="Alfabetik", device_pref: str = "Auto"):
self.replacement_file = replacement_file_path
self.synonym_file = synonym_file_path
self.addition_file = addition_file_path
if run_joint_classifier is None: return "", "❌ Joint Tagger (JTP-3) yüklenemedi.", []
if image is None: return "", "⚠️ Resim yüklenmedi.", []
try:
ai_tags_string_raw, raw_tags_sorted_by_confidence = run_joint_classifier(image, threshold, device_pref)
original_order_for_joint = list(raw_tags_sorted_by_confidence.keys())
final_tags = self.process_tags(ai_tags_string_raw, sort_order, original_order_for_joint)
return final_tags, "✅ Joint (JTP-3) işlemi tamamlandı!", original_order_for_joint
except Exception as e:
return f"Hata: {e}", f"❌ Joint hata: {e}", []