|
|
import os |
|
|
import requests |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
HF_MODEL_URL = "https://router.huggingface.co/hf-inference/meta-llama/Meta-Llama-3-8B-Instruct" |
|
|
|
|
|
|
|
|
HF_API_TOKEN = os.environ.get("HF_API_TOKEN") |
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
prompt: str |
|
|
max_new_tokens: int = 256 |
|
|
|
|
|
|
|
|
def build_prompt(user_prompt: str) -> str: |
|
|
""" |
|
|
「そこそこ賢いけど、会話が微妙に噛み合わない日本語アシスタント」 |
|
|
用のプロンプトを組み立てる |
|
|
""" |
|
|
return f"""[INST] |
|
|
あなたは、一見まともそうだが会話が少し噛み合わない日本語アシスタントです。 |
|
|
質問にはある程度ちゃんと答えますが、 |
|
|
・重要なポイントを微妙に外したり |
|
|
・ちょっと話題をズラしたり |
|
|
・余計な例え話を挟んだりしてください。 |
|
|
|
|
|
ただし、完全に的外れにはならず、質問と多少は関係のあることを話してください。 |
|
|
必ず日本語で、2〜4段落くらいで返答してください。 |
|
|
|
|
|
ユーザー: {user_prompt} |
|
|
アシスタント: |
|
|
[/INST] |
|
|
""" |
|
|
|
|
|
|
|
|
@app.post("/generate") |
|
|
def generate(req: GenerateRequest): |
|
|
if HF_API_TOKEN is None: |
|
|
|
|
|
raise HTTPException(status_code=500, detail="HF_API_TOKEN が環境変数に設定されていません。") |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {HF_API_TOKEN}", |
|
|
"Content-Type": "application/json", |
|
|
} |
|
|
|
|
|
payload = { |
|
|
"inputs": build_prompt(req.prompt), |
|
|
"parameters": { |
|
|
"max_new_tokens": req.max_new_tokens, |
|
|
"temperature": 1.2, |
|
|
"top_p": 0.9, |
|
|
}, |
|
|
} |
|
|
|
|
|
try: |
|
|
resp = requests.post(HF_MODEL_URL, headers=headers, json=payload, timeout=60) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Inference API request error: {e}") |
|
|
|
|
|
if resp.status_code != 200: |
|
|
|
|
|
raise HTTPException(status_code=resp.status_code, detail=resp.text) |
|
|
|
|
|
data = resp.json() |
|
|
|
|
|
|
|
|
if isinstance(data, list) and len(data) > 0 and "generated_text" in data[0]: |
|
|
text = data[0]["generated_text"] |
|
|
else: |
|
|
raise HTTPException(status_code=500, detail=f"Unexpected response format: {data}") |
|
|
|
|
|
|
|
|
marker = "アシスタント:" |
|
|
if marker in text: |
|
|
text = text.split(marker, 1)[-1].strip() |
|
|
|
|
|
|