File size: 7,662 Bytes
b6f9fa8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
scripts/ingest_incremental.py
==============================
Adds new chunks to an EXISTING FAISS index without rebuilding from scratch.
Only the new chunks are embedded — existing vectors are untouched.

Usage:
    python scripts/ingest_incremental.py --input data/dailymed_chunks.jsonl
    python scripts/ingest_incremental.py --input data/dailymed_chunks.jsonl --dry-run
"""
from __future__ import annotations

import argparse
import json
import logging
import pickle
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

import faiss
import numpy as np
import yaml

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)


def load_config() -> dict:
    with open("config.yaml", "r", encoding="utf-8") as f:
        return yaml.safe_load(f)


def load_new_chunks(path: str) -> list[dict]:
    chunks = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                chunks.append(json.loads(line))
    logger.info("Loaded %d new chunks from %s", len(chunks), path)
    return chunks


def embed_chunks(chunks: list[dict], model_name: str) -> np.ndarray:
    from sentence_transformers import SentenceTransformer
    model = SentenceTransformer(model_name)
    texts = [c["chunk_text"] for c in chunks]
    logger.info("Embedding %d new chunks with %s...", len(texts), model_name)
    embeddings = model.encode(
        texts,
        batch_size=32,
        show_progress_bar=True,
        normalize_embeddings=True,
        convert_to_numpy=True,
    )
    return embeddings.astype(np.float32)


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", required=True, help="JSONL file of new chunks")
    parser.add_argument("--dry-run", action="store_true",
                        help="Show what would be added without writing to disk")
    parser.add_argument("--force-update-section", default=None,
                        help="Force-update chunk_text for existing chunks matching this section keyword (e.g. 'ADVERSE REACTIONS')")
    args = parser.parse_args()

    cfg = load_config()
    idx_path   = cfg["retrieval"]["index_path"]
    meta_path  = cfg["retrieval"]["metadata_path"]
    model_name = cfg["retrieval"]["embedding_model"]

    if not Path(idx_path).exists():
        logger.error("FAISS index not found at %s. Run embedder.py first.", idx_path)
        sys.exit(1)

    # Load existing index + metadata
    logger.info("Loading existing FAISS index from %s ...", idx_path)
    index = faiss.read_index(idx_path)
    existing_count = index.ntotal
    logger.info("Existing index: %d vectors", existing_count)

    with open(meta_path, "rb") as f:
        metadata_store: dict[int, dict] = pickle.load(f)

    # Force-update existing chunk_text for a specific section (no new FAISS vectors needed)
    all_input_chunks = load_new_chunks(args.input)
    if args.force_update_section:
        section_kw = args.force_update_section.upper()
        # Primary lookup: chunk_id → FAISS key (works for FDA with deterministic IDs)
        id_to_meta = {v.get("chunk_id"): k for k, v in metadata_store.items()}
        # Secondary lookup: (doc_id, chunk_index) → FAISS key (works for guidelines with random UUID IDs)
        docidx_to_meta = {(v.get("doc_id", ""), v.get("chunk_index", 0)): k
                          for k, v in metadata_store.items()}
        updated = 0
        for chunk in all_input_chunks:
            if section_kw in chunk.get("chunk_text", "").upper():
                # Try primary match first
                faiss_key = id_to_meta.get(chunk.get("chunk_id"))
                # Fallback to (doc_id, chunk_index) match
                if faiss_key is None:
                    faiss_key = docidx_to_meta.get((chunk.get("doc_id", ""), chunk.get("chunk_index", 0)))
                if faiss_key is not None:
                    metadata_store[faiss_key]["chunk_text"] = chunk["chunk_text"]
                    updated += 1
        logger.info("Force-updated chunk_text for %d '%s' entries", updated, section_kw)
        if not args.dry_run:
            with open(meta_path, "wb") as f:
                pickle.dump(metadata_store, f, protocol=pickle.HIGHEST_PROTOCOL)
            logger.info("Metadata store saved.")
            # Invalidate BM25 cache
            bm25_cache = Path(meta_path).parent / "bm25_cache.pkl"
            if bm25_cache.exists():
                bm25_cache.unlink()
                logger.info("BM25 cache invalidated — will rebuild on next startup.")
        return

    # Deduplicate — skip chunks already in the index.
    # Primary key: chunk_id. Secondary key: (doc_id, chunk_index) handles
    # re-ingestion of the same document with new UUIDs (e.g. FDA label updates).
    existing_ids = {v.get("chunk_id", "") for v in metadata_store.values()}
    existing_docidx = {
        (v.get("doc_id", ""), v.get("chunk_index", -1))
        for v in metadata_store.values()
        if v.get("doc_id") and v.get("chunk_index", -1) >= 0
    }

    def _is_duplicate(c: dict) -> bool:
        if c.get("chunk_id") in existing_ids:
            return True
        key = (c.get("doc_id", ""), c.get("chunk_index", -1))
        return key[0] != "" and key[1] >= 0 and key in existing_docidx

    new_chunks = [c for c in all_input_chunks if not _is_duplicate(c)]

    if not new_chunks:
        logger.info("All chunks already in index. Nothing to add.")
        return

    logger.info("%d new chunks to add (%d duplicates skipped)",
                len(new_chunks), len(all_input_chunks) - len(new_chunks))

    if args.dry_run:
        logger.info("DRY RUN — no changes written.")
        for c in new_chunks[:5]:
            logger.info("  Would add: %s | %s", c.get("chunk_id"), c.get("title", "")[:60])
        return

    # Embed new chunks only
    embeddings = embed_chunks(new_chunks, model_name)

    # Add to existing FAISS index
    index.add(embeddings)
    logger.info("Index now has %d vectors (+%d)", index.ntotal, len(new_chunks))

    # Extend metadata store (new keys start from existing_count)
    for i, chunk in enumerate(new_chunks):
        metadata_store[existing_count + i] = {
            "chunk_id":     chunk.get("chunk_id", f"chunk_{existing_count + i}"),
            "doc_id":       chunk.get("doc_id", ""),
            "source":       chunk.get("source", ""),
            "title":        chunk.get("title", ""),
            "pub_type":     chunk.get("pub_type", "unknown"),
            "pub_year":     chunk.get("pub_year"),
            "journal":      chunk.get("journal", ""),
            "chunk_index":  chunk.get("chunk_index", 0),
            "total_chunks": chunk.get("total_chunks", 1),
            "chunk_text":   chunk.get("chunk_text", ""),
        }

    # Save updated artifacts
    faiss.write_index(index, idx_path)
    logger.info("FAISS index saved to %s", idx_path)

    with open(meta_path, "wb") as f:
        pickle.dump(metadata_store, f, protocol=pickle.HIGHEST_PROTOCOL)
    logger.info("Metadata store saved (%d total entries)", len(metadata_store))

    # Also append to chunks.jsonl for future full rebuilds
    chunks_jsonl = Path("data/processed/chunks.jsonl")
    with open(chunks_jsonl, "a", encoding="utf-8") as f:
        for chunk in new_chunks:
            f.write(json.dumps(chunk) + "\n")
    logger.info("Appended %d chunks to %s", len(new_chunks), chunks_jsonl)

    logger.info("Done. Restart the backend to reload the updated index.")


if __name__ == "__main__":
    main()