File size: 8,593 Bytes
c2ea5ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
Extraction Factory for Knowledge Extraction Methods

This module provides a factory for creating instances of knowledge extraction methods
based on their registry configuration. It handles dynamic loading and instantiation.
"""

import importlib
import inspect
from typing import Any, Dict, Optional, Union, Type
from .method_registry import (
    get_method_info, 
    get_schema_for_method, 
    is_valid_method,
    SchemaType,
    MethodType,
    DEFAULT_METHOD
)


class ExtractionFactory:
    """Factory for creating knowledge extraction method instances"""
    
    def __init__(self):
        self._method_cache = {}
        self._schema_cache = {}
    
    def create_method(self, method_name: str, **kwargs) -> Any:
        """
        Create an instance of the specified extraction method
        
        Args:
            method_name: Name of the method to create
            **kwargs: Additional arguments to pass to the method constructor
            
        Returns:
            Instance of the extraction method
            
        Raises:
            ValueError: If method_name is invalid
            ImportError: If method module cannot be loaded
            AttributeError: If method class cannot be found
        """
        if not is_valid_method(method_name):
            raise ValueError(f"Unknown method: {method_name}")
        
        # Get method info from registry
        method_info = get_method_info(method_name)
        
        # Load the method class
        method_class = self._load_method_class(method_info)
        
        # Create instance based on processing type
        processing_type = method_info.get("processing_type", "direct_call")
        
        if processing_type == "async_crew":
            # For CrewAI-based methods, return the crew instance directly
            return method_class
        elif processing_type == "direct_call":
            # For baseline methods, instantiate the class
            return method_class(**kwargs)
        else:
            raise ValueError(f"Unknown processing type: {processing_type}")
    
    def _load_method_class(self, method_info: Dict[str, Any]) -> Type:
        """Load the method class from its module"""
        module_path = method_info["module_path"]
        class_name = method_info["class_name"]
        
        # Check cache first
        cache_key = f"{module_path}.{class_name}"
        if cache_key in self._method_cache:
            return self._method_cache[cache_key]
        
        try:
            # Import the module
            module = importlib.import_module(module_path)
            
            # Get the class from the module
            method_class = getattr(module, class_name)
            
            # Cache the class
            self._method_cache[cache_key] = method_class
            
            return method_class
            
        except ImportError as e:
            raise ImportError(f"Cannot import module {module_path}: {e}")
        except AttributeError as e:
            raise AttributeError(f"Cannot find class {class_name} in module {module_path}: {e}")
    
    def get_schema_models(self, method_name: str) -> Dict[str, Type]:
        """
        Get the schema models for a specific method
        
        Args:
            method_name: Name of the method
            
        Returns:
            Dictionary with 'Entity', 'Relation', 'KnowledgeGraph' model classes
        """
        if not is_valid_method(method_name):
            raise ValueError(f"Unknown method: {method_name}")
        
        schema_type = get_schema_for_method(method_name)
        
        # Check cache first
        if schema_type in self._schema_cache:
            return self._schema_cache[schema_type]
        
        if schema_type == SchemaType.REFERENCE_BASED:
            # Import reference-based models
            from agentgraph.shared.models.reference_based import Entity, Relation, KnowledgeGraph
            models = {
                'Entity': Entity,
                'Relation': Relation,
                'KnowledgeGraph': KnowledgeGraph
            }
        elif schema_type == SchemaType.DIRECT_BASED:
            # Import direct-based models
            from agentgraph.shared.models.direct_based.models import Entity, Relation, KnowledgeGraph
            models = {
                'Entity': Entity,
                'Relation': Relation,
                'KnowledgeGraph': KnowledgeGraph
            }
        else:
            raise ValueError(f"Unknown schema type: {schema_type}")
        
        # Cache the models
        self._schema_cache[schema_type] = models
        
        return models
    
    def get_method_schema_type(self, method_name: str) -> SchemaType:
        """Get the schema type for a method"""
        if not is_valid_method(method_name):
            raise ValueError(f"Unknown method: {method_name}")
        
        return get_schema_for_method(method_name)
    
    def requires_content_references(self, method_name: str) -> bool:
        """Check if a method requires content references (line numbers)"""
        if not is_valid_method(method_name):
            return False
        
        method_info = get_method_info(method_name)
        supported_features = method_info.get("supported_features", [])
        return "content_references" in supported_features
    
    def requires_line_numbers(self, method_name: str) -> bool:
        """Check if a method requires line numbers to be added to content"""
        if not is_valid_method(method_name):
            return False
        
        method_info = get_method_info(method_name)
        supported_features = method_info.get("supported_features", [])
        return "line_numbers" in supported_features
    
    def supports_failure_detection(self, method_name: str) -> bool:
        """Check if a method supports failure detection"""
        if not is_valid_method(method_name):
            return False
        
        method_info = get_method_info(method_name)
        supported_features = method_info.get("supported_features", [])
        return "failure_detection" in supported_features
    
    def get_processing_type(self, method_name: str) -> str:
        """Get the processing type for a method"""
        if not is_valid_method(method_name):
            raise ValueError(f"Unknown method: {method_name}")
        
        method_info = get_method_info(method_name)
        return method_info.get("processing_type", "direct_call")
    
    def clear_cache(self):
        """Clear the internal caches"""
        self._method_cache.clear()
        self._schema_cache.clear()


# Global factory instance
_factory = ExtractionFactory()


def create_extraction_method(method_name: str = DEFAULT_METHOD, **kwargs) -> Any:
    """
    Create an extraction method instance using the global factory
    
    Args:
        method_name: Name of the method to create (defaults to DEFAULT_METHOD)
        **kwargs: Additional arguments to pass to the method constructor
        
    Returns:
        Instance of the extraction method
    """
    return _factory.create_method(method_name, **kwargs)


def get_schema_models_for_method(method_name: str) -> Dict[str, Type]:
    """
    Get schema models for a method using the global factory
    
    Args:
        method_name: Name of the method
        
    Returns:
        Dictionary with 'Entity', 'Relation', 'KnowledgeGraph' model classes
    """
    return _factory.get_schema_models(method_name)


def get_method_schema_type(method_name: str) -> SchemaType:
    """Get the schema type for a method using the global factory"""
    return _factory.get_method_schema_type(method_name)


def method_requires_content_references(method_name: str) -> bool:
    """Check if a method requires content references using the global factory"""
    return _factory.requires_content_references(method_name)


def method_requires_line_numbers(method_name: str) -> bool:
    """Check if a method requires line numbers using the global factory"""
    return _factory.requires_line_numbers(method_name)


def method_supports_failure_detection(method_name: str) -> bool:
    """Check if a method supports failure detection using the global factory"""
    return _factory.supports_failure_detection(method_name)


def get_method_processing_type(method_name: str) -> str:
    """Get the processing type for a method using the global factory"""
    return _factory.get_processing_type(method_name)


def clear_extraction_factory_cache():
    """Clear the global factory cache"""
    _factory.clear_cache()