emb1024 / client.py
gcharanteja
ch4
5b1c825
Raw
History Blame Contribute Delete
2.93 kB
import argparse
import os
from typing import List, Tuple
import chromadb
from chromadb.config import Settings
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="ChromaDB client demo")
parser.add_argument(
"--collection",
default=os.getenv("CHROMA_COLLECTION", "knowledge_base"),
help="Collection name",
)
parser.add_argument(
"--query",
default=os.getenv("CHROMA_QUERY", "Tell me about vector stores"),
help="Query text",
)
parser.add_argument(
"--n-results",
type=int,
default=int(os.getenv("CHROMA_N_RESULTS", "2")),
help="Number of results to return",
)
return parser.parse_args()
def _build_client() -> chromadb.Client:
host = os.getenv("CHROMA_HOST", "maxxcarl-emb1024.hf.space")
port = int(os.getenv("CHROMA_PORT", "443"))
ssl = os.getenv("CHROMA_SSL", "true").lower() in {"1", "true", "yes", "y"}
http_path = os.getenv("CHROMA_HTTP_PATH", "/chroma")
settings = Settings(
chroma_server_host=host,
chroma_server_http_port=port,
chroma_server_ssl=ssl,
chroma_server_http_path=http_path,
)
return chromadb.Client(settings)
def _seed_collection(collection: chromadb.Collection) -> Tuple[List[str], List[str]]:
documents = [
"Chroma is a lightweight, open-source vector database built for AI.",
"Python is a high-level programming language used extensively in data science.",
"The celestial body closest to Earth is the Moon.",
]
metadatas = [
{"category": "tech", "source": "docs"},
{"category": "tech", "source": "wiki"},
{"category": "science", "source": "space-facts"},
]
ids = ["doc1", "doc2", "doc3"]
collection.add(documents=documents, metadatas=metadatas, ids=ids)
return ids, documents
def main() -> None:
args = _parse_args()
client = _build_client()
print("Connecting to Chroma...")
collection = client.get_or_create_collection(name=args.collection)
print(f"Using collection: {args.collection}")
if collection.count() == 0:
print("Seeding collection with sample documents...")
ids, _ = _seed_collection(collection)
print(f"Added {len(ids)} documents.")
else:
print(f"Collection already has {collection.count()} documents.")
print(f"Query: {args.query}")
results = collection.query(query_texts=[args.query], n_results=args.n_results)
print("\n--- Search Results ---")
for doc, meta, distance in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
):
print(f"Matched Document: {doc}")
print(f"Metadata: {meta}")
print(f"Distance Score (Lower is better): {distance:.4f}")
print()
print("----------------------")
if __name__ == "__main__":
main()