flan-t5-chatbot / app.py
MASSJ77's picture
Update app.py
3d2e39d verified
import os
from fastapi import FastAPI, Request
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
# ✅ Cache dir
CACHE_DIR = "/tmp/hf_cache"
os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.makedirs(CACHE_DIR, exist_ok=True)
# FastAPI
app = FastAPI()
# ✅ Model name
MODEL_NAME = "facebook/blenderbot-400M-distill"
# Load tokenizer & model once
tokenizer = BlenderbotTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
model = BlenderbotForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
@app.get("/")
async def root():
return {"message": "BlenderBot-400M Chatbot API is running!"}
@app.post("/chat")
async def chat(req: Request):
data = await req.json()
user_message = data.get("message", "").strip()
if not user_message:
return {"reply": "Please send a valid message."}
# Encode input
inputs = tokenizer([user_message], return_tensors="pt")
# Generate response
reply_ids = model.generate(
**inputs,
max_length=100,
do_sample=True,
temperature=0.7,
top_p=0.9,
use_cache=False
)
# Decode output
reply = tokenizer.decode(reply_ids[0], skip_special_tokens=True)
return {"reply": reply}
@app.get("/health")
async def health():
return {"ready": True}