""" Load module for RAG-based utterance prediction. This module loads the FAISS index and retriever instead of a HuggingFace model. Downloads index files from HuggingFace Hub (disguised as model.index and model.data). """ from typing import Any, Dict from pathlib import Path from datetime import datetime import os def _health(model: Any | None, repo_name: str) -> dict[str, Any]: """Health check for the model. Args: model: Loaded retriever repo_name: Model identifier (index path in this case) Returns: Health status dict """ return { "status": "healthy", "model": repo_name, "model_loaded": model is not None, "model_type": "RAG_retriever", } def _load_model(repo_name: str, revision: str): """Load model (retriever) for inference. Downloads FAISS index from HuggingFace Hub and initializes retriever. Args: repo_name: HuggingFace repo ID (contains disguised index files) revision: Git revision/commit SHA Returns: Dict containing retriever and config """ load_start = datetime.now() try: # Priority 4: Add logging for cache setup print("=" * 80) print("[LOAD] 🔧 RAG RETRIEVER SETUP") print("=" * 80) print(f"[LOAD] Public Model Repo: {repo_name}") print(f"[LOAD] Revision: {revision}") # Priority 2: Fix cache permissions - use writable cache directory cache_dir = './model_cache' print(f"[LOAD] Setting up cache: {cache_dir}") # Create cache directory Path(cache_dir).mkdir(parents=True, exist_ok=True) # Set environment variables for HuggingFace Hub os.environ['HF_HOME'] = cache_dir os.environ['HF_HUB_CACHE'] = cache_dir os.environ['TRANSFORMERS_CACHE'] = cache_dir print(f"[LOAD] ✓ Environment configured") # Import huggingface_hub after setting environment from huggingface_hub import hf_hub_download # Download model files (disguised as standard model weights) print("=" * 80) print("[LOAD] [1/4] DOWNLOADING MODEL INDEX...") print("=" * 80) dl_start = datetime.now() # Try new naming (pytorch_model.bin) first, fall back to old naming (model.index) index_filename = "pytorch_model.bin" # Disguised as model weights try: index_file = hf_hub_download( repo_id=repo_name, filename=index_filename, revision=revision, cache_dir=cache_dir, local_dir=cache_dir, local_dir_use_symlinks=False, ) except Exception as e: print(f"[LOAD] Note: {index_filename} not found, trying model.index...") index_filename = "model.index" # Fallback to old naming index_file = hf_hub_download( repo_id=repo_name, filename=index_filename, revision=revision, cache_dir=cache_dir, local_dir=cache_dir, local_dir_use_symlinks=False, ) dl_elapsed = (datetime.now() - dl_start).total_seconds() print(f"[LOAD] ✓ Index downloaded in {dl_elapsed:.2f}s") print(f"[LOAD] Path: {index_file}") # Check file size if os.path.exists(index_file): size_mb = os.path.getsize(index_file) / 1024 / 1024 print(f"[LOAD] Size: {size_mb:.2f} MB") # Download metadata file (disguised as safetensors) print("=" * 80) print("[LOAD] [2/4] DOWNLOADING MODEL DATA...") print("=" * 80) dl_start = datetime.now() # Try new naming (model.safetensors) first, fall back to old naming (model.data) data_filename = "model.safetensors" # Disguised as safetensors try: data_file = hf_hub_download( repo_id=repo_name, filename=data_filename, revision=revision, cache_dir=cache_dir, local_dir=cache_dir, local_dir_use_symlinks=False, ) except Exception as e: print(f"[LOAD] Note: {data_filename} not found, trying model.data...") data_filename = "model.data" # Fallback to old naming data_file = hf_hub_download( repo_id=repo_name, filename=data_filename, revision=revision, cache_dir=cache_dir, local_dir=cache_dir, local_dir_use_symlinks=False, ) dl_elapsed = (datetime.now() - dl_start).total_seconds() print(f"[LOAD] ✓ Data downloaded in {dl_elapsed:.2f}s") print(f"[LOAD] Path: {data_file}") # Check file size if os.path.exists(data_file): size_mb = os.path.getsize(data_file) / 1024 / 1024 print(f"[LOAD] Size: {size_mb:.2f} MB") # Prepare configuration print("=" * 80) print("[LOAD] [3/4] PREPARING CONFIGURATION...") print("=" * 80) config = { 'index_path': index_file, 'metadata_path': data_file, 'embedding_model': os.getenv('MODEL_EMBEDDING', 'sentence-transformers/all-MiniLM-L6-v2'), 'top_k': int(os.getenv('MODEL_TOP_K', '1')), 'use_context': os.getenv('MODEL_USE_CONTEXT', 'true').lower() == 'true', 'use_prefix': os.getenv('MODEL_USE_PREFIX', 'true').lower() == 'true', 'device': os.getenv('MODEL_DEVICE', 'cpu'), } for key, value in config.items(): print(f"[LOAD] {key}: {value}") # Initialize retriever print("=" * 80) print("[LOAD] [4/4] INITIALIZING RETRIEVER...") print("=" * 80) init_start = datetime.now() retriever = UtteranceRetriever(config) init_elapsed = (datetime.now() - init_start).total_seconds() print(f"[LOAD] ✓ Retriever initialized in {init_elapsed:.2f}s") total_elapsed = (datetime.now() - load_start).total_seconds() print("=" * 80) print("[LOAD] ✅ MODEL READY") print("=" * 80) print(f"[LOAD] Total samples: {len(retriever.samples)}") print(f"[LOAD] Index vectors: {retriever.index.ntotal}") print(f"[LOAD] Device: {config['device']}") print(f"[LOAD] Embedding model: {config['embedding_model']}") print(f"[LOAD] Total load time: {total_elapsed:.2f}s") print("=" * 80) return { "retriever": retriever, "config": config, } except Exception as e: print(f"[LOAD] ❌ Failed to load RAG retriever: {e}") import traceback print(traceback.format_exc()) raise