File size: 3,329 Bytes
85add87 bdfb62e 248aee5 bdfb62e 85add87 bdfb62e 85add87 bdfb62e 85add87 bdfb62e 85add87 bdfb62e 85add87 bdfb62e 85add87 248aee5 85add87 bdfb62e 85add87 bdfb62e 85add87 bdfb62e 85add87 bdfb62e 85add87 bdfb62e 85add87 bdfb62e 85add87 bdfb62e 85add87 bdfb62e 85add87 bdfb62e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | from fastapi import FastAPI, Form
from fastapi.responses import Response
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from twilio.twiml.messaging_response import MessagingResponse
import os
# -----------------------------
# Environment-safe cache path
# -----------------------------
HF_CACHE_DIR = os.getenv("HF_HOME", "/tmp/hf_cache")
# -----------------------------
# Load classification model from Hugging Face
# -----------------------------
# ⚠️ Change this to your classification model
model_id = "ST-THOMAS-OF-AQUINAS/DocumentVerifaction"
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
model = AutoModelForSequenceClassification.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# -----------------------------
# Define label mapping
# -----------------------------
# ⚠️ Update this mapping to match your model's labels
label_map = {
0: "Registral",
1: "Dean of Students"
}
# -----------------------------
# Helper function
# -----------------------------
def predict_classification(text: str, debug: bool = False):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
predicted_class_id = torch.argmax(probs).item()
predicted_label = label_map.get(predicted_class_id, str(predicted_class_id))
score_list = probs.tolist()
if debug:
return {
"logits": logits.tolist(),
"probs": score_list,
"predicted_class_id": predicted_class_id,
"predicted_label": predicted_label
}
return {
"predicted_label": predicted_label,
"predicted_class_id": predicted_class_id,
"class_probabilities": {
label_map[i]: float(score_list[i]) for i in range(len(score_list))
}
}
# -----------------------------
# FastAPI app
# -----------------------------
app = FastAPI(title="Document Classification API")
# Health-check route
@app.get("/")
async def health_check():
return {"status": "✅ API is running"}
# Simple GET test
@app.get("/predict")
async def get_predict(text: str):
result = predict_classification(text)
return result
# Debugging route
@app.get("/debug")
async def debug_predict(text: str):
debug_data = predict_classification(text, debug=True)
return {"debug_output": debug_data}
# -----------------------------
# Twilio WhatsApp POST
# -----------------------------
@app.post("/whatsapp")
async def whatsapp_reply(Body: str = Form(...)):
resp = MessagingResponse()
if Body.strip():
result = predict_classification(Body)
reply = (
f"Prediction: {result['predicted_label']}\n"
f"Probabilities: {result['class_probabilities']}"
)
else:
reply = "⚠️ No text detected."
resp.message(reply)
# Return proper TwiML XML
return Response(content=str(resp), media_type="application/xml")
|