File size: 3,329 Bytes
85add87
 
 
bdfb62e
 
 
248aee5
bdfb62e
 
 
85add87
 
bdfb62e
 
 
 
85add87
 
bdfb62e
 
 
85add87
bdfb62e
 
 
 
 
 
 
85add87
 
bdfb62e
85add87
 
bdfb62e
 
 
 
 
 
 
 
 
 
 
85add87
248aee5
85add87
 
 
 
 
bdfb62e
 
 
 
 
 
 
 
 
 
 
 
85add87
bdfb62e
 
 
 
 
85add87
 
bdfb62e
 
 
 
 
 
85add87
bdfb62e
85add87
 
bdfb62e
85add87
bdfb62e
 
 
85add87
bdfb62e
85add87
bdfb62e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85add87
bdfb62e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from fastapi import FastAPI, Form
from fastapi.responses import Response
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from twilio.twiml.messaging_response import MessagingResponse
import os

# -----------------------------
# Environment-safe cache path
# -----------------------------
HF_CACHE_DIR = os.getenv("HF_HOME", "/tmp/hf_cache")

# -----------------------------
# Load classification model from Hugging Face
# -----------------------------
# ⚠️ Change this to your classification model
model_id = "ST-THOMAS-OF-AQUINAS/DocumentVerifaction"

tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
model = AutoModelForSequenceClassification.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# -----------------------------
# Define label mapping
# -----------------------------
# ⚠️ Update this mapping to match your model's labels
label_map = {
    0: "Registral",
    1: "Dean of Students"
}

# -----------------------------
# Helper function
# -----------------------------
def predict_classification(text: str, debug: bool = False):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=256
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()

    predicted_class_id = torch.argmax(probs).item()
    predicted_label = label_map.get(predicted_class_id, str(predicted_class_id))
    score_list = probs.tolist()

    if debug:
        return {
            "logits": logits.tolist(),
            "probs": score_list,
            "predicted_class_id": predicted_class_id,
            "predicted_label": predicted_label
        }

    return {
        "predicted_label": predicted_label,
        "predicted_class_id": predicted_class_id,
        "class_probabilities": {
            label_map[i]: float(score_list[i]) for i in range(len(score_list))
        }
    }

# -----------------------------
# FastAPI app
# -----------------------------
app = FastAPI(title="Document Classification API")

# Health-check route
@app.get("/")
async def health_check():
    return {"status": "✅ API is running"}

# Simple GET test
@app.get("/predict")
async def get_predict(text: str):
    result = predict_classification(text)
    return result

# Debugging route
@app.get("/debug")
async def debug_predict(text: str):
    debug_data = predict_classification(text, debug=True)
    return {"debug_output": debug_data}

# -----------------------------
# Twilio WhatsApp POST
# -----------------------------
@app.post("/whatsapp")
async def whatsapp_reply(Body: str = Form(...)):
    resp = MessagingResponse()

    if Body.strip():
        result = predict_classification(Body)
        reply = (
            f"Prediction: {result['predicted_label']}\n"
            f"Probabilities: {result['class_probabilities']}"
        )
    else:
        reply = "⚠️ No text detected."

    resp.message(reply)
    
    # Return proper TwiML XML
    return Response(content=str(resp), media_type="application/xml")