Spaces:
Running
Running
| """ | |
| Base Interface for Knowledge Extraction Methods | |
| Defines the standard interface that all knowledge extraction baselines must implement. | |
| """ | |
| import asyncio | |
| import logging | |
| import time | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Dict | |
| logger = logging.getLogger(__name__) | |
| class BaseKnowledgeExtractionMethod(ABC): | |
| """Abstract base class for knowledge extraction methods.""" | |
| def __init__(self, method_name: str, **kwargs): | |
| """ | |
| Initialize the knowledge extraction method. | |
| Args: | |
| method_name: Name of the method | |
| **kwargs: Additional method-specific parameters | |
| """ | |
| self.method_name = method_name | |
| self.config = kwargs | |
| def process_text(self, text: str) -> Dict[str, Any]: | |
| """ | |
| Process input text and extract knowledge graph. | |
| Args: | |
| text: Input text to process | |
| Returns: | |
| Dictionary containing: | |
| - kg_data: Knowledge graph data with entities and relations | |
| - metadata: Processing metadata (timing, method info, etc.) | |
| - success: Boolean indicating if processing was successful | |
| - error: Error message if processing failed | |
| """ | |
| pass | |
| async def process_text_async(self, text: str) -> Dict[str, Any]: | |
| """ | |
| Async wrapper for process_text method. | |
| Args: | |
| text: Input text to process | |
| Returns: | |
| Dictionary containing the same format as process_text | |
| """ | |
| return await asyncio.to_thread(self.process_text, text) | |
| def get_method_info(self) -> Dict[str, Any]: | |
| """Get information about this method.""" | |
| return { | |
| "name": self.method_name, | |
| "config": self.config, | |
| "description": self.__doc__ or "No description available" | |
| } | |
| def check_success(cls, kg_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Check if the knowledge graph was created successfully with required elements and valid relations. | |
| Args: | |
| kg_data: Knowledge graph data dictionary | |
| Returns: | |
| Dictionary with success status and validation details | |
| """ | |
| # Schema v3 relation definitions | |
| valid_relations = { | |
| "CONSUMED_BY": ("Input", "Agent"), # Direction is Input -> Agent | |
| "PERFORMS": ("Agent", "Task"), | |
| "ASSIGNED_TO": ("Task", "Agent"), | |
| "USES": ("Agent", "Tool"), | |
| "REQUIRED_BY": ("Tool", "Task"), | |
| "SUBTASK_OF": ("Task", "Task"), | |
| "NEXT": ("Task", "Task"), | |
| "PRODUCES": ("Task", "Output"), | |
| "DELIVERS_TO": ("Output", "Human"), | |
| "INTERVENES": (["Agent", "Human"], "Task"), | |
| } | |
| # Required & optional entity types for schema v3 | |
| required_types = {"Agent", "Task", "Input", "Output"} | |
| optional_types = {"Tool", "Human"} | |
| entities = kg_data.get("entities", []) | |
| relations = kg_data.get("relations", []) | |
| # Create entity lookup for validation | |
| entity_lookup = {e["id"]: e for e in entities} | |
| # Check for required entity types | |
| found_types = {e["type"] for e in entities} | |
| missing_required = required_types - found_types | |
| # Validate relations | |
| invalid_relations = [] | |
| for relation in relations: | |
| rel_type = relation.get("type") | |
| # Check if relation type is valid | |
| if rel_type not in valid_relations: | |
| invalid_relations.append({ | |
| "relation": relation, | |
| "error": f"Invalid relation type: {rel_type}" | |
| }) | |
| continue | |
| # Get source and target entities | |
| source_id = relation.get("source") | |
| target_id = relation.get("target") | |
| source_entity = entity_lookup.get(source_id) | |
| target_entity = entity_lookup.get(target_id) | |
| if not source_entity or not target_entity: | |
| invalid_relations.append({ | |
| "relation": relation, | |
| "error": "Source or target entity not found" | |
| }) | |
| continue | |
| # Validate source->target type constraints | |
| expected_source, expected_target = valid_relations[rel_type] | |
| errors = [] | |
| # Handle both single types and list of types for source | |
| if expected_source: | |
| if isinstance(expected_source, list): | |
| if source_entity["type"] not in expected_source: | |
| errors.append(f"{rel_type} requires source type in {expected_source}, got {source_entity['type']}") | |
| else: | |
| if source_entity["type"] != expected_source: | |
| errors.append(f"{rel_type} requires source type {expected_source}, got {source_entity['type']}") | |
| # Handle both single types and list of types for target | |
| if expected_target: | |
| if isinstance(expected_target, list): | |
| if target_entity["type"] not in expected_target: | |
| errors.append(f"{rel_type} requires target type in {expected_target}, got {target_entity['type']}") | |
| else: | |
| if target_entity["type"] != expected_target: | |
| errors.append(f"{rel_type} requires target type {expected_target}, got {target_entity['type']}") | |
| if errors: | |
| invalid_relations.append({ | |
| "relation": relation, | |
| "error": "; ".join(errors) | |
| }) | |
| # Count entities by type | |
| entity_counts = {} | |
| for entity in entities: | |
| entity_type = entity["type"] | |
| entity_counts[entity_type] = entity_counts.get(entity_type, 0) + 1 | |
| # Count relations by type | |
| relation_counts = {} | |
| for relation in relations: | |
| rel_type = relation.get("type", "UNKNOWN") | |
| relation_counts[rel_type] = relation_counts.get(rel_type, 0) + 1 | |
| # Find isolated entities (entities not connected to any relation) | |
| connected_entity_ids = set() | |
| for relation in relations: | |
| connected_entity_ids.add(relation.get("source")) | |
| connected_entity_ids.add(relation.get("target")) | |
| isolated_entities = [] | |
| for entity in entities: | |
| if entity.get("id") not in connected_entity_ids: | |
| isolated_entities.append(entity) | |
| # Determine overall success | |
| success = ( | |
| len(missing_required) == 0 and | |
| len(invalid_relations) == 0 and | |
| len(entities) > 0 and | |
| len(relations) > 0 | |
| ) | |
| return { | |
| "success": success, | |
| "validation": { | |
| "entity_counts": entity_counts, | |
| "relation_counts": relation_counts, | |
| "missing_required_types": list(missing_required), | |
| "found_optional_types": list(found_types & optional_types), | |
| "invalid_relations": invalid_relations, | |
| "isolated_entities": isolated_entities, | |
| "total_entities": len(entities), | |
| "total_relations": len(relations), | |
| "total_invalid_relations": len(invalid_relations), | |
| "total_isolated_entities": len(isolated_entities), | |
| "has_required_elements": len(missing_required) == 0, | |
| "all_relations_valid": len(invalid_relations) == 0, | |
| "no_isolated_entities": len(isolated_entities) == 0 | |
| } | |
| } | |
| def validate_output(self, result: Dict[str, Any]) -> bool: | |
| """ | |
| Validate that the method output follows the expected format. | |
| Args: | |
| result: Result from process_text method | |
| Returns: | |
| True if output is valid, False otherwise | |
| """ | |
| required_keys = ["kg_data", "metadata", "success"] | |
| if not isinstance(result, dict): | |
| logger.error(f"Method {self.method_name} output is not a dictionary") | |
| return False | |
| for key in required_keys: | |
| if key not in result: | |
| logger.error(f"Method {self.method_name} output missing required key: {key}") | |
| return False | |
| # Validate kg_data structure | |
| kg_data = result.get("kg_data", {}) | |
| if not isinstance(kg_data, dict): | |
| logger.error(f"Method {self.method_name} kg_data is not a dictionary") | |
| return False | |
| if "entities" not in kg_data or "relations" not in kg_data: | |
| logger.error(f"Method {self.method_name} kg_data missing entities or relations") | |
| return False | |
| if not isinstance(kg_data["entities"], list) or not isinstance(kg_data["relations"], list): | |
| logger.error(f"Method {self.method_name} entities/relations are not lists") | |
| return False | |
| return True | |
| def process_with_timing(self, text: str) -> Dict[str, Any]: | |
| """ | |
| Process text with automatic timing and error handling. | |
| Args: | |
| text: Input text to process | |
| Returns: | |
| Result dictionary with timing information | |
| """ | |
| start_time = time.time() | |
| try: | |
| result = self.process_text(text) | |
| # Ensure result has required structure | |
| if not self.validate_output(result): | |
| result = { | |
| "kg_data": {"entities": [], "relations": []}, | |
| "metadata": {"method": self.method_name, "error": "Invalid output format"}, | |
| "success": False, | |
| "error": "Method produced invalid output format" | |
| } | |
| except Exception as e: | |
| logger.error(f"Error in {self.method_name}: {str(e)}") | |
| result = { | |
| "kg_data": {"entities": [], "relations": []}, | |
| "metadata": {"method": self.method_name, "error": str(e)}, | |
| "success": False, | |
| "error": str(e) | |
| } | |
| # Add timing information | |
| end_time = time.time() | |
| processing_time = end_time - start_time | |
| if "metadata" not in result: | |
| result["metadata"] = {} | |
| result["metadata"].update({ | |
| "method": self.method_name, | |
| "processing_time": processing_time, | |
| "start_time": start_time, | |
| "end_time": end_time | |
| }) | |
| return result |