File size: 1,362 Bytes
0774c8d
57abf81
1aeb97d
bcd596d
df73aaf
bcd596d
 
 
 
 
df73aaf
bcd596d
2c7bc87
2bfca34
1aeb97d
bcd596d
1aeb97d
 
 
3d0b576
f1943e1
 
1aeb97d
f1943e1
 
 
 
57abf81
bcd596d
57abf81
f1943e1
 
1aeb97d
 
bcd596d
5ba66ef
1aeb97d
bcd596d
1aeb97d
bcd596d
df73aaf
 
1aeb97d
bcd596d
 
1aeb97d
 
b63600b
df73aaf
cd480e5
 
 
df73aaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
from fastapi import FastAPI, Request
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration

# ✅ Cache dir
CACHE_DIR = "/tmp/hf_cache"
os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.makedirs(CACHE_DIR, exist_ok=True)

# FastAPI
app = FastAPI()

# ✅ Model name
MODEL_NAME = "facebook/blenderbot-400M-distill"

# Load tokenizer & model once
tokenizer = BlenderbotTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
model = BlenderbotForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)

@app.get("/")
async def root():
    return {"message": "BlenderBot-400M Chatbot API is running!"}

@app.post("/chat")
async def chat(req: Request):
    data = await req.json()
    user_message = data.get("message", "").strip()

    if not user_message:
        return {"reply": "Please send a valid message."}

    # Encode input
    inputs = tokenizer([user_message], return_tensors="pt")

    # Generate response
    reply_ids = model.generate(
        **inputs,
        max_length=100,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        use_cache=False 
    )

    # Decode output
    reply = tokenizer.decode(reply_ids[0], skip_special_tokens=True)

    return {"reply": reply}

@app.get("/health")
async def health():
    return {"ready": True}