swarm-backend / core /smart_cache.py
lk080424's picture
Upload folder using huggingface_hub
17fba62 verified
#!/usr/bin/env python3
"""
虫群智能体系统 — 智能缓存层
缓存重复查询结果,减少API调用
LRU + TTL + 语义相似度匹配
"""
import hashlib
import json
import logging
import os
import threading
import time
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
DEFAULT_CACHE_DIR = "/home/admin/swarm/data/cache"
class CacheEntry:
"""缓存条目"""
def __init__(self, query: str, response: str, model_id: str,
confidence: float, ttl: int = 3600):
self.query = query
self.response = response
self.model_id = model_id
self.confidence = confidence
self.created_at = time.time()
self.ttl = ttl # 秒
self.hit_count = 0
self.last_hit_at = None
@property
def is_expired(self) -> bool:
return time.time() - self.created_at > self.ttl
@property
def age_seconds(self) -> float:
return time.time() - self.created_at
def hit(self):
"""记录一次命中"""
self.hit_count += 1
self.last_hit_at = time.time()
def to_dict(self) -> Dict:
return {
"query": self.query,
"response": self.response,
"model_id": self.model_id,
"confidence": self.confidence,
"created_at": self.created_at,
"ttl": self.ttl,
"hit_count": self.hit_count,
"last_hit_at": self.last_hit_at,
}
@classmethod
def from_dict(cls, d: Dict) -> "CacheEntry":
entry = cls(d["query"], d["response"], d["model_id"],
d["confidence"], d.get("ttl", 3600))
entry.created_at = d.get("created_at", time.time())
entry.hit_count = d.get("hit_count", 0)
entry.last_hit_at = d.get("last_hit_at")
return entry
class SmartCache:
"""
智能缓存 — 单例
特性:
- LRU淘汰(容量上限)
- TTL过期(时间上限)
- 语义相似度匹配(相似问题命中缓存)
- 关键词去重(核心词相同视为同一查询)
"""
_instance = None
_lock = threading.Lock()
# 缓存配置
MAX_ENTRIES = 500 # 最大条目数
DEFAULT_TTL = 3600 # 默认1小时过期
SIMILARITY_THRESHOLD = 0.8 # 相似度阈值
def __new__(cls):
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._cache = OrderedDict() # key -> CacheEntry
cls._instance._keyword_index = {} # 关键词集合 -> cache_key
cls._instance._stats = {
"hits": 0, "misses": 0, "evictions": 0, "expirations": 0
}
cls._instance._initialized = False
return cls._instance
def initialize(self):
"""初始化缓存"""
# 尝试加载持久化缓存
self._load()
self._initialized = True
logger.info(f"智能缓存初始化: {len(self._cache)}条缓存")
# ----------------------------------------------------------
# 核心:查询/存储
# ----------------------------------------------------------
def get(self, query: str) -> Optional[CacheEntry]:
"""
查询缓存
策略:精确匹配 → 关键词匹配 → 无命中
"""
# 1. 精确匹配
key = self._make_key(query)
entry = self._cache.get(key)
if entry and not entry.is_expired:
entry.hit()
self._stats["hits"] += 1
# LRU:移到末尾
self._cache.move_to_end(key)
logger.debug(f"缓存命中(精确): {query[:30]}")
return entry
elif entry and entry.is_expired:
# 过期,删除
del self._cache[key]
self._stats["expirations"] += 1
# 2. 关键词相似度匹配
similar = self._find_similar(query)
if similar:
similar.hit()
self._stats["hits"] += 1
logger.debug(f"缓存命中(相似): {query[:30]}{similar.query[:30]}")
return similar
self._stats["misses"] += 1
return None
def put(self, query: str, response: str, model_id: str,
confidence: float, ttl: int = None):
"""
存储到缓存
仅缓存有意义的查询(太短/太常见的跳过)
"""
# 过滤不需要缓存的查询
if self._should_skip(query):
return
key = self._make_key(query)
entry = CacheEntry(query, response, model_id, confidence,
ttl or self.DEFAULT_TTL)
# 容量检查:LRU淘汰
while len(self._cache) >= self.MAX_ENTRIES:
oldest_key = next(iter(self._cache))
del self._cache[oldest_key]
self._stats["evictions"] += 1
self._cache[key] = entry
# 更新关键词索引
self._update_keyword_index(key, query)
# 持久化(节流:每10次写入持久化一次)
if self._stats["misses"] % 10 == 0:
self._save()
def invalidate(self, query: str):
"""使指定查询的缓存失效"""
key = self._make_key(query)
if key in self._cache:
del self._cache[key]
def clear(self):
"""清空缓存"""
self._cache.clear()
self._keyword_index.clear()
self._save()
logger.info("缓存已清空")
# ----------------------------------------------------------
# 相似度匹配
# ----------------------------------------------------------
def _find_similar(self, query: str) -> Optional[CacheEntry]:
"""基于关键词集合的相似度匹配"""
query_keywords = self._extract_keywords(query)
if not query_keywords:
return None
best_match = None
best_score = 0.0
for idx_key, cache_key in self._keyword_index.items():
if cache_key not in self._cache:
continue
entry = self._cache[cache_key]
if entry.is_expired:
continue
# Jaccard相似度
stored_keywords = self._extract_keywords(entry.query)
if not stored_keywords:
continue
intersection = len(query_keywords & stored_keywords)
union = len(query_keywords | stored_keywords)
if union == 0:
continue
score = intersection / union
if score >= self.SIMILARITY_THRESHOLD and score > best_score:
best_score = score
best_match = entry
return best_match
def _extract_keywords(self, text: str) -> set:
"""提取关键词(简单的中文分词+英文分词)"""
keywords = set()
# 英文单词
import re
en_words = re.findall(r'[a-zA-Z]{2,}', text.lower())
keywords.update(en_words)
# 中文字符对(2-gram)
cn_chars = re.findall(r'[\u4e00-\u9fff]+', text)
for segment in cn_chars:
if len(segment) >= 2:
for i in range(len(segment) - 1):
keywords.add(segment[i:i+2])
else:
keywords.add(segment)
# 去掉太常见的词
stop_words = {"的是", "在了", "和有", "这不", "一我"}
keywords -= stop_words
return keywords
def _update_keyword_index(self, cache_key: str, query: str):
"""更新关键词索引"""
keywords = self._extract_keywords(query)
if keywords:
idx_key = frozenset(keywords)
self._keyword_index[idx_key] = cache_key
# ----------------------------------------------------------
# 辅助
# ----------------------------------------------------------
def _make_key(self, query: str) -> str:
"""生成缓存key"""
normalized = query.strip().lower()
return hashlib.md5(normalized.encode()).hexdigest()
def _should_skip(self, query: str) -> bool:
"""判断是否应跳过缓存"""
q = query.strip()
# 太短
if len(q) < 3:
return True
# 问候语
greetings = ["你好", "您好", "嗨", "hi", "hello", "早上好", "下午好", "晚上好"]
if q.lower() in greetings:
return True
# 确认词
confirms = ["好的", "明白", "收到", "ok", "谢谢", "感谢"]
if q.lower() in confirms:
return True
return False
# ----------------------------------------------------------
# 统计与持久化
# ----------------------------------------------------------
def get_stats(self) -> Dict:
"""获取缓存统计"""
total = self._stats["hits"] + self._stats["misses"]
return {
"total_entries": len(self._cache),
"max_entries": self.MAX_ENTRIES,
"hits": self._stats["hits"],
"misses": self._stats["misses"],
"hit_rate": round(self._stats["hits"] / max(1, total), 3),
"evictions": self._stats["evictions"],
"expirations": self._stats["expirations"],
}
def get_top_entries(self, top_k: int = 10) -> List[Dict]:
"""获取最热门缓存条目"""
entries = sorted(
self._cache.values(),
key=lambda e: e.hit_count, reverse=True
)
return [
{"query": e.query[:50], "hits": e.hit_count,
"model": e.model_id, "age_s": round(e.age_seconds)}
for e in entries[:top_k]
]
def _save(self):
"""持久化缓存到磁盘"""
try:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
data = {
"entries": {k: v.to_dict() for k, v in self._cache.items()},
"stats": self._stats,
}
path = os.path.join(DEFAULT_CACHE_DIR, "smart_cache.json")
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except Exception as e:
logger.warning(f"缓存持久化失败: {e}")
def _load(self):
"""从磁盘加载缓存"""
path = os.path.join(DEFAULT_CACHE_DIR, "smart_cache.json")
if not os.path.exists(path):
return
try:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
for k, v in data.get("entries", {}).items():
entry = CacheEntry.from_dict(v)
if not entry.is_expired:
self._cache[k] = entry
self._stats.update(data.get("stats", {}))
logger.info(f"加载缓存: {len(self._cache)}条")
except Exception as e:
logger.warning(f"缓存加载失败: {e}")
# ============================================================
# 便捷函数
# ============================================================
_cache = None
def get_cache() -> SmartCache:
"""获取全局缓存实例"""
global _cache
if _cache is None:
_cache = SmartCache()
if not _cache._initialized:
_cache.initialize()
return _cache