Spaces:
Running
Running
| """Pre-download all models and build index during Docker build.""" | |
| import os | |
| os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" | |
| print("=== PRE-DOWNLOADING MODELS ===") | |
| # 1. Download embedding model | |
| print("[1/3] Downloading embedding model...") | |
| from sentence_transformers import SentenceTransformer | |
| SentenceTransformer("NeuML/pubmedbert-base-embeddings") | |
| print(" Done.") | |
| # 2. Download NLI model | |
| print("[2/3] Downloading NLI model...") | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| AutoTokenizer.from_pretrained("pritamdeka/PubMedBERT-MNLI-MedNLI") | |
| AutoModelForSequenceClassification.from_pretrained("pritamdeka/PubMedBERT-MNLI-MedNLI") | |
| print(" Done.") | |
| # 3. Build FAISS index | |
| print("[3/3] Building FAISS index...") | |
| from src.bio_rag.config import BioRAGConfig | |
| from src.bio_rag.data_loader import load_diabetes_pubmedqa | |
| from src.bio_rag.knowledge_base import KnowledgeBaseBuilder | |
| config = BioRAGConfig() | |
| samples = load_diabetes_pubmedqa(config.dataset_name, max_samples=config.max_samples) | |
| print(f" Filtered {len(samples)} diabetes samples.") | |
| kb = KnowledgeBaseBuilder(config) | |
| vs = kb.load_or_build(samples) | |
| print(" FAISS index built and saved.") | |
| print("=== ALL MODELS READY ===") |