Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import torch | |
| import torch.nn.functional as F | |
| import timm | |
| import numpy as np | |
| import gradio as gr | |
| import requests | |
| from io import BytesIO | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| # ============== 配置参数 ============== | |
| class Config: | |
| model_repo = "telecomadm1145/cmodel_v2_test" | |
| weights_file = "pytorch_model.bin" | |
| name_mapping_file = "label_id_mapping.json" | |
| rev = "6350d5e35f883ca058bbc84a82853407874b68da" | |
| model_name = "convnext_base.dinov3_lvd1689m" | |
| image_size = 384 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| npy_file = "target_vector.npy" | |
| cfg = Config() | |
| # ============== 2. 核心逻辑 ============== | |
| def download_file(repo_id, filename, rev): | |
| print(f"📥 Downloading {filename} from {repo_id}...") | |
| try: | |
| return hf_hub_download(repo_id=repo_id, filename=filename, revision=rev) | |
| except Exception as e: | |
| print(f"⚠️ Failed to download {filename}: {e}") | |
| return None | |
| # ============== 相似度映射 ============== | |
| def map_similarity(sim: float) -> float: | |
| """将余弦相似度线性映射到 [0, 1],原始区间 [0.8, 0.9]""" | |
| mapped = (sim - 0.75) / (0.85 - 0.75) | |
| return max(0.0, min(1.0, mapped)) | |
| # --- 向量匹配核心 --- | |
| class SingleVectorMatcher: | |
| def __init__(self): | |
| print(f"🔄 Initializing on device: {cfg.device}") | |
| # 1. 初始化模型架构 | |
| name_map_path = download_file(cfg.model_repo, cfg.name_mapping_file, "main") | |
| with open(name_map_path, 'r', encoding='utf-8') as f: | |
| full_map = json.load(f) | |
| num_classes = len(full_map.get('id_to_model_name', {})) + len(full_map.get('id_to_base_model', {})) | |
| self.model = timm.create_model(cfg.model_name, pretrained=False, num_classes=num_classes) | |
| # 2. 加载权重 | |
| weights_path = download_file(cfg.model_repo, cfg.weights_file, cfg.rev) | |
| state_dict = torch.load(weights_path, map_location='cpu') | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(cfg.device) | |
| self.model.eval() | |
| # 3. 图像预处理 | |
| self.transform = timm.data.create_transform( | |
| input_size=(3, cfg.image_size, cfg.image_size), | |
| is_training=False, | |
| mean=timm.data.IMAGENET_DEFAULT_MEAN, | |
| std=timm.data.IMAGENET_DEFAULT_STD | |
| ) | |
| # 4. 加载单一向量 NPY | |
| self.target_tensor = None | |
| self._load_single_npy() | |
| def _load_single_npy(self): | |
| if not os.path.exists(cfg.npy_file): | |
| print(f"❌ Error: '{cfg.npy_file}' not found.") | |
| return | |
| print(f"🎯 Loading target vector from {cfg.npy_file}...") | |
| try: | |
| vector_array = np.load(cfg.npy_file) | |
| tensor = torch.tensor(vector_array, dtype=torch.float32) | |
| if tensor.dim() == 1: | |
| tensor = tensor.unsqueeze(0) | |
| self.target_tensor = tensor.to(cfg.device) | |
| print(f"✅ Target vector loaded. Shape: {self.target_tensor.shape}") | |
| except Exception as e: | |
| print(f"❌ Failed to load {cfg.npy_file}: {e}") | |
| def match(self, image: Image.Image): | |
| if image is None: | |
| return "请先上传图片或输入 QQ 号获取头像。" | |
| if self.target_tensor is None: | |
| return f"初始化失败:未找到目标向量文件 `{cfg.npy_file}`。" | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| img_tensor = self.transform(image).unsqueeze(0).to(cfg.device) | |
| features = self.model.forward_features(img_tensor) | |
| embedding = self.model.forward_head(features, pre_logits=True) | |
| raw_sim = F.cosine_similarity(embedding, self.target_tensor).item() | |
| mapped_sim = map_similarity(raw_sim) | |
| # 风格判断 | |
| if mapped_sim > 0.85: | |
| label = "高度相似" | |
| desc = "图像风格与 Nano Banana 高度吻合,具有强烈的标志性特征。" | |
| color = "#4096ff" | |
| elif mapped_sim > 0.65: | |
| label = "较为相似" | |
| desc = "图像风格与 Nano Banana 有一定相似度,部分特征较为接近。" | |
| color = "#36cfc9" | |
| elif mapped_sim > 0.45: | |
| label = "轻微相似" | |
| desc = "图像与 Nano Banana 风格存在少量共同特征,整体差异明显。" | |
| color = "#9254de" | |
| else: | |
| label = "风格不符" | |
| desc = "图像风格与 Nano Banana 差异显著,几乎不具备相关特征。" | |
| color = "#8c8c8c" | |
| md_out = f""" | |
| <div style="padding: 24px; border-radius: 10px; border: 1px solid #e8e8e8; background: #fafafa;"> | |
| <p style="margin: 0 0 6px; font-size: 13px; color: #8c8c8c;">风格相似度评估</p> | |
| <h2 style="margin: 0 0 4px; color: {color}; font-size: 2.8em; font-weight: 700;">{mapped_sim:.2%}</h2> | |
| <span style="display: inline-block; padding: 2px 10px; border-radius: 4px; background: {color}20; color: {color}; font-size: 13px; font-weight: 600;">{label}</span> | |
| <p style="margin: 14px 0 16px; color: #444; font-size: 14px;">{desc}</p> | |
| <hr style="border: none; border-top: 1px solid #eee; margin: 0 0 14px;"> | |
| <div style="font-size: 12px; color: #999; font-family: monospace;"> | |
| 原始余弦相似度(Raw Cosine Similarity):<b style="color:#555">{raw_sim:.6f}</b><br> | |
| 映射函数:<code>(x − 0.80) / (0.90 − 0.80),裁剪至 [0, 1]</code> | |
| </div> | |
| </div> | |
| """ | |
| return md_out | |
| # ============== 3. 构建 Gradio UI ============== | |
| try: | |
| matcher = SingleVectorMatcher() | |
| is_ready = matcher.target_tensor is not None | |
| except Exception as e: | |
| print(f"Initialization Failed: {e}") | |
| matcher, is_ready = None, False | |
| def run_inference(image): | |
| if not is_ready: | |
| return "❌ 系统未就绪,请确保 `target_vector.npy` 已经上传至根目录。" | |
| return matcher.match(image) | |
| with gr.Blocks(title="Nano Banana 风格检测") as demo: | |
| gr.Markdown( | |
| """ | |
| ## Nano Banana 风格相似度检测 | |
| 上传图片,检测其视觉风格与 Nano Banana 的相似程度。 | |
| """ | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| input_img = gr.Image(type="pil", label="输入图片") | |
| btn = gr.Button("开始检测", variant="primary") | |
| with gr.Column(scale=1): | |
| out_md = gr.Markdown("检测结果将在此显示。") | |
| btn.click(run_inference, inputs=[input_img], outputs=[out_md]) | |
| gr.Markdown( | |
| """ | |
| --- | |
| **说明:** 相似度基于图像高维嵌入与预设参考向量的余弦相似度计算,结果仅供参考,不代表对风格的完整评价。 | |
| 原始输出区间约为 [.75, .85+],经线性映射后展示为百分比。 | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |