import os from fastapi import FastAPI, Request from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration # ✅ Cache dir CACHE_DIR = "/tmp/hf_cache" os.environ["HF_HOME"] = CACHE_DIR os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR os.makedirs(CACHE_DIR, exist_ok=True) # FastAPI app = FastAPI() # ✅ Model name MODEL_NAME = "facebook/blenderbot-400M-distill" # Load tokenizer & model once tokenizer = BlenderbotTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) model = BlenderbotForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) @app.get("/") async def root(): return {"message": "BlenderBot-400M Chatbot API is running!"} @app.post("/chat") async def chat(req: Request): data = await req.json() user_message = data.get("message", "").strip() if not user_message: return {"reply": "Please send a valid message."} # Encode input inputs = tokenizer([user_message], return_tensors="pt") # Generate response reply_ids = model.generate( **inputs, max_length=100, do_sample=True, temperature=0.7, top_p=0.9, use_cache=False ) # Decode output reply = tokenizer.decode(reply_ids[0], skip_special_tokens=True) return {"reply": reply} @app.get("/health") async def health(): return {"ready": True}