| import chromadb | |
| from datetime import datetime | |
| chroma_client = chromadb.Client() | |
| def get_or_create_collection(coll_name: str): | |
| date = coll_name[:6] | |
| coll = chroma_client.get_or_create_collection(name=coll_name, metadata={"date": date}) | |
| return coll | |
| def get_collection(coll_name: str): | |
| coll = chroma_client.get_collection(name=coll_name) | |
| return coll | |
| def reset_collection(coll_name: str): | |
| coll = chroma_client.get_collection(name=coll_name) | |
| coll.delete() | |
| return coll | |
| def delete_old_collections(old=2): | |
| collections = chroma_client.list_collections() | |
| current_hour = int(datetime.now().strftime("%m%d%H")) | |
| for coll in collections: | |
| coll_hour = int(coll.metadata['date']) | |
| if coll_hour < current_hour - old: | |
| chroma_client.delete_collection(coll.name) | |
| def add_texts_to_collection(coll_name: str, texts: [str], file: str, source: str): | |
| """ | |
| add texts to a collection : texts originate all from the same file | |
| """ | |
| coll = chroma_client.get_collection(name=coll_name) | |
| filenames = [{file: 1, 'source': source} for _ in texts] | |
| ids = [file+'-'+str(i) for i in range(len(texts))] | |
| try: | |
| coll.delete(ids=ids) | |
| coll.add(documents=texts, metadatas=filenames, ids=ids) | |
| except: | |
| print(f"exception raised for collection :{coll_name}, texts: {texts} from file {file} and source {source}") | |
| def delete_collection(coll_name: str): | |
| chroma_client.delete_collection(name=coll_name) | |
| def list_collections(): | |
| return chroma_client.list_collections() | |
| def query_collection(coll_name: str, query: str, from_files: [str], n_results: int = 4): | |
| assert 0 < len(from_files) | |
| coll = chroma_client.get_collection(name=coll_name) | |
| where_ = [{file: 1} for file in from_files] | |
| where_ = where_[0] if len(where_) == 1 else {'$or': where_} | |
| n_results_ = min(n_results, coll.count()) | |
| ans = "" | |
| try: | |
| ans = coll.query(query_texts=query, n_results=n_results_, where=where_) | |
| except: | |
| print(f"exception raised at query collection for collection {coll_name} and query {query} from files " | |
| f"{from_files}") | |
| return ans | |