Spaces:
Paused
Paused
| """ | |
| Data Source Manager | |
| This module provides the central manager for all data sources, implementing | |
| the singleton pattern for thread-safe access across the application. | |
| """ | |
| import logging | |
| import threading | |
| from typing import Any, Dict, Iterator, List, Optional, Type, TYPE_CHECKING | |
| from potato.data_sources.base import DataSource, SourceConfig, SourceType | |
| from potato.data_sources.credentials import CredentialManager | |
| from potato.data_sources.cache_manager import CacheManager | |
| from potato.data_sources.partial_reader import PartialReader, PartialLoadingConfig | |
| if TYPE_CHECKING: | |
| from potato.item_state_management import ItemStateManager | |
| logger = logging.getLogger(__name__) | |
| # Singleton instance with thread-safe initialization | |
| DATA_SOURCE_MANAGER: Optional["DataSourceManager"] = None | |
| _MANAGER_LOCK = threading.Lock() | |
| # Registry of source type implementations | |
| _SOURCE_REGISTRY: Dict[SourceType, Type[DataSource]] = {} | |
| def register_source_type(source_type: SourceType, source_class: Type[DataSource]) -> None: | |
| """ | |
| Register a data source implementation. | |
| Args: | |
| source_type: The SourceType enum value | |
| source_class: The DataSource subclass | |
| """ | |
| _SOURCE_REGISTRY[source_type] = source_class | |
| logger.debug(f"Registered source type: {source_type.value} -> {source_class.__name__}") | |
| def get_source_class(source_type: SourceType) -> Optional[Type[DataSource]]: | |
| """ | |
| Get the source class for a given type. | |
| Args: | |
| source_type: The SourceType to look up | |
| Returns: | |
| The DataSource subclass, or None if not registered | |
| """ | |
| return _SOURCE_REGISTRY.get(source_type) | |
| def get_registered_types() -> List[str]: | |
| """Get list of registered source type names.""" | |
| return [t.value for t in _SOURCE_REGISTRY.keys()] | |
| class DataSourceManager: | |
| """ | |
| Central manager for all data sources. | |
| This class provides: | |
| - Registration and lifecycle management of data sources | |
| - Credential management with environment variable substitution | |
| - Caching for remote sources | |
| - Partial/incremental loading coordination | |
| - Thread-safe access to all sources | |
| Attributes: | |
| config: The application configuration | |
| credential_manager: Handles credential resolution | |
| cache_manager: Manages cached remote files | |
| partial_reader: Coordinates incremental loading | |
| """ | |
| def __init__( | |
| self, | |
| config: Dict[str, Any], | |
| item_state_manager: "ItemStateManager" | |
| ): | |
| """ | |
| Initialize the data source manager. | |
| Args: | |
| config: Application configuration dictionary | |
| item_state_manager: The ItemStateManager for adding items | |
| """ | |
| self._config = config | |
| self._item_state_manager = item_state_manager | |
| self._sources: Dict[str, DataSource] = {} | |
| self._lock = threading.RLock() | |
| # Initialize sub-managers | |
| self.credential_manager = CredentialManager.from_config(config) | |
| # Set up cache manager if caching is enabled | |
| cache_config = config.get("data_cache", {}) | |
| if cache_config.get("enabled", True): | |
| cache_dir = cache_config.get( | |
| "cache_dir", | |
| ".potato_cache/data_sources" | |
| ) | |
| # Resolve relative to task_dir | |
| task_dir = config.get("task_dir", ".") | |
| if not cache_dir.startswith("/"): | |
| import os | |
| cache_dir = os.path.join(task_dir, cache_dir) | |
| self.cache_manager = CacheManager( | |
| cache_dir=cache_dir, | |
| ttl_seconds=cache_config.get("ttl_seconds", 3600), | |
| max_size_mb=cache_config.get("max_size_mb", 500) | |
| ) | |
| else: | |
| self.cache_manager = None | |
| # Set up partial reader if incremental loading is configured | |
| partial_config = PartialLoadingConfig.from_dict(config) | |
| if partial_config.enabled: | |
| output_dir = config.get("output_annotation_dir", ".") | |
| self.partial_reader = PartialReader(partial_config, output_dir) | |
| else: | |
| self.partial_reader = None | |
| # Get item property keys | |
| item_props = config.get("item_properties", {}) | |
| self._id_key = item_props.get("id_key", "id") | |
| self._text_key = item_props.get("text_key", "text") | |
| # Initialize sources from configuration | |
| self._init_sources() | |
| logger.info(f"DataSourceManager initialized with {len(self._sources)} sources") | |
| def _init_sources(self) -> None: | |
| """Initialize data sources from configuration.""" | |
| data_sources = self._config.get("data_sources", []) | |
| for index, source_dict in enumerate(data_sources): | |
| try: | |
| # Process credentials in the source config | |
| processed_config = self.credential_manager.process_config(source_dict) | |
| # Parse source configuration | |
| source_config = SourceConfig.from_dict(processed_config, index) | |
| if not source_config.enabled: | |
| logger.debug(f"Skipping disabled source: {source_config.source_id}") | |
| continue | |
| # Get the source class for this type | |
| source_class = get_source_class(source_config.source_type) | |
| if not source_class: | |
| logger.warning( | |
| f"No implementation for source type: {source_config.source_type.value}. " | |
| f"Available types: {get_registered_types()}" | |
| ) | |
| continue | |
| # Create the source instance | |
| source = source_class(source_config) | |
| # Validate configuration | |
| errors = source.validate_config() | |
| if errors: | |
| logger.error( | |
| f"Invalid configuration for source {source_config.source_id}: " | |
| f"{'; '.join(errors)}" | |
| ) | |
| continue | |
| # Check availability | |
| if not source.is_available(): | |
| logger.warning( | |
| f"Source {source_config.source_id} is not available. " | |
| f"Check dependencies and credentials." | |
| ) | |
| # Still register the source, but log the warning | |
| # It may become available later | |
| self._sources[source_config.source_id] = source | |
| logger.info( | |
| f"Initialized source: {source_config.source_id} " | |
| f"(type={source_config.source_type.value})" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize source at index {index}: {e}") | |
| def get_source(self, source_id: str) -> Optional[DataSource]: | |
| """ | |
| Get a data source by ID. | |
| Args: | |
| source_id: The source identifier | |
| Returns: | |
| The DataSource instance, or None if not found | |
| """ | |
| with self._lock: | |
| return self._sources.get(source_id) | |
| def get_all_sources(self) -> Dict[str, DataSource]: | |
| """ | |
| Get all registered sources. | |
| Returns: | |
| Dictionary mapping source_id to DataSource | |
| """ | |
| with self._lock: | |
| return dict(self._sources) | |
| def list_sources(self) -> List[Dict[str, Any]]: | |
| """ | |
| List all sources with their status. | |
| Returns: | |
| List of source status dictionaries | |
| """ | |
| with self._lock: | |
| statuses = [] | |
| for source in self._sources.values(): | |
| status = source.get_status() | |
| # Add partial loading state if available | |
| if self.partial_reader: | |
| state = self.partial_reader.get_state(source.source_id) | |
| if state: | |
| status["items_loaded"] = state.items_loaded | |
| status["is_complete"] = state.is_complete | |
| status["last_loaded_at"] = state.last_loaded_at | |
| statuses.append(status) | |
| return statuses | |
| def load_initial_data(self) -> int: | |
| """ | |
| Load initial data from all sources. | |
| If partial loading is enabled, loads only the initial_count items | |
| from each source. Otherwise, loads all data. | |
| Returns: | |
| Total number of items loaded | |
| """ | |
| total_loaded = 0 | |
| with self._lock: | |
| for source_id, source in self._sources.items(): | |
| try: | |
| count = self._load_from_source(source, is_initial=True) | |
| total_loaded += count | |
| logger.info(f"Loaded {count} items from {source_id}") | |
| except Exception as e: | |
| logger.error(f"Failed to load from {source_id}: {e}") | |
| return total_loaded | |
| def load_more( | |
| self, | |
| source_id: str, | |
| count: Optional[int] = None | |
| ) -> int: | |
| """ | |
| Load more items from a specific source. | |
| Args: | |
| source_id: The source to load from | |
| count: Number of items to load (uses batch_size if not specified) | |
| Returns: | |
| Number of items loaded | |
| Raises: | |
| ValueError: If source_id is not found | |
| """ | |
| with self._lock: | |
| source = self._sources.get(source_id) | |
| if not source: | |
| raise ValueError(f"Unknown source: {source_id}") | |
| return self._load_from_source(source, is_initial=False, count=count) | |
| def _load_from_source( | |
| self, | |
| source: DataSource, | |
| is_initial: bool = True, | |
| count: Optional[int] = None | |
| ) -> int: | |
| """ | |
| Load items from a source into the ItemStateManager. | |
| Args: | |
| source: The data source | |
| is_initial: Whether this is the initial load | |
| count: Number of items to load (overrides config) | |
| Returns: | |
| Number of items loaded | |
| """ | |
| source_id = source.source_id | |
| # Check if source is complete | |
| if self.partial_reader: | |
| state = self.partial_reader.get_state(source_id) | |
| if state and state.is_complete: | |
| logger.debug(f"Source {source_id} is already complete") | |
| return 0 | |
| # Determine how many items to load and from what position | |
| if self.partial_reader and self.partial_reader.config.enabled: | |
| start = self.partial_reader.get_start_position(source_id) | |
| if count is None: | |
| count = self.partial_reader.get_load_count(source_id, is_initial) | |
| else: | |
| start = 0 | |
| count = None # Load all | |
| # Check if source supports partial reading | |
| if start > 0 and not source.supports_partial_reading(): | |
| logger.warning( | |
| f"Source {source_id} does not support partial reading, " | |
| f"cannot continue from position {start}" | |
| ) | |
| return 0 | |
| # Load items | |
| items_loaded = 0 | |
| is_complete = False | |
| try: | |
| for item in source.read_items(start=start, count=count): | |
| # Validate ID key exists | |
| if self._id_key not in item: | |
| logger.warning( | |
| f"Missing id_key '{self._id_key}' in item from {source_id}" | |
| ) | |
| continue | |
| instance_id = str(item[self._id_key]) | |
| # Check for duplicates | |
| if self._item_state_manager.has_item(instance_id): | |
| logger.debug(f"Skipping duplicate ID: {instance_id}") | |
| continue | |
| # Add item to state manager | |
| try: | |
| self._item_state_manager.add_item(instance_id, item) | |
| items_loaded += 1 | |
| except ValueError as e: | |
| logger.warning(f"Failed to add item {instance_id}: {e}") | |
| # Check if we loaded fewer items than requested (source exhausted) | |
| if count is not None and items_loaded < count: | |
| is_complete = True | |
| except StopIteration: | |
| is_complete = True | |
| # Update partial reader state | |
| if self.partial_reader: | |
| total_estimate = source.get_total_count() | |
| self.partial_reader.update_state( | |
| source_id=source_id, | |
| items_added=items_loaded, | |
| is_complete=is_complete, | |
| total_estimate=total_estimate | |
| ) | |
| return items_loaded | |
| def refresh_source(self, source_id: str) -> bool: | |
| """ | |
| Refresh a data source (re-fetch from remote). | |
| Args: | |
| source_id: The source to refresh | |
| Returns: | |
| True if refresh was successful | |
| Raises: | |
| ValueError: If source_id is not found | |
| """ | |
| with self._lock: | |
| source = self._sources.get(source_id) | |
| if not source: | |
| raise ValueError(f"Unknown source: {source_id}") | |
| # Invalidate cache | |
| if self.cache_manager: | |
| self.cache_manager.invalidate(source_id) | |
| # Reset partial reader state | |
| if self.partial_reader: | |
| self.partial_reader.reset_state(source_id) | |
| return source.refresh() | |
| def check_auto_load( | |
| self, | |
| annotated_count: int, | |
| total_loaded: int | |
| ) -> Dict[str, int]: | |
| """ | |
| Check if any sources should auto-load more data. | |
| Args: | |
| annotated_count: Total number of annotated items | |
| total_loaded: Total number of loaded items | |
| Returns: | |
| Dictionary mapping source_id to items loaded (for sources that triggered) | |
| """ | |
| if not self.partial_reader or not self.partial_reader.config.auto_load_enabled: | |
| return {} | |
| results = {} | |
| with self._lock: | |
| for source_id, source in self._sources.items(): | |
| if self.partial_reader.should_load_more( | |
| source_id, annotated_count, total_loaded | |
| ): | |
| try: | |
| loaded = self._load_from_source(source, is_initial=False) | |
| if loaded > 0: | |
| results[source_id] = loaded | |
| logger.info( | |
| f"Auto-loaded {loaded} items from {source_id}" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Auto-load failed for {source_id}: {e}") | |
| return results | |
| def clear_cache(self) -> int: | |
| """ | |
| Clear the download cache for all sources. | |
| Returns: | |
| Number of cache entries cleared | |
| """ | |
| if self.cache_manager: | |
| return self.cache_manager.clear() | |
| return 0 | |
| def get_stats(self) -> Dict[str, Any]: | |
| """ | |
| Get comprehensive statistics. | |
| Returns: | |
| Dictionary with source and manager statistics | |
| """ | |
| stats = { | |
| "source_count": len(self._sources), | |
| "sources": self.list_sources(), | |
| } | |
| if self.cache_manager: | |
| stats["cache"] = self.cache_manager.get_stats() | |
| if self.partial_reader: | |
| stats["partial_loading"] = self.partial_reader.get_stats() | |
| return stats | |
| def close(self) -> None: | |
| """Close all sources and release resources.""" | |
| with self._lock: | |
| for source in self._sources.values(): | |
| try: | |
| source.close() | |
| except Exception as e: | |
| logger.warning(f"Error closing source {source.source_id}: {e}") | |
| self._sources.clear() | |
| def init_data_source_manager(config: Dict[str, Any]) -> Optional[DataSourceManager]: | |
| """ | |
| Initialize the global DataSourceManager singleton. | |
| This function creates the manager if data_sources is configured in | |
| the configuration. Thread-safe initialization using double-checked | |
| locking pattern. | |
| Args: | |
| config: Application configuration dictionary | |
| Returns: | |
| The DataSourceManager instance, or None if not configured | |
| """ | |
| global DATA_SOURCE_MANAGER | |
| # Check if data_sources is configured | |
| if "data_sources" not in config or not config["data_sources"]: | |
| return None | |
| # Double-checked locking | |
| if DATA_SOURCE_MANAGER is None: | |
| with _MANAGER_LOCK: | |
| if DATA_SOURCE_MANAGER is None: | |
| from potato.item_state_management import get_item_state_manager | |
| ism = get_item_state_manager() | |
| DATA_SOURCE_MANAGER = DataSourceManager(config, ism) | |
| return DATA_SOURCE_MANAGER | |
| def get_data_source_manager() -> Optional[DataSourceManager]: | |
| """ | |
| Get the global DataSourceManager singleton. | |
| Returns: | |
| The DataSourceManager instance, or None if not initialized | |
| """ | |
| return DATA_SOURCE_MANAGER | |
| def clear_data_source_manager() -> None: | |
| """ | |
| Clear the global DataSourceManager singleton. | |
| This function closes all sources and clears the singleton instance. | |
| Thread-safe. Used primarily for testing. | |
| """ | |
| global DATA_SOURCE_MANAGER | |
| with _MANAGER_LOCK: | |
| if DATA_SOURCE_MANAGER is not None: | |
| DATA_SOURCE_MANAGER.close() | |
| DATA_SOURCE_MANAGER = None | |