Spaces:
Running
Running
| """ | |
| Extraction Factory for Knowledge Extraction Methods | |
| This module provides a factory for creating instances of knowledge extraction methods | |
| based on their registry configuration. It handles dynamic loading and instantiation. | |
| """ | |
| import importlib | |
| import inspect | |
| from typing import Any, Dict, Optional, Union, Type | |
| from .method_registry import ( | |
| get_method_info, | |
| get_schema_for_method, | |
| is_valid_method, | |
| SchemaType, | |
| MethodType, | |
| DEFAULT_METHOD | |
| ) | |
| class ExtractionFactory: | |
| """Factory for creating knowledge extraction method instances""" | |
| def __init__(self): | |
| self._method_cache = {} | |
| self._schema_cache = {} | |
| def create_method(self, method_name: str, **kwargs) -> Any: | |
| """ | |
| Create an instance of the specified extraction method | |
| Args: | |
| method_name: Name of the method to create | |
| **kwargs: Additional arguments to pass to the method constructor | |
| Returns: | |
| Instance of the extraction method | |
| Raises: | |
| ValueError: If method_name is invalid | |
| ImportError: If method module cannot be loaded | |
| AttributeError: If method class cannot be found | |
| """ | |
| if not is_valid_method(method_name): | |
| raise ValueError(f"Unknown method: {method_name}") | |
| # Get method info from registry | |
| method_info = get_method_info(method_name) | |
| # Load the method class | |
| method_class = self._load_method_class(method_info) | |
| # Create instance based on processing type | |
| processing_type = method_info.get("processing_type", "direct_call") | |
| if processing_type == "async_crew": | |
| # For CrewAI-based methods, return the crew instance directly | |
| return method_class | |
| elif processing_type == "direct_call": | |
| # For baseline methods, instantiate the class | |
| return method_class(**kwargs) | |
| else: | |
| raise ValueError(f"Unknown processing type: {processing_type}") | |
| def _load_method_class(self, method_info: Dict[str, Any]) -> Type: | |
| """Load the method class from its module""" | |
| module_path = method_info["module_path"] | |
| class_name = method_info["class_name"] | |
| # Check cache first | |
| cache_key = f"{module_path}.{class_name}" | |
| if cache_key in self._method_cache: | |
| return self._method_cache[cache_key] | |
| try: | |
| # Import the module | |
| module = importlib.import_module(module_path) | |
| # Get the class from the module | |
| method_class = getattr(module, class_name) | |
| # Cache the class | |
| self._method_cache[cache_key] = method_class | |
| return method_class | |
| except ImportError as e: | |
| raise ImportError(f"Cannot import module {module_path}: {e}") | |
| except AttributeError as e: | |
| raise AttributeError(f"Cannot find class {class_name} in module {module_path}: {e}") | |
| def get_schema_models(self, method_name: str) -> Dict[str, Type]: | |
| """ | |
| Get the schema models for a specific method | |
| Args: | |
| method_name: Name of the method | |
| Returns: | |
| Dictionary with 'Entity', 'Relation', 'KnowledgeGraph' model classes | |
| """ | |
| if not is_valid_method(method_name): | |
| raise ValueError(f"Unknown method: {method_name}") | |
| schema_type = get_schema_for_method(method_name) | |
| # Check cache first | |
| if schema_type in self._schema_cache: | |
| return self._schema_cache[schema_type] | |
| if schema_type == SchemaType.REFERENCE_BASED: | |
| # Import reference-based models | |
| from agentgraph.shared.models.reference_based import Entity, Relation, KnowledgeGraph | |
| models = { | |
| 'Entity': Entity, | |
| 'Relation': Relation, | |
| 'KnowledgeGraph': KnowledgeGraph | |
| } | |
| elif schema_type == SchemaType.DIRECT_BASED: | |
| # Import direct-based models | |
| from agentgraph.shared.models.direct_based.models import Entity, Relation, KnowledgeGraph | |
| models = { | |
| 'Entity': Entity, | |
| 'Relation': Relation, | |
| 'KnowledgeGraph': KnowledgeGraph | |
| } | |
| else: | |
| raise ValueError(f"Unknown schema type: {schema_type}") | |
| # Cache the models | |
| self._schema_cache[schema_type] = models | |
| return models | |
| def get_method_schema_type(self, method_name: str) -> SchemaType: | |
| """Get the schema type for a method""" | |
| if not is_valid_method(method_name): | |
| raise ValueError(f"Unknown method: {method_name}") | |
| return get_schema_for_method(method_name) | |
| def requires_content_references(self, method_name: str) -> bool: | |
| """Check if a method requires content references (line numbers)""" | |
| if not is_valid_method(method_name): | |
| return False | |
| method_info = get_method_info(method_name) | |
| supported_features = method_info.get("supported_features", []) | |
| return "content_references" in supported_features | |
| def requires_line_numbers(self, method_name: str) -> bool: | |
| """Check if a method requires line numbers to be added to content""" | |
| if not is_valid_method(method_name): | |
| return False | |
| method_info = get_method_info(method_name) | |
| supported_features = method_info.get("supported_features", []) | |
| return "line_numbers" in supported_features | |
| def supports_failure_detection(self, method_name: str) -> bool: | |
| """Check if a method supports failure detection""" | |
| if not is_valid_method(method_name): | |
| return False | |
| method_info = get_method_info(method_name) | |
| supported_features = method_info.get("supported_features", []) | |
| return "failure_detection" in supported_features | |
| def get_processing_type(self, method_name: str) -> str: | |
| """Get the processing type for a method""" | |
| if not is_valid_method(method_name): | |
| raise ValueError(f"Unknown method: {method_name}") | |
| method_info = get_method_info(method_name) | |
| return method_info.get("processing_type", "direct_call") | |
| def clear_cache(self): | |
| """Clear the internal caches""" | |
| self._method_cache.clear() | |
| self._schema_cache.clear() | |
| # Global factory instance | |
| _factory = ExtractionFactory() | |
| def create_extraction_method(method_name: str = DEFAULT_METHOD, **kwargs) -> Any: | |
| """ | |
| Create an extraction method instance using the global factory | |
| Args: | |
| method_name: Name of the method to create (defaults to DEFAULT_METHOD) | |
| **kwargs: Additional arguments to pass to the method constructor | |
| Returns: | |
| Instance of the extraction method | |
| """ | |
| return _factory.create_method(method_name, **kwargs) | |
| def get_schema_models_for_method(method_name: str) -> Dict[str, Type]: | |
| """ | |
| Get schema models for a method using the global factory | |
| Args: | |
| method_name: Name of the method | |
| Returns: | |
| Dictionary with 'Entity', 'Relation', 'KnowledgeGraph' model classes | |
| """ | |
| return _factory.get_schema_models(method_name) | |
| def get_method_schema_type(method_name: str) -> SchemaType: | |
| """Get the schema type for a method using the global factory""" | |
| return _factory.get_method_schema_type(method_name) | |
| def method_requires_content_references(method_name: str) -> bool: | |
| """Check if a method requires content references using the global factory""" | |
| return _factory.requires_content_references(method_name) | |
| def method_requires_line_numbers(method_name: str) -> bool: | |
| """Check if a method requires line numbers using the global factory""" | |
| return _factory.requires_line_numbers(method_name) | |
| def method_supports_failure_detection(method_name: str) -> bool: | |
| """Check if a method supports failure detection using the global factory""" | |
| return _factory.supports_failure_detection(method_name) | |
| def get_method_processing_type(method_name: str) -> str: | |
| """Get the processing type for a method using the global factory""" | |
| return _factory.get_processing_type(method_name) | |
| def clear_extraction_factory_cache(): | |
| """Clear the global factory cache""" | |
| _factory.clear_cache() |