File size: 6,330 Bytes
29a5ed9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from functools import partial, reduce

import torch

from modules import extra_networks, script_callbacks

from modules import prompt_parser
from modules.devices import device, dtype
from modules.sd_hijack import model_hijack

import gradio as gr
import matplotlib.pyplot as ax

ax.switch_backend("agg")

sd_model = None
sd_model_betas = None
sd_model_alphas_cumprod = None
sd_model_alphas_cumprod_prev = None


def do_restore_model_params():
    sd_model.betas = sd_model_betas
    sd_model.alphas_cumprod = sd_model_alphas_cumprod
    sd_model.alphas_cumprod_prev = sd_model_alphas_cumprod_prev

    values = sd_model_alphas_cumprod.tolist()

    x_values = list(range(sd_model.num_timesteps))

    ax.clf()  # clear current figure.
    ax.plot(x_values, values, label="original")
    ax.legend()
    ax.title("Alphas Cumulative Product")
    ax.xlabel("step")
    ax.ylabel("alphas cumprod")

    return ax


def do_update_model_params(beta_start_mil: int, beta_end_mil: int):

    torch.set_printoptions(precision=8, threshold=50)

    values = sd_model_alphas_cumprod.tolist()

    beta_start = beta_start_mil * 1.e-5
    beta_end = beta_end_mil * 1.e-5
    # beta_schedule = "linear"
    # num_train_timesteps = 1000  # default = 1000

    betas = torch.linspace(beta_start, beta_end, sd_model.num_timesteps, device=device, dtype=dtype)

    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_cumprod_prev = torch.cat(
        (torch.tensor([1.0], device=device, dtype=dtype), alphas_cumprod[:-1]))

    new_values = alphas_cumprod.tolist()

    x_values = list(range(sd_model.num_timesteps))

    ax.clf()  # clear current figure.
    ax.plot(x_values, values, label="original")
    ax.plot(x_values, new_values, label="update")
    ax.legend()
    ax.title("Alphas Cumulative Product")
    ax.xlabel("step")
    ax.ylabel("alphas cumprod")

    sd_model.betas = betas
    sd_model.alphas_cumprod = alphas_cumprod
    sd_model.alphas_cumprod_prev = alphas_cumprod_prev

    return ax


def do_schedule(text, steps, current_step):

    #
    # update_token_counter in modules/ui.py
    #
    try:
        text, _ = extra_networks.parse_prompt(text)

        _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
        prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)

    except Exception:
        # a parsing error can happen here during typing, and we don't want to bother the user with
        # messages related to it in console
        prompt_schedules = [[[steps, text]]]

    flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)

    ht = []
    md = ''

    current_prompt = None
    for when, prompt in flat_prompts:
        if current_step <= when:
            current_prompt = prompt
            break

    if current_prompt is not None:
        #
        # in modules/sd_hijack_clip.py
        #
        clip = model_hijack.clip

        batch_chunks, token_count = clip.process_texts([current_prompt])

        # used_embeddings = {}
        chunk_count = max([len(x) for x in batch_chunks])

        for i in range(chunk_count):
            batch_chunk = [chunks[i] if i < len(chunks) else clip.empty_chunk() for chunks in batch_chunks]
            for x in batch_chunk:
                for token in clip.tokenizer.convert_ids_to_tokens(x.tokens):
                    if token.startswith('<|'):
                        if token == '<|startoftext|>':
                            ht.append(['.', 'B'])
                        elif token == '<|endoftext|>':
                            ht.append(['.', 'E'])
                    else:
                        ht.append([token[:-4] if token.endswith('</w>') else token, None])

        md += f'{token_count} tokens at step {current_step}\n'

    for when, prompt in flat_prompts:
        md += f'### step {when}\n'
        md += prompt
        md += f'\n'

    return ht, md


def on_model_loaded(sd_model_):
    global sd_model
    global sd_model_betas, sd_model_alphas_cumprod, sd_model_alphas_cumprod_prev

    if sd_model_ == sd_model:
        return

    sd_model = sd_model_
    sd_model_betas = sd_model_.betas.to(device, dtype)
    sd_model_alphas_cumprod = sd_model_.alphas_cumprod.to(device, dtype)
    sd_model_alphas_cumprod_prev = sd_model_.alphas_cumprod_prev.to(device, dtype)


def on_ui_tabs():
    with gr.Blocks(analytics_enabled=False, variant="compact") as demo:
        with gr.Row():
            with gr.Column():
                plot = gr.Plot(value=ax)
        with gr.Row():
            with gr.Column():
                beta_start = gr.Slider(minimum=5, maximum=125, step=1, label="Beta start * 1.e+5", value=85)  # 85.020
            with gr.Column():
                beta_end = gr.Slider(minimum=400, maximum=2000, step=20, label="Beta end * 1.e+5", value=1200)  # 1200.104
        with gr.Row():
            with gr.Column():
                restore_button = gr.Button(value="Restore")
            with gr.Column():
                update_button = gr.Button(value="Update", variant="primary")

        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt", show_label=False, lines=3, placeholder="Prompt")
        with gr.Row():
            with gr.Column():
                steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling steps", value=20)
            with gr.Column():
                current_step = gr.Slider(minimum=1, maximum=150, step=1, label="Count tokens at this step", value=1)
        with gr.Row():
            with gr.Column():
                schedule_button = gr.Button(value="Schedule", variant="primary")
                report_ht = gr.HighlightedText(combine_adjacent=True, adjacent_separator=' ', label="CLIP").style(color_map={'B': 'green', 'E': 'red'})
                report_md = gr.Markdown()

        restore_button.click(fn=do_restore_model_params, inputs=[], outputs=[plot])
        update_button.click(fn=do_update_model_params, inputs=[beta_start, beta_end], outputs=[plot])
        schedule_button.click(fn=do_schedule, inputs=[prompt, steps, current_step], outputs=[report_ht, report_md])

    return (demo, "Sched.", "sched"),


script_callbacks.on_model_loaded(on_model_loaded)
script_callbacks.on_ui_tabs(on_ui_tabs)