uhc-policy-chatbot / embedding /scripts /embed_chunks.py
Mayank Patel
Initial deployment: UHC Medical Policy Chatbot
5c32ed1
"""
Generate embeddings for RAG chunks using MedEmbed-large-v0.1.
Reads rag_chunks.json, encodes every chunk's text with the MedEmbed model,
and saves the resulting vectors alongside their chunk IDs to a compressed
numpy archive (.npz) for downstream loading into Qdrant.
Usage:
python embed_chunks.py # full run
python embed_chunks.py --limit 100 # embed only first 100 chunks (for testing)
python embed_chunks.py --batch-size 64 # override batch size
"""
import argparse
import json
import time
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from config import (
RAG_CHUNKS_PATH,
EMBEDDINGS_DIR,
EMBEDDINGS_FILE,
EMBEDDING_MODEL_NAME,
BATCH_SIZE,
MAX_SEQ_LENGTH,
)
def select_device() -> str:
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def load_chunks(path, limit=None):
with open(path, "r", encoding="utf-8") as f:
chunks = json.load(f)
if limit:
chunks = chunks[:limit]
return chunks
def build_embedding_text(chunk: dict) -> str:
"""
Construct the text that gets embedded. Prepend key metadata so the
embedding captures policy context, not just the raw paragraph.
"""
parts = []
policy = chunk.get("policy_name", "").replace("-", " ").title()
if policy:
parts.append(f"Policy: {policy}")
section = chunk.get("section", "")
if section:
parts.append(f"Section: {section}")
parts.append(chunk["text"])
return " | ".join(parts)
def embed_in_batches(model, texts, batch_size, device):
all_embeddings = []
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding batches"):
batch = texts[i : i + batch_size]
embeddings = model.encode(
batch,
batch_size=batch_size,
show_progress_bar=False,
convert_to_numpy=True,
normalize_embeddings=True,
device=device,
)
all_embeddings.append(embeddings)
return np.vstack(all_embeddings)
def main():
parser = argparse.ArgumentParser(description="Embed RAG chunks with MedEmbed")
parser.add_argument("--limit", type=int, default=None, help="Limit chunks to embed (for testing)")
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Batch size for encoding")
args = parser.parse_args()
device = select_device()
print(f"Device: {device}")
print(f"Model: {EMBEDDING_MODEL_NAME}")
print(f"Batch: {args.batch_size}")
print("\nLoading chunks...")
chunks = load_chunks(RAG_CHUNKS_PATH, limit=args.limit)
print(f"Loaded {len(chunks)} chunks")
chunk_ids = [c["id"] for c in chunks]
texts = [build_embedding_text(c) for c in chunks]
print(f"\nLoading model {EMBEDDING_MODEL_NAME}...")
model = SentenceTransformer(EMBEDDING_MODEL_NAME, trust_remote_code=True)
model.max_seq_length = MAX_SEQ_LENGTH
print(f"Model loaded — embedding dim: {model.get_sentence_embedding_dimension()}")
print("\nGenerating embeddings...")
start = time.time()
embeddings = embed_in_batches(model, texts, args.batch_size, device)
elapsed = time.time() - start
print(f"\nEmbeddings shape: {embeddings.shape}")
print(f"Time: {elapsed:.1f}s ({len(texts) / elapsed:.1f} chunks/sec)")
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
np.savez_compressed(
EMBEDDINGS_FILE,
ids=np.array(chunk_ids, dtype=object),
embeddings=embeddings,
)
size_mb = EMBEDDINGS_FILE.stat().st_size / (1024 * 1024)
print(f"\nSaved to {EMBEDDINGS_FILE} ({size_mb:.1f} MB)")
if __name__ == "__main__":
main()