| from functools import partial, reduce |
|
|
| import torch |
|
|
| from modules import extra_networks, script_callbacks |
|
|
| from modules import prompt_parser |
| from modules.devices import device, dtype |
| from modules.sd_hijack import model_hijack |
|
|
| import gradio as gr |
| import matplotlib.pyplot as ax |
|
|
| ax.switch_backend("agg") |
|
|
| sd_model = None |
| sd_model_betas = None |
| sd_model_alphas_cumprod = None |
| sd_model_alphas_cumprod_prev = None |
|
|
|
|
| def do_restore_model_params(): |
| sd_model.betas = sd_model_betas |
| sd_model.alphas_cumprod = sd_model_alphas_cumprod |
| sd_model.alphas_cumprod_prev = sd_model_alphas_cumprod_prev |
|
|
| values = sd_model_alphas_cumprod.tolist() |
|
|
| x_values = list(range(sd_model.num_timesteps)) |
|
|
| ax.clf() |
| ax.plot(x_values, values, label="original") |
| ax.legend() |
| ax.title("Alphas Cumulative Product") |
| ax.xlabel("step") |
| ax.ylabel("alphas cumprod") |
|
|
| return ax |
|
|
|
|
| def do_update_model_params(beta_start_mil: int, beta_end_mil: int): |
|
|
| torch.set_printoptions(precision=8, threshold=50) |
|
|
| values = sd_model_alphas_cumprod.tolist() |
|
|
| beta_start = beta_start_mil * 1.e-5 |
| beta_end = beta_end_mil * 1.e-5 |
| |
| |
|
|
| betas = torch.linspace(beta_start, beta_end, sd_model.num_timesteps, device=device, dtype=dtype) |
|
|
| alphas = 1.0 - betas |
| alphas_cumprod = torch.cumprod(alphas, dim=0) |
| alphas_cumprod_prev = torch.cat( |
| (torch.tensor([1.0], device=device, dtype=dtype), alphas_cumprod[:-1])) |
|
|
| new_values = alphas_cumprod.tolist() |
|
|
| x_values = list(range(sd_model.num_timesteps)) |
|
|
| ax.clf() |
| ax.plot(x_values, values, label="original") |
| ax.plot(x_values, new_values, label="update") |
| ax.legend() |
| ax.title("Alphas Cumulative Product") |
| ax.xlabel("step") |
| ax.ylabel("alphas cumprod") |
|
|
| sd_model.betas = betas |
| sd_model.alphas_cumprod = alphas_cumprod |
| sd_model.alphas_cumprod_prev = alphas_cumprod_prev |
|
|
| return ax |
|
|
|
|
| def do_schedule(text, steps, current_step): |
|
|
| |
| |
| |
| try: |
| text, _ = extra_networks.parse_prompt(text) |
|
|
| _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) |
| prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) |
|
|
| except Exception: |
| |
| |
| prompt_schedules = [[[steps, text]]] |
|
|
| flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) |
|
|
| ht = [] |
| md = '' |
|
|
| current_prompt = None |
| for when, prompt in flat_prompts: |
| if current_step <= when: |
| current_prompt = prompt |
| break |
|
|
| if current_prompt is not None: |
| |
| |
| |
| clip = model_hijack.clip |
|
|
| batch_chunks, token_count = clip.process_texts([current_prompt]) |
|
|
| |
| chunk_count = max([len(x) for x in batch_chunks]) |
|
|
| for i in range(chunk_count): |
| batch_chunk = [chunks[i] if i < len(chunks) else clip.empty_chunk() for chunks in batch_chunks] |
| for x in batch_chunk: |
| for token in clip.tokenizer.convert_ids_to_tokens(x.tokens): |
| if token.startswith('<|'): |
| if token == '<|startoftext|>': |
| ht.append(['.', 'B']) |
| elif token == '<|endoftext|>': |
| ht.append(['.', 'E']) |
| else: |
| ht.append([token[:-4] if token.endswith('</w>') else token, None]) |
|
|
| md += f'{token_count} tokens at step {current_step}\n' |
|
|
| for when, prompt in flat_prompts: |
| md += f'### step {when}\n' |
| md += prompt |
| md += f'\n' |
|
|
| return ht, md |
|
|
|
|
| def on_model_loaded(sd_model_): |
| global sd_model |
| global sd_model_betas, sd_model_alphas_cumprod, sd_model_alphas_cumprod_prev |
|
|
| if sd_model_ == sd_model: |
| return |
|
|
| sd_model = sd_model_ |
| sd_model_betas = sd_model_.betas.to(device, dtype) |
| sd_model_alphas_cumprod = sd_model_.alphas_cumprod.to(device, dtype) |
| sd_model_alphas_cumprod_prev = sd_model_.alphas_cumprod_prev.to(device, dtype) |
|
|
|
|
| def on_ui_tabs(): |
| with gr.Blocks(analytics_enabled=False, variant="compact") as demo: |
| with gr.Row(): |
| with gr.Column(): |
| plot = gr.Plot(value=ax) |
| with gr.Row(): |
| with gr.Column(): |
| beta_start = gr.Slider(minimum=5, maximum=125, step=1, label="Beta start * 1.e+5", value=85) |
| with gr.Column(): |
| beta_end = gr.Slider(minimum=400, maximum=2000, step=20, label="Beta end * 1.e+5", value=1200) |
| with gr.Row(): |
| with gr.Column(): |
| restore_button = gr.Button(value="Restore") |
| with gr.Column(): |
| update_button = gr.Button(value="Update", variant="primary") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| prompt = gr.Textbox(label="Prompt", show_label=False, lines=3, placeholder="Prompt") |
| with gr.Row(): |
| with gr.Column(): |
| steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling steps", value=20) |
| with gr.Column(): |
| current_step = gr.Slider(minimum=1, maximum=150, step=1, label="Count tokens at this step", value=1) |
| with gr.Row(): |
| with gr.Column(): |
| schedule_button = gr.Button(value="Schedule", variant="primary") |
| report_ht = gr.HighlightedText(combine_adjacent=True, adjacent_separator=' ', label="CLIP").style(color_map={'B': 'green', 'E': 'red'}) |
| report_md = gr.Markdown() |
|
|
| restore_button.click(fn=do_restore_model_params, inputs=[], outputs=[plot]) |
| update_button.click(fn=do_update_model_params, inputs=[beta_start, beta_end], outputs=[plot]) |
| schedule_button.click(fn=do_schedule, inputs=[prompt, steps, current_step], outputs=[report_ht, report_md]) |
|
|
| return (demo, "Sched.", "sched"), |
|
|
|
|
| script_callbacks.on_model_loaded(on_model_loaded) |
| script_callbacks.on_ui_tabs(on_ui_tabs) |
|
|