from fastapi import FastAPI, Form from fastapi.responses import Response from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch from twilio.twiml.messaging_response import MessagingResponse import os # ----------------------------- # Environment-safe cache path # ----------------------------- HF_CACHE_DIR = os.getenv("HF_HOME", "/tmp/hf_cache") # ----------------------------- # Load classification model from Hugging Face # ----------------------------- # ⚠️ Change this to your classification model model_id = "ST-THOMAS-OF-AQUINAS/DocumentVerifaction" tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=HF_CACHE_DIR) model = AutoModelForSequenceClassification.from_pretrained(model_id, cache_dir=HF_CACHE_DIR) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # ----------------------------- # Define label mapping # ----------------------------- # ⚠️ Update this mapping to match your model's labels label_map = { 0: "Registral", 1: "Dean of Students" } # ----------------------------- # Helper function # ----------------------------- def predict_classification(text: str, debug: bool = False): inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=256 ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=-1).squeeze() predicted_class_id = torch.argmax(probs).item() predicted_label = label_map.get(predicted_class_id, str(predicted_class_id)) score_list = probs.tolist() if debug: return { "logits": logits.tolist(), "probs": score_list, "predicted_class_id": predicted_class_id, "predicted_label": predicted_label } return { "predicted_label": predicted_label, "predicted_class_id": predicted_class_id, "class_probabilities": { label_map[i]: float(score_list[i]) for i in range(len(score_list)) } } # ----------------------------- # FastAPI app # ----------------------------- app = FastAPI(title="Document Classification API") # Health-check route @app.get("/") async def health_check(): return {"status": "✅ API is running"} # Simple GET test @app.get("/predict") async def get_predict(text: str): result = predict_classification(text) return result # Debugging route @app.get("/debug") async def debug_predict(text: str): debug_data = predict_classification(text, debug=True) return {"debug_output": debug_data} # ----------------------------- # Twilio WhatsApp POST # ----------------------------- @app.post("/whatsapp") async def whatsapp_reply(Body: str = Form(...)): resp = MessagingResponse() if Body.strip(): result = predict_classification(Body) reply = ( f"Prediction: {result['predicted_label']}\n" f"Probabilities: {result['class_probabilities']}" ) else: reply = "⚠️ No text detected." resp.message(reply) # Return proper TwiML XML return Response(content=str(resp), media_type="application/xml")