File size: 5,195 Bytes
6f78bf3 25d1b57 f8184cb 2b65d25 28de333 c2609dd 2b65d25 c2609dd 53ee96a 6f78bf3 28de333 809867d 66ca71d 28de333 f8184cb 28de333 25d1b57 4d9abbf c2609dd f8184cb e4b129b 5383485 6f78bf3 28de333 5383485 f8184cb 5383485 f8184cb 6f78bf3 25d1b57 2b65d25 25d1b57 2b65d25 25d1b57 2b65d25 25d1b57 2b65d25 4c67cec d696de5 6e0aaee d696de5 6e0aaee d696de5 77b42b3 d696de5 77b42b3 2b65d25 25d1b57 6f78bf3 f8184cb e4b129b 25d1b57 e4b129b 4b3ff1b 28de333 5383485 7cd4b81 5383485 53ee96a 28de333 f8184cb 28de333 c2609dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
import os
import uvicorn
import threading
app = FastAPI()
# Load model & tokenizer sekali saat startup
# MODEL_NAME = "Qwen/Qwen1.5-1.8B-Chat"
MODEL_NAME = "Qwen/Qwen1.5-4B-Chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="cpu",
trust_remote_code=True,
)
model.config.use_cache = True
# fallback kalau chat_template kosong
if not tokenizer.chat_template:
tokenizer.chat_template = """{% for message in messages %}
{{ message['role'] }}: {{ message['content'] }}
{% endfor %}
assistant:"""
# Request schema
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: list[Message]
max_new_tokens: int = 128
# Generator untuk streaming token
def generate_stream(prompt, max_new_tokens=128):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# streamer = tokenizer.as_target_tokenizer()
# # pakai generate incremental
# with torch.no_grad():
# output_ids = model.generate(
# **inputs,
# max_new_tokens=max_new_tokens,
# do_sample=True,
# top_p=0.9,
# temperature=0.7
# )[0]
# # Ambil hasil tanpa input
# generated_tokens = output_ids[inputs["input_ids"].shape[1]:]
# for tok in generated_tokens:
# text = tokenizer.decode(tok, skip_special_tokens=True)
# if text.strip():
# yield text
# streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# generation_kwargs = dict(
# **inputs,
# max_new_tokens=max_new_tokens,
# eos_token_id=tokenizer.eos_token_id,
# do_sample=True,
# temperature=0.7,
# streamer=streamer
# )
# thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
# thread.start()
inputs = {k: v.to(model.device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
def run_generation():
try:
model.generate(
**inputs,
max_new_tokens=max_new_tokens,
# min_new_tokens=16,
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=streamer,
# early_stopping=True,
# eos_token_id=tokenizer.eos_token_id,
# pad_token_id=tokenizer.pad_token_id,
use_cache=True,
)
except Exception as e:
# simpan error agar bisa dikembalikan ke client setelah streamer selesai
# error_container.append(str(e))
pass
thread = threading.Thread(target=run_generation, daemon=True)
thread.start()
for token in streamer:
yield token
# streamer = tokenizer.as_target_tokenizer()
# with torch.no_grad():
# output_ids = model.generate(
# **inputs,
# max_new_tokens=128, # batasi jawaban
# min_new_tokens=16, # biar ga berhenti terlalu cepat
# temperature=0.7, # lebih to the point
# top_p=0.9,
# do_sample=True,
# early_stopping=True,
# eos_token_id=tokenizer.eos_token_id,
# pad_token_id=tokenizer.pad_token_id,
# )
# decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# if "Assistant:" in decoded:
# answer = decoded.split("Assistant:")[-1].strip()
# else:
# answer = decoded
# # stream potongan kalimat (kata demi kata)
# for word in answer.split():
# yield word + " "
@app.post("/stream")
async def chat(req: ChatRequest):
# Format prompt sesuai chat template
text = tokenizer.apply_chat_template(
req.messages,
tokenize=False,
add_generation_prompt=True
)
generator = generate_stream(text, req.max_new_tokens)
return StreamingResponse(generator, media_type="text/plain")
@app.post("/chat")
def chat(req: ChatRequest):
text = tokenizer.apply_chat_template(
[m.model_dump() for m in req.messages],
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
top_p=0.9,
temperature=0.7
)
response = tokenizer.decode(
# outputs[0][inputs["input_ids"]:].tolist(),
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
return {"response": response}
@app.get("/")
def root():
return {"message": "Qwen FastAPI running 🚀"}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port)
|