Spaces:
Sleeping
Sleeping
| """ | |
| 语义编码器 — vecs75字符级语义编码 | |
| 替代brain.py中的hash编码,让输入向量具有语义信息 | |
| 原理: | |
| "你好" → 查表"你"vec + "好"vec → 取平均 → 归一化 → 75维语义向量 | |
| 语义相近的文字编码后余弦相似度高,hash编码做不到 | |
| """ | |
| import numpy as np | |
| import pickle | |
| import os | |
| from typing import Optional | |
| class SemanticEncoder: | |
| """vecs75语义编码器""" | |
| def __init__(self, model_dir: str = None): | |
| self.words = [] | |
| self.vecs75 = None | |
| self.vecs_n = None # 归一化版本 | |
| self.w2i = {} | |
| self.dim = 75 | |
| self._loaded = False | |
| self._clean_loaded = False # 干净解码词表 | |
| # 尝试加载(支持pkl和npz+json格式) | |
| paths = [] | |
| if model_dir: | |
| paths.append(os.path.join(model_dir, 'vocab75_index.pkl')) | |
| paths.append(os.path.join(model_dir, 'vocab75_index.npz')) # npz格式 | |
| paths.append(os.path.expanduser('~/.swarm/models/vocab75_index.pkl')) | |
| paths.append(os.path.expanduser('~/.swarm/models/vocab75_index.npz')) | |
| paths.append(os.path.expanduser('~/swarm_product/models/vocab75_index.pkl')) | |
| paths.append(os.path.expanduser('~/swarm_product/models/vocab75_index.npz')) | |
| for p in paths: | |
| if os.path.exists(p): | |
| self._load(p) | |
| break | |
| if not self._loaded: | |
| # 硬编码路径 (HF Space + 本地) | |
| check_paths = [ | |
| '/home/admin/swarm_product/models/vocab75_index.pkl', | |
| '/home/admin/swarm_product/models/vocab75_index.npz', | |
| '/app/models/vocab75_index.pkl', | |
| '/app/models/vocab75_index.npz', | |
| 'models/vocab75_index.pkl', # 相对路径 | |
| 'models/vocab75_index.npz', | |
| ] | |
| import sys | |
| print(f'[Encoder] 检查路径: {check_paths}', file=sys.stderr) | |
| for p in check_paths: | |
| print(f'[Encoder] 检查 {p}, exists={os.path.exists(p)}', file=sys.stderr) | |
| if os.path.exists(p): | |
| self._load(p) | |
| break | |
| def _load(self, path: str): | |
| """加载vecs75词表(支持pkl和npz+json格式)""" | |
| try: | |
| # 确定npz和json路径 | |
| if path.endswith('.npz'): | |
| npz_path = path | |
| json_path = path.replace('.npz', '_words.json') | |
| else: | |
| npz_path = path.replace('.pkl', '.npz') | |
| json_path = path.replace('.pkl', '_words.json') | |
| if os.path.exists(npz_path) and os.path.exists(json_path): | |
| import json as _json | |
| with open(json_path, 'r', encoding='utf-8') as f: | |
| self.words = _json.load(f) | |
| data = np.load(npz_path) | |
| self.vecs75 = data['vecs75'] | |
| if 'vecs75_normed' in data: | |
| norms = np.linalg.norm(self.vecs75, axis=1, keepdims=True) | |
| norms[norms < 1e-8] = 1 | |
| self.vecs_n = self.vecs75 / norms | |
| else: | |
| self.vecs_n = data['vecs75_normed'] | |
| elif os.path.exists(path): | |
| # 回退pkl格式 | |
| with open(path, 'rb') as f: | |
| data = pickle.load(f) | |
| self.words = data['words'] | |
| self.vecs75 = data['vecs75'] | |
| norms = np.linalg.norm(self.vecs75, axis=1, keepdims=True) | |
| norms[norms < 1e-8] = 1 | |
| self.vecs_n = self.vecs75 / norms | |
| else: | |
| return | |
| self.w2i = {w: i for i, w in enumerate(self.words)} | |
| self._loaded = True | |
| print(f'[SemanticEncoder] 加载: {len(self.words)}词, {self.vecs75.shape}') | |
| # 加载干净解码词表 | |
| self._load_clean(path) | |
| except Exception as e: | |
| print(f'[SemanticEncoder] 加载失败: {e}') | |
| def _load_clean(self, orig_path: str): | |
| """加载干净解码词表(支持pkl和npz+json格式)""" | |
| # 尝试npz+json格式(优先) | |
| npz_path = orig_path.replace('vocab75_index.pkl', 'vocab75_clean.npz') | |
| json_path = orig_path.replace('vocab75_index.pkl', 'vocab75_clean_words.json') | |
| if os.path.exists(npz_path) and os.path.exists(json_path): | |
| try: | |
| import json as _json | |
| with open(json_path, 'r', encoding='utf-8') as f: | |
| self._clean_words = _json.load(f) | |
| data = np.load(npz_path) | |
| self._clean_vecs_n = data['vecs75_normed'] | |
| self._clean_loaded = True | |
| print(f'[SemanticEncoder] 干净词表: {len(self._clean_words)}词') | |
| return | |
| except Exception as e: | |
| print(f'[SemanticEncoder] 干净词表(npz)加载失败: {e}') | |
| # 回退pkl格式 | |
| clean_path = orig_path.replace('vocab75_index.pkl', 'vocab75_clean.pkl') | |
| if not os.path.exists(clean_path): | |
| clean_path = '/home/admin/swarm_product/models/vocab75_clean.pkl' | |
| if os.path.exists(clean_path): | |
| try: | |
| with open(clean_path, 'rb') as f: | |
| data = pickle.load(f) | |
| self._clean_words = data['words'] | |
| self._clean_vecs_n = data['vecs75_normed'] | |
| self._clean_loaded = True | |
| print(f'[SemanticEncoder] 干净词表: {len(self._clean_words)}词') | |
| except Exception as e: | |
| print(f'[SemanticEncoder] 干净词表加载失败: {e}') | |
| def encode(self, text: str) -> np.ndarray: | |
| """ | |
| 文本→75维语义向量 | |
| 策略: 字符级查表+加权平均 | |
| - 前面的字权重高(注意力衰减) | |
| - 取平均后归一化 | |
| - 无匹配字符时回退到hash编码 | |
| """ | |
| if not self._loaded or not text: | |
| return self._hash_encode(text or '') | |
| # 字符级查表 | |
| char_vecs = [] | |
| weights = [] | |
| for i, ch in enumerate(text[:20]): | |
| if ch in self.w2i: | |
| idx = self.w2i[ch] | |
| char_vecs.append(self.vecs75[idx]) | |
| weights.append(1.0 / (i + 1)) # 前面的字更重要 | |
| if not char_vecs: | |
| # 全部字符不在词表中,回退hash | |
| return self._hash_encode(text) | |
| # 加权平均 | |
| char_vecs = np.array(char_vecs) | |
| weights = np.array(weights).reshape(-1, 1) | |
| vec = (char_vecs * weights).sum(axis=0) / weights.sum() | |
| # 归一化 | |
| norm = np.linalg.norm(vec) | |
| if norm > 1e-8: | |
| vec = vec / norm | |
| return vec.astype(np.float32) | |
| def _hash_encode(self, text: str) -> np.ndarray: | |
| """回退: hash编码(旧逻辑)""" | |
| vec = np.zeros(self.dim, dtype=np.float32) | |
| for i, ch in enumerate(text[:20]): | |
| idx = hash(ch) % self.dim | |
| vec[idx] += 1.0 / (i + 1) | |
| if vec.max() > 0: | |
| vec = vec / vec.max() | |
| return vec | |
| def decode_nearest(self, vec: np.ndarray, top_k: int = 5, | |
| prefer_chinese: bool = True, max_word_len: int = 4) -> list: | |
| """ | |
| 向量→最近邻词汇(输出解码器用) | |
| Args: | |
| vec: 75维向量 | |
| top_k: 返回前k个 | |
| prefer_chinese: 优先返回中文词(过滤英文/长短语) | |
| max_word_len: 最大词长度(过滤长短语) | |
| Returns: | |
| [(词, 相似度), ...] | |
| """ | |
| if not self._loaded: | |
| return [] | |
| vec = np.asarray(vec, dtype=np.float32).ravel()[:self.dim] | |
| if len(vec) < self.dim: | |
| vec = np.pad(vec, (0, self.dim - len(vec))) | |
| norm = np.linalg.norm(vec) | |
| if norm < 1e-8: | |
| return [] | |
| vec_n = vec / norm | |
| # 余弦相似度 | |
| # 优先用干净词表解码 | |
| if self._clean_loaded: | |
| sims = self._clean_vecs_n @ vec_n | |
| # 先取较多候选 | |
| n_cand = min(top_k * 5, len(sims)) | |
| top_indices = np.argsort(sims)[-n_cand:][::-1] | |
| results = [] | |
| # 优先中文词 | |
| for i in top_indices: | |
| w = self._clean_words[i] | |
| if '\u4e00' <= w[0] <= '\u9fff': # 首字是中文 | |
| results.append((w, float(sims[i]))) | |
| if len(results) >= top_k: | |
| break | |
| # 不够再补英文 | |
| if len(results) < top_k: | |
| for i in top_indices: | |
| w = self._clean_words[i] | |
| if not any(r[0] == w for r in results): | |
| results.append((w, float(sims[i]))) | |
| if len(results) >= top_k: | |
| break | |
| return results | |
| sims = self.vecs_n @ vec_n | |
| if prefer_chinese: | |
| # 先取top_k*3候选,再过滤 | |
| n_cand = min(top_k * 5, len(sims)) | |
| top_indices = np.argsort(sims)[-n_cand:][::-1] | |
| results = [] | |
| for i in top_indices: | |
| w = self.words[i] | |
| # 过滤: 只要中文词且长度<=max_word_len, 排除脏数据(以n开头的中英混合) | |
| if len(w) <= max_word_len and any('\u4e00' <= c <= '\u9fff' for c in w): | |
| # 排除vecs75脏数据: 以非中文字符开头但含中文的混合词 | |
| first_char = w[0] | |
| if '\u4e00' <= first_char <= '\u9fff': | |
| results.append((w, float(sims[i]))) | |
| elif first_char.isalpha() and len(w) > 1 and '\u4e00' <= w[1] <= '\u9fff': | |
| continue # 跳过"n这个"类脏数据 | |
| else: | |
| results.append((w, float(sims[i]))) | |
| if len(results) >= top_k: | |
| break | |
| # 如果过滤后不够,补回英文/长词 | |
| if len(results) < top_k: | |
| for i in top_indices: | |
| w = self.words[i] | |
| if not any(r[0] == w for r in results): | |
| results.append((w, float(sims[i]))) | |
| if len(results) >= top_k: | |
| break | |
| return results | |
| else: | |
| top_indices = np.argsort(sims)[-top_k:][::-1] | |
| return [(self.words[i], float(sims[i])) for i in top_indices] | |
| def encode_sentence(self, text: str) -> np.ndarray: | |
| """ | |
| 句子级编码 — 分词后词向量平均(比字符级更精准) | |
| 简单分词: 连续中文字符/连续英文/数字各为一段 | |
| """ | |
| if not self._loaded or not text: | |
| return self.encode(text or '') | |
| # 简单分词: 2-gram + 1-gram | |
| tokens = set() | |
| # 1-gram | |
| for ch in text: | |
| if ch in self.w2i: | |
| tokens.add(ch) | |
| # 2-gram (相邻字组合) | |
| for i in range(len(text) - 1): | |
| bigram = text[i:i+2] | |
| if bigram in self.w2i: | |
| tokens.add(bigram) | |
| if not tokens: | |
| return self.encode(text) | |
| # 取平均 | |
| idxs = [self.w2i[t] for t in tokens] | |
| vec = self.vecs75[idxs].mean(axis=0) | |
| norm = np.linalg.norm(vec) | |
| if norm > 1e-8: | |
| vec = vec / norm | |
| return vec.astype(np.float32) | |
| # 全局单例(延迟加载) | |
| _encoder = None | |
| def get_encoder() -> SemanticEncoder: | |
| """获取全局编码器实例""" | |
| global _encoder | |
| if _encoder is None: | |
| _encoder = SemanticEncoder() | |
| return _encoder | |