Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """ | |
| 向量服务层 - Qdrant 版 | |
| 特性: | |
| 1. 混合搜索 - Qdrant 向量 + BM25 关键词,RRF 融合 | |
| 2. 异步原生 - 全链路异步 | |
| 3. 会话隔离 - 每个 session 独立集合 | |
| 4. 状态持久化 - 仓库信息、BM25 索引缓存 | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import pickle | |
| import re | |
| import tempfile | |
| import time | |
| from dataclasses import dataclass, field | |
| from typing import List, Dict, Any, Optional, Set | |
| from rank_bm25 import BM25Okapi | |
| from app.core.config import settings | |
| from app.storage.base import Document, SearchResult, CollectionStats | |
| from app.storage.qdrant_store import QdrantVectorStore, QdrantConfig, get_qdrant_factory | |
| from app.utils.embedding import get_embedding_service, EmbeddingConfig | |
| logger = logging.getLogger(__name__) | |
| # ============================================================ | |
| # 使用统一配置 | |
| # ============================================================ | |
| from app.core.config import vector_config as config | |
| # 确保目录存在 | |
| os.makedirs(config.context_dir, exist_ok=True) | |
| # === 向后兼容导出 (供 main.py 使用) === | |
| vector_config = config # 兼容旧名称 | |
| CONTEXT_DIR = config.context_dir | |
| QDRANT_DIR = config.data_dir # Qdrant 数据目录 | |
| # ============================================================ | |
| # Embedding 服务 | |
| # ============================================================ | |
| _embedding_service = None | |
| def get_embedding(): | |
| """获取 Embedding 服务单例""" | |
| global _embedding_service | |
| if _embedding_service is None: | |
| emb_config = EmbeddingConfig( | |
| api_base_url=config.embedding_api_url, | |
| model_name=config.embedding_model, | |
| batch_size=config.embedding_batch_size, | |
| max_text_length=config.embedding_max_length, | |
| max_concurrent_batches=config.embedding_concurrency, | |
| ) | |
| _embedding_service = get_embedding_service(emb_config) | |
| return _embedding_service | |
| # ============================================================ | |
| # 向量存储服务 | |
| # ============================================================ | |
| class VectorStore: | |
| """ | |
| 向量存储服务 | |
| 整合 Qdrant 向量搜索和 BM25 关键词搜索 | |
| 使用示例: | |
| ```python | |
| store = VectorStore("session_123") | |
| await store.initialize() | |
| # 重置 (分析新仓库时) | |
| await store.reset() | |
| # 添加文档 | |
| await store.add_documents(documents, metadatas) | |
| # 混合搜索 | |
| results = await store.search_hybrid("how does auth work?") | |
| await store.close() | |
| ``` | |
| """ | |
| def __init__(self, session_id: str): | |
| self.session_id = self._sanitize_id(session_id) | |
| self.collection_name = f"repo_{self.session_id}" | |
| # Qdrant 存储 | |
| self._qdrant: Optional[QdrantVectorStore] = None | |
| # BM25 索引 (内存) | |
| self._bm25: Optional[BM25Okapi] = None | |
| self._doc_store: List[Document] = [] | |
| self._indexed_files: Set[str] = set() | |
| # 上下文 | |
| self.repo_url: Optional[str] = None | |
| self.global_context: Dict[str, Any] = {} | |
| # 文件路径 | |
| self._context_file = os.path.join(config.context_dir, f"{self.session_id}.json") | |
| self._cache_file = os.path.join(config.context_dir, f"{self.session_id}_bm25.pkl") | |
| self._initialized = False | |
| def _sanitize_id(session_id: str) -> str: | |
| """清理 session ID""" | |
| clean = re.sub(r'[^a-zA-Z0-9_-]', '', session_id) | |
| if not clean: | |
| raise ValueError("Invalid session_id") | |
| return clean | |
| async def initialize(self) -> None: | |
| """初始化存储""" | |
| if self._initialized: | |
| return | |
| # 初始化 Qdrant | |
| factory = get_qdrant_factory() | |
| self._qdrant = factory.create(self.collection_name) | |
| await self._qdrant.initialize() | |
| # 加载本地状态 | |
| await self._load_state() | |
| self._initialized = True | |
| logger.debug(f"✅ VectorStore 初始化: {self.session_id}") | |
| async def close(self) -> None: | |
| """关闭连接""" | |
| if self._qdrant: | |
| await self._qdrant.close() | |
| self._qdrant = None | |
| self._initialized = False | |
| async def _load_state(self) -> None: | |
| """加载状态""" | |
| # 1. 加载上下文 JSON | |
| if os.path.exists(self._context_file): | |
| try: | |
| with open(self._context_file, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| self.repo_url = data.get("repo_url") | |
| self.global_context = data.get("global_context", {}) | |
| except Exception as e: | |
| logger.warning(f"加载上下文失败: {e}") | |
| # 2. 尝试加载 BM25 缓存 | |
| cache_loaded = False | |
| if os.path.exists(self._cache_file): | |
| try: | |
| with open(self._cache_file, 'rb') as f: | |
| cache = pickle.load(f) | |
| if isinstance(cache, dict) and cache.get("version") == config.cache_version: | |
| self._bm25 = cache.get("bm25") | |
| self._doc_store = cache.get("doc_store", []) | |
| self._indexed_files = cache.get("indexed_files", set()) | |
| cache_loaded = True | |
| logger.debug(f"📦 BM25 缓存命中: {len(self._doc_store)} 文档") | |
| except Exception as e: | |
| logger.warning(f"BM25 缓存损坏: {e}") | |
| os.remove(self._cache_file) | |
| # 3. 缓存未命中: 从 Qdrant 重建 | |
| if not cache_loaded and self._qdrant: | |
| await self._rebuild_bm25_index() | |
| async def _rebuild_bm25_index(self) -> None: | |
| """从 Qdrant 重建 BM25 索引""" | |
| logger.info(f"🔄 重建 BM25 索引: {self.session_id}") | |
| documents = await self._qdrant.get_all_documents() | |
| if documents: | |
| self._doc_store = documents | |
| self._indexed_files = {doc.file_path for doc in documents if doc.file_path} | |
| tokenized = [self._tokenize(doc.content) for doc in documents] | |
| if tokenized: | |
| self._bm25 = BM25Okapi(tokenized) | |
| self._save_bm25_cache() | |
| logger.info(f"✅ BM25 索引重建完成: {len(documents)} 文档") | |
| def _save_bm25_cache(self) -> None: | |
| """保存 BM25 缓存 (原子写入)""" | |
| if not self._doc_store: | |
| return | |
| try: | |
| fd, tmp_path = tempfile.mkstemp(dir=config.context_dir) | |
| with os.fdopen(fd, 'wb') as f: | |
| pickle.dump({ | |
| "version": config.cache_version, | |
| "bm25": self._bm25, | |
| "doc_store": self._doc_store, | |
| "indexed_files": self._indexed_files, | |
| }, f) | |
| if os.path.exists(self._cache_file): | |
| os.remove(self._cache_file) | |
| os.rename(tmp_path, self._cache_file) | |
| except Exception as e: | |
| logger.error(f"保存 BM25 缓存失败: {e}") | |
| def _tokenize(self, text: str) -> List[str]: | |
| """分词""" | |
| return [ | |
| t.lower() for t in re.split(config.tokenize_regex, text) | |
| if t.strip() | |
| ] | |
| async def save_context(self, repo_url: str, context_data: Dict[str, Any]) -> None: | |
| """保存仓库上下文 (异步,不阻塞事件循环)""" | |
| self.repo_url = repo_url | |
| self.global_context = context_data | |
| await asyncio.to_thread(self._write_context_file, { | |
| "repo_url": repo_url, | |
| "global_context": context_data, | |
| }) | |
| def _write_context_file(self, updates: Dict[str, Any]) -> None: | |
| """写入上下文文件 (同步,供线程池调用)""" | |
| try: | |
| existing = {} | |
| if os.path.exists(self._context_file): | |
| with open(self._context_file, 'r', encoding='utf-8') as f: | |
| existing = json.load(f) | |
| existing.update(updates) | |
| with open(self._context_file, 'w', encoding='utf-8') as f: | |
| json.dump(existing, f, ensure_ascii=False, indent=2) | |
| except Exception as e: | |
| logger.error(f"写入上下文失败: {e}") | |
| async def save_report(self, report: str, language: str = "en") -> None: | |
| """保存技术报告 (异步,不阻塞事件循环)""" | |
| await asyncio.to_thread(self._write_report, report, language) | |
| def _write_report(self, report: str, language: str) -> None: | |
| """写入报告 (同步,供线程池调用)""" | |
| try: | |
| existing = {} | |
| if os.path.exists(self._context_file): | |
| with open(self._context_file, 'r', encoding='utf-8') as f: | |
| existing = json.load(f) | |
| if "reports" not in existing: | |
| existing["reports"] = {} | |
| existing["reports"][language] = report | |
| existing["report"] = report | |
| existing["report_language"] = language | |
| with open(self._context_file, 'w', encoding='utf-8') as f: | |
| json.dump(existing, f, ensure_ascii=False, indent=2) | |
| logger.info(f"📝 报告已保存: {self.session_id} ({language})") | |
| except Exception as e: | |
| logger.error(f"保存报告失败: {e}") | |
| def get_report(self, language: str = "en") -> Optional[str]: | |
| """ | |
| 获取指定语言的报告 | |
| Args: | |
| language: 语言代码 ('en', 'zh') | |
| Returns: | |
| 报告内容,不存在返回 None | |
| """ | |
| context = self.load_context() | |
| if not context: | |
| return None | |
| # 优先从 reports 字典获取 | |
| reports = context.get("reports", {}) | |
| if language in reports: | |
| return reports[language] | |
| # 兼容旧格式:如果只有 report 字段且语言匹配 | |
| if "report" in context: | |
| stored_lang = context.get("report_language", "en") | |
| if stored_lang == language: | |
| return context["report"] | |
| return None | |
| def get_available_languages(self) -> List[str]: | |
| """获取已有报告的语言列表""" | |
| context = self.load_context() | |
| if not context: | |
| return [] | |
| reports = context.get("reports", {}) | |
| return list(reports.keys()) | |
| def load_context(self) -> Optional[Dict[str, Any]]: | |
| """ | |
| 加载仓库上下文 | |
| Returns: | |
| 包含 repo_url, global_context, report 等的字典,不存在返回 None | |
| """ | |
| if not os.path.exists(self._context_file): | |
| return None | |
| try: | |
| with open(self._context_file, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| # 恢复内存状态 | |
| self.repo_url = data.get("repo_url") | |
| self.global_context = data.get("global_context", {}) | |
| return data | |
| except Exception as e: | |
| logger.error(f"加载上下文失败: {e}") | |
| return None | |
| def has_index(self) -> bool: | |
| """检查是否已有索引""" | |
| context = self.load_context() | |
| return context is not None and context.get("repo_url") is not None | |
| async def reset(self) -> None: | |
| """重置存储 (分析新仓库时调用)""" | |
| await self.initialize() | |
| # 删除 Qdrant 集合 | |
| if self._qdrant: | |
| await self._qdrant.delete_collection() | |
| await self._qdrant.initialize() | |
| # 清理本地文件 | |
| for f in [self._context_file, self._cache_file]: | |
| if os.path.exists(f): | |
| os.remove(f) | |
| # 重置内存状态 | |
| self._bm25 = None | |
| self._doc_store = [] | |
| self._indexed_files = set() | |
| self.repo_url = None | |
| self.global_context = {} | |
| logger.info(f"🗑️ 重置存储: {self.session_id}") | |
| # 兼容旧接口 | |
| def reset_collection(self) -> None: | |
| """同步重置 (兼容旧代码)""" | |
| asyncio.get_event_loop().run_until_complete(self.reset()) | |
| async def add_documents( | |
| self, | |
| documents: List[str], | |
| metadatas: List[Dict[str, Any]] | |
| ) -> int: | |
| """ | |
| 添加文档 | |
| Args: | |
| documents: 文档内容列表 | |
| metadatas: 元数据列表 | |
| Returns: | |
| 成功添加的数量 | |
| """ | |
| if not documents: | |
| return 0 | |
| await self.initialize() | |
| # 1. 批量获取 Embedding | |
| logger.info(f"📊 Embedding: {len(documents)} 个文档") | |
| embedding_service = get_embedding() | |
| embeddings = await embedding_service.embed_batch(documents, show_progress=True) | |
| # 过滤无效的 | |
| valid_indices = [i for i, emb in enumerate(embeddings) if emb] | |
| if not valid_indices: | |
| logger.error("所有 Embedding 都失败了") | |
| return 0 | |
| # 2. 构建 Document 对象 | |
| docs = [] | |
| for i in valid_indices: | |
| doc_id = f"{metadatas[i].get('file', 'unknown')}_{len(self._doc_store) + len(docs)}" | |
| doc = Document( | |
| id=doc_id, | |
| content=documents[i], | |
| metadata=metadatas[i], | |
| ) | |
| docs.append(doc) | |
| valid_embeddings = [embeddings[i] for i in valid_indices] | |
| # 3. 写入 Qdrant | |
| added = await self._qdrant.add_documents(docs, valid_embeddings) | |
| # 4. 更新 BM25 索引 (放入线程池,避免阻塞) | |
| self._doc_store.extend(docs) | |
| self._indexed_files.update(doc.file_path for doc in docs) | |
| await asyncio.to_thread(self._rebuild_bm25_sync) | |
| return added | |
| def _rebuild_bm25_sync(self) -> None: | |
| """重建 BM25 索引 (同步,用于线程池)""" | |
| tokenized = [self._tokenize(doc.content) for doc in self._doc_store] | |
| self._bm25 = BM25Okapi(tokenized) | |
| self._save_bm25_cache() | |
| async def embed_text(self, text: str) -> List[float]: | |
| """获取文本 Embedding""" | |
| embedding_service = get_embedding() | |
| return await embedding_service.embed_text(text) | |
| async def search_hybrid( | |
| self, | |
| query: str, | |
| top_k: int = None | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| 混合搜索 (向量 + BM25,RRF 融合) | |
| Args: | |
| query: 查询文本 | |
| top_k: 返回数量 | |
| Returns: | |
| 搜索结果列表 | |
| """ | |
| await self.initialize() | |
| top_k = top_k or config.default_top_k | |
| candidate_k = top_k * config.search_oversample | |
| # 1. 向量搜索 | |
| vector_results: List[SearchResult] = [] | |
| query_embedding = await self.embed_text(query) | |
| if query_embedding and self._qdrant: | |
| vector_results = await self._qdrant.search( | |
| query_embedding, | |
| top_k=candidate_k | |
| ) | |
| # 2. BM25 搜索 | |
| bm25_results: List[SearchResult] = [] | |
| if self._bm25 and self._doc_store: | |
| tokens = self._tokenize(query) | |
| if not tokens: | |
| tokens = [""] | |
| try: | |
| scores = self._bm25.get_scores(tokens) | |
| top_indices = sorted( | |
| range(len(scores)), | |
| key=lambda i: scores[i], | |
| reverse=True | |
| )[:candidate_k] | |
| for idx in top_indices: | |
| if scores[idx] > 0: | |
| doc = self._doc_store[idx] | |
| bm25_results.append(SearchResult( | |
| document=doc, | |
| score=scores[idx], | |
| source="bm25", | |
| )) | |
| except Exception as e: | |
| logger.error(f"BM25 搜索失败: {e}") | |
| # 3. RRF 融合 | |
| fused = self._rrf_fusion(vector_results, bm25_results) | |
| # 4. 格式化输出 (兼容旧接口) | |
| results = [] | |
| for item in fused[:top_k]: | |
| doc = item.document | |
| results.append({ | |
| "id": doc.id, | |
| "content": doc.content, | |
| "file": doc.file_path, | |
| "metadata": doc.metadata, | |
| "score": item.score, | |
| }) | |
| return results | |
| def _rrf_fusion( | |
| self, | |
| vector_results: List[SearchResult], | |
| bm25_results: List[SearchResult] | |
| ) -> List[SearchResult]: | |
| """RRF (Reciprocal Rank Fusion) 融合""" | |
| k = config.rrf_k | |
| fused: Dict[str, Dict] = {} | |
| # 向量结果 | |
| for rank, result in enumerate(vector_results): | |
| doc_id = result.document.id | |
| if doc_id not in fused: | |
| fused[doc_id] = {"result": result, "score": 0} | |
| fused[doc_id]["score"] += config.rrf_weight_vector / (k + rank + 1) | |
| # BM25 结果 | |
| for rank, result in enumerate(bm25_results): | |
| doc_id = result.document.id | |
| if doc_id not in fused: | |
| fused[doc_id] = {"result": result, "score": 0} | |
| fused[doc_id]["score"] += config.rrf_weight_bm25 / (k + rank + 1) | |
| # 排序 | |
| sorted_items = sorted( | |
| fused.values(), | |
| key=lambda x: x["score"], | |
| reverse=True | |
| ) | |
| return [ | |
| SearchResult( | |
| document=item["result"].document, | |
| score=item["score"], | |
| source="hybrid", | |
| ) | |
| for item in sorted_items | |
| ] | |
| def get_documents_by_file(self, file_path: str) -> List[Dict[str, Any]]: | |
| """根据文件路径获取文档 (兼容旧接口)""" | |
| docs = [ | |
| doc for doc in self._doc_store | |
| if doc.file_path == file_path | |
| ] | |
| result = [] | |
| for doc in sorted(docs, key=lambda d: d.metadata.get("start_line", 0)): | |
| result.append({ | |
| "id": doc.id, | |
| "content": doc.content, | |
| "file": doc.file_path, | |
| "metadata": doc.metadata, | |
| "score": 1.0, | |
| }) | |
| return result | |
| def indexed_files(self) -> Set[str]: | |
| """已索引的文件""" | |
| return self._indexed_files | |
| # ============================================================ | |
| # 管理器 - LRU Cache + 过期清理 | |
| # ============================================================ | |
| class SessionEntry: | |
| """Session 条目 - 包含存储实例和访问时间""" | |
| __slots__ = ('store', 'last_access', 'created_at') | |
| def __init__(self, store: VectorStore): | |
| self.store = store | |
| self.last_access = time.time() | |
| self.created_at = time.time() | |
| def touch(self) -> None: | |
| """更新访问时间""" | |
| self.last_access = time.time() | |
| class VectorStoreManager: | |
| """ | |
| 向量存储管理器 - LRU Cache 实现 | |
| 特性: | |
| 1. LRU 淘汰 - 超过 max_count 时淘汰最久未访问的内存中的 session | |
| 2. 仓库数据永久存储 - 不清理仓库索引和报告 | |
| 3. 线程安全 - 使用 asyncio.Lock | |
| """ | |
| def __init__(self, max_count: int = None): | |
| self._max_count = max_count or config.session_max_count | |
| self._sessions: Dict[str, SessionEntry] = {} | |
| self._lock = asyncio.Lock() | |
| def get_store(self, session_id: str) -> VectorStore: | |
| """ | |
| 获取或创建存储实例 (同步接口,兼容现有代码) | |
| 会触发 LRU 淘汰检查 | |
| """ | |
| if session_id in self._sessions: | |
| entry = self._sessions[session_id] | |
| entry.touch() | |
| # 移动到最后(模拟 LRU) | |
| self._sessions.pop(session_id) | |
| self._sessions[session_id] = entry | |
| return entry.store | |
| # 创建新 session | |
| store = VectorStore(session_id) | |
| entry = SessionEntry(store) | |
| self._sessions[session_id] = entry | |
| # 检查是否需要 LRU 淘汰(异步执行) | |
| if len(self._sessions) > self._max_count: | |
| asyncio.create_task(self._evict_lru()) | |
| logger.info(f"📦 Session 创建: {session_id} (总数: {len(self._sessions)})") | |
| return store | |
| async def _evict_lru(self) -> None: | |
| """淘汰最久未访问的 session""" | |
| async with self._lock: | |
| while len(self._sessions) > self._max_count: | |
| # 找到最久未访问的 | |
| oldest_id = min( | |
| self._sessions.keys(), | |
| key=lambda k: self._sessions[k].last_access | |
| ) | |
| entry = self._sessions.pop(oldest_id) | |
| await entry.store.close() | |
| logger.info(f"🗑️ LRU 淘汰: {oldest_id}") | |
| async def close_session(self, session_id: str) -> None: | |
| """关闭指定 session""" | |
| async with self._lock: | |
| if session_id in self._sessions: | |
| entry = self._sessions.pop(session_id) | |
| await entry.store.close() | |
| logger.info(f"🔒 Session 关闭: {session_id}") | |
| async def close_all(self) -> None: | |
| """关闭所有连接""" | |
| async with self._lock: | |
| for session_id, entry in list(self._sessions.items()): | |
| await entry.store.close() | |
| self._sessions.clear() | |
| logger.info("🔒 所有 Session 已关闭") | |
| def get_stats(self) -> Dict[str, Any]: | |
| """获取管理器统计信息""" | |
| now = time.time() | |
| sessions_info = [] | |
| for sid, entry in self._sessions.items(): | |
| sessions_info.append({ | |
| "session_id": sid, | |
| "age_hours": round((now - entry.created_at) / 3600, 2), | |
| "idle_minutes": round((now - entry.last_access) / 60, 2), | |
| }) | |
| return { | |
| "total_sessions": len(self._sessions), | |
| "max_sessions": self._max_count, | |
| "sessions": sorted(sessions_info, key=lambda x: x["idle_minutes"], reverse=True) | |
| } | |
| # 全局管理器 | |
| store_manager = VectorStoreManager() | |