File size: 10,932 Bytes
4156c57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188ec8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4156c57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188ec8d
 
4156c57
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
#!/usr/bin/env python3
"""
Pre-download and cache models for Hugging Face Spaces deployment.
Run this during Docker build to avoid runtime downloads.

PRE-CACHED MODELS (downloaded during build):
- facebook/bart-large-cnn (Summarization)
- patrickvonplaten/longformer2roberta-cnn_dailymail-fp16 (Seq2Seq)
- google/flan-t5-large (Summarization)
- microsoft/Phi-3-mini-4k-instruct (Causal OpenVINO)
- OpenVINO/Phi-3-mini-4k-instruct-fp16-ov (Causal OpenVINO)
- microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf (GGUF - PRIMARY)

RUNTIME BEHAVIOR:
- If you request a pre-cached model: Loads instantly from cache (30-60 sec)
- If you request a different model: Downloads and uses at runtime automatically
- System supports both pre-cached and on-demand model loading

PRIMARY MODEL for patient summaries:
- microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf (is_active: true)
"""
import os
import sys
import logging
from pathlib import Path

# Add src to path for benchmarking
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
sys.path.insert(0, os.path.join(project_root, "services", "ai-service", "src"))

try:
    from ai_med_extract.utils.benchmark import BenchmarkContext
except ImportError:
    # Fallback if path is wrong or module missing (though we set path)
    logging.warning("Benchmark module not found. creating dummy context.")
    class BenchmarkContext:
        def __init__(self, *args, **kwargs): pass
        def __enter__(self): return self
        def __exit__(self, *args): pass

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Set cache directories - these will be baked into the Docker image
MODEL_CACHE_DIR = os.environ.get('MODEL_CACHE_DIR', '/app/models')
HF_HOME = os.environ.get('HF_HOME', '/app/.cache/huggingface')
TORCH_HOME = os.environ.get('TORCH_HOME', '/app/.cache/torch')
WHISPER_CACHE = os.environ.get('WHISPER_CACHE', '/app/.cache/whisper')

# Create cache directories
for cache_dir in [MODEL_CACHE_DIR, HF_HOME, TORCH_HOME, WHISPER_CACHE]:
    Path(cache_dir).mkdir(parents=True, exist_ok=True)
    logger.info(f"Created cache directory: {cache_dir}")

def preload_transformers_models():
    """Pre-download Hugging Face transformers models"""
    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
    from huggingface_hub import snapshot_download
    
    # Models for patient summary generation - as specified by user
    models = [
        # Summarization models
        {
            "name": "facebook/bart-large-cnn",
            "type": "seq2seq",
            "description": "BART Large CNN - Summarization",
            "is_active": False  # Available but not primary
        },
        {
            "name": "patrickvonplaten/longformer2roberta-cnn_dailymail-fp16",
            "type": "seq2seq",
            "description": "Longformer2Roberta - Seq2Seq Summarization",
            "is_active": False
        },
        {
            "name": "google/flan-t5-large",
            "type": "seq2seq",
            "description": "FLAN-T5 Large - Summarization",
            "is_active": False
        },
        # OpenVINO models for patient summaries
        {
            "name": "microsoft/Phi-3-mini-4k-instruct",
            "type": "causal",
            "description": "Phi-3 Mini - Causal OpenVINO (base model)",
            "is_active": False
        },
        {
            "name": "OpenVINO/Phi-3-mini-4k-instruct-fp16-ov",
            "type": "causal",
            "description": "Phi-3 Mini - FP16 OpenVINO optimized",
            "is_active": False
        },
    ]
    
    for model_info in models:
        model_name = model_info["name"]
        model_type = model_info["type"]
        description = model_info["description"]
        
        try:
            logger.info(f"πŸ“₯ Downloading {description}: {model_name}")
            
            # Download tokenizer
            logger.info(f"  ↳ Downloading tokenizer...")
            tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                cache_dir=HF_HOME,
                trust_remote_code=False
            )
            
            # Download model
            logger.info(f"  ↳ Downloading model weights...")
            if model_type == "seq2seq":
                model = AutoModelForSeq2SeqLM.from_pretrained(
                    model_name,
                    cache_dir=HF_HOME,
                    trust_remote_code=False
                )
            else:
                # For token classification and other types
                from transformers import AutoModel
                model = AutoModel.from_pretrained(
                    model_name,
                    cache_dir=HF_HOME,
                    trust_remote_code=False
                )
            
            logger.info(f"  βœ… Successfully cached {model_name}")
            
            # Clean up memory
            del model
            del tokenizer
            
        except Exception as e:
            logger.error(f"  ❌ Failed to download {model_name}: {e}")
            # Don't fail the entire script if one model fails
            continue

