dikdimon's picture
Upload extensions using SD-Hub extension
3dabe4a verified
import sys, unittest
from pathlib import Path
from loractl.lib.utils import sorted_positions, calculate_weight, params_to_weights
path = str(Path(__file__).parent.parent.parent.parent)
sys.path.insert(0, path)
from modules.extra_networks import ExtraNetworkParams
sys.path.remove(path)
class LoraCtlTests(unittest.TestCase):
def test_sorted_positions(self):
self.assertEqual(sorted_positions("1"), 1.0)
self.assertEqual(sorted_positions("1@0,0.5@3,1@6"),
[[1.0, 0.5, 1.0], [0.0, 3.0, 6.0]])
self.assertEqual(sorted_positions("0.5@3,1@6,1@0"),
[[1.0, 0.5, 1.0], [0.0, 3.0, 6.0]])
self.assertEqual(sorted_positions("0.5@0,0.5@0.5,0@1"),
[[0.5, 0.5, 0.0], [0.0, 0.5, 1.0]])
def test_sorted_position_semicolons(self):
self.assertEqual(sorted_positions("1@0;0.5@3;1@6"),
[[1.0, 0.5, 1.0], [0.0, 3.0, 6.0]])
def test_weight_interpolation(self):
# Bare weights are never interpolated
steps = sorted_positions("1.0")
self.assertEqual(calculate_weight(steps, 0, 30), 1.0)
self.assertEqual(calculate_weight(steps, 15, 30), 1.0)
self.assertEqual(calculate_weight(steps, 30, 30), 1.0)
# Weights are interpolated correctly
steps = sorted_positions("0.75@0;0.5@3;1@6")
self.assertEqual(calculate_weight(steps, 0, 30), 0.75)
self.assertEqual(calculate_weight(steps, 3, 30), 0.5)
self.assertEqual(calculate_weight(steps, 6, 30), 1.0)
self.assertEqual(calculate_weight(steps, 9, 30), 1.0)
# An implicit 0-step is added
steps = sorted_positions("0.5@5,1.0@10")
self.assertEqual(calculate_weight(steps, 0, 30), 0.5)
self.assertEqual(calculate_weight(steps, 5, 30), 0.5)
self.assertEqual(calculate_weight(steps, 8, 30), 0.8)
self.assertEqual(calculate_weight(steps, 15, 30), 1.0)
class LoraCtlNetworkTests(unittest.TestCase):
def assert_params(self, str, expected):
params = ExtraNetworkParams(str.split(":"))
self.assertEqual(params_to_weights(params), expected)
def test_params_to_weights(self):
# TE cascades to all
self.assert_params("loraname:1.0", {
'hrte': 1.0,
'hrunet': 1.0,
'te': 1.0,
'unet': 1.0
})
# HR can be specified separately
self.assert_params("loraname:0.5@0,1@1:hr=0.6", {
'hrte': 0.6,
'hrunet': 0.6,
'te': [[0.5, 1.0], [0.0, 1.0]],
'unet': [[0.5, 1.0], [0.0, 1.0]]
})
# Explicit TE cascades
self.assert_params("loraname:te=0.5@0,1@1", {
'te': [[0.5, 1.0], [0.0, 1.0]],
'unet': [[0.5, 1.0], [0.0, 1.0]],
'hrte': [[0.5, 1.0], [0.0, 1.0]],
'hrunet': [[0.5, 1.0], [0.0, 1.0]],
})
# Implicit TE cascades, explicit unet cascades
self.assert_params("loraname:unet=0.5@0,1@1", {
'te': 1.0,
'unet': [[0.5, 1.0], [0.0, 1.0]],
'hrte': 1.0,
'hrunet': [[0.5, 1.0], [0.0, 1.0]],
})
# Explicit HR TE overrides lowres TE
self.assert_params("loraname:unet=0.5@0,1@1:hrte=0.5", {
'te': 1.0,
'unet': [[0.5, 1.0], [0.0, 1.0]],
'hrte': 0.5,
'hrunet': [[0.5, 1.0], [0.0, 1.0]],
})
# Explicit HR TE overrides HR
self.assert_params("loraname:hr=0.6:hrte=0.5", {
'te': 1.0,
'unet': 1.0,
'hrte': 0.5,
'hrunet': 0.6,
})
self.assert_params("loraname:0.8@0.15,0@0.3:hr=0", {
'hrte': 0.0,
'hrunet': 0.0,
'te': [[0.8, 0.0], [0.15, 0.3]],
'unet': [[0.8, 0.0], [0.15, 0.3]]
})