Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer, util | |
| import os | |
| import pickle | |
| import jieba | |
| import numpy as np | |
| from nicegui import ui, app, run | |
| import asyncio | |
| import torch | |
| import re | |
| import time | |
| def is_running_on_huggingface_space(): | |
| return os.environ.get("SPACE_ID") is not None | |
| # 核心逻辑类 (单例模式) | |
| class DanbooruTagger: | |
| _instance = None | |
| _lock = asyncio.Lock() | |
| async def get_instance(cls): | |
| """获取单例,如果未初始化则等待初始化""" | |
| async with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = cls() | |
| await run.io_bound(cls._instance.load) | |
| return cls._instance | |
| def __init__(self, model_path=None, csv_file='tags_enhanced.csv', | |
| cache_file='danbooru_vectors_multiview_'): | |
| local_model_path = 'my_model_bge_m3' | |
| hf_model_id = 'BAAI/bge-m3' | |
| if model_path: | |
| self.model_path = model_path | |
| elif os.path.exists(local_model_path): | |
| print(f"DTOOL: 检测到本地模型文件夹 '{local_model_path}',将使用本地模型。") | |
| self.model_path = local_model_path | |
| else: | |
| print(f"DTOOL: 未找到本地模型,将使用 HuggingFace ID '{hf_model_id}'。") | |
| self.model_path = hf_model_id | |
| self.csv_path = csv_file | |
| self.device = 'cpu' if torch.cuda.is_available() else 'cpu' | |
| # [修改] 更改缓存文件名后缀,避免读取旧缓存报错,触发重新构建 | |
| self.cache_path = cache_file + self.device + '_fp16.pkl' | |
| self.model = None | |
| self.df = None | |
| self.emb_en = None | |
| self.emb_cn = None | |
| self.emb_wiki = None | |
| self.emb_cn_core = None # [新增] 核心词向量 | |
| self.max_log_count = 15.0 | |
| self.is_loaded = False | |
| # 停用词表 | |
| self.stop_words = { | |
| ',', '.', ':', ';', '?', '!', '"', "'", '`', | |
| '(', ')', '[', ']', '{', '}', '<', '>', | |
| '-', '_', '=', '+', '/', '\\', '|', '@', '#', '$', '%', '^', '&', '*', '~', | |
| ',', '。', ':', ';', '?', '!', '“', '”', '‘', '’', | |
| '(', ')', '【', '】', '《', '》', '、', '…', '—', '·', | |
| ' ', '\t', '\n', '\r', | |
| '的', '地', '得', '了', '着', '过', | |
| '是', '为', '被', '给', '把', '让', '由', | |
| '在', '从', '自', '向', '往', '对', '于', | |
| '和', '与', '及', '或', '且', '而', '但', '并', '即', '又', '也', | |
| '啊', '吗', '吧', '呢', '噢', '哦', '哈', '呀', '哇', | |
| '我', '你', '他', '她', '它', '我们', '你们', '他们', | |
| '这', '那', '此', '其', '谁', '啥', '某', '每', | |
| '这个', '那个', '这些', '那些', '这里', '那里', | |
| '个', '位', '只', '条', '张', '幅', '件', '套', '双', '对', '副', | |
| '种', '类', '群', '些', '点', '份', '部', '名', | |
| '很', '太', '更', '最', '挺', '特', '好', '真', | |
| '一', '一个', '一种', '一下', '一点', '一些', | |
| '有', '无', '非', '没', '不' | |
| } | |
| # 类型映射表 | |
| self.cat_map = { | |
| '0': 'General', '1': 'Artist', '3': 'Copyright', '4': 'Character', '5': 'Meta' | |
| } | |
| def load(self): | |
| if self.is_loaded: return | |
| t0 = time.time() | |
| # 加载缓存 | |
| if not os.path.exists(self.cache_path): | |
| print("\n" + "=" * 50) | |
| print("DTOOL: 未找到四视图索引缓存,开始首次构建(这可能需要 1~3 分钟)...") | |
| print("=" * 50 + "\n") | |
| self._build_from_csv() | |
| else: | |
| print(f"DTOOL: 正在加载缓存索引 ({self.cache_path})...") | |
| self._load_from_cache() | |
| if os.path.exists(self.csv_path): | |
| self._smart_update() | |
| # 模型加载 | |
| if self.model is None: | |
| self._load_model() | |
| self._setup_jieba_from_memory() | |
| self.is_loaded = True | |
| print(f"DTOOL: 系统初始化完成,耗时 {time.time() - t0:.2f} 秒") | |
| def _load_model(self): | |
| print(f"DTOOL: 正在加载模型 (路径: {self.model_path}, Device: {self.device})...") | |
| try: | |
| self.model = SentenceTransformer(self.model_path, device=self.device) | |
| except Exception as e: | |
| print(f"DTOOL: 本地模型加载失败,尝试从 HF 拉取: {e}") | |
| self.model = SentenceTransformer('BAAI/bge-m3', device=self.device) | |
| def _read_csv_robust(self, path): | |
| for enc in ['utf-8', 'gbk', 'gb18030']: | |
| try: | |
| return pd.read_csv(path, dtype=str, encoding=enc).fillna("") | |
| except UnicodeDecodeError: | |
| continue | |
| raise ValueError("CSV 读取失败,请检查编码") | |
| def _preprocess_raw_df(self, df): | |
| df.dropna(subset=['name'], inplace=True) | |
| df = df[df['name'].str.strip() != ''] | |
| for col in ['cn_name', 'category', 'wiki', 'nsfw']: | |
| if col not in df.columns: df[col] = '' | |
| df['category'] = df['category'].fillna('0') | |
| df['nsfw'] = df['nsfw'].fillna('0') | |
| for char in [',', '|', '、']: | |
| df['cn_name'] = df['cn_name'].str.replace(char, ',', regex=False) | |
| if 'post_count' not in df.columns: df['post_count'] = 0 | |
| df['post_count'] = pd.to_numeric(df['post_count'], errors='coerce').fillna(0) | |
| df['cn_name'] = df['cn_name'].fillna("") | |
| df['wiki'] = df['wiki'].fillna("") | |
| # [新增] 提取 cn_name 的第一个词作为核心词 | |
| # 逻辑:按逗号分割,取第一个,去除首尾空格 | |
| df['cn_core'] = df['cn_name'].str.split(',', n=1).str[0].str.strip() | |
| df['cn_core'] = df['cn_core'].fillna("") | |
| df.drop_duplicates(subset=['name'], inplace=True) | |
| df.reset_index(drop=True, inplace=True) | |
| return df | |
| def _build_from_csv(self): | |
| print(f"DTOOL: 正在全量读取 {self.csv_path} ...") | |
| raw_df = self._read_csv_robust(self.csv_path) | |
| print("DTOOL: 正在预处理数据...") | |
| self.df = self._preprocess_raw_df(raw_df) | |
| self.max_log_count = np.log1p(self.df['post_count'].max()) | |
| self._load_model() | |
| self._encode_and_save() | |
| def _encode_and_save(self): | |
| print("DTOOL: 正在生成向量索引...") | |
| print("DTOOL: 正在生成英文索引...") | |
| self.emb_en = self.model.encode(self.df['name'].tolist(), batch_size=64, show_progress_bar=True, | |
| convert_to_tensor=True) | |
| print("DTOOL: 正在生成中文扩展索引...") | |
| self.emb_cn = self.model.encode(self.df['cn_name'].tolist(), batch_size=64, show_progress_bar=True, | |
| convert_to_tensor=True) | |
| print("DTOOL: 正在生成释义索引...") | |
| self.emb_wiki = self.model.encode(self.df['wiki'].tolist(), batch_size=64, show_progress_bar=True, | |
| convert_to_tensor=True) | |
| print("DTOOL: 正在生成中文核心词索引...") | |
| self.emb_cn_core = self.model.encode(self.df['cn_core'].tolist(), batch_size=64, show_progress_bar=True, | |
| convert_to_tensor=True) | |
| print("DTOOL: 正在保存 (FP16)...") | |
| cache_data = { | |
| 'df': self.df, | |
| 'embeddings_en': self.emb_en.half(), | |
| 'embeddings_cn': self.emb_cn.half(), | |
| 'embeddings_wiki': self.emb_wiki.half(), | |
| 'embeddings_cn_core': self.emb_cn_core.half(), # [新增] | |
| 'max_log_count': self.max_log_count | |
| } | |
| with open(self.cache_path, 'wb') as f: | |
| pickle.dump(cache_data, f) | |
| print("DTOOL: 缓存保存成功!") | |
| def _load_from_cache(self): | |
| with open(self.cache_path, 'rb') as f: | |
| data = pickle.load(f) | |
| self.df = data['df'] | |
| self.emb_en = data['embeddings_en'].float() | |
| self.emb_cn = data['embeddings_cn'].float() | |
| self.emb_wiki = data.get('embeddings_wiki', torch.zeros_like(self.emb_en)).float() | |
| self.emb_cn_core = data.get('embeddings_cn_core', torch.zeros_like(self.emb_en)).float() | |
| self.max_log_count = data.get('max_log_count', 15.0) | |
| def _smart_update(self): | |
| print("DTOOL: 正在检查数据变更...") | |
| raw_df = self._read_csv_robust(self.csv_path) | |
| compare_df = raw_df.copy() | |
| for col in ['cn_name', 'wiki', 'nsfw']: | |
| if col not in compare_df.columns: compare_df[col] = "" | |
| compare_df[col] = compare_df[col].fillna("") | |
| for char in [',', '|', '、']: | |
| compare_df['cn_name'] = compare_df['cn_name'].str.replace(char, ',', regex=False) | |
| current_map = {} | |
| for _, row in compare_df.iterrows(): | |
| current_map[row['name']] = (row['cn_name'], row['wiki'], row['nsfw']) | |
| cached_df = self.df | |
| cached_map = {} | |
| has_wiki = 'wiki' in cached_df.columns | |
| has_nsfw = 'nsfw' in cached_df.columns | |
| for _, row in cached_df.iterrows(): | |
| w = row['wiki'] if has_wiki else "" | |
| n = row['nsfw'] if has_nsfw else "0" | |
| cached_map[row['name']] = (row['cn_name'], w, n) | |
| new_tags = [] | |
| changed_tags = [] | |
| for name, new_tuple in current_map.items(): | |
| if name not in cached_map: | |
| new_tags.append(name) | |
| else: | |
| if new_tuple != cached_map[name]: | |
| changed_tags.append(name) | |
| if not new_tags and not changed_tags: | |
| print("DTOOL: 数据已是最新。") | |
| return | |
| print(f"DTOOL: 检测到变更 -> 新增: {len(new_tags)}, 修改: {len(changed_tags)}") | |
| print("DTOOL: 检测到数据变动,触发重建索引...") | |
| self._build_from_csv() | |
| def _setup_jieba_from_memory(self): | |
| print("DTOOL: 正在从内存构建 Jieba 词典...") | |
| if self.df is not None: | |
| target_col = 'cn_name' | |
| texts = self.df[target_col].dropna().astype(str).tolist() | |
| unique_words = set() | |
| for text in texts: | |
| parts = text.replace(',', ' ').split() | |
| for part in parts: | |
| part = part.strip() | |
| if len(part) > 1: unique_words.add(part) | |
| for word in unique_words: | |
| jieba.add_word(word, 2000) | |
| def _smart_split(self, text): | |
| tokens = [] | |
| chunks = re.split(r'([\u4e00-\u9fa5]+)', text) | |
| for chunk in chunks: | |
| if not chunk.strip(): continue | |
| if re.match(r'[\u4e00-\u9fa5]+', chunk): | |
| tokens.extend(jieba.cut(chunk)) | |
| else: | |
| cleaned = re.sub(r'[,()\[\]{}:]', ' ', chunk) | |
| parts = cleaned.split() | |
| for part in parts: | |
| try: | |
| float(part) | |
| except ValueError: | |
| tokens.append(part) | |
| return tokens | |
| def search(self, user_query, top_k=5, limit=80, popularity_weight=0.15, show_nsfw=False, | |
| use_segmentation=True, target_layers=None, target_categories=None): | |
| if not self.is_loaded: self.load() | |
| if target_layers is None: target_layers = ['英文', '中文扩展词', '释义', '中文核心词'] | |
| if target_categories is None: target_categories = ['General', 'Character', 'Copyright'] | |
| if use_segmentation: | |
| raw_keywords = self._smart_split(user_query) | |
| keywords = [w.strip() for w in raw_keywords if w.strip() and w.strip() not in self.stop_words] | |
| search_queries = [user_query] + keywords | |
| else: | |
| keywords = [] | |
| search_queries = [user_query] | |
| query_embeddings = self.model.encode(search_queries, convert_to_tensor=True).float() | |
| empty_hits = [[] for _ in search_queries] | |
| hits_en = util.semantic_search(query_embeddings, self.emb_en, | |
| top_k=top_k) if '英文' in target_layers else empty_hits | |
| hits_cn = util.semantic_search(query_embeddings, self.emb_cn, | |
| top_k=top_k) if '中文扩展词' in target_layers else empty_hits | |
| hits_wiki = util.semantic_search(query_embeddings, self.emb_wiki, | |
| top_k=top_k) if '释义' in target_layers else empty_hits | |
| hits_cn_core = util.semantic_search(query_embeddings, self.emb_cn_core, | |
| top_k=top_k) if '中文核心词' in target_layers else empty_hits | |
| final_results = {} | |
| for i, _ in enumerate(search_queries): | |
| source_word = search_queries[i] | |
| combined = [] | |
| for h in hits_en[i]: combined.append((h, '英文')) | |
| for h in hits_cn[i]: combined.append((h, '中文扩展词')) | |
| for h in hits_wiki[i]: combined.append((h, '释义')) | |
| for h in hits_cn_core[i]: combined.append((h, '中文核心词')) | |
| for hit, layer in combined: | |
| score = hit['score'] | |
| if score < 0.35: continue | |
| idx = hit['corpus_id'] | |
| row = self.df.iloc[idx] | |
| nsfw_flag = str(row.get('nsfw', '0')) | |
| if not show_nsfw and nsfw_flag == '1': continue | |
| cat_val = str(row.get('category', '0')) | |
| cat_text = self.cat_map.get(cat_val, 'Other') | |
| if cat_text not in target_categories: continue | |
| tag_name = row['name'] | |
| count = row['post_count'] | |
| log_count = np.log1p(count) | |
| pop_score = log_count / self.max_log_count | |
| final_score = (score * (1 - popularity_weight)) + (pop_score * popularity_weight) | |
| if tag_name not in final_results or final_score > final_results[tag_name]['final_score']: | |
| final_results[tag_name] = { | |
| 'tag': tag_name, | |
| 'final_score': round(float(final_score), 4), | |
| 'semantic_score': round(float(score), 4), | |
| 'cn_name': row['cn_name'], | |
| 'count': int(count), | |
| 'source': source_word, | |
| 'layer': layer, | |
| 'category': cat_text, | |
| 'nsfw': nsfw_flag, | |
| 'wiki': str(row.get('wiki', '')) | |
| } | |
| sorted_tags = sorted(final_results.items(), key=lambda x: x[1]['final_score'], reverse=True) | |
| valid_tags = [item[1] for item in sorted_tags if item[1]['final_score'] > 0.45] | |
| if len(valid_tags) > limit: | |
| valid_tags = valid_tags[:limit] | |
| tag_string = ", ".join([item['tag'] for item in valid_tags]) | |
| return tag_string, valid_tags, keywords | |
| # NiceGUI 界面 | |
| # 定义基础列 | |
| base_columns = [ | |
| {'name': 'tag', 'label': '匹配标签', 'field': 'tag', 'align': 'left', 'sortable': True}, | |
| {'name': 'cn_name', 'label': '含义', 'field': 'cn_name', 'align': 'left'}, | |
| {'name': 'category', 'label': '类型', 'field': 'category', 'align': 'left', 'sortable': True}, | |
| {'name': 'nsfw', 'label': '分级', 'field': 'nsfw', 'align': 'center', 'sortable': True}, | |
| {'name': 'final_score', 'label': '综合分', 'field': 'final_score', 'sortable': True}, | |
| {'name': 'count', 'label': '热度', 'field': 'count', 'sortable': True}, | |
| ] | |
| # 定义可选列 | |
| optional_col_map = { | |
| 'semantic': {'name': 'semantic_score', 'label': '语义分', 'field': 'semantic_score', 'sortable': True}, | |
| 'layer': {'name': 'layer', 'label': '匹配层', 'field': 'layer'}, | |
| 'source': {'name': 'source', 'label': '匹配来源', 'field': 'source'}, | |
| } | |
| async def main_page(): | |
| ui.colors(primary='#4A90E2', secondary='#5E6C84', accent='#FF6B6B') | |
| full_table_data = [] | |
| current_query_str = "" | |
| # 提示区 | |
| with ui.card().classes('w-full max-w-6xl mx-auto bg-orange-50 border-l-4 border-orange-500 mb-2'): | |
| with ui.column().classes('gap-1'): | |
| ui.label('⚠️ 注意事项 / Note').classes('text-lg font-bold text-orange-800') | |
| ui.markdown(""" | |
| - **本网站为AI工具,其结果未必正确无误** (Results may contain errors) | |
| - **查找结果可能会包括 NSFW 内容** (Results may include NSFW content) | |
| - **仅支持汉语、英语查找** (Only supports Chinese/English) | |
| - **仅显示Danbooru频数超过100的标签** (Frequency >= 100) | |
| - **仅显示特征、角色、作品标签** (General Character and Copyright tags only) | |
| - **Comfy UI 插件地址** : https://github.com/SuzumiyaAkizuki/ComfyUI-DanbooruSearcher | |
| """).classes('text-sm text-gray-800 ml-4') | |
| # 主体 | |
| with ui.column().classes('w-full max-w-6xl mx-auto p-4 gap-6'): | |
| with ui.row().classes('items-center gap-2'): | |
| ui.icon('search', size='2em', color='primary') | |
| ui.label('Danbooru 标签模糊搜索').classes('text-2xl font-bold text-gray-800') | |
| # 基础控制面板 | |
| with ui.card().classes('w-full'): | |
| with ui.grid(columns=4).classes('w-full gap-8 items-center'): | |
| input_top_k = ui.number('Top K (语义相关)', value=5, min=1, max=50) \ | |
| .props('outlined dense suffix="个"').classes('w-full') | |
| input_limit = ui.number('结果上限', value=80, min=10, max=500) \ | |
| .props('outlined dense suffix="个"').classes('w-full') | |
| with ui.column().classes('gap-0'): | |
| with ui.row().classes('w-full justify-between'): | |
| ui.label('热度权重').classes('text-xs text-gray-500') | |
| ui.label().bind_text_from(input_weight := ui.slider(min=0.0, max=1.0, value=0.15, step=0.05), | |
| 'value', lambda v: f"{v:.2f}") | |
| input_weight.classes('w-full') | |
| input_nsfw = ui.switch('显示 NSFW', value=False).props('color=red').classes('w-full') | |
| # 高级设置面板 | |
| with ui.expansion('高级设置 (Advanced Settings)', icon='tune').classes('w-full bg-gray-50 border rounded-lg'): | |
| with ui.column().classes('w-full p-4 gap-4'): | |
| # 分词开关 | |
| input_segment = ui.switch('启用智能分词 (Segmentation)', value=True).props('color=primary') | |
| ui.label('关闭后系统将只匹配完整句子,适用于精准搜索整句。').classes('text-xs text-gray-500 -mt-2 ml-10') | |
| ui.separator() | |
| # 匹配层筛选 | |
| ui.label('匹配层筛选 (Target Layers):').classes('font-bold text-gray-700') | |
| with ui.row().classes('gap-4'): | |
| layer_options = ['英文', '中文扩展词', '释义', '中文核心词'] | |
| selected_layers = {layer: True for layer in layer_options} | |
| def toggle_layer(l, value): selected_layers[l] = value | |
| for layer in layer_options: | |
| ui.checkbox(layer, value=True, on_change=lambda e, l=layer: toggle_layer(l, e.value)) | |
| ui.separator() | |
| # 类型筛选 | |
| ui.label('标签类型筛选 (Categories):').classes('font-bold text-gray-700') | |
| with ui.row().classes('gap-4 flex-wrap'): | |
| cat_options = ['General', 'Copyright', 'Character'] | |
| selected_cats = {cat: True for cat in cat_options} | |
| def toggle_cat(c, value): selected_cats[c] = value | |
| for cat in cat_options: | |
| color_map = {'General': 'blue', 'Copyright': 'pink', 'Character': 'green'} | |
| ui.checkbox(cat, value=True, on_change=lambda e, c=cat: toggle_cat(c, e.value)) \ | |
| .props(f'color={color_map.get(cat, "primary")}') | |
| ui.separator() | |
| # 表格显示选项 | |
| ui.label('表格显示选项 (Display Options):').classes('font-bold text-gray-700') | |
| with ui.row().classes('gap-6'): | |
| sw_semantic = ui.switch('显示语义分', value=False) | |
| sw_layer = ui.switch('显示匹配层', value=False) | |
| sw_source = ui.switch('显示匹配来源', value=False) | |
| # 动态更新表格列 | |
| def update_table_columns(): | |
| cols = list(base_columns) # 复制基础列 | |
| if sw_semantic.value: cols.append(optional_col_map['semantic']) | |
| if sw_layer.value: cols.append(optional_col_map['layer']) | |
| if sw_source.value: cols.append(optional_col_map['source']) | |
| result_table.columns = cols | |
| sw_semantic.on('update:model-value', update_table_columns) | |
| sw_layer.on('update:model-value', update_table_columns) | |
| sw_source.on('update:model-value', update_table_columns) | |
| # 搜索输入区 | |
| with ui.card().classes('w-full p-0 overflow-hidden'): | |
| with ui.column().classes('w-full p-6 gap-4'): | |
| ui.label('画面描述').classes('text-lg font-bold text-gray-700') | |
| search_input = ui.textarea(placeholder='例如:一个穿着白色水手服的女孩在雨中奔跑').classes( | |
| 'w-full text-lg').props('outlined rows=3') | |
| keywords_container = ui.row().classes('gap-2 items-center') | |
| spinner = ui.spinner(size='2em').classes('hidden') | |
| def filter_table_by_source(keyword): | |
| if not keyword or keyword == 'ALL': | |
| result_table.rows = full_table_data | |
| else: | |
| result_table.rows = [row for row in full_table_data if row['source'] == keyword] | |
| for child in keywords_container.default_slot.children: | |
| if isinstance(child, ui.chip): | |
| is_selected = False | |
| if keyword == 'ALL' and child.text == '全部': | |
| is_selected = True | |
| elif keyword == current_query_str and child.text == '整句': | |
| is_selected = True | |
| elif child.text == keyword: | |
| is_selected = True | |
| child.props( | |
| f'color={"primary" if is_selected else "grey-4"} text-color={"white" if is_selected else "black"}') | |
| async def perform_search(): | |
| nonlocal full_table_data, current_query_str | |
| query = search_input.value.strip() | |
| if not query: return | |
| current_query_str = query | |
| search_btn.disable() | |
| spinner.classes(remove='hidden') | |
| ui.notify('正在搜索...', type='info') | |
| target_layers_list = [k for k, v in selected_layers.items() if v] | |
| target_cats_list = [k for k, v in selected_cats.items() if v] | |
| if not target_layers_list: | |
| ui.notify('请至少选择一个匹配层!', type='warning') | |
| search_btn.enable() | |
| spinner.classes(add='hidden') | |
| return | |
| try: | |
| tagger = await DanbooruTagger.get_instance() | |
| tags_str, table_data, keywords = await run.io_bound( | |
| tagger.search, | |
| query, | |
| int(input_top_k.value), | |
| int(input_limit.value), | |
| float(input_weight.value), | |
| input_nsfw.value, | |
| input_segment.value, | |
| target_layers_list, | |
| target_cats_list | |
| ) | |
| full_table_data = table_data | |
| all_result_area.value = tags_str | |
| result_table.rows = table_data | |
| result_table.selected = [] | |
| update_selection_display(None) | |
| keywords_container.clear() | |
| with keywords_container: | |
| ui.label('分词筛选:').classes('text-sm text-gray-500 font-bold mr-2') | |
| ui.chip('全部', on_click=lambda: filter_table_by_source('ALL')) \ | |
| .props('color=primary text-color=white clickable') | |
| if input_segment.value: | |
| ui.chip('整句', on_click=lambda: filter_table_by_source(current_query_str)) \ | |
| .props('color=grey-4 text-color=black clickable') | |
| for kw in keywords: | |
| ui.chip(kw, on_click=lambda k=kw: filter_table_by_source(k)) \ | |
| .props('color=grey-4 text-color=black clickable') | |
| else: | |
| ui.label('(分词已关闭)').classes('text-xs text-gray-400') | |
| ui.notify(f'找到 {len(table_data)} 个标签', type='positive') | |
| except Exception as e: | |
| ui.notify(f'错误: {str(e)}', type='negative') | |
| finally: | |
| search_btn.enable() | |
| spinner.classes(add='hidden') | |
| with ui.row().classes('w-full justify-end items-center gap-4'): | |
| spinner | |
| search_btn = ui.button('开始搜索', on_click=perform_search, icon='search') | |
| search_btn.classes('px-8 py-2 text-lg').props('unelevated color=primary') | |
| search_input.on('keydown.ctrl.enter', perform_search) | |
| # 结果表格 | |
| with ui.row().classes('w-full gap-6'): | |
| with ui.card().classes('w-1/3 flex-grow'): | |
| ui.label('推荐 Prompt (全部)').classes('font-bold text-gray-600') | |
| all_result_area = ui.textarea().classes('w-full h-full bg-gray-50').props( | |
| 'readonly outlined input-class=text-sm') | |
| with ui.column().classes('w-2/3 flex-grow'): | |
| with ui.card().classes('w-full bg-blue-50 border-blue-200 border'): | |
| with ui.row().classes('w-full items-center justify-between'): | |
| with ui.row().classes('items-center gap-2'): | |
| ui.icon('check_circle', color='primary') | |
| ui.label('已选标签:').classes('font-bold text-primary') | |
| selection_count_label = ui.label('0').classes( | |
| 'bg-primary text-white px-2 rounded-full text-sm') | |
| copy_btn = ui.button('复制选中', icon='content_copy').props('dense unelevated color=primary') | |
| selected_display = ui.textarea().classes('w-full mt-2').props( | |
| 'outlined dense rows=2 readonly bg-white') | |
| def copy_selection(): | |
| ui.clipboard.write(selected_display.value) | |
| ui.notify('已复制选中标签!', type='positive') | |
| copy_btn.on_click(copy_selection) | |
| def update_selection_display(e): | |
| selected_rows = result_table.selected | |
| tags = [row['tag'] for row in selected_rows] | |
| selected_display.value = ", ".join(tags) | |
| selection_count_label.text = str(len(tags)) | |
| result_table = ui.table( | |
| columns=base_columns, | |
| rows=[], | |
| pagination=10, | |
| selection='multiple', | |
| row_key='tag' | |
| ).classes('w-full') | |
| result_table.on('selection', update_selection_display) | |
| result_table.add_slot('body-cell-final_score', ''' | |
| <q-td :props="props"> | |
| <q-badge :color="props.value > 0.6 ? 'green' : (props.value > 0.5 ? 'teal' : 'orange')"> | |
| {{ props.value }} | |
| </q-badge> | |
| </q-td> | |
| ''') | |
| result_table.add_slot('body-cell-category', ''' | |
| <q-td :props="props"> | |
| <q-badge :color=" | |
| props.value === 'General' ? 'blue' : | |
| (props.value === 'Character' ? 'green' : | |
| (props.value === 'Copyright' ? 'pink' : 'red')) | |
| " outline> | |
| {{ props.value }} | |
| </q-badge> | |
| </q-td> | |
| ''') | |
| result_table.add_slot('body-cell-nsfw', ''' | |
| <q-td :props="props"> | |
| <div v-if="props.value === '1'" title="NSFW Content" class="text-red-500">🔴</div> | |
| <div v-else class="text-green-500">🟢</div> | |
| </q-td> | |
| ''') | |
| # 4. 中文含义分词显示 + 悬停显示 Wiki (修复字体大小问题) | |
| result_table.add_slot('body-cell-cn_name', ''' | |
| <q-td :props="props"> | |
| <template v-if="props.value"> | |
| <q-badge | |
| v-for="(item, index) in props.value.split(',')" | |
| :key="index" | |
| :color="index === 0 ? 'black' : 'grey'" | |
| outline | |
| style="font-size: 14px" | |
| class="q-mr-xs q-mb-xs cursor-help" | |
| > | |
| {{ item }} | |
| </q-badge> | |
| </template> | |
| <q-tooltip | |
| v-if="props.row.wiki" | |
| content-class="bg-black text-white shadow-4" | |
| max-width="500px" | |
| :offset="[10, 10]" | |
| > | |
| <div style="font-size: 14px; line-height: 1.5;"> | |
| {{ props.row.wiki }} | |
| </div> | |
| </q-tooltip> | |
| </q-td> | |
| ''') | |
| if __name__ in {"__main__", "__mp_main__"}: | |
| if is_running_on_huggingface_space(): | |
| host = '0.0.0.0' | |
| port = 7860 | |
| else: | |
| host = '127.0.0.1' | |
| port = 1145 | |
| ui.run( | |
| host=host, | |
| port=port, | |
| title='Danbooru Tags Searcher', | |
| reload=True, | |
| show=True | |
| ) |