test / app.py
telecomadm1145's picture
Update app.py
35d488e verified
import os
import json
import torch
import torch.nn.functional as F
import timm
import numpy as np
import gradio as gr
import requests
from io import BytesIO
from PIL import Image
from huggingface_hub import hf_hub_download
# ============== 配置参数 ==============
class Config:
model_repo = "telecomadm1145/cmodel_v2_test"
weights_file = "pytorch_model.bin"
name_mapping_file = "label_id_mapping.json"
rev = "6350d5e35f883ca058bbc84a82853407874b68da"
model_name = "convnext_base.dinov3_lvd1689m"
image_size = 384
device = "cuda" if torch.cuda.is_available() else "cpu"
npy_file = "target_vector.npy"
cfg = Config()
# ============== 2. 核心逻辑 ==============
def download_file(repo_id, filename, rev):
print(f"📥 Downloading {filename} from {repo_id}...")
try:
return hf_hub_download(repo_id=repo_id, filename=filename, revision=rev)
except Exception as e:
print(f"⚠️ Failed to download {filename}: {e}")
return None
# ============== 相似度映射 ==============
def map_similarity(sim: float) -> float:
"""将余弦相似度线性映射到 [0, 1],原始区间 [0.8, 0.9]"""
mapped = (sim - 0.75) / (0.85 - 0.75)
return max(0.0, min(1.0, mapped))
# --- 向量匹配核心 ---
class SingleVectorMatcher:
def __init__(self):
print(f"🔄 Initializing on device: {cfg.device}")
# 1. 初始化模型架构
name_map_path = download_file(cfg.model_repo, cfg.name_mapping_file, "main")
with open(name_map_path, 'r', encoding='utf-8') as f:
full_map = json.load(f)
num_classes = len(full_map.get('id_to_model_name', {})) + len(full_map.get('id_to_base_model', {}))
self.model = timm.create_model(cfg.model_name, pretrained=False, num_classes=num_classes)
# 2. 加载权重
weights_path = download_file(cfg.model_repo, cfg.weights_file, cfg.rev)
state_dict = torch.load(weights_path, map_location='cpu')
self.model.load_state_dict(state_dict)
self.model.to(cfg.device)
self.model.eval()
# 3. 图像预处理
self.transform = timm.data.create_transform(
input_size=(3, cfg.image_size, cfg.image_size),
is_training=False,
mean=timm.data.IMAGENET_DEFAULT_MEAN,
std=timm.data.IMAGENET_DEFAULT_STD
)
# 4. 加载单一向量 NPY
self.target_tensor = None
self._load_single_npy()
def _load_single_npy(self):
if not os.path.exists(cfg.npy_file):
print(f"❌ Error: '{cfg.npy_file}' not found.")
return
print(f"🎯 Loading target vector from {cfg.npy_file}...")
try:
vector_array = np.load(cfg.npy_file)
tensor = torch.tensor(vector_array, dtype=torch.float32)
if tensor.dim() == 1:
tensor = tensor.unsqueeze(0)
self.target_tensor = tensor.to(cfg.device)
print(f"✅ Target vector loaded. Shape: {self.target_tensor.shape}")
except Exception as e:
print(f"❌ Failed to load {cfg.npy_file}: {e}")
@torch.no_grad()
def match(self, image: Image.Image):
if image is None:
return "请先上传图片或输入 QQ 号获取头像。"
if self.target_tensor is None:
return f"初始化失败:未找到目标向量文件 `{cfg.npy_file}`。"
if image.mode != 'RGB':
image = image.convert('RGB')
img_tensor = self.transform(image).unsqueeze(0).to(cfg.device)
features = self.model.forward_features(img_tensor)
embedding = self.model.forward_head(features, pre_logits=True)
raw_sim = F.cosine_similarity(embedding, self.target_tensor).item()
mapped_sim = map_similarity(raw_sim)
# 风格判断
if mapped_sim > 0.85:
label = "高度相似"
desc = "图像风格与 Nano Banana 高度吻合,具有强烈的标志性特征。"
color = "#4096ff"
elif mapped_sim > 0.65:
label = "较为相似"
desc = "图像风格与 Nano Banana 有一定相似度,部分特征较为接近。"
color = "#36cfc9"
elif mapped_sim > 0.45:
label = "轻微相似"
desc = "图像与 Nano Banana 风格存在少量共同特征,整体差异明显。"
color = "#9254de"
else:
label = "风格不符"
desc = "图像风格与 Nano Banana 差异显著,几乎不具备相关特征。"
color = "#8c8c8c"
md_out = f"""
<div style="padding: 24px; border-radius: 10px; border: 1px solid #e8e8e8; background: #fafafa;">
<p style="margin: 0 0 6px; font-size: 13px; color: #8c8c8c;">风格相似度评估</p>
<h2 style="margin: 0 0 4px; color: {color}; font-size: 2.8em; font-weight: 700;">{mapped_sim:.2%}</h2>
<span style="display: inline-block; padding: 2px 10px; border-radius: 4px; background: {color}20; color: {color}; font-size: 13px; font-weight: 600;">{label}</span>
<p style="margin: 14px 0 16px; color: #444; font-size: 14px;">{desc}</p>
<hr style="border: none; border-top: 1px solid #eee; margin: 0 0 14px;">
<div style="font-size: 12px; color: #999; font-family: monospace;">
原始余弦相似度(Raw Cosine Similarity):<b style="color:#555">{raw_sim:.6f}</b><br>
映射函数:<code>(x − 0.80) / (0.90 − 0.80),裁剪至 [0, 1]</code>
</div>
</div>
"""
return md_out
# ============== 3. 构建 Gradio UI ==============
try:
matcher = SingleVectorMatcher()
is_ready = matcher.target_tensor is not None
except Exception as e:
print(f"Initialization Failed: {e}")
matcher, is_ready = None, False
def run_inference(image):
if not is_ready:
return "❌ 系统未就绪,请确保 `target_vector.npy` 已经上传至根目录。"
return matcher.match(image)
with gr.Blocks(title="Nano Banana 风格检测") as demo:
gr.Markdown(
"""
## Nano Banana 风格相似度检测
上传图片,检测其视觉风格与 Nano Banana 的相似程度。
"""
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
input_img = gr.Image(type="pil", label="输入图片")
btn = gr.Button("开始检测", variant="primary")
with gr.Column(scale=1):
out_md = gr.Markdown("检测结果将在此显示。")
btn.click(run_inference, inputs=[input_img], outputs=[out_md])
gr.Markdown(
"""
---
**说明:** 相似度基于图像高维嵌入与预设参考向量的余弦相似度计算,结果仅供参考,不代表对风格的完整评价。
原始输出区间约为 [.75, .85+],经线性映射后展示为百分比。
"""
)
if __name__ == "__main__":
demo.launch()