from fastapi import FastAPI, Form from fastapi.responses import Response from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch from twilio.twiml.messaging_response import MessagingResponse import os # ----------------------------- # Environment-safe cache path # ----------------------------- HF_CACHE_DIR = os.getenv("HF_HOME", "/tmp/hf_cache") # ----------------------------- # Load regression model from Hugging Face # ----------------------------- model_id = "ST-THOMAS-OF-AQUINAS/impersonation-bart" tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=HF_CACHE_DIR) model = AutoModelForSequenceClassification.from_pretrained(model_id, cache_dir=HF_CACHE_DIR) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # ----------------------------- # Helper function # ----------------------------- def predict_score(text: str, debug: bool = False): inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=256 ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) raw_logits = outputs.logits.squeeze() # Apply sigmoid to convert raw logits into probability [0,1] score = torch.sigmoid(raw_logits).item() if debug: return { "raw_logits": raw_logits.item(), "sigmoid_score": score } return round(score, 3) # ----------------------------- # FastAPI app # ----------------------------- app = FastAPI(title="Impersonation Detector st thomas of aquinas") # Health-check route @app.get("/") async def health_check(): return {"status": "✅ API is running"} # Simple GET test @app.get("/predict") async def get_predict(text: str): score = predict_score(text) return {"impersonation_score": score} # Debugging route @app.get("/debug") async def debug_predict(text: str): debug_data = predict_score(text, debug=True) return {"debug_output": debug_data} # ----------------------------- # Twilio WhatsApp POST # ----------------------------- @app.post("/whatsapp") async def whatsapp_reply(Body: str = Form(...)): resp = MessagingResponse() if Body.strip(): score = predict_score(Body) reply = f"Impersonation Score: {score}\n(0.0 = genuine, 1.0 = impersonation)" else: reply = "⚠️ No text detected." resp.message(reply) # Return proper TwiML XML return Response(content=str(resp), media_type="application/xml")