""" OpenAI Agent Knowledge Extraction Method Two-stage pipeline: 1. Direct LLM Method -> Generate Agent Graph 2. Validation Agent -> Validate the generated graph only (no trace data) """ import asyncio import json import logging import os import time from datetime import datetime from typing import Any, Dict, List, Tuple from agents import Agent, Runner, function_tool from pydantic import BaseModel from agentgraph.methods.baseline.base_method import BaseKnowledgeExtractionMethod from agentgraph.methods.baseline.pydantic_method import PydanticKnowledgeExtractor from agentgraph.shared.models.direct_based import Entity, Relation # Import graph builder prompts for validation agent from evaluation.knowledge_extraction.utils.prompts import ( GRAPH_BUILDER_INSTRUCTION_PROMPT, GRAPH_BUILDER_SYSTEM_PROMPT, ) class AdditionalGraphData(BaseModel): """New entities and relations to add to existing knowledge graph""" entities: List[Entity] = [] relations: List[Relation] = [] # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Set higher log levels for noisy libraries logging.getLogger("openai").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING) def check_task_connectivity(kg_dict: Dict[str, Any]) -> Dict[str, Any]: """ Check if all tasks are connected via 'next' or 'subtask' relations. Args: kg_dict: Knowledge graph dictionary Returns: Dictionary with connectivity check results """ entities = kg_dict.get("entities", []) relations = kg_dict.get("relations", []) # Find all Task entities task_entities = [e for e in entities if e.get("type") == "Task"] task_ids = {e.get("id") for e in task_entities} if len(task_entities) <= 1: return { "task_connectivity_valid": True, "disconnected_tasks": [], "total_tasks": len(task_entities), "message": "Single or no tasks - connectivity not required" } # Find all task-to-task connections via 'NEXT' or 'SUBTASK' relations task_connections = set() for relation in relations: rel_type = relation.get("type", "").upper() source_id = relation.get("source") target_id = relation.get("target") if rel_type in ["NEXT", "SUBTASK"] and source_id in task_ids and target_id in task_ids: task_connections.add((source_id, target_id)) task_connections.add((target_id, source_id)) # Treat as bidirectional for connectivity # Check connectivity using graph traversal if not task_connections: disconnected_tasks = list(task_ids) if len(task_ids) > 1 else [] return { "task_connectivity_valid": len(task_ids) <= 1, "disconnected_tasks": disconnected_tasks, "total_tasks": len(task_entities), "message": f"No task connections found. All {len(task_ids)} tasks are disconnected." } # Build adjacency list adjacency = {task_id: set() for task_id in task_ids} for source, target in task_connections: adjacency[source].add(target) # Find connected components using DFS visited = set() connected_components = [] def dfs(node, component): if node in visited: return visited.add(node) component.add(node) for neighbor in adjacency[node]: dfs(neighbor, component) for task_id in task_ids: if task_id not in visited: component = set() dfs(task_id, component) connected_components.append(component) # All tasks should be linked together is_connected = len(connected_components) == 1 disconnected_tasks = [] if not is_connected: # Find tasks in smaller components (not the largest one) largest_component = max(connected_components, key=len) for component in connected_components: if component != largest_component: disconnected_tasks.extend(list(component)) return { "task_connectivity_valid": is_connected, "disconnected_tasks": disconnected_tasks, "total_tasks": len(task_entities), "connected_components": len(connected_components), "task_connections_found": len(task_connections) // 2, # Divide by 2 since we added bidirectional "message": f"Found {len(connected_components)} connected components. Tasks should be in 1 component." } @function_tool def validate_knowledge_graph(kg_data: str) -> str: """ Validate knowledge graph using BaseKnowledgeExtractionMethod.check_success. Args: kg_data: Knowledge graph data as JSON string Returns: Validation result as JSON string with success status and details """ print("Validation Agent Tool Called: Validating knowledge graph") try: kg_dict = json.loads(kg_data) result = BaseKnowledgeExtractionMethod.check_success(kg_dict) # Add task connectivity validation task_connectivity = check_task_connectivity(kg_dict) result["validation"]["task_connectivity"] = task_connectivity # Update success status based on task connectivity if not task_connectivity["task_connectivity_valid"]: result["success"] = False if "missing_required_types" not in result["validation"]: result["validation"]["missing_required_types"] = [] result["validation"]["missing_required_types"].append("Connected Tasks (all tasks must be linked together via 'NEXT' or 'SUBTASK')") print("Validation Result: ", result) return json.dumps(result) except Exception as e: result = { "success": False, "validation": { "error": str(e), "entity_counts": {}, "relation_counts": {}, "missing_required_types": [], "invalid_relations": [], "isolated_entities": [], "total_entities": 0, "total_relations": 0, "task_connectivity": { "task_connectivity_valid": False, "disconnected_tasks": [], "total_tasks": 0, } } } print("Validation Error: ", result) return json.dumps(result) async def create_validation_agent() -> Agent: """Create validation + improvement agent that receives generated graph and fixes it.""" return Agent( name="Graph Validator & Improver", instructions=f"""You are a Knowledge Graph Validator and Improver. {GRAPH_BUILDER_SYSTEM_PROMPT} Your job is to validate a knowledge graph and OUTPUT ONLY NEW entities/relations to add: CRITICAL RULES: - You will ONLY receive the generated knowledge graph JSON (no trace data) - You will NOT do entity/relation extraction from scratch - NEVER DELETE OR MODIFY existing entities or relations - OUTPUT ONLY the NEW entities and relations that need to be ADDED - Do NOT repeat existing entities or relations in your output VALIDATION & IMPROVEMENT PROCESS: 1. Use validate_knowledge_graph tool to check the provided graph 2. If validation fails, IDENTIFY what's missing (do not remove anything) 3. ADD missing required entity types (Agent, Task, Input, Output) if needed 4. ADD missing connections between isolated entities 5. ENSURE ALL TASKS ARE LINKED: Add 'NEXT' or 'SUBTASK' relations to connect isolated task groups 6. APPEND new relations to connect disconnected entities 7. Return ONLY the NEW entities and relations to be added CRITICAL TASK CONNECTIVITY REQUIREMENT: - ALL Task entities must be linked together via 'NEXT' or 'SUBTASK' relations - If validation shows disconnected_tasks, output NEW 'NEXT' relations between them - Use relation type 'NEXT' to connect Task_A → Task_B in sequential order - Use relation type 'SUBTASK' for parent-child Task relationships - Keep adding relations until task_connectivity_valid becomes true OUTPUT FORMAT: - entities: [list of new entities to add] - relations: [list of new relations to connect tasks] {GRAPH_BUILDER_INSTRUCTION_PROMPT} Focus on outputting ONLY the missing pieces that need to be added, never the existing ones.""", tools=[validate_knowledge_graph], output_type=AdditionalGraphData # Return only new entities/relations to add ) async def validate_and_improve_graph(kg_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any], bool]: """ Validate and improve knowledge graph using validation agent with external iterative loop. Args: kg_dict: Generated knowledge graph dictionary Returns: Tuple of (improved_kg_dict, validation_result, is_valid) """ validation_agent = await create_validation_agent() current_kg = kg_dict.copy() total_entities_added = 0 total_relations_added = 0 iteration_count = 0 max_iterations = 5 while iteration_count < max_iterations: iteration_count += 1 logger.info(f"Validation iteration {iteration_count}/{max_iterations}") # Convert current graph to JSON for validation kg_json = json.dumps(current_kg, indent=2) result = await Runner.run( validation_agent, f"""Please validate this knowledge graph and output ONLY NEW additions needed. Current Graph: {kg_json} VALIDATION INSTRUCTIONS: 1. Use validate_knowledge_graph tool to check what's wrong 2. If task_connectivity_valid is FALSE, add relations between disconnected tasks 3. Fix other validation issues 4. Return ONLY the NEW entities/relations to add (not existing ones) Output format: entities: [], relations: []""", max_turns=5 ) # Get additional data from agent additional_data = result.final_output # Convert to dict if needed if hasattr(additional_data, 'model_dump'): additional_dict = additional_data.model_dump() elif hasattr(additional_data, 'dict'): additional_dict = additional_data.dict() elif isinstance(additional_data, dict): additional_dict = additional_data else: # Fallback to no additions logger.warning(f"Iteration {iteration_count}: Agent didn't return proper format") additional_dict = {"entities": [], "relations": []} # Check if any additions were made entities_added_this_iter = len(additional_dict.get("entities", [])) relations_added_this_iter = len(additional_dict.get("relations", [])) if entities_added_this_iter == 0 and relations_added_this_iter == 0: logger.info(f"Iteration {iteration_count}: No additions needed, validation complete") break # Add new entities and relations to current graph if additional_dict.get("entities"): # Convert Entity objects to dicts entity_dicts = [entity.dict() if hasattr(entity, 'dict') else entity for entity in additional_dict["entities"]] current_kg["entities"] = current_kg.get("entities", []) + entity_dicts total_entities_added += entities_added_this_iter if additional_dict.get("relations"): # Convert Relation objects to dicts relation_dicts = [relation.dict() if hasattr(relation, 'dict') else relation for relation in additional_dict["relations"]] current_kg["relations"] = current_kg.get("relations", []) + relation_dicts total_relations_added += relations_added_this_iter logger.info(f"Iteration {iteration_count}: Added {entities_added_this_iter} entities, {relations_added_this_iter} relations") # Check if validation passes now current_validation = BaseKnowledgeExtractionMethod.check_success(current_kg) if current_validation.get("success", False): logger.info(f"Validation passed after {iteration_count} iterations!") break # Final validation result validation_result = { "iterations": iteration_count, "improvement_applied": total_entities_added > 0 or total_relations_added > 0, "original_entities": len(kg_dict.get("entities", [])), "improved_entities": len(current_kg.get("entities", [])), "original_relations": len(kg_dict.get("relations", [])), "improved_relations": len(current_kg.get("relations", [])), "entities_added": total_entities_added, "relations_added": total_relations_added } # Final validation check final_validation = BaseKnowledgeExtractionMethod.check_success(current_kg) is_valid = final_validation.get("success", False) original_entities = len(kg_dict.get("entities", [])) original_relations = len(kg_dict.get("relations", [])) improved_entities = len(current_kg.get("entities", [])) improved_relations = len(current_kg.get("relations", [])) logger.info(f"Final result: {original_entities} → {improved_entities} entities (+{validation_result['entities_added']})") logger.info(f"Final result: {original_relations} → {improved_relations} relations (+{validation_result['relations_added']})") logger.info(f"Validation {'PASSED' if is_valid else 'FAILED'} after {iteration_count} iterations") return current_kg, validation_result, is_valid class OpenAIAgentKnowledgeExtractor(BaseKnowledgeExtractionMethod): """ Two-stage pipeline: 1. Direct LLM Method -> Generate Agent Graph 2. Validation Agent -> Validate generated graph only """ def __init__(self, model: str = "gpt-5-mini", **kwargs): super().__init__("openai_agent_method", **kwargs) self.model = model # Initialize Pydantic Hybrid extractor self.pydantic_extractor = PydanticKnowledgeExtractor(model=model, mode="hybrid_2_stage", **kwargs) # Set OpenAI API key if "OPENAI_API_KEY" not in os.environ: logger.warning("OPENAI_API_KEY not set in environment") def _extract_with_pipeline(self, text: str) -> Dict[str, Any]: """ Two-stage extraction pipeline: 1. Direct LLM -> Generate graph 2. Validation Agent -> Validate graph only """ try: logger.info("Stage 1: Generating graph with Pydantic Hybrid Method") # Stage 1: Generate graph using Pydantic Hybrid direct_result = self.pydantic_extractor.process_text(text) if not direct_result.get("success", False): logger.error("Pydantic Hybrid extraction failed") return direct_result kg_data = direct_result.get("kg_data", {}) logger.info(f"Pydantic Hybrid generated {len(kg_data.get('entities', []))} entities and {len(kg_data.get('relations', []))} relations") # Stage 2: Validate and Improve with Validation Agent (graph only, no trace data) logger.info("Stage 2: Validating and improving graph with Validation Agent") improved_kg_data, validation_result, is_valid = asyncio.run(validate_and_improve_graph(kg_data)) # Add validation metadata to improved graph improved_kg_data["metadata"] = improved_kg_data.get("metadata", {}) improved_kg_data["metadata"].update({ "method": "openai_agent_pipeline", "pipeline_stages": ["pydantic_hybrid", "validation_improvement_agent"], "validation_result": validation_result, "validation_passed": is_valid, "model": self.model, "timestamp": datetime.now().isoformat(), "improvement_stats": { "entities_added": validation_result.get("entities_added", 0), "relations_added": validation_result.get("relations_added", 0), "iterations": validation_result.get("iterations", 0) } }) logger.info(f"Validation {'PASSED' if is_valid else 'FAILED'}") entities_added = improved_kg_data['metadata']['improvement_stats']['entities_added'] relations_added = improved_kg_data['metadata']['improvement_stats']['relations_added'] iterations = improved_kg_data['metadata']['improvement_stats']['iterations'] logger.info(f"Graph improved: +{entities_added} entities, +{relations_added} relations in {iterations} iterations") return { "success": True, "kg_data": improved_kg_data, # Return improved graph "validation_passed": is_valid, "validation_details": validation_result } except Exception as e: logger.error(f"Pipeline extraction failed: {e}") return { "success": False, "error": str(e), "kg_data": { "entities": [], "relations": [], "system_name": "Failed Extraction", "system_summary": "Pipeline extraction failed.", "metadata": {"error": str(e), "method": "openai_agent_pipeline"}, }, } def process_text(self, text: str) -> Dict[str, Any]: """ Process input text using two-stage pipeline. Args: text: Input text to process Returns: Dictionary with kg_data, metadata, success, and validation info """ start_time = time.time() try: logger.info(f"Processing text with OpenAI Agent Pipeline (length: {len(text)})") # Run two-stage pipeline result = self._extract_with_pipeline(text) processing_time = time.time() - start_time if result.get("success", False): kg_data = result.get("kg_data", {}) entity_count = len(kg_data.get("entities", [])) relation_count = len(kg_data.get("relations", [])) return { "success": True, "kg_data": kg_data, "metadata": { "approach": "openai_agent_pipeline", "model": self.model, "method": self.method_name, "processing_time_seconds": processing_time, "entity_count": entity_count, "relation_count": relation_count, "validation_passed": result.get("validation_passed", False), "pipeline_stages": ["pydantic_hybrid", "validation_agent"], }, } else: return { "success": False, "error": result.get("error", "Unknown error"), "kg_data": result.get("kg_data", {"entities": [], "relations": []}), "metadata": { "approach": "openai_agent_pipeline", "model": self.model, "method": self.method_name, "processing_time_seconds": processing_time, "error": result.get("error", "Unknown error"), }, } except Exception as e: processing_time = time.time() - start_time logger.error(f"Error in OpenAI Agent pipeline: {e}") return { "success": False, "error": str(e), "kg_data": {"entities": [], "relations": []}, "metadata": { "approach": "openai_agent_pipeline", "model": self.model, "method": self.method_name, "processing_time_seconds": processing_time, "error": str(e), }, } def extract_knowledge_graph(self, trace_data: str) -> Dict[str, Any]: """ Extract knowledge graph using two-stage pipeline. Args: trace_data: Agent trace data as JSON string Returns: Dictionary with entities and relations """ try: logger.info(f"OpenAI Agent Pipeline called with trace_data type: {type(trace_data)}") if isinstance(trace_data, str): logger.info(f"trace_data length: {len(trace_data)}") # Process the trace data with pipeline result = self.process_text(trace_data) # Return just the knowledge graph data if result.get("success", False): return result.get("kg_data", {"entities": [], "relations": []}) else: # Return empty knowledge graph on failure return {"entities": [], "relations": []} except Exception as e: logger.error(f"Error in extract_knowledge_graph: {e}") return {"entities": [], "relations": []}