NurseLex / local_search.py
NurseCitizenDeveloper's picture
fix(ui): remove experimental webgpu tab and switch semantic cache back to stable numpy arrays
0e0459c
"""
local_search.py
Locally loads the `i-dot-ai/all-miniLM-L6-v2-UKPGA-6k-finetune` model to encode
the cached `nursing_sections.json` into semantic embeddings for fast, reliable local searches.
"""
import json
import os
import logging
import torch
import numpy as np
from sentence_transformers import SentenceTransformer, util
logger = logging.getLogger(__name__)
# Constants
MODEL_NAME = "i-dot-ai/all-miniLM-L6-v2-UKPGA-6k-finetune"
CACHE_FILE = os.path.join(os.path.dirname(__file__), "nursing_sections.json")
EMBEDDINGS_FILE = os.path.join(os.path.dirname(__file__), "nursing_sections_embeddings.npy")
# Global variables to hold the model and embeddings in memory
_model = None
_corpus_embeddings = None
_sections = []
def init_local_search():
"""Initializes the model and computes embeddings for all cached sections."""
global _model, _corpus_embeddings, _sections
if _model is not None:
return # Already initialized
try:
logger.info(f"Loading local embedding model: {MODEL_NAME}...")
_model = SentenceTransformer(MODEL_NAME)
if not os.path.exists(CACHE_FILE):
logger.error(f"Cache file not found at {CACHE_FILE}")
return
with open(CACHE_FILE, "r", encoding="utf-8") as f:
_sections = json.load(f)
if not _sections:
logger.warning("No sections found in cache.")
return
if os.path.exists(EMBEDDINGS_FILE):
logger.info("Loading precomputed numpy embeddings from disk (Instant)...")
np_embeddings = np.load(EMBEDDINGS_FILE)
# Convert back to tensor for cosine similarity
_corpus_embeddings = torch.from_numpy(np_embeddings)
logger.info("Local semantic search engine ready.")
return
logger.info(f"Computing embeddings for {len(_sections)} cached sections. This may take a minute on first run...")
# Prepare text for embedding: combine legislation title, section title, and text
corpus_texts = []
for s in _sections:
# Reconstruct the act name roughly from the URL to give the model context
leg_id = s.get("legislation_id", "")
act_name = leg_id.split("/")[-2] if "/" in leg_id else leg_id
# Create a rich text representation for the vector search
content = f"Act: {act_name}. Section {s.get('number', '')}: {s.get('title', '')}. {s.get('text', '')}"
corpus_texts.append(content)
# Encode all sections
_corpus_embeddings = _model.encode(corpus_texts, convert_to_tensor=True, show_progress_bar=False)
logger.info("Saving computed numpy embeddings for future use...")
try:
np.save(EMBEDDINGS_FILE, _corpus_embeddings.cpu().numpy())
except Exception as save_err:
logger.warning(f"Failed to save embeddings cache: {save_err}")
logger.info("Local semantic search engine ready.")
except Exception as e:
logger.error(f"Failed to initialize local search engine: {e}")
_model = None # Reset on failure
def search_scenarios_locally(query: str, top_k: int = 5) -> list[dict]:
"""Semantic search over the local cached sections using cosine similarity."""
global _model, _corpus_embeddings, _sections
if _model is None or _corpus_embeddings is None:
init_local_search()
if _model is None or _corpus_embeddings is None:
logger.error("Local search engine is unavailable.")
return []
try:
query_embedding = _model.encode(query, convert_to_tensor=True)
# Compute cosine similarities
cos_scores = util.cos_sim(query_embedding, _corpus_embeddings)[0]
# Find the top_k scores
top_results = torch.topk(cos_scores, k=min(top_k, len(_sections)))
results = []
for score, idx in zip(top_results[0], top_results[1]):
# Only return highly relevant matches (tune this threshold if needed)
if score.item() > 0.4:
match = _sections[idx].copy()
match["score"] = score.item()
results.append(match)
return results
except Exception as e:
logger.error(f"Error during local scenario search: {e}")
return []