| | import os |
| | import sys |
| | import unittest |
| | from unittest import mock |
| | from uuid import uuid4 |
| |
|
| | from weaviate import Client |
| | from weaviate.util import get_valid_uuid |
| |
|
| | from autogpt.config import Config |
| | from autogpt.memory.base import get_ada_embedding |
| | from autogpt.memory.weaviate import WeaviateMemory |
| |
|
| |
|
| | class TestWeaviateMemory(unittest.TestCase): |
| | cfg = None |
| | client = None |
| | index = None |
| |
|
| | @classmethod |
| | def setUpClass(cls): |
| | |
| | cls.cfg = Config() |
| |
|
| | if cls.cfg.use_weaviate_embedded: |
| | from weaviate.embedded import EmbeddedOptions |
| |
|
| | cls.client = Client( |
| | embedded_options=EmbeddedOptions( |
| | hostname=cls.cfg.weaviate_host, |
| | port=int(cls.cfg.weaviate_port), |
| | persistence_data_path=cls.cfg.weaviate_embedded_path, |
| | ) |
| | ) |
| | else: |
| | cls.client = Client( |
| | f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}" |
| | ) |
| |
|
| | cls.index = WeaviateMemory.format_classname(cls.cfg.memory_index) |
| |
|
| | """ |
| | In order to run these tests you will need a local instance of |
| | Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose |
| | for creating local instances using docker. |
| | Alternatively in your .env file set the following environmental variables to run Weaviate embedded (see: https://weaviate.io/developers/weaviate/installation/embedded): |
| | |
| | USE_WEAVIATE_EMBEDDED=True |
| | WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate" |
| | """ |
| |
|
| | def setUp(self): |
| | try: |
| | self.client.schema.delete_class(self.index) |
| | except: |
| | pass |
| |
|
| | self.memory = WeaviateMemory(self.cfg) |
| |
|
| | def test_add(self): |
| | doc = "You are a Titan name Thanos and you are looking for the Infinity Stones" |
| | self.memory.add(doc) |
| | result = self.client.query.get(self.index, ["raw_text"]).do() |
| | actual = result["data"]["Get"][self.index] |
| |
|
| | self.assertEqual(len(actual), 1) |
| | self.assertEqual(actual[0]["raw_text"], doc) |
| |
|
| | def test_get(self): |
| | doc = "You are an Avenger and swore to defend the Galaxy from a menace called Thanos" |
| |
|
| | with self.client.batch as batch: |
| | batch.add_data_object( |
| | uuid=get_valid_uuid(uuid4()), |
| | data_object={"raw_text": doc}, |
| | class_name=self.index, |
| | vector=get_ada_embedding(doc), |
| | ) |
| |
|
| | batch.flush() |
| |
|
| | actual = self.memory.get(doc) |
| |
|
| | self.assertEqual(len(actual), 1) |
| | self.assertEqual(actual[0], doc) |
| |
|
| | def test_get_stats(self): |
| | docs = [ |
| | "You are now about to count the number of docs in this index", |
| | "And then you about to find out if you can count correctly", |
| | ] |
| |
|
| | [self.memory.add(doc) for doc in docs] |
| |
|
| | stats = self.memory.get_stats() |
| |
|
| | self.assertTrue(stats) |
| | self.assertTrue("count" in stats) |
| | self.assertEqual(stats["count"], 2) |
| |
|
| | def test_clear(self): |
| | docs = [ |
| | "Shame this is the last test for this class", |
| | "Testing is fun when someone else is doing it", |
| | ] |
| |
|
| | [self.memory.add(doc) for doc in docs] |
| |
|
| | self.assertEqual(self.memory.get_stats()["count"], 2) |
| |
|
| | self.memory.clear() |
| |
|
| | self.assertEqual(self.memory.get_stats()["count"], 0) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | unittest.main() |
| |
|