File size: 7,124 Bytes
efc0ff1
 
 
 
 
 
 
 
 
 
 
d057383
83fe6c9
 
 
 
 
 
 
 
 
 
35d488e
 
 
 
 
 
 
 
d057383
55d0d4a
d057383
 
55d0d4a
 
22b1c49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83fe6c9
 
 
d057383
83fe6c9
d057383
83fe6c9
 
 
 
 
 
d057383
83fe6c9
d057383
 
83fe6c9
d057383
 
 
 
 
 
 
 
 
 
 
 
 
83fe6c9
d057383
 
 
83fe6c9
 
d057383
 
 
 
 
 
 
 
 
 
83fe6c9
 
 
 
22b1c49
 
 
 
 
 
 
 
 
 
 
 
83fe6c9
d057383
83fe6c9
 
d057383
35d488e
83fe6c9
 
 
d057383
83fe6c9
d057383
 
83fe6c9
d057383
 
83fe6c9
d057383
83fe6c9
 
 
d057383
 
83fe6c9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import json
import torch
import torch.nn.functional as F
import timm
import numpy as np
import gradio as gr
import requests
from io import BytesIO
from PIL import Image
from huggingface_hub import hf_hub_download
# ============== 配置参数 ==============
class Config:
    model_repo = "telecomadm1145/cmodel_v2_test"
    weights_file = "pytorch_model.bin"       
    name_mapping_file = "label_id_mapping.json"
    rev = "6350d5e35f883ca058bbc84a82853407874b68da"
    model_name = "convnext_base.dinov3_lvd1689m"
    image_size = 384
    device = "cuda" if torch.cuda.is_available() else "cpu"
    npy_file = "target_vector.npy"
cfg = Config()
# ============== 2. 核心逻辑 ==============
def download_file(repo_id, filename, rev):
    print(f"📥 Downloading {filename} from {repo_id}...")
    try:
        return hf_hub_download(repo_id=repo_id, filename=filename, revision=rev)
    except Exception as e:
        print(f"⚠️ Failed to download {filename}: {e}")
        return None
# ============== 相似度映射 ==============
def map_similarity(sim: float) -> float:
    """将余弦相似度线性映射到 [0, 1],原始区间 [0.8, 0.9]"""
    mapped = (sim - 0.75) / (0.85 - 0.75)
    return max(0.0, min(1.0, mapped))

