Spaces:
Sleeping
Sleeping
| """ | |
| Zep实体读取与过滤服务 | |
| 从Zep图谱中读取节点,筛选出符合预定义实体类型的节点 | |
| """ | |
| import time | |
| from typing import Dict, Any, List, Optional, Set, Callable, TypeVar | |
| from dataclasses import dataclass, field | |
| from zep_cloud.client import Zep | |
| from ..config import Config | |
| from ..utils.logger import get_logger | |
| from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges | |
| logger = get_logger('mirofish.zep_entity_reader') | |
| # 用于泛型返回类型 | |
| T = TypeVar('T') | |
| class EntityNode: | |
| """实体节点数据结构""" | |
| uuid: str | |
| name: str | |
| labels: List[str] | |
| summary: str | |
| attributes: Dict[str, Any] | |
| # 相关的边信息 | |
| related_edges: List[Dict[str, Any]] = field(default_factory=list) | |
| # 相关的其他节点信息 | |
| related_nodes: List[Dict[str, Any]] = field(default_factory=list) | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "uuid": self.uuid, | |
| "name": self.name, | |
| "labels": self.labels, | |
| "summary": self.summary, | |
| "attributes": self.attributes, | |
| "related_edges": self.related_edges, | |
| "related_nodes": self.related_nodes, | |
| } | |
| def get_entity_type(self) -> Optional[str]: | |
| """获取实体类型(排除默认的Entity标签)""" | |
| for label in self.labels: | |
| if label not in ["Entity", "Node"]: | |
| return label | |
| return None | |
| class FilteredEntities: | |
| """过滤后的实体集合""" | |
| entities: List[EntityNode] | |
| entity_types: Set[str] | |
| total_count: int | |
| filtered_count: int | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "entities": [e.to_dict() for e in self.entities], | |
| "entity_types": list(self.entity_types), | |
| "total_count": self.total_count, | |
| "filtered_count": self.filtered_count, | |
| } | |
| class ZepEntityReader: | |
| """ | |
| Zep实体读取与过滤服务 | |
| 主要功能: | |
| 1. 从Zep图谱读取所有节点 | |
| 2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点) | |
| 3. 获取每个实体的相关边和关联节点信息 | |
| """ | |
| def __init__(self, api_key: Optional[str] = None): | |
| self.api_key = api_key or Config.ZEP_API_KEY | |
| if not self.api_key: | |
| raise ValueError("ZEP_API_KEY 未配置") | |
| self.client = Zep(api_key=self.api_key) | |
| def _call_with_retry( | |
| self, | |
| func: Callable[[], T], | |
| operation_name: str, | |
| max_retries: int = 3, | |
| initial_delay: float = 2.0 | |
| ) -> T: | |
| """ | |
| 带重试机制的Zep API调用 | |
| Args: | |
| func: 要执行的函数(无参数的lambda或callable) | |
| operation_name: 操作名称,用于日志 | |
| max_retries: 最大重试次数(默认3次,即最多尝试3次) | |
| initial_delay: 初始延迟秒数 | |
| Returns: | |
| API调用结果 | |
| """ | |
| last_exception = None | |
| delay = initial_delay | |
| for attempt in range(max_retries): | |
| try: | |
| return func() | |
| except Exception as e: | |
| last_exception = e | |
| if attempt < max_retries - 1: | |
| logger.warning( | |
| f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, " | |
| f"{delay:.1f}秒后重试..." | |
| ) | |
| time.sleep(delay) | |
| delay *= 2 # 指数退避 | |
| else: | |
| logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}") | |
| raise last_exception | |
| def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: | |
| """ | |
| 获取图谱的所有节点(分页获取) | |
| Args: | |
| graph_id: 图谱ID | |
| Returns: | |
| 节点列表 | |
| """ | |
| logger.info(f"获取图谱 {graph_id} 的所有节点...") | |
| nodes = fetch_all_nodes(self.client, graph_id) | |
| nodes_data = [] | |
| for node in nodes: | |
| nodes_data.append({ | |
| "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), | |
| "name": node.name or "", | |
| "labels": node.labels or [], | |
| "summary": node.summary or "", | |
| "attributes": node.attributes or {}, | |
| }) | |
| logger.info(f"共获取 {len(nodes_data)} 个节点") | |
| return nodes_data | |
| def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: | |
| """ | |
| 获取图谱的所有边(分页获取) | |
| Args: | |
| graph_id: 图谱ID | |
| Returns: | |
| 边列表 | |
| """ | |
| logger.info(f"获取图谱 {graph_id} 的所有边...") | |
| edges = fetch_all_edges(self.client, graph_id) | |
| edges_data = [] | |
| for edge in edges: | |
| edges_data.append({ | |
| "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), | |
| "name": edge.name or "", | |
| "fact": edge.fact or "", | |
| "source_node_uuid": edge.source_node_uuid, | |
| "target_node_uuid": edge.target_node_uuid, | |
| "attributes": edge.attributes or {}, | |
| }) | |
| logger.info(f"共获取 {len(edges_data)} 条边") | |
| return edges_data | |
| def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: | |
| """ | |
| 获取指定节点的所有相关边(带重试机制) | |
| Args: | |
| node_uuid: 节点UUID | |
| Returns: | |
| 边列表 | |
| """ | |
| try: | |
| # 使用重试机制调用Zep API | |
| edges = self._call_with_retry( | |
| func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), | |
| operation_name=f"获取节点边(node={node_uuid[:8]}...)" | |
| ) | |
| edges_data = [] | |
| for edge in edges: | |
| edges_data.append({ | |
| "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), | |
| "name": edge.name or "", | |
| "fact": edge.fact or "", | |
| "source_node_uuid": edge.source_node_uuid, | |
| "target_node_uuid": edge.target_node_uuid, | |
| "attributes": edge.attributes or {}, | |
| }) | |
| return edges_data | |
| except Exception as e: | |
| logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}") | |
| return [] | |
| def filter_defined_entities( | |
| self, | |
| graph_id: str, | |
| defined_entity_types: Optional[List[str]] = None, | |
| enrich_with_edges: bool = True | |
| ) -> FilteredEntities: | |
| """ | |
| 筛选出符合预定义实体类型的节点 | |
| 筛选逻辑: | |
| - 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过 | |
| - 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留 | |
| Args: | |
| graph_id: 图谱ID | |
| defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型) | |
| enrich_with_edges: 是否获取每个实体的相关边信息 | |
| Returns: | |
| FilteredEntities: 过滤后的实体集合 | |
| """ | |
| logger.info(f"开始筛选图谱 {graph_id} 的实体...") | |
| # 获取所有节点 | |
| all_nodes = self.get_all_nodes(graph_id) | |
| total_count = len(all_nodes) | |
| # 获取所有边(用于后续关联查找) | |
| all_edges = self.get_all_edges(graph_id) if enrich_with_edges else [] | |
| # 构建节点UUID到节点数据的映射 | |
| node_map = {n["uuid"]: n for n in all_nodes} | |
| # 筛选符合条件的实体 | |
| filtered_entities = [] | |
| entity_types_found = set() | |
| for node in all_nodes: | |
| labels = node.get("labels", []) | |
| # 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签 | |
| custom_labels = [l for l in labels if l not in ["Entity", "Node"]] | |
| if not custom_labels: | |
| # 只有默认标签,跳过 | |
| continue | |
| # 如果指定了预定义类型,检查是否匹配 | |
| if defined_entity_types: | |
| matching_labels = [l for l in custom_labels if l in defined_entity_types] | |
| if not matching_labels: | |
| continue | |
| entity_type = matching_labels[0] | |
| else: | |
| entity_type = custom_labels[0] | |
| entity_types_found.add(entity_type) | |
| # 创建实体节点对象 | |
| entity = EntityNode( | |
| uuid=node["uuid"], | |
| name=node["name"], | |
| labels=labels, | |
| summary=node["summary"], | |
| attributes=node["attributes"], | |
| ) | |
| # 获取相关边和节点 | |
| if enrich_with_edges: | |
| related_edges = [] | |
| related_node_uuids = set() | |
| for edge in all_edges: | |
| if edge["source_node_uuid"] == node["uuid"]: | |
| related_edges.append({ | |
| "direction": "outgoing", | |
| "edge_name": edge["name"], | |
| "fact": edge["fact"], | |
| "target_node_uuid": edge["target_node_uuid"], | |
| }) | |
| related_node_uuids.add(edge["target_node_uuid"]) | |
| elif edge["target_node_uuid"] == node["uuid"]: | |
| related_edges.append({ | |
| "direction": "incoming", | |
| "edge_name": edge["name"], | |
| "fact": edge["fact"], | |
| "source_node_uuid": edge["source_node_uuid"], | |
| }) | |
| related_node_uuids.add(edge["source_node_uuid"]) | |
| entity.related_edges = related_edges | |
| # 获取关联节点的基本信息 | |
| related_nodes = [] | |
| for related_uuid in related_node_uuids: | |
| if related_uuid in node_map: | |
| related_node = node_map[related_uuid] | |
| related_nodes.append({ | |
| "uuid": related_node["uuid"], | |
| "name": related_node["name"], | |
| "labels": related_node["labels"], | |
| "summary": related_node.get("summary", ""), | |
| }) | |
| entity.related_nodes = related_nodes | |
| filtered_entities.append(entity) | |
| logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, " | |
| f"实体类型: {entity_types_found}") | |
| return FilteredEntities( | |
| entities=filtered_entities, | |
| entity_types=entity_types_found, | |
| total_count=total_count, | |
| filtered_count=len(filtered_entities), | |
| ) | |
| def get_entity_with_context( | |
| self, | |
| graph_id: str, | |
| entity_uuid: str | |
| ) -> Optional[EntityNode]: | |
| """ | |
| 获取单个实体及其完整上下文(边和关联节点,带重试机制) | |
| Args: | |
| graph_id: 图谱ID | |
| entity_uuid: 实体UUID | |
| Returns: | |
| EntityNode或None | |
| """ | |
| try: | |
| # 使用重试机制获取节点 | |
| node = self._call_with_retry( | |
| func=lambda: self.client.graph.node.get(uuid_=entity_uuid), | |
| operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" | |
| ) | |
| if not node: | |
| return None | |
| # 获取节点的边 | |
| edges = self.get_node_edges(entity_uuid) | |
| # 获取所有节点用于关联查找 | |
| all_nodes = self.get_all_nodes(graph_id) | |
| node_map = {n["uuid"]: n for n in all_nodes} | |
| # 处理相关边和节点 | |
| related_edges = [] | |
| related_node_uuids = set() | |
| for edge in edges: | |
| if edge["source_node_uuid"] == entity_uuid: | |
| related_edges.append({ | |
| "direction": "outgoing", | |
| "edge_name": edge["name"], | |
| "fact": edge["fact"], | |
| "target_node_uuid": edge["target_node_uuid"], | |
| }) | |
| related_node_uuids.add(edge["target_node_uuid"]) | |
| else: | |
| related_edges.append({ | |
| "direction": "incoming", | |
| "edge_name": edge["name"], | |
| "fact": edge["fact"], | |
| "source_node_uuid": edge["source_node_uuid"], | |
| }) | |
| related_node_uuids.add(edge["source_node_uuid"]) | |
| # 获取关联节点信息 | |
| related_nodes = [] | |
| for related_uuid in related_node_uuids: | |
| if related_uuid in node_map: | |
| related_node = node_map[related_uuid] | |
| related_nodes.append({ | |
| "uuid": related_node["uuid"], | |
| "name": related_node["name"], | |
| "labels": related_node["labels"], | |
| "summary": related_node.get("summary", ""), | |
| }) | |
| return EntityNode( | |
| uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), | |
| name=node.name or "", | |
| labels=node.labels or [], | |
| summary=node.summary or "", | |
| attributes=node.attributes or {}, | |
| related_edges=related_edges, | |
| related_nodes=related_nodes, | |
| ) | |
| except Exception as e: | |
| logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}") | |
| return None | |
| def get_entities_by_type( | |
| self, | |
| graph_id: str, | |
| entity_type: str, | |
| enrich_with_edges: bool = True | |
| ) -> List[EntityNode]: | |
| """ | |
| 获取指定类型的所有实体 | |
| Args: | |
| graph_id: 图谱ID | |
| entity_type: 实体类型(如 "Student", "PublicFigure" 等) | |
| enrich_with_edges: 是否获取相关边信息 | |
| Returns: | |
| 实体列表 | |
| """ | |
| result = self.filter_defined_entities( | |
| graph_id=graph_id, | |
| defined_entity_types=[entity_type], | |
| enrich_with_edges=enrich_with_edges | |
| ) | |
| return result.entities | |