Spjimr / vectorstore.py
shahidshaikh's picture
Upload 40 files
a52bae4 verified
# ============================================================================
# vectorstore.py — ChromaDB-backed vector store for the training dataset
# ============================================================================
#
# PURPOSE
# -------
# Semantic vector storage and retrieval using ChromaDB as the backend.
# Unlike training.py (which only holds vectors in RAM during a single
# classifier fit or clustering run), this module PERSISTS vectors to disk
# so students can index once and then run many semantic searches against
# the stored collection.
#
# Uses the same sentence-transformers model as training.py so vectors are
# comparable across all parts of the demo.
#
# WHAT GETS STORED
# ----------------
# For each of the 100 training_data.py sentences we store:
# - sentence text (the document)
# - 384-dim embedding vector (from all-MiniLM-L6-v2)
# - metadata: {label, index}
#
# Persistence: ChromaDB writes to ./chroma_db/ under the app directory.
# On HuggingFace Spaces this persists for the life of the container but
# is wiped on Space restart (Spaces are ephemeral). That is fine for a
# teaching demo — students re-index at the start of each session.
#
# CONTRACT (what app.py imports from here)
# ----------------------------------------
# get_collection() -> chroma collection (creates on first call)
# index_training_data() -> {indexed, sentence_count, vector_dim}
# search(query, n_results=5) -> list of dicts with sentence, label, score
# clear_collection() -> drops all vectors
# collection_stats() -> {count, embedding_model, persist_dir}
# preview_vectors(n=10) -> list of {sentence, label, vector_head} dicts
# used by the Vectorize sub-tab for inspection
# ============================================================================
import os
import providers
from training_data import TRAINING_EXAMPLES
# ----------------------------------------------------------------
# Configuration
# ----------------------------------------------------------------
PERSIST_DIR = os.environ.get("CHROMA_PERSIST_DIR", "./chroma_db")
COLLECTION_NAME = "training_sentences"
DEFAULT_EMBEDDING_PROVIDER = "MiniLM (local)"
# ----------------------------------------------------------------
# Lazy client for chromadb
# ----------------------------------------------------------------
_CLIENT = None
_COLLECTION = None
def _get_client():
global _CLIENT
if _CLIENT is None:
import chromadb
os.makedirs(PERSIST_DIR, exist_ok=True)
_CLIENT = chromadb.PersistentClient(path=PERSIST_DIR)
return _CLIENT
def get_collection():
"""Get or create the Chroma collection. Safe to call many times."""
global _COLLECTION
if _COLLECTION is None:
client = _get_client()
_COLLECTION = client.get_or_create_collection(
name=COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
return _COLLECTION
# ----------------------------------------------------------------
# Indexing — embed all 100 training sentences and persist to disk
# ----------------------------------------------------------------
def index_training_data(embedding_provider=DEFAULT_EMBEDDING_PROVIDER,
embedding_api_key=""):
"""Embed every sentence in TRAINING_EXAMPLES and write to the collection.
Returns a dict with summary fields for UI display. If the collection
already has rows they are cleared first so re-indexing is idempotent.
"""
collection = get_collection()
# Reset so re-indexing is predictable
existing_count = collection.count()
if existing_count > 0:
existing_ids = collection.get().get("ids", [])
if existing_ids:
collection.delete(ids=existing_ids)
sentences = [e["sentence"] for e in TRAINING_EXAMPLES]
labels = [e["label"] for e in TRAINING_EXAMPLES]
vectors = providers.embed_texts(
sentences, embedding_provider, embedding_api_key,
)
ids = [f"sent_{i:03d}" for i in range(len(sentences))]
metadatas = [
{"label": lab, "index": i}
for i, lab in enumerate(labels)
]
collection.add(
ids=ids,
documents=sentences,
embeddings=vectors.tolist(),
metadatas=metadatas,
)
return {
"indexed": len(sentences),
"sentence_count": len(sentences),
"vector_dim": int(vectors.shape[1]),
"embedding_provider": embedding_provider,
"embedding_model": providers.EMBEDDING_PROVIDERS[embedding_provider]["default_model"],
"persist_dir": PERSIST_DIR,
"collection_name": COLLECTION_NAME,
}
# ----------------------------------------------------------------
# Search — embed a query and retrieve nearest neighbors
# ----------------------------------------------------------------
def search(query, n_results=5,
embedding_provider=DEFAULT_EMBEDDING_PROVIDER,
embedding_api_key=""):
"""Embed query and return top-N nearest training sentences."""
collection = get_collection()
if collection.count() == 0:
return []
q_vecs = providers.embed_texts(
[query], embedding_provider, embedding_api_key,
)
q_vec = q_vecs[0]
res = collection.query(
query_embeddings=[q_vec.tolist()],
n_results=int(n_results),
)
hits = []
docs = (res.get("documents") or [[]])[0]
metas = (res.get("metadatas") or [[]])[0]
dists = (res.get("distances") or [[]])[0]
for doc, meta, dist in zip(docs, metas, dists):
similarity = float(1.0 - dist)
hits.append({
"sentence": doc,
"label": (meta or {}).get("label"),
"index": (meta or {}).get("index"),
"distance": float(dist),
"similarity": similarity,
})
return hits
# ----------------------------------------------------------------
# Utilities — clear, stats, preview
# ----------------------------------------------------------------
def clear_collection():
collection = get_collection()
ids = collection.get().get("ids", [])
if ids:
collection.delete(ids=ids)
return {"cleared": len(ids)}
def collection_stats():
collection = get_collection()
return {
"count": collection.count(),
"persist_dir": PERSIST_DIR,
"collection_name": COLLECTION_NAME,
}
def preview_vectors(n=10,
embedding_provider=DEFAULT_EMBEDDING_PROVIDER,
embedding_api_key=""):
"""Return the first N sentences with the head of their embedding vectors."""
rows = []
sample = TRAINING_EXAMPLES[:int(n)]
sentences = [e["sentence"] for e in sample]
vectors = providers.embed_texts(
sentences, embedding_provider, embedding_api_key,
)
for i, (ex, vec) in enumerate(zip(sample, vectors)):
head = [round(float(x), 4) for x in vec[:8]]
rows.append({
"index": i,
"sentence": ex["sentence"],
"label": ex["label"],
"vector_head": str(head),
"vector_dim": int(vec.shape[0]),
})
return rows