def preload_gguf_models():
    """Pre-download GGUF models"""
    from huggingface_hub import hf_hub_download
    
    # GGUF model for patient summaries - PRIMARY MODEL (is_active: true)
    gguf_models = [
        {
            "repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf",
            "filename": "Phi-3-mini-4k-instruct-q4.gguf",
            "description": "Phi-3 Mini GGUF (Q4 quantized) - PRIMARY for patient summaries",
            "is_active": True  # This is the active model for patient summaries
        }
    ]
    
    for model_info in gguf_models:
        try:
            logger.info(f"πŸ“₯ Downloading GGUF: {model_info['description']}")
            
            file_path = hf_hub_download(
                repo_id=model_info["repo_id"],
                filename=model_info["filename"],
                cache_dir=HF_HOME,
                local_dir=MODEL_CACHE_DIR,
                local_dir_use_symlinks=False  # Copy files instead of symlinks
            )
            
            logger.info(f"  βœ… Successfully cached GGUF model at: {file_path}")
            
        except Exception as e:
            logger.error(f"  ❌ Failed to download GGUF model: {e}")
            continue

def preload_whisper_models():
    """Pre-download Whisper models"""
    try:
        logger.info(f"πŸ“₯ Downloading Whisper tiny model...")
        
        import whisper
        model = whisper.load_model(
            "tiny",
            device="cpu",
            download_root=WHISPER_CACHE
        )
        
        logger.info(f"  βœ… Successfully cached Whisper tiny model")
        del model
        
    except Exception as e:
        logger.error(f"  ❌ Failed to download Whisper model: {e}")

def preload_spacy_models():
    """Pre-download spaCy models"""
    try:
        logger.info(f"πŸ“₯ Loading spaCy en_core_web_sm model...")
        
        import spacy
        nlp = spacy.load("en_core_web_sm")
        
        logger.info(f"  βœ… Successfully loaded spaCy model")
        
    except Exception as e:
        logger.error(f"  ❌ Failed to load spaCy model: {e}")

def preload_nltk_data():
    """Pre-download NLTK data"""
    try:
        logger.info(f"πŸ“₯ Downloading NLTK data...")
        
        import nltk
        nltk_data_dir = os.path.join(HF_HOME, 'nltk_data')
        Path(nltk_data_dir).mkdir(parents=True, exist_ok=True)
        
        # Download common NLTK datasets
        for package in ['punkt', 'stopwords', 'wordnet', 'averaged_perceptron_tagger']:
            try:
                nltk.download(package, download_dir=nltk_data_dir, quiet=True)
                logger.info(f"  βœ… Downloaded NLTK package: {package}")
            except:
                logger.warning(f"  ⚠️  Failed to download NLTK package: {package}")
        
    except Exception as e:
        logger.error(f"  ❌ Failed to download NLTK data: {e}")

def print_cache_summary():
    """Print summary of cached models"""
    logger.info("\n" + "="*80)
    logger.info("CACHE SUMMARY")
    logger.info("="*80)
    
    for cache_dir in [MODEL_CACHE_DIR, HF_HOME, TORCH_HOME, WHISPER_CACHE]:
        if os.path.exists(cache_dir):
            # Calculate directory size
            total_size = 0
            file_count = 0
            for dirpath, dirnames, filenames in os.walk(cache_dir):
                for f in filenames:
                    fp = os.path.join(dirpath, f)
                    if os.path.exists(fp):
                        total_size += os.path.getsize(fp)
                        file_count += 1
            
            size_mb = total_size / (1024 * 1024)
            size_gb = size_mb / 1024
            
            logger.info(f"\nπŸ“ {cache_dir}")
            logger.info(f"   Files: {file_count}")
            logger.info(f"   Size: {size_mb:.2f} MB ({size_gb:.2f} GB)")
    
    logger.info("\n" + "="*80)

def main():
    """Main preload function"""
    logger.info("πŸš€ Starting model pre-download process...")
    logger.info(f"   HF_HOME: {HF_HOME}")
    logger.info(f"   MODEL_CACHE_DIR: {MODEL_CACHE_DIR}")
    logger.info(f"   TORCH_HOME: {TORCH_HOME}")
    logger.info(f"   WHISPER_CACHE: {WHISPER_CACHE}")
    logger.info("")
    
    # Import torch early to ensure CUDA detection works
    try:
        import torch
        logger.info(f"πŸ”§ PyTorch version: {torch.__version__}")
        logger.info(f"πŸ”§ CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            logger.info(f"πŸ”§ CUDA version: {torch.version.cuda}")
            logger.info(f"πŸ”§ GPU: {torch.cuda.get_device_name(0)}")
    except Exception as e:
        logger.warning(f"⚠️  Could not detect PyTorch/CUDA info: {e}")
    
    logger.info("")
    
    # Preload all models
    steps = [
        ("Transformers Models", preload_transformers_models),
        ("GGUF Models", preload_gguf_models),
        ("Whisper Models", preload_whisper_models),
        ("spaCy Models", preload_spacy_models),
        ("NLTK Data", preload_nltk_data),
    ]
    
    for step_name, step_func in steps:
        logger.info(f"\n{'='*80}")
        logger.info(f"STEP: {step_name}")
        logger.info(f"{'='*80}\n")
        
        try:
            with BenchmarkContext(f"preload_{step_name.replace(' ', '_')}"):
                step_func()
        except Exception as e:
            logger.error(f"❌ Failed during {step_name}: {e}")
            import traceback
            traceback.print_exc()
    
    # Print summary
    print_cache_summary()
    
    logger.info("\nβœ… Model pre-download completed!")

if __name__ == "__main__":
    main()