File size: 5,195 Bytes
6f78bf3
25d1b57
f8184cb
2b65d25
28de333
c2609dd
 
2b65d25
c2609dd
53ee96a
6f78bf3
28de333
809867d
66ca71d
28de333
f8184cb
28de333
25d1b57
4d9abbf
c2609dd
f8184cb
e4b129b
5383485
 
 
 
 
 
 
6f78bf3
28de333
5383485
 
 
 
f8184cb
5383485
f8184cb
6f78bf3
25d1b57
 
 
2b65d25
25d1b57
2b65d25
 
 
 
 
 
 
 
 
25d1b57
2b65d25
 
25d1b57
2b65d25
 
 
 
4c67cec
 
 
 
 
 
 
 
 
 
 
 
 
 
d696de5
 
 
 
 
 
 
6e0aaee
d696de5
 
 
 
6e0aaee
 
 
d696de5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77b42b3
d696de5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77b42b3
2b65d25
 
25d1b57
 
 
 
 
 
 
 
 
 
 
6f78bf3
f8184cb
e4b129b
25d1b57
e4b129b
 
 
 
4b3ff1b
28de333
 
 
 
 
 
 
 
5383485
7cd4b81
5383485
 
 
53ee96a
28de333
f8184cb
28de333
 
 
c2609dd
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
import os
import uvicorn
import threading

app = FastAPI()

# Load model & tokenizer sekali saat startup
# MODEL_NAME = "Qwen/Qwen1.5-1.8B-Chat"
MODEL_NAME = "Qwen/Qwen1.5-4B-Chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
    trust_remote_code=True,
)
model.config.use_cache = True

# fallback kalau chat_template kosong
if not tokenizer.chat_template:
    tokenizer.chat_template = """{% for message in messages %}
{{ message['role'] }}: {{ message['content'] }}
{% endfor %}
assistant:"""

# Request schema
class Message(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    messages: list[Message]
    max_new_tokens: int = 128

# Generator untuk streaming token
def generate_stream(prompt, max_new_tokens=128):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    # streamer = tokenizer.as_target_tokenizer()

    # # pakai generate incremental
    # with torch.no_grad():
    #     output_ids = model.generate(
    #         **inputs,
    #         max_new_tokens=max_new_tokens,
    #         do_sample=True,
    #         top_p=0.9,
    #         temperature=0.7
    #     )[0]

    # # Ambil hasil tanpa input
    # generated_tokens = output_ids[inputs["input_ids"].shape[1]:]

    # for tok in generated_tokens:
    #     text = tokenizer.decode(tok, skip_special_tokens=True)
    #     if text.strip():
    #         yield text

    # streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    # generation_kwargs = dict(
    #     **inputs,
    #     max_new_tokens=max_new_tokens,
    #     eos_token_id=tokenizer.eos_token_id,
    #     do_sample=True,
    #     temperature=0.7,
    #     streamer=streamer
    # )

    # thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
    # thread.start()
    
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    def run_generation():
        try:
            model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                # min_new_tokens=16,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                streamer=streamer,
                # early_stopping=True,
                # eos_token_id=tokenizer.eos_token_id,
                # pad_token_id=tokenizer.pad_token_id,
                use_cache=True,
            )
        except Exception as e:
            # simpan error agar bisa dikembalikan ke client setelah streamer selesai
            # error_container.append(str(e))
            pass

    thread = threading.Thread(target=run_generation, daemon=True)
    thread.start()

    for token in streamer:
        yield token


    # streamer = tokenizer.as_target_tokenizer()

    # with torch.no_grad():
    #     output_ids = model.generate(
    #         **inputs,
    #         max_new_tokens=128,       # batasi jawaban
    #         min_new_tokens=16,        # biar ga berhenti terlalu cepat
    #         temperature=0.7,          # lebih to the point
    #         top_p=0.9,
    #         do_sample=True,
    #         early_stopping=True,
    #         eos_token_id=tokenizer.eos_token_id,
    #         pad_token_id=tokenizer.pad_token_id,
    #     )

    # decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    # if "Assistant:" in decoded:
    #     answer = decoded.split("Assistant:")[-1].strip()
    # else:
    #     answer = decoded

    # # stream potongan kalimat (kata demi kata)
    # for word in answer.split():
    #     yield word + " "


@app.post("/stream")
async def chat(req: ChatRequest):
    # Format prompt sesuai chat template
    text = tokenizer.apply_chat_template(
        req.messages,
        tokenize=False,
        add_generation_prompt=True
    )

    generator = generate_stream(text, req.max_new_tokens)
    return StreamingResponse(generator, media_type="text/plain")

@app.post("/chat")
def chat(req: ChatRequest):
    text = tokenizer.apply_chat_template(
        [m.model_dump() for m in req.messages],
        tokenize=False,
        add_generation_prompt=True
    )
    inputs = tokenizer(text, return_tensors="pt").to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=req.max_new_tokens,
        do_sample=True,
        top_p=0.9,
        temperature=0.7
    )

    response = tokenizer.decode(
        # outputs[0][inputs["input_ids"]:].tolist(),
        outputs[0][inputs["input_ids"].shape[1]:],
        skip_special_tokens=True
    )

    return {"response": response}

@app.get("/")
def root():
    return {"message": "Qwen FastAPI running 🚀"}

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run("app:app", host="0.0.0.0", port=port)