Spaces:
Paused
Paused
| import os | |
| import re | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| from typing import List, Dict, Tuple, Optional | |
| from openai import OpenAI | |
| from datetime import datetime | |
| import csv | |
| class KnowledgeBaseVectorizer: | |
| def __init__(self, api_key: str, data_path: str = "", vector_db_dir: str = ""): | |
| """ | |
| 初始化向量化器(适配学生Space) | |
| Args: | |
| api_key: OpenAI API密钥 | |
| data_path: knowledge_base.md文件的路径(如果为空,使用vector_db_dir中的文件) | |
| vector_db_dir: 向量数据库所在目录(通常是数据存储仓库的本地目录) | |
| """ | |
| self.client = OpenAI(api_key=api_key) | |
| self.embedding_model = "text-embedding-3-small" | |
| # 如果指定了vector_db_dir,优先使用该目录中的文件 | |
| if vector_db_dir: | |
| self.data_path = os.path.join(vector_db_dir, "knowledge_base.md") | |
| self.vector_db_path = os.path.join(vector_db_dir, "vector_database.csv") | |
| self.metadata_path = os.path.join(vector_db_dir, "vector_metadata.json") | |
| else: | |
| # 保持原有逻辑用于向后兼容 | |
| self.data_path = data_path if data_path else "knowledge_base.md" | |
| self.vector_db_path = "vector_database.csv" | |
| self.metadata_path = "vector_metadata.json" | |
| # 缓存相关属性 | |
| self._cached_df = None | |
| self._cached_metadata = None | |
| self._cached_embeddings = {} # 缓存不同类型的向量矩阵 | |
| self._last_load_time = None | |
| print(f"[KnowledgeBaseVectorizer] Initialized with:") | |
| print(f" - Knowledge base: {self.data_path}") | |
| print(f" - Vector database: {self.vector_db_path}") | |
| print(f" - Metadata: {self.metadata_path}") | |
| def parse_knowledge_base(self) -> List[Dict]: | |
| """ | |
| 解析knowledge_base.md文件,提取所有数据条目 | |
| 支持包含表格的完整内容提取 | |
| Returns: | |
| 包含所有数据条目的列表,每个条目是一个字典 | |
| """ | |
| entries = [] | |
| try: | |
| with open(self.data_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| print(f"[parse_knowledge_base] Successfully read file: {self.data_path}") | |
| except FileNotFoundError: | |
| print(f"[parse_knowledge_base] Error: File not found - {self.data_path}") | |
| return entries | |
| except Exception as e: | |
| print(f"[parse_knowledge_base] Error reading file: {e}") | |
| return entries | |
| # 改进的匹配策略:使用更精确的正则表达式 | |
| # 匹配模式:# xx-xx-xx title **source** ... **content** ... (直到下一个 # 或文件结尾) | |
| pattern = r'#\s+(\d{2}-\d{2}-\d{2})\s+([^\n]+)\s+\*\*source\*\*\s+([^\n]+)\s+\*\*content\*\*\s+(.*?)(?=\n#\s+\d{2}-\d{2}-\d{2}|$)' | |
| matches = re.findall(pattern, content, re.DOTALL) | |
| for match in matches: | |
| # 清理内容:移除多余的空白行,但保留表格格式 | |
| content_text = match[3].strip() | |
| # 保留表格的结构,但清理多余的空白 | |
| content_lines = content_text.split('\n') | |
| cleaned_lines = [] | |
| for line in content_lines: | |
| # 保留非空行和表格行 | |
| if line.strip() or (line.startswith('|') and line.endswith('|')): | |
| cleaned_lines.append(line.rstrip()) | |
| # 重新组合内容 | |
| cleaned_content = '\n'.join(cleaned_lines) | |
| entry = { | |
| 'id': match[0].strip(), | |
| 'title': match[1].strip(), | |
| 'source': match[2].strip(), | |
| 'content': cleaned_content, | |
| 'full_text': f"{match[1].strip()} {cleaned_content}" # 用于向量化的完整文本 | |
| } | |
| entries.append(entry) | |
| print(f"[parse_knowledge_base] Successfully parsed {len(entries)} entries") | |
| # 打印一些调试信息 | |
| if entries: | |
| print("[parse_knowledge_base] First 3 entries info:") | |
| for i, entry in enumerate(entries[:3]): | |
| content_lines = entry['content'].count('\n') + 1 | |
| has_table = '|' in entry['content'] | |
| print(f" Entry {entry['id']}: {len(entry['content'])} chars, {content_lines} lines, has table: {has_table}") | |
| return entries | |
| def get_embedding(self, text: str) -> List[float]: | |
| """ | |
| 使用OpenAI API获取文本的向量表示 | |
| Args: | |
| text: 要向量化的文本 | |
| Returns: | |
| 文本的向量表示 | |
| """ | |
| try: | |
| response = self.client.embeddings.create( | |
| input=text, | |
| model=self.embedding_model | |
| ) | |
| return response.data[0].embedding | |
| except Exception as e: | |
| print(f"[get_embedding] Error: {e}") | |
| return [] | |
| def batch_get_embeddings(self, texts: List[str], batch_size: int = 10) -> List[List[float]]: | |
| """ | |
| 批量获取文本的向量表示 | |
| Args: | |
| texts: 要向量化的文本列表 | |
| batch_size: 批处理大小 | |
| Returns: | |
| 向量列表 | |
| """ | |
| embeddings = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i:i + batch_size] | |
| print(f"[batch_get_embeddings] Processing batch {i//batch_size + 1}/{(len(texts) + batch_size - 1)//batch_size}") | |
| try: | |
| response = self.client.embeddings.create( | |
| input=batch, | |
| model=self.embedding_model | |
| ) | |
| batch_embeddings = [item.embedding for item in response.data] | |
| embeddings.extend(batch_embeddings) | |
| except Exception as e: | |
| print(f"[batch_get_embeddings] Batch error: {e}") | |
| # 如果批处理失败,尝试单个处理 | |
| for text in batch: | |
| embedding = self.get_embedding(text) | |
| embeddings.append(embedding if embedding else [0] * 1536) # 默认维度 | |
| return embeddings | |
| def create_vector_database(self): | |
| """ | |
| 创建向量数据库并保存为CSV文件 | |
| 支持标题和内容的分别向量化 | |
| """ | |
| print("[create_vector_database] Starting to create vector database...") | |
| # 1. 解析知识库 | |
| entries = self.parse_knowledge_base() | |
| if not entries: | |
| print("[create_vector_database] No entries found") | |
| return | |
| # 2. 准备要向量化的文本 | |
| titles = [entry['title'] for entry in entries] | |
| contents = [entry['content'] for entry in entries] | |
| full_texts = [entry['full_text'] for entry in entries] | |
| # 3. 批量获取向量 | |
| print("[create_vector_database] Vectorizing titles...") | |
| title_embeddings = self.batch_get_embeddings(titles) | |
| print("[create_vector_database] Vectorizing contents...") | |
| content_embeddings = self.batch_get_embeddings(contents) | |
| print("[create_vector_database] Vectorizing full texts...") | |
| full_embeddings = self.batch_get_embeddings(full_texts) | |
| # 4. 创建DataFrame来存储数据 | |
| print("[create_vector_database] Creating DataFrame...") | |
| # 准备数据行 | |
| rows = [] | |
| for i, entry in enumerate(entries): | |
| row = { | |
| 'index': i, | |
| 'id': entry['id'], | |
| 'title': entry['title'], | |
| 'source': entry['source'], | |
| 'content': entry['content'], | |
| 'full_text': entry['full_text'] | |
| } | |
| # 添加标题向量维度 | |
| for j, val in enumerate(title_embeddings[i]): | |
| row[f'title_dim_{j}'] = val | |
| # 添加内容向量维度 | |
| for j, val in enumerate(content_embeddings[i]): | |
| row[f'content_dim_{j}'] = val | |
| # 添加完整文本向量维度 | |
| for j, val in enumerate(full_embeddings[i]): | |
| row[f'full_dim_{j}'] = val | |
| rows.append(row) | |
| # 创建DataFrame | |
| df = pd.DataFrame(rows) | |
| # 5. 保存为CSV文件 | |
| print(f"[create_vector_database] Saving to {self.vector_db_path}...") | |
| df.to_csv(self.vector_db_path, index=False, encoding='utf-8') | |
| # 6. 保存元数据(JSON格式,便于查看) | |
| metadata = { | |
| 'embedding_model': self.embedding_model, | |
| 'created_at': datetime.now().isoformat(), | |
| 'num_entries': len(entries), | |
| 'embedding_dimensions': len(title_embeddings[0]) if title_embeddings else 0, | |
| 'vector_types': ['title', 'content', 'full'], | |
| 'columns': list(df.columns), | |
| 'entries_summary': [ | |
| { | |
| 'id': entry['id'], | |
| 'title': entry['title'], | |
| 'source': entry['source'] | |
| } for entry in entries | |
| ] | |
| } | |
| with open(self.metadata_path, 'w', encoding='utf-8') as f: | |
| json.dump(metadata, f, ensure_ascii=False, indent=2) | |
| print(f"[create_vector_database] Vector database created successfully!") | |
| print(f" - Vector database saved to: {self.vector_db_path}") | |
| print(f" - Metadata saved to: {self.metadata_path}") | |
| print(f" - Processed {len(entries)} entries") | |
| print(f" - Vector dimensions: {len(title_embeddings[0]) if title_embeddings else 0}") | |
| # 清除缓存以便重新加载 | |
| self.clear_cache() | |
| def clear_cache(self): | |
| """清除所有缓存""" | |
| self._cached_df = None | |
| self._cached_metadata = None | |
| self._cached_embeddings = {} | |
| self._last_load_time = None | |
| print("[clear_cache] Vector database cache cleared") | |
| def load_vector_database(self, force_reload: bool = False) -> Tuple[Optional[pd.DataFrame], Optional[Dict]]: | |
| """ | |
| 从CSV文件加载向量数据库(带缓存机制) | |
| Args: | |
| force_reload: 是否强制重新加载 | |
| Returns: | |
| DataFrame和元数据字典的元组 | |
| """ | |
| # 检查是否需要重新加载 | |
| if not force_reload and self._cached_df is not None and self._cached_metadata is not None: | |
| return self._cached_df, self._cached_metadata | |
| try: | |
| # 加载CSV文件 | |
| print(f"[load_vector_database] Loading from {self.vector_db_path}") | |
| df = pd.read_csv(self.vector_db_path, encoding='utf-8') | |
| # 加载元数据 | |
| print(f"[load_vector_database] Loading metadata from {self.metadata_path}") | |
| with open(self.metadata_path, 'r', encoding='utf-8') as f: | |
| metadata = json.load(f) | |
| # 缓存结果 | |
| self._cached_df = df | |
| self._cached_metadata = metadata | |
| self._last_load_time = datetime.now() | |
| # 预加载向量矩阵到缓存 | |
| self._preload_embeddings() | |
| print(f"[load_vector_database] Successfully loaded vector database with {len(df)} entries") | |
| return df, metadata | |
| except FileNotFoundError as e: | |
| print(f"[load_vector_database] Error: File not found - {e}") | |
| return None, None | |
| except Exception as e: | |
| print(f"[load_vector_database] Error loading vector database: {e}") | |
| return None, None | |
| def _preload_embeddings(self): | |
| """预加载所有类型的向量矩阵到缓存""" | |
| if self._cached_df is None: | |
| return | |
| vector_types = ['title', 'content', 'full'] | |
| for vector_type in vector_types: | |
| if vector_type not in self._cached_embeddings: | |
| embeddings = self.get_embeddings_from_df(self._cached_df, vector_type) | |
| # 预计算归一化向量 | |
| embeddings_norm = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) | |
| self._cached_embeddings[vector_type] = { | |
| 'raw': embeddings, | |
| 'normalized': embeddings_norm | |
| } | |
| print(f"[_preload_embeddings] Preloaded {len(vector_types)} types of vector matrices") | |
| def get_embeddings_from_df(self, df: pd.DataFrame, vector_type: str = 'full') -> np.ndarray: | |
| """ | |
| 从DataFrame中提取向量矩阵 | |
| Args: | |
| df: 包含向量的DataFrame | |
| vector_type: 向量类型 ('title', 'content', 'full') | |
| Returns: | |
| 向量矩阵 | |
| """ | |
| # 根据类型获取对应的列 | |
| if vector_type == 'title': | |
| embedding_cols = [col for col in df.columns if col.startswith('title_dim_')] | |
| elif vector_type == 'content': | |
| embedding_cols = [col for col in df.columns if col.startswith('content_dim_')] | |
| else: # 'full' | |
| embedding_cols = [col for col in df.columns if col.startswith('full_dim_')] | |
| embeddings = df[embedding_cols].values | |
| return embeddings | |
| def batch_search_similar(self, queries: List[str], top_k: int = 5, | |
| title_weight: float = 0.4, | |
| content_weight: float = 0.3, | |
| full_weight: float = 0.3) -> List[List[Tuple[Dict, float, Dict]]]: | |
| """ | |
| 批量搜索多个查询,只加载一次向量数据库 | |
| Args: | |
| queries: 查询文本列表 | |
| top_k: 每个查询返回最相似的前k个结果 | |
| title_weight: 标题相似度的权重 | |
| content_weight: 内容相似度的权重 | |
| full_weight: 完整文本相似度的权重 | |
| Returns: | |
| 每个查询对应的相似条目列表 | |
| """ | |
| # 确保权重之和为1 | |
| total_weight = title_weight + content_weight + full_weight | |
| title_weight /= total_weight | |
| content_weight /= total_weight | |
| full_weight /= total_weight | |
| # 加载向量数据库(只加载一次) | |
| df, metadata = self.load_vector_database() | |
| if df is None: | |
| return [[] for _ in queries] | |
| # 批量获取查询向量 | |
| print(f"[batch_search_similar] Generating vectors for {len(queries)} queries...") | |
| query_embeddings = self.batch_get_embeddings(queries, batch_size=min(10, len(queries))) | |
| if len(query_embeddings) != len(queries): | |
| print("[batch_search_similar] Query vector generation failed") | |
| return [[] for _ in queries] | |
| # 获取缓存的归一化向量矩阵 | |
| title_embeddings_norm = self._cached_embeddings['title']['normalized'] | |
| content_embeddings_norm = self._cached_embeddings['content']['normalized'] | |
| full_embeddings_norm = self._cached_embeddings['full']['normalized'] | |
| all_results = [] | |
| # 对每个查询进行相似度计算 | |
| for i, (query, query_embedding) in enumerate(zip(queries, query_embeddings)): | |
| if not query_embedding: | |
| all_results.append([]) | |
| continue | |
| query_vec = np.array(query_embedding) | |
| query_vec_norm = query_vec / np.linalg.norm(query_vec) | |
| # 计算各部分的相似度 | |
| title_similarities = np.dot(title_embeddings_norm, query_vec_norm) | |
| content_similarities = np.dot(content_embeddings_norm, query_vec_norm) | |
| full_similarities = np.dot(full_embeddings_norm, query_vec_norm) | |
| # 加权综合相似度 | |
| combined_similarities = ( | |
| title_weight * title_similarities + | |
| content_weight * content_similarities + | |
| full_weight * full_similarities | |
| ) | |
| # 获取top-k | |
| top_indices = np.argsort(combined_similarities)[::-1][:top_k] | |
| query_results = [] | |
| for idx in top_indices: | |
| # 从DataFrame中获取条目信息 | |
| row = df.iloc[idx] | |
| entry = { | |
| 'id': row['id'], | |
| 'title': row['title'], | |
| 'source': row['source'], | |
| 'content': row['content'] | |
| } | |
| # 添加各部分的相似度详情 | |
| similarity_details = { | |
| 'combined': float(combined_similarities[idx]), | |
| 'title': float(title_similarities[idx]), | |
| 'content': float(content_similarities[idx]), | |
| 'full': float(full_similarities[idx]) | |
| } | |
| query_results.append((entry, float(combined_similarities[idx]), similarity_details)) | |
| all_results.append(query_results) | |
| print(f"[batch_search_similar] Completed query {i+1}/{len(queries)}: '{query[:50]}...'") | |
| return all_results | |
| def search_similar(self, query: str, top_k: int = 5, | |
| title_weight: float = 0.4, | |
| content_weight: float = 0.3, | |
| full_weight: float = 0.3) -> List[Tuple[Dict, float, Dict]]: | |
| """ | |
| 搜索与查询最相似的条目,综合考虑标题和内容的相似度 | |
| 使用批量搜索的优化版本 | |
| Args: | |
| query: 查询文本 | |
| top_k: 返回最相似的前k个结果 | |
| title_weight: 标题相似度的权重 | |
| content_weight: 内容相似度的权重 | |
| full_weight: 完整文本相似度的权重 | |
| Returns: | |
| 相似条目和相似度分数的列表 | |
| """ | |
| # 使用批量搜索处理单个查询 | |
| results = self.batch_search_similar([query], top_k, title_weight, content_weight, full_weight) | |
| return results[0] if results else [] | |
| def search_with_entities_optimized(self, entities: List[str], top_k: int = 5) -> List[Tuple[Dict, float, Dict]]: | |
| """ | |
| 优化版本:使用实体列表搜索知识库,只加载一次向量数据库 | |
| Args: | |
| entities: 实体列表 | |
| top_k: 每个实体返回的结果数 | |
| Returns: | |
| 合并和去重后的搜索结果 | |
| """ | |
| if not entities: | |
| return [] | |
| # 使用批量搜索 | |
| batch_results = self.batch_search_similar( | |
| entities, | |
| top_k=top_k, | |
| title_weight=0.3, # 对于实体搜索,标题权重更高 | |
| content_weight=0.5, | |
| full_weight=0.2 | |
| ) | |
| # 合并结果并去重 | |
| seen_ids = set() | |
| all_results = [] | |
| for entity_results in batch_results: | |
| for entry, score, details in entity_results: | |
| entry_id = entry['id'] | |
| if entry_id not in seen_ids: | |
| seen_ids.add(entry_id) | |
| all_results.append((entry, score, details)) | |
| # 按分数排序 | |
| sorted_results = sorted(all_results, key=lambda x: x[1], reverse=True) | |
| return sorted_results | |
| def get_cache_info(self) -> Dict: | |
| """ | |
| 获取缓存状态信息 | |
| Returns: | |
| 缓存状态字典 | |
| """ | |
| return { | |
| 'is_cached': self._cached_df is not None, | |
| 'cache_size': len(self._cached_df) if self._cached_df is not None else 0, | |
| 'cached_embeddings': list(self._cached_embeddings.keys()), | |
| 'last_load_time': self._last_load_time.isoformat() if self._last_load_time else None, | |
| 'data_paths': { | |
| 'knowledge_base': self.data_path, | |
| 'vector_database': self.vector_db_path, | |
| 'metadata': self.metadata_path | |
| } | |
| } |