from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM import torch import re app = FastAPI() # ★ 1b を使用 MODEL_NAME = "cyberagent/open-calm-1b" device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if torch.cuda.is_available() else torch.float32 tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=dtype, ) # pad_token が無い場合は EOS を代わりに使う if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model.to(device) model.eval() class GenerateRequest(BaseModel): prompt: str # 文として成立する程度+速さ重視 max_new_tokens: int = 25 def build_prompt(user_prompt: str) -> str: """ open-calm-1b は chat_template が無いので、 素のテキストで「ちょいアホキャラ」を指定する。 """ system = ( "### 指示\n" "あなたは日本語で会話する、少しアホなアシスタントです。\n" "質問の意味はだいたい分かりますが、ときどき勘違いしたり、" "重要なところを少しズラした答え方をします。\n" "ただし、質問と完全に無関係な話や、意味不明な文章は避けてください。\n" "なるべく短く、1〜2文・40文字以内を目安に、フランクな口調で答えてください。\n\n" "### 会話\n" ) prompt = system + f"ユーザー: {user_prompt}\nアシスタント:" return prompt def shorten(text: str, max_chars: int = 40, max_sentences: int = 2) -> str: """ ・改行をスペースに ・先頭から最大 max_sentences 文をつなげる ・それでも長ければ max_chars でカット """ text = text.replace("\n", " ") # 「。!?!?」でざっくり文区切り parts = [p.strip() for p in re.split(r"[。!?!?]", text) if p.strip()] if not parts: result = text.strip() else: # 先頭から最大 max_sentences 文を拾う picked = parts[:max_sentences] result = "。".join(picked) # 文の終わりに「。」を付けておく if not result.endswith("。"): result += "。" # 文字数制限 if len(result) > max_chars: result = result[:max_chars] return result @app.post("/generate") def generate(req: GenerateRequest): prompt = build_prompt(req.prompt) inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=req.max_new_tokens, do_sample=True, temperature=1.3, # ← ここで「ちょいアホ度」を調整(1.4〜1.8くらいで遊べる) top_p=0.9, no_repeat_ngram_size=3, repetition_penalty=1.0, pad_token_id=tokenizer.pad_token_id, ) # プロンプト部分を切り落として生成分だけを取り出す gen_ids = outputs[0][inputs["input_ids"].shape[-1]:] text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() # 先頭に「アシスタント:」が付いた場合だけ削る for kw in ["アシスタント:", "assistant:", "Assistant:"]: if text.startswith(kw): text = text[len(kw):].strip() # ★ 途中で出てくる「ユーザー:」「User:」以降を全部カット # ユーザー:: みたいなパターンもまとめて消す text = re.split(r"(ユーザー|User)\s*[::]+", text)[0].strip() # 短く整形 text = shorten(text, max_chars=40) # 念のための最終カット if len(text) > 60: text = text[:60] return {"text": text}