Spaces:
Sleeping
Sleeping
| """ | |
| Generate embeddings for RAG chunks using MedEmbed-large-v0.1. | |
| Reads rag_chunks.json, encodes every chunk's text with the MedEmbed model, | |
| and saves the resulting vectors alongside their chunk IDs to a compressed | |
| numpy archive (.npz) for downstream loading into Qdrant. | |
| Usage: | |
| python embed_chunks.py # full run | |
| python embed_chunks.py --limit 100 # embed only first 100 chunks (for testing) | |
| python embed_chunks.py --batch-size 64 # override batch size | |
| """ | |
| import argparse | |
| import json | |
| import time | |
| import numpy as np | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from tqdm import tqdm | |
| from config import ( | |
| RAG_CHUNKS_PATH, | |
| EMBEDDINGS_DIR, | |
| EMBEDDINGS_FILE, | |
| EMBEDDING_MODEL_NAME, | |
| BATCH_SIZE, | |
| MAX_SEQ_LENGTH, | |
| ) | |
| def select_device() -> str: | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def load_chunks(path, limit=None): | |
| with open(path, "r", encoding="utf-8") as f: | |
| chunks = json.load(f) | |
| if limit: | |
| chunks = chunks[:limit] | |
| return chunks | |
| def build_embedding_text(chunk: dict) -> str: | |
| """ | |
| Construct the text that gets embedded. Prepend key metadata so the | |
| embedding captures policy context, not just the raw paragraph. | |
| """ | |
| parts = [] | |
| policy = chunk.get("policy_name", "").replace("-", " ").title() | |
| if policy: | |
| parts.append(f"Policy: {policy}") | |
| section = chunk.get("section", "") | |
| if section: | |
| parts.append(f"Section: {section}") | |
| parts.append(chunk["text"]) | |
| return " | ".join(parts) | |
| def embed_in_batches(model, texts, batch_size, device): | |
| all_embeddings = [] | |
| for i in tqdm(range(0, len(texts), batch_size), desc="Embedding batches"): | |
| batch = texts[i : i + batch_size] | |
| embeddings = model.encode( | |
| batch, | |
| batch_size=batch_size, | |
| show_progress_bar=False, | |
| convert_to_numpy=True, | |
| normalize_embeddings=True, | |
| device=device, | |
| ) | |
| all_embeddings.append(embeddings) | |
| return np.vstack(all_embeddings) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Embed RAG chunks with MedEmbed") | |
| parser.add_argument("--limit", type=int, default=None, help="Limit chunks to embed (for testing)") | |
| parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Batch size for encoding") | |
| args = parser.parse_args() | |
| device = select_device() | |
| print(f"Device: {device}") | |
| print(f"Model: {EMBEDDING_MODEL_NAME}") | |
| print(f"Batch: {args.batch_size}") | |
| print("\nLoading chunks...") | |
| chunks = load_chunks(RAG_CHUNKS_PATH, limit=args.limit) | |
| print(f"Loaded {len(chunks)} chunks") | |
| chunk_ids = [c["id"] for c in chunks] | |
| texts = [build_embedding_text(c) for c in chunks] | |
| print(f"\nLoading model {EMBEDDING_MODEL_NAME}...") | |
| model = SentenceTransformer(EMBEDDING_MODEL_NAME, trust_remote_code=True) | |
| model.max_seq_length = MAX_SEQ_LENGTH | |
| print(f"Model loaded — embedding dim: {model.get_sentence_embedding_dimension()}") | |
| print("\nGenerating embeddings...") | |
| start = time.time() | |
| embeddings = embed_in_batches(model, texts, args.batch_size, device) | |
| elapsed = time.time() - start | |
| print(f"\nEmbeddings shape: {embeddings.shape}") | |
| print(f"Time: {elapsed:.1f}s ({len(texts) / elapsed:.1f} chunks/sec)") | |
| EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) | |
| np.savez_compressed( | |
| EMBEDDINGS_FILE, | |
| ids=np.array(chunk_ids, dtype=object), | |
| embeddings=embeddings, | |
| ) | |
| size_mb = EMBEDDINGS_FILE.stat().st_size / (1024 * 1024) | |
| print(f"\nSaved to {EMBEDDINGS_FILE} ({size_mb:.1f} MB)") | |
| if __name__ == "__main__": | |
| main() | |