telecomadm1145 commited on
Commit
d057383
·
verified ·
1 Parent(s): d142a55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -159
app.py CHANGED
@@ -1,16 +1,4 @@
1
- import os
2
- import json
3
- import torch
4
- import torch.nn.functional as F
5
- import timm
6
- import numpy as np
7
- import gradio as gr
8
- import requests
9
- from io import BytesIO
10
- from PIL import Image
11
- from huggingface_hub import hf_hub_download
12
-
13
- # ============== 1. 配置参数 ==============
14
  class Config:
15
  model_repo = "telecomadm1145/cmodel_v2_test"
16
  weights_file = "pytorch_model.bin"
@@ -18,190 +6,90 @@ class Config:
18
  rev = "6350d5e35f883ca058bbc84a82853407874b68da"
19
  model_name = "convnext_base.dinov3_lvd1689m"
20
  image_size = 384
21
-
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
-
24
- # 你的“男娘概念”特征向量 (确保同目录下有这个文件)
25
  npy_file = "target_vector.npy"
26
-
27
  cfg = Config()
28
 
29
- # ============== 2. 核心逻辑 ==============
30
- def download_file(repo_id, filename, rev):
31
- print(f"📥 Downloading {filename} from {repo_id}...")
32
- try:
33
- return hf_hub_download(repo_id=repo_id, filename=filename, revision=rev)
34
- except Exception as e:
35
- print(f"⚠️ Failed to download {filename}: {e}")
36
- return None
37
-
38
- # --- 获取 QQ 头像的功能 ---
39
- def fetch_qq_avatar(qq_num):
40
- if not qq_num or not qq_num.isdigit():
41
- return None, "❌ 请输入纯数字的有效QQ号!"
42
-
43
- # 使用 QQ 头像官方接口 (s=640 获取高清图, q1/q2 均可)
44
- url = f"http://q1.qlogo.cn/g?b=qq&nk={qq_num}&s=640"
45
- try:
46
- response = requests.get(url, timeout=5)
47
- response.raise_for_status()
48
- img = Image.open(BytesIO(response.content)).convert("RGB")
49
- return img, "✅ 获取 QQ 头像成功!点击下方按钮开始检测吧~"
50
- except Exception as e:
51
- return None, f"❌ 获取失败,可能是网络波动或QQ号不存在: {e}"
52
-
53
  def map_similarity(sim: float) -> float:
54
- """将余弦相似度 [sim_min, sim_max] 线性映射到 [0, 1],并裁剪。"""
55
- mapped = (sim - 0.8) / (0.9 - 0.8)
56
  return max(0.0, min(1.0, mapped))
57
 
58
- # --- 向量匹配核心 ---
59
- class SingleVectorMatcher:
60
- def __init__(self):
61
- print(f"🔄 Initializing on device: {cfg.device}")
62
-
63
- # 1. 初始化模型架构
64
- name_map_path = download_file(cfg.model_repo, cfg.name_mapping_file, "main")
65
- with open(name_map_path, 'r', encoding='utf-8') as f:
66
- full_map = json.load(f)
67
- num_classes = len(full_map.get('id_to_model_name', {})) + len(full_map.get('id_to_base_model', {}))
68
-
69
- self.model = timm.create_model(cfg.model_name, pretrained=False, num_classes=num_classes)
70
-
71
- # 2. 加载权重
72
- weights_path = download_file(cfg.model_repo, cfg.weights_file, cfg.rev)
73
- state_dict = torch.load(weights_path, map_location='cpu')
74
- self.model.load_state_dict(state_dict)
75
- self.model.to(cfg.device)
76
- self.model.eval()
77
-
78
- # 3. 图像预处理
79
- self.transform = timm.data.create_transform(
80
- input_size=(3, cfg.image_size, cfg.image_size),
81
- is_training=False,
82
- mean=timm.data.IMAGENET_DEFAULT_MEAN,
83
- std=timm.data.IMAGENET_DEFAULT_STD
84
- )
85
-
86
- # 4. 加载单一向量 NPY
87
- self.target_tensor = None
88
- self._load_single_npy()
89
-
90
- def _load_single_npy(self):
91
- if not os.path.exists(cfg.npy_file):
92
- print(f"❌ Error: '{cfg.npy_file}' not found.")
93
- return
94
-
95
- print(f"🎯 Loading target vector from {cfg.npy_file}...")
96
- try:
97
- vector_array = np.load(cfg.npy_file)
98
- tensor = torch.tensor(vector_array, dtype=torch.float32)
99
- if tensor.dim() == 1:
100
- tensor = tensor.unsqueeze(0)
101
- self.target_tensor = tensor.to(cfg.device)
102
- print(f"✅ Target vector loaded. Shape: {self.target_tensor.shape}")
103
- except Exception as e:
104
- print(f"❌ Failed to load {cfg.npy_file}: {e}")
105
-
106
  @torch.no_grad()
107
  def match(self, image: Image.Image):
108
  if image is None:
109
- return "请先上传图片或获取 QQ 头像"
110
  if self.target_tensor is None:
