File size: 5,874 Bytes
1166e57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2aa9275
 
 
 
a327f7b
 
2aa9275
 
a327f7b
 
2aa9275
 
 
 
 
 
 
 
 
 
 
 
a327f7b
465af28
2aa9275
 
1166e57
2aa9275
 
465af28
 
 
 
2aa9275
 
 
 
1166e57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2aa9275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465af28
 
 
2aa9275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a327f7b
 
2aa9275
465af28
a327f7b
 
 
1166e57
2aa9275
 
 
 
465af28
 
 
 
2aa9275
 
a327f7b
 
 
2aa9275
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def respond(message, history):
    """
    チャットボットの応答を生成する関数
    historyはタプルのリスト形式: [(user_msg, bot_msg), ...]
    """
    try:
        # システムメッセージとして使用するプロンプト
        system_message = "あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。"
        
        # 会話履歴を文字列形式に変換
        conversation = ""
        if system_message.strip():
            conversation += f"システム: {system_message}\n"
        
        # 会話履歴を追加(タプル形式)
        for user_msg, bot_msg in history:
            if user_msg:
                conversation += f"ユーザー: {user_msg}\n"
            if bot_msg:
                conversation += f"アシスタント: {bot_msg}\n"import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings("ignore")

"""
Sarashinaモデルを使用したGradioチャットボット
Hugging Face Transformersライブラリを使用してローカルでモデルを実行
"""

# モデルとトークナイザーの初期化
MODEL_NAME = "sbintuitions/sarashina2.2-3b-instruct-v0.1"

print("モデルを読み込み中...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None,
    trust_remote_code=True
)
print("モデルの読み込みが完了しました。")

def respond(message, history):
    """
    チャットボットの応答を生成する関数
    ChatInterfaceではtype="messages"でもhistoryの形式が異なる場合があります
    """
    try:
        # システムメッセージとして使用するプロンプト
        system_message = "あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。"
        
        # 会話履歴を文字列形式に変換
        conversation = ""
        if system_message.strip():
            conversation += f"システム: {system_message}\n"
        
        # 会話履歴を追加(historyの形式を確認して処理)
        if history:
            for item in history:
                if isinstance(item, dict):
                    # messages形式の場合
                    if item.get("role") == "user":
                        conversation += f"ユーザー: {item.get('content', '')}\n"
                    elif item.get("role") == "assistant":
                        conversation += f"アシスタント: {item.get('content', '')}\n"
                elif isinstance(item, (list, tuple)) and len(item) >= 2:
                    # タプル形式の場合 (user_msg, bot_msg)
                    user_msg, bot_msg = item[0], item[1]
                    if user_msg:
                        conversation += f"ユーザー: {user_msg}\n"
                    if bot_msg:
                        conversation += f"アシスタント: {bot_msg}\n"
        
        # 現在のメッセージを追加
        conversation += f"ユーザー: {message}\nアシスタント: "
        
        # トークン化
        inputs = tokenizer.encode(conversation, return_tensors="pt")
        
        # GPU使用時はCUDAに移動
        if torch.cuda.is_available():
            inputs = inputs.cuda()
        
        # 応答生成(ストリーミング対応)
        response = ""
        with torch.no_grad():
            # 一度に生成してからストリーミング風に出力
            outputs = model.generate(
                inputs,
                max_new_tokens=512,  # デフォルト値を使用
                temperature=0.7,     # デフォルト値を使用
                top_p=0.95,          # デフォルト値を使用
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                repetition_penalty=1.1
            )
        
        # 生成されたテキストをデコード
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # 応答部分のみを抽出
        full_response = generated[len(conversation):].strip()
        
        # 不要な部分を除去
        if "ユーザー:" in full_response:
            full_response = full_response.split("ユーザー:")[0].strip()
        
        # ストリーミング風の出力
        for i in range(len(full_response)):
            response = full_response[:i+1]
            yield response
            
    except Exception as e:
        yield f"エラーが発生しました: {str(e)}"

"""
Gradio ChatInterfaceを使用したシンプルなチャットボット
type="messages"を設定してOpenAI形式のメッセージを使用
"""
demo = gr.ChatInterface(
    respond,
    # type="messages"を削除 - デフォルトのタプル形式を使用
    title="🤖 Sarashina Chatbot",
    description="Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。",
    theme=gr.themes.Soft(),
    examples=[
        "こんにちは!今日はどんなことを話しましょうか?",
        "日本の文化について教えてください。",
        "簡単なレシピを教えてもらえますか?",
        "プログラミングについて質問があります。",
    ],
    cache_examples=False,
)

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_api=True,  # API documentation を表示
        debug=True
    )