rag-api-node-1 / download_models.py
Peterase's picture
feat(rag): implement hybrid search with live sources and production-grade intent classification
a63c61f
import os
import sys
# Monkeypatch for transformers/FlagEmbedding compatibility issue
try:
import transformers.utils.import_utils
if not hasattr(transformers.utils.import_utils, 'is_torch_fx_available'):
transformers.utils.import_utils.is_torch_fx_available = lambda: False
except Exception:
pass
from FlagEmbedding import BGEM3FlagModel
from sentence_transformers import CrossEncoder
def download():
print("--- STARTING MODEL PRE-CACHE ---")
# 1. BGE-M3
model_name = "BAAI/bge-m3"
print(f"Downloading/Loading {model_name}...")
try:
# This will trigger the download if not present
_ = BGEM3FlagModel(model_name, use_fp16=True)
print(f"Successfully cached {model_name}")
except Exception as e:
print(f"Error caching {model_name}: {e}")
# 2. Reranker
reranker_name = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
print(f"Downloading/Loading {reranker_name}...")
try:
_ = CrossEncoder(reranker_name)
print(f"Successfully cached {reranker_name}")
except Exception as e:
print(f"Error caching {reranker_name}: {e}")
print("--- PRE-CACHE COMPLETE ---")
if __name__ == "__main__":
download()