File size: 4,530 Bytes
8da2d54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dfb8b7
 
8da2d54
 
 
 
6dfb8b7
8da2d54
6dfb8b7
8da2d54
 
 
 
 
 
 
6dfb8b7
8da2d54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dfb8b7
 
8da2d54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# =========================================================
# INDICBERT MODEL โ€” CATEGORY CLASSIFICATION (HINDI + TELUGU)
# =========================================================

import os
import re
import torch
import pickle
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# โ”€โ”€ Path config โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
BASE_DIR     = os.path.dirname(os.path.abspath(__file__))
ARTIFACT_DIR = os.path.join(BASE_DIR, "artifacts")
MAX_LENGTH   = 128

# โ”€โ”€ Load artifacts โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ”€โ”€ Load tokenizer from HF Hub โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
tokenizer = AutoTokenizer.from_pretrained("mohanbot799s/civicconnect-bert-indic")

with open(os.path.join(ARTIFACT_DIR, "label_encoder.pkl"), "rb") as f:
    label_encoder = pickle.load(f)

# โ”€โ”€ Load model from HF Hub โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
model = AutoModelForSequenceClassification.from_pretrained(
    "mohanbot799s/civicconnect-bert-indic"
)
model.eval()

# โ”€โ”€ Edge-case constants โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
LABEL_WORDS = {
    "water", "electricity", "roads", "garbage",
    "sanitation", "pollution", "transport", "animals",
    "เฐชเฐพเฐจเฑ€", "เฐฌเฐฟเฐœเฐฒเฑ€", "เฐธเฐกเฐ•", "เฐ•เฐšเฐฐเฐพ",
    "เฐจเฑ€เฐฐเฑ", "เฐตเฐฟเฐฆเฑเฐฏเฑเฐคเฑ", "เฐฐเฑ‹เฐกเฑเฐกเฑ", "เฐšเฑ†เฐคเฑเฐค",
}

NON_GRIEVANCE_PHRASES = {
    "hello", "hi", "good morning", "good evening",
    "thank you", "thanks", "all good", "no issues", "test", "demo",
    "เฐจเฐฎเฐธเฑเฐคเฑ‡", "เฐงเฐจเฑเฐฏเฐตเฐพเฐฆเฐพเฐฒเฑ", "เฐ…เฐจเฑเฐจเฑ€ เฐฌเฐพเฐ—เฑเฐจเฑเฐจเฐพเฐฏเฐฟ", "เฐธเฐฎเฐธเฑเฐฏ เฐฒเฑ‡เฐฆเฑ",
}


# โ”€โ”€ Text cleaning (Indic-safe) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def clean_text(text: str) -> str:
    text = str(text)
    text = re.sub(r"<.*?>", " ", text)
    text = re.sub(r"[^\u0900-\u097F\u0C00-\u0C7F\u0020-\u007F]", " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text


# โ”€โ”€ Input validation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def validate_input(text: str):
    if not text or not text.strip():
        return "empty_text"
    text_l = text.strip().lower()
    if len(text_l) < 5:
        return "too_short"
    if len(text_l.split()) < 2:
        return "too_few_words"
    if text_l in LABEL_WORDS:
        return "label_only"
    if text_l in NON_GRIEVANCE_PHRASES:
        return "non_grievance_text"
    return None


# โ”€โ”€ Predict โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def predict(
    text: str,
    input_ids=None,
    attention_mask=None,
) -> dict:
    reason = validate_input(text)
    if reason:
        return {
            "status":      "failed",
            "reason":      reason,
            "category":    None,
            "confidence":  0.0,
            "class_index": None,
        }

    cleaned = clean_text(text)

    if input_ids is None:
        enc = tokenizer(
            cleaned,
            return_tensors="pt",
            truncation=True,
            padding=False,
            max_length=MAX_LENGTH,
        )
        input_ids      = enc["input_ids"]
        attention_mask = enc["attention_mask"]

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

    probs      = torch.softmax(outputs.logits, dim=1)
    conf, pred = torch.max(probs, dim=1)
    confidence      = conf.item()
    predicted_index = pred.item()

    if confidence < 0.30:
        return {
            "status":      "success",
            "reason":      "low_confidence",
            "category":    "Other",
            "confidence":  round(confidence, 4),
            "class_index": predicted_index,
        }

    label = label_encoder.inverse_transform([predicted_index])[0]

    return {
        "status":      "success",
        "category":    label,
        "confidence":  round(confidence, 4),
        "class_index": predicted_index,
    }


def get_model_and_tokenizer():
    return model, tokenizer