|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import providers
|
| from training_data import TRAINING_EXAMPLES
|
|
|
|
|
|
|
|
|
|
|
| PERSIST_DIR = os.environ.get("CHROMA_PERSIST_DIR", "./chroma_db")
|
| COLLECTION_NAME = "training_sentences"
|
| DEFAULT_EMBEDDING_PROVIDER = "MiniLM (local)"
|
|
|
|
|
|
|
|
|
|
|
| _CLIENT = None
|
| _COLLECTION = None
|
|
|
|
|
| def _get_client():
|
| global _CLIENT
|
| if _CLIENT is None:
|
| import chromadb
|
| os.makedirs(PERSIST_DIR, exist_ok=True)
|
| _CLIENT = chromadb.PersistentClient(path=PERSIST_DIR)
|
| return _CLIENT
|
|
|
|
|
| def get_collection():
|
| """Get or create the Chroma collection. Safe to call many times."""
|
| global _COLLECTION
|
| if _COLLECTION is None:
|
| client = _get_client()
|
| _COLLECTION = client.get_or_create_collection(
|
| name=COLLECTION_NAME,
|
| metadata={"hnsw:space": "cosine"},
|
| )
|
| return _COLLECTION
|
|
|
|
|
|
|
|
|
|
|
| def index_training_data(embedding_provider=DEFAULT_EMBEDDING_PROVIDER,
|
| embedding_api_key=""):
|
| """Embed every sentence in TRAINING_EXAMPLES and write to the collection.
|
|
|
| Returns a dict with summary fields for UI display. If the collection
|
| already has rows they are cleared first so re-indexing is idempotent.
|
| """
|
| collection = get_collection()
|
|
|
|
|
| existing_count = collection.count()
|
| if existing_count > 0:
|
| existing_ids = collection.get().get("ids", [])
|
| if existing_ids:
|
| collection.delete(ids=existing_ids)
|
|
|
| sentences = [e["sentence"] for e in TRAINING_EXAMPLES]
|
| labels = [e["label"] for e in TRAINING_EXAMPLES]
|
|
|
| vectors = providers.embed_texts(
|
| sentences, embedding_provider, embedding_api_key,
|
| )
|
|
|
| ids = [f"sent_{i:03d}" for i in range(len(sentences))]
|
| metadatas = [
|
| {"label": lab, "index": i}
|
| for i, lab in enumerate(labels)
|
| ]
|
|
|
| collection.add(
|
| ids=ids,
|
| documents=sentences,
|
| embeddings=vectors.tolist(),
|
| metadatas=metadatas,
|
| )
|
|
|
| return {
|
| "indexed": len(sentences),
|
| "sentence_count": len(sentences),
|
| "vector_dim": int(vectors.shape[1]),
|
| "embedding_provider": embedding_provider,
|
| "embedding_model": providers.EMBEDDING_PROVIDERS[embedding_provider]["default_model"],
|
| "persist_dir": PERSIST_DIR,
|
| "collection_name": COLLECTION_NAME,
|
| }
|
|
|
|
|
|
|
|
|
|
|
| def search(query, n_results=5,
|
| embedding_provider=DEFAULT_EMBEDDING_PROVIDER,
|
| embedding_api_key=""):
|
| """Embed query and return top-N nearest training sentences."""
|
| collection = get_collection()
|
| if collection.count() == 0:
|
| return []
|
|
|
| q_vecs = providers.embed_texts(
|
| [query], embedding_provider, embedding_api_key,
|
| )
|
| q_vec = q_vecs[0]
|
|
|
| res = collection.query(
|
| query_embeddings=[q_vec.tolist()],
|
| n_results=int(n_results),
|
| )
|
|
|
| hits = []
|
| docs = (res.get("documents") or [[]])[0]
|
| metas = (res.get("metadatas") or [[]])[0]
|
| dists = (res.get("distances") or [[]])[0]
|
| for doc, meta, dist in zip(docs, metas, dists):
|
| similarity = float(1.0 - dist)
|
| hits.append({
|
| "sentence": doc,
|
| "label": (meta or {}).get("label"),
|
| "index": (meta or {}).get("index"),
|
| "distance": float(dist),
|
| "similarity": similarity,
|
| })
|
| return hits
|
|
|
|
|
|
|
|
|
|
|
| def clear_collection():
|
| collection = get_collection()
|
| ids = collection.get().get("ids", [])
|
| if ids:
|
| collection.delete(ids=ids)
|
| return {"cleared": len(ids)}
|
|
|
|
|
| def collection_stats():
|
| collection = get_collection()
|
| return {
|
| "count": collection.count(),
|
| "persist_dir": PERSIST_DIR,
|
| "collection_name": COLLECTION_NAME,
|
| }
|
|
|
|
|
| def preview_vectors(n=10,
|
| embedding_provider=DEFAULT_EMBEDDING_PROVIDER,
|
| embedding_api_key=""):
|
| """Return the first N sentences with the head of their embedding vectors."""
|
| rows = []
|
| sample = TRAINING_EXAMPLES[:int(n)]
|
| sentences = [e["sentence"] for e in sample]
|
| vectors = providers.embed_texts(
|
| sentences, embedding_provider, embedding_api_key,
|
| )
|
|
|
| for i, (ex, vec) in enumerate(zip(sample, vectors)):
|
| head = [round(float(x), 4) for x in vec[:8]]
|
| rows.append({
|
| "index": i,
|
| "sentence": ex["sentence"],
|
| "label": ex["label"],
|
| "vector_head": str(head),
|
| "vector_dim": int(vec.shape[0]),
|
| })
|
| return rows
|
|
|