File size: 10,990 Bytes
739a2ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 |
import html
import os
import time
import torch
import transformers
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM
from modules import shared, generation_parameters_copypaste
from modules import scripts, script_callbacks, devices, ui
import gradio as gr
from modules.ui_components import FormRow
class Model:
name = None
model = None
tokenizer = None
available_models = []
current = Model()
base_dir = scripts.basedir()
models_dir = os.path.join(base_dir, "models")
def device():
return devices.cpu if shared.opts.promptgen_device == 'cpu' else devices.device
def list_available_models():
available_models.clear()
os.makedirs(models_dir, exist_ok=True)
for dirname in os.listdir(models_dir):
if os.path.isdir(os.path.join(models_dir, dirname)):
available_models.append(dirname)
for name in [x.strip() for x in shared.opts.promptgen_names.split(",")]:
if not name:
continue
available_models.append(name)
def get_model_path(name):
dirname = os.path.join(models_dir, name)
if not os.path.isdir(dirname):
return name
return dirname
def generate_batch(input_ids, min_length, max_length, num_beams, temperature, repetition_penalty, length_penalty, sampling_mode, top_k, top_p):
top_p = float(top_p) if sampling_mode == 'Top P' else None
top_k = int(top_k) if sampling_mode == 'Top K' else None
outputs = current.model.generate(
input_ids,
do_sample=True,
temperature=max(float(temperature), 1e-6),
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
top_p=top_p,
top_k=top_k,
num_beams=int(num_beams),
min_length=min_length,
max_length=max_length,
pad_token_id=current.tokenizer.pad_token_id or current.tokenizer.eos_token_id
)
texts = current.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return texts
def model_selection_changed(model_name):
if model_name == "None":
current.tokenizer = None
current.model = None
current.name = None
devices.torch_gc()
def generate(id_task, model_name, batch_count, batch_size, text, *args):
shared.state.textinfo = "Loading model..."
shared.state.job_count = batch_count
model_name = 'qwopqwop/danbooru-llama-gptq'
if current.name != model_name:
current.tokenizer = None
current.model = None
current.name = None
if model_name != 'None':
model = AutoGPTQForCausalLM.from_quantized("qwopqwop/danbooru-llama-gptq").model
current.model = model
DEFAULT_PAD_TOKEN = "[PAD]"
tokenizer = AutoTokenizer.from_pretrained("pinkmanlove/llama-7b-hf", use_fast=False)
def smart_tokenizer_and_embedding_resize(
special_tokens_dict,
tokenizer,
model,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if tokenizer._pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model)
tokenizer.add_special_tokens({"eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
"bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id),
"unk_token": tokenizer.convert_ids_to_tokens(model.config.pad_token_id if model.config.pad_token_id != -1 else tokenizer.pad_token_id),})
current.tokenizer = tokenizer
current.name = model_name
assert current.model, 'No model available'
assert current.tokenizer, 'No tokenizer available'
current.model.to(device())
shared.state.textinfo = ""
input_ids = current.tokenizer(text, return_tensors="pt").input_ids
if input_ids.shape[1] == 0:
input_ids = torch.asarray([[current.tokenizer.bos_token_id]], dtype=torch.long)
input_ids = input_ids.to(device())
input_ids = input_ids.repeat((batch_size, 1))
markup = '<table><tbody>'
index = 0
for i in range(batch_count):
texts = generate_batch(input_ids, *args)
shared.state.nextjob()
for generated_text in texts:
index += 1
markup += f"""
<tr>
<td>
<div class="prompt gr-box gr-text-input">
<p id='promptgen_res_{index}'>{html.escape(generated_text)}</p>
</div>
</td>
<td class="sendto">
<a class='gr-button gr-button-lg gr-button-secondary' onclick="promptgen_send_to_txt2img(gradioApp().getElementById('promptgen_res_{index}').textContent)">to txt2img</a>
<a class='gr-button gr-button-lg gr-button-secondary' onclick="promptgen_send_to_img2img(gradioApp().getElementById('promptgen_res_{index}').textContent)">to img2img</a>
</td>
</tr>
"""
markup += '</tbody></table>'
return markup, ''
def find_prompts(fields):
field_prompt = [x for x in fields if x[1] == "Prompt"][0]
field_negative_prompt = [x for x in fields if x[1] == "Negative prompt"][0]
return [field_prompt[0], field_negative_prompt[0]]
def send_prompts(text):
params = generation_parameters_copypaste.parse_generation_parameters(text)
negative_prompt = params.get("Negative prompt", "")
return params.get("Prompt", ""), negative_prompt or gr.update()
def add_tab():
list_available_models()
with gr.Blocks(analytics_enabled=False) as tab:
with gr.Row():
with gr.Column(scale=80):
prompt = gr.Textbox(label="Prompt", elem_id="promptgen_prompt", show_label=False, lines=2, placeholder="Beginning of the prompt (press Ctrl+Enter or Alt+Enter to generate)").style(container=False)
with gr.Column(scale=10):
submit = gr.Button('Generate', elem_id="promptgen_generate", variant='primary')
with gr.Row(elem_id="promptgen_main"):
with gr.Column(variant="compact"):
selected_text = gr.TextArea(elem_id='promptgen_selected_text', visible=False)
send_to_txt2img = gr.Button(elem_id='promptgen_send_to_txt2img', visible=False)
send_to_img2img = gr.Button(elem_id='promptgen_send_to_img2img', visible=False)
with FormRow():
model_selection = gr.Dropdown(label="Model", elem_id="promptgen_model", value=available_models[0], choices=["None"] + available_models)
with FormRow():
sampling_mode = gr.Radio(label="Sampling mode", elem_id="promptgen_sampling_mode", value="Top K", choices=["Top K", "Top P"])
top_k = gr.Slider(label="Top K", elem_id="promptgen_top_k", value=12, minimum=1, maximum=50, step=1)
top_p = gr.Slider(label="Top P", elem_id="promptgen_top_p", value=0.15, minimum=0, maximum=1, step=0.001)
with gr.Row():
num_beams = gr.Slider(label="Number of beams", elem_id="promptgen_num_beams", value=1, minimum=1, maximum=8, step=1)
temperature = gr.Slider(label="Temperature", elem_id="promptgen_temperature", value=1, minimum=0, maximum=4, step=0.01)
repetition_penalty = gr.Slider(label="Repetition penalty", elem_id="promptgen_repetition_penalty", value=1, minimum=1, maximum=4, step=0.01)
with FormRow():
length_penalty = gr.Slider(label="Length preference", elem_id="promptgen_length_preference", value=1, minimum=-10, maximum=10, step=0.1)
min_length = gr.Slider(label="Min length", elem_id="promptgen_min_length", value=20, minimum=1, maximum=400, step=1)
max_length = gr.Slider(label="Max length", elem_id="promptgen_max_length", value=150, minimum=1, maximum=400, step=1)
with FormRow():
batch_count = gr.Slider(label="Batch count", elem_id="promptgen_batch_count", value=1, minimum=1, maximum=100, step=1)
batch_size = gr.Slider(label="Batch size", elem_id="promptgen_batch_size", value=10, minimum=1, maximum=100, step=1)
with open(os.path.join(base_dir, "explanation.html"), encoding="utf8") as file:
footer = file.read()
gr.HTML(footer)
with gr.Column():
with gr.Group(elem_id="promptgen_results_column"):
res = gr.HTML()
res_info = gr.HTML()
submit.click(
fn=ui.wrap_gradio_gpu_call(generate, extra_outputs=['']),
_js="submit_promptgen",
inputs=[model_selection, model_selection, batch_count, batch_size, prompt, min_length, max_length, num_beams, temperature, repetition_penalty, length_penalty, sampling_mode, top_k, top_p, ],
outputs=[res, res_info]
)
model_selection.change(
fn=model_selection_changed,
inputs=[model_selection],
outputs=[],
)
send_to_txt2img.click(
fn=send_prompts,
inputs=[selected_text],
outputs=find_prompts(ui.txt2img_paste_fields)
)
send_to_img2img.click(
fn=send_prompts,
inputs=[selected_text],
outputs=find_prompts(ui.img2img_paste_fields)
)
return [(tab, "Promptgen", "promptgen")]
def on_ui_settings():
section = ("promptgen", "Promptgen")
shared.opts.add_option("promptgen_names", shared.OptionInfo("qwopqwop/danbooru-llama-gptq", section=section))
shared.opts.add_option("promptgen_device", shared.OptionInfo("gpu", "Device to use for text generation", gr.Radio, {"choices": ["gpu"]}, section=section))
def on_unload():
current.model = None
current.tokenizer = None
script_callbacks.on_ui_tabs(add_tab)
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_script_unloaded(on_unload)
|