File size: 11,391 Bytes
17fba62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
#!/usr/bin/env python3
"""
虫群智能体系统 — 智能缓存层
缓存重复查询结果,减少API调用
LRU + TTL + 语义相似度匹配
"""

import hashlib
import json
import logging
import os
import threading
import time
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

logger = logging.getLogger(__name__)

DEFAULT_CACHE_DIR = "/home/admin/swarm/data/cache"


class CacheEntry:
    """缓存条目"""

    def __init__(self, query: str, response: str, model_id: str,
                 confidence: float, ttl: int = 3600):
        self.query = query
        self.response = response
        self.model_id = model_id
        self.confidence = confidence
        self.created_at = time.time()
        self.ttl = ttl  # 秒
        self.hit_count = 0
        self.last_hit_at = None

    @property
    def is_expired(self) -> bool:
        return time.time() - self.created_at > self.ttl

    @property
    def age_seconds(self) -> float:
        return time.time() - self.created_at

    def hit(self):
        """记录一次命中"""
        self.hit_count += 1
        self.last_hit_at = time.time()

    def to_dict(self) -> Dict:
        return {
            "query": self.query,
            "response": self.response,
            "model_id": self.model_id,
            "confidence": self.confidence,
            "created_at": self.created_at,
            "ttl": self.ttl,
            "hit_count": self.hit_count,
            "last_hit_at": self.last_hit_at,
        }

    @classmethod
    def from_dict(cls, d: Dict) -> "CacheEntry":
        entry = cls(d["query"], d["response"], d["model_id"],
                    d["confidence"], d.get("ttl", 3600))
        entry.created_at = d.get("created_at", time.time())
        entry.hit_count = d.get("hit_count", 0)
        entry.last_hit_at = d.get("last_hit_at")
        return entry


