| | import os |
| | import unittest |
| |
|
| | import modules.flags |
| | from modules import util |
| |
|
| |
|
| | class TestUtils(unittest.TestCase): |
| | def test_can_parse_tokens_with_lora(self): |
| | test_cases = [ |
| | { |
| | "input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5, True), |
| | "output": ( |
| | [('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'), |
| | }, |
| | |
| | { |
| | "input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1, True), |
| | "output": ( |
| | [('hey-lora.safetensors', 0.4)], |
| | 'some prompt, very cool, cool' |
| | ), |
| | }, |
| | |
| | { |
| | "input": ( |
| | "some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>", |
| | [("hey-lora.safetensors", 0.4)], |
| | 5, |
| | True |
| | ), |
| | "output": ( |
| | [ |
| | ('hey-lora.safetensors', 0.4), |
| | ('l1.safetensors', 0.4), |
| | ('l2.safetensors', -0.2), |
| | ('l3.safetensors', 0.3), |
| | ('l4.safetensors', 0.5) |
| | ], |
| | 'some prompt, very cool' |
| | ) |
| | }, |
| | |
| | { |
| | "input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3, True), |
| | "output": ( |
| | [ |
| | ('hey-lora.safetensors', 0.4), |
| | ('you-lora.safetensors', 0.2) |
| | ], |
| | 'some prompt, very cool' |
| | ), |
| | }, |
| | |
| | { |
| | "input": ( |
| | "some prompt, very cool, <lora:hey-lora:0.4><lora:hey-lora:0.4><lora:you-lora:0.2>", |
| | [('you-lora.safetensors', 0.3)], |
| | 3, |
| | True |
| | ), |
| | "output": ( |
| | [ |
| | ('you-lora.safetensors', 0.3), |
| | ('hey-lora.safetensors', 0.4) |
| | ], |
| | 'some prompt, very cool' |
| | ), |
| | }, |
| | { |
| | "input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6, True), |
| | "output": ( |
| | [], |
| | '<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>' |
| | ) |
| | } |
| | ] |
| |
|
| | for test in test_cases: |
| | prompt, loras, loras_limit, skip_file_check = test["input"] |
| | expected = test["output"] |
| | actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, |
| | skip_file_check=skip_file_check) |
| | self.assertEqual(expected, actual) |
| |
|
| | def test_can_parse_tokens_and_strip_performance_lora(self): |
| | lora_filenames = [ |
| | 'hey-lora.safetensors', |
| | modules.flags.PerformanceLoRA.EXTREME_SPEED.value, |
| | modules.flags.PerformanceLoRA.LIGHTNING.value, |
| | os.path.join('subfolder', modules.flags.PerformanceLoRA.HYPER_SD.value) |
| | ] |
| |
|
| | test_cases = [ |
| | { |
| | "input": ("some prompt, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.QUALITY), |
| | "output": ( |
| | [('hey-lora.safetensors', 0.4)], |
| | 'some prompt' |
| | ), |
| | }, |
| | { |
| | "input": ("some prompt, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.SPEED), |
| | "output": ( |
| | [('hey-lora.safetensors', 0.4)], |
| | 'some prompt' |
| | ), |
| | }, |
| | { |
| | "input": ("some prompt, <lora:sdxl_lcm_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.EXTREME_SPEED), |
| | "output": ( |
| | [('hey-lora.safetensors', 0.4)], |
| | 'some prompt' |
| | ), |
| | }, |
| | { |
| | "input": ("some prompt, <lora:sdxl_lightning_4step_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.LIGHTNING), |
| | "output": ( |
| | [('hey-lora.safetensors', 0.4)], |
| | 'some prompt' |
| | ), |
| | }, |
| | { |
| | "input": ("some prompt, <lora:sdxl_hyper_sd_4step_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.HYPER_SD), |
| | "output": ( |
| | [('hey-lora.safetensors', 0.4)], |
| | 'some prompt' |
| | ), |
| | } |
| | ] |
| |
|
| | for test in test_cases: |
| | prompt, loras, loras_limit, skip_file_check, performance = test["input"] |
| | lora_filenames = modules.util.remove_performance_lora(lora_filenames, performance) |
| | expected = test["output"] |
| | actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, lora_filenames=lora_filenames) |
| | self.assertEqual(expected, actual) |
| |
|