ST-THOMAS-OF-AQUINAS's picture
Update app.py
bdfb62e verified
from fastapi import FastAPI, Form
from fastapi.responses import Response
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from twilio.twiml.messaging_response import MessagingResponse
import os
# -----------------------------
# Environment-safe cache path
# -----------------------------
HF_CACHE_DIR = os.getenv("HF_HOME", "/tmp/hf_cache")
# -----------------------------
# Load classification model from Hugging Face
# -----------------------------
# ⚠️ Change this to your classification model
model_id = "ST-THOMAS-OF-AQUINAS/DocumentVerifaction"
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
model = AutoModelForSequenceClassification.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# -----------------------------
# Define label mapping
# -----------------------------
# ⚠️ Update this mapping to match your model's labels
label_map = {
0: "Registral",
1: "Dean of Students"
}
# -----------------------------
# Helper function
# -----------------------------
def predict_classification(text: str, debug: bool = False):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
predicted_class_id = torch.argmax(probs).item()
predicted_label = label_map.get(predicted_class_id, str(predicted_class_id))
score_list = probs.tolist()
if debug:
return {
"logits": logits.tolist(),
"probs": score_list,
"predicted_class_id": predicted_class_id,
"predicted_label": predicted_label
}
return {
"predicted_label": predicted_label,
"predicted_class_id": predicted_class_id,
"class_probabilities": {
label_map[i]: float(score_list[i]) for i in range(len(score_list))
}
}
# -----------------------------
# FastAPI app
# -----------------------------
app = FastAPI(title="Document Classification API")
# Health-check route
@app.get("/")
async def health_check():
return {"status": "✅ API is running"}
# Simple GET test
@app.get("/predict")
async def get_predict(text: str):
result = predict_classification(text)
return result
# Debugging route
@app.get("/debug")
async def debug_predict(text: str):
debug_data = predict_classification(text, debug=True)
return {"debug_output": debug_data}
# -----------------------------
# Twilio WhatsApp POST
# -----------------------------
@app.post("/whatsapp")
async def whatsapp_reply(Body: str = Form(...)):
resp = MessagingResponse()
if Body.strip():
result = predict_classification(Body)
reply = (
f"Prediction: {result['predicted_label']}\n"
f"Probabilities: {result['class_probabilities']}"
)
else:
reply = "⚠️ No text detected."
resp.message(reply)
# Return proper TwiML XML
return Response(content=str(resp), media_type="application/xml")