# ============================================================================ # vectorstore.py — ChromaDB-backed vector store for the training dataset # ============================================================================ # # PURPOSE # ------- # Semantic vector storage and retrieval using ChromaDB as the backend. # Unlike training.py (which only holds vectors in RAM during a single # classifier fit or clustering run), this module PERSISTS vectors to disk # so students can index once and then run many semantic searches against # the stored collection. # # Uses the same sentence-transformers model as training.py so vectors are # comparable across all parts of the demo. # # WHAT GETS STORED # ---------------- # For each of the 100 training_data.py sentences we store: # - sentence text (the document) # - 384-dim embedding vector (from all-MiniLM-L6-v2) # - metadata: {label, index} # # Persistence: ChromaDB writes to ./chroma_db/ under the app directory. # On HuggingFace Spaces this persists for the life of the container but # is wiped on Space restart (Spaces are ephemeral). That is fine for a # teaching demo — students re-index at the start of each session. # # CONTRACT (what app.py imports from here) # ---------------------------------------- # get_collection() -> chroma collection (creates on first call) # index_training_data() -> {indexed, sentence_count, vector_dim} # search(query, n_results=5) -> list of dicts with sentence, label, score # clear_collection() -> drops all vectors # collection_stats() -> {count, embedding_model, persist_dir} # preview_vectors(n=10) -> list of {sentence, label, vector_head} dicts # used by the Vectorize sub-tab for inspection # ============================================================================ import os import providers from training_data import TRAINING_EXAMPLES # ---------------------------------------------------------------- # Configuration # ---------------------------------------------------------------- PERSIST_DIR = os.environ.get("CHROMA_PERSIST_DIR", "./chroma_db") COLLECTION_NAME = "training_sentences" DEFAULT_EMBEDDING_PROVIDER = "MiniLM (local)" # ---------------------------------------------------------------- # Lazy client for chromadb # ---------------------------------------------------------------- _CLIENT = None _COLLECTION = None def _get_client(): global _CLIENT if _CLIENT is None: import chromadb os.makedirs(PERSIST_DIR, exist_ok=True) _CLIENT = chromadb.PersistentClient(path=PERSIST_DIR) return _CLIENT def get_collection(): """Get or create the Chroma collection. Safe to call many times.""" global _COLLECTION if _COLLECTION is None: client = _get_client() _COLLECTION = client.get_or_create_collection( name=COLLECTION_NAME, metadata={"hnsw:space": "cosine"}, ) return _COLLECTION # ---------------------------------------------------------------- # Indexing — embed all 100 training sentences and persist to disk # ---------------------------------------------------------------- def index_training_data(embedding_provider=DEFAULT_EMBEDDING_PROVIDER, embedding_api_key=""): """Embed every sentence in TRAINING_EXAMPLES and write to the collection. Returns a dict with summary fields for UI display. If the collection already has rows they are cleared first so re-indexing is idempotent. """ collection = get_collection() # Reset so re-indexing is predictable existing_count = collection.count() if existing_count > 0: existing_ids = collection.get().get("ids", []) if existing_ids: collection.delete(ids=existing_ids) sentences = [e["sentence"] for e in TRAINING_EXAMPLES] labels = [e["label"] for e in TRAINING_EXAMPLES] vectors = providers.embed_texts( sentences, embedding_provider, embedding_api_key, ) ids = [f"sent_{i:03d}" for i in range(len(sentences))] metadatas = [ {"label": lab, "index": i} for i, lab in enumerate(labels) ] collection.add( ids=ids, documents=sentences, embeddings=vectors.tolist(), metadatas=metadatas, ) return { "indexed": len(sentences), "sentence_count": len(sentences), "vector_dim": int(vectors.shape[1]), "embedding_provider": embedding_provider, "embedding_model": providers.EMBEDDING_PROVIDERS[embedding_provider]["default_model"], "persist_dir": PERSIST_DIR, "collection_name": COLLECTION_NAME, } # ---------------------------------------------------------------- # Search — embed a query and retrieve nearest neighbors # ---------------------------------------------------------------- def search(query, n_results=5, embedding_provider=DEFAULT_EMBEDDING_PROVIDER, embedding_api_key=""): """Embed query and return top-N nearest training sentences.""" collection = get_collection() if collection.count() == 0: return [] q_vecs = providers.embed_texts( [query], embedding_provider, embedding_api_key, ) q_vec = q_vecs[0] res = collection.query( query_embeddings=[q_vec.tolist()], n_results=int(n_results), ) hits = [] docs = (res.get("documents") or [[]])[0] metas = (res.get("metadatas") or [[]])[0] dists = (res.get("distances") or [[]])[0] for doc, meta, dist in zip(docs, metas, dists): similarity = float(1.0 - dist) hits.append({ "sentence": doc, "label": (meta or {}).get("label"), "index": (meta or {}).get("index"), "distance": float(dist), "similarity": similarity, }) return hits # ---------------------------------------------------------------- # Utilities — clear, stats, preview # ---------------------------------------------------------------- def clear_collection(): collection = get_collection() ids = collection.get().get("ids", []) if ids: collection.delete(ids=ids) return {"cleared": len(ids)} def collection_stats(): collection = get_collection() return { "count": collection.count(), "persist_dir": PERSIST_DIR, "collection_name": COLLECTION_NAME, } def preview_vectors(n=10, embedding_provider=DEFAULT_EMBEDDING_PROVIDER, embedding_api_key=""): """Return the first N sentences with the head of their embedding vectors.""" rows = [] sample = TRAINING_EXAMPLES[:int(n)] sentences = [e["sentence"] for e in sample] vectors = providers.embed_texts( sentences, embedding_provider, embedding_api_key, ) for i, (ex, vec) in enumerate(zip(sample, vectors)): head = [round(float(x), 4) for x in vec[:8]] rows.append({ "index": i, "sentence": ex["sentence"], "label": ex["label"], "vector_head": str(head), "vector_dim": int(vec.shape[0]), }) return rows