Spaces:
Sleeping
Sleeping
File size: 1,362 Bytes
0774c8d 57abf81 1aeb97d bcd596d df73aaf bcd596d df73aaf bcd596d 2c7bc87 2bfca34 1aeb97d bcd596d 1aeb97d 3d0b576 f1943e1 1aeb97d f1943e1 57abf81 bcd596d 57abf81 f1943e1 1aeb97d bcd596d 5ba66ef 1aeb97d bcd596d 1aeb97d bcd596d df73aaf 1aeb97d bcd596d 1aeb97d b63600b df73aaf cd480e5 df73aaf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | import os
from fastapi import FastAPI, Request
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
# ✅ Cache dir
CACHE_DIR = "/tmp/hf_cache"
os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.makedirs(CACHE_DIR, exist_ok=True)
# FastAPI
app = FastAPI()
# ✅ Model name
MODEL_NAME = "facebook/blenderbot-400M-distill"
# Load tokenizer & model once
tokenizer = BlenderbotTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
model = BlenderbotForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
@app.get("/")
async def root():
return {"message": "BlenderBot-400M Chatbot API is running!"}
@app.post("/chat")
async def chat(req: Request):
data = await req.json()
user_message = data.get("message", "").strip()
if not user_message:
return {"reply": "Please send a valid message."}
# Encode input
inputs = tokenizer([user_message], return_tensors="pt")
# Generate response
reply_ids = model.generate(
**inputs,
max_length=100,
do_sample=True,
temperature=0.7,
top_p=0.9,
use_cache=False
)
# Decode output
reply = tokenizer.decode(reply_ids[0], skip_special_tokens=True)
return {"reply": reply}
@app.get("/health")
async def health():
return {"ready": True} |