"""ChromaDB tools for NYC code lookup — with re-ranking, budget tracking, and caching.""" from __future__ import annotations import hashlib from collections import Counter import chromadb from chromadb.utils import embedding_functions from config import ( CHROMA_COLLECTION_NAME, CHROMA_DB_PATH, DISCOVER_N_RESULTS, EMBEDDING_MODEL_NAME, FETCH_MAX_SECTIONS, RERANK_TOP_K, ) # --------------------------------------------------------------------------- # Singleton collection loader # --------------------------------------------------------------------------- _collection = None _warmup_done = False def warmup_collection() -> bool: """Eagerly load the embedding model and connect to ChromaDB. Returns True if collection is available, False otherwise. Call this during app startup so the heavy model download + load happens visibly (with a progress spinner) rather than on the first query. """ global _warmup_done try: get_collection() _warmup_done = True return True except Exception: _warmup_done = False return False def is_warmed_up() -> bool: return _warmup_done def get_collection(): """Lazy-load the ChromaDB collection (singleton).""" global _collection if _collection is None: client = chromadb.PersistentClient(path=CHROMA_DB_PATH) embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction( model_name=EMBEDDING_MODEL_NAME, ) _collection = client.get_collection( name=CHROMA_COLLECTION_NAME, embedding_function=embedding_fn, ) return _collection # --------------------------------------------------------------------------- # Query cache for deduplication # --------------------------------------------------------------------------- class QueryCache: """Simple cache to avoid re-querying semantically identical topics.""" def __init__(self): self._cache: dict[str, str] = {} # normalized_key -> result def _normalize(self, query: str) -> str: words = sorted(set(query.lower().split())) return " ".join(words) def get(self, query: str) -> str | None: key = self._normalize(query) return self._cache.get(key) def put(self, query: str, result: str) -> None: key = self._normalize(query) self._cache[key] = result # --------------------------------------------------------------------------- # discover_code_locations — semantic search with re-ranking # --------------------------------------------------------------------------- def discover_code_locations(query: str, cache: QueryCache | None = None) -> str: """Semantic search over NYC codes with hierarchical re-ranking. Returns a formatted report of the most relevant code sections. """ # Check cache if cache is not None: cached = cache.get(query) if cached is not None: return f"[CACHED RESULT]\n{cached}" collection = get_collection() results = collection.query( query_texts=[query], n_results=DISCOVER_N_RESULTS, include=["metadatas", "documents", "distances"], ) if not results["metadatas"][0]: return "No results found. Try a different query phrasing." metas = results["metadatas"][0] docs = results["documents"][0] distances = results["distances"][0] # ------ Re-ranking ------ # Score = semantic_similarity + hierarchy_bonus + exception_bonus ranked = [] for meta, doc, dist in zip(metas, docs, distances): score = -dist # Lower distance = better match, negate for sorting # Hierarchy bonus: shallower sections (fewer dots) rank higher for broad queries depth = meta.get("section_full", "").count(".") score += max(0, 3 - depth) * 0.05 # Up to +0.15 for top-level sections # Exception bonus: sections with exceptions are more useful for compliance if meta.get("has_exceptions", False): score += 0.1 ranked.append((score, meta, doc)) ranked.sort(key=lambda x: x[0], reverse=True) top_results = ranked[:RERANK_TOP_K] # ------ Format output ------ category_chapter_pairs = [ f"{m['code_type']} | Ch. {m['parent_major']}" for _, m, _ in top_results ] counts = Counter(category_chapter_pairs) chapter_summary = "\n".join( f"- {pair} ({count} hits)" for pair, count in counts.most_common(5) ) section_reports = [] for _score, m, doc in top_results: exceptions_tag = " [HAS EXCEPTIONS]" if m.get("has_exceptions", False) else "" xrefs = m.get("cross_references", "") xref_tag = f"\n Cross-refs: {xrefs}" if xrefs else "" report = ( f"ID: {m['section_full']} | Code: {m['code_type']} | Chapter: {m['parent_major']}" f"{exceptions_tag}{xref_tag}\n" f"Snippet: {doc[:500]}" # Truncate long snippets ) section_reports.append(report) output = ( "### CODE DISCOVERY REPORT ###\n" f"MOST RELEVANT CHAPTERS:\n{chapter_summary}\n\n" "TOP RELEVANT SECTIONS:\n" + "\n---\n".join(section_reports) ) # Cache the result if cache is not None: cache.put(query, output) return output # --------------------------------------------------------------------------- # fetch_full_chapter — with section filtering and pagination # --------------------------------------------------------------------------- def fetch_full_chapter( code_type: str, chapter_id: str, section_filter: str | None = None, ) -> str: """Retrieve sections from a specific chapter, with optional keyword filtering. Parameters ---------- code_type : str One of: Administrative, Building, FuelGas, Mechanical, Plumbing chapter_id : str The parent_major chapter ID (e.g., "10", "602") section_filter : str, optional If provided, only return sections containing this keyword """ collection = get_collection() try: chapter_data = collection.get( where={ "$and": [ {"code_type": {"$eq": code_type}}, {"parent_major": {"$eq": chapter_id}}, ] }, include=["documents", "metadatas"], ) if not chapter_data["documents"]: return f"No documentation found for {code_type} Chapter {chapter_id}." pairs = list(zip(chapter_data["metadatas"], chapter_data["documents"])) # Apply keyword filter if provided if section_filter: filter_lower = section_filter.lower() pairs = [(m, d) for m, d in pairs if filter_lower in d.lower()] if not pairs: return ( f"No sections in {code_type} Chapter {chapter_id} " f"match filter '{section_filter}'." ) # Sort by section number and limit pairs.sort(key=lambda x: x[0]["section_full"]) total_sections = len(pairs) pairs = pairs[:FETCH_MAX_SECTIONS] # Build output header = f"## {code_type.upper()} CODE - CHAPTER {chapter_id}" if total_sections > FETCH_MAX_SECTIONS: header += f" (showing {FETCH_MAX_SECTIONS} of {total_sections} sections)" if section_filter: header += f" [filtered by: '{section_filter}']" header += "\n\n" full_text = header for meta, doc in pairs: # Deduplicate [CONT.] blocks within the document blocks = doc.split("[CONT.]:") unique_blocks = [] seen = set() for b in blocks: clean_b = b.strip() if clean_b: h = hashlib.md5(clean_b.encode()).hexdigest() if h not in seen: unique_blocks.append(clean_b) seen.add(h) clean_doc = " ".join(unique_blocks) exceptions_tag = "" if meta.get("has_exceptions", False): exceptions_tag = f" [CONTAINS {meta.get('exception_count', '?')} EXCEPTION(S)]" full_text += ( f"### SECTION {meta['section_full']}{exceptions_tag}\n" f"{clean_doc}\n\n---\n\n" ) return full_text except Exception as e: return f"Error retrieving chapter content: {e!s}"