DanbooruSearch / core /engine.py
SAkizuki's picture
Auto-sync from GitHub Actions
5b3d6a6
"""
core/engine.py
──────────────
DanbooruTagger 核心引擎
缓存格式(存于 cache_dir/ 目录):
embeddings.safetensors — 四路向量矩阵(FP16),行顺序与 metadata.parquet 完全对齐
metadata.parquet — DataFrame(name/cn_name/cn_core/wiki/nsfw/category/post_count)
meta.json — 标量元数据(max_log_count、schema_version)
"""
from __future__ import annotations
import asyncio
import json
import os
import re
import time
from pathlib import Path
from datetime import datetime
from typing import Optional
import jieba
import numpy as np
import pandas as pd
import torch
from safetensors.torch import load_file as st_load, save_file as st_save
from sentence_transformers import SentenceTransformer, util
from .models import SearchRequest, SearchResponse, TagResult
from platform_utils import (
PLATFORM,
is_cloud,
download_file,
resolve_model_path,
)
# ──────────────────────────────────────────────
# 常量
# ──────────────────────────────────────────────
STOP_WORDS: frozenset[str] = frozenset({
',', '.', ':', ';', '?', '!', '"', "'", '`',
'(', ')', '[', ']', '{', '}', '<', '>',
'-', '_', '=', '+', '/', '\\', '|', '@', '#', '$', '%', '^', '&', '*', '~',
',', '。', ':', ';', '?', '!', '\u201c', '\u201d', '\u2018', '\u2019',
'(', ')', '【', '】', '《', '》', '、', '…', '—', '·',
' ', '\t', '\n', '\r',
'的', '地', '得', '了', '着', '过',
'是', '为', '被', '给', '把', '让', '由',
'在', '从', '自', '向', '往', '对', '于',
'和', '与', '及', '或', '且', '而', '但', '并', '即', '又', '也',
'啊', '吗', '吧', '呢', '噢', '哦', '哈', '呀', '哇',
'我', '你', '他', '她', '它', '我们', '你们', '他们',
'这', '那', '此', '其', '谁', '啥', '某', '每',
'这个', '那个', '这些', '那些', '这里', '那里',
'个', '位', '只', '条', '张', '幅', '件', '套', '双', '对', '副',
'种', '类', '群', '些', '点', '份', '部', '名',
'很', '太', '更', '最', '挺', '特', '好', '真',
'一', '一个', '一种', '一下', '一点', '一些',
'有', '无', '非', '没', '不',
})
CAT_MAP: dict[str, str] = {
'0': 'General', '1': 'Artist', '3': 'Copyright', '4': 'Character', '5': 'Meta',
}
SCHEMA_VERSION = 2 # 升级此值将自动触发全量重建,用于破坏性格式变更
# ──────────────────────────────────────────────
# 缓存路径助手
# ──────────────────────────────────────────────
class _CachePaths:
def __init__(self, cache_dir: str | Path):
self.dir = Path(cache_dir)
self.embeddings = self.dir / 'danbooru_multiview_embeddings.safetensors'
self.metadata = self.dir / 'tags_metadata.parquet'
self.meta_json = self.dir / 'version_data.json'
def exists(self) -> bool:
return (
self.embeddings.is_file()
and self.metadata.is_file()
and self.meta_json.is_file()
)
def ensure_dir(self):
self.dir.mkdir(parents=True, exist_ok=True)
class DanbooruTagger:
"""核心搜索引擎(单例)"""
_instance: Optional['DanbooruTagger'] = None
_lock: Optional[asyncio.Lock] = None
@classmethod
def is_ready(cls) -> bool:
return cls._instance is not None and cls._instance.is_loaded
@classmethod
async def get_instance(cls, **kwargs) -> 'DanbooruTagger':
if cls._lock is None:
cls._lock = asyncio.Lock()
async with cls._lock:
if cls._instance is None:
inst = cls(**kwargs)
await asyncio.to_thread(inst.load)
cls._instance = inst
return cls._instance
def __init__(
self,
model_path: Optional[str] = None,
csv_file: str = 'origin_database/tags_enhanced.csv',
cache_dir: str = 'tags_embedding',
cooc_file: str = 'origin_database/cooccurrence_clean.csv',
):
# 模型路径:优先使用显式传入,否则交由 platform_utils 解析
self.model_path = model_path or resolve_model_path()
self.csv_path = csv_file
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.paths = _CachePaths(cache_dir)
self.cooc_file = cooc_file
self.model: Optional[SentenceTransformer] = None
self.df: Optional[pd.DataFrame] = None
self.emb_en: Optional[torch.Tensor] = None
self.emb_cn: Optional[torch.Tensor] = None
self.emb_wiki: Optional[torch.Tensor] = None
self.emb_cn_core: Optional[torch.Tensor] = None
self.max_log_count: float = 15.0
self.cooc: dict[str, list[tuple[str, int]]] = {}
self._name_to_idx: dict[str, int] = {}
self.is_loaded: bool = False
# ── 初始化 ────────────────────────────────────────────────────────────
def load(self) -> None:
"""同步加载,在线程池中调用。"""
if self.is_loaded:
return
t0 = time.time()
# ── 云端环境:从对应平台 Hub 拉取数据文件 ──────────────────────────
if is_cloud():
self._pull_cloud_files()
# ── 缓存校验与构建 ─────────────────────────────────────────────────
if not self.paths.exists():
print('\n' + '=' * 50)
print('[Engine] 未找到缓存,开始首次构建(约 1~3 分钟)...')
print('=' * 50 + '\n')
self._load_model()
self._build_full()
else:
print(f'[Engine] 加载缓存 ({self.paths.dir}) ...')
self._load_from_cache()
if self._cached_schema_version() != SCHEMA_VERSION:
print('[Engine] 缓存格式版本不符,触发全量重建...')
self._load_model()
self._build_full()
elif os.path.exists(self.csv_path):
self._load_model()
self._smart_update()
if self.model is None:
self._load_model()
self._setup_jieba_from_memory()
self._load_cooc()
self._name_to_idx = {n: i for i, n in enumerate(self.df['name'])}
self.is_loaded = True
print(f'[Engine] 初始化完成,耗时 {time.time() - t0:.2f}s')
def _pull_cloud_files(self) -> None:
"""
从当前云平台拉取所有数据文件,并将路径写回实例属性。
HF / MS 的差异完全由 platform_utils.download_file() 屏蔽。
"""
print(f'[Engine] 云端环境 ({PLATFORM}),开始拉取数据文件...')
# ── HF 平台需要额外指定 repo_id(SPACE_ID)和 repo_type ──────────
extra_hf_kwargs = {}
if PLATFORM == 'hf':
extra_hf_kwargs = {
'hf_repo_id': os.environ.get('SPACE_ID'),
'hf_repo_type': 'space',
}
def pull(filename: str) -> str:
try:
return download_file(filename, **extra_hf_kwargs)
except Exception as e:
print(f'[Engine] 拉取 {filename} 失败(非致命): {e}')
return filename # 回退到原始路径,让后续逻辑决定是否重建
self.csv_path = pull('origin_database/tags_enhanced.csv')
self.cooc_file = pull('origin_database/cooccurrence_clean.parquet')
meta_path = pull('tags_embedding/tags_metadata.parquet')
emb_path = pull('tags_embedding/danbooru_multiview_embeddings.safetensors')
json_path = pull('tags_embedding/version_data.json')
# 只有三个缓存文件都成功拉取才覆盖路径,防止部分失败导致 exists() 误判
if all(
Path(p).is_file()
for p in (meta_path, emb_path, json_path)
):
self.paths.metadata = Path(meta_path)
self.paths.embeddings = Path(emb_path)
self.paths.meta_json = Path(json_path)
print('[Engine] 云端数据文件拉取完毕。')
else:
print('[Engine] 部分缓存文件拉取失败,将触发本地重建。')
# ── 搜索 ──────────────────────────────────────────────────────────────
def search(self, request: SearchRequest) -> SearchResponse:
if not self.is_loaded:
self.load()
if request.use_segmentation:
raw_kw = self._smart_split(request.query)
keywords = [w.strip() for w in raw_kw if w.strip() and w.strip() not in STOP_WORDS]
queries = [request.query] + keywords
else:
keywords = []
queries = [request.query]
q_emb = self.model.encode(queries, convert_to_tensor=True).float()
empty = [[] for _ in queries]
tl = request.target_layers
k = request.top_k
hits_en = util.semantic_search(q_emb, self.emb_en, top_k=k) if '英文' in tl else empty
hits_cn = util.semantic_search(q_emb, self.emb_cn, top_k=k) if '中文扩展词' in tl else empty
hits_wiki = util.semantic_search(q_emb, self.emb_wiki, top_k=k) if '释义' in tl else empty
hits_core = util.semantic_search(q_emb, self.emb_cn_core, top_k=k) if '中文核心词' in tl else empty
final: dict[str, TagResult] = {}
for i, source_word in enumerate(queries):
combined = (
[(h, '英文') for h in hits_en[i]]
+ [(h, '中文扩展词') for h in hits_cn[i]]
+ [(h, '释义') for h in hits_wiki[i]]
+ [(h, '中文核心词') for h in hits_core[i]]
)
for hit, layer in combined:
score = hit['score']
if score < 0.35:
continue
idx = hit['corpus_id']
row = self.df.iloc[idx]
cat_text = CAT_MAP.get(str(row.get('category', '0')), 'Other')
if cat_text not in request.target_categories:
continue
tag_name = row['name']
count = row['post_count']
pop_score = np.log1p(count) / self.max_log_count
w = request.popularity_weight
final_score = score * (1 - w) + pop_score * w
if tag_name not in final or final_score > final[tag_name].final_score:
final[tag_name] = TagResult(
tag=tag_name, cn_name=row['cn_name'], category=cat_text,
nsfw=str(row.get('nsfw', '0')),
final_score=round(float(final_score), 4),
semantic_score=round(float(score), 4),
count=int(count), source=source_word, layer=layer,
wiki=str(row.get('wiki', '')),
)
sorted_results = sorted(final.values(), key=lambda r: r.final_score, reverse=True)
valid = [r for r in sorted_results if r.final_score > 0.45][: request.limit]
tags_all = ', '.join(r.tag for r in valid)
tags_sfw = ', '.join(r.tag for r in valid if r.nsfw != '1')
return SearchResponse(
tags_all=tags_all, tags_sfw=tags_sfw,
results=valid, keywords=keywords,
)
# ── 全量构建 ──────────────────────────────────────────────────────────
def _build_full(self) -> None:
print(f'[Engine] 全量读取 {self.csv_path} ...')
raw_df = self._read_csv_robust(self.csv_path)
self.df = self._preprocess_raw_df(raw_df)
self.max_log_count = float(np.log1p(self.df['post_count'].max()))
self._encode_all_and_save()
def _encode_all_and_save(self) -> None:
print('[Engine] 全量编码...')
self.emb_en = self._encode_texts(self.df['name'].tolist())
self.emb_cn = self._encode_texts(self.df['cn_name'].tolist())
self.emb_wiki = self._encode_texts(self.df['wiki'].tolist())
self.emb_cn_core = self._encode_texts(self.df['cn_core'].tolist())
self._save_cache()
# ── 增量更新 ──────────────────────────────────────────────────────────
def _smart_update(self) -> None:
print('[Engine] 检查增量变更...')
t0 = time.time()
raw_df = self._read_csv_robust(self.csv_path)
new_df = self._preprocess_raw_df(raw_df)
_SIG_COLS = ['cn_name', 'wiki', 'cn_core']
def _sig(df: pd.DataFrame, iloc_idx: int) -> tuple:
row = df.iloc[iloc_idx]
return tuple(str(row.get(c, '')) for c in _SIG_COLS)
cached_idx: dict[str, int] = {n: i for i, n in enumerate(self.df['name'])}
new_idx: dict[str, int] = {n: i for i, n in enumerate(new_df['name'])}
added_names = [n for n in new_idx if n not in cached_idx]
deleted_names = [n for n in cached_idx if n not in new_idx]
changed_names = [
n for n in new_idx
if n in cached_idx and _sig(new_df, new_idx[n]) != _sig(self.df, cached_idx[n])
]
if not added_names and not deleted_names and not changed_names:
print('[Engine] 数据已是最新,无需更新。')
return
print(f'[Engine] 变更 → 新增: {len(added_names)} 修改: {len(changed_names)} 删除: {len(deleted_names)}')
if deleted_names:
keep_mask = ~self.df['name'].isin(set(deleted_names))
keep_pos = [i for i, v in enumerate(keep_mask) if v]
self.df = self.df[keep_mask].reset_index(drop=True)
self.emb_en = self.emb_en[keep_pos]
self.emb_cn = self.emb_cn[keep_pos]
self.emb_wiki = self.emb_wiki[keep_pos]
self.emb_cn_core = self.emb_cn_core[keep_pos]
cached_idx = {n: i for i, n in enumerate(self.df['name'])}
if changed_names:
changed_rows = new_df[new_df['name'].isin(set(changed_names))].reset_index(drop=True)
vecs_en = self._encode_texts(changed_rows['name'].tolist())
vecs_cn = self._encode_texts(changed_rows['cn_name'].tolist())
vecs_wiki = self._encode_texts(changed_rows['wiki'].tolist())
vecs_core = self._encode_texts(changed_rows['cn_core'].tolist())
for j, name in enumerate(changed_rows['name']):
ci = cached_idx[name]
self.emb_en[ci] = vecs_en[j]
self.emb_cn[ci] = vecs_cn[j]
self.emb_wiki[ci] = vecs_wiki[j]
self.emb_cn_core[ci] = vecs_core[j]
for col in changed_rows.columns:
self.df.at[ci, col] = changed_rows.at[j, col]
if added_names:
added_rows = new_df[new_df['name'].isin(set(added_names))].reset_index(drop=True)
vecs_en = self._encode_texts(added_rows['name'].tolist())
vecs_cn = self._encode_texts(added_rows['cn_name'].tolist())
vecs_wiki = self._encode_texts(added_rows['wiki'].tolist())
vecs_core = self._encode_texts(added_rows['cn_core'].tolist())
self.emb_en = torch.cat([self.emb_en, vecs_en], dim=0)
self.emb_cn = torch.cat([self.emb_cn, vecs_cn], dim=0)
self.emb_wiki = torch.cat([self.emb_wiki, vecs_wiki], dim=0)
self.emb_cn_core = torch.cat([self.emb_cn_core, vecs_core], dim=0)
self.df = pd.concat([self.df, added_rows], ignore_index=True)
self.max_log_count = float(np.log1p(self.df['post_count'].max()))
self._name_to_idx = {n: i for i, n in enumerate(self.df['name'])}
self._save_cache()
print(f'[Engine] 增量更新完成,耗时 {time.time() - t0:.2f}s(共 {len(self.df)} 条)')
# ── 缓存 I/O ──────────────────────────────────────────────────────────
def _save_cache(self) -> None:
self.paths.ensure_dir()
st_save(
{
'emb_en': self.emb_en.half(),
'emb_cn': self.emb_cn.half(),
'emb_wiki': self.emb_wiki.half(),
'emb_cn_core': self.emb_cn_core.half(),
},
str(self.paths.embeddings),
)
save_cols = ['name', 'cn_name', 'cn_core', 'wiki', 'nsfw', 'category', 'post_count']
self.df[save_cols].to_parquet(str(self.paths.metadata), index=False)
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with open(self.paths.meta_json, 'w', encoding='utf-8') as f:
json.dump({
'schema_version': SCHEMA_VERSION,
'updated_at': current_time,
}, f, ensure_ascii=False, indent=4)
print(f'[Engine] 缓存保存完成({len(self.df)} 条记录),生成时间: {current_time}')
def _load_from_cache(self) -> None:
tensors = st_load(str(self.paths.embeddings), device=self.device)
self.emb_en = tensors['emb_en'].float()
self.emb_cn = tensors['emb_cn'].float()
self.emb_wiki = tensors['emb_wiki'].float()
self.emb_cn_core = tensors['emb_cn_core'].float()
self.df = pd.read_parquet(str(self.paths.metadata))
self.max_log_count = float(np.log1p(self.df['post_count'].max()))
def _cached_schema_version(self) -> int:
try:
with open(self.paths.meta_json, 'r', encoding='utf-8') as f:
return int(json.load(f).get('schema_version', 1))
except Exception:
return 0
# ── 编码 & 预处理 ──────────────────────────────────────────────────────
def _encode_texts(self, texts: list[str]) -> torch.Tensor:
return self.model.encode(
texts, batch_size=64, show_progress_bar=len(texts) > 500, convert_to_tensor=True,
).float()
def _load_model(self) -> None:
if self.model is not None:
return
print(f'[Engine] 加载模型 (path={self.model_path}, device={self.device})...')
try:
self.model = SentenceTransformer(self.model_path, device=self.device)
except Exception as e:
print(f'[Engine] 指定路径加载失败,尝试重新解析: {e}')
fallback = resolve_model_path()
self.model = SentenceTransformer(fallback, device=self.device)
def _read_csv_robust(self, path: str) -> pd.DataFrame:
for enc in ['utf-8', 'gbk', 'gb18030']:
try:
return pd.read_csv(path, dtype=str, encoding=enc).fillna('')
except UnicodeDecodeError:
continue
raise ValueError('CSV 读取失败,请检查编码')
def _preprocess_raw_df(self, df: pd.DataFrame) -> pd.DataFrame:
df = df.copy()
print("Current columns:", df.columns.tolist())
df.dropna(subset=['name'], inplace=True)
df = df[df['name'].str.strip() != '']
for col in ['cn_name', 'category', 'wiki', 'nsfw']:
if col not in df.columns:
df[col] = ''
df['category'] = df['category'].fillna('0')
df['nsfw'] = df['nsfw'].fillna('0')
for char in [',', '|', '、']:
df['cn_name'] = df['cn_name'].str.replace(char, ',', regex=False)
if 'post_count' not in df.columns:
df['post_count'] = 0
df['post_count'] = pd.to_numeric(df['post_count'], errors='coerce').fillna(0)
df['cn_name'] = df['cn_name'].fillna('')
df['wiki'] = df['wiki'].fillna('')
df['cn_core'] = df['cn_name'].str.split(',', n=1).str[0].str.strip().fillna('')
df.drop_duplicates(subset=['name'], inplace=True)
df.reset_index(drop=True, inplace=True)
return df
def _setup_jieba_from_memory(self) -> None:
if self.df is None:
return
unique_words: set[str] = set()
for text in self.df['cn_name'].dropna().astype(str):
for part in text.replace(',', ' ').split():
part = part.strip()
if len(part) > 1:
unique_words.add(part)
for word in unique_words:
jieba.add_word(word, 2000)
def _smart_split(self, text: str) -> list[str]:
tokens: list[str] = []
for chunk in re.split(r'([一-龥]+)', text):
if not chunk.strip():
continue
if re.match(r'[一-龥]+', chunk):
tokens.extend(jieba.cut(chunk))
else:
cleaned = re.sub(r'[,()\[\]{}:]', ' ', chunk)
for part in cleaned.split():
try:
float(part)
except ValueError:
tokens.append(part)
return tokens
# ── 关联推荐 ──────────────────────────────────────────────────────────
def get_related(
self,
seed_tags: list[str],
exclude: set[str] | None = None,
limit: int = 20,
show_nsfw: bool = True,
) -> list:
from .models import RelatedTag
import math
if not self.cooc or not seed_tags:
return []
exclude = exclude or set()
# 估算语料库总大小 N,取数据集中发帖量的最大值,并设置合理下限
total_posts = float(max(self.df['post_count'].max(), 7000000.0))
npmi_scores: dict[str, float] = {}
total_cooc: dict[str, int] = {}
tag_sources: dict[str, list[str]] = {}
name_to_idx = self._name_to_idx
for seed in seed_tags:
if seed not in name_to_idx:
continue
seed_count = float(self.df.iloc[name_to_idx[seed]].get('post_count', 1) or 1)
for neighbor, cnt in self.cooc.get(seed, []):
if neighbor in exclude or neighbor == seed:
continue
if neighbor not in name_to_idx:
continue
neighbor_count = float(self.df.iloc[name_to_idx[neighbor]].get('post_count', 1) or 1)
cooc = min(float(cnt), seed_count, neighbor_count)
if cooc <= 0:
continue
# 计算分子:(Cooc * N) / (Count(A) * Count(B))
numerator = (cooc * total_posts) / (seed_count * neighbor_count)
# 忽略负相关或完全不相关的词条
if numerator <= 1.0:
continue
pmi = math.log(numerator)
# 计算分母:-log(P(A, B))
p_a_b = cooc / total_posts
if p_a_b >= 1.0:
npmi = 1.0
else:
npmi = pmi / -math.log(p_a_b)
# 多词条搜索时累加 NPMI
npmi_scores[neighbor] = npmi_scores.get(neighbor, 0.0) + npmi
total_cooc[neighbor] = total_cooc.get(neighbor, 0) + cnt
tag_sources.setdefault(neighbor, []).append(seed)
if not npmi_scores:
return []
# 归一化用于前端展示
max_score = max(npmi_scores.values())
sorted_candidates = sorted(npmi_scores.items(), key=lambda x: x[1], reverse=True)
results = []
for tag_name, raw_score in sorted_candidates:
if len(results) >= limit:
break
row = self.df.iloc[name_to_idx[tag_name]]
nsfw = str(row.get('nsfw', '0'))
if nsfw == '1' and not show_nsfw:
continue
cat = CAT_MAP.get(str(row.get('category', '0')), 'Other')
results.append(RelatedTag(
tag=tag_name,
cn_name=str(row.get('cn_name', '')),
category=cat,
nsfw=nsfw,
cooc_count=total_cooc[tag_name],
cooc_score=round(raw_score / max_score, 4),
sources=tag_sources.get(tag_name, []),
))
return results
def _load_cooc(self) -> None:
csv_path = Path(self.cooc_file)
parquet_path = csv_path.with_suffix('.parquet')
if parquet_path.is_file() and (
not csv_path.is_file()
or parquet_path.stat().st_mtime >= csv_path.stat().st_mtime
):
read_path = parquet_path
is_parquet = True
elif csv_path.is_file():
read_path = csv_path
is_parquet = False
else:
print(f'[Engine] 未找到共现表 ({self.cooc_file}),关联推荐功能不可用。')
return
print(f'[Engine] 加载共现表 ({read_path.name})...')
t0 = time.time()
try:
if is_parquet:
df = pd.read_parquet(str(read_path))
else:
df = self._read_csv_robust(str(read_path))
df['count'] = pd.to_numeric(df['count'], errors='coerce').fillna(0).astype(int)
df.to_parquet(str(parquet_path), index=False)
print(f'[Engine] 已将共现表缓存为 {parquet_path.name},下次启动将直接加载。')
tag_a = df['tag_a'].astype(str).to_numpy()
tag_b = df['tag_b'].astype(str).to_numpy()
counts = df['count'].astype(int).to_numpy()
src = np.concatenate([tag_a, tag_b])
dst = np.concatenate([tag_b, tag_a])
cnt = np.concatenate([counts, counts])
sort_idx = np.lexsort((-cnt, src))
src = src[sort_idx]
dst = dst[sort_idx]
cnt = cnt[sort_idx]
unique_srcs, first_pos = np.unique(src, return_index=True)
end_pos = np.append(first_pos[1:], len(src))
cooc: dict[str, list[tuple[str, int]]] = {}
for s, start, end in zip(unique_srcs, first_pos, end_pos):
cooc[s] = list(zip(dst[start:end].tolist(), cnt[start:end].tolist()))
self.cooc = cooc
print(
f'[Engine] 共现表加载完成,{len(cooc):,} 个 tag,'
f'耗时 {time.time() - t0:.2f}s'
)
except Exception as e:
print(f'[Engine] 共现表加载失败: {e}')