Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| Pre-download and cache models for Hugging Face Spaces deployment. | |
| Run this during Docker build to avoid runtime downloads. | |
| PRE-CACHED MODELS (downloaded during build): | |
| - facebook/bart-large-cnn (Summarization) | |
| - patrickvonplaten/longformer2roberta-cnn_dailymail-fp16 (Seq2Seq) | |
| - google/flan-t5-large (Summarization) | |
| - microsoft/Phi-3-mini-4k-instruct (Causal OpenVINO) | |
| - OpenVINO/Phi-3-mini-4k-instruct-fp16-ov (Causal OpenVINO) | |
| - microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf (GGUF - PRIMARY) | |
| RUNTIME BEHAVIOR: | |
| - If you request a pre-cached model: Loads instantly from cache (30-60 sec) | |
| - If you request a different model: Downloads and uses at runtime automatically | |
| - System supports both pre-cached and on-demand model loading | |
| PRIMARY MODEL for patient summaries: | |
| - microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf (is_active: true) | |
| """ | |
| import os | |
| import sys | |
| import logging | |
| from pathlib import Path | |
| # Add src to path for benchmarking | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| project_root = os.path.dirname(current_dir) | |
| sys.path.insert(0, os.path.join(project_root, "services", "ai-service", "src")) | |
| try: | |
| from ai_med_extract.utils.benchmark import BenchmarkContext | |
| except ImportError: | |
| # Fallback if path is wrong or module missing (though we set path) | |
| logging.warning("Benchmark module not found. creating dummy context.") | |
| class BenchmarkContext: | |
| def __init__(self, *args, **kwargs): pass | |
| def __enter__(self): return self | |
| def __exit__(self, *args): pass | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Set cache directories - these will be baked into the Docker image | |
| MODEL_CACHE_DIR = os.environ.get('MODEL_CACHE_DIR', '/app/models') | |
| HF_HOME = os.environ.get('HF_HOME', '/app/.cache/huggingface') | |
| TORCH_HOME = os.environ.get('TORCH_HOME', '/app/.cache/torch') | |
| WHISPER_CACHE = os.environ.get('WHISPER_CACHE', '/app/.cache/whisper') | |
| # Create cache directories | |
| for cache_dir in [MODEL_CACHE_DIR, HF_HOME, TORCH_HOME, WHISPER_CACHE]: | |
| Path(cache_dir).mkdir(parents=True, exist_ok=True) | |
| logger.info(f"Created cache directory: {cache_dir}") | |
| def preload_transformers_models(): | |
| """Pre-download Hugging Face transformers models""" | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM | |
| from huggingface_hub import snapshot_download | |
| # Models for patient summary generation - as specified by user | |
| models = [ | |
| # Summarization models | |
| { | |
| "name": "facebook/bart-large-cnn", | |
| "type": "seq2seq", | |
| "description": "BART Large CNN - Summarization", | |
| "is_active": False # Available but not primary | |
| }, | |
| { | |
| "name": "patrickvonplaten/longformer2roberta-cnn_dailymail-fp16", | |
| "type": "seq2seq", | |
| "description": "Longformer2Roberta - Seq2Seq Summarization", | |
| "is_active": False | |
| }, | |
| { | |
| "name": "google/flan-t5-large", | |
| "type": "seq2seq", | |
| "description": "FLAN-T5 Large - Summarization", | |
| "is_active": False | |
| }, | |
| # OpenVINO models for patient summaries | |
| { | |
| "name": "microsoft/Phi-3-mini-4k-instruct", | |
| "type": "causal", | |
| "description": "Phi-3 Mini - Causal OpenVINO (base model)", | |
| "is_active": False | |
| }, | |
| { | |
| "name": "OpenVINO/Phi-3-mini-4k-instruct-fp16-ov", | |
| "type": "causal", | |
| "description": "Phi-3 Mini - FP16 OpenVINO optimized", | |
| "is_active": False | |
| }, | |
| ] | |
| for model_info in models: | |
| model_name = model_info["name"] | |
| model_type = model_info["type"] | |
| description = model_info["description"] | |
| try: | |
| logger.info(f"π₯ Downloading {description}: {model_name}") | |
| # Download tokenizer | |
| logger.info(f" β³ Downloading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| cache_dir=HF_HOME, | |
| trust_remote_code=False | |
| ) | |
| # Download model | |
| logger.info(f" β³ Downloading model weights...") | |
| if model_type == "seq2seq": | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_name, | |
| cache_dir=HF_HOME, | |
| trust_remote_code=False | |
| ) | |
| else: | |
| # For token classification and other types | |
| from transformers import AutoModel | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| cache_dir=HF_HOME, | |
| trust_remote_code=False | |
| ) | |
| logger.info(f" β Successfully cached {model_name}") | |
| # Clean up memory | |
| del model | |
| del tokenizer | |
| except Exception as e: | |
| logger.error(f" β Failed to download {model_name}: {e}") | |
| # Don't fail the entire script if one model fails | |
| continue | |
| def preload_gguf_models(): | |
| """Pre-download GGUF models""" | |
| from huggingface_hub import hf_hub_download | |
| # GGUF model for patient summaries - PRIMARY MODEL (is_active: true) | |
| gguf_models = [ | |
| { | |
| "repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf", | |
| "filename": "Phi-3-mini-4k-instruct-q4.gguf", | |
| "description": "Phi-3 Mini GGUF (Q4 quantized) - PRIMARY for patient summaries", | |
| "is_active": True # This is the active model for patient summaries | |
| } | |
| ] | |
| for model_info in gguf_models: | |
| try: | |
| logger.info(f"π₯ Downloading GGUF: {model_info['description']}") | |
| file_path = hf_hub_download( | |
| repo_id=model_info["repo_id"], | |
| filename=model_info["filename"], | |
| cache_dir=HF_HOME, | |
| local_dir=MODEL_CACHE_DIR, | |
| local_dir_use_symlinks=False # Copy files instead of symlinks | |
| ) | |
| logger.info(f" β Successfully cached GGUF model at: {file_path}") | |
| except Exception as e: | |
| logger.error(f" β Failed to download GGUF model: {e}") | |
| continue | |
| def preload_whisper_models(): | |
| """Pre-download Whisper models""" | |
| try: | |
| logger.info(f"π₯ Downloading Whisper tiny model...") | |
| import whisper | |
| model = whisper.load_model( | |
| "tiny", | |
| device="cpu", | |
| download_root=WHISPER_CACHE | |
| ) | |
| logger.info(f" β Successfully cached Whisper tiny model") | |
| del model | |
| except Exception as e: | |
| logger.error(f" β Failed to download Whisper model: {e}") | |
| def preload_spacy_models(): | |
| """Pre-download spaCy models""" | |
| try: | |
| logger.info(f"π₯ Loading spaCy en_core_web_sm model...") | |
| import spacy | |
| nlp = spacy.load("en_core_web_sm") | |
| logger.info(f" β Successfully loaded spaCy model") | |
| except Exception as e: | |
| logger.error(f" β Failed to load spaCy model: {e}") | |
| def preload_nltk_data(): | |
| """Pre-download NLTK data""" | |
| try: | |
| logger.info(f"π₯ Downloading NLTK data...") | |
| import nltk | |
| nltk_data_dir = os.path.join(HF_HOME, 'nltk_data') | |
| Path(nltk_data_dir).mkdir(parents=True, exist_ok=True) | |
| # Download common NLTK datasets | |
| for package in ['punkt', 'stopwords', 'wordnet', 'averaged_perceptron_tagger']: | |
| try: | |
| nltk.download(package, download_dir=nltk_data_dir, quiet=True) | |
| logger.info(f" β Downloaded NLTK package: {package}") | |
| except: | |
| logger.warning(f" β οΈ Failed to download NLTK package: {package}") | |
| except Exception as e: | |
| logger.error(f" β Failed to download NLTK data: {e}") | |
| def print_cache_summary(): | |
| """Print summary of cached models""" | |
| logger.info("\n" + "="*80) | |
| logger.info("CACHE SUMMARY") | |
| logger.info("="*80) | |
| for cache_dir in [MODEL_CACHE_DIR, HF_HOME, TORCH_HOME, WHISPER_CACHE]: | |
| if os.path.exists(cache_dir): | |
| # Calculate directory size | |
| total_size = 0 | |
| file_count = 0 | |
| for dirpath, dirnames, filenames in os.walk(cache_dir): | |
| for f in filenames: | |
| fp = os.path.join(dirpath, f) | |
| if os.path.exists(fp): | |
| total_size += os.path.getsize(fp) | |
| file_count += 1 | |
| size_mb = total_size / (1024 * 1024) | |
| size_gb = size_mb / 1024 | |
| logger.info(f"\nπ {cache_dir}") | |
| logger.info(f" Files: {file_count}") | |
| logger.info(f" Size: {size_mb:.2f} MB ({size_gb:.2f} GB)") | |
| logger.info("\n" + "="*80) | |
| def main(): | |
| """Main preload function""" | |
| logger.info("π Starting model pre-download process...") | |
| logger.info(f" HF_HOME: {HF_HOME}") | |
| logger.info(f" MODEL_CACHE_DIR: {MODEL_CACHE_DIR}") | |
| logger.info(f" TORCH_HOME: {TORCH_HOME}") | |
| logger.info(f" WHISPER_CACHE: {WHISPER_CACHE}") | |
| logger.info("") | |
| # Import torch early to ensure CUDA detection works | |
| try: | |
| import torch | |
| logger.info(f"π§ PyTorch version: {torch.__version__}") | |
| logger.info(f"π§ CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| logger.info(f"π§ CUDA version: {torch.version.cuda}") | |
| logger.info(f"π§ GPU: {torch.cuda.get_device_name(0)}") | |
| except Exception as e: | |
| logger.warning(f"β οΈ Could not detect PyTorch/CUDA info: {e}") | |
| logger.info("") | |
| # Preload all models | |
| steps = [ | |
| ("Transformers Models", preload_transformers_models), | |
| ("GGUF Models", preload_gguf_models), | |
| ("Whisper Models", preload_whisper_models), | |
| ("spaCy Models", preload_spacy_models), | |
| ("NLTK Data", preload_nltk_data), | |
| ] | |
| for step_name, step_func in steps: | |
| logger.info(f"\n{'='*80}") | |
| logger.info(f"STEP: {step_name}") | |
| logger.info(f"{'='*80}\n") | |
| try: | |
| with BenchmarkContext(f"preload_{step_name.replace(' ', '_')}"): | |
| step_func() | |
| except Exception as e: | |
| logger.error(f"β Failed during {step_name}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Print summary | |
| print_cache_summary() | |
| logger.info("\nβ Model pre-download completed!") | |
| if __name__ == "__main__": | |
| main() | |