| import os
|
| import unittest
|
|
|
| import modules.flags
|
| from modules import util
|
|
|
|
|
| class TestUtils(unittest.TestCase):
|
| def test_can_parse_tokens_with_lora(self):
|
| test_cases = [
|
| {
|
| "input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5, True),
|
| "output": (
|
| [('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'),
|
| },
|
|
|
| {
|
| "input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1, True),
|
| "output": (
|
| [('hey-lora.safetensors', 0.4)],
|
| 'some prompt, very cool, cool'
|
| ),
|
| },
|
|
|
| {
|
| "input": (
|
| "some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>",
|
| [("hey-lora.safetensors", 0.4)],
|
| 5,
|
| True
|
| ),
|
| "output": (
|
| [
|
| ('hey-lora.safetensors', 0.4),
|
| ('l1.safetensors', 0.4),
|
| ('l2.safetensors', -0.2),
|
| ('l3.safetensors', 0.3),
|
| ('l4.safetensors', 0.5)
|
| ],
|
| 'some prompt, very cool'
|
| )
|
| },
|
|
|
| {
|
| "input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3, True),
|
| "output": (
|
| [
|
| ('hey-lora.safetensors', 0.4),
|
| ('you-lora.safetensors', 0.2)
|
| ],
|
| 'some prompt, very cool'
|
| ),
|
| },
|
|
|
| {
|
| "input": (
|
| "some prompt, very cool, <lora:hey-lora:0.4><lora:hey-lora:0.4><lora:you-lora:0.2>",
|
| [('you-lora.safetensors', 0.3)],
|
| 3,
|
| True
|
| ),
|
| "output": (
|
| [
|
| ('you-lora.safetensors', 0.3),
|
| ('hey-lora.safetensors', 0.4)
|
| ],
|
| 'some prompt, very cool'
|
| ),
|
| },
|
| {
|
| "input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6, True),
|
| "output": (
|
| [],
|
| '<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>'
|
| )
|
| }
|
| ]
|
|
|
| for test in test_cases:
|
| prompt, loras, loras_limit, skip_file_check = test["input"]
|
| expected = test["output"]
|
| actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit,
|
| skip_file_check=skip_file_check)
|
| self.assertEqual(expected, actual)
|
|
|
| def test_can_parse_tokens_and_strip_performance_lora(self):
|
| lora_filenames = [
|
| 'hey-lora.safetensors',
|
| modules.flags.PerformanceLoRA.EXTREME_SPEED.value,
|
| modules.flags.PerformanceLoRA.LIGHTNING.value,
|
| os.path.join('subfolder', modules.flags.PerformanceLoRA.HYPER_SD.value)
|
| ]
|
|
|
| test_cases = [
|
| {
|
| "input": ("some prompt, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.QUALITY),
|
| "output": (
|
| [('hey-lora.safetensors', 0.4)],
|
| 'some prompt'
|
| ),
|
| },
|
| {
|
| "input": ("some prompt, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.SPEED),
|
| "output": (
|
| [('hey-lora.safetensors', 0.4)],
|
| 'some prompt'
|
| ),
|
| },
|
| {
|
| "input": ("some prompt, <lora:sdxl_lcm_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.EXTREME_SPEED),
|
| "output": (
|
| [('hey-lora.safetensors', 0.4)],
|
| 'some prompt'
|
| ),
|
| },
|
| {
|
| "input": ("some prompt, <lora:sdxl_lightning_4step_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.LIGHTNING),
|
| "output": (
|
| [('hey-lora.safetensors', 0.4)],
|
| 'some prompt'
|
| ),
|
| },
|
| {
|
| "input": ("some prompt, <lora:sdxl_hyper_sd_4step_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.HYPER_SD),
|
| "output": (
|
| [('hey-lora.safetensors', 0.4)],
|
| 'some prompt'
|
| ),
|
| }
|
| ]
|
|
|
| for test in test_cases:
|
| prompt, loras, loras_limit, skip_file_check, performance = test["input"]
|
| lora_filenames = modules.util.remove_performance_lora(lora_filenames, performance)
|
| expected = test["output"]
|
| actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, lora_filenames=lora_filenames)
|
| self.assertEqual(expected, actual)
|
|
|