Updated_code_complaince / tools /chroma_tools.py
Ryan2219's picture
Upload 70 files
e1ced8e verified
"""ChromaDB tools for NYC code lookup — with re-ranking, budget tracking, and caching."""
from __future__ import annotations
import hashlib
from collections import Counter
import chromadb
from chromadb.utils import embedding_functions
from config import (
CHROMA_COLLECTION_NAME,
CHROMA_DB_PATH,
DISCOVER_N_RESULTS,
EMBEDDING_MODEL_NAME,
FETCH_MAX_SECTIONS,
RERANK_TOP_K,
)
# ---------------------------------------------------------------------------
# Singleton collection loader
# ---------------------------------------------------------------------------
_collection = None
_warmup_done = False
def warmup_collection() -> bool:
"""Eagerly load the embedding model and connect to ChromaDB.
Returns True if collection is available, False otherwise.
Call this during app startup so the heavy model download + load
happens visibly (with a progress spinner) rather than on the first query.
"""
global _warmup_done
try:
get_collection()
_warmup_done = True
return True
except Exception:
_warmup_done = False
return False
def is_warmed_up() -> bool:
return _warmup_done
def get_collection():
"""Lazy-load the ChromaDB collection (singleton)."""
global _collection
if _collection is None:
client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=EMBEDDING_MODEL_NAME,
)
_collection = client.get_collection(
name=CHROMA_COLLECTION_NAME,
embedding_function=embedding_fn,
)
return _collection
# ---------------------------------------------------------------------------
# Query cache for deduplication
# ---------------------------------------------------------------------------
class QueryCache:
"""Simple cache to avoid re-querying semantically identical topics."""
def __init__(self):
self._cache: dict[str, str] = {} # normalized_key -> result
def _normalize(self, query: str) -> str:
words = sorted(set(query.lower().split()))
return " ".join(words)
def get(self, query: str) -> str | None:
key = self._normalize(query)
return self._cache.get(key)
def put(self, query: str, result: str) -> None:
key = self._normalize(query)
self._cache[key] = result
# ---------------------------------------------------------------------------
# discover_code_locations — semantic search with re-ranking
# ---------------------------------------------------------------------------
def discover_code_locations(query: str, cache: QueryCache | None = None) -> str:
"""Semantic search over NYC codes with hierarchical re-ranking.
Returns a formatted report of the most relevant code sections.
"""
# Check cache
if cache is not None:
cached = cache.get(query)
if cached is not None:
return f"[CACHED RESULT]\n{cached}"
collection = get_collection()
results = collection.query(
query_texts=[query],
n_results=DISCOVER_N_RESULTS,
include=["metadatas", "documents", "distances"],
)
if not results["metadatas"][0]:
return "No results found. Try a different query phrasing."
metas = results["metadatas"][0]
docs = results["documents"][0]
distances = results["distances"][0]
# ------ Re-ranking ------
# Score = semantic_similarity + hierarchy_bonus + exception_bonus
ranked = []
for meta, doc, dist in zip(metas, docs, distances):
score = -dist # Lower distance = better match, negate for sorting
# Hierarchy bonus: shallower sections (fewer dots) rank higher for broad queries
depth = meta.get("section_full", "").count(".")
score += max(0, 3 - depth) * 0.05 # Up to +0.15 for top-level sections
# Exception bonus: sections with exceptions are more useful for compliance
if meta.get("has_exceptions", False):
score += 0.1
ranked.append((score, meta, doc))
ranked.sort(key=lambda x: x[0], reverse=True)
top_results = ranked[:RERANK_TOP_K]
# ------ Format output ------
category_chapter_pairs = [
f"{m['code_type']} | Ch. {m['parent_major']}" for _, m, _ in top_results
]
counts = Counter(category_chapter_pairs)
chapter_summary = "\n".join(
f"- {pair} ({count} hits)" for pair, count in counts.most_common(5)
)
section_reports = []
for _score, m, doc in top_results:
exceptions_tag = " [HAS EXCEPTIONS]" if m.get("has_exceptions", False) else ""
xrefs = m.get("cross_references", "")
xref_tag = f"\n Cross-refs: {xrefs}" if xrefs else ""
report = (
f"ID: {m['section_full']} | Code: {m['code_type']} | Chapter: {m['parent_major']}"
f"{exceptions_tag}{xref_tag}\n"
f"Snippet: {doc[:500]}" # Truncate long snippets
)
section_reports.append(report)
output = (
"### CODE DISCOVERY REPORT ###\n"
f"MOST RELEVANT CHAPTERS:\n{chapter_summary}\n\n"
"TOP RELEVANT SECTIONS:\n"
+ "\n---\n".join(section_reports)
)
# Cache the result
if cache is not None:
cache.put(query, output)
return output
# ---------------------------------------------------------------------------
# fetch_full_chapter — with section filtering and pagination
# ---------------------------------------------------------------------------
def fetch_full_chapter(
code_type: str,
chapter_id: str,
section_filter: str | None = None,
) -> str:
"""Retrieve sections from a specific chapter, with optional keyword filtering.
Parameters
----------
code_type : str
One of: Administrative, Building, FuelGas, Mechanical, Plumbing
chapter_id : str
The parent_major chapter ID (e.g., "10", "602")
section_filter : str, optional
If provided, only return sections containing this keyword
"""
collection = get_collection()
try:
chapter_data = collection.get(
where={
"$and": [
{"code_type": {"$eq": code_type}},
{"parent_major": {"$eq": chapter_id}},
]
},
include=["documents", "metadatas"],
)
if not chapter_data["documents"]:
return f"No documentation found for {code_type} Chapter {chapter_id}."
pairs = list(zip(chapter_data["metadatas"], chapter_data["documents"]))
# Apply keyword filter if provided
if section_filter:
filter_lower = section_filter.lower()
pairs = [(m, d) for m, d in pairs if filter_lower in d.lower()]
if not pairs:
return (
f"No sections in {code_type} Chapter {chapter_id} "
f"match filter '{section_filter}'."
)
# Sort by section number and limit
pairs.sort(key=lambda x: x[0]["section_full"])
total_sections = len(pairs)
pairs = pairs[:FETCH_MAX_SECTIONS]
# Build output
header = f"## {code_type.upper()} CODE - CHAPTER {chapter_id}"
if total_sections > FETCH_MAX_SECTIONS:
header += f" (showing {FETCH_MAX_SECTIONS} of {total_sections} sections)"
if section_filter:
header += f" [filtered by: '{section_filter}']"
header += "\n\n"
full_text = header
for meta, doc in pairs:
# Deduplicate [CONT.] blocks within the document
blocks = doc.split("[CONT.]:")
unique_blocks = []
seen = set()
for b in blocks:
clean_b = b.strip()
if clean_b:
h = hashlib.md5(clean_b.encode()).hexdigest()
if h not in seen:
unique_blocks.append(clean_b)
seen.add(h)
clean_doc = " ".join(unique_blocks)
exceptions_tag = ""
if meta.get("has_exceptions", False):
exceptions_tag = f" [CONTAINS {meta.get('exception_count', '?')} EXCEPTION(S)]"
full_text += (
f"### SECTION {meta['section_full']}{exceptions_tag}\n"
f"{clean_doc}\n\n---\n\n"
)
return full_text
except Exception as e:
return f"Error retrieving chapter content: {e!s}"