class SmartCache:
    """
    智能缓存 — 单例
    特性:
    - LRU淘汰(容量上限)
    - TTL过期(时间上限)
    - 语义相似度匹配(相似问题命中缓存)
    - 关键词去重(核心词相同视为同一查询)
    """

    _instance = None
    _lock = threading.Lock()

    # 缓存配置
    MAX_ENTRIES = 500        # 最大条目数
    DEFAULT_TTL = 3600       # 默认1小时过期
    SIMILARITY_THRESHOLD = 0.8  # 相似度阈值

    def __new__(cls):
        with cls._lock:
            if cls._instance is None:
                cls._instance = super().__new__(cls)
                cls._instance._cache = OrderedDict()  # key -> CacheEntry
                cls._instance._keyword_index = {}     # 关键词集合 -> cache_key
                cls._instance._stats = {
                    "hits": 0, "misses": 0, "evictions": 0, "expirations": 0
                }
                cls._instance._initialized = False
            return cls._instance

    def initialize(self):
        """初始化缓存"""
        # 尝试加载持久化缓存
        self._load()
        self._initialized = True
        logger.info(f"智能缓存初始化: {len(self._cache)}条缓存")

    # ----------------------------------------------------------
    # 核心:查询/存储
    # ----------------------------------------------------------

    def get(self, query: str) -> Optional[CacheEntry]:
        """
        查询缓存
        策略:精确匹配 → 关键词匹配 → 无命中
        """
        # 1. 精确匹配
        key = self._make_key(query)
        entry = self._cache.get(key)
        if entry and not entry.is_expired:
            entry.hit()
            self._stats["hits"] += 1
            # LRU:移到末尾
            self._cache.move_to_end(key)
            logger.debug(f"缓存命中(精确): {query[:30]}")
            return entry
        elif entry and entry.is_expired:
            # 过期,删除
            del self._cache[key]
            self._stats["expirations"] += 1

        # 2. 关键词相似度匹配
        similar = self._find_similar(query)
        if similar:
            similar.hit()
            self._stats["hits"] += 1
            logger.debug(f"缓存命中(相似): {query[:30]}{similar.query[:30]}")
            return similar

        self._stats["misses"] += 1
        return None

    def put(self, query: str, response: str, model_id: str,
            confidence: float, ttl: int = None):
        """
        存储到缓存
        仅缓存有意义的查询(太短/太常见的跳过)
        """
        # 过滤不需要缓存的查询
        if self._should_skip(query):
            return

        key = self._make_key(query)
        entry = CacheEntry(query, response, model_id, confidence,
                           ttl or self.DEFAULT_TTL)

        # 容量检查:LRU淘汰
        while len(self._cache) >= self.MAX_ENTRIES:
            oldest_key = next(iter(self._cache))
            del self._cache[oldest_key]
            self._stats["evictions"] += 1

        self._cache[key] = entry
        # 更新关键词索引
        self._update_keyword_index(key, query)

        # 持久化(节流:每10次写入持久化一次)
        if self._stats["misses"] % 10 == 0:
            self._save()

    def invalidate(self, query: str):
        """使指定查询的缓存失效"""
        key = self._make_key(query)
        if key in self._cache:
            del self._cache[key]

    def clear(self):
        """清空缓存"""
        self._cache.clear()
        self._keyword_index.clear()
        self._save()
        logger.info("缓存已清空")

    # ----------------------------------------------------------
    # 相似度匹配
    # ----------------------------------------------------------

    def _find_similar(self, query: str) -> Optional[CacheEntry]:
        """基于关键词集合的相似度匹配"""
        query_keywords = self._extract_keywords(query)
        if not query_keywords:
            return None

        best_match = None
        best_score = 0.0

        for idx_key, cache_key in self._keyword_index.items():
            if cache_key not in self._cache:
                continue
            entry = self._cache[cache_key]
            if entry.is_expired:
                continue

            # Jaccard相似度
            stored_keywords = self._extract_keywords(entry.query)
            if not stored_keywords:
                continue

            intersection = len(query_keywords & stored_keywords)
            union = len(query_keywords | stored_keywords)
            if union == 0:
                continue

            score = intersection / union
            if score >= self.SIMILARITY_THRESHOLD and score > best_score:
                best_score = score
                best_match = entry

        return best_match

    def _extract_keywords(self, text: str) -> set:
        """提取关键词(简单的中文分词+英文分词)"""
        keywords = set()
        # 英文单词
        import re
        en_words = re.findall(r'[a-zA-Z]{2,}', text.lower())
        keywords.update(en_words)
        # 中文字符对(2-gram)
        cn_chars = re.findall(r'[\u4e00-\u9fff]+', text)
        for segment in cn_chars:
            if len(segment) >= 2:
                for i in range(len(segment) - 1):
                    keywords.add(segment[i:i+2])
            else:
                keywords.add(segment)
        # 去掉太常见的词
        stop_words = {"的是", "在了", "和有", "这不", "一我"}
        keywords -= stop_words
        return keywords

    def _update_keyword_index(self, cache_key: str, query: str):
        """更新关键词索引"""
        keywords = self._extract_keywords(query)
        if keywords:
            idx_key = frozenset(keywords)
            self._keyword_index[idx_key] = cache_key

    # ----------------------------------------------------------
    # 辅助
    # ----------------------------------------------------------

    def _make_key(self, query: str) -> str:
        """生成缓存key"""
        normalized = query.strip().lower()
        return hashlib.md5(normalized.encode()).hexdigest()

    def _should_skip(self, query: str) -> bool:
        """判断是否应跳过缓存"""
        q = query.strip()
        # 太短
        if len(q) < 3:
            return True
        # 问候语
        greetings = ["你好", "您好", "嗨", "hi", "hello", "早上好", "下午好", "晚上好"]
        if q.lower() in greetings:
            return True
        # 确认词
        confirms = ["好的", "明白", "收到", "ok", "谢谢", "感谢"]
        if q.lower() in confirms:
            return True
        return False

    # ----------------------------------------------------------
    # 统计与持久化
    # ----------------------------------------------------------

    def get_stats(self) -> Dict:
        """获取缓存统计"""
        total = self._stats["hits"] + self._stats["misses"]
        return {
            "total_entries": len(self._cache),
            "max_entries": self.MAX_ENTRIES,
            "hits": self._stats["hits"],
            "misses": self._stats["misses"],
            "hit_rate": round(self._stats["hits"] / max(1, total), 3),
            "evictions": self._stats["evictions"],
            "expirations": self._stats["expirations"],
        }

    def get_top_entries(self, top_k: int = 10) -> List[Dict]:
        """获取最热门缓存条目"""
        entries = sorted(
            self._cache.values(),
            key=lambda e: e.hit_count, reverse=True
        )
        return [
            {"query": e.query[:50], "hits": e.hit_count,
             "model": e.model_id, "age_s": round(e.age_seconds)}
            for e in entries[:top_k]
        ]

    def _save(self):
        """持久化缓存到磁盘"""
        try:
            os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
            data = {
                "entries": {k: v.to_dict() for k, v in self._cache.items()},
                "stats": self._stats,
            }
            path = os.path.join(DEFAULT_CACHE_DIR, "smart_cache.json")
            with open(path, "w", encoding="utf-8") as f:
                json.dump(data, f, ensure_ascii=False, indent=2)
        except Exception as e:
            logger.warning(f"缓存持久化失败: {e}")

    def _load(self):
        """从磁盘加载缓存"""
        path = os.path.join(DEFAULT_CACHE_DIR, "smart_cache.json")
        if not os.path.exists(path):
            return
        try:
            with open(path, "r", encoding="utf-8") as f:
                data = json.load(f)
            for k, v in data.get("entries", {}).items():
                entry = CacheEntry.from_dict(v)
                if not entry.is_expired:
                    self._cache[k] = entry
            self._stats.update(data.get("stats", {}))
            logger.info(f"加载缓存: {len(self._cache)}条")
        except Exception as e:
            logger.warning(f"缓存加载失败: {e}")


# ============================================================
# 便捷函数
# ============================================================

_cache = None


def get_cache() -> SmartCache:
    """获取全局缓存实例"""
    global _cache
    if _cache is None:
        _cache = SmartCache()
        if not _cache._initialized:
            _cache.initialize()
    return _cache