File size: 3,671 Bytes
06e3619
 
9332d9b
06e3619
9332d9b
06e3619
9332d9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd95ee1
9332d9b
 
 
 
 
 
 
 
 
 
 
 
 
 
06e3619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from threading import Thread

import torch
from django.http import StreamingHttpResponse
from ninja import NinjaAPI
from transformers import TextIteratorStreamer

from .model_loader import get_model
from .schemas import ChatInput, ChatOutput

# APIインスタンスの作成
api = NinjaAPI()


@api.post("/chat", response=ChatOutput)
def chat(request, data: ChatInput):
    """
    Qwenモデルを使用したチャットAPI
    """
    user_input = data.text  # Schema経由で安全にアクセス

    # モデルのロード(初回のみロードが走る)
    model, tokenizer = get_model()

    # 1. 会話フォーマットの作成
    messages = [
        {
            "role": "system",
            "content": "あなたは親切でフレンドリーなAIアシスタント「qwen」です。自然な日本語で簡潔に返事をしてください。",
        },
        {"role": "user", "content": user_input},
    ]

    # 2. プロンプトへの変換
    text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    inputs = tokenizer([text], return_tensors="pt").to(model.device)

    # 3. 生成
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=1024,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
        )

    # 4. デコード
    generated_ids = [
        output_ids[len(input_ids) :]
        for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
    ]
    response_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    # ChatOutputスキーマに合わせてdictを返す
    return {"result": response_text}


# ストリーミング用
@api.post("/chat/stream")
def chat_stream(request, data: ChatInput):
    """
    Qwenモデルを使用したストリーミングチャットAPI
    """
    user_input = data.text
    model, tokenizer = get_model()

    # 1. 会話フォーマットの作成
    messages = [
        {
            "role": "system",
            "content": "あなたは親切でフレンドリーなAIアシスタント「qwen」です。自然な日本語で簡潔に返事をしてください。",
        },
        {"role": "user", "content": user_input},
    ]

    # 2. プロンプトへの変換
    text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer([text], return_tensors="pt").to(model.device)

    # 3. ストリーマーの準備
    # skip_prompt=True にしないと、質問文も一緒に返ってきてしまいます
    streamer = TextIteratorStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )

    # generateに渡す引数を準備
    generation_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
    )

    # 4. 別スレッドで生成を開始
    # model.generateはブロッキング処理なので、スレッドに逃がさないと
    # ストリーミング(yield)が開始されません。
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # 5. ジェネレーター関数の定義
    def event_stream():
        # streamerはイテレータとして動作し、新しいトークンが生成されるたびにループが回る
        for new_text in streamer:
            yield new_text

    # StreamingHttpResponseにジェネレーターを渡して返す
    return StreamingHttpResponse(event_stream(), content_type="text/event-stream")