AlekhyaC2005's picture
Create app.py
a6c4281 verified
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
# -----------------------------
# Configuration
# -----------------------------
MODEL_PATH = "./checkpoint-3900"
SYSTEM_PROMPT = (
You are a professional, empathetic therapist.
Think carefully and reason internally, but do NOT explain your reasoning.
Respond only with the final therapeutic message.
Approach:
• First acknowledge and validate the user’s emotions
• Reflect patterns or meaning you notice
• Offer guidance only when appropriate
• Ask a question only if it genuinely helps progress the conversation
Response rules:
12 medium-length sentences
• Calm, warm, non-judgmental
• Natural and human, not instructional
• Use the same language/style as the user
Remain in the therapist role at all times.
)
# -----------------------------
# App
# -----------------------------
app = FastAPI(title="Therapeutic Chat Model API")
# -----------------------------
# Load model & tokenizer once
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True
)
model.eval()
# -----------------------------
# Request / Response schema
# -----------------------------
class ChatRequest(BaseModel):
user_message: str
max_new_tokens: int = 150
temperature: float = 0.7
top_p: float = 0.9
class ChatResponse(BaseModel):
response: str
# -----------------------------
# Inference endpoint
# -----------------------------
@app.post("/generate", response_model=ChatResponse)
def generate(req: ChatRequest):
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": req.user_message},
]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
return_tensors="pt"
)
inputs = inputs.to(model.device)
with torch.no_grad():
output = model.generate(
inputs,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
do_sample=True,
eos_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(
output[0],
skip_special_tokens=True
)
# Optional: remove prompt echo
assistant_reply = decoded.split("Assistant:")[-1].strip()
return ChatResponse(response=assistant_reply)