| import argparse |
| import os |
| from typing import List, Tuple |
|
|
| import chromadb |
| from chromadb.config import Settings |
|
|
|
|
| def _parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="ChromaDB client demo") |
| parser.add_argument( |
| "--collection", |
| default=os.getenv("CHROMA_COLLECTION", "knowledge_base"), |
| help="Collection name", |
| ) |
| parser.add_argument( |
| "--query", |
| default=os.getenv("CHROMA_QUERY", "Tell me about vector stores"), |
| help="Query text", |
| ) |
| parser.add_argument( |
| "--n-results", |
| type=int, |
| default=int(os.getenv("CHROMA_N_RESULTS", "2")), |
| help="Number of results to return", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def _build_client() -> chromadb.Client: |
| host = os.getenv("CHROMA_HOST", "maxxcarl-emb1024.hf.space") |
| port = int(os.getenv("CHROMA_PORT", "443")) |
| ssl = os.getenv("CHROMA_SSL", "true").lower() in {"1", "true", "yes", "y"} |
| http_path = os.getenv("CHROMA_HTTP_PATH", "/chroma") |
|
|
| settings = Settings( |
| chroma_server_host=host, |
| chroma_server_http_port=port, |
| chroma_server_ssl=ssl, |
| chroma_server_http_path=http_path, |
| ) |
| return chromadb.Client(settings) |
|
|
|
|
| def _seed_collection(collection: chromadb.Collection) -> Tuple[List[str], List[str]]: |
| documents = [ |
| "Chroma is a lightweight, open-source vector database built for AI.", |
| "Python is a high-level programming language used extensively in data science.", |
| "The celestial body closest to Earth is the Moon.", |
| ] |
| metadatas = [ |
| {"category": "tech", "source": "docs"}, |
| {"category": "tech", "source": "wiki"}, |
| {"category": "science", "source": "space-facts"}, |
| ] |
| ids = ["doc1", "doc2", "doc3"] |
|
|
| collection.add(documents=documents, metadatas=metadatas, ids=ids) |
| return ids, documents |
|
|
|
|
| def main() -> None: |
| args = _parse_args() |
| client = _build_client() |
|
|
| print("Connecting to Chroma...") |
| collection = client.get_or_create_collection(name=args.collection) |
| print(f"Using collection: {args.collection}") |
|
|
| if collection.count() == 0: |
| print("Seeding collection with sample documents...") |
| ids, _ = _seed_collection(collection) |
| print(f"Added {len(ids)} documents.") |
| else: |
| print(f"Collection already has {collection.count()} documents.") |
|
|
| print(f"Query: {args.query}") |
| results = collection.query(query_texts=[args.query], n_results=args.n_results) |
|
|
| print("\n--- Search Results ---") |
| for doc, meta, distance in zip( |
| results["documents"][0], |
| results["metadatas"][0], |
| results["distances"][0], |
| ): |
| print(f"Matched Document: {doc}") |
| print(f"Metadata: {meta}") |
| print(f"Distance Score (Lower is better): {distance:.4f}") |
| print() |
| print("----------------------") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|