| | import uuid |
| |
|
| | import weaviate |
| | from weaviate import Client |
| | from weaviate.embedded import EmbeddedOptions |
| | from weaviate.util import generate_uuid5 |
| |
|
| | from autogpt.config import Config |
| | from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding |
| |
|
| |
|
| | def default_schema(weaviate_index): |
| | return { |
| | "class": weaviate_index, |
| | "properties": [ |
| | { |
| | "name": "raw_text", |
| | "dataType": ["text"], |
| | "description": "original text for the embedding", |
| | } |
| | ], |
| | } |
| |
|
| |
|
| | class WeaviateMemory(MemoryProviderSingleton): |
| | def __init__(self, cfg): |
| | auth_credentials = self._build_auth_credentials(cfg) |
| |
|
| | url = f"{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}" |
| |
|
| | if cfg.use_weaviate_embedded: |
| | self.client = Client( |
| | embedded_options=EmbeddedOptions( |
| | hostname=cfg.weaviate_host, |
| | port=int(cfg.weaviate_port), |
| | persistence_data_path=cfg.weaviate_embedded_path, |
| | ) |
| | ) |
| |
|
| | print( |
| | f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}" |
| | ) |
| | else: |
| | self.client = Client(url, auth_client_secret=auth_credentials) |
| |
|
| | self.index = WeaviateMemory.format_classname(cfg.memory_index) |
| | self._create_schema() |
| |
|
| | @staticmethod |
| | def format_classname(index): |
| | |
| | |
| | |
| | if len(index) == 1: |
| | return index.capitalize() |
| | return index[0].capitalize() + index[1:] |
| |
|
| | def _create_schema(self): |
| | schema = default_schema(self.index) |
| | if not self.client.schema.contains(schema): |
| | self.client.schema.create_class(schema) |
| |
|
| | def _build_auth_credentials(self, cfg): |
| | if cfg.weaviate_username and cfg.weaviate_password: |
| | return weaviate.AuthClientPassword( |
| | cfg.weaviate_username, cfg.weaviate_password |
| | ) |
| | if cfg.weaviate_api_key: |
| | return weaviate.AuthApiKey(api_key=cfg.weaviate_api_key) |
| | else: |
| | return None |
| |
|
| | def add(self, data): |
| | vector = get_ada_embedding(data) |
| |
|
| | doc_uuid = generate_uuid5(data, self.index) |
| | data_object = {"raw_text": data} |
| |
|
| | with self.client.batch as batch: |
| | batch.add_data_object( |
| | uuid=doc_uuid, |
| | data_object=data_object, |
| | class_name=self.index, |
| | vector=vector, |
| | ) |
| |
|
| | return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}" |
| |
|
| | def get(self, data): |
| | return self.get_relevant(data, 1) |
| |
|
| | def clear(self): |
| | self.client.schema.delete_all() |
| |
|
| | |
| | |
| | |
| | self._create_schema() |
| |
|
| | return "Obliterated" |
| |
|
| | def get_relevant(self, data, num_relevant=5): |
| | query_embedding = get_ada_embedding(data) |
| | try: |
| | results = ( |
| | self.client.query.get(self.index, ["raw_text"]) |
| | .with_near_vector({"vector": query_embedding, "certainty": 0.7}) |
| | .with_limit(num_relevant) |
| | .do() |
| | ) |
| |
|
| | if len(results["data"]["Get"][self.index]) > 0: |
| | return [ |
| | str(item["raw_text"]) for item in results["data"]["Get"][self.index] |
| | ] |
| | else: |
| | return [] |
| |
|
| | except Exception as err: |
| | print(f"Unexpected error {err=}, {type(err)=}") |
| | return [] |
| |
|
| | def get_stats(self): |
| | result = self.client.query.aggregate(self.index).with_meta_count().do() |
| | class_data = result["data"]["Aggregate"][self.index] |
| |
|
| | return class_data[0]["meta"] if class_data else {} |
| |
|