yuto0o commited on
Commit
06e3619
·
1 Parent(s): dd95ee1

ストリーミング

Browse files
Files changed (2) hide show
  1. .gitattributes copy +0 -35
  2. ml_api/api.py +60 -0
.gitattributes copy DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ml_api/api.py CHANGED
@@ -1,5 +1,9 @@
 
 
1
  import torch
 
2
  from ninja import NinjaAPI
 
3
 
4
  from .model_loader import get_model
5
  from .schemas import ChatInput, ChatOutput
@@ -53,3 +57,59 @@ def chat(request, data: ChatInput):
53
 
54
  # ChatOutputスキーマに合わせてdictを返す
55
  return {"result": response_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+
3
  import torch
4
+ from django.http import StreamingHttpResponse
5
  from ninja import NinjaAPI
6
+ from transformers import TextIteratorStreamer
7
 
8
  from .model_loader import get_model
9
  from .schemas import ChatInput, ChatOutput
 
57
 
58
  # ChatOutputスキーマに合わせてdictを返す
59
  return {"result": response_text}
60
+
61
+
62
+ # ストリーミング用
63
+ @api.post("/chat/stream")
64
+ def chat_stream(request, data: ChatInput):
65
+ """
66
+ Qwenモデルを使用したストリーミングチャットAPI
67
+ """
68
+ user_input = data.text
69
+ model, tokenizer = get_model()
70
+
71
+ # 1. 会話フォーマットの作成
72
+ messages = [
73
+ {
74
+ "role": "system",
75
+ "content": "あなたは親切でフレンドリーなAIアシスタント「qwen」です。自然な日本語で簡潔に返事をしてください。",
76
+ },
77
+ {"role": "user", "content": user_input},
78
+ ]
79
+
80
+ # 2. プロンプトへの変換
81
+ text = tokenizer.apply_chat_template(
82
+ messages, tokenize=False, add_generation_prompt=True
83
+ )
84
+ inputs = tokenizer([text], return_tensors="pt").to(model.device)
85
+
86
+ # 3. ストリーマーの準備
87
+ # skip_prompt=True にしないと、質問文も一緒に返ってきてしまいます
88
+ streamer = TextIteratorStreamer(
89
+ tokenizer, skip_prompt=True, skip_special_tokens=True
90
+ )
91
+
92
+ # generateに渡す引数を準備
93
+ generation_kwargs = dict(
94
+ inputs,
95
+ streamer=streamer,
96
+ max_new_tokens=1024,
97
+ do_sample=True,
98
+ temperature=0.7,
99
+ top_p=0.9,
100
+ )
101
+
102
+ # 4. 別スレッドで生成を開始
103
+ # model.generateはブロッキング処理なので、スレッドに逃がさないと
104
+ # ストリーミング(yield)が開始されません。
105
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
106
+ thread.start()
107
+
108
+ # 5. ジェネレーター関数の定義
109
+ def event_stream():
110
+ # streamerはイテレータとして動作し、新しいトークンが生成されるたびにループが回る
111
+ for new_text in streamer:
112
+ yield new_text
113
+
114
+ # StreamingHttpResponseにジェネレーターを渡して返す
115
+ return StreamingHttpResponse(event_stream(), content_type="text/event-stream")