Spaces:
Running
Running
| import os | |
| import sys | |
| # Monkeypatch for transformers/FlagEmbedding compatibility issue | |
| try: | |
| import transformers.utils.import_utils | |
| if not hasattr(transformers.utils.import_utils, 'is_torch_fx_available'): | |
| transformers.utils.import_utils.is_torch_fx_available = lambda: False | |
| except Exception: | |
| pass | |
| from FlagEmbedding import BGEM3FlagModel | |
| from sentence_transformers import CrossEncoder | |
| def download(): | |
| print("--- STARTING MODEL PRE-CACHE ---") | |
| # 1. BGE-M3 | |
| model_name = "BAAI/bge-m3" | |
| print(f"Downloading/Loading {model_name}...") | |
| try: | |
| # This will trigger the download if not present | |
| _ = BGEM3FlagModel(model_name, use_fp16=True) | |
| print(f"Successfully cached {model_name}") | |
| except Exception as e: | |
| print(f"Error caching {model_name}: {e}") | |
| # 2. Reranker | |
| reranker_name = "cross-encoder/ms-marco-TinyBERT-L-2-v2" | |
| print(f"Downloading/Loading {reranker_name}...") | |
| try: | |
| _ = CrossEncoder(reranker_name) | |
| print(f"Successfully cached {reranker_name}") | |
| except Exception as e: | |
| print(f"Error caching {reranker_name}: {e}") | |
| print("--- PRE-CACHE COMPLETE ---") | |
| if __name__ == "__main__": | |
| download() | |