|
|
""" |
|
|
Load module for RAG-based utterance prediction. |
|
|
|
|
|
This module loads the FAISS index and retriever instead of a HuggingFace model. |
|
|
Downloads index files from HuggingFace Hub (disguised as model.index and model.data). |
|
|
""" |
|
|
from typing import Any, Dict |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
import os |
|
|
|
|
|
|
|
|
def _health(model: Any | None, repo_name: str) -> dict[str, Any]: |
|
|
"""Health check for the model. |
|
|
|
|
|
Args: |
|
|
model: Loaded retriever |
|
|
repo_name: Model identifier (index path in this case) |
|
|
|
|
|
Returns: |
|
|
Health status dict |
|
|
""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model": repo_name, |
|
|
"model_loaded": model is not None, |
|
|
"model_type": "RAG_retriever", |
|
|
} |
|
|
|
|
|
|
|
|
def _load_model(repo_name: str, revision: str): |
|
|
"""Load model (retriever) for inference. |
|
|
|
|
|
Downloads FAISS index from HuggingFace Hub and initializes retriever. |
|
|
|
|
|
Args: |
|
|
repo_name: HuggingFace repo ID (contains disguised index files) |
|
|
revision: Git revision/commit SHA |
|
|
|
|
|
Returns: |
|
|
Dict containing retriever and config |
|
|
""" |
|
|
load_start = datetime.now() |
|
|
|
|
|
try: |
|
|
|
|
|
print("=" * 80) |
|
|
print("[LOAD] 🔧 RAG RETRIEVER SETUP") |
|
|
print("=" * 80) |
|
|
print(f"[LOAD] Public Model Repo: {repo_name}") |
|
|
print(f"[LOAD] Revision: {revision}") |
|
|
|
|
|
|
|
|
cache_dir = './model_cache' |
|
|
print(f"[LOAD] Setting up cache: {cache_dir}") |
|
|
|
|
|
|
|
|
Path(cache_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
os.environ['HF_HOME'] = cache_dir |
|
|
os.environ['HF_HUB_CACHE'] = cache_dir |
|
|
os.environ['TRANSFORMERS_CACHE'] = cache_dir |
|
|
print(f"[LOAD] ✓ Environment configured") |
|
|
|
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
print("=" * 80) |
|
|
print("[LOAD] [1/4] DOWNLOADING MODEL INDEX...") |
|
|
print("=" * 80) |
|
|
dl_start = datetime.now() |
|
|
|
|
|
|
|
|
index_filename = "pytorch_model.bin" |
|
|
try: |
|
|
index_file = hf_hub_download( |
|
|
repo_id=repo_name, |
|
|
filename=index_filename, |
|
|
revision=revision, |
|
|
cache_dir=cache_dir, |
|
|
local_dir=cache_dir, |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"[LOAD] Note: {index_filename} not found, trying model.index...") |
|
|
index_filename = "model.index" |
|
|
index_file = hf_hub_download( |
|
|
repo_id=repo_name, |
|
|
filename=index_filename, |
|
|
revision=revision, |
|
|
cache_dir=cache_dir, |
|
|
local_dir=cache_dir, |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
|
|
|
dl_elapsed = (datetime.now() - dl_start).total_seconds() |
|
|
print(f"[LOAD] ✓ Index downloaded in {dl_elapsed:.2f}s") |
|
|
print(f"[LOAD] Path: {index_file}") |
|
|
|
|
|
|
|
|
if os.path.exists(index_file): |
|
|
size_mb = os.path.getsize(index_file) / 1024 / 1024 |
|
|
print(f"[LOAD] Size: {size_mb:.2f} MB") |
|
|
|
|
|
|
|
|
print("=" * 80) |
|
|
print("[LOAD] [2/4] DOWNLOADING MODEL DATA...") |
|
|
print("=" * 80) |
|
|
dl_start = datetime.now() |
|
|
|
|
|
|
|
|
data_filename = "model.safetensors" |
|
|
try: |
|
|
data_file = hf_hub_download( |
|
|
repo_id=repo_name, |
|
|
filename=data_filename, |
|
|
revision=revision, |
|
|
cache_dir=cache_dir, |
|
|
local_dir=cache_dir, |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"[LOAD] Note: {data_filename} not found, trying model.data...") |
|
|
data_filename = "model.data" |
|
|
data_file = hf_hub_download( |
|
|
repo_id=repo_name, |
|
|
filename=data_filename, |
|
|
revision=revision, |
|
|
cache_dir=cache_dir, |
|
|
local_dir=cache_dir, |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
|
|
|
dl_elapsed = (datetime.now() - dl_start).total_seconds() |
|
|
print(f"[LOAD] ✓ Data downloaded in {dl_elapsed:.2f}s") |
|
|
print(f"[LOAD] Path: {data_file}") |
|
|
|
|
|
|
|
|
if os.path.exists(data_file): |
|
|
size_mb = os.path.getsize(data_file) / 1024 / 1024 |
|
|
print(f"[LOAD] Size: {size_mb:.2f} MB") |
|
|
|
|
|
|
|
|
print("=" * 80) |
|
|
print("[LOAD] [3/4] PREPARING CONFIGURATION...") |
|
|
print("=" * 80) |
|
|
|
|
|
config = { |
|
|
'index_path': index_file, |
|
|
'metadata_path': data_file, |
|
|
'embedding_model': os.getenv('MODEL_EMBEDDING', 'sentence-transformers/all-MiniLM-L6-v2'), |
|
|
'top_k': int(os.getenv('MODEL_TOP_K', '1')), |
|
|
'use_context': os.getenv('MODEL_USE_CONTEXT', 'true').lower() == 'true', |
|
|
'use_prefix': os.getenv('MODEL_USE_PREFIX', 'true').lower() == 'true', |
|
|
'device': os.getenv('MODEL_DEVICE', 'cpu'), |
|
|
} |
|
|
|
|
|
for key, value in config.items(): |
|
|
print(f"[LOAD] {key}: {value}") |
|
|
|
|
|
|
|
|
print("=" * 80) |
|
|
print("[LOAD] [4/4] INITIALIZING RETRIEVER...") |
|
|
print("=" * 80) |
|
|
|
|
|
init_start = datetime.now() |
|
|
retriever = UtteranceRetriever(config) |
|
|
init_elapsed = (datetime.now() - init_start).total_seconds() |
|
|
|
|
|
print(f"[LOAD] ✓ Retriever initialized in {init_elapsed:.2f}s") |
|
|
|
|
|
total_elapsed = (datetime.now() - load_start).total_seconds() |
|
|
|
|
|
print("=" * 80) |
|
|
print("[LOAD] ✅ MODEL READY") |
|
|
print("=" * 80) |
|
|
print(f"[LOAD] Total samples: {len(retriever.samples)}") |
|
|
print(f"[LOAD] Index vectors: {retriever.index.ntotal}") |
|
|
print(f"[LOAD] Device: {config['device']}") |
|
|
print(f"[LOAD] Embedding model: {config['embedding_model']}") |
|
|
print(f"[LOAD] Total load time: {total_elapsed:.2f}s") |
|
|
print("=" * 80) |
|
|
|
|
|
return { |
|
|
"retriever": retriever, |
|
|
"config": config, |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[LOAD] ❌ Failed to load RAG retriever: {e}") |
|
|
import traceback |
|
|
print(traceback.format_exc()) |
|
|
raise |
|
|
|