swarm-chat / semantic_encoder.py
lk080424's picture
Upload semantic_encoder.py with huggingface_hub
87ea277 verified
Raw
History Blame Contribute Delete
11.5 kB
"""
语义编码器 — vecs75字符级语义编码
替代brain.py中的hash编码,让输入向量具有语义信息
原理:
"你好" → 查表"你"vec + "好"vec → 取平均 → 归一化 → 75维语义向量
语义相近的文字编码后余弦相似度高,hash编码做不到
"""
import numpy as np
import pickle
import os
from typing import Optional
class SemanticEncoder:
"""vecs75语义编码器"""
def __init__(self, model_dir: str = None):
self.words = []
self.vecs75 = None
self.vecs_n = None # 归一化版本
self.w2i = {}
self.dim = 75
self._loaded = False
self._clean_loaded = False # 干净解码词表
# 尝试加载(支持pkl和npz+json格式)
paths = []
if model_dir:
paths.append(os.path.join(model_dir, 'vocab75_index.pkl'))
paths.append(os.path.join(model_dir, 'vocab75_index.npz')) # npz格式
paths.append(os.path.expanduser('~/.swarm/models/vocab75_index.pkl'))
paths.append(os.path.expanduser('~/.swarm/models/vocab75_index.npz'))
paths.append(os.path.expanduser('~/swarm_product/models/vocab75_index.pkl'))
paths.append(os.path.expanduser('~/swarm_product/models/vocab75_index.npz'))
for p in paths:
if os.path.exists(p):
self._load(p)
break
if not self._loaded:
# 硬编码路径
for p in ['/home/admin/swarm_product/models/vocab75_index.pkl',
'/home/admin/swarm_product/models/vocab75_index.npz',
'/app/models/vocab75_index.npz']:
if os.path.exists(p):
self._load(p)
break
def _load(self, path: str):
"""加载vecs75词表(支持pkl和npz+json格式)"""
try:
# 确定npz和json路径
if path.endswith('.npz'):
npz_path = path
json_path = path.replace('.npz', '_words.json')
else:
npz_path = path.replace('.pkl', '.npz')
json_path = path.replace('.pkl', '_words.json')
if os.path.exists(npz_path) and os.path.exists(json_path):
import json as _json
with open(json_path, 'r', encoding='utf-8') as f:
self.words = _json.load(f)
data = np.load(npz_path)
self.vecs75 = data['vecs75']
if 'vecs75_normed' in data:
norms = np.linalg.norm(self.vecs75, axis=1, keepdims=True)
norms[norms < 1e-8] = 1
self.vecs_n = self.vecs75 / norms
else:
self.vecs_n = data['vecs75_normed']
elif os.path.exists(path):
# 回退pkl格式
with open(path, 'rb') as f:
data = pickle.load(f)
self.words = data['words']
self.vecs75 = data['vecs75']
norms = np.linalg.norm(self.vecs75, axis=1, keepdims=True)
norms[norms < 1e-8] = 1
self.vecs_n = self.vecs75 / norms
else:
return
self.w2i = {w: i for i, w in enumerate(self.words)}
self._loaded = True
print(f'[SemanticEncoder] 加载: {len(self.words)}词, {self.vecs75.shape}')
# 加载干净解码词表
self._load_clean(path)
except Exception as e:
print(f'[SemanticEncoder] 加载失败: {e}')
def _load_clean(self, orig_path: str):
"""加载干净解码词表(支持pkl和npz+json格式)"""
# 尝试npz+json格式(优先)
npz_path = orig_path.replace('vocab75_index.pkl', 'vocab75_clean.npz')
json_path = orig_path.replace('vocab75_index.pkl', 'vocab75_clean_words.json')
if os.path.exists(npz_path) and os.path.exists(json_path):
try:
import json as _json
with open(json_path, 'r', encoding='utf-8') as f:
self._clean_words = _json.load(f)
data = np.load(npz_path)
self._clean_vecs_n = data['vecs75_normed']
self._clean_loaded = True
print(f'[SemanticEncoder] 干净词表: {len(self._clean_words)}词')
return
except Exception as e:
print(f'[SemanticEncoder] 干净词表(npz)加载失败: {e}')
# 回退pkl格式
clean_path = orig_path.replace('vocab75_index.pkl', 'vocab75_clean.pkl')
if not os.path.exists(clean_path):
clean_path = '/home/admin/swarm_product/models/vocab75_clean.pkl'
if os.path.exists(clean_path):
try:
with open(clean_path, 'rb') as f:
data = pickle.load(f)
self._clean_words = data['words']
self._clean_vecs_n = data['vecs75_normed']
self._clean_loaded = True
print(f'[SemanticEncoder] 干净词表: {len(self._clean_words)}词')
except Exception as e:
print(f'[SemanticEncoder] 干净词表加载失败: {e}')
def encode(self, text: str) -> np.ndarray:
"""
文本→75维语义向量
策略: 字符级查表+加权平均
- 前面的字权重高(注意力衰减)
- 取平均后归一化
- 无匹配字符时回退到hash编码
"""
if not self._loaded or not text:
return self._hash_encode(text or '')
# 字符级查表
char_vecs = []
weights = []
for i, ch in enumerate(text[:20]):
if ch in self.w2i:
idx = self.w2i[ch]
char_vecs.append(self.vecs75[idx])
weights.append(1.0 / (i + 1)) # 前面的字更重要
if not char_vecs:
# 全部字符不在词表中,回退hash
return self._hash_encode(text)
# 加权平均
char_vecs = np.array(char_vecs)
weights = np.array(weights).reshape(-1, 1)
vec = (char_vecs * weights).sum(axis=0) / weights.sum()
# 归一化
norm = np.linalg.norm(vec)
if norm > 1e-8:
vec = vec / norm
return vec.astype(np.float32)
def _hash_encode(self, text: str) -> np.ndarray:
"""回退: hash编码(旧逻辑)"""
vec = np.zeros(self.dim, dtype=np.float32)
for i, ch in enumerate(text[:20]):
idx = hash(ch) % self.dim
vec[idx] += 1.0 / (i + 1)
if vec.max() > 0:
vec = vec / vec.max()
return vec
def decode_nearest(self, vec: np.ndarray, top_k: int = 5,
prefer_chinese: bool = True, max_word_len: int = 4) -> list:
"""
向量→最近邻词汇(输出解码器用)
Args:
vec: 75维向量
top_k: 返回前k个
prefer_chinese: 优先返回中文词(过滤英文/长短语)
max_word_len: 最大词长度(过滤长短语)
Returns:
[(词, 相似度), ...]
"""
if not self._loaded:
return []
vec = np.asarray(vec, dtype=np.float32).ravel()[:self.dim]
if len(vec) < self.dim:
vec = np.pad(vec, (0, self.dim - len(vec)))
norm = np.linalg.norm(vec)
if norm < 1e-8:
return []
vec_n = vec / norm
# 余弦相似度
# 优先用干净词表解码
if self._clean_loaded:
sims = self._clean_vecs_n @ vec_n
# 先取较多候选
n_cand = min(top_k * 5, len(sims))
top_indices = np.argsort(sims)[-n_cand:][::-1]
results = []
# 优先中文词
for i in top_indices:
w = self._clean_words[i]
if '\u4e00' <= w[0] <= '\u9fff': # 首字是中文
results.append((w, float(sims[i])))
if len(results) >= top_k:
break
# 不够再补英文
if len(results) < top_k:
for i in top_indices:
w = self._clean_words[i]
if not any(r[0] == w for r in results):
results.append((w, float(sims[i])))
if len(results) >= top_k:
break
return results
sims = self.vecs_n @ vec_n
if prefer_chinese:
# 先取top_k*3候选,再过滤
n_cand = min(top_k * 5, len(sims))
top_indices = np.argsort(sims)[-n_cand:][::-1]
results = []
for i in top_indices:
w = self.words[i]
# 过滤: 只要中文词且长度<=max_word_len, 排除脏数据(以n开头的中英混合)
if len(w) <= max_word_len and any('\u4e00' <= c <= '\u9fff' for c in w):
# 排除vecs75脏数据: 以非中文字符开头但含中文的混合词
first_char = w[0]
if '\u4e00' <= first_char <= '\u9fff':
results.append((w, float(sims[i])))
elif first_char.isalpha() and len(w) > 1 and '\u4e00' <= w[1] <= '\u9fff':
continue # 跳过"n这个"类脏数据
else:
results.append((w, float(sims[i])))
if len(results) >= top_k:
break
# 如果过滤后不够,补回英文/长词
if len(results) < top_k:
for i in top_indices:
w = self.words[i]
if not any(r[0] == w for r in results):
results.append((w, float(sims[i])))
if len(results) >= top_k:
break
return results
else:
top_indices = np.argsort(sims)[-top_k:][::-1]
return [(self.words[i], float(sims[i])) for i in top_indices]
def encode_sentence(self, text: str) -> np.ndarray:
"""
句子级编码 — 分词后词向量平均(比字符级更精准)
简单分词: 连续中文字符/连续英文/数字各为一段
"""
if not self._loaded or not text:
return self.encode(text or '')
# 简单分词: 2-gram + 1-gram
tokens = set()
# 1-gram
for ch in text:
if ch in self.w2i:
tokens.add(ch)
# 2-gram (相邻字组合)
for i in range(len(text) - 1):
bigram = text[i:i+2]
if bigram in self.w2i:
tokens.add(bigram)
if not tokens:
return self.encode(text)
# 取平均
idxs = [self.w2i[t] for t in tokens]
vec = self.vecs75[idxs].mean(axis=0)
norm = np.linalg.norm(vec)
if norm > 1e-8:
vec = vec / norm
return vec.astype(np.float32)
# 全局单例(延迟加载)
_encoder = None
def get_encoder() -> SemanticEncoder:
"""获取全局编码器实例"""
global _encoder
if _encoder is None:
_encoder = SemanticEncoder()
return _encoder