Spaces:
Runtime error
Runtime error
| # app/core/model_manager.py | |
| import logging | |
| import os | |
| import asyncio | |
| from pathlib import Path | |
| from typing import Callable, Optional, Dict, List | |
| # Imports for downloading specific model types | |
| import nltk | |
| from huggingface_hub import snapshot_download | |
| import spacy.cli | |
| # Internal application imports | |
| from app.core.config import ( | |
| MODELS_DIR, | |
| NLTK_DATA_DIR, | |
| SPACY_MODEL_ID, | |
| SENTENCE_TRANSFORMER_MODEL_ID, | |
| TONE_MODEL_ID, | |
| TRANSLATION_MODEL_ID, | |
| WORDNET_NLTK_ID, | |
| APP_NAME | |
| ) | |
| from app.core.exceptions import ModelNotDownloadedError, ModelDownloadFailedError, ServiceError | |
| logger = logging.getLogger(f"{APP_NAME}.core.model_manager") | |
| # Type alias for progress callback | |
| ProgressCallback = Callable[[str, str, float, Optional[str]], None] # (model_id, status, progress, message) | |
| def _get_hf_model_local_path(model_id: str) -> Path: | |
| """Helper to get the expected local path for a Hugging Face model.""" | |
| # snapshot_download creates a specific folder structure inside MODELS_DIR/hf_cache | |
| # For example, for "bert-base-uncased", it might be MODELS_DIR/hf_cache/models--bert-base-uncased | |
| # The actual model files are inside that. | |
| # The `transformers` library usually handles this resolution. | |
| # We just need to check if the directory created by snapshot_download exists. | |
| # A robust check involves looking inside that directory. | |
| return MODELS_DIR / "hf_cache" / model_id.replace("/", "--") # Standard HF cache path logic | |
| def check_model_exists(model_id: str, model_type: str) -> bool: | |
| """ | |
| Checks if a specific model or NLTK data is already downloaded locally. | |
| """ | |
| if model_type == "huggingface": | |
| local_path = _get_hf_model_local_path(model_id) | |
| # Check if the directory exists and contains some files | |
| return local_path.is_dir() and any(local_path.iterdir()) | |
| elif model_type == "spacy": | |
| # spaCy models are symlinked or copied into a specific site-packages location | |
| # The easiest check is to try loading it, or check spacy.util.is_package | |
| # For our purposes, we'll check if the directory created by `spacy download` exists | |
| # within our MODELS_DIR, assuming we direct spaCy there. | |
| # However, `spacy.load` is the most reliable. For pre-check, we'll rely on the | |
| # existence check in load_spacy_model. This is a simplified check. | |
| # The actual loading process in app.services.base handles the `is_package` check. | |
| # For `spacy.cli.download` to work with MODELS_DIR, it often requires setting SPACY_DATA. | |
| spacy_target_path = MODELS_DIR / model_id | |
| return spacy_target_path.is_dir() and any(spacy_target_path.iterdir()) | |
| elif model_type == "nltk": | |
| # NLTK data check | |
| try: | |
| return nltk.data.find(f"corpora/{model_id}") is not None | |
| except LookupError: | |
| return False | |
| else: | |
| logger.warning(f"Unknown model type for check_model_exists: {model_type}") | |
| return False | |
| # --- Download Functions --- | |
| async def download_hf_model_async( | |
| model_id: str, | |
| feature_name: str, | |
| progress_callback: Optional[ProgressCallback] = None | |
| ) -> None: | |
| """ | |
| Asynchronously downloads a Hugging Face model from the Hub. | |
| """ | |
| logger.info(f"Initiating download for Hugging Face model '{model_id}' for '{feature_name}'...") | |
| if check_model_exists(model_id, "huggingface"): | |
| logger.info(f"Hugging Face model '{model_id}' already exists locally. Skipping download.") | |
| if progress_callback: | |
| progress_callback(model_id, "completed", 1.0, "Already downloaded.") | |
| return | |
| # Use a thread pool for blocking download operation | |
| try: | |
| def _blocking_download(): | |
| # This downloads to MODELS_DIR/hf_cache by default if HF_HOME is set to MODELS_DIR | |
| # Otherwise, specify cache_dir. | |
| # For simplicity, we rely on `settings.MODELS_DIR` handling HF_HOME in config.py | |
| snapshot_download( | |
| repo_id=model_id, | |
| cache_dir=str(MODELS_DIR / "hf_cache"), # Explicitly set cache directory | |
| local_dir_use_symlinks=False, # Use False for better self-contained app | |
| # The `_` prefix means it's an internal parameter not typically exposed. | |
| # `progress_callback` in `snapshot_download` is not directly exposed for live updates. | |
| # We log at beginning and end. | |
| ) | |
| logger.info(f"Hugging Face model '{model_id}' download complete.") | |
| if progress_callback: | |
| progress_callback(model_id, "downloading", 0.05, "Starting download...") | |
| await asyncio.to_thread(_blocking_download) # Run blocking download in a separate thread | |
| if progress_callback: | |
| progress_callback(model_id, "completed", 1.0, "Download successful.") | |
| except Exception as e: | |
| logger.error(f"Failed to download Hugging Face model '{model_id}': {e}", exc_info=True) | |
| if progress_callback: | |
| progress_callback(model_id, "failed", 0.0, f"Error: {e}") | |
| raise ModelDownloadFailedError(model_id, feature_name, original_error=str(e)) | |
| async def download_spacy_model_async( | |
| model_id: str, | |
| feature_name: str, | |
| progress_callback: Optional[ProgressCallback] = None | |
| ) -> None: | |
| """ | |
| Asynchronously downloads a spaCy model. | |
| """ | |
| logger.info(f"Initiating download for spaCy model '{model_id}' for '{feature_name}'...") | |
| # Check if the model package is already installed/available in the spacy data path | |
| # NOTE: This check might not be sufficient if SPACY_DATA isn't correctly pointing. | |
| # The `spacy.util.is_package` would be more robust but requires `import spacy` first. | |
| # For now, we trust `spacy.cli.download` to handle the check or fail gracefully. | |
| # We must ensure SPACY_DATA environment variable is set to MODELS_DIR | |
| # for spacy.cli.download to put it in our custom path. | |
| original_spacy_data = os.environ.get("SPACY_DATA") | |
| try: | |
| os.environ["SPACY_DATA"] = str(MODELS_DIR) | |
| if check_model_exists(model_id, "spacy"): # Using our own simplified check | |
| logger.info(f"SpaCy model '{model_id}' already exists locally. Skipping download.") | |
| if progress_callback: | |
| progress_callback(model_id, "completed", 1.0, "Already downloaded.") | |
| return | |
| def _blocking_download(): | |
| # spacy.cli.download attempts to download and link/copy | |
| # It will raise an error if already downloaded if it can't link, etc. | |
| # We're relying on our check_model_exists before this. | |
| spacy.cli.download(model_id) | |
| logger.info(f"SpaCy model '{model_id}' download complete.") | |
| if progress_callback: | |
| progress_callback(model_id, "downloading", 0.05, "Starting download...") | |
| await asyncio.to_thread(_blocking_download) | |
| if progress_callback: | |
| progress_callback(model_id, "completed", 1.0, "Download successful.") | |
| except Exception as e: | |
| logger.error(f"Failed to download spaCy model '{model_id}': {e}", exc_info=True) | |
| if progress_callback: | |
| progress_callback(model_id, "failed", 0.0, f"Error: {e}") | |
| raise ModelDownloadFailedError(model_id, feature_name, original_error=str(e)) | |
| finally: | |
| # Restore original SPACY_DATA if it was set | |
| if original_spacy_data is not None: | |
| os.environ["SPACY_DATA"] = original_spacy_data | |
| else: | |
| if "SPACY_DATA" in os.environ: | |
| del os.environ["SPACY_DATA"] | |
| async def download_nltk_data_async( | |
| data_id: str, | |
| feature_name: str, | |
| progress_callback: Optional[ProgressCallback] = None | |
| ) -> None: | |
| """ | |
| Asynchronously downloads NLTK data. | |
| """ | |
| logger.info(f"Initiating download for NLTK data '{data_id}' for '{feature_name}'...") | |
| # NLTK data path should be set by NLTK_DATA environment variable in config.py | |
| # `nltk.download` will use this path. | |
| if check_model_exists(data_id, "nltk"): | |
| logger.info(f"NLTK data '{data_id}' already exists locally. Skipping download.") | |
| if progress_callback: | |
| progress_callback(data_id, "completed", 1.0, "Already downloaded.") | |
| return | |
| def _blocking_download(): | |
| # NLTK downloader can show a GUI, so ensure it's not trying to do that | |
| # `download_dir` should be set by NLTK_DATA env variable. | |
| # `quiet=True` is important for programmatic download. | |
| nltk.download(data_id, download_dir=str(NLTK_DATA_DIR), quiet=True) | |
| logger.info(f"NLTK data '{data_id}' download complete.") | |
| try: | |
| if progress_callback: | |
| progress_callback(data_id, "downloading", 0.05, "Starting download...") | |
| await asyncio.to_thread(_blocking_download) | |
| if progress_callback: | |
| progress_callback(data_id, "completed", 1.0, "Download successful.") | |
| except Exception as e: | |
| logger.error(f"Failed to download NLTK data '{data_id}': {e}", exc_info=True) | |
| if progress_callback: | |
| progress_callback(data_id, "failed", 0.0, f"Error: {e}") | |
| raise ModelDownloadFailedError(data_id, feature_name, original_error=str(e)) | |
| # --- Comprehensive Model Management --- | |
| def get_all_required_models() -> List[Dict]: | |
| """ | |
| Returns a list of all models required by the application, with their type and feature. | |
| """ | |
| return [ | |
| {"id": SPACY_MODEL_ID, "type": "spacy", "feature": "Text Processing (General)"}, | |
| {"id": SENTENCE_TRANSFORMER_MODEL_ID, "type": "huggingface", "feature": "Sentence Embeddings"}, | |
| {"id": TONE_MODEL_ID, "type": "huggingface", "feature": "Tone Classification"}, | |
| {"id": TRANSLATION_MODEL_ID, "type": "huggingface", "feature": "Translation"}, | |
| {"id": WORDNET_NLTK_ID, "type": "nltk", "feature": "Synonym Suggestion"}, | |
| # Add any other models here as your application grows | |
| ] | |
| async def download_all_required_models(progress_callback: Optional[ProgressCallback] = None) -> Dict[str, str]: | |
| """ | |
| Attempts to download all required models. | |
| Returns a dictionary of download statuses. | |
| """ | |
| required_models = get_all_required_models() | |
| download_statuses = {} | |
| for model_info in required_models: | |
| model_id = model_info["id"] | |
| model_type = model_info["type"] | |
| feature_name = model_info["feature"] | |
| if check_model_exists(model_id, model_type): | |
| status_message = f"'{model_id}' ({feature_name}) already downloaded." | |
| logger.info(status_message) | |
| download_statuses[model_id] = "already_downloaded" | |
| if progress_callback: | |
| progress_callback(model_id, "completed", 1.0, status_message) | |
| continue | |
| logger.info(f"Attempting to download '{model_id}' ({feature_name})...") | |
| try: | |
| if model_type == "huggingface": | |
| await download_hf_model_async(model_id, feature_name, progress_callback) | |
| elif model_type == "spacy": | |
| await download_spacy_model_async(model_id, feature_name, progress_callback) | |
| elif model_type == "nltk": | |
| await download_nltk_data_async(model_id, feature_name, progress_callback) | |
| else: | |
| raise ValueError(f"Unsupported model type: {model_type}") | |
| status_message = f"'{model_id}' ({feature_name}) downloaded successfully." | |
| logger.info(status_message) | |
| download_statuses[model_id] = "success" | |
| except ModelDownloadFailedError as e: | |
| status_message = f"Failed to download '{model_id}' ({feature_name}): {e.original_error}" | |
| logger.error(status_message) | |
| download_statuses[model_id] = "failed" | |
| # The progress_callback is already called within the specific download functions on failure | |
| except Exception as e: | |
| status_message = f"An unexpected error occurred while downloading '{model_id}' ({feature_name}): {e}" | |
| logger.error(status_message, exc_info=True) | |
| download_statuses[model_id] = "failed" | |
| if progress_callback: | |
| progress_callback(model_id, "failed", 0.0, status_message) | |
| logger.info("Finished attempting to download all required models.") | |
| return download_statuses |