| import hashlib |
| import os |
| import unittest |
|
|
| from PIL import Image |
|
|
| from autogpt.commands.image_gen import generate_image, generate_image_with_sd_webui |
| from autogpt.config import Config |
| from autogpt.workspace import path_in_workspace |
|
|
|
|
| def lst(txt): |
| return txt.split(":")[1].strip() |
|
|
|
|
| @unittest.skipIf(os.getenv("CI"), "Skipping image generation tests") |
| class TestImageGen(unittest.TestCase): |
| def setUp(self): |
| self.config = Config() |
|
|
| def test_dalle(self): |
| self.config.image_provider = "dalle" |
|
|
| |
| result = lst(generate_image("astronaut riding a horse", 256)) |
| image_path = path_in_workspace(result) |
| self.assertTrue(image_path.exists()) |
| with Image.open(image_path) as img: |
| self.assertEqual(img.size, (256, 256)) |
| image_path.unlink() |
|
|
| |
| result = lst(generate_image("astronaut riding a horse", 512)) |
| image_path = path_in_workspace(result) |
| with Image.open(image_path) as img: |
| self.assertEqual(img.size, (512, 512)) |
| image_path.unlink() |
|
|
| def test_huggingface(self): |
| self.config.image_provider = "huggingface" |
|
|
| |
| self.config.huggingface_image_model = "CompVis/stable-diffusion-v1-4" |
| result = lst(generate_image("astronaut riding a horse", 512)) |
| image_path = path_in_workspace(result) |
| self.assertTrue(image_path.exists()) |
| with Image.open(image_path) as img: |
| self.assertEqual(img.size, (512, 512)) |
| image_path.unlink() |
|
|
| |
| self.config.huggingface_image_model = "stabilityai/stable-diffusion-2-1" |
| result = lst(generate_image("astronaut riding a horse", 768)) |
| image_path = path_in_workspace(result) |
| with Image.open(image_path) as img: |
| self.assertEqual(img.size, (768, 768)) |
| image_path.unlink() |
|
|
| def test_sd_webui(self): |
| self.config.image_provider = "sd_webui" |
| return |
|
|
| |
| result = lst(generate_image_with_sd_webui("astronaut riding a horse", 128)) |
| image_path = path_in_workspace(result) |
| self.assertTrue(image_path.exists()) |
| with Image.open(image_path) as img: |
| self.assertEqual(img.size, (128, 128)) |
| image_path.unlink() |
|
|
| |
| result = lst( |
| generate_image_with_sd_webui( |
| "astronaut riding a horse", |
| negative_prompt="horse", |
| size=64, |
| extra={"seed": 123}, |
| ) |
| ) |
| image_path = path_in_workspace(result) |
| with Image.open(image_path) as img: |
| self.assertEqual(img.size, (64, 64)) |
| neg_image_hash = hashlib.md5(img.tobytes()).hexdigest() |
| image_path.unlink() |
|
|
| |
| result = lst( |
| generate_image_with_sd_webui( |
| "astronaut riding a horse", image_size=64, size=1, extra={"seed": 123} |
| ) |
| ) |
| image_path = path_in_workspace(result) |
| with Image.open(image_path) as img: |
| self.assertEqual(img.size, (64, 64)) |
| image_hash = hashlib.md5(img.tobytes()).hexdigest() |
| image_path.unlink() |
|
|
| self.assertNotEqual(image_hash, neg_image_hash) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|