File size: 1,770 Bytes
6f78bf3
f8184cb
28de333
 
c2609dd
 
 
53ee96a
6f78bf3
28de333
809867d
66ca71d
28de333
f8184cb
28de333
 
4d9abbf
c2609dd
f8184cb
e4b129b
5383485
 
 
 
 
 
 
6f78bf3
28de333
5383485
 
 
 
f8184cb
5383485
f8184cb
6f78bf3
 
f8184cb
e4b129b
5383485
e4b129b
 
 
 
4b3ff1b
28de333
 
 
 
 
 
 
 
5383485
7cd4b81
5383485
 
 
53ee96a
28de333
f8184cb
28de333
 
 
c2609dd
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import uvicorn

app = FastAPI()

# Load model & tokenizer sekali saat startup
# MODEL_NAME = "Qwen/Qwen1.5-1.8B-Chat"
MODEL_NAME = "Qwen/Qwen1.5-4B-Chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="cpu",
    trust_remote_code=True,
)
model.config.use_cache = True

# fallback kalau chat_template kosong
if not tokenizer.chat_template:
    tokenizer.chat_template = """{% for message in messages %}
{{ message['role'] }}: {{ message['content'] }}
{% endfor %}
assistant:"""

# Request schema
class Message(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    messages: list[Message]
    max_new_tokens: int = 128

@app.post("/chat")
def chat(req: ChatRequest):
    text = tokenizer.apply_chat_template(
        [m.dict() for m in req.messages],
        tokenize=False,
        add_generation_prompt=True
    )
    inputs = tokenizer(text, return_tensors="pt").to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=req.max_new_tokens,
        do_sample=True,
        top_p=0.9,
        temperature=0.7
    )

    response = tokenizer.decode(
        # outputs[0][inputs["input_ids"]:].tolist(),
        outputs[0][inputs["input_ids"].shape[1]:],
        skip_special_tokens=True
    )

    return {"response": response}

@app.get("/")
def root():
    return {"message": "Qwen FastAPI running 🚀"}

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run("app:app", host="0.0.0.0", port=port)