| | import spaces |
| |
|
| | import os |
| | import gradio as gr |
| | import torch |
| | from threading import Thread |
| |
|
| | from typing import Union |
| | from pathlib import Path |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | PreTrainedModel, |
| | PreTrainedTokenizer, |
| | PreTrainedTokenizerFast, |
| | StoppingCriteria, |
| | StoppingCriteriaList, |
| | TextIteratorStreamer |
| | ) |
| |
|
| | ModelType = Union[PreTrainedModel] |
| | TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] |
| |
|
| | MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat') |
| | TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) |
| |
|
| |
|
| | def _resolve_path(path: Union[str, Path]) -> Path: |
| | return Path(path).expanduser().resolve() |
| |
|
| |
|
| | def load_model_and_tokenizer( |
| | model_dir: Union[str, Path], trust_remote_code: bool = True |
| | ) -> tuple[ModelType, TokenizerType]: |
| | model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=trust_remote_code, device_map='auto') |
| | tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=trust_remote_code, use_fast=False) |
| | return model, tokenizer |
| |
|
| |
|
| | model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True) |
| |
|
| |
|
| | class StopOnTokens(StoppingCriteria): |
| | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
| | stop_ids = model.config.eos_token_id |
| | for stop_id in stop_ids: |
| | if input_ids[0][-1] == stop_id: |
| | return True |
| | return False |
| |
|
| |
|
| | def parse_text(text): |
| | lines = text.split("\n") |
| | lines = [line for line in lines if line != ""] |
| | count = 0 |
| | for i, line in enumerate(lines): |
| | if "```" in line: |
| | count += 1 |
| | items = line.split('`') |
| | if count % 2 == 1: |
| | lines[i] = f'<pre><code class="language-{items[-1]}">' |
| | else: |
| | lines[i] = f'<br></code></pre>' |
| | else: |
| | if i > 0: |
| | if count % 2 == 1: |
| | line = line.replace("`", "\`") |
| | line = line.replace("<", "<") |
| | line = line.replace(">", ">") |
| | line = line.replace(" ", " ") |
| | line = line.replace("*", "*") |
| | line = line.replace("_", "_") |
| | line = line.replace("-", "-") |
| | line = line.replace(".", ".") |
| | line = line.replace("!", "!") |
| | line = line.replace("(", "(") |
| | line = line.replace(")", ")") |
| | line = line.replace("$", "$") |
| | lines[i] = "<br>" + line |
| | text = "".join(lines) |
| | return text |
| |
|
| | @spaces.GPU |
| | def predict(history, max_length, top_p, temperature): |
| | stop = StopOnTokens() |
| | messages = [] |
| | for idx, (user_msg, model_msg) in enumerate(history): |
| | if idx == len(history) - 1 and not model_msg: |
| | messages.append({"role": "user", "content": user_msg}) |
| | break |
| | if user_msg: |
| | messages.append({"role": "user", "content": user_msg}) |
| | if model_msg: |
| | messages.append({"role": "assistant", "content": model_msg}) |
| |
|
| | model_inputs = tokenizer.apply_chat_template(messages, |
| | add_generation_prompt=True, |
| | tokenize=True, |
| | return_tensors="pt").to(next(model.parameters()).device) |
| | streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True) |
| | generate_kwargs = { |
| | "input_ids": model_inputs, |
| | "streamer": streamer, |
| | "max_new_tokens": max_length, |
| | "do_sample": True, |
| | "top_p": top_p, |
| | "temperature": temperature, |
| | "stopping_criteria": StoppingCriteriaList([stop]), |
| | "repetition_penalty": 1.2, |
| | "eos_token_id": model.config.eos_token_id, |
| | } |
| | t = Thread(target=model.generate, kwargs=generate_kwargs) |
| | t.start() |
| | for new_token in streamer: |
| | if new_token: |
| | history[-1][1] += new_token |
| | yield history |
| |
|
| |
|
| | with gr.Blocks() as demo: |
| | gr.HTML("""<h1 align="center">GLM-4-9B Gradio Simple Chat Demo</h1>""") |
| | chatbot = gr.Chatbot() |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=4): |
| | with gr.Column(scale=12): |
| | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False) |
| | with gr.Column(min_width=32, scale=1): |
| | submitBtn = gr.Button("Submit") |
| | with gr.Column(scale=1): |
| | emptyBtn = gr.Button("Clear History") |
| | max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True) |
| | top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True) |
| | temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True) |
| |
|
| |
|
| | def user(query, history): |
| | return "", history + [[parse_text(query), ""]] |
| |
|
| |
|
| | submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then( |
| | predict, [chatbot, max_length, top_p, temperature], chatbot |
| | ) |
| | emptyBtn.click(lambda: None, None, chatbot, queue=False) |
| |
|
| | demo.queue().launch() |
| |
|
| |
|