|
|
import sys, unittest |
|
|
from pathlib import Path |
|
|
from loractl.lib.utils import sorted_positions, calculate_weight, params_to_weights |
|
|
|
|
|
path = str(Path(__file__).parent.parent.parent.parent) |
|
|
sys.path.insert(0, path) |
|
|
from modules.extra_networks import ExtraNetworkParams |
|
|
sys.path.remove(path) |
|
|
|
|
|
|
|
|
class LoraCtlTests(unittest.TestCase): |
|
|
def test_sorted_positions(self): |
|
|
self.assertEqual(sorted_positions("1"), 1.0) |
|
|
self.assertEqual(sorted_positions("1@0,0.5@3,1@6"), |
|
|
[[1.0, 0.5, 1.0], [0.0, 3.0, 6.0]]) |
|
|
self.assertEqual(sorted_positions("0.5@3,1@6,1@0"), |
|
|
[[1.0, 0.5, 1.0], [0.0, 3.0, 6.0]]) |
|
|
self.assertEqual(sorted_positions("0.5@0,0.5@0.5,0@1"), |
|
|
[[0.5, 0.5, 0.0], [0.0, 0.5, 1.0]]) |
|
|
|
|
|
def test_sorted_position_semicolons(self): |
|
|
self.assertEqual(sorted_positions("1@0;0.5@3;1@6"), |
|
|
[[1.0, 0.5, 1.0], [0.0, 3.0, 6.0]]) |
|
|
|
|
|
def test_weight_interpolation(self): |
|
|
|
|
|
steps = sorted_positions("1.0") |
|
|
self.assertEqual(calculate_weight(steps, 0, 30), 1.0) |
|
|
self.assertEqual(calculate_weight(steps, 15, 30), 1.0) |
|
|
self.assertEqual(calculate_weight(steps, 30, 30), 1.0) |
|
|
|
|
|
|
|
|
steps = sorted_positions("0.75@0;0.5@3;1@6") |
|
|
self.assertEqual(calculate_weight(steps, 0, 30), 0.75) |
|
|
self.assertEqual(calculate_weight(steps, 3, 30), 0.5) |
|
|
self.assertEqual(calculate_weight(steps, 6, 30), 1.0) |
|
|
self.assertEqual(calculate_weight(steps, 9, 30), 1.0) |
|
|
|
|
|
|
|
|
steps = sorted_positions("0.5@5,1.0@10") |
|
|
self.assertEqual(calculate_weight(steps, 0, 30), 0.5) |
|
|
self.assertEqual(calculate_weight(steps, 5, 30), 0.5) |
|
|
self.assertEqual(calculate_weight(steps, 8, 30), 0.8) |
|
|
self.assertEqual(calculate_weight(steps, 15, 30), 1.0) |
|
|
|
|
|
|
|
|
class LoraCtlNetworkTests(unittest.TestCase): |
|
|
def assert_params(self, str, expected): |
|
|
params = ExtraNetworkParams(str.split(":")) |
|
|
self.assertEqual(params_to_weights(params), expected) |
|
|
|
|
|
def test_params_to_weights(self): |
|
|
|
|
|
self.assert_params("loraname:1.0", { |
|
|
'hrte': 1.0, |
|
|
'hrunet': 1.0, |
|
|
'te': 1.0, |
|
|
'unet': 1.0 |
|
|
}) |
|
|
|
|
|
|
|
|
self.assert_params("loraname:0.5@0,1@1:hr=0.6", { |
|
|
'hrte': 0.6, |
|
|
'hrunet': 0.6, |
|
|
'te': [[0.5, 1.0], [0.0, 1.0]], |
|
|
'unet': [[0.5, 1.0], [0.0, 1.0]] |
|
|
}) |
|
|
|
|
|
|
|
|
self.assert_params("loraname:te=0.5@0,1@1", { |
|
|
'te': [[0.5, 1.0], [0.0, 1.0]], |
|
|
'unet': [[0.5, 1.0], [0.0, 1.0]], |
|
|
'hrte': [[0.5, 1.0], [0.0, 1.0]], |
|
|
'hrunet': [[0.5, 1.0], [0.0, 1.0]], |
|
|
}) |
|
|
|
|
|
|
|
|
self.assert_params("loraname:unet=0.5@0,1@1", { |
|
|
'te': 1.0, |
|
|
'unet': [[0.5, 1.0], [0.0, 1.0]], |
|
|
'hrte': 1.0, |
|
|
'hrunet': [[0.5, 1.0], [0.0, 1.0]], |
|
|
}) |
|
|
|
|
|
|
|
|
self.assert_params("loraname:unet=0.5@0,1@1:hrte=0.5", { |
|
|
'te': 1.0, |
|
|
'unet': [[0.5, 1.0], [0.0, 1.0]], |
|
|
'hrte': 0.5, |
|
|
'hrunet': [[0.5, 1.0], [0.0, 1.0]], |
|
|
}) |
|
|
|
|
|
|
|
|
self.assert_params("loraname:hr=0.6:hrte=0.5", { |
|
|
'te': 1.0, |
|
|
'unet': 1.0, |
|
|
'hrte': 0.5, |
|
|
'hrunet': 0.6, |
|
|
}) |
|
|
|
|
|
self.assert_params("loraname:0.8@0.15,0@0.3:hr=0", { |
|
|
'hrte': 0.0, |
|
|
'hrunet': 0.0, |
|
|
'te': [[0.8, 0.0], [0.15, 0.3]], |
|
|
'unet': [[0.8, 0.0], [0.15, 0.3]] |
|
|
}) |
|
|
|