Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| 虫群智能体系统 — 智能缓存层 | |
| 缓存重复查询结果,减少API调用 | |
| LRU + TTL + 语义相似度匹配 | |
| """ | |
| import hashlib | |
| import json | |
| import logging | |
| import os | |
| import threading | |
| import time | |
| from collections import OrderedDict | |
| from typing import Dict, List, Optional, Tuple | |
| logger = logging.getLogger(__name__) | |
| DEFAULT_CACHE_DIR = "/home/admin/swarm/data/cache" | |
| class CacheEntry: | |
| """缓存条目""" | |
| def __init__(self, query: str, response: str, model_id: str, | |
| confidence: float, ttl: int = 3600): | |
| self.query = query | |
| self.response = response | |
| self.model_id = model_id | |
| self.confidence = confidence | |
| self.created_at = time.time() | |
| self.ttl = ttl # 秒 | |
| self.hit_count = 0 | |
| self.last_hit_at = None | |
| def is_expired(self) -> bool: | |
| return time.time() - self.created_at > self.ttl | |
| def age_seconds(self) -> float: | |
| return time.time() - self.created_at | |
| def hit(self): | |
| """记录一次命中""" | |
| self.hit_count += 1 | |
| self.last_hit_at = time.time() | |
| def to_dict(self) -> Dict: | |
| return { | |
| "query": self.query, | |
| "response": self.response, | |
| "model_id": self.model_id, | |
| "confidence": self.confidence, | |
| "created_at": self.created_at, | |
| "ttl": self.ttl, | |
| "hit_count": self.hit_count, | |
| "last_hit_at": self.last_hit_at, | |
| } | |
| def from_dict(cls, d: Dict) -> "CacheEntry": | |
| entry = cls(d["query"], d["response"], d["model_id"], | |
| d["confidence"], d.get("ttl", 3600)) | |
| entry.created_at = d.get("created_at", time.time()) | |
| entry.hit_count = d.get("hit_count", 0) | |
| entry.last_hit_at = d.get("last_hit_at") | |
| return entry | |
| class SmartCache: | |
| """ | |
| 智能缓存 — 单例 | |
| 特性: | |
| - LRU淘汰(容量上限) | |
| - TTL过期(时间上限) | |
| - 语义相似度匹配(相似问题命中缓存) | |
| - 关键词去重(核心词相同视为同一查询) | |
| """ | |
| _instance = None | |
| _lock = threading.Lock() | |
| # 缓存配置 | |
| MAX_ENTRIES = 500 # 最大条目数 | |
| DEFAULT_TTL = 3600 # 默认1小时过期 | |
| SIMILARITY_THRESHOLD = 0.8 # 相似度阈值 | |
| def __new__(cls): | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| cls._instance._cache = OrderedDict() # key -> CacheEntry | |
| cls._instance._keyword_index = {} # 关键词集合 -> cache_key | |
| cls._instance._stats = { | |
| "hits": 0, "misses": 0, "evictions": 0, "expirations": 0 | |
| } | |
| cls._instance._initialized = False | |
| return cls._instance | |
| def initialize(self): | |
| """初始化缓存""" | |
| # 尝试加载持久化缓存 | |
| self._load() | |
| self._initialized = True | |
| logger.info(f"智能缓存初始化: {len(self._cache)}条缓存") | |
| # ---------------------------------------------------------- | |
| # 核心:查询/存储 | |
| # ---------------------------------------------------------- | |
| def get(self, query: str) -> Optional[CacheEntry]: | |
| """ | |
| 查询缓存 | |
| 策略:精确匹配 → 关键词匹配 → 无命中 | |
| """ | |
| # 1. 精确匹配 | |
| key = self._make_key(query) | |
| entry = self._cache.get(key) | |
| if entry and not entry.is_expired: | |
| entry.hit() | |
| self._stats["hits"] += 1 | |
| # LRU:移到末尾 | |
| self._cache.move_to_end(key) | |
| logger.debug(f"缓存命中(精确): {query[:30]}") | |
| return entry | |
| elif entry and entry.is_expired: | |
| # 过期,删除 | |
| del self._cache[key] | |
| self._stats["expirations"] += 1 | |
| # 2. 关键词相似度匹配 | |
| similar = self._find_similar(query) | |
| if similar: | |
| similar.hit() | |
| self._stats["hits"] += 1 | |
| logger.debug(f"缓存命中(相似): {query[:30]} → {similar.query[:30]}") | |
| return similar | |
| self._stats["misses"] += 1 | |
| return None | |
| def put(self, query: str, response: str, model_id: str, | |
| confidence: float, ttl: int = None): | |
| """ | |
| 存储到缓存 | |
| 仅缓存有意义的查询(太短/太常见的跳过) | |
| """ | |
| # 过滤不需要缓存的查询 | |
| if self._should_skip(query): | |
| return | |
| key = self._make_key(query) | |
| entry = CacheEntry(query, response, model_id, confidence, | |
| ttl or self.DEFAULT_TTL) | |
| # 容量检查:LRU淘汰 | |
| while len(self._cache) >= self.MAX_ENTRIES: | |
| oldest_key = next(iter(self._cache)) | |
| del self._cache[oldest_key] | |
| self._stats["evictions"] += 1 | |
| self._cache[key] = entry | |
| # 更新关键词索引 | |
| self._update_keyword_index(key, query) | |
| # 持久化(节流:每10次写入持久化一次) | |
| if self._stats["misses"] % 10 == 0: | |
| self._save() | |
| def invalidate(self, query: str): | |
| """使指定查询的缓存失效""" | |
| key = self._make_key(query) | |
| if key in self._cache: | |
| del self._cache[key] | |
| def clear(self): | |
| """清空缓存""" | |
| self._cache.clear() | |
| self._keyword_index.clear() | |
| self._save() | |
| logger.info("缓存已清空") | |
| # ---------------------------------------------------------- | |
| # 相似度匹配 | |
| # ---------------------------------------------------------- | |
| def _find_similar(self, query: str) -> Optional[CacheEntry]: | |
| """基于关键词集合的相似度匹配""" | |
| query_keywords = self._extract_keywords(query) | |
| if not query_keywords: | |
| return None | |
| best_match = None | |
| best_score = 0.0 | |
| for idx_key, cache_key in self._keyword_index.items(): | |
| if cache_key not in self._cache: | |
| continue | |
| entry = self._cache[cache_key] | |
| if entry.is_expired: | |
| continue | |
| # Jaccard相似度 | |
| stored_keywords = self._extract_keywords(entry.query) | |
| if not stored_keywords: | |
| continue | |
| intersection = len(query_keywords & stored_keywords) | |
| union = len(query_keywords | stored_keywords) | |
| if union == 0: | |
| continue | |
| score = intersection / union | |
| if score >= self.SIMILARITY_THRESHOLD and score > best_score: | |
| best_score = score | |
| best_match = entry | |
| return best_match | |
| def _extract_keywords(self, text: str) -> set: | |
| """提取关键词(简单的中文分词+英文分词)""" | |
| keywords = set() | |
| # 英文单词 | |
| import re | |
| en_words = re.findall(r'[a-zA-Z]{2,}', text.lower()) | |
| keywords.update(en_words) | |
| # 中文字符对(2-gram) | |
| cn_chars = re.findall(r'[\u4e00-\u9fff]+', text) | |
| for segment in cn_chars: | |
| if len(segment) >= 2: | |
| for i in range(len(segment) - 1): | |
| keywords.add(segment[i:i+2]) | |
| else: | |
| keywords.add(segment) | |
| # 去掉太常见的词 | |
| stop_words = {"的是", "在了", "和有", "这不", "一我"} | |
| keywords -= stop_words | |
| return keywords | |
| def _update_keyword_index(self, cache_key: str, query: str): | |
| """更新关键词索引""" | |
| keywords = self._extract_keywords(query) | |
| if keywords: | |
| idx_key = frozenset(keywords) | |
| self._keyword_index[idx_key] = cache_key | |
| # ---------------------------------------------------------- | |
| # 辅助 | |
| # ---------------------------------------------------------- | |
| def _make_key(self, query: str) -> str: | |
| """生成缓存key""" | |
| normalized = query.strip().lower() | |
| return hashlib.md5(normalized.encode()).hexdigest() | |
| def _should_skip(self, query: str) -> bool: | |
| """判断是否应跳过缓存""" | |
| q = query.strip() | |
| # 太短 | |
| if len(q) < 3: | |
| return True | |
| # 问候语 | |
| greetings = ["你好", "您好", "嗨", "hi", "hello", "早上好", "下午好", "晚上好"] | |
| if q.lower() in greetings: | |
| return True | |
| # 确认词 | |
| confirms = ["好的", "明白", "收到", "ok", "谢谢", "感谢"] | |
| if q.lower() in confirms: | |
| return True | |
| return False | |
| # ---------------------------------------------------------- | |
| # 统计与持久化 | |
| # ---------------------------------------------------------- | |
| def get_stats(self) -> Dict: | |
| """获取缓存统计""" | |
| total = self._stats["hits"] + self._stats["misses"] | |
| return { | |
| "total_entries": len(self._cache), | |
| "max_entries": self.MAX_ENTRIES, | |
| "hits": self._stats["hits"], | |
| "misses": self._stats["misses"], | |
| "hit_rate": round(self._stats["hits"] / max(1, total), 3), | |
| "evictions": self._stats["evictions"], | |
| "expirations": self._stats["expirations"], | |
| } | |
| def get_top_entries(self, top_k: int = 10) -> List[Dict]: | |
| """获取最热门缓存条目""" | |
| entries = sorted( | |
| self._cache.values(), | |
| key=lambda e: e.hit_count, reverse=True | |
| ) | |
| return [ | |
| {"query": e.query[:50], "hits": e.hit_count, | |
| "model": e.model_id, "age_s": round(e.age_seconds)} | |
| for e in entries[:top_k] | |
| ] | |
| def _save(self): | |
| """持久化缓存到磁盘""" | |
| try: | |
| os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) | |
| data = { | |
| "entries": {k: v.to_dict() for k, v in self._cache.items()}, | |
| "stats": self._stats, | |
| } | |
| path = os.path.join(DEFAULT_CACHE_DIR, "smart_cache.json") | |
| with open(path, "w", encoding="utf-8") as f: | |
| json.dump(data, f, ensure_ascii=False, indent=2) | |
| except Exception as e: | |
| logger.warning(f"缓存持久化失败: {e}") | |
| def _load(self): | |
| """从磁盘加载缓存""" | |
| path = os.path.join(DEFAULT_CACHE_DIR, "smart_cache.json") | |
| if not os.path.exists(path): | |
| return | |
| try: | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| for k, v in data.get("entries", {}).items(): | |
| entry = CacheEntry.from_dict(v) | |
| if not entry.is_expired: | |
| self._cache[k] = entry | |
| self._stats.update(data.get("stats", {})) | |
| logger.info(f"加载缓存: {len(self._cache)}条") | |
| except Exception as e: | |
| logger.warning(f"缓存加载失败: {e}") | |
| # ============================================================ | |
| # 便捷函数 | |
| # ============================================================ | |
| _cache = None | |
| def get_cache() -> SmartCache: | |
| """获取全局缓存实例""" | |
| global _cache | |
| if _cache is None: | |
| _cache = SmartCache() | |
| if not _cache._initialized: | |
| _cache.initialize() | |
| return _cache | |