File size: 8,796 Bytes
e1ced8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
"""ChromaDB tools for NYC code lookup — with re-ranking, budget tracking, and caching."""
from __future__ import annotations

import hashlib
from collections import Counter

import chromadb
from chromadb.utils import embedding_functions

from config import (
    CHROMA_COLLECTION_NAME,
    CHROMA_DB_PATH,
    DISCOVER_N_RESULTS,
    EMBEDDING_MODEL_NAME,
    FETCH_MAX_SECTIONS,
    RERANK_TOP_K,
)


# ---------------------------------------------------------------------------
# Singleton collection loader
# ---------------------------------------------------------------------------

_collection = None
_warmup_done = False


def warmup_collection() -> bool:
    """Eagerly load the embedding model and connect to ChromaDB.



    Returns True if collection is available, False otherwise.

    Call this during app startup so the heavy model download + load

    happens visibly (with a progress spinner) rather than on the first query.

    """
    global _warmup_done
    try:
        get_collection()
        _warmup_done = True
        return True
    except Exception:
        _warmup_done = False
        return False


def is_warmed_up() -> bool:
    return _warmup_done


def get_collection():
    """Lazy-load the ChromaDB collection (singleton)."""
    global _collection
    if _collection is None:
        client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
        embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
            model_name=EMBEDDING_MODEL_NAME,
        )
        _collection = client.get_collection(
            name=CHROMA_COLLECTION_NAME,
            embedding_function=embedding_fn,
        )
    return _collection


# ---------------------------------------------------------------------------
# Query cache for deduplication
# ---------------------------------------------------------------------------

class QueryCache:
    """Simple cache to avoid re-querying semantically identical topics."""

    def __init__(self):
        self._cache: dict[str, str] = {}  # normalized_key -> result

    def _normalize(self, query: str) -> str:
        words = sorted(set(query.lower().split()))
        return " ".join(words)

    def get(self, query: str) -> str | None:
        key = self._normalize(query)
        return self._cache.get(key)

    def put(self, query: str, result: str) -> None:
        key = self._normalize(query)
        self._cache[key] = result


# ---------------------------------------------------------------------------
# discover_code_locations — semantic search with re-ranking
# ---------------------------------------------------------------------------

def discover_code_locations(query: str, cache: QueryCache | None = None) -> str:
    """Semantic search over NYC codes with hierarchical re-ranking.



    Returns a formatted report of the most relevant code sections.

    """
    # Check cache
    if cache is not None:
        cached = cache.get(query)
        if cached is not None:
            return f"[CACHED RESULT]\n{cached}"

    collection = get_collection()
    results = collection.query(
        query_texts=[query],
        n_results=DISCOVER_N_RESULTS,
        include=["metadatas", "documents", "distances"],
    )

    if not results["metadatas"][0]:
        return "No results found. Try a different query phrasing."

    metas = results["metadatas"][0]
    docs = results["documents"][0]
    distances = results["distances"][0]

    # ------ Re-ranking ------
    # Score = semantic_similarity + hierarchy_bonus + exception_bonus
    ranked = []
    for meta, doc, dist in zip(metas, docs, distances):
        score = -dist  # Lower distance = better match, negate for sorting

        # Hierarchy bonus: shallower sections (fewer dots) rank higher for broad queries
        depth = meta.get("section_full", "").count(".")
        score += max(0, 3 - depth) * 0.05  # Up to +0.15 for top-level sections

        # Exception bonus: sections with exceptions are more useful for compliance
        if meta.get("has_exceptions", False):
            score += 0.1

        ranked.append((score, meta, doc))

    ranked.sort(key=lambda x: x[0], reverse=True)
    top_results = ranked[:RERANK_TOP_K]

    # ------ Format output ------
    category_chapter_pairs = [
        f"{m['code_type']} | Ch. {m['parent_major']}" for _, m, _ in top_results
    ]
    counts = Counter(category_chapter_pairs)
    chapter_summary = "\n".join(
        f"- {pair} ({count} hits)" for pair, count in counts.most_common(5)
    )

    section_reports = []
    for _score, m, doc in top_results:
        exceptions_tag = " [HAS EXCEPTIONS]" if m.get("has_exceptions", False) else ""
        xrefs = m.get("cross_references", "")
        xref_tag = f"\n  Cross-refs: {xrefs}" if xrefs else ""

        report = (
            f"ID: {m['section_full']} | Code: {m['code_type']} | Chapter: {m['parent_major']}"
            f"{exceptions_tag}{xref_tag}\n"
            f"Snippet: {doc[:500]}"  # Truncate long snippets
        )
        section_reports.append(report)

    output = (
        "### CODE DISCOVERY REPORT ###\n"
        f"MOST RELEVANT CHAPTERS:\n{chapter_summary}\n\n"
        "TOP RELEVANT SECTIONS:\n"
        + "\n---\n".join(section_reports)
    )

    # Cache the result
    if cache is not None:
        cache.put(query, output)

    return output