# --- 向量匹配核心 ---
class SingleVectorMatcher:
    def __init__(self):
        print(f"🔄 Initializing on device: {cfg.device}")
        
        # 1. 初始化模型架构
        name_map_path = download_file(cfg.model_repo, cfg.name_mapping_file, "main")
        with open(name_map_path, 'r', encoding='utf-8') as f:
            full_map = json.load(f)
            num_classes = len(full_map.get('id_to_model_name', {})) + len(full_map.get('id_to_base_model', {}))

        self.model = timm.create_model(cfg.model_name, pretrained=False, num_classes=num_classes)
        
        # 2. 加载权重
        weights_path = download_file(cfg.model_repo, cfg.weights_file, cfg.rev)
        state_dict = torch.load(weights_path, map_location='cpu')
        self.model.load_state_dict(state_dict)
        self.model.to(cfg.device)
        self.model.eval()
        
        # 3. 图像预处理
        self.transform = timm.data.create_transform(
            input_size=(3, cfg.image_size, cfg.image_size),
            is_training=False,
            mean=timm.data.IMAGENET_DEFAULT_MEAN,
            std=timm.data.IMAGENET_DEFAULT_STD
        )
        
        # 4. 加载单一向量 NPY
        self.target_tensor = None
        self._load_single_npy()

    def _load_single_npy(self):
        if not os.path.exists(cfg.npy_file):
            print(f"❌ Error: '{cfg.npy_file}' not found.")
            return

        print(f"🎯 Loading target vector from {cfg.npy_file}...")
        try:
            vector_array = np.load(cfg.npy_file)
            tensor = torch.tensor(vector_array, dtype=torch.float32)
            if tensor.dim() == 1:
                tensor = tensor.unsqueeze(0)
            self.target_tensor = tensor.to(cfg.device)
            print(f"✅ Target vector loaded. Shape: {self.target_tensor.shape}")
        except Exception as e:
            print(f"❌ Failed to load {cfg.npy_file}: {e}")

    @torch.no_grad()
    def match(self, image: Image.Image):
        if image is None:
            return "请先上传图片或输入 QQ 号获取头像。"
        if self.target_tensor is None:
            return f"初始化失败:未找到目标向量文件 `{cfg.npy_file}`。"
            
        if image.mode != 'RGB':
            image = image.convert('RGB')
            
        img_tensor = self.transform(image).unsqueeze(0).to(cfg.device)
        features = self.model.forward_features(img_tensor)
        embedding = self.model.forward_head(features, pre_logits=True)
        
        raw_sim = F.cosine_similarity(embedding, self.target_tensor).item()
        mapped_sim = map_similarity(raw_sim)
        
        # 风格判断
        if mapped_sim > 0.85:
            label = "高度相似"
            desc = "图像风格与 Nano Banana 高度吻合,具有强烈的标志性特征。"
            color = "#4096ff"
        elif mapped_sim > 0.65:
            label = "较为相似"
            desc = "图像风格与 Nano Banana 有一定相似度,部分特征较为接近。"
            color = "#36cfc9"
        elif mapped_sim > 0.45:
            label = "轻微相似"
            desc = "图像与 Nano Banana 风格存在少量共同特征,整体差异明显。"
            color = "#9254de"
        else:
            label = "风格不符"
            desc = "图像风格与 Nano Banana 差异显著,几乎不具备相关特征。"
            color = "#8c8c8c"
            
        md_out = f"""
        <div style="padding: 24px; border-radius: 10px; border: 1px solid #e8e8e8; background: #fafafa;">
            <p style="margin: 0 0 6px; font-size: 13px; color: #8c8c8c;">风格相似度评估</p>
            <h2 style="margin: 0 0 4px; color: {color}; font-size: 2.8em; font-weight: 700;">{mapped_sim:.2%}</h2>
            <span style="display: inline-block; padding: 2px 10px; border-radius: 4px; background: {color}20; color: {color}; font-size: 13px; font-weight: 600;">{label}</span>
            <p style="margin: 14px 0 16px; color: #444; font-size: 14px;">{desc}</p>
            <hr style="border: none; border-top: 1px solid #eee; margin: 0 0 14px;">
            <div style="font-size: 12px; color: #999; font-family: monospace;">
                原始余弦相似度(Raw Cosine Similarity):<b style="color:#555">{raw_sim:.6f}</b><br>
                映射函数:<code>(x − 0.80) / (0.90 − 0.80),裁剪至 [0, 1]</code>
            </div>
        </div>
        """
        return md_out

# ============== 3. 构建 Gradio UI ==============
try:
    matcher = SingleVectorMatcher()
    is_ready = matcher.target_tensor is not None
except Exception as e:
    print(f"Initialization Failed: {e}")
    matcher, is_ready = None, False

def run_inference(image):
    if not is_ready:
         return "❌ 系统未就绪,请确保 `target_vector.npy` 已经上传至根目录。"
    return matcher.match(image)

with gr.Blocks(title="Nano Banana 风格检测") as demo:
    gr.Markdown(
        """
        ## Nano Banana 风格相似度检测
        上传图片,检测其视觉风格与 Nano Banana 的相似程度。
        """
    )
    
    with gr.Row(equal_height=True):
        with gr.Column(scale=1):
            input_img = gr.Image(type="pil", label="输入图片")
            btn = gr.Button("开始检测", variant="primary")

        with gr.Column(scale=1):
            out_md = gr.Markdown("检测结果将在此显示。")
    btn.click(run_inference, inputs=[input_img], outputs=[out_md])

    gr.Markdown(
        """
        ---
        **说明:** 相似度基于图像高维嵌入与预设参考向量的余弦相似度计算,结果仅供参考,不代表对风格的完整评价。  
        原始输出区间约为 [.75, .85+],经线性映射后展示为百分比。
        """
    )

if __name__ == "__main__":
    demo.launch()