|
|
def respond(message, history): |
|
|
""" |
|
|
チャットボットの応答を生成する関数 |
|
|
historyはタプルのリスト形式: [(user_msg, bot_msg), ...] |
|
|
""" |
|
|
try: |
|
|
|
|
|
system_message = "あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。" |
|
|
|
|
|
|
|
|
conversation = "" |
|
|
if system_message.strip(): |
|
|
conversation += f"システム: {system_message}\n" |
|
|
|
|
|
|
|
|
for user_msg, bot_msg in history: |
|
|
if user_msg: |
|
|
conversation += f"ユーザー: {user_msg}\n" |
|
|
if bot_msg: |
|
|
conversation += f"アシスタント: {bot_msg}\n"import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
""" |
|
|
Sarashinaモデルを使用したGradioチャットボット |
|
|
Hugging Face Transformersライブラリを使用してローカルでモデルを実行 |
|
|
""" |
|
|
|
|
|
|
|
|
MODEL_NAME = "sbintuitions/sarashina2.2-3b-instruct-v0.1" |
|
|
|
|
|
print("モデルを読み込み中...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_NAME, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
trust_remote_code=True |
|
|
) |
|
|
print("モデルの読み込みが完了しました。") |
|
|
|
|
|
def respond(message, history): |
|
|
""" |
|
|
チャットボットの応答を生成する関数 |
|
|
ChatInterfaceではtype="messages"でもhistoryの形式が異なる場合があります |
|
|
""" |
|
|
try: |
|
|
|
|
|
system_message = "あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。" |
|
|
|
|
|
|
|
|
conversation = "" |
|
|
if system_message.strip(): |
|
|
conversation += f"システム: {system_message}\n" |
|
|
|
|
|
|
|
|
if history: |
|
|
for item in history: |
|
|
if isinstance(item, dict): |
|
|
|
|
|
if item.get("role") == "user": |
|
|
conversation += f"ユーザー: {item.get('content', '')}\n" |
|
|
elif item.get("role") == "assistant": |
|
|
conversation += f"アシスタント: {item.get('content', '')}\n" |
|
|
elif isinstance(item, (list, tuple)) and len(item) >= 2: |
|
|
|
|
|
user_msg, bot_msg = item[0], item[1] |
|
|
if user_msg: |
|
|
conversation += f"ユーザー: {user_msg}\n" |
|
|
if bot_msg: |
|
|
conversation += f"アシスタント: {bot_msg}\n" |
|
|
|
|
|
|
|
|
conversation += f"ユーザー: {message}\nアシスタント: " |
|
|
|
|
|
|
|
|
inputs = tokenizer.encode(conversation, return_tensors="pt") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
inputs = inputs.cuda() |
|
|
|
|
|
|
|
|
response = "" |
|
|
with torch.no_grad(): |
|
|
|
|
|
outputs = model.generate( |
|
|
inputs, |
|
|
max_new_tokens=512, |
|
|
temperature=0.7, |
|
|
top_p=0.95, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
repetition_penalty=1.1 |
|
|
) |
|
|
|
|
|
|
|
|
generated = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
full_response = generated[len(conversation):].strip() |
|
|
|
|
|
|
|
|
if "ユーザー:" in full_response: |
|
|
full_response = full_response.split("ユーザー:")[0].strip() |
|
|
|
|
|
|
|
|
for i in range(len(full_response)): |
|
|
response = full_response[:i+1] |
|
|
yield response |
|
|
|
|
|
except Exception as e: |
|
|
yield f"エラーが発生しました: {str(e)}" |
|
|
|
|
|
""" |
|
|
Gradio ChatInterfaceを使用したシンプルなチャットボット |
|
|
type="messages"を設定してOpenAI形式のメッセージを使用 |
|
|
""" |
|
|
demo = gr.ChatInterface( |
|
|
respond, |
|
|
|
|
|
title="🤖 Sarashina Chatbot", |
|
|
description="Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。", |
|
|
theme=gr.themes.Soft(), |
|
|
examples=[ |
|
|
"こんにちは!今日はどんなことを話しましょうか?", |
|
|
"日本の文化について教えてください。", |
|
|
"簡単なレシピを教えてもらえますか?", |
|
|
"プログラミングについて質問があります。", |
|
|
], |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
show_api=True, |
|
|
debug=True |
|
|
) |