Spaces:
Running
Running
| """ | |
| OpenAI Agent Knowledge Extraction Method | |
| Two-stage pipeline: | |
| 1. Direct LLM Method -> Generate Agent Graph | |
| 2. Validation Agent -> Validate the generated graph only (no trace data) | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Tuple | |
| from agents import Agent, Runner, function_tool | |
| from pydantic import BaseModel | |
| from agentgraph.methods.baseline.base_method import BaseKnowledgeExtractionMethod | |
| from agentgraph.methods.baseline.pydantic_method import PydanticKnowledgeExtractor | |
| from agentgraph.shared.models.direct_based import Entity, Relation | |
| # Import graph builder prompts for validation agent | |
| from evaluation.knowledge_extraction.utils.prompts import ( | |
| GRAPH_BUILDER_INSTRUCTION_PROMPT, | |
| GRAPH_BUILDER_SYSTEM_PROMPT, | |
| ) | |
| class AdditionalGraphData(BaseModel): | |
| """New entities and relations to add to existing knowledge graph""" | |
| entities: List[Entity] = [] | |
| relations: List[Relation] = [] | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Set higher log levels for noisy libraries | |
| logging.getLogger("openai").setLevel(logging.WARNING) | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| def check_task_connectivity(kg_dict: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Check if all tasks are connected via 'next' or 'subtask' relations. | |
| Args: | |
| kg_dict: Knowledge graph dictionary | |
| Returns: | |
| Dictionary with connectivity check results | |
| """ | |
| entities = kg_dict.get("entities", []) | |
| relations = kg_dict.get("relations", []) | |
| # Find all Task entities | |
| task_entities = [e for e in entities if e.get("type") == "Task"] | |
| task_ids = {e.get("id") for e in task_entities} | |
| if len(task_entities) <= 1: | |
| return { | |
| "task_connectivity_valid": True, | |
| "disconnected_tasks": [], | |
| "total_tasks": len(task_entities), | |
| "message": "Single or no tasks - connectivity not required" | |
| } | |
| # Find all task-to-task connections via 'NEXT' or 'SUBTASK' relations | |
| task_connections = set() | |
| for relation in relations: | |
| rel_type = relation.get("type", "").upper() | |
| source_id = relation.get("source") | |
| target_id = relation.get("target") | |
| if rel_type in ["NEXT", "SUBTASK"] and source_id in task_ids and target_id in task_ids: | |
| task_connections.add((source_id, target_id)) | |
| task_connections.add((target_id, source_id)) # Treat as bidirectional for connectivity | |
| # Check connectivity using graph traversal | |
| if not task_connections: | |
| disconnected_tasks = list(task_ids) if len(task_ids) > 1 else [] | |
| return { | |
| "task_connectivity_valid": len(task_ids) <= 1, | |
| "disconnected_tasks": disconnected_tasks, | |
| "total_tasks": len(task_entities), | |
| "message": f"No task connections found. All {len(task_ids)} tasks are disconnected." | |
| } | |
| # Build adjacency list | |
| adjacency = {task_id: set() for task_id in task_ids} | |
| for source, target in task_connections: | |
| adjacency[source].add(target) | |
| # Find connected components using DFS | |
| visited = set() | |
| connected_components = [] | |
| def dfs(node, component): | |
| if node in visited: | |
| return | |
| visited.add(node) | |
| component.add(node) | |
| for neighbor in adjacency[node]: | |
| dfs(neighbor, component) | |
| for task_id in task_ids: | |
| if task_id not in visited: | |
| component = set() | |
| dfs(task_id, component) | |
| connected_components.append(component) | |
| # All tasks should be linked together | |
| is_connected = len(connected_components) == 1 | |
| disconnected_tasks = [] | |
| if not is_connected: | |
| # Find tasks in smaller components (not the largest one) | |
| largest_component = max(connected_components, key=len) | |
| for component in connected_components: | |
| if component != largest_component: | |
| disconnected_tasks.extend(list(component)) | |
| return { | |
| "task_connectivity_valid": is_connected, | |
| "disconnected_tasks": disconnected_tasks, | |
| "total_tasks": len(task_entities), | |
| "connected_components": len(connected_components), | |
| "task_connections_found": len(task_connections) // 2, # Divide by 2 since we added bidirectional | |
| "message": f"Found {len(connected_components)} connected components. Tasks should be in 1 component." | |
| } | |
| def validate_knowledge_graph(kg_data: str) -> str: | |
| """ | |
| Validate knowledge graph using BaseKnowledgeExtractionMethod.check_success. | |
| Args: | |
| kg_data: Knowledge graph data as JSON string | |
| Returns: | |
| Validation result as JSON string with success status and details | |
| """ | |
| print("Validation Agent Tool Called: Validating knowledge graph") | |
| try: | |
| kg_dict = json.loads(kg_data) | |
| result = BaseKnowledgeExtractionMethod.check_success(kg_dict) | |
| # Add task connectivity validation | |
| task_connectivity = check_task_connectivity(kg_dict) | |
| result["validation"]["task_connectivity"] = task_connectivity | |
| # Update success status based on task connectivity | |
| if not task_connectivity["task_connectivity_valid"]: | |
| result["success"] = False | |
| if "missing_required_types" not in result["validation"]: | |
| result["validation"]["missing_required_types"] = [] | |
| result["validation"]["missing_required_types"].append("Connected Tasks (all tasks must be linked together via 'NEXT' or 'SUBTASK')") | |
| print("Validation Result: ", result) | |
| return json.dumps(result) | |
| except Exception as e: | |
| result = { | |
| "success": False, | |
| "validation": { | |
| "error": str(e), | |
| "entity_counts": {}, | |
| "relation_counts": {}, | |
| "missing_required_types": [], | |
| "invalid_relations": [], | |
| "isolated_entities": [], | |
| "total_entities": 0, | |
| "total_relations": 0, | |
| "task_connectivity": { | |
| "task_connectivity_valid": False, | |
| "disconnected_tasks": [], | |
| "total_tasks": 0, | |
| } | |
| } | |
| } | |
| print("Validation Error: ", result) | |
| return json.dumps(result) | |
| async def create_validation_agent() -> Agent: | |
| """Create validation + improvement agent that receives generated graph and fixes it.""" | |
| return Agent( | |
| name="Graph Validator & Improver", | |
| instructions=f"""You are a Knowledge Graph Validator and Improver. | |
| {GRAPH_BUILDER_SYSTEM_PROMPT} | |
| Your job is to validate a knowledge graph and OUTPUT ONLY NEW entities/relations to add: | |
| CRITICAL RULES: | |
| - You will ONLY receive the generated knowledge graph JSON (no trace data) | |
| - You will NOT do entity/relation extraction from scratch | |
| - NEVER DELETE OR MODIFY existing entities or relations | |
| - OUTPUT ONLY the NEW entities and relations that need to be ADDED | |
| - Do NOT repeat existing entities or relations in your output | |
| VALIDATION & IMPROVEMENT PROCESS: | |
| 1. Use validate_knowledge_graph tool to check the provided graph | |
| 2. If validation fails, IDENTIFY what's missing (do not remove anything) | |
| 3. ADD missing required entity types (Agent, Task, Input, Output) if needed | |
| 4. ADD missing connections between isolated entities | |
| 5. ENSURE ALL TASKS ARE LINKED: Add 'NEXT' or 'SUBTASK' relations to connect isolated task groups | |
| 6. APPEND new relations to connect disconnected entities | |
| 7. Return ONLY the NEW entities and relations to be added | |
| CRITICAL TASK CONNECTIVITY REQUIREMENT: | |
| - ALL Task entities must be linked together via 'NEXT' or 'SUBTASK' relations | |
| - If validation shows disconnected_tasks, output NEW 'NEXT' relations between them | |
| - Use relation type 'NEXT' to connect Task_A → Task_B in sequential order | |
| - Use relation type 'SUBTASK' for parent-child Task relationships | |
| - Keep adding relations until task_connectivity_valid becomes true | |
| OUTPUT FORMAT: | |
| - entities: [list of new entities to add] | |
| - relations: [list of new relations to connect tasks] | |
| {GRAPH_BUILDER_INSTRUCTION_PROMPT} | |
| Focus on outputting ONLY the missing pieces that need to be added, never the existing ones.""", | |
| tools=[validate_knowledge_graph], | |
| output_type=AdditionalGraphData # Return only new entities/relations to add | |
| ) | |
| async def validate_and_improve_graph(kg_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any], bool]: | |
| """ | |
| Validate and improve knowledge graph using validation agent with external iterative loop. | |
| Args: | |
| kg_dict: Generated knowledge graph dictionary | |
| Returns: | |
| Tuple of (improved_kg_dict, validation_result, is_valid) | |
| """ | |
| validation_agent = await create_validation_agent() | |
| current_kg = kg_dict.copy() | |
| total_entities_added = 0 | |
| total_relations_added = 0 | |
| iteration_count = 0 | |
| max_iterations = 5 | |
| while iteration_count < max_iterations: | |
| iteration_count += 1 | |
| logger.info(f"Validation iteration {iteration_count}/{max_iterations}") | |
| # Convert current graph to JSON for validation | |
| kg_json = json.dumps(current_kg, indent=2) | |
| result = await Runner.run( | |
| validation_agent, | |
| f"""Please validate this knowledge graph and output ONLY NEW additions needed. | |
| Current Graph: | |
| {kg_json} | |
| VALIDATION INSTRUCTIONS: | |
| 1. Use validate_knowledge_graph tool to check what's wrong | |
| 2. If task_connectivity_valid is FALSE, add relations between disconnected tasks | |
| 3. Fix other validation issues | |
| 4. Return ONLY the NEW entities/relations to add (not existing ones) | |
| Output format: entities: [], relations: []""", | |
| max_turns=5 | |
| ) | |
| # Get additional data from agent | |
| additional_data = result.final_output | |
| # Convert to dict if needed | |
| if hasattr(additional_data, 'model_dump'): | |
| additional_dict = additional_data.model_dump() | |
| elif hasattr(additional_data, 'dict'): | |
| additional_dict = additional_data.dict() | |
| elif isinstance(additional_data, dict): | |
| additional_dict = additional_data | |
| else: | |
| # Fallback to no additions | |
| logger.warning(f"Iteration {iteration_count}: Agent didn't return proper format") | |
| additional_dict = {"entities": [], "relations": []} | |
| # Check if any additions were made | |
| entities_added_this_iter = len(additional_dict.get("entities", [])) | |
| relations_added_this_iter = len(additional_dict.get("relations", [])) | |
| if entities_added_this_iter == 0 and relations_added_this_iter == 0: | |
| logger.info(f"Iteration {iteration_count}: No additions needed, validation complete") | |
| break | |
| # Add new entities and relations to current graph | |
| if additional_dict.get("entities"): | |
| # Convert Entity objects to dicts | |
| entity_dicts = [entity.dict() if hasattr(entity, 'dict') else entity for entity in additional_dict["entities"]] | |
| current_kg["entities"] = current_kg.get("entities", []) + entity_dicts | |
| total_entities_added += entities_added_this_iter | |
| if additional_dict.get("relations"): | |
| # Convert Relation objects to dicts | |
| relation_dicts = [relation.dict() if hasattr(relation, 'dict') else relation for relation in additional_dict["relations"]] | |
| current_kg["relations"] = current_kg.get("relations", []) + relation_dicts | |
| total_relations_added += relations_added_this_iter | |
| logger.info(f"Iteration {iteration_count}: Added {entities_added_this_iter} entities, {relations_added_this_iter} relations") | |
| # Check if validation passes now | |
| current_validation = BaseKnowledgeExtractionMethod.check_success(current_kg) | |
| if current_validation.get("success", False): | |
| logger.info(f"Validation passed after {iteration_count} iterations!") | |
| break | |
| # Final validation result | |
| validation_result = { | |
| "iterations": iteration_count, | |
| "improvement_applied": total_entities_added > 0 or total_relations_added > 0, | |
| "original_entities": len(kg_dict.get("entities", [])), | |
| "improved_entities": len(current_kg.get("entities", [])), | |
| "original_relations": len(kg_dict.get("relations", [])), | |
| "improved_relations": len(current_kg.get("relations", [])), | |
| "entities_added": total_entities_added, | |
| "relations_added": total_relations_added | |
| } | |
| # Final validation check | |
| final_validation = BaseKnowledgeExtractionMethod.check_success(current_kg) | |
| is_valid = final_validation.get("success", False) | |
| original_entities = len(kg_dict.get("entities", [])) | |
| original_relations = len(kg_dict.get("relations", [])) | |
| improved_entities = len(current_kg.get("entities", [])) | |
| improved_relations = len(current_kg.get("relations", [])) | |
| logger.info(f"Final result: {original_entities} → {improved_entities} entities (+{validation_result['entities_added']})") | |
| logger.info(f"Final result: {original_relations} → {improved_relations} relations (+{validation_result['relations_added']})") | |
| logger.info(f"Validation {'PASSED' if is_valid else 'FAILED'} after {iteration_count} iterations") | |
| return current_kg, validation_result, is_valid | |
| class OpenAIAgentKnowledgeExtractor(BaseKnowledgeExtractionMethod): | |
| """ | |
| Two-stage pipeline: | |
| 1. Direct LLM Method -> Generate Agent Graph | |
| 2. Validation Agent -> Validate generated graph only | |
| """ | |
| def __init__(self, model: str = "gpt-5-mini", **kwargs): | |
| super().__init__("openai_agent_method", **kwargs) | |
| self.model = model | |
| # Initialize Pydantic Hybrid extractor | |
| self.pydantic_extractor = PydanticKnowledgeExtractor(model=model, mode="hybrid_2_stage", **kwargs) | |
| # Set OpenAI API key | |
| if "OPENAI_API_KEY" not in os.environ: | |
| logger.warning("OPENAI_API_KEY not set in environment") | |
| def _extract_with_pipeline(self, text: str) -> Dict[str, Any]: | |
| """ | |
| Two-stage extraction pipeline: | |
| 1. Direct LLM -> Generate graph | |
| 2. Validation Agent -> Validate graph only | |
| """ | |
| try: | |
| logger.info("Stage 1: Generating graph with Pydantic Hybrid Method") | |
| # Stage 1: Generate graph using Pydantic Hybrid | |
| direct_result = self.pydantic_extractor.process_text(text) | |
| if not direct_result.get("success", False): | |
| logger.error("Pydantic Hybrid extraction failed") | |
| return direct_result | |
| kg_data = direct_result.get("kg_data", {}) | |
| logger.info(f"Pydantic Hybrid generated {len(kg_data.get('entities', []))} entities and {len(kg_data.get('relations', []))} relations") | |
| # Stage 2: Validate and Improve with Validation Agent (graph only, no trace data) | |
| logger.info("Stage 2: Validating and improving graph with Validation Agent") | |
| improved_kg_data, validation_result, is_valid = asyncio.run(validate_and_improve_graph(kg_data)) | |
| # Add validation metadata to improved graph | |
| improved_kg_data["metadata"] = improved_kg_data.get("metadata", {}) | |
| improved_kg_data["metadata"].update({ | |
| "method": "openai_agent_pipeline", | |
| "pipeline_stages": ["pydantic_hybrid", "validation_improvement_agent"], | |
| "validation_result": validation_result, | |
| "validation_passed": is_valid, | |
| "model": self.model, | |
| "timestamp": datetime.now().isoformat(), | |
| "improvement_stats": { | |
| "entities_added": validation_result.get("entities_added", 0), | |
| "relations_added": validation_result.get("relations_added", 0), | |
| "iterations": validation_result.get("iterations", 0) | |
| } | |
| }) | |
| logger.info(f"Validation {'PASSED' if is_valid else 'FAILED'}") | |
| entities_added = improved_kg_data['metadata']['improvement_stats']['entities_added'] | |
| relations_added = improved_kg_data['metadata']['improvement_stats']['relations_added'] | |
| iterations = improved_kg_data['metadata']['improvement_stats']['iterations'] | |
| logger.info(f"Graph improved: +{entities_added} entities, +{relations_added} relations in {iterations} iterations") | |
| return { | |
| "success": True, | |
| "kg_data": improved_kg_data, # Return improved graph | |
| "validation_passed": is_valid, | |
| "validation_details": validation_result | |
| } | |
| except Exception as e: | |
| logger.error(f"Pipeline extraction failed: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "kg_data": { | |
| "entities": [], | |
| "relations": [], | |
| "system_name": "Failed Extraction", | |
| "system_summary": "Pipeline extraction failed.", | |
| "metadata": {"error": str(e), "method": "openai_agent_pipeline"}, | |
| }, | |
| } | |
| def process_text(self, text: str) -> Dict[str, Any]: | |
| """ | |
| Process input text using two-stage pipeline. | |
| Args: | |
| text: Input text to process | |
| Returns: | |
| Dictionary with kg_data, metadata, success, and validation info | |
| """ | |
| start_time = time.time() | |
| try: | |
| logger.info(f"Processing text with OpenAI Agent Pipeline (length: {len(text)})") | |
| # Run two-stage pipeline | |
| result = self._extract_with_pipeline(text) | |
| processing_time = time.time() - start_time | |
| if result.get("success", False): | |
| kg_data = result.get("kg_data", {}) | |
| entity_count = len(kg_data.get("entities", [])) | |
| relation_count = len(kg_data.get("relations", [])) | |
| return { | |
| "success": True, | |
| "kg_data": kg_data, | |
| "metadata": { | |
| "approach": "openai_agent_pipeline", | |
| "model": self.model, | |
| "method": self.method_name, | |
| "processing_time_seconds": processing_time, | |
| "entity_count": entity_count, | |
| "relation_count": relation_count, | |
| "validation_passed": result.get("validation_passed", False), | |
| "pipeline_stages": ["pydantic_hybrid", "validation_agent"], | |
| }, | |
| } | |
| else: | |
| return { | |
| "success": False, | |
| "error": result.get("error", "Unknown error"), | |
| "kg_data": result.get("kg_data", {"entities": [], "relations": []}), | |
| "metadata": { | |
| "approach": "openai_agent_pipeline", | |
| "model": self.model, | |
| "method": self.method_name, | |
| "processing_time_seconds": processing_time, | |
| "error": result.get("error", "Unknown error"), | |
| }, | |
| } | |
| except Exception as e: | |
| processing_time = time.time() - start_time | |
| logger.error(f"Error in OpenAI Agent pipeline: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "kg_data": {"entities": [], "relations": []}, | |
| "metadata": { | |
| "approach": "openai_agent_pipeline", | |
| "model": self.model, | |
| "method": self.method_name, | |
| "processing_time_seconds": processing_time, | |
| "error": str(e), | |
| }, | |
| } | |
| def extract_knowledge_graph(self, trace_data: str) -> Dict[str, Any]: | |
| """ | |
| Extract knowledge graph using two-stage pipeline. | |
| Args: | |
| trace_data: Agent trace data as JSON string | |
| Returns: | |
| Dictionary with entities and relations | |
| """ | |
| try: | |
| logger.info(f"OpenAI Agent Pipeline called with trace_data type: {type(trace_data)}") | |
| if isinstance(trace_data, str): | |
| logger.info(f"trace_data length: {len(trace_data)}") | |
| # Process the trace data with pipeline | |
| result = self.process_text(trace_data) | |
| # Return just the knowledge graph data | |
| if result.get("success", False): | |
| return result.get("kg_data", {"entities": [], "relations": []}) | |
| else: | |
| # Return empty knowledge graph on failure | |
| return {"entities": [], "relations": []} | |
| except Exception as e: | |
| logger.error(f"Error in extract_knowledge_graph: {e}") | |
| return {"entities": [], "relations": []} | |