File size: 3,807 Bytes
5c32ed1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Generate embeddings for RAG chunks using MedEmbed-large-v0.1.

Reads rag_chunks.json, encodes every chunk's text with the MedEmbed model,
and saves the resulting vectors alongside their chunk IDs to a compressed
numpy archive (.npz) for downstream loading into Qdrant.

Usage:
    python embed_chunks.py                      # full run
    python embed_chunks.py --limit 100          # embed only first 100 chunks (for testing)
    python embed_chunks.py --batch-size 64      # override batch size
"""

import argparse
import json
import time
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

from config import (
    RAG_CHUNKS_PATH,
    EMBEDDINGS_DIR,
    EMBEDDINGS_FILE,
    EMBEDDING_MODEL_NAME,
    BATCH_SIZE,
    MAX_SEQ_LENGTH,
)


def select_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def load_chunks(path, limit=None):
    with open(path, "r", encoding="utf-8") as f:
        chunks = json.load(f)
    if limit:
        chunks = chunks[:limit]
    return chunks


def build_embedding_text(chunk: dict) -> str:
    """
    Construct the text that gets embedded. Prepend key metadata so the
    embedding captures policy context, not just the raw paragraph.
    """
    parts = []

    policy = chunk.get("policy_name", "").replace("-", " ").title()
    if policy:
        parts.append(f"Policy: {policy}")

    section = chunk.get("section", "")
    if section:
        parts.append(f"Section: {section}")

    parts.append(chunk["text"])

    return " | ".join(parts)


def embed_in_batches(model, texts, batch_size, device):
    all_embeddings = []

    for i in tqdm(range(0, len(texts), batch_size), desc="Embedding batches"):
        batch = texts[i : i + batch_size]
        embeddings = model.encode(
            batch,
            batch_size=batch_size,
            show_progress_bar=False,
            convert_to_numpy=True,
            normalize_embeddings=True,
            device=device,
        )
        all_embeddings.append(embeddings)

    return np.vstack(all_embeddings)


def main():
    parser = argparse.ArgumentParser(description="Embed RAG chunks with MedEmbed")
    parser.add_argument("--limit", type=int, default=None, help="Limit chunks to embed (for testing)")
    parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Batch size for encoding")
    args = parser.parse_args()

    device = select_device()
    print(f"Device: {device}")
    print(f"Model:  {EMBEDDING_MODEL_NAME}")
    print(f"Batch:  {args.batch_size}")

    print("\nLoading chunks...")
    chunks = load_chunks(RAG_CHUNKS_PATH, limit=args.limit)
    print(f"Loaded {len(chunks)} chunks")

    chunk_ids = [c["id"] for c in chunks]
    texts = [build_embedding_text(c) for c in chunks]

    print(f"\nLoading model {EMBEDDING_MODEL_NAME}...")
    model = SentenceTransformer(EMBEDDING_MODEL_NAME, trust_remote_code=True)
    model.max_seq_length = MAX_SEQ_LENGTH
    print(f"Model loaded — embedding dim: {model.get_sentence_embedding_dimension()}")

    print("\nGenerating embeddings...")
    start = time.time()
    embeddings = embed_in_batches(model, texts, args.batch_size, device)
    elapsed = time.time() - start

    print(f"\nEmbeddings shape: {embeddings.shape}")
    print(f"Time: {elapsed:.1f}s ({len(texts) / elapsed:.1f} chunks/sec)")

    EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(
        EMBEDDINGS_FILE,
        ids=np.array(chunk_ids, dtype=object),
        embeddings=embeddings,
    )
    size_mb = EMBEDDINGS_FILE.stat().st_size / (1024 * 1024)
    print(f"\nSaved to {EMBEDDINGS_FILE} ({size_mb:.1f} MB)")


if __name__ == "__main__":
    main()