| | import hashlib |
| | import os |
| | import unittest |
| |
|
| | from PIL import Image |
| |
|
| | from autogpt.commands.image_gen import generate_image, generate_image_with_sd_webui |
| | from autogpt.config import Config |
| | from autogpt.workspace import path_in_workspace |
| |
|
| |
|
| | def lst(txt): |
| | return txt.split(":")[1].strip() |
| |
|
| |
|
| | @unittest.skipIf(os.getenv("CI"), "Skipping image generation tests") |
| | class TestImageGen(unittest.TestCase): |
| | def setUp(self): |
| | self.config = Config() |
| |
|
| | def test_dalle(self): |
| | self.config.image_provider = "dalle" |
| |
|
| | |
| | result = lst(generate_image("astronaut riding a horse", 256)) |
| | image_path = path_in_workspace(result) |
| | self.assertTrue(image_path.exists()) |
| | with Image.open(image_path) as img: |
| | self.assertEqual(img.size, (256, 256)) |
| | image_path.unlink() |
| |
|
| | |
| | result = lst(generate_image("astronaut riding a horse", 512)) |
| | image_path = path_in_workspace(result) |
| | with Image.open(image_path) as img: |
| | self.assertEqual(img.size, (512, 512)) |
| | image_path.unlink() |
| |
|
| | def test_huggingface(self): |
| | self.config.image_provider = "huggingface" |
| |
|
| | |
| | self.config.huggingface_image_model = "CompVis/stable-diffusion-v1-4" |
| | result = lst(generate_image("astronaut riding a horse", 512)) |
| | image_path = path_in_workspace(result) |
| | self.assertTrue(image_path.exists()) |
| | with Image.open(image_path) as img: |
| | self.assertEqual(img.size, (512, 512)) |
| | image_path.unlink() |
| |
|
| | |
| | self.config.huggingface_image_model = "stabilityai/stable-diffusion-2-1" |
| | result = lst(generate_image("astronaut riding a horse", 768)) |
| | image_path = path_in_workspace(result) |
| | with Image.open(image_path) as img: |
| | self.assertEqual(img.size, (768, 768)) |
| | image_path.unlink() |
| |
|
| | def test_sd_webui(self): |
| | self.config.image_provider = "sd_webui" |
| | return |
| |
|
| | |
| | result = lst(generate_image_with_sd_webui("astronaut riding a horse", 128)) |
| | image_path = path_in_workspace(result) |
| | self.assertTrue(image_path.exists()) |
| | with Image.open(image_path) as img: |
| | self.assertEqual(img.size, (128, 128)) |
| | image_path.unlink() |
| |
|
| | |
| | result = lst( |
| | generate_image_with_sd_webui( |
| | "astronaut riding a horse", |
| | negative_prompt="horse", |
| | size=64, |
| | extra={"seed": 123}, |
| | ) |
| | ) |
| | image_path = path_in_workspace(result) |
| | with Image.open(image_path) as img: |
| | self.assertEqual(img.size, (64, 64)) |
| | neg_image_hash = hashlib.md5(img.tobytes()).hexdigest() |
| | image_path.unlink() |
| |
|
| | |
| | result = lst( |
| | generate_image_with_sd_webui( |
| | "astronaut riding a horse", image_size=64, size=1, extra={"seed": 123} |
| | ) |
| | ) |
| | image_path = path_in_workspace(result) |
| | with Image.open(image_path) as img: |
| | self.assertEqual(img.size, (64, 64)) |
| | image_hash = hashlib.md5(img.tobytes()).hexdigest() |
| | image_path.unlink() |
| |
|
| | self.assertNotEqual(image_hash, neg_image_hash) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | unittest.main() |
| |
|