yuto0o
ストリーミング
06e3619
from threading import Thread
import torch
from django.http import StreamingHttpResponse
from ninja import NinjaAPI
from transformers import TextIteratorStreamer
from .model_loader import get_model
from .schemas import ChatInput, ChatOutput
# APIインスタンスの作成
api = NinjaAPI()
@api.post("/chat", response=ChatOutput)
def chat(request, data: ChatInput):
"""
Qwenモデルを使用したチャットAPI
"""
user_input = data.text # Schema経由で安全にアクセス
# モデルのロード(初回のみロードが走る)
model, tokenizer = get_model()
# 1. 会話フォーマットの作成
messages = [
{
"role": "system",
"content": "あなたは親切でフレンドリーなAIアシスタント「qwen」です。自然な日本語で簡潔に返事をしてください。",
},
{"role": "user", "content": user_input},
]
# 2. プロンプトへの変換
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
# 3. 生成
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
# 4. デコード
generated_ids = [
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
]
response_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# ChatOutputスキーマに合わせてdictを返す
return {"result": response_text}
# ストリーミング用
@api.post("/chat/stream")
def chat_stream(request, data: ChatInput):
"""
Qwenモデルを使用したストリーミングチャットAPI
"""
user_input = data.text
model, tokenizer = get_model()
# 1. 会話フォーマットの作成
messages = [
{
"role": "system",
"content": "あなたは親切でフレンドリーなAIアシスタント「qwen」です。自然な日本語で簡潔に返事をしてください。",
},
{"role": "user", "content": user_input},
]
# 2. プロンプトへの変換
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
# 3. ストリーマーの準備
# skip_prompt=True にしないと、質問文も一緒に返ってきてしまいます
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
# generateに渡す引数を準備
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
# 4. 別スレッドで生成を開始
# model.generateはブロッキング処理なので、スレッドに逃がさないと
# ストリーミング(yield)が開始されません。
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# 5. ジェネレーター関数の定義
def event_stream():
# streamerはイテレータとして動作し、新しいトークンが生成されるたびにループが回る
for new_text in streamer:
yield new_text
# StreamingHttpResponseにジェネレーターを渡して返す
return StreamingHttpResponse(event_stream(), content_type="text/event-stream")