import os import json import torch import torch.nn.functional as F import timm import numpy as np import gradio as gr import requests from io import BytesIO from PIL import Image from huggingface_hub import hf_hub_download # ============== 配置参数 ============== class Config: model_repo = "telecomadm1145/cmodel_v2_test" weights_file = "pytorch_model.bin" name_mapping_file = "label_id_mapping.json" rev = "6350d5e35f883ca058bbc84a82853407874b68da" model_name = "convnext_base.dinov3_lvd1689m" image_size = 384 device = "cuda" if torch.cuda.is_available() else "cpu" npy_file = "target_vector.npy" cfg = Config() # ============== 2. 核心逻辑 ============== def download_file(repo_id, filename, rev): print(f"📥 Downloading {filename} from {repo_id}...") try: return hf_hub_download(repo_id=repo_id, filename=filename, revision=rev) except Exception as e: print(f"⚠️ Failed to download {filename}: {e}") return None # ============== 相似度映射 ============== def map_similarity(sim: float) -> float: """将余弦相似度线性映射到 [0, 1],原始区间 [0.8, 0.9]""" mapped = (sim - 0.75) / (0.85 - 0.75) return max(0.0, min(1.0, mapped)) # --- 向量匹配核心 --- class SingleVectorMatcher: def __init__(self): print(f"🔄 Initializing on device: {cfg.device}") # 1. 初始化模型架构 name_map_path = download_file(cfg.model_repo, cfg.name_mapping_file, "main") with open(name_map_path, 'r', encoding='utf-8') as f: full_map = json.load(f) num_classes = len(full_map.get('id_to_model_name', {})) + len(full_map.get('id_to_base_model', {})) self.model = timm.create_model(cfg.model_name, pretrained=False, num_classes=num_classes) # 2. 加载权重 weights_path = download_file(cfg.model_repo, cfg.weights_file, cfg.rev) state_dict = torch.load(weights_path, map_location='cpu') self.model.load_state_dict(state_dict) self.model.to(cfg.device) self.model.eval() # 3. 图像预处理 self.transform = timm.data.create_transform( input_size=(3, cfg.image_size, cfg.image_size), is_training=False, mean=timm.data.IMAGENET_DEFAULT_MEAN, std=timm.data.IMAGENET_DEFAULT_STD ) # 4. 加载单一向量 NPY self.target_tensor = None self._load_single_npy() def _load_single_npy(self): if not os.path.exists(cfg.npy_file): print(f"❌ Error: '{cfg.npy_file}' not found.") return print(f"🎯 Loading target vector from {cfg.npy_file}...") try: vector_array = np.load(cfg.npy_file) tensor = torch.tensor(vector_array, dtype=torch.float32) if tensor.dim() == 1: tensor = tensor.unsqueeze(0) self.target_tensor = tensor.to(cfg.device) print(f"✅ Target vector loaded. Shape: {self.target_tensor.shape}") except Exception as e: print(f"❌ Failed to load {cfg.npy_file}: {e}") @torch.no_grad() def match(self, image: Image.Image): if image is None: return "请先上传图片或输入 QQ 号获取头像。" if self.target_tensor is None: return f"初始化失败:未找到目标向量文件 `{cfg.npy_file}`。" if image.mode != 'RGB': image = image.convert('RGB') img_tensor = self.transform(image).unsqueeze(0).to(cfg.device) features = self.model.forward_features(img_tensor) embedding = self.model.forward_head(features, pre_logits=True) raw_sim = F.cosine_similarity(embedding, self.target_tensor).item() mapped_sim = map_similarity(raw_sim) # 风格判断 if mapped_sim > 0.85: label = "高度相似" desc = "图像风格与 Nano Banana 高度吻合,具有强烈的标志性特征。" color = "#4096ff" elif mapped_sim > 0.65: label = "较为相似" desc = "图像风格与 Nano Banana 有一定相似度,部分特征较为接近。" color = "#36cfc9" elif mapped_sim > 0.45: label = "轻微相似" desc = "图像与 Nano Banana 风格存在少量共同特征,整体差异明显。" color = "#9254de" else: label = "风格不符" desc = "图像风格与 Nano Banana 差异显著,几乎不具备相关特征。" color = "#8c8c8c" md_out = f"""
风格相似度评估
{desc}
(x − 0.80) / (0.90 − 0.80),裁剪至 [0, 1]