|
|
from fastapi import FastAPI, Form |
|
|
from fastapi.responses import Response |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
from twilio.twiml.messaging_response import MessagingResponse |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_CACHE_DIR = os.getenv("HF_HOME", "/tmp/hf_cache") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_id = "ST-THOMAS-OF-AQUINAS/DocumentVerifaction" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=HF_CACHE_DIR) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_id, cache_dir=HF_CACHE_DIR) |
|
|
model.eval() |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
label_map = { |
|
|
0: "Registral", |
|
|
1: "Dean of Students" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_classification(text: str, debug: bool = False): |
|
|
inputs = tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=256 |
|
|
) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze() |
|
|
|
|
|
predicted_class_id = torch.argmax(probs).item() |
|
|
predicted_label = label_map.get(predicted_class_id, str(predicted_class_id)) |
|
|
score_list = probs.tolist() |
|
|
|
|
|
if debug: |
|
|
return { |
|
|
"logits": logits.tolist(), |
|
|
"probs": score_list, |
|
|
"predicted_class_id": predicted_class_id, |
|
|
"predicted_label": predicted_label |
|
|
} |
|
|
|
|
|
return { |
|
|
"predicted_label": predicted_label, |
|
|
"predicted_class_id": predicted_class_id, |
|
|
"class_probabilities": { |
|
|
label_map[i]: float(score_list[i]) for i in range(len(score_list)) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="Document Classification API") |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def health_check(): |
|
|
return {"status": "✅ API is running"} |
|
|
|
|
|
|
|
|
@app.get("/predict") |
|
|
async def get_predict(text: str): |
|
|
result = predict_classification(text) |
|
|
return result |
|
|
|
|
|
|
|
|
@app.get("/debug") |
|
|
async def debug_predict(text: str): |
|
|
debug_data = predict_classification(text, debug=True) |
|
|
return {"debug_output": debug_data} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/whatsapp") |
|
|
async def whatsapp_reply(Body: str = Form(...)): |
|
|
resp = MessagingResponse() |
|
|
|
|
|
if Body.strip(): |
|
|
result = predict_classification(Body) |
|
|
reply = ( |
|
|
f"Prediction: {result['predicted_label']}\n" |
|
|
f"Probabilities: {result['class_probabilities']}" |
|
|
) |
|
|
else: |
|
|
reply = "⚠️ No text detected." |
|
|
|
|
|
resp.message(reply) |
|
|
|
|
|
|
|
|
return Response(content=str(resp), media_type="application/xml") |
|
|
|