Spaces:
Running
Running
| """ | |
| Direct LLM Knowledge Extraction Method | |
| A streamlined approach that uses direct LLM API calls with structured output | |
| instead of the CrewAI framework for better performance and cost efficiency. | |
| Supports both 3-stage (original) and 2-stage (hybrid) processing modes. | |
| """ | |
| import asyncio | |
| import logging | |
| import os | |
| import sys | |
| import time | |
| from asyncio import gather | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from pydantic_ai import Agent | |
| from pydantic_ai.agent import AgentRunResult | |
| from pydantic_ai.settings import ModelSettings | |
| from pydantic_ai.usage import Usage | |
| # Import shared prompt templates (schema v3) | |
| from evaluation.knowledge_extraction.utils.prompts import ( | |
| ENTITY_EXTRACTION_INSTRUCTION_PROMPT, | |
| ENTITY_EXTRACTION_SYSTEM_PROMPT, | |
| GRAPH_BUILDER_INSTRUCTION_PROMPT, | |
| GRAPH_BUILDER_SYSTEM_PROMPT, | |
| RELATION_EXTRACTION_INSTRUCTION_PROMPT, | |
| RELATION_EXTRACTION_SYSTEM_PROMPT, | |
| ) | |
| # Add the parent directory to the path to ensure imports work correctly | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) | |
| from evaluation.knowledge_extraction.baselines.base_method import BaseKnowledgeExtractionMethod | |
| from evaluation.knowledge_extraction.utils.models import Entity, KnowledgeGraph, Relation | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Set higher log levels for noisy libraries | |
| logging.getLogger("openai").setLevel(logging.WARNING) | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| async def get_agent_graph_entities(trace_content: str, temperature: float = 0.0) -> AgentRunResult[List[Entity]]: | |
| model = os.environ.get("OPENAI_MODEL_NAME", "gpt-5-mini") | |
| # Use shared prompt templates | |
| system_prompt = ENTITY_EXTRACTION_SYSTEM_PROMPT | |
| instruction_template = ENTITY_EXTRACTION_INSTRUCTION_PROMPT | |
| entity_agent = Agent( | |
| model, | |
| model_settings=ModelSettings(temperature=temperature), | |
| output_type=List[Entity], | |
| system_prompt=system_prompt | |
| ) | |
| entity_result: AgentRunResult[List[Entity]] = await entity_agent.run(instruction_template.format(input_data=trace_content)) | |
| return entity_result | |
| async def get_agent_graph_relations( | |
| trace_content: str, entities: Optional[List[Entity]] = None, temperature: float = 0 | |
| ) -> AgentRunResult[List[Relation]]: | |
| model = os.environ.get("OPENAI_MODEL_NAME", "gpt-5-mini") | |
| # Use shared prompt templates | |
| system_prompt = RELATION_EXTRACTION_SYSTEM_PROMPT | |
| instruction_template = RELATION_EXTRACTION_INSTRUCTION_PROMPT | |
| # Add entities information at the end if provided | |
| if entities: | |
| instruction_template += "\n\nAvailable Entities: {entities}" | |
| relation_agent = Agent( | |
| model, | |
| model_settings=ModelSettings(temperature=temperature), | |
| output_type=List[Relation], | |
| system_prompt=system_prompt | |
| ) | |
| if entities: | |
| instruction = instruction_template.format(input_data=trace_content, entities=entities) | |
| else: | |
| instruction = instruction_template.format(input_data=trace_content) | |
| relation_result: AgentRunResult[List[Relation]] = await relation_agent.run(instruction) | |
| return relation_result | |
| def remove_duplicate_relations(relations: List[Relation]) -> List[Relation]: | |
| """Remove duplicate relations, keeping the last occurrence (latest created).""" | |
| seen = {} | |
| for relation in relations: | |
| key = (relation.source, relation.target, relation.type) | |
| seen[key] = relation | |
| return list(seen.values()) | |
| def validate_knowledge_graph(kg: KnowledgeGraph) -> KnowledgeGraph: | |
| """Validate and clean knowledge graph by removing invalid relations and ensuring connectivity.""" | |
| if not kg.entities or not kg.relations: | |
| logger.warning("Knowledge graph has no entities or relations") | |
| return kg | |
| entity_ids = {entity.id for entity in kg.entities} | |
| cleaned_relations = remove_duplicate_relations(kg.relations) | |
| valid_relations = [] | |
| for relation in cleaned_relations: | |
| if relation.source in entity_ids and relation.target in entity_ids: | |
| valid_relations.append(relation) | |
| else: | |
| logger.warning(f"Removing invalid relation: {relation.source} -> {relation.target} (missing entities)") | |
| kg.relations = cleaned_relations | |
| logger.info(f"Validation complete: {len(kg.entities)} entities, " | |
| f"{len(valid_relations)}/{len(cleaned_relations)} relations kept") | |
| return kg | |
| async def build_agent_graph(entities: List[Entity], relations: List[Relation], temperature: float = 0.0) -> AgentRunResult[KnowledgeGraph]: | |
| model = os.environ.get("OPENAI_MODEL_NAME", "gpt-5-mini") | |
| # Use shared prompt templates | |
| system_prompt = GRAPH_BUILDER_SYSTEM_PROMPT | |
| instruction_template = GRAPH_BUILDER_INSTRUCTION_PROMPT | |
| graph_builder_agent = Agent( | |
| model, | |
| model_settings=ModelSettings(temperature=temperature), | |
| output_type=KnowledgeGraph, | |
| system_prompt=system_prompt | |
| ) | |
| graph_result: AgentRunResult[KnowledgeGraph] = await graph_builder_agent.run( | |
| instruction_template + "\n\nEntities: " + str(entities) + "\n\nRelations: " + str(relations) | |
| ) | |
| cleaned_kg = validate_knowledge_graph(graph_result.output) | |
| graph_result.output = cleaned_kg | |
| return graph_result | |
| # Hybrid method functions | |
| async def get_hybrid_extraction(trace_content: str, temperature: float = 0.0) -> AgentRunResult[str]: | |
| """First stage of hybrid method: combined entity and relation extraction (text output).""" | |
| model = os.environ.get("OPENAI_MODEL_NAME", "gpt-5-mini") | |
| role = "Knowledge Extraction Specialist" | |
| goal = "Extract comprehensive entities and relationships from agent system data efficiently" | |
| system_prompt = f"""You are {role}. | |
| Your goal is: {goal} | |
| {ENTITY_EXTRACTION_SYSTEM_PROMPT} | |
| {RELATION_EXTRACTION_SYSTEM_PROMPT}""" | |
| # Hybrid extraction instruction (combines both tasks) | |
| instruction_template = f""" | |
| {ENTITY_EXTRACTION_INSTRUCTION_PROMPT} | |
| {RELATION_EXTRACTION_INSTRUCTION_PROMPT} | |
| Expected Output: Structured extraction with entities, relations, and preliminary analysis | |
| """ | |
| extraction_agent = Agent( | |
| model, | |
| model_settings=ModelSettings(temperature=temperature), | |
| result_type=str, | |
| system_prompt=system_prompt | |
| ) | |
| extraction_result: AgentRunResult[str] = await extraction_agent.run( | |
| instruction_template.format(input_data=trace_content) | |
| ) | |
| return extraction_result | |
| async def get_hybrid_validation(extraction_text: str, temperature: float = 0.0) -> AgentRunResult[KnowledgeGraph]: | |
| """Second stage of hybrid method: validation and enhancement (matches original).""" | |
| model = os.environ.get("OPENAI_MODEL_NAME", "gpt-5-mini") | |
| role = "Knowledge Graph Validator and Enhancer" | |
| goal = "Validate, enhance, and structure extracted knowledge into a comprehensive knowledge graph" | |
| system_prompt = f"""You are {role}. | |
| Your goal is: {goal} | |
| You are a knowledge graph validation and enhancement specialist who ensures | |
| the quality, completeness, and coherence of extracted knowledge graphs. You take raw | |
| extracted entities and relationships and transform them into polished, well-structured | |
| knowledge graphs. | |
| Your expertise includes: | |
| - Validating entity and relationship consistency | |
| - Identifying and filling gaps in knowledge extraction | |
| - Ensuring proper connectivity and graph coherence | |
| - Creating meaningful system summaries and assessments | |
| - Optimizing knowledge graph structure for clarity and usability | |
| You serve as the quality assurance layer that transforms good extractions into | |
| excellent knowledge graphs.""" | |
| # Validation instruction | |
| instruction_template = """ | |
| Validate, enhance, and structure the extracted knowledge into a comprehensive knowledge graph. | |
| Take the extracted entities and relationships from the previous task and: | |
| 1. VALIDATION AND ENHANCEMENT: | |
| - Verify all entities have proper IDs, types, names, and descriptions | |
| - Ensure all relationships use correct predefined types | |
| - Check that every entity connects to at least one other entity | |
| - Fill any gaps in entity descriptions or relationship mappings | |
| - Validate that relationship directions and types are correct | |
| 2. CONNECTIVITY OPTIMIZATION: | |
| - Ensure no isolated entities (all must be connected) | |
| - Verify logical flow from inputs through processing to outputs | |
| - Add missing relationships if entities should be connected | |
| - Optimize relationship network for clarity and completeness | |
| 3. KNOWLEDGE GRAPH CONSTRUCTION: | |
| - Create descriptive system name (3-7 words) | |
| - Write comprehensive 2-3 sentence system summary explaining purpose, coordination, and value | |
| - Include metadata with timestamp, statistics, and processing information | |
| - Ensure all components are reachable (no isolated subgraphs) | |
| - Validate connectivity: inputs consumed, outputs produced, agents have roles | |
| 4. QUALITY ASSURANCE: | |
| - Double-check entity uniqueness and proper categorization | |
| - Verify relationship consistency and logical flow | |
| - Ensure system summary accurately reflects the extracted knowledge | |
| - Validate that the knowledge graph tells a coherent story | |
| EXTRACTION RESULTS FROM PREVIOUS TASK: | |
| {extraction_text} | |
| Expected Output: A complete, validated knowledge graph with entities, relations, and metadata | |
| Output a complete, validated KnowledgeGraph object with entities, relations, system_name, | |
| system_summary, and metadata. Ensure the knowledge graph is comprehensive, accurate, | |
| well-connected, and represents the system effectively. | |
| """ | |
| validation_agent = Agent( | |
| model, | |
| model_settings=ModelSettings(temperature=temperature), | |
| output_type=KnowledgeGraph, | |
| system_prompt=system_prompt | |
| ) | |
| validation_result: AgentRunResult[KnowledgeGraph] = await validation_agent.run( | |
| instruction_template.format(extraction_text=extraction_text) | |
| ) | |
| cleaned_kg = validate_knowledge_graph(validation_result.output) | |
| validation_result.output = cleaned_kg | |
| return validation_result | |
| async def get_agent_graph(trace_content: str, sequential: bool = False, hybrid: bool = False, temperature: float = 0) -> Tuple[KnowledgeGraph, Usage]: | |
| if hybrid: | |
| # Hybrid 2-stage processing: extraction -> validation | |
| extraction_result = await get_hybrid_extraction(trace_content, temperature) | |
| extraction_data = extraction_result.output | |
| # Validate and enhance with extraction results only | |
| graph_result = await get_hybrid_validation(extraction_data, temperature) | |
| # Combine usage from both stages | |
| total_usage = Usage() | |
| total_usage.incr(extraction_result.usage()) | |
| total_usage.incr(graph_result.usage()) | |
| return graph_result.output, total_usage | |
| elif sequential: | |
| # Sequential processing: entities first, then relations with entity information | |
| entity_result = await get_agent_graph_entities(trace_content, temperature) | |
| entities = entity_result.output | |
| # Pass entities to relation extraction | |
| relation_result = await get_agent_graph_relations(trace_content, entities, temperature) | |
| relations = relation_result.output | |
| else: | |
| # Parallel processing: entities and relations simultaneously | |
| entity_result, relation_result = await gather( | |
| get_agent_graph_entities(trace_content, temperature), | |
| get_agent_graph_relations(trace_content, temperature=temperature) | |
| ) | |
| entities = entity_result.output | |
| relations = relation_result.output | |
| # Build the graph with the extracted entities and relations | |
| graph_run_result = await build_agent_graph(entities, relations, temperature) | |
| graph_result = graph_run_result.output | |
| # Combine usage from all three agents | |
| total_usage = Usage() | |
| total_usage.incr(entity_result.usage()) | |
| total_usage.incr(relation_result.usage()) | |
| total_usage.incr(graph_run_result.usage()) | |
| return graph_result, total_usage | |
| class PydanticKnowledgeExtractor(BaseKnowledgeExtractionMethod): | |
| """Direct LLM knowledge extraction method using pydantic_ai with structured output.""" | |
| def __init__(self, model: str = "gpt-5-mini", sequential: bool = False, hybrid: bool = False, temperature: float = 0.0, **kwargs): | |
| method_name = "pydantic_ai_method" | |
| if hybrid: | |
| method_name = "pydantic_hybrid_method" | |
| elif sequential: | |
| method_name = "pydantic_sequential_method" | |
| super().__init__(method_name, **kwargs) | |
| self.model = model | |
| self.sequential = sequential | |
| self.hybrid = hybrid | |
| self.temperature = temperature | |
| os.environ["OPENAI_MODEL_NAME"] = model | |
| def process_text(self, text: str) -> Dict[str, Any]: | |
| """ | |
| Process input text using pydantic_ai agents. | |
| Args: | |
| text: Input text to process | |
| Returns: | |
| Dictionary with kg_data, metadata, success, and optional error | |
| """ | |
| start_time = time.time() | |
| try: | |
| mode = "hybrid_2_stage" if self.hybrid else ("sequential_3_stage" if self.sequential else "parallel_3_stage") | |
| logger.info(f"Processing text with pydantic_ai method in {mode} mode (length: {len(text)})") | |
| # Extract knowledge graph using async function | |
| kg_data: KnowledgeGraph | |
| usage: Usage | |
| kg_data, usage = asyncio.run(get_agent_graph(text, self.sequential, self.hybrid, self.temperature)) | |
| # Convert to dict format | |
| kg_dict = kg_data.model_dump() | |
| processing_time = time.time() - start_time | |
| # Check if extraction was successful | |
| success = len(kg_dict.get("entities", [])) > 0 or len(kg_dict.get("relations", [])) > 0 | |
| # # Perform detailed validation | |
| validation_result = self.check_success(kg_dict) | |
| success = validation_result["success"] | |
| # Calculate statistics | |
| entity_count = len(kg_dict.get("entities", [])) | |
| relation_count = len(kg_dict.get("relations", [])) | |
| # Add processing metadata | |
| if "metadata" not in kg_dict: | |
| kg_dict["metadata"] = {} | |
| kg_dict["metadata"].update({ | |
| "processing_info": { | |
| "method": "pydantic_ai", | |
| "mode": mode, | |
| "processing_time_seconds": processing_time, | |
| "processed_at": datetime.now().isoformat(), | |
| "model": self.model, | |
| "api_calls": usage.requests, | |
| "entity_count": entity_count, | |
| "relation_count": relation_count | |
| } | |
| }) | |
| total_tokens = usage.total_tokens or 0 | |
| request_tokens = usage.request_tokens or 0 | |
| response_tokens = usage.response_tokens or 0 | |
| token_usage = { | |
| "total_tokens": total_tokens, | |
| "prompt_tokens": request_tokens, | |
| "completion_tokens": response_tokens, | |
| "total_cost_usd": self._calculate_token_cost(total_tokens, request_tokens, response_tokens, self.model), | |
| "usage_available": True | |
| } | |
| # Create metadata with actual usage information | |
| metadata = { | |
| "approach": f"pydantic_ai_{mode}", | |
| "model": self.model, | |
| "method": self.method_name, | |
| "processing_time_seconds": processing_time, | |
| "entity_count": entity_count, | |
| "relation_count": relation_count, | |
| "entities_per_second": entity_count / processing_time if processing_time > 0 else 0, | |
| "relations_per_second": relation_count / processing_time if processing_time > 0 else 0, | |
| "api_calls": usage.requests, | |
| "request_tokens": usage.request_tokens, | |
| "response_tokens": usage.response_tokens, | |
| "token_usage": token_usage, | |
| "validation": validation_result["validation"] | |
| } | |
| kg_dict["metadata"] = metadata | |
| # Add token usage details if available | |
| if usage.details: | |
| metadata["token_details"] = usage.details | |
| return { | |
| "success": success, | |
| "kg_data": kg_dict, | |
| "metadata": metadata | |
| } | |
| except Exception as e: | |
| processing_time = time.time() - start_time | |
| logger.error(f"Error in pydantic_ai knowledge extraction: {e}") | |
| import traceback | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| mode = "hybrid_2_stage" if self.hybrid else ("sequential_3_stage" if self.sequential else "parallel_3_stage") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "kg_data": {"entities": [], "relations": []}, | |
| "metadata": { | |
| "approach": f"pydantic_ai_{mode}", | |
| "model": self.model, | |
| "method": self.method_name, | |
| "processing_time_seconds": processing_time, | |
| "api_calls": 0, | |
| "error": str(e), | |
| "token_usage": { | |
| "total_tokens": 0, | |
| "prompt_tokens": 0, | |
| "completion_tokens": 0, | |
| "model_used": self.model, | |
| "total_cost_usd": 0.0, | |
| "usage_available": True | |
| } | |
| } | |
| } | |
| def _calculate_token_cost(self, total_tokens: int, prompt_tokens: int, completion_tokens: int, model_name: str) -> float: | |
| """ | |
| Calculate token cost based on model pricing. | |
| Args: | |
| total_tokens: Total number of tokens | |
| prompt_tokens: Number of input/prompt tokens | |
| completion_tokens: Number of output/completion tokens | |
| model_name: Name of the model used | |
| Returns: | |
| Total cost in USD | |
| """ | |
| # Model pricing per 1k tokens (as of 2025) | |
| pricing = { | |
| "gpt-4o-mini": {"input": 0.00015, "output": 0.0006}, | |
| "gpt-4o": {"input": 0.0025, "output": 0.01}, | |
| "gpt-4": {"input": 0.03, "output": 0.06}, | |
| "gpt-4-turbo": {"input": 0.01, "output": 0.03}, | |
| "gpt-3.5-turbo": {"input": 0.0015, "output": 0.002}, | |
| "gpt-4.1": {"input": 0.002, "output": 0.008}, | |
| "gpt-4.1-mini": {"input": 0.0004, "output": 0.0016}, | |
| "gpt-4.1-nano": {"input": 0.0001, "output": 0.0004}, | |
| "gpt-4.5-preview": {"input": 0.075, "output": 0.15}, | |
| "claude-3-opus": {"input": 0.015, "output": 0.075}, | |
| "claude-3-sonnet": {"input": 0.003, "output": 0.015}, | |
| "claude-3-haiku": {"input": 0.00025, "output": 0.00125}, | |
| "claude-3.5-sonnet": {"input": 0.003, "output": 0.015}, | |
| "claude-3.5-haiku": {"input": 0.0008, "output": 0.004} | |
| } | |
| # Normalize model name to match pricing keys | |
| model_key = model_name.lower() | |
| if "gpt-4o-mini" in model_key: | |
| model_key = "gpt-4o-mini" | |
| elif "gpt-4o" in model_key: | |
| model_key = "gpt-4o" | |
| elif "gpt-4.5-preview" in model_key: | |
| model_key = "gpt-4.5-preview" | |
| elif "gpt-4.1-nano" in model_key: | |
| model_key = "gpt-4.1-nano" | |
| elif "gpt-4.1-mini" in model_key: | |
| model_key = "gpt-4.1-mini" | |
| elif "gpt-4.1" in model_key: | |
| model_key = "gpt-4.1" | |
| elif "gpt-4" in model_key: | |
| model_key = "gpt-4" | |
| elif "gpt-3.5" in model_key: | |
| model_key = "gpt-3.5-turbo" | |
| elif "claude-3.5-sonnet" in model_key: | |
| model_key = "claude-3.5-sonnet" | |
| elif "claude-3.5-haiku" in model_key: | |
| model_key = "claude-3.5-haiku" | |
| elif "claude-3-opus" in model_key: | |
| model_key = "claude-3-opus" | |
| elif "claude-3-sonnet" in model_key: | |
| model_key = "claude-3-sonnet" | |
| elif "claude-3-haiku" in model_key: | |
| model_key = "claude-3-haiku" | |
| if model_key not in pricing: | |
| # Default to gpt-4o-mini pricing if model not found | |
| model_key = "gpt-4o-mini" | |
| rates = pricing[model_key] | |
| # Calculate cost: (tokens / 1000) * rate_per_1k_tokens | |
| input_cost = (prompt_tokens / 1000) * rates["input"] | |
| output_cost = (completion_tokens / 1000) * rates["output"] | |
| return input_cost + output_cost | |
| def extract_knowledge_graph(self, trace_data: str) -> Dict[str, Any]: | |
| """ | |
| Extract knowledge graph from trace data. | |
| Args: | |
| trace_data: Agent trace data as JSON string | |
| Returns: | |
| Dictionary with entities and relations | |
| """ | |
| try: | |
| logger.info(f"extract_knowledge_graph called with trace_data type: {type(trace_data)}") | |
| if isinstance(trace_data, str): | |
| logger.info(f"trace_data length: {len(trace_data)}") | |
| logger.info(f"trace_data first 200 chars: {repr(trace_data[:200])}") | |
| # Process the trace data | |
| result = self.process_text(trace_data) | |
| # Return just the knowledge graph data | |
| return result.get("kg_data", {"entities": [], "relations": []}) | |
| except Exception as e: | |
| logger.error(f"Error in extract_knowledge_graph: {e}") | |
| logger.error(f"trace_data type: {type(trace_data)}") | |
| if isinstance(trace_data, str): | |
| logger.error(f"trace_data content (first 200 chars): {repr(trace_data[:200])}") | |
| return {"entities": [], "relations": []} | |