| """ |
| 语义编码器 — vecs75字符级语义编码 |
| 替代brain.py中的hash编码,让输入向量具有语义信息 |
| |
| 原理: |
| "你好" → 查表"你"vec + "好"vec → 取平均 → 归一化 → 75维语义向量 |
| 语义相近的文字编码后余弦相似度高,hash编码做不到 |
| """ |
| import numpy as np |
| import pickle |
| import os |
| from typing import Optional |
|
|
|
|
| class SemanticEncoder: |
| """vecs75语义编码器""" |
| |
| def __init__(self, model_dir: str = None): |
| self.words = [] |
| self.vecs75 = None |
| self.vecs_n = None |
| self.w2i = {} |
| self.dim = 75 |
| self._loaded = False |
| self._clean_loaded = False |
| |
| |
| paths = [] |
| if model_dir: |
| paths.append(os.path.join(model_dir, 'vocab75_index.pkl')) |
| paths.append(os.path.join(model_dir, 'vocab75_index.npz')) |
| paths.append(os.path.expanduser('~/.swarm/models/vocab75_index.pkl')) |
| paths.append(os.path.expanduser('~/.swarm/models/vocab75_index.npz')) |
| paths.append(os.path.expanduser('~/swarm_product/models/vocab75_index.pkl')) |
| paths.append(os.path.expanduser('~/swarm_product/models/vocab75_index.npz')) |
| |
| for p in paths: |
| if os.path.exists(p): |
| self._load(p) |
| break |
| |
| if not self._loaded: |
| |
| for p in ['/home/admin/swarm_product/models/vocab75_index.pkl', |
| '/home/admin/swarm_product/models/vocab75_index.npz', |
| '/app/models/vocab75_index.npz']: |
| if os.path.exists(p): |
| self._load(p) |
| break |
| |
| def _load(self, path: str): |
| """加载vecs75词表(支持pkl和npz+json格式)""" |
| try: |
| |
| if path.endswith('.npz'): |
| npz_path = path |
| json_path = path.replace('.npz', '_words.json') |
| else: |
| npz_path = path.replace('.pkl', '.npz') |
| json_path = path.replace('.pkl', '_words.json') |
| if os.path.exists(npz_path) and os.path.exists(json_path): |
| import json as _json |
| with open(json_path, 'r', encoding='utf-8') as f: |
| self.words = _json.load(f) |
| data = np.load(npz_path) |
| self.vecs75 = data['vecs75'] |
| if 'vecs75_normed' in data: |
| norms = np.linalg.norm(self.vecs75, axis=1, keepdims=True) |
| norms[norms < 1e-8] = 1 |
| self.vecs_n = self.vecs75 / norms |
| else: |
| self.vecs_n = data['vecs75_normed'] |
| elif os.path.exists(path): |
| |
| with open(path, 'rb') as f: |
| data = pickle.load(f) |
| self.words = data['words'] |
| self.vecs75 = data['vecs75'] |
| norms = np.linalg.norm(self.vecs75, axis=1, keepdims=True) |
| norms[norms < 1e-8] = 1 |
| self.vecs_n = self.vecs75 / norms |
| else: |
| return |
| self.w2i = {w: i for i, w in enumerate(self.words)} |
| self._loaded = True |
| print(f'[SemanticEncoder] 加载: {len(self.words)}词, {self.vecs75.shape}') |
| |
| self._load_clean(path) |
| except Exception as e: |
| print(f'[SemanticEncoder] 加载失败: {e}') |
| |
| def _load_clean(self, orig_path: str): |
| """加载干净解码词表(支持pkl和npz+json格式)""" |
| |
| npz_path = orig_path.replace('vocab75_index.pkl', 'vocab75_clean.npz') |
| json_path = orig_path.replace('vocab75_index.pkl', 'vocab75_clean_words.json') |
| if os.path.exists(npz_path) and os.path.exists(json_path): |
| try: |
| import json as _json |
| with open(json_path, 'r', encoding='utf-8') as f: |
| self._clean_words = _json.load(f) |
| data = np.load(npz_path) |
| self._clean_vecs_n = data['vecs75_normed'] |
| self._clean_loaded = True |
| print(f'[SemanticEncoder] 干净词表: {len(self._clean_words)}词') |
| return |
| except Exception as e: |
| print(f'[SemanticEncoder] 干净词表(npz)加载失败: {e}') |
| |
| |
| clean_path = orig_path.replace('vocab75_index.pkl', 'vocab75_clean.pkl') |
| if not os.path.exists(clean_path): |
| clean_path = '/home/admin/swarm_product/models/vocab75_clean.pkl' |
| if os.path.exists(clean_path): |
| try: |
| with open(clean_path, 'rb') as f: |
| data = pickle.load(f) |
| self._clean_words = data['words'] |
| self._clean_vecs_n = data['vecs75_normed'] |
| self._clean_loaded = True |
| print(f'[SemanticEncoder] 干净词表: {len(self._clean_words)}词') |
| except Exception as e: |
| print(f'[SemanticEncoder] 干净词表加载失败: {e}') |
| |
| def encode(self, text: str) -> np.ndarray: |
| """ |
| 文本→75维语义向量 |
| |
| 策略: 字符级查表+加权平均 |
| - 前面的字权重高(注意力衰减) |
| - 取平均后归一化 |
| - 无匹配字符时回退到hash编码 |
| """ |
| if not self._loaded or not text: |
| return self._hash_encode(text or '') |
| |
| |
| char_vecs = [] |
| weights = [] |
| for i, ch in enumerate(text[:20]): |
| if ch in self.w2i: |
| idx = self.w2i[ch] |
| char_vecs.append(self.vecs75[idx]) |
| weights.append(1.0 / (i + 1)) |
| |
| if not char_vecs: |
| |
| return self._hash_encode(text) |
| |
| |
| char_vecs = np.array(char_vecs) |
| weights = np.array(weights).reshape(-1, 1) |
| vec = (char_vecs * weights).sum(axis=0) / weights.sum() |
| |
| |
| norm = np.linalg.norm(vec) |
| if norm > 1e-8: |
| vec = vec / norm |
| |
| return vec.astype(np.float32) |
| |
| def _hash_encode(self, text: str) -> np.ndarray: |
| """回退: hash编码(旧逻辑)""" |
| vec = np.zeros(self.dim, dtype=np.float32) |
| for i, ch in enumerate(text[:20]): |
| idx = hash(ch) % self.dim |
| vec[idx] += 1.0 / (i + 1) |
| if vec.max() > 0: |
| vec = vec / vec.max() |
| return vec |
| |
| def decode_nearest(self, vec: np.ndarray, top_k: int = 5, |
| prefer_chinese: bool = True, max_word_len: int = 4) -> list: |
| """ |
| 向量→最近邻词汇(输出解码器用) |
| |
| Args: |
| vec: 75维向量 |
| top_k: 返回前k个 |
| prefer_chinese: 优先返回中文词(过滤英文/长短语) |
| max_word_len: 最大词长度(过滤长短语) |
| |
| Returns: |
| [(词, 相似度), ...] |
| """ |
| if not self._loaded: |
| return [] |
| |
| vec = np.asarray(vec, dtype=np.float32).ravel()[:self.dim] |
| if len(vec) < self.dim: |
| vec = np.pad(vec, (0, self.dim - len(vec))) |
| |
| norm = np.linalg.norm(vec) |
| if norm < 1e-8: |
| return [] |
| vec_n = vec / norm |
| |
| |
| |
| if self._clean_loaded: |
| sims = self._clean_vecs_n @ vec_n |
| |
| n_cand = min(top_k * 5, len(sims)) |
| top_indices = np.argsort(sims)[-n_cand:][::-1] |
| results = [] |
| |
| for i in top_indices: |
| w = self._clean_words[i] |
| if '\u4e00' <= w[0] <= '\u9fff': |
| results.append((w, float(sims[i]))) |
| if len(results) >= top_k: |
| break |
| |
| if len(results) < top_k: |
| for i in top_indices: |
| w = self._clean_words[i] |
| if not any(r[0] == w for r in results): |
| results.append((w, float(sims[i]))) |
| if len(results) >= top_k: |
| break |
| return results |
| |
| sims = self.vecs_n @ vec_n |
| |
| if prefer_chinese: |
| |
| n_cand = min(top_k * 5, len(sims)) |
| top_indices = np.argsort(sims)[-n_cand:][::-1] |
| results = [] |
| for i in top_indices: |
| w = self.words[i] |
| |
| if len(w) <= max_word_len and any('\u4e00' <= c <= '\u9fff' for c in w): |
| |
| first_char = w[0] |
| if '\u4e00' <= first_char <= '\u9fff': |
| results.append((w, float(sims[i]))) |
| elif first_char.isalpha() and len(w) > 1 and '\u4e00' <= w[1] <= '\u9fff': |
| continue |
| else: |
| results.append((w, float(sims[i]))) |
| if len(results) >= top_k: |
| break |
| |
| if len(results) < top_k: |
| for i in top_indices: |
| w = self.words[i] |
| if not any(r[0] == w for r in results): |
| results.append((w, float(sims[i]))) |
| if len(results) >= top_k: |
| break |
| return results |
| else: |
| top_indices = np.argsort(sims)[-top_k:][::-1] |
| return [(self.words[i], float(sims[i])) for i in top_indices] |
| |
| def encode_sentence(self, text: str) -> np.ndarray: |
| """ |
| 句子级编码 — 分词后词向量平均(比字符级更精准) |
| |
| 简单分词: 连续中文字符/连续英文/数字各为一段 |
| """ |
| if not self._loaded or not text: |
| return self.encode(text or '') |
| |
| |
| tokens = set() |
| |
| for ch in text: |
| if ch in self.w2i: |
| tokens.add(ch) |
| |
| for i in range(len(text) - 1): |
| bigram = text[i:i+2] |
| if bigram in self.w2i: |
| tokens.add(bigram) |
| |
| if not tokens: |
| return self.encode(text) |
| |
| |
| idxs = [self.w2i[t] for t in tokens] |
| vec = self.vecs75[idxs].mean(axis=0) |
| norm = np.linalg.norm(vec) |
| if norm > 1e-8: |
| vec = vec / norm |
| |
| return vec.astype(np.float32) |
|
|
|
|
| |
| _encoder = None |
|
|
| def get_encoder() -> SemanticEncoder: |
| """获取全局编码器实例""" |
| global _encoder |
| if _encoder is None: |
| _encoder = SemanticEncoder() |
| return _encoder |
|
|