111
- return f"❌ 引擎故障:找灵魂向量 `{cfg.npy_file}`"
112
 
113
  if image.mode != 'RGB':
114
  image = image.convert('RGB')
115
 
116
- # 提取图像特征
117
  img_tensor = self.transform(image).unsqueeze(0).to(cfg.device)
118
  features = self.model.forward_features(img_tensor)
119
- embedding = self.model.forward_head(features, pre_logits=True)
120
 
121
- # 计算余弦相似度
122
- raw = F.cosine_similarity(embedding, self.target_tensor).item()
123
- similarity = map_similarity(raw)
124
 
125
- # ================= Meme 文案逻辑 =================
126
- if similarity > 0.85:
127
- status = "🚨 **最警报!纯极高的小男娘!**<br>这什么神仙画风,快让他/她穿上小裙子!"
128
- color = "#ff4d4f" # 红色
129
- elif similarity > 0.65:
130
- status = "👀 **疑似男娘...**<br>成分复杂,眼神逐渐变得不清白,建议严查!"
131
- color = "#faad14" # 橙黄
132
- elif similarity > 0.45:
133
- status = "🤔 **薛定谔的男娘**<br>处于男娘与普通路人的量子叠加态,有点东西但不多。"
134
- color = "#1890ff" # 蓝色
 
 
 
135
  else:
136
- status = "🗿 **纯爷们 / 铁直女无误**<br>完全没有任何男娘气息,钢铁直,散了吧。"
137
- color = "#52c41a" # 绿色
 
138
 
139
  md_out = f"""
140
- <div style="text-align: center; padding: 30px; border-radius: 15px; background-color: #f8f9fa; border: 2px solid {color}; box-shadow: 0 4px 12px rgba(0,0,0,0.1);">
141
- <h1 style="color: {color}; font-size: 4em; margin: 10px 0;">{similarity:.2%}(raw:{raw})</h1>
 
 
 
 
 
 
 
 
142
  </div>
143
  """
144
  return md_out
145
 
146
- # ============== 3. 构建 Gradio UI ==============
147
- try:
148
- matcher = SingleVectorMatcher()
149
- is_ready = matcher.target_tensor is not None
150
- except Exception as e:
151
- print(f"Initialization Failed: {e}")
152
- matcher, is_ready = None, False
153
-
154
- def run_inference(image):
155
- if not is_ready:
156
- return "❌ 系统未就绪,请确保 `target_vector.npy` 已经上传至根目录。"
157
- return matcher.match(image)
158
 
159
- # 自定义 CSS 让界面更二次元/Meme一点
160
- css = """
161
- .gradio-container { font-family: 'Comic Sans MS', 'Microsoft YaHei', sans-serif !important; }
162
- """
163
-
164
- with gr.Blocks(title="小男娘浓度检测器", theme=gr.themes.Soft(primary_hue="pink"), css=css) as demo:
165
  gr.Markdown(
166
  """
167
- <div style="text-align: center;">
168
- <h1>小男娘浓度检测器</h1>
169
- <p>基于先进的深度学习卷积神经网络(确信),精准检测你的头像成分!<br>
170
- <i>只需输入 QQ 号,或者直接上传图片即可判定!</i></p>
171
- </div>
172
  """
173
  )
174
 
175
- with gr.Row():
176
- with gr.Column(scale=1):
177
- with gr.Group():
178
- gr.Markdown("### 方式一:一键查成分")
179
- with gr.Row():
180
- qq_input = gr.Textbox(label="", placeholder="输入TA的QQ号...", scale=3, show_label=False)
181
- qq_btn = gr.Button("🔍 获取头像", variant="secondary", scale=1)
182
-
183
- gr.Markdown("### 方式二:自己传黑照")
184
- input_img = gr.Image(type="pil", label="被检测目标")
185
-
186
- # 主按钮
187
- btn = gr.Button("开始成分鉴定", variant="primary", size="lg")
188
-
189
  with gr.Column(scale=1):
190
- out_md = gr.Markdown("### 鉴定报告将在此生成...")
 
191
 
192
- # 事件绑定
193
- qq_btn.click(fetch_qq_avatar, inputs=[qq_input], outputs=[input_img, out_md])
194
  btn.click(run_inference, inputs=[input_img], outputs=[out_md])
195
-
196
- # 免责声明 (Disclaimer)
197
  gr.Markdown(
198
  """
199
  ---
200
- ### ⚠️ 免责声(Disclaimer)
201
- 1. **仅供娱乐**:本页面仅供群友整活、Meme(梗)交流与娱乐使用
202
- 2. **AI 的数学游戏**:本程序的“浓度”仅代表图像高维矩阵特征与特定预设模型向量的**余弦相似度(Cosine Similarity)**。
203
- 3. **无关真实身份**:检测结果**绝对不代表**任何真实人物的生理性别、身份认同、性取向或道德评价。
204
- 4. **请勿滥用**:请大家图一乐就好,**严禁**将本测试结果用于网暴、人身攻击、造谣或任何严肃场合。最终解释权归作者所有。
205
  """
206
  )
