Spaces:
Sleeping
Sleeping
File size: 7,124 Bytes
efc0ff1 d057383 83fe6c9 35d488e d057383 55d0d4a d057383 55d0d4a 22b1c49 83fe6c9 d057383 83fe6c9 d057383 83fe6c9 d057383 83fe6c9 d057383 83fe6c9 d057383 83fe6c9 d057383 83fe6c9 d057383 83fe6c9 22b1c49 83fe6c9 d057383 83fe6c9 d057383 35d488e 83fe6c9 d057383 83fe6c9 d057383 83fe6c9 d057383 83fe6c9 d057383 83fe6c9 d057383 83fe6c9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | import os
import json
import torch
import torch.nn.functional as F
import timm
import numpy as np
import gradio as gr
import requests
from io import BytesIO
from PIL import Image
from huggingface_hub import hf_hub_download
# ============== 配置参数 ==============
class Config:
model_repo = "telecomadm1145/cmodel_v2_test"
weights_file = "pytorch_model.bin"
name_mapping_file = "label_id_mapping.json"
rev = "6350d5e35f883ca058bbc84a82853407874b68da"
model_name = "convnext_base.dinov3_lvd1689m"
image_size = 384
device = "cuda" if torch.cuda.is_available() else "cpu"
npy_file = "target_vector.npy"
cfg = Config()
# ============== 2. 核心逻辑 ==============
def download_file(repo_id, filename, rev):
print(f"📥 Downloading {filename} from {repo_id}...")
try:
return hf_hub_download(repo_id=repo_id, filename=filename, revision=rev)
except Exception as e:
print(f"⚠️ Failed to download {filename}: {e}")
return None
# ============== 相似度映射 ==============
def map_similarity(sim: float) -> float:
"""将余弦相似度线性映射到 [0, 1],原始区间 [0.8, 0.9]"""
mapped = (sim - 0.75) / (0.85 - 0.75)
return max(0.0, min(1.0, mapped))
# --- 向量匹配核心 ---
class SingleVectorMatcher:
def __init__(self):
print(f"🔄 Initializing on device: {cfg.device}")
# 1. 初始化模型架构
name_map_path = download_file(cfg.model_repo, cfg.name_mapping_file, "main")
with open(name_map_path, 'r', encoding='utf-8') as f:
full_map = json.load(f)
num_classes = len(full_map.get('id_to_model_name', {})) + len(full_map.get('id_to_base_model', {}))
self.model = timm.create_model(cfg.model_name, pretrained=False, num_classes=num_classes)
# 2. 加载权重
weights_path = download_file(cfg.model_repo, cfg.weights_file, cfg.rev)
state_dict = torch.load(weights_path, map_location='cpu')
self.model.load_state_dict(state_dict)
self.model.to(cfg.device)
self.model.eval()
# 3. 图像预处理
self.transform = timm.data.create_transform(
input_size=(3, cfg.image_size, cfg.image_size),
is_training=False,
mean=timm.data.IMAGENET_DEFAULT_MEAN,
std=timm.data.IMAGENET_DEFAULT_STD
)
# 4. 加载单一向量 NPY
self.target_tensor = None
self._load_single_npy()
def _load_single_npy(self):
if not os.path.exists(cfg.npy_file):
print(f"❌ Error: '{cfg.npy_file}' not found.")
return
print(f"🎯 Loading target vector from {cfg.npy_file}...")
try:
vector_array = np.load(cfg.npy_file)
tensor = torch.tensor(vector_array, dtype=torch.float32)
if tensor.dim() == 1:
tensor = tensor.unsqueeze(0)
self.target_tensor = tensor.to(cfg.device)
print(f"✅ Target vector loaded. Shape: {self.target_tensor.shape}")
except Exception as e:
print(f"❌ Failed to load {cfg.npy_file}: {e}")
@torch.no_grad()
def match(self, image: Image.Image):
if image is None:
return "请先上传图片或输入 QQ 号获取头像。"
if self.target_tensor is None:
return f"初始化失败:未找到目标向量文件 `{cfg.npy_file}`。"
if image.mode != 'RGB':
image = image.convert('RGB')
img_tensor = self.transform(image).unsqueeze(0).to(cfg.device)
features = self.model.forward_features(img_tensor)
embedding = self.model.forward_head(features, pre_logits=True)
raw_sim = F.cosine_similarity(embedding, self.target_tensor).item()
mapped_sim = map_similarity(raw_sim)
# 风格判断
if mapped_sim > 0.85:
label = "高度相似"
desc = "图像风格与 Nano Banana 高度吻合,具有强烈的标志性特征。"
color = "#4096ff"
elif mapped_sim > 0.65:
label = "较为相似"
desc = "图像风格与 Nano Banana 有一定相似度,部分特征较为接近。"
color = "#36cfc9"
elif mapped_sim > 0.45:
label = "轻微相似"
desc = "图像与 Nano Banana 风格存在少量共同特征,整体差异明显。"
color = "#9254de"
else:
label = "风格不符"
desc = "图像风格与 Nano Banana 差异显著,几乎不具备相关特征。"
color = "#8c8c8c"
md_out = f"""
<div style="padding: 24px; border-radius: 10px; border: 1px solid #e8e8e8; background: #fafafa;">
<p style="margin: 0 0 6px; font-size: 13px; color: #8c8c8c;">风格相似度评估</p>
<h2 style="margin: 0 0 4px; color: {color}; font-size: 2.8em; font-weight: 700;">{mapped_sim:.2%}</h2>
<span style="display: inline-block; padding: 2px 10px; border-radius: 4px; background: {color}20; color: {color}; font-size: 13px; font-weight: 600;">{label}</span>
<p style="margin: 14px 0 16px; color: #444; font-size: 14px;">{desc}</p>
<hr style="border: none; border-top: 1px solid #eee; margin: 0 0 14px;">
<div style="font-size: 12px; color: #999; font-family: monospace;">
原始余弦相似度(Raw Cosine Similarity):<b style="color:#555">{raw_sim:.6f}</b><br>
映射函数:<code>(x − 0.80) / (0.90 − 0.80),裁剪至 [0, 1]</code>
</div>
</div>
"""
return md_out
# ============== 3. 构建 Gradio UI ==============
try:
matcher = SingleVectorMatcher()
is_ready = matcher.target_tensor is not None
except Exception as e:
print(f"Initialization Failed: {e}")
matcher, is_ready = None, False
def run_inference(image):
if not is_ready:
return "❌ 系统未就绪,请确保 `target_vector.npy` 已经上传至根目录。"
return matcher.match(image)
with gr.Blocks(title="Nano Banana 风格检测") as demo:
gr.Markdown(
"""
## Nano Banana 风格相似度检测
上传图片,检测其视觉风格与 Nano Banana 的相似程度。
"""
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
input_img = gr.Image(type="pil", label="输入图片")
btn = gr.Button("开始检测", variant="primary")
with gr.Column(scale=1):
out_md = gr.Markdown("检测结果将在此显示。")
btn.click(run_inference, inputs=[input_img], outputs=[out_md])
gr.Markdown(
"""
---
**说明:** 相似度基于图像高维嵌入与预设参考向量的余弦相似度计算,结果仅供参考,不代表对风格的完整评价。
原始输出区间约为 [.75, .85+],经线性映射后展示为百分比。
"""
)
if __name__ == "__main__":
demo.launch() |