ST-THOMAS-OF-AQUINAS's picture
Update app.py
c6e66b7 verified
from fastapi import FastAPI, Form
from fastapi.responses import Response
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from twilio.twiml.messaging_response import MessagingResponse
import os
# -----------------------------
# Environment-safe cache path
# -----------------------------
HF_CACHE_DIR = os.getenv("HF_HOME", "/tmp/hf_cache")
# -----------------------------
# Load regression model from Hugging Face
# -----------------------------
model_id = "ST-THOMAS-OF-AQUINAS/impersonation-bart"
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
model = AutoModelForSequenceClassification.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# -----------------------------
# Helper function
# -----------------------------
def predict_score(text: str, debug: bool = False):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
raw_logits = outputs.logits.squeeze()
# Apply sigmoid to convert raw logits into probability [0,1]
score = torch.sigmoid(raw_logits).item()
if debug:
return {
"raw_logits": raw_logits.item(),
"sigmoid_score": score
}
return round(score, 3)
# -----------------------------
# FastAPI app
# -----------------------------
app = FastAPI(title="Impersonation Detector st thomas of aquinas")
# Health-check route
@app.get("/")
async def health_check():
return {"status": "✅ API is running"}
# Simple GET test
@app.get("/predict")
async def get_predict(text: str):
score = predict_score(text)
return {"impersonation_score": score}
# Debugging route
@app.get("/debug")
async def debug_predict(text: str):
debug_data = predict_score(text, debug=True)
return {"debug_output": debug_data}
# -----------------------------
# Twilio WhatsApp POST
# -----------------------------
@app.post("/whatsapp")
async def whatsapp_reply(Body: str = Form(...)):
resp = MessagingResponse()
if Body.strip():
score = predict_score(Body)
reply = f"Impersonation Score: {score}\n(0.0 = genuine, 1.0 = impersonation)"
else:
reply = "⚠️ No text detected."
resp.message(reply)
# Return proper TwiML XML
return Response(content=str(resp), media_type="application/xml")