Spaces:
Sleeping
Sleeping
| """ | |
| 图谱构建服务 | |
| 接口2:使用Zep API构建Standalone Graph | |
| """ | |
| import os | |
| import uuid | |
| import time | |
| import threading | |
| from typing import Dict, Any, List, Optional, Callable | |
| from dataclasses import dataclass | |
| from zep_cloud.client import Zep | |
| from zep_cloud import EpisodeData, EntityEdgeSourceTarget | |
| from ..config import Config | |
| from ..models.task import TaskManager, TaskStatus | |
| from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges | |
| from .text_processor import TextProcessor | |
| class GraphInfo: | |
| """图谱信息""" | |
| graph_id: str | |
| node_count: int | |
| edge_count: int | |
| entity_types: List[str] | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "graph_id": self.graph_id, | |
| "node_count": self.node_count, | |
| "edge_count": self.edge_count, | |
| "entity_types": self.entity_types, | |
| } | |
| class GraphBuilderService: | |
| """ | |
| 图谱构建服务 | |
| 负责调用Zep API构建知识图谱 | |
| """ | |
| def __init__(self, api_key: Optional[str] = None): | |
| self.api_key = api_key or Config.ZEP_API_KEY | |
| if not self.api_key: | |
| raise ValueError("ZEP_API_KEY 未配置") | |
| self.client = Zep(api_key=self.api_key) | |
| self.task_manager = TaskManager() | |
| def build_graph_async( | |
| self, | |
| text: str, | |
| ontology: Dict[str, Any], | |
| graph_name: str = "MiroFish Graph", | |
| chunk_size: int = 500, | |
| chunk_overlap: int = 50, | |
| batch_size: int = 3 | |
| ) -> str: | |
| """ | |
| 异步构建图谱 | |
| Args: | |
| text: 输入文本 | |
| ontology: 本体定义(来自接口1的输出) | |
| graph_name: 图谱名称 | |
| chunk_size: 文本块大小 | |
| chunk_overlap: 块重叠大小 | |
| batch_size: 每批发送的块数量 | |
| Returns: | |
| 任务ID | |
| """ | |
| # 创建任务 | |
| task_id = self.task_manager.create_task( | |
| task_type="graph_build", | |
| metadata={ | |
| "graph_name": graph_name, | |
| "chunk_size": chunk_size, | |
| "text_length": len(text), | |
| } | |
| ) | |
| # 在后台线程中执行构建 | |
| thread = threading.Thread( | |
| target=self._build_graph_worker, | |
| args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size) | |
| ) | |
| thread.daemon = True | |
| thread.start() | |
| return task_id | |
| def _build_graph_worker( | |
| self, | |
| task_id: str, | |
| text: str, | |
| ontology: Dict[str, Any], | |
| graph_name: str, | |
| chunk_size: int, | |
| chunk_overlap: int, | |
| batch_size: int | |
| ): | |
| """图谱构建工作线程""" | |
| try: | |
| self.task_manager.update_task( | |
| task_id, | |
| status=TaskStatus.PROCESSING, | |
| progress=5, | |
| message="开始构建图谱..." | |
| ) | |
| # 1. 创建图谱 | |
| graph_id = self.create_graph(graph_name) | |
| self.task_manager.update_task( | |
| task_id, | |
| progress=10, | |
| message=f"图谱已创建: {graph_id}" | |
| ) | |
| # 2. 设置本体 | |
| self.set_ontology(graph_id, ontology) | |
| self.task_manager.update_task( | |
| task_id, | |
| progress=15, | |
| message="本体已设置" | |
| ) | |
| # 3. 文本分块 | |
| chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap) | |
| total_chunks = len(chunks) | |
| self.task_manager.update_task( | |
| task_id, | |
| progress=20, | |
| message=f"文本已分割为 {total_chunks} 个块" | |
| ) | |
| # 4. 分批发送数据 | |
| episode_uuids = self.add_text_batches( | |
| graph_id, chunks, batch_size, | |
| lambda msg, prog: self.task_manager.update_task( | |
| task_id, | |
| progress=20 + int(prog * 0.4), # 20-60% | |
| message=msg | |
| ) | |
| ) | |
| # 5. 等待Zep处理完成 | |
| self.task_manager.update_task( | |
| task_id, | |
| progress=60, | |
| message="等待Zep处理数据..." | |
| ) | |
| self._wait_for_episodes( | |
| episode_uuids, | |
| lambda msg, prog: self.task_manager.update_task( | |
| task_id, | |
| progress=60 + int(prog * 0.3), # 60-90% | |
| message=msg | |
| ) | |
| ) | |
| # 6. 获取图谱信息 | |
| self.task_manager.update_task( | |
| task_id, | |
| progress=90, | |
| message="获取图谱信息..." | |
| ) | |
| graph_info = self._get_graph_info(graph_id) | |
| # 完成 | |
| self.task_manager.complete_task(task_id, { | |
| "graph_id": graph_id, | |
| "graph_info": graph_info.to_dict(), | |
| "chunks_processed": total_chunks, | |
| }) | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"{str(e)}\n{traceback.format_exc()}" | |
| self.task_manager.fail_task(task_id, error_msg) | |
| def create_graph(self, name: str) -> str: | |
| """创建Zep图谱(公开方法)""" | |
| graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" | |
| self.client.graph.create( | |
| graph_id=graph_id, | |
| name=name, | |
| description="MiroFish Social Simulation Graph" | |
| ) | |
| return graph_id | |
| def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): | |
| """设置图谱本体(公开方法)""" | |
| import warnings | |
| from typing import Optional | |
| from pydantic import Field | |
| from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel | |
| # 抑制 Pydantic v2 关于 Field(default=None) 的警告 | |
| # 这是 Zep SDK 要求的用法,警告来自动态类创建,可以安全忽略 | |
| warnings.filterwarnings('ignore', category=UserWarning, module='pydantic') | |
| # Zep 保留名称,不能作为属性名 | |
| RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'} | |
| def safe_attr_name(attr_name: str) -> str: | |
| """将保留名称转换为安全名称""" | |
| if attr_name.lower() in RESERVED_NAMES: | |
| return f"entity_{attr_name}" | |
| return attr_name | |
| # 动态创建实体类型 | |
| entity_types = {} | |
| for entity_def in ontology.get("entity_types", []): | |
| name = entity_def["name"] | |
| description = entity_def.get("description", f"A {name} entity.") | |
| # 创建属性字典和类型注解(Pydantic v2 需要) | |
| attrs = {"__doc__": description} | |
| annotations = {} | |
| for attr_def in entity_def.get("attributes", []): | |
| attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 | |
| attr_desc = attr_def.get("description", attr_name) | |
| # Zep API 需要 Field 的 description,这是必需的 | |
| attrs[attr_name] = Field(description=attr_desc, default=None) | |
| annotations[attr_name] = Optional[EntityText] # 类型注解 | |
| attrs["__annotations__"] = annotations | |
| # 动态创建类 | |
| entity_class = type(name, (EntityModel,), attrs) | |
| entity_class.__doc__ = description | |
| entity_types[name] = entity_class | |
| # 动态创建边类型 | |
| edge_definitions = {} | |
| for edge_def in ontology.get("edge_types", []): | |
| name = edge_def["name"] | |
| description = edge_def.get("description", f"A {name} relationship.") | |
| # 创建属性字典和类型注解 | |
| attrs = {"__doc__": description} | |
| annotations = {} | |
| for attr_def in edge_def.get("attributes", []): | |
| attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 | |
| attr_desc = attr_def.get("description", attr_name) | |
| # Zep API 需要 Field 的 description,这是必需的 | |
| attrs[attr_name] = Field(description=attr_desc, default=None) | |
| annotations[attr_name] = Optional[str] # 边属性用str类型 | |
| attrs["__annotations__"] = annotations | |
| # 动态创建类 | |
| class_name = ''.join(word.capitalize() for word in name.split('_')) | |
| edge_class = type(class_name, (EdgeModel,), attrs) | |
| edge_class.__doc__ = description | |
| # 构建source_targets | |
| source_targets = [] | |
| for st in edge_def.get("source_targets", []): | |
| source_targets.append( | |
| EntityEdgeSourceTarget( | |
| source=st.get("source", "Entity"), | |
| target=st.get("target", "Entity") | |
| ) | |
| ) | |
| if source_targets: | |
| edge_definitions[name] = (edge_class, source_targets) | |
| # 调用Zep API设置本体 | |
| if entity_types or edge_definitions: | |
| self.client.graph.set_ontology( | |
| graph_ids=[graph_id], | |
| entities=entity_types if entity_types else None, | |
| edges=edge_definitions if edge_definitions else None, | |
| ) | |
| def add_text_batches( | |
| self, | |
| graph_id: str, | |
| chunks: List[str], | |
| batch_size: int = 3, | |
| progress_callback: Optional[Callable] = None | |
| ) -> List[str]: | |
| """分批添加文本到图谱,返回所有 episode 的 uuid 列表""" | |
| episode_uuids = [] | |
| total_chunks = len(chunks) | |
| for i in range(0, total_chunks, batch_size): | |
| batch_chunks = chunks[i:i + batch_size] | |
| batch_num = i // batch_size + 1 | |
| total_batches = (total_chunks + batch_size - 1) // batch_size | |
| if progress_callback: | |
| progress = (i + len(batch_chunks)) / total_chunks | |
| progress_callback( | |
| f"发送第 {batch_num}/{total_batches} 批数据 ({len(batch_chunks)} 块)...", | |
| progress | |
| ) | |
| # 构建episode数据 | |
| episodes = [ | |
| EpisodeData(data=chunk, type="text") | |
| for chunk in batch_chunks | |
| ] | |
| # 发送到Zep | |
| try: | |
| batch_result = self.client.graph.add_batch( | |
| graph_id=graph_id, | |
| episodes=episodes | |
| ) | |
| # 收集返回的 episode uuid | |
| if batch_result and isinstance(batch_result, list): | |
| for ep in batch_result: | |
| ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None) | |
| if ep_uuid: | |
| episode_uuids.append(ep_uuid) | |
| # 避免请求过快 | |
| time.sleep(1) | |
| except Exception as e: | |
| if progress_callback: | |
| progress_callback(f"批次 {batch_num} 发送失败: {str(e)}", 0) | |
| raise | |
| return episode_uuids | |
| def _wait_for_episodes( | |
| self, | |
| episode_uuids: List[str], | |
| progress_callback: Optional[Callable] = None, | |
| timeout: int = 600 | |
| ): | |
| """等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)""" | |
| if not episode_uuids: | |
| if progress_callback: | |
| progress_callback("无需等待(没有 episode)", 1.0) | |
| return | |
| start_time = time.time() | |
| pending_episodes = set(episode_uuids) | |
| completed_count = 0 | |
| total_episodes = len(episode_uuids) | |
| if progress_callback: | |
| progress_callback(f"开始等待 {total_episodes} 个文本块处理...", 0) | |
| while pending_episodes: | |
| if time.time() - start_time > timeout: | |
| if progress_callback: | |
| progress_callback( | |
| f"部分文本块超时,已完成 {completed_count}/{total_episodes}", | |
| completed_count / total_episodes | |
| ) | |
| break | |
| # 检查每个 episode 的处理状态 | |
| for ep_uuid in list(pending_episodes): | |
| try: | |
| episode = self.client.graph.episode.get(uuid_=ep_uuid) | |
| is_processed = getattr(episode, 'processed', False) | |
| if is_processed: | |
| pending_episodes.remove(ep_uuid) | |
| completed_count += 1 | |
| except Exception as e: | |
| # 忽略单个查询错误,继续 | |
| pass | |
| elapsed = int(time.time() - start_time) | |
| if progress_callback: | |
| progress_callback( | |
| f"Zep处理中... {completed_count}/{total_episodes} 完成, {len(pending_episodes)} 待处理 ({elapsed}秒)", | |
| completed_count / total_episodes if total_episodes > 0 else 0 | |
| ) | |
| if pending_episodes: | |
| time.sleep(3) # 每3秒检查一次 | |
| if progress_callback: | |
| progress_callback(f"处理完成: {completed_count}/{total_episodes}", 1.0) | |
| def _get_graph_info(self, graph_id: str) -> GraphInfo: | |
| """获取图谱信息""" | |
| # 获取节点(分页) | |
| nodes = fetch_all_nodes(self.client, graph_id) | |
| # 获取边(分页) | |
| edges = fetch_all_edges(self.client, graph_id) | |
| # 统计实体类型 | |
| entity_types = set() | |
| for node in nodes: | |
| if node.labels: | |
| for label in node.labels: | |
| if label not in ["Entity", "Node"]: | |
| entity_types.add(label) | |
| return GraphInfo( | |
| graph_id=graph_id, | |
| node_count=len(nodes), | |
| edge_count=len(edges), | |
| entity_types=list(entity_types) | |
| ) | |
| def get_graph_data(self, graph_id: str) -> Dict[str, Any]: | |
| """ | |
| 获取完整图谱数据(包含详细信息) | |
| Args: | |
| graph_id: 图谱ID | |
| Returns: | |
| 包含nodes和edges的字典,包括时间信息、属性等详细数据 | |
| """ | |
| nodes = fetch_all_nodes(self.client, graph_id) | |
| edges = fetch_all_edges(self.client, graph_id) | |
| # 创建节点映射用于获取节点名称 | |
| node_map = {} | |
| for node in nodes: | |
| node_map[node.uuid_] = node.name or "" | |
| nodes_data = [] | |
| for node in nodes: | |
| # 获取创建时间 | |
| created_at = getattr(node, 'created_at', None) | |
| if created_at: | |
| created_at = str(created_at) | |
| nodes_data.append({ | |
| "uuid": node.uuid_, | |
| "name": node.name, | |
| "labels": node.labels or [], | |
| "summary": node.summary or "", | |
| "attributes": node.attributes or {}, | |
| "created_at": created_at, | |
| }) | |
| edges_data = [] | |
| for edge in edges: | |
| # 获取时间信息 | |
| created_at = getattr(edge, 'created_at', None) | |
| valid_at = getattr(edge, 'valid_at', None) | |
| invalid_at = getattr(edge, 'invalid_at', None) | |
| expired_at = getattr(edge, 'expired_at', None) | |
| # 获取 episodes | |
| episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) | |
| if episodes and not isinstance(episodes, list): | |
| episodes = [str(episodes)] | |
| elif episodes: | |
| episodes = [str(e) for e in episodes] | |
| # 获取 fact_type | |
| fact_type = getattr(edge, 'fact_type', None) or edge.name or "" | |
| edges_data.append({ | |
| "uuid": edge.uuid_, | |
| "name": edge.name or "", | |
| "fact": edge.fact or "", | |
| "fact_type": fact_type, | |
| "source_node_uuid": edge.source_node_uuid, | |
| "target_node_uuid": edge.target_node_uuid, | |
| "source_node_name": node_map.get(edge.source_node_uuid, ""), | |
| "target_node_name": node_map.get(edge.target_node_uuid, ""), | |
| "attributes": edge.attributes or {}, | |
| "created_at": str(created_at) if created_at else None, | |
| "valid_at": str(valid_at) if valid_at else None, | |
| "invalid_at": str(invalid_at) if invalid_at else None, | |
| "expired_at": str(expired_at) if expired_at else None, | |
| "episodes": episodes or [], | |
| }) | |
| return { | |
| "graph_id": graph_id, | |
| "nodes": nodes_data, | |
| "edges": edges_data, | |
| "node_count": len(nodes_data), | |
| "edge_count": len(edges_data), | |
| } | |
| def delete_graph(self, graph_id: str): | |
| """删除图谱""" | |
| self.client.graph.delete(graph_id=graph_id) | |