| | import torch |
| | from modules import chat, shared |
| | from modules.text_generation import ( |
| | decode, |
| | encode, |
| | generate_reply, |
| | ) |
| | from transformers import LogitsProcessor |
| | import gradio as gr |
| |
|
| | params = { |
| | "display_name": "Long replies", |
| | "is_tab": False, |
| | "min_length": 120, |
| | } |
| |
|
| | initial_size = 0 |
| |
|
| | class MyLogits(LogitsProcessor): |
| | """ |
| | Manipulates the probabilities for the next token before it gets sampled. |
| | Used in the logits_processor_modifier function below. |
| | """ |
| | def __init__(self): |
| | self.newline_id = shared.tokenizer.encode('\n')[-1] |
| | pass |
| |
|
| | def __call__(self, input_ids, scores): |
| | if input_ids.shape[-1] - initial_size < params["min_length"]: |
| | scores[...,self.newline_id] = -1000 |
| | |
| |
|
| | |
| | |
| | |
| | return scores |
| |
|
| | def history_modifier(history): |
| | """ |
| | Modifies the chat history. |
| | Only used in chat mode. |
| | """ |
| | return history |
| |
|
| | def state_modifier(state): |
| | """ |
| | Modifies the state variable, which is a dictionary containing the input |
| | values in the UI like sliders and checkboxes. |
| | """ |
| | return state |
| |
|
| | def chat_input_modifier(text, visible_text, state): |
| | """ |
| | Modifies the user input string in chat mode (visible_text). |
| | You can also modify the internal representation of the user |
| | input (text) to change how it will appear in the prompt. |
| | """ |
| | return text, visible_text |
| |
|
| | def input_modifier(string, state): |
| | """ |
| | In default/notebook modes, modifies the whole prompt. |
| | |
| | In chat mode, it is the same as chat_input_modifier but only applied |
| | to "text", here called "string", and not to "visible_text". |
| | """ |
| | return string |
| |
|
| | def bot_prefix_modifier(string, state): |
| | """ |
| | Modifies the prefix for the next bot reply in chat mode. |
| | By default, the prefix will be something like "Bot Name:". |
| | """ |
| | return string |
| |
|
| | def tokenizer_modifier(state, prompt, input_ids, input_embeds): |
| | """ |
| | Modifies the input ids and embeds. |
| | Used by the multimodal extension to put image embeddings in the prompt. |
| | Only used by loaders that use the transformers library for sampling. |
| | """ |
| |
|
| | global initial_size |
| | initial_size = input_ids.shape[-1] |
| |
|
| | return prompt, input_ids, input_embeds |
| |
|
| | def logits_processor_modifier(processor_list, input_ids): |
| | """ |
| | Adds logits processors to the list, allowing you to access and modify |
| | the next token probabilities. |
| | Only used by loaders that use the transformers library for sampling. |
| | """ |
| | processor_list.append(MyLogits()) |
| | return processor_list |
| |
|
| | def output_modifier(string, state): |
| | """ |
| | Modifies the LLM output before it gets presented. |
| | |
| | In chat mode, the modified version goes into history['visible'], |
| | and the original version goes into history['internal']. |
| | """ |
| | return string |
| |
|
| | def custom_generate_chat_prompt(user_input, state, **kwargs): |
| | """ |
| | Replaces the function that generates the prompt from the chat history. |
| | Only used in chat mode. |
| | """ |
| | result = chat.generate_chat_prompt(user_input, state, **kwargs) |
| | return result |
| |
|
| | def custom_css(): |
| | """ |
| | Returns a CSS string that gets appended to the CSS for the webui. |
| | """ |
| | return '' |
| |
|
| | def custom_js(): |
| | """ |
| | Returns a javascript string that gets appended to the javascript |
| | for the webui. |
| | """ |
| | return '' |
| |
|
| | def setup(): |
| | """ |
| | Gets executed only once, when the extension is imported. |
| | """ |
| | pass |
| |
|
| | def ui(): |
| | """ |
| | Gets executed when the UI is drawn. Custom gradio elements and |
| | their corresponding event handlers should be defined here. |
| | |
| | To learn about gradio components, check out the docs: |
| | https://gradio.app/docs/ |
| | """ |
| |
|
| | min_length = gr.Slider(0, 800, step=10, value=params['min_length'], label='Minimum reply length') |
| | min_length.change(lambda x: params.update({'min_length': x}), min_length, None) |
| |
|