207
 
 
1
+ # ============== 配置参数 ==============
 
 
 
 
 
 
 
 
 
 
 
 
2
  class Config:
3
  model_repo = "telecomadm1145/cmodel_v2_test"
4
  weights_file = "pytorch_model.bin"
 
6
  rev = "6350d5e35f883ca058bbc84a82853407874b68da"
7
  model_name = "convnext_base.dinov3_lvd1689m"
8
  image_size = 384
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
10
  npy_file = "target_vector.npy"
 
11
  cfg = Config()
12
 
13
+ # ============== 相似度映射 ==============
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def map_similarity(sim: float) -> float:
15
+ """将余弦相似度线性映射到 [0, 1],原始区间 [0.8, 0.9]"""
16
+ mapped = (sim - 0.75) / (0.85 - 0.75)
17
  return max(0.0, min(1.0, mapped))
18
 
19
+ # ============== 推理输出 ==============
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @torch.no_grad()
21
  def match(self, image: Image.Image):
22
  if image is None:
23
+ return "请先上传图片或输入 QQ 号获取头像"
24
  if self.target_tensor is None:
25
+ return f"初始化失败找到目标向量文件 `{cfg.npy_file}`"
26
 
27
  if image.mode != 'RGB':
28
  image = image.convert('RGB')
29
 
 
30
  img_tensor = self.transform(image).unsqueeze(0).to(cfg.device)
31
  features = self.model.forward_features(img_tensor)
32
+ embedding = self.model.forward_head(features, pre_logits=True)
33
 
34
+ raw_sim = F.cosine_similarity(embedding, self.target_tensor).item()
35
+ mapped_sim = map_similarity(raw_sim)
 
36
 
37
+ # 风格判断
38
+ if mapped_sim > 0.85:
39
+ label = "高度相似"
40
+ desc = "图像风格与 Nano Banana 高度吻合,具有强烈的标志性特征。"
41
+ color = "#4096ff"
42
+ elif mapped_sim > 0.65:
43
+ label = "较为相似"
44
+ desc = "图像风格与 Nano Banana 有一定相似度,部分特征较为接近。"
45
+ color = "#36cfc9"
46
+ elif mapped_sim > 0.45:
47
+ label = "轻微相似"
48
+ desc = "图像与 Nano Banana 风格存在少量共同特征,整体差异明显。"
49
+ color = "#9254de"
50
  else:
51
+ label = "风格不符"
52
+ desc = "图像风格与 Nano Banana 差异显著,几乎不具备相关特征。"
53
+ color = "#8c8c8c"
54
 
55
  md_out = f"""
56
+ <div style="padding: 24px; border-radius: 10px; border: 1px solid #e8e8e8; background: #fafafa;">
57
+ <p style="margin: 0 0 6px; font-size: 13px; color: #8c8c8c;">风格相似度评估</p>
58
+ <h2 style="margin: 0 0 4px; color: {color}; font-size: 2.8em; font-weight: 700;">{mapped_sim:.2%}</h2>
59
+ <span style="display: inline-block; padding: 2px 10px; border-radius: 4px; background: {color}20; color: {color}; font-size: 13px; font-weight: 600;">{label}</span>
60
+ <p style="margin: 14px 0 16px; color: #444; font-size: 14px;">{desc}</p>
61
+ <hr style="border: none; border-top: 1px solid #eee; margin: 0 0 14px;">
62
+ <div style="font-size: 12px; color: #999; font-family: monospace;">
63
+ 原始余弦相似度(Raw Cosine Similarity):<b style="color:#555">{raw_sim:.6f}</b><br>
64
+ 映射函数:<code>(x − 0.80) / (0.90 − 0.80),裁剪至 [0, 1]</code>
65
+ </div>
66
  </div>
67
  """
68
  return md_out
69
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ with gr.Blocks(title="Nano Banana 风格检测") as demo:
 
 
 
 
 
72
  gr.Markdown(
73
  """
74
+ ## Nano Banana 风格相似度检测
75
+ 上传图片,检测其视觉风格与 Nano Banana 的相似程度。支持直接上传图片或通过 QQ 号拉取头像。
 
 
 
76
  """
77
  )
78
 
79
+ with gr.Row(equal_height=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  with gr.Column(scale=1):
81
+ input_img = gr.Image(type="pil", label="输入图片")
82
+ btn = gr.Button("开始检测", variant="primary")
83
 
84
+ with gr.Column(scale=1):
85
+ out_md = gr.Markdown("检测结果将在此显示。")
86
  btn.click(run_inference, inputs=[input_img], outputs=[out_md])
87
+
 
88
  gr.Markdown(
89
  """
90
  ---
91
+ **说:** 相似度基于图像高维嵌入与预设参考向量的余弦相似度计算,结果仅供参考,不代表对风格的完整评价。
92
+ 原始输出区间约为 [.75, .85+],经线性映射后展示为百分比
 
 
 
93
  """
94
  )
95