| import uuid |
|
|
| import weaviate |
| from weaviate import Client |
| from weaviate.embedded import EmbeddedOptions |
| from weaviate.util import generate_uuid5 |
|
|
| from autogpt.config import Config |
| from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding |
|
|
|
|
| def default_schema(weaviate_index): |
| return { |
| "class": weaviate_index, |
| "properties": [ |
| { |
| "name": "raw_text", |
| "dataType": ["text"], |
| "description": "original text for the embedding", |
| } |
| ], |
| } |
|
|
|
|
| class WeaviateMemory(MemoryProviderSingleton): |
| def __init__(self, cfg): |
| auth_credentials = self._build_auth_credentials(cfg) |
|
|
| url = f"{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}" |
|
|
| if cfg.use_weaviate_embedded: |
| self.client = Client( |
| embedded_options=EmbeddedOptions( |
| hostname=cfg.weaviate_host, |
| port=int(cfg.weaviate_port), |
| persistence_data_path=cfg.weaviate_embedded_path, |
| ) |
| ) |
|
|
| print( |
| f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}" |
| ) |
| else: |
| self.client = Client(url, auth_client_secret=auth_credentials) |
|
|
| self.index = WeaviateMemory.format_classname(cfg.memory_index) |
| self._create_schema() |
|
|
| @staticmethod |
| def format_classname(index): |
| |
| |
| |
| if len(index) == 1: |
| return index.capitalize() |
| return index[0].capitalize() + index[1:] |
|
|
| def _create_schema(self): |
| schema = default_schema(self.index) |
| if not self.client.schema.contains(schema): |
| self.client.schema.create_class(schema) |
|
|
| def _build_auth_credentials(self, cfg): |
| if cfg.weaviate_username and cfg.weaviate_password: |
| return weaviate.AuthClientPassword( |
| cfg.weaviate_username, cfg.weaviate_password |
| ) |
| if cfg.weaviate_api_key: |
| return weaviate.AuthApiKey(api_key=cfg.weaviate_api_key) |
| else: |
| return None |
|
|
| def add(self, data): |
| vector = get_ada_embedding(data) |
|
|
| doc_uuid = generate_uuid5(data, self.index) |
| data_object = {"raw_text": data} |
|
|
| with self.client.batch as batch: |
| batch.add_data_object( |
| uuid=doc_uuid, |
| data_object=data_object, |
| class_name=self.index, |
| vector=vector, |
| ) |
|
|
| return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}" |
|
|
| def get(self, data): |
| return self.get_relevant(data, 1) |
|
|
| def clear(self): |
| self.client.schema.delete_all() |
|
|
| |
| |
| |
| self._create_schema() |
|
|
| return "Obliterated" |
|
|
| def get_relevant(self, data, num_relevant=5): |
| query_embedding = get_ada_embedding(data) |
| try: |
| results = ( |
| self.client.query.get(self.index, ["raw_text"]) |
| .with_near_vector({"vector": query_embedding, "certainty": 0.7}) |
| .with_limit(num_relevant) |
| .do() |
| ) |
|
|
| if len(results["data"]["Get"][self.index]) > 0: |
| return [ |
| str(item["raw_text"]) for item in results["data"]["Get"][self.index] |
| ] |
| else: |
| return [] |
|
|
| except Exception as err: |
| print(f"Unexpected error {err=}, {type(err)=}") |
| return [] |
|
|
| def get_stats(self): |
| result = self.client.query.aggregate(self.index).with_meta_count().do() |
| class_data = result["data"]["Aggregate"][self.index] |
|
|
| return class_data[0]["meta"] if class_data else {} |
|
|