Spaces:
Running
Running
| """ | |
| Direct LLM Knowledge Extraction Method | |
| A streamlined approach that uses direct LLM API calls with structured output | |
| instead of the CrewAI framework for better performance and cost efficiency. | |
| """ | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| import time | |
| from datetime import datetime | |
| from typing import Any, Dict | |
| from openai import OpenAI | |
| from pydantic import ValidationError | |
| from evaluation.knowledge_extraction.baselines.unified_method import KnowledgeGraph | |
| # Add the parent directory to the path to ensure imports work correctly | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) | |
| from evaluation.knowledge_extraction.baselines.base_method import BaseKnowledgeExtractionMethod | |
| # Import shared prompt templates | |
| from evaluation.knowledge_extraction.utils.prompts import ( | |
| ENTITY_EXTRACTION_INSTRUCTION_PROMPT, | |
| ENTITY_EXTRACTION_SYSTEM_PROMPT, | |
| GRAPH_BUILDER_INSTRUCTION_PROMPT, | |
| GRAPH_BUILDER_SYSTEM_PROMPT, | |
| RELATION_EXTRACTION_INSTRUCTION_PROMPT, | |
| RELATION_EXTRACTION_SYSTEM_PROMPT, | |
| ) | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Set higher log levels for noisy libraries | |
| logging.getLogger("openai").setLevel(logging.WARNING) | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| class DirectLLMKnowledgeExtractor(BaseKnowledgeExtractionMethod): | |
| """Direct LLM knowledge extraction method using OpenAI API with structured output.""" | |
| def __init__(self, model: str = "gpt-5-mini", **kwargs): | |
| super().__init__("direct_llm_method", **kwargs) | |
| self.client = OpenAI() | |
| self.model = model | |
| self.max_retries = 3 | |
| self.retry_delay = 1.0 | |
| def _get_optimized_system_prompt(self) -> str: | |
| """Get the optimized system prompt for knowledge graph extraction.""" | |
| # Combine all system prompts for a unified extraction | |
| return f"""{ENTITY_EXTRACTION_SYSTEM_PROMPT} | |
| {RELATION_EXTRACTION_SYSTEM_PROMPT} | |
| {GRAPH_BUILDER_SYSTEM_PROMPT}""" | |
| def _get_extraction_instruction(self, text: str) -> str: | |
| """Get the extraction instruction with the input text.""" | |
| # Combine entity and relation extraction instructions | |
| entity_instruction = ENTITY_EXTRACTION_INSTRUCTION_PROMPT.format(input_data=text) | |
| relation_instruction = RELATION_EXTRACTION_INSTRUCTION_PROMPT.format(input_data=text) | |
| graph_instruction = GRAPH_BUILDER_INSTRUCTION_PROMPT | |
| return f"""Extract a complete knowledge graph from the following agent system data. | |
| First, extract entities following these instructions: | |
| {entity_instruction} | |
| Then, extract relations following these instructions: | |
| {relation_instruction} | |
| Finally, build the knowledge graph following these instructions: | |
| {graph_instruction} | |
| """ | |
| def _extract_with_retry(self, text: str) -> Dict[str, Any]: | |
| """Extract knowledge graph with retry logic using new Structured Outputs API.""" | |
| last_error = None | |
| for attempt in range(self.max_retries): | |
| try: | |
| logger.info(f"Extraction attempt {attempt + 1}/{self.max_retries}") | |
| # Use the beta API with structured outputs | |
| response = self.client.beta.chat.completions.parse( | |
| model=self.model, | |
| messages=[ | |
| {"role": "system", "content": self._get_optimized_system_prompt()}, | |
| {"role": "user", "content": self._get_extraction_instruction(text)}, | |
| ], | |
| response_format=KnowledgeGraph, | |
| temperature=0, | |
| ) | |
| # Get the parsed response | |
| parsed_response = response.choices[0].message.parsed | |
| # Handle refusal | |
| if response.choices[0].message.refusal: | |
| raise ValueError(f"Model refused: {response.choices[0].message.refusal}") | |
| if not parsed_response: | |
| raise ValueError("Empty parsed response from LLM") | |
| # Convert to dict | |
| kg_dict = parsed_response.model_dump() | |
| # Add metadata | |
| kg_dict["metadata"] = { | |
| "method": "direct_llm", | |
| "model": self.model, | |
| "attempt": attempt + 1, | |
| "timestamp": datetime.now().isoformat(), | |
| "token_usage": { | |
| "prompt_tokens": response.usage.prompt_tokens if response.usage else 0, | |
| "completion_tokens": response.usage.completion_tokens if response.usage else 0, | |
| "total_tokens": response.usage.total_tokens if response.usage else 0, | |
| }, | |
| } | |
| logger.info( | |
| f"Successfully extracted KG with {len(kg_dict['entities'])} entities and {len(kg_dict['relations'])} relations" | |
| ) | |
| return kg_dict | |
| except json.JSONDecodeError as e: | |
| last_error = f"JSON parsing error: {e}" | |
| logger.warning(f"Attempt {attempt + 1} failed: {last_error}") | |
| except ValidationError as e: | |
| last_error = f"Validation error: {e}" | |
| logger.warning(f"Attempt {attempt + 1} failed: {last_error}") | |
| except Exception as e: | |
| last_error = f"API error: {e}" | |
| logger.warning(f"Attempt {attempt + 1} failed: {last_error}") | |
| if attempt < self.max_retries - 1: | |
| time.sleep(self.retry_delay * (2**attempt)) # Exponential backoff | |
| # If all attempts failed, return empty structure | |
| logger.error(f"All extraction attempts failed. Last error: {last_error}") | |
| return { | |
| "entities": [], | |
| "relations": [], | |
| "system_name": "Failed Extraction", | |
| "system_summary": "Knowledge graph extraction failed after multiple attempts.", | |
| "metadata": {"error": last_error, "method": "direct_llm"}, | |
| } | |
| def process_text(self, text: str) -> Dict[str, Any]: | |
| """ | |
| Process input text using direct LLM API calls. | |
| Args: | |
| text: Input text to process | |
| Returns: | |
| Dictionary with kg_data, metadata, success, and optional error | |
| """ | |
| start_time = time.time() | |
| try: | |
| logger.info(f"Processing text with Direct LLM method (length: {len(text)})") | |
| # Extract knowledge graph | |
| kg_data = self._extract_with_retry(text) | |
| processing_time = time.time() - start_time | |
| # Check if extraction was successful | |
| success = len(kg_data.get("entities", [])) > 0 or len(kg_data.get("relations", [])) > 0 | |
| # Calculate statistics | |
| entity_count = len(kg_data.get("entities", [])) | |
| relation_count = len(kg_data.get("relations", [])) | |
| # Add processing metadata | |
| if "metadata" not in kg_data: | |
| kg_data["metadata"] = {} | |
| kg_data["metadata"].update( | |
| { | |
| "processing_info": { | |
| "method": "direct_llm", | |
| "processing_time_seconds": processing_time, | |
| "processed_at": datetime.now().isoformat(), | |
| "model": self.model, | |
| "api_calls": 1, | |
| "entity_count": entity_count, | |
| "relation_count": relation_count, | |
| } | |
| } | |
| ) | |
| return { | |
| "success": success, | |
| "kg_data": kg_data, | |
| "metadata": { | |
| "approach": "direct_llm", | |
| "model": self.model, | |
| "method": self.method_name, | |
| "processing_time_seconds": processing_time, | |
| "entity_count": entity_count, | |
| "relation_count": relation_count, | |
| "entities_per_second": entity_count / processing_time if processing_time > 0 else 0, | |
| "relations_per_second": relation_count / processing_time if processing_time > 0 else 0, | |
| "api_calls": 1, | |
| "token_usage": kg_data.get("metadata", {}).get("token_usage", {}), | |
| }, | |
| } | |
| except Exception as e: | |
| processing_time = time.time() - start_time | |
| logger.error(f"Error in direct LLM knowledge extraction: {e}") | |
| import traceback | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "kg_data": {"entities": [], "relations": []}, | |
| "metadata": { | |
| "approach": "direct_llm", | |
| "model": self.model, | |
| "method": self.method_name, | |
| "processing_time_seconds": processing_time, | |
| "api_calls": 1, | |
| "error": str(e), | |
| }, | |
| } | |
| def extract_knowledge_graph(self, trace_data: str) -> Dict[str, Any]: | |
| """ | |
| Extract knowledge graph from trace data. | |
| Args: | |
| trace_data: Agent trace data as JSON string | |
| Returns: | |
| Dictionary with entities and relations | |
| """ | |
| try: | |
| logger.info(f"extract_knowledge_graph called with trace_data type: {type(trace_data)}") | |
| if isinstance(trace_data, str): | |
| logger.info(f"trace_data length: {len(trace_data)}") | |
| logger.info(f"trace_data first 200 chars: {repr(trace_data[:200])}") | |
| # Process the trace data | |
| result = self.process_text(trace_data) | |
| # Return just the knowledge graph data | |
| if result.get("success", False): | |
| return result.get("kg_data", {"entities": [], "relations": []}) | |
| else: | |
| # Return empty knowledge graph on failure | |
| return {"entities": [], "relations": []} | |
| except Exception as e: | |
| logger.error(f"Error in extract_knowledge_graph: {e}") | |
| logger.error(f"trace_data type: {type(trace_data)}") | |
| if isinstance(trace_data, str): | |
| logger.error(f"trace_data content (first 200 chars): {repr(trace_data[:200])}") | |
| return {"entities": [], "relations": []} | |