| import os | |
| import gradio as gr | |
| from text_generation import Client, InferenceAPIClient | |
| def get_client(model: str): | |
| return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None)) | |
| def get_usernames(model: str): | |
| """ | |
| Returns: | |
| (str, str, str, str): pre-prompt, username, bot name, separator | |
| """ | |
| if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"): | |
| return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>" | |
| return "", "User: ", "Assistant: ", "\n" | |
| def predict( | |
| inputs: str, | |
| ): | |
| model = "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5" | |
| client = get_client(model) | |
| preprompt, user_name, assistant_name, sep = get_usernames(model) | |
| past = [] | |
| total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() | |
| partial_words = "" | |
| if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"): | |
| iterator = client.generate( | |
| total_inputs, | |
| typical_p=0.1, | |
| truncate=1000, | |
| watermark=0, | |
| max_new_tokens=500, | |
| ) | |
| yield iterator.generated_text | |
| g = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.components.Textbox(lines=2, label="Input", placeholder="none"), | |
| ], | |
| outputs=[ | |
| gr.inputs.Textbox( | |
| lines=5, | |
| label="Output", | |
| ) | |
| ] | |
| ) | |
| g.queue(concurrency_count=1) | |
| g.launch() | |