File size: 7,431 Bytes
a52bae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
# ============================================================================
# vectorstore.py — ChromaDB-backed vector store for the training dataset
# ============================================================================
#
# PURPOSE
# -------
# Semantic vector storage and retrieval using ChromaDB as the backend.
# Unlike training.py (which only holds vectors in RAM during a single
# classifier fit or clustering run), this module PERSISTS vectors to disk
# so students can index once and then run many semantic searches against
# the stored collection.
#
# Uses the same sentence-transformers model as training.py so vectors are
# comparable across all parts of the demo.
#
# WHAT GETS STORED
# ----------------
# For each of the 100 training_data.py sentences we store:
#   - sentence text (the document)
#   - 384-dim embedding vector (from all-MiniLM-L6-v2)
#   - metadata: {label, index}
#
# Persistence: ChromaDB writes to ./chroma_db/ under the app directory.
# On HuggingFace Spaces this persists for the life of the container but
# is wiped on Space restart (Spaces are ephemeral). That is fine for a
# teaching demo — students re-index at the start of each session.
#
# CONTRACT (what app.py imports from here)
# ----------------------------------------
#   get_collection()            -> chroma collection (creates on first call)
#   index_training_data()       -> {indexed, sentence_count, vector_dim}
#   search(query, n_results=5)  -> list of dicts with sentence, label, score
#   clear_collection()          -> drops all vectors
#   collection_stats()          -> {count, embedding_model, persist_dir}
#   preview_vectors(n=10)       -> list of {sentence, label, vector_head} dicts
#                                  used by the Vectorize sub-tab for inspection
# ============================================================================


import os
import providers
from training_data import TRAINING_EXAMPLES


# ----------------------------------------------------------------
# Configuration
# ----------------------------------------------------------------
PERSIST_DIR = os.environ.get("CHROMA_PERSIST_DIR", "./chroma_db")
COLLECTION_NAME = "training_sentences"
DEFAULT_EMBEDDING_PROVIDER = "MiniLM (local)"


# ----------------------------------------------------------------
# Lazy client for chromadb
# ----------------------------------------------------------------
_CLIENT = None
_COLLECTION = None


def _get_client():
    global _CLIENT
    if _CLIENT is None:
        import chromadb
        os.makedirs(PERSIST_DIR, exist_ok=True)
        _CLIENT = chromadb.PersistentClient(path=PERSIST_DIR)
    return _CLIENT


def get_collection():
    """Get or create the Chroma collection. Safe to call many times."""
    global _COLLECTION
    if _COLLECTION is None:
        client = _get_client()
        _COLLECTION = client.get_or_create_collection(
            name=COLLECTION_NAME,
            metadata={"hnsw:space": "cosine"},
        )
    return _COLLECTION


# ----------------------------------------------------------------
# Indexing — embed all 100 training sentences and persist to disk
# ----------------------------------------------------------------
def index_training_data(embedding_provider=DEFAULT_EMBEDDING_PROVIDER,

                        embedding_api_key=""):
    """Embed every sentence in TRAINING_EXAMPLES and write to the collection.



    Returns a dict with summary fields for UI display. If the collection

    already has rows they are cleared first so re-indexing is idempotent.

    """
    collection = get_collection()

    # Reset so re-indexing is predictable
    existing_count = collection.count()
    if existing_count > 0:
        existing_ids = collection.get().get("ids", [])
        if existing_ids:
            collection.delete(ids=existing_ids)

    sentences = [e["sentence"] for e in TRAINING_EXAMPLES]
    labels = [e["label"] for e in TRAINING_EXAMPLES]

    vectors = providers.embed_texts(
        sentences, embedding_provider, embedding_api_key,
    )

    ids = [f"sent_{i:03d}" for i in range(len(sentences))]
    metadatas = [
        {"label": lab, "index": i}
        for i, lab in enumerate(labels)
    ]

    collection.add(
        ids=ids,
        documents=sentences,
        embeddings=vectors.tolist(),
        metadatas=metadatas,
    )

    return {
        "indexed": len(sentences),
        "sentence_count": len(sentences),
        "vector_dim": int(vectors.shape[1]),
        "embedding_provider": embedding_provider,
        "embedding_model": providers.EMBEDDING_PROVIDERS[embedding_provider]["default_model"],
        "persist_dir": PERSIST_DIR,
        "collection_name": COLLECTION_NAME,
    }


# ----------------------------------------------------------------
# Search — embed a query and retrieve nearest neighbors
# ----------------------------------------------------------------
def search(query, n_results=5,

           embedding_provider=DEFAULT_EMBEDDING_PROVIDER,

           embedding_api_key=""):
    """Embed query and return top-N nearest training sentences."""
    collection = get_collection()
    if collection.count() == 0:
        return []

    q_vecs = providers.embed_texts(
        [query], embedding_provider, embedding_api_key,
    )
    q_vec = q_vecs[0]

    res = collection.query(
        query_embeddings=[q_vec.tolist()],
        n_results=int(n_results),
    )

    hits = []
    docs = (res.get("documents") or [[]])[0]
    metas = (res.get("metadatas") or [[]])[0]
    dists = (res.get("distances") or [[]])[0]
    for doc, meta, dist in zip(docs, metas, dists):
        similarity = float(1.0 - dist)
        hits.append({
            "sentence": doc,
            "label": (meta or {}).get("label"),
            "index": (meta or {}).get("index"),
            "distance": float(dist),
            "similarity": similarity,
        })
    return hits


# ----------------------------------------------------------------
# Utilities — clear, stats, preview
# ----------------------------------------------------------------
def clear_collection():
    collection = get_collection()
    ids = collection.get().get("ids", [])
    if ids:
        collection.delete(ids=ids)
    return {"cleared": len(ids)}


def collection_stats():
    collection = get_collection()
    return {
        "count": collection.count(),
        "persist_dir": PERSIST_DIR,
        "collection_name": COLLECTION_NAME,
    }


def preview_vectors(n=10,

                    embedding_provider=DEFAULT_EMBEDDING_PROVIDER,

                    embedding_api_key=""):
    """Return the first N sentences with the head of their embedding vectors."""
    rows = []
    sample = TRAINING_EXAMPLES[:int(n)]
    sentences = [e["sentence"] for e in sample]
    vectors = providers.embed_texts(
        sentences, embedding_provider, embedding_api_key,
    )

    for i, (ex, vec) in enumerate(zip(sample, vectors)):
        head = [round(float(x), 4) for x in vec[:8]]
        rows.append({
            "index": i,
            "sentence": ex["sentence"],
            "label": ex["label"],
            "vector_head": str(head),
            "vector_dim": int(vec.shape[0]),
        })
    return rows