| import gradio as gr |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| import spaces |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| |
| MODEL_NAME = "YUGOROU/Short-TeenEmo-Reasoning" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| DEFAULT_SYSTEM_PROMPT = """あなたは思いやりのあるカウンセラーです。思考の過程は<think>と</think>の間に記してください""" |
|
|
| |
| model = None |
| tokenizer = None |
|
|
| def load_model(): |
| """モデルとトークナイザーの読み込み""" |
| global model, tokenizer |
|
|
| try: |
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| MODEL_NAME, |
| trust_remote_code=True, |
| use_fast=True |
| ) |
|
|
| |
| if DEVICE == "cuda": |
| |
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.float16, |
| bnb_4bit_use_double_quant=True, |
| ) |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| quantization_config=quantization_config, |
| device_map="auto", |
| trust_remote_code=True, |
| torch_dtype=torch.float16, |
| low_cpu_mem_usage=True, |
| ) |
| else: |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| trust_remote_code=True, |
| torch_dtype=torch.float16, |
| low_cpu_mem_usage=True, |
| ) |
| |
| print(f"Model loaded successfully on {DEVICE}") |
| return True |
|
|
| except Exception as e: |
| print(f"Error loading model: {e}") |
| return False |
|
|
| |
| print("Loading model...") |
| model_loaded = load_model() |
|
|
| @spaces.GPU(duration=120) |
| def generate_response( |
| message, |
| history, |
| system_prompt=DEFAULT_SYSTEM_PROMPT, |
| temperature=1.3, |
| max_new_tokens=512, |
| top_k=50, |
| top_p=0.9, |
| repetition_penalty=1.1 |
| ): |
| """メッセージに対する応答を生成""" |
|
|
| if not model_loaded: |
| return "モデルの読み込みに失敗しました。" |
|
|
| try: |
| |
| conversation = [] |
|
|
| |
| if system_prompt.strip(): |
| conversation.append({"role": "system", "content": system_prompt}) |
|
|
| |
| for user_msg, assistant_msg in history: |
| conversation.append({"role": "user", "content": user_msg}) |
| if assistant_msg: |
| conversation.append({"role": "assistant", "content": assistant_msg}) |
|
|
| |
| conversation.append({"role": "user", "content": message}) |
|
|
| |
| input_text = tokenizer.apply_chat_template( |
| conversation, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
|
|
| |
| inputs = tokenizer(input_text, return_tensors="pt").to(DEVICE) |
|
|
| |
| generation_kwargs = { |
| "input_ids": inputs["input_ids"], |
| "attention_mask": inputs["attention_mask"], |
| "max_new_tokens": max_new_tokens, |
| "temperature": temperature, |
| "top_k": top_k, |
| "top_p": top_p, |
| "repetition_penalty": repetition_penalty, |
| "do_sample": True, |
| "pad_token_id": tokenizer.eos_token_id, |
| "eos_token_id": tokenizer.eos_token_id, |
| } |
|
|
| |
| with torch.no_grad(): |
| outputs = model.generate(**generation_kwargs) |
|
|
| |
| response = tokenizer.decode( |
| outputs[0][len(inputs["input_ids"][0]):], |
| skip_special_tokens=True |
| ).strip() |
|
|
| return response |
|
|
| except Exception as e: |
| return f"エラーが発生しました: {str(e)}" |
|
|
| def clear_chat(): |
| """チャット履歴をクリア""" |
| return [], "" |
|
|
| |
| with gr.Blocks( |
| title="TeenEmo Reasoning v2 - 10代向けカウンセリングAI", |
| theme=gr.themes.Soft(), |
| css=""".gradio-container {max-width: 1000px; margin: auto;}""" |
| ) as demo: |
|
|
| gr.HTML(""" |
| <div style="text-align: center; padding: 20px;"> |
| <h1>🤖 TeenEmo Reasoning v2</h1> |
| <h3>10代向けカウンセリングAIチャットボット</h3> |
| <p>このAIは、10代の若者の感情的な悩みや課題に寄り添い、思いやりのあるサポートを提供するように設計されています。</p> |
| <p><strong>機能:</strong> 感情理解、共感的対話、建設的なアドバイス</p> |
| <p><em>※ このAIは補助的なツールです。深刻な問題については専門家にご相談ください。</em></p> |
| </div> |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=3): |
| |
| chatbot = gr.Chatbot( |
| height=500, |
| placeholder="👋 こんにちは!何でも気軽に話しかけてくださいね。あなたの気持ちに寄り添います。", |
| show_label=False, |
| container=True, |
| bubble_full_width=False |
| ) |
|
|
| with gr.Row(): |
| msg = gr.Textbox( |
| placeholder="メッセージを入力してください...", |
| container=False, |
| scale=4, |
| show_label=False |
| ) |
| submit_btn = gr.Button("送信", variant="primary", scale=1) |
| clear_btn = gr.Button("クリア", variant="secondary", scale=1) |
|
|
| with gr.Column(scale=1, min_width=300): |
| |
| with gr.Accordion("⚙️ 高度な設定", open=False): |
| system_prompt = gr.Textbox( |
| label="システムプロンプト", |
| value=DEFAULT_SYSTEM_PROMPT, |
| lines=4, |
| placeholder="AIの役割や行動指針を設定してください" |
| ) |
|
|
| with gr.Group(): |
| temperature = gr.Slider( |
| minimum=0.1, |
| maximum=2.0, |
| value=1.3, |
| step=0.1, |
| label="Temperature (創造性)", |
| info="高いほど創造的、低いほど一貫性のある応答" |
| ) |
|
|
| max_new_tokens = gr.Slider( |
| minimum=50, |
| maximum=1024, |
| value=512, |
| step=50, |
| label="最大新規トークン数", |
| info="応答の最大長を制御" |
| ) |
|
|
| top_k = gr.Slider( |
| minimum=1, |
| maximum=100, |
| value=50, |
| step=1, |
| label="Top K サンプリング", |
| info="選択候補の語彙数を制限" |
| ) |
|
|
| top_p = gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=0.9, |
| step=0.05, |
| label="Top P サンプリング", |
| info="累積確率でトークンを選択" |
| ) |
|
|
| repetition_penalty = gr.Slider( |
| minimum=1.0, |
| maximum=2.0, |
| value=1.1, |
| step=0.05, |
| label="繰り返しペナルティ", |
| info="同じ表現の繰り返しを抑制" |
| ) |
|
|
| reset_settings = gr.Button("設定をリセット", variant="secondary") |
|
|
| |
| with gr.Row(): |
| gr.Examples( |
| examples=[ |
| ["最近、学校で友達関係に悩んでいます。どうすればいいでしょうか?"], |
| ["将来について不安で眠れません。"], |
| ["親との関係がうまくいかなくて困っています。"], |
| ["勉強のプレッシャーでストレスを感じています。"], |
| ["自分に自信が持てません。どうしたら自信をつけられますか?"] |
| ], |
| inputs=msg, |
| label="💬 サンプル質問(クリックして試してみてください)" |
| ) |
|
|
| |
| gr.HTML(""" |
| <div style="background-color: #f0f8ff; padding: 15px; border-radius: 10px; margin-top: 20px;"> |
| <h4>🔔 ご利用上の注意</h4> |
| <ul> |
| <li>このAIは研究・学習目的で開発されたプロトタイプです</li> |
| <li>深刻な心理的問題や危機的状況では、専門家や信頼できる大人にご相談ください</li> |
| <li>緊急時は適切な支援機関(いのちの電話: 0570-783-556など)にご連絡ください</li> |
| <li>生成される応答の正確性は保証されません</li> |
| </ul> |
| </div> |
| """) |
|
|
| |
| def respond(message, chat_history, system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty): |
| if not message.strip(): |
| return chat_history, "" |
|
|
| |
| bot_message = generate_response( |
| message, |
| chat_history, |
| system_prompt, |
| temperature, |
| max_new_tokens, |
| top_k, |
| top_p, |
| repetition_penalty |
| ) |
|
|
| |
| chat_history.append((message, bot_message)) |
| return chat_history, "" |
|
|
| def reset_advanced_settings(): |
| return ( |
| DEFAULT_SYSTEM_PROMPT, |
| 0.7, |
| 512, |
| 50, |
| 0.9, |
| 1.1 |
| ) |
|
|
| |
| submit_btn.click( |
| respond, |
| inputs=[msg, chatbot, system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty], |
| outputs=[chatbot, msg] |
| ) |
|
|
| msg.submit( |
| respond, |
| inputs=[msg, chatbot, system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty], |
| outputs=[chatbot, msg] |
| ) |
|
|
| clear_btn.click(clear_chat, outputs=[chatbot, msg]) |
|
|
| reset_settings.click( |
| reset_advanced_settings, |
| outputs=[system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty] |
| ) |
|
|
| |
| if __name__ == "__main__": |
| demo.queue(max_size=10).launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| show_error=True |
| ) |