import os import sys # Monkeypatch for transformers/FlagEmbedding compatibility issue try: import transformers.utils.import_utils if not hasattr(transformers.utils.import_utils, 'is_torch_fx_available'): transformers.utils.import_utils.is_torch_fx_available = lambda: False except Exception: pass from FlagEmbedding import BGEM3FlagModel from sentence_transformers import CrossEncoder def download(): print("--- STARTING MODEL PRE-CACHE ---") # 1. BGE-M3 model_name = "BAAI/bge-m3" print(f"Downloading/Loading {model_name}...") try: # This will trigger the download if not present _ = BGEM3FlagModel(model_name, use_fp16=True) print(f"Successfully cached {model_name}") except Exception as e: print(f"Error caching {model_name}: {e}") # 2. Reranker reranker_name = "cross-encoder/ms-marco-TinyBERT-L-2-v2" print(f"Downloading/Loading {reranker_name}...") try: _ = CrossEncoder(reranker_name) print(f"Successfully cached {reranker_name}") except Exception as e: print(f"Error caching {reranker_name}: {e}") print("--- PRE-CACHE COMPLETE ---") if __name__ == "__main__": download()