|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
import re |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
MODEL_NAME = "cyberagent/open-calm-1b" |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_NAME, |
|
|
torch_dtype=dtype, |
|
|
) |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
prompt: str |
|
|
|
|
|
max_new_tokens: int = 25 |
|
|
|
|
|
|
|
|
def build_prompt(user_prompt: str) -> str: |
|
|
""" |
|
|
open-calm-1b ใฏ chat_template ใ็กใใฎใงใ |
|
|
็ด ใฎใใญในใใงใใกใใใขใใญใฃใฉใใๆๅฎใใใ |
|
|
""" |
|
|
system = ( |
|
|
"### ๆ็คบ\n" |
|
|
"ใใชใใฏๆฅๆฌ่ชใงไผ่ฉฑใใใๅฐใใขใใชใขใทในใฟใณใใงใใ\n" |
|
|
"่ณชๅใฎๆๅณใฏใ ใใใๅใใใพใใใใจใใฉใๅ้ใใใใใ" |
|
|
"้่ฆใชใจใใใๅฐใใบใฉใใ็ญใๆนใใใพใใ\n" |
|
|
"ใใ ใใ่ณชๅใจๅฎๅ
จใซ็ก้ขไฟใช่ฉฑใใๆๅณไธๆใชๆ็ซ ใฏ้ฟใใฆใใ ใใใ\n" |
|
|
"ใชใในใ็ญใใ1ใ2ๆใป40ๆๅญไปฅๅ
ใ็ฎๅฎใซใใใฉใณใฏใชๅฃ่ชฟใง็ญใใฆใใ ใใใ\n\n" |
|
|
"### ไผ่ฉฑ\n" |
|
|
) |
|
|
prompt = system + f"ใฆใผใถใผ: {user_prompt}\nใขใทในใฟใณใ:" |
|
|
return prompt |
|
|
|
|
|
|
|
|
def shorten(text: str, max_chars: int = 40, max_sentences: int = 2) -> str: |
|
|
""" |
|
|
ใปๆน่กใในใใผในใซ |
|
|
ใปๅ
้ ญใใๆๅคง max_sentences ๆใใคใชใใ |
|
|
ใปใใใงใ้ทใใใฐ max_chars ใงใซใใ |
|
|
""" |
|
|
text = text.replace("\n", " ") |
|
|
|
|
|
|
|
|
parts = [p.strip() for p in re.split(r"[ใ๏ผ๏ผ!?]", text) if p.strip()] |
|
|
|
|
|
if not parts: |
|
|
result = text.strip() |
|
|
else: |
|
|
|
|
|
picked = parts[:max_sentences] |
|
|
result = "ใ".join(picked) |
|
|
|
|
|
if not result.endswith("ใ"): |
|
|
result += "ใ" |
|
|
|
|
|
|
|
|
if len(result) > max_chars: |
|
|
result = result[:max_chars] |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
@app.post("/generate") |
|
|
def generate(req: GenerateRequest): |
|
|
prompt = build_prompt(req.prompt) |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=req.max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=1.3, |
|
|
top_p=0.9, |
|
|
no_repeat_ngram_size=3, |
|
|
repetition_penalty=1.0, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
gen_ids = outputs[0][inputs["input_ids"].shape[-1]:] |
|
|
text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
|
for kw in ["ใขใทในใฟใณใ:", "assistant:", "Assistant:"]: |
|
|
if text.startswith(kw): |
|
|
text = text[len(kw):].strip() |
|
|
|
|
|
|
|
|
|
|
|
text = re.split(r"(ใฆใผใถใผ|User)\s*[:๏ผ]+", text)[0].strip() |
|
|
|
|
|
|
|
|
text = shorten(text, max_chars=40) |
|
|
|
|
|
|
|
|
if len(text) > 60: |
|
|
text = text[:60] |
|
|
|
|
|
return {"text": text} |
|
|
|
|
|
|