File size: 4,117 Bytes
e1ced8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""Ingest preprocessed NYC code JSON files into ChromaDB with bge-large-en-v1.5."""
from __future__ import annotations

import json
import os
import sys

import chromadb
from chromadb.utils import embedding_functions


EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5"
COLLECTION_NAME = "nyc_building_codes"
DB_PATH = os.path.join(os.path.dirname(__file__), "nyc_code_db")

# Map of JSON files to their code types
CODE_FILES = {
    "BUILDING_CODE.json": "Building",
    "FUEL_GAS_CODE.json": "FuelGas",
    "GENERAL_ADMINISTRATIVE_PROVISIONS.json": "Administrative",
    "MECHANICAL_CODE.json": "Mechanical",
    "PLUMBING_CODE.json": "Plumbing",
}


def create_collection(db_path: str = DB_PATH, reset: bool = True):
    """Create or reset the ChromaDB collection."""
    client = chromadb.PersistentClient(path=db_path)
    embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
        model_name=EMBEDDING_MODEL,
    )

    if reset:
        try:
            client.delete_collection(name=COLLECTION_NAME)
            print(f"Deleted existing collection '{COLLECTION_NAME}'.")
        except Exception:
            pass

    collection = client.create_collection(
        name=COLLECTION_NAME,
        embedding_function=embedding_fn,
    )
    return client, collection


def ingest_json_file(collection, json_path: str, code_type: str) -> int:
    """Ingest a single JSON file into the collection. Returns count of sections added."""
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    documents = []
    metadatas = []
    ids = []
    seen_ids: set[str] = set()

    for entry in data:
        meta = entry["metadata"]
        # Ensure code_type is set (should already be from preprocessing)
        meta["code_type"] = code_type

        unique_id = f"{code_type}_{entry['id']}"
        if unique_id in seen_ids:
            continue

        # Flatten list-type metadata for ChromaDB (only supports str/int/float/bool)
        flat_meta = {}
        for k, v in meta.items():
            if isinstance(v, list):
                flat_meta[k] = ", ".join(str(x) for x in v) if v else ""
            elif isinstance(v, bool):
                flat_meta[k] = v
            elif isinstance(v, (int, float)):
                flat_meta[k] = v
            else:
                flat_meta[k] = str(v)

        documents.append(entry["text"])
        metadatas.append(flat_meta)
        ids.append(unique_id)
        seen_ids.add(unique_id)

    # Batch upsert
    batch_size = 200  # Smaller batches for larger embeddings
    for i in range(0, len(documents), batch_size):
        batch_end = min(i + batch_size, len(documents))
        collection.upsert(
            documents=documents[i:batch_end],
            metadatas=metadatas[i:batch_end],
            ids=ids[i:batch_end],
        )
        print(f"  Batch {i // batch_size + 1}: upserted {batch_end - i} sections")

    return len(ids)


def ingest_all(data_dir: str, db_path: str = DB_PATH) -> dict[str, int]:
    """Ingest all code JSON files into a fresh ChromaDB collection."""
    print(f"Creating ChromaDB at {db_path} with embedding model: {EMBEDDING_MODEL}")
    _client, collection = create_collection(db_path, reset=True)

    counts: dict[str, int] = {}
    for filename, code_type in CODE_FILES.items():
        json_path = os.path.join(data_dir, filename)
        if os.path.exists(json_path):
            print(f"\nIngesting {filename} as '{code_type}'...")
            count = ingest_json_file(collection, json_path, code_type)
            counts[code_type] = count
            print(f"  -> {count} sections ingested")
        else:
            print(f"WARNING: {json_path} not found, skipping.")

    total = sum(counts.values())
    print(f"\nIngestion complete. Total: {total} sections across {len(counts)} code types.")
    return counts


if __name__ == "__main__":
    data_dir = sys.argv[1] if len(sys.argv) > 1 else os.path.dirname(__file__)
    ingest_all(data_dir)