HNTAI / scripts /preload_models.py
sachinchandrankallar's picture
bench mark
188ec8d
#!/usr/bin/env python3
"""
Pre-download and cache models for Hugging Face Spaces deployment.
Run this during Docker build to avoid runtime downloads.
PRE-CACHED MODELS (downloaded during build):
- facebook/bart-large-cnn (Summarization)
- patrickvonplaten/longformer2roberta-cnn_dailymail-fp16 (Seq2Seq)
- google/flan-t5-large (Summarization)
- microsoft/Phi-3-mini-4k-instruct (Causal OpenVINO)
- OpenVINO/Phi-3-mini-4k-instruct-fp16-ov (Causal OpenVINO)
- microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf (GGUF - PRIMARY)
RUNTIME BEHAVIOR:
- If you request a pre-cached model: Loads instantly from cache (30-60 sec)
- If you request a different model: Downloads and uses at runtime automatically
- System supports both pre-cached and on-demand model loading
PRIMARY MODEL for patient summaries:
- microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf (is_active: true)
"""
import os
import sys
import logging
from pathlib import Path
# Add src to path for benchmarking
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
sys.path.insert(0, os.path.join(project_root, "services", "ai-service", "src"))
try:
from ai_med_extract.utils.benchmark import BenchmarkContext
except ImportError:
# Fallback if path is wrong or module missing (though we set path)
logging.warning("Benchmark module not found. creating dummy context.")
class BenchmarkContext:
def __init__(self, *args, **kwargs): pass
def __enter__(self): return self
def __exit__(self, *args): pass
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Set cache directories - these will be baked into the Docker image
MODEL_CACHE_DIR = os.environ.get('MODEL_CACHE_DIR', '/app/models')
HF_HOME = os.environ.get('HF_HOME', '/app/.cache/huggingface')
TORCH_HOME = os.environ.get('TORCH_HOME', '/app/.cache/torch')
WHISPER_CACHE = os.environ.get('WHISPER_CACHE', '/app/.cache/whisper')
# Create cache directories
for cache_dir in [MODEL_CACHE_DIR, HF_HOME, TORCH_HOME, WHISPER_CACHE]:
Path(cache_dir).mkdir(parents=True, exist_ok=True)
logger.info(f"Created cache directory: {cache_dir}")
def preload_transformers_models():
"""Pre-download Hugging Face transformers models"""
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from huggingface_hub import snapshot_download
# Models for patient summary generation - as specified by user
models = [
# Summarization models
{
"name": "facebook/bart-large-cnn",
"type": "seq2seq",
"description": "BART Large CNN - Summarization",
"is_active": False # Available but not primary
},
{
"name": "patrickvonplaten/longformer2roberta-cnn_dailymail-fp16",
"type": "seq2seq",
"description": "Longformer2Roberta - Seq2Seq Summarization",
"is_active": False
},
{
"name": "google/flan-t5-large",
"type": "seq2seq",
"description": "FLAN-T5 Large - Summarization",
"is_active": False
},
# OpenVINO models for patient summaries
{
"name": "microsoft/Phi-3-mini-4k-instruct",
"type": "causal",
"description": "Phi-3 Mini - Causal OpenVINO (base model)",
"is_active": False
},
{
"name": "OpenVINO/Phi-3-mini-4k-instruct-fp16-ov",
"type": "causal",
"description": "Phi-3 Mini - FP16 OpenVINO optimized",
"is_active": False
},
]
for model_info in models:
model_name = model_info["name"]
model_type = model_info["type"]
description = model_info["description"]
try:
logger.info(f"πŸ“₯ Downloading {description}: {model_name}")
# Download tokenizer
logger.info(f" ↳ Downloading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=HF_HOME,
trust_remote_code=False
)
# Download model
logger.info(f" ↳ Downloading model weights...")
if model_type == "seq2seq":
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
cache_dir=HF_HOME,
trust_remote_code=False
)
else:
# For token classification and other types
from transformers import AutoModel
model = AutoModel.from_pretrained(
model_name,
cache_dir=HF_HOME,
trust_remote_code=False
)
logger.info(f" βœ… Successfully cached {model_name}")
# Clean up memory
del model
del tokenizer
except Exception as e:
logger.error(f" ❌ Failed to download {model_name}: {e}")
# Don't fail the entire script if one model fails
continue
def preload_gguf_models():
"""Pre-download GGUF models"""
from huggingface_hub import hf_hub_download
# GGUF model for patient summaries - PRIMARY MODEL (is_active: true)
gguf_models = [
{
"repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf",
"filename": "Phi-3-mini-4k-instruct-q4.gguf",
"description": "Phi-3 Mini GGUF (Q4 quantized) - PRIMARY for patient summaries",
"is_active": True # This is the active model for patient summaries
}
]
for model_info in gguf_models:
try:
logger.info(f"πŸ“₯ Downloading GGUF: {model_info['description']}")
file_path = hf_hub_download(
repo_id=model_info["repo_id"],
filename=model_info["filename"],
cache_dir=HF_HOME,
local_dir=MODEL_CACHE_DIR,
local_dir_use_symlinks=False # Copy files instead of symlinks
)
logger.info(f" βœ… Successfully cached GGUF model at: {file_path}")
except Exception as e:
logger.error(f" ❌ Failed to download GGUF model: {e}")
continue
def preload_whisper_models():
"""Pre-download Whisper models"""
try:
logger.info(f"πŸ“₯ Downloading Whisper tiny model...")
import whisper
model = whisper.load_model(
"tiny",
device="cpu",
download_root=WHISPER_CACHE
)
logger.info(f" βœ… Successfully cached Whisper tiny model")
del model
except Exception as e:
logger.error(f" ❌ Failed to download Whisper model: {e}")
def preload_spacy_models():
"""Pre-download spaCy models"""
try:
logger.info(f"πŸ“₯ Loading spaCy en_core_web_sm model...")
import spacy
nlp = spacy.load("en_core_web_sm")
logger.info(f" βœ… Successfully loaded spaCy model")
except Exception as e:
logger.error(f" ❌ Failed to load spaCy model: {e}")
def preload_nltk_data():
"""Pre-download NLTK data"""
try:
logger.info(f"πŸ“₯ Downloading NLTK data...")
import nltk
nltk_data_dir = os.path.join(HF_HOME, 'nltk_data')
Path(nltk_data_dir).mkdir(parents=True, exist_ok=True)
# Download common NLTK datasets
for package in ['punkt', 'stopwords', 'wordnet', 'averaged_perceptron_tagger']:
try:
nltk.download(package, download_dir=nltk_data_dir, quiet=True)
logger.info(f" βœ… Downloaded NLTK package: {package}")
except:
logger.warning(f" ⚠️ Failed to download NLTK package: {package}")
except Exception as e:
logger.error(f" ❌ Failed to download NLTK data: {e}")
def print_cache_summary():
"""Print summary of cached models"""
logger.info("\n" + "="*80)
logger.info("CACHE SUMMARY")
logger.info("="*80)
for cache_dir in [MODEL_CACHE_DIR, HF_HOME, TORCH_HOME, WHISPER_CACHE]:
if os.path.exists(cache_dir):
# Calculate directory size
total_size = 0
file_count = 0
for dirpath, dirnames, filenames in os.walk(cache_dir):
for f in filenames:
fp = os.path.join(dirpath, f)
if os.path.exists(fp):
total_size += os.path.getsize(fp)
file_count += 1
size_mb = total_size / (1024 * 1024)
size_gb = size_mb / 1024
logger.info(f"\nπŸ“ {cache_dir}")
logger.info(f" Files: {file_count}")
logger.info(f" Size: {size_mb:.2f} MB ({size_gb:.2f} GB)")
logger.info("\n" + "="*80)
def main():
"""Main preload function"""
logger.info("πŸš€ Starting model pre-download process...")
logger.info(f" HF_HOME: {HF_HOME}")
logger.info(f" MODEL_CACHE_DIR: {MODEL_CACHE_DIR}")
logger.info(f" TORCH_HOME: {TORCH_HOME}")
logger.info(f" WHISPER_CACHE: {WHISPER_CACHE}")
logger.info("")
# Import torch early to ensure CUDA detection works
try:
import torch
logger.info(f"πŸ”§ PyTorch version: {torch.__version__}")
logger.info(f"πŸ”§ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"πŸ”§ CUDA version: {torch.version.cuda}")
logger.info(f"πŸ”§ GPU: {torch.cuda.get_device_name(0)}")
except Exception as e:
logger.warning(f"⚠️ Could not detect PyTorch/CUDA info: {e}")
logger.info("")
# Preload all models
steps = [
("Transformers Models", preload_transformers_models),
("GGUF Models", preload_gguf_models),
("Whisper Models", preload_whisper_models),
("spaCy Models", preload_spacy_models),
("NLTK Data", preload_nltk_data),
]
for step_name, step_func in steps:
logger.info(f"\n{'='*80}")
logger.info(f"STEP: {step_name}")
logger.info(f"{'='*80}\n")
try:
with BenchmarkContext(f"preload_{step_name.replace(' ', '_')}"):
step_func()
except Exception as e:
logger.error(f"❌ Failed during {step_name}: {e}")
import traceback
traceback.print_exc()
# Print summary
print_cache_summary()
logger.info("\nβœ… Model pre-download completed!")
if __name__ == "__main__":
main()