# ---------------------------------------------------------------------------
# fetch_full_chapter — with section filtering and pagination
# ---------------------------------------------------------------------------

def fetch_full_chapter(

    code_type: str,

    chapter_id: str,

    section_filter: str | None = None,

) -> str:
    """Retrieve sections from a specific chapter, with optional keyword filtering.



    Parameters

    ----------

    code_type : str

        One of: Administrative, Building, FuelGas, Mechanical, Plumbing

    chapter_id : str

        The parent_major chapter ID (e.g., "10", "602")

    section_filter : str, optional

        If provided, only return sections containing this keyword

    """
    collection = get_collection()

    try:
        chapter_data = collection.get(
            where={
                "$and": [
                    {"code_type": {"$eq": code_type}},
                    {"parent_major": {"$eq": chapter_id}},
                ]
            },
            include=["documents", "metadatas"],
        )

        if not chapter_data["documents"]:
            return f"No documentation found for {code_type} Chapter {chapter_id}."

        pairs = list(zip(chapter_data["metadatas"], chapter_data["documents"]))

        # Apply keyword filter if provided
        if section_filter:
            filter_lower = section_filter.lower()
            pairs = [(m, d) for m, d in pairs if filter_lower in d.lower()]
            if not pairs:
                return (
                    f"No sections in {code_type} Chapter {chapter_id} "
                    f"match filter '{section_filter}'."
                )

        # Sort by section number and limit
        pairs.sort(key=lambda x: x[0]["section_full"])
        total_sections = len(pairs)
        pairs = pairs[:FETCH_MAX_SECTIONS]

        # Build output
        header = f"## {code_type.upper()} CODE - CHAPTER {chapter_id}"
        if total_sections > FETCH_MAX_SECTIONS:
            header += f" (showing {FETCH_MAX_SECTIONS} of {total_sections} sections)"
        if section_filter:
            header += f" [filtered by: '{section_filter}']"
        header += "\n\n"

        full_text = header
        for meta, doc in pairs:
            # Deduplicate [CONT.] blocks within the document
            blocks = doc.split("[CONT.]:")
            unique_blocks = []
            seen = set()
            for b in blocks:
                clean_b = b.strip()
                if clean_b:
                    h = hashlib.md5(clean_b.encode()).hexdigest()
                    if h not in seen:
                        unique_blocks.append(clean_b)
                        seen.add(h)

            clean_doc = " ".join(unique_blocks)

            exceptions_tag = ""
            if meta.get("has_exceptions", False):
                exceptions_tag = f" [CONTAINS {meta.get('exception_count', '?')} EXCEPTION(S)]"

            full_text += (
                f"### SECTION {meta['section_full']}{exceptions_tag}\n"
                f"{clean_doc}\n\n---\n\n"
            )

        return full_text

    except Exception as e:
        return f"Error retrieving chapter content: {e!s}"