File size: 6,330 Bytes
29a5ed9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | from functools import partial, reduce
import torch
from modules import extra_networks, script_callbacks
from modules import prompt_parser
from modules.devices import device, dtype
from modules.sd_hijack import model_hijack
import gradio as gr
import matplotlib.pyplot as ax
ax.switch_backend("agg")
sd_model = None
sd_model_betas = None
sd_model_alphas_cumprod = None
sd_model_alphas_cumprod_prev = None
def do_restore_model_params():
sd_model.betas = sd_model_betas
sd_model.alphas_cumprod = sd_model_alphas_cumprod
sd_model.alphas_cumprod_prev = sd_model_alphas_cumprod_prev
values = sd_model_alphas_cumprod.tolist()
x_values = list(range(sd_model.num_timesteps))
ax.clf() # clear current figure.
ax.plot(x_values, values, label="original")
ax.legend()
ax.title("Alphas Cumulative Product")
ax.xlabel("step")
ax.ylabel("alphas cumprod")
return ax
def do_update_model_params(beta_start_mil: int, beta_end_mil: int):
torch.set_printoptions(precision=8, threshold=50)
values = sd_model_alphas_cumprod.tolist()
beta_start = beta_start_mil * 1.e-5
beta_end = beta_end_mil * 1.e-5
# beta_schedule = "linear"
# num_train_timesteps = 1000 # default = 1000
betas = torch.linspace(beta_start, beta_end, sd_model.num_timesteps, device=device, dtype=dtype)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat(
(torch.tensor([1.0], device=device, dtype=dtype), alphas_cumprod[:-1]))
new_values = alphas_cumprod.tolist()
x_values = list(range(sd_model.num_timesteps))
ax.clf() # clear current figure.
ax.plot(x_values, values, label="original")
ax.plot(x_values, new_values, label="update")
ax.legend()
ax.title("Alphas Cumulative Product")
ax.xlabel("step")
ax.ylabel("alphas cumprod")
sd_model.betas = betas
sd_model.alphas_cumprod = alphas_cumprod
sd_model.alphas_cumprod_prev = alphas_cumprod_prev
return ax
def do_schedule(text, steps, current_step):
#
# update_token_counter in modules/ui.py
#
try:
text, _ = extra_networks.parse_prompt(text)
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
except Exception:
# a parsing error can happen here during typing, and we don't want to bother the user with
# messages related to it in console
prompt_schedules = [[[steps, text]]]
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
ht = []
md = ''
current_prompt = None
for when, prompt in flat_prompts:
if current_step <= when:
current_prompt = prompt
break
if current_prompt is not None:
#
# in modules/sd_hijack_clip.py
#
clip = model_hijack.clip
batch_chunks, token_count = clip.process_texts([current_prompt])
# used_embeddings = {}
chunk_count = max([len(x) for x in batch_chunks])
for i in range(chunk_count):
batch_chunk = [chunks[i] if i < len(chunks) else clip.empty_chunk() for chunks in batch_chunks]
for x in batch_chunk:
for token in clip.tokenizer.convert_ids_to_tokens(x.tokens):
if token.startswith('<|'):
if token == '<|startoftext|>':
ht.append(['.', 'B'])
elif token == '<|endoftext|>':
ht.append(['.', 'E'])
else:
ht.append([token[:-4] if token.endswith('</w>') else token, None])
md += f'{token_count} tokens at step {current_step}\n'
for when, prompt in flat_prompts:
md += f'### step {when}\n'
md += prompt
md += f'\n'
return ht, md
def on_model_loaded(sd_model_):
global sd_model
global sd_model_betas, sd_model_alphas_cumprod, sd_model_alphas_cumprod_prev
if sd_model_ == sd_model:
return
sd_model = sd_model_
sd_model_betas = sd_model_.betas.to(device, dtype)
sd_model_alphas_cumprod = sd_model_.alphas_cumprod.to(device, dtype)
sd_model_alphas_cumprod_prev = sd_model_.alphas_cumprod_prev.to(device, dtype)
def on_ui_tabs():
with gr.Blocks(analytics_enabled=False, variant="compact") as demo:
with gr.Row():
with gr.Column():
plot = gr.Plot(value=ax)
with gr.Row():
with gr.Column():
beta_start = gr.Slider(minimum=5, maximum=125, step=1, label="Beta start * 1.e+5", value=85) # 85.020
with gr.Column():
beta_end = gr.Slider(minimum=400, maximum=2000, step=20, label="Beta end * 1.e+5", value=1200) # 1200.104
with gr.Row():
with gr.Column():
restore_button = gr.Button(value="Restore")
with gr.Column():
update_button = gr.Button(value="Update", variant="primary")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", show_label=False, lines=3, placeholder="Prompt")
with gr.Row():
with gr.Column():
steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling steps", value=20)
with gr.Column():
current_step = gr.Slider(minimum=1, maximum=150, step=1, label="Count tokens at this step", value=1)
with gr.Row():
with gr.Column():
schedule_button = gr.Button(value="Schedule", variant="primary")
report_ht = gr.HighlightedText(combine_adjacent=True, adjacent_separator=' ', label="CLIP").style(color_map={'B': 'green', 'E': 'red'})
report_md = gr.Markdown()
restore_button.click(fn=do_restore_model_params, inputs=[], outputs=[plot])
update_button.click(fn=do_update_model_params, inputs=[beta_start, beta_end], outputs=[plot])
schedule_button.click(fn=do_schedule, inputs=[prompt, steps, current_step], outputs=[report_ht, report_md])
return (demo, "Sched.", "sched"),
script_callbacks.on_model_loaded(on_model_loaded)
script_callbacks.on_ui_tabs(on_ui_tabs)
|