File size: 1,561 Bytes
3dabe4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import io
from PIL import Image
from modules import script_callbacks
import matplotlib
import pandas as pd
from loractl.lib.lora_ctl_network import networks
log_weights = []
log_names = []
last_plotted_step = -1
# Copied from composable_lora
def plot_lora_weight(lora_weights, lora_names):
data = pd.DataFrame(lora_weights, columns=lora_names)
ax = data.plot()
ax.set_xlabel("Steps")
ax.set_ylabel("LoRA weight")
ax.set_title("LoRA weight in all steps")
ax.legend(loc=0)
result_image = fig2img(ax)
matplotlib.pyplot.close(ax.figure)
del ax
return result_image
# Copied from composable_lora
def fig2img(fig):
buf = io.BytesIO()
fig.figure.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
def reset_plot():
global last_plotted_step
log_weights.clear()
log_names.clear()
def make_plot():
return plot_lora_weight(log_weights, log_names)
# On each step, capture our lora weights for plotting
def on_step(params):
global last_plotted_step
if last_plotted_step == params.sampling_step and len(log_weights) > 0:
log_weights.pop()
last_plotted_step = params.sampling_step
if len(log_names) == 0:
for net in networks.loaded_networks:
log_names.append(net.name + "_te")
log_names.append(net.name + "_unet")
frame = []
for net in networks.loaded_networks:
frame.append(net.te_multiplier)
frame.append(net.unet_multiplier)
log_weights.append(frame)
script_callbacks.on_cfg_after_cfg(on_step)
|