Idiot-llm / main.py
aaa090910's picture
Update main.py
0ba4702 verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import re
app = FastAPI()
# โ˜… 1b ใ‚’ไฝฟ็”จ
MODEL_NAME = "cyberagent/open-calm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
)
# pad_token ใŒ็„กใ„ๅ ดๅˆใฏ EOS ใ‚’ไปฃใ‚ใ‚Šใซไฝฟใ†
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.to(device)
model.eval()
class GenerateRequest(BaseModel):
prompt: str
# ๆ–‡ใจใ—ใฆๆˆ็ซ‹ใ™ใ‚‹็จ‹ๅบฆ๏ผ‹้€Ÿใ•้‡่ฆ–
max_new_tokens: int = 25
def build_prompt(user_prompt: str) -> str:
"""
open-calm-1b ใฏ chat_template ใŒ็„กใ„ใฎใงใ€
็ด ใฎใƒ†ใ‚ญใ‚นใƒˆใงใ€Œใกใ‚‡ใ„ใ‚ขใƒ›ใ‚ญใƒฃใƒฉใ€ใ‚’ๆŒ‡ๅฎšใ™ใ‚‹ใ€‚
"""
system = (
"### ๆŒ‡็คบ\n"
"ใ‚ใชใŸใฏๆ—ฅๆœฌ่ชžใงไผš่ฉฑใ™ใ‚‹ใ€ๅฐ‘ใ—ใ‚ขใƒ›ใชใ‚ขใ‚ทใ‚นใ‚ฟใƒณใƒˆใงใ™ใ€‚\n"
"่ณชๅ•ใฎๆ„ๅ‘ณใฏใ ใ„ใŸใ„ๅˆ†ใ‹ใ‚Šใพใ™ใŒใ€ใจใใฉใๅ‹˜้•ใ„ใ—ใŸใ‚Šใ€"
"้‡่ฆใชใจใ“ใ‚ใ‚’ๅฐ‘ใ—ใ‚บใƒฉใ—ใŸ็ญ”ใˆๆ–นใ‚’ใ—ใพใ™ใ€‚\n"
"ใŸใ ใ—ใ€่ณชๅ•ใจๅฎŒๅ…จใซ็„ก้–ขไฟ‚ใช่ฉฑใ‚„ใ€ๆ„ๅ‘ณไธๆ˜Žใชๆ–‡็ซ ใฏ้ฟใ‘ใฆใใ ใ•ใ„ใ€‚\n"
"ใชใ‚‹ในใ็Ÿญใใ€1ใ€œ2ๆ–‡ใƒป40ๆ–‡ๅญ—ไปฅๅ†…ใ‚’็›ฎๅฎ‰ใซใ€ใƒ•ใƒฉใƒณใ‚ฏใชๅฃ่ชฟใง็ญ”ใˆใฆใใ ใ•ใ„ใ€‚\n\n"
"### ไผš่ฉฑ\n"
)
prompt = system + f"ใƒฆใƒผใ‚ถใƒผ: {user_prompt}\nใ‚ขใ‚ทใ‚นใ‚ฟใƒณใƒˆ:"
return prompt
def shorten(text: str, max_chars: int = 40, max_sentences: int = 2) -> str:
"""
ใƒปๆ”น่กŒใ‚’ใ‚นใƒšใƒผใ‚นใซ
ใƒปๅ…ˆ้ ญใ‹ใ‚‰ๆœ€ๅคง max_sentences ๆ–‡ใ‚’ใคใชใ’ใ‚‹
ใƒปใใ‚Œใงใ‚‚้•ทใ‘ใ‚Œใฐ max_chars ใงใ‚ซใƒƒใƒˆ
"""
text = text.replace("\n", " ")
# ใ€Œใ€‚๏ผ๏ผŸ!?ใ€ใงใ–ใฃใใ‚Šๆ–‡ๅŒบๅˆ‡ใ‚Š
parts = [p.strip() for p in re.split(r"[ใ€‚๏ผ๏ผŸ!?]", text) if p.strip()]
if not parts:
result = text.strip()
else:
# ๅ…ˆ้ ญใ‹ใ‚‰ๆœ€ๅคง max_sentences ๆ–‡ใ‚’ๆ‹พใ†
picked = parts[:max_sentences]
result = "ใ€‚".join(picked)
# ๆ–‡ใฎ็ต‚ใ‚ใ‚Šใซใ€Œใ€‚ใ€ใ‚’ไป˜ใ‘ใฆใŠใ
if not result.endswith("ใ€‚"):
result += "ใ€‚"
# ๆ–‡ๅญ—ๆ•ฐๅˆถ้™
if len(result) > max_chars:
result = result[:max_chars]
return result
@app.post("/generate")
def generate(req: GenerateRequest):
prompt = build_prompt(req.prompt)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
temperature=1.3, # โ† ใ“ใ“ใงใ€Œใกใ‚‡ใ„ใ‚ขใƒ›ๅบฆใ€ใ‚’่ชฟๆ•ด๏ผˆ1.4ใ€œ1.8ใใ‚‰ใ„ใง้Šในใ‚‹๏ผ‰
top_p=0.9,
no_repeat_ngram_size=3,
repetition_penalty=1.0,
pad_token_id=tokenizer.pad_token_id,
)
# ใƒ—ใƒญใƒณใƒ—ใƒˆ้ƒจๅˆ†ใ‚’ๅˆ‡ใ‚Š่ฝใจใ—ใฆ็”Ÿๆˆๅˆ†ใ ใ‘ใ‚’ๅ–ใ‚Šๅ‡บใ™
gen_ids = outputs[0][inputs["input_ids"].shape[-1]:]
text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
# ๅ…ˆ้ ญใซใ€Œใ‚ขใ‚ทใ‚นใ‚ฟใƒณใƒˆ:ใ€ใŒไป˜ใ„ใŸๅ ดๅˆใ ใ‘ๅ‰Šใ‚‹
for kw in ["ใ‚ขใ‚ทใ‚นใ‚ฟใƒณใƒˆ:", "assistant:", "Assistant:"]:
if text.startswith(kw):
text = text[len(kw):].strip()
# โ˜… ้€”ไธญใงๅ‡บใฆใใ‚‹ใ€Œใƒฆใƒผใ‚ถใƒผ:ใ€ใ€ŒUser:ใ€ไปฅ้™ใ‚’ๅ…จ้ƒจใ‚ซใƒƒใƒˆ
# ใƒฆใƒผใ‚ถใƒผ:: ใฟใŸใ„ใชใƒ‘ใ‚ฟใƒผใƒณใ‚‚ใพใจใ‚ใฆๆถˆใ™
text = re.split(r"(ใƒฆใƒผใ‚ถใƒผ|User)\s*[:๏ผš]+", text)[0].strip()
# ็Ÿญใๆ•ดๅฝข
text = shorten(text, max_chars=40)
# ๅฟตใฎใŸใ‚ใฎๆœ€็ต‚ใ‚ซใƒƒใƒˆ
if len(text) > 60:
text = text[:60]
return {"text": text}