| import gradio as gr |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
| import torch |
|
|
| |
| MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2" |
| |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| torch_dtype=torch.float16, |
| device_map="auto" |
| ) |
|
|
| def chat_fn(message, history): |
| |
| prompt = "" |
| for user, assistant in history: |
| prompt += f"<s>[ユーザー]: {user}\n[アシスタント]: {assistant}</s>\n" |
| prompt += f"<s>[ユーザー]: {message}\n[アシスタント]:" |
|
|
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
| output_ids = model.generate( |
| **inputs, |
| max_new_tokens=200, |
| temperature=0.7, |
| do_sample=True, |
| top_p=0.9 |
| ) |
|
|
| response = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
| |
| if "[アシスタント]:" in response: |
| response = response.split("[アシスタント]:")[-1].strip() |
|
|
| history.append((message, response)) |
| return response, history |
|
|
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# 🦙💬 Simple Llama / Mistral Chatbot") |
| chatbot = gr.Chatbot() |
| msg = gr.Textbox(label="Message") |
|
|
| def user_send(user_message, chat_history): |
| return "", chat_history + [[user_message, None]] |
|
|
| msg.submit(user_send, [msg, chatbot], [msg, chatbot]).then( |
| chat_fn, [msg, chatbot], [chatbot] |
| ) |
|
|
| demo.launch() |
|
|