File size: 5,276 Bytes
8da2d54
 
 
 
 
 
 
 
 
 
 
 
 
6dfb8b7
8da2d54
 
 
 
 
 
 
 
6dfb8b7
8da2d54
6dfb8b7
8da2d54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dfb8b7
 
8da2d54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dfb8b7
8da2d54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# =========================================================
# BERT MODEL β€” CATEGORY CLASSIFICATION (ENGLISH)
# =========================================================

import os
import re
import torch
import pickle
from transformers import BertForSequenceClassification

# ── Path config ───────────────────────────────────────────
BASE_DIR     = os.path.dirname(os.path.abspath(__file__))
ARTIFACT_DIR = os.path.join(BASE_DIR, "artifacts")
MAX_LENGTH   = 128

# ── Load artifacts ────────────────────────────────────────
with open(os.path.join(ARTIFACT_DIR, "tokenizer.pkl"), "rb") as f:
    tokenizer = pickle.load(f)

with open(os.path.join(ARTIFACT_DIR, "label_encoder.pkl"), "rb") as f:
    label_encoder = pickle.load(f)

# ── Load model from HF Hub ────────────────────────────────
model = BertForSequenceClassification.from_pretrained(
    "mohanbot799s/civicconnect-bert-en"
)
model.eval()

# ── Edge-case constants ───────────────────────────────────
LABEL_WORDS = {
    "water", "electricity", "roads", "garbage",
    "sanitation", "pollution", "transport", "animals",
}

NON_GRIEVANCE_PHRASES = {
    "hello", "hi", "hi there", "hey", "hey there",
    "good morning", "good afternoon", "good evening", "good day",
    "greetings", "namaste", "how are you", "how are you doing",
    "hope you are doing well", "hope everything is fine",
    "just checking in", "nice to meet you", "long time no see",
    "good weather", "nice weather", "weather is nice", "weather is good",
    "it is a sunny day", "it is raining today", "pleasant weather",
    "cool weather today", "hot weather today", "cold weather today",
    "it is a good day", "everything is fine", "all good", "no issues",
    "no problem", "things are okay", "everything looks good",
    "nothing to complain", "all services are working",
    "thank you", "thanks", "thanks a lot", "thank you very much",
    "appreciate it", "appreciate your help", "great work", "good job",
    "well done", "excellent service", "for your information",
    "just informing", "sharing information", "today is a holiday",
    "office opens at 10 am", "school reopens next week",
    "meeting scheduled tomorrow", "okay", "ok", "alright", "fine",
    "cool", "great", "nice", "regards", "best regards", "with regards",
    "kind regards", "thank you and regards", "thank you very much sir",
    "test", "testing", "demo", "sample text", "random text",
    "πŸ™‚", "πŸ‘", "πŸ™", "πŸ˜‚", "πŸ”₯", "!!!", "???",
}


# ── Text cleaning ─────────────────────────────────────────
def clean_text(text: str) -> str:
    text = str(text)
    text = re.sub(r"<.*?>", " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text


# ── Input validation ──────────────────────────────────────
def validate_input(text: str):
    if not text or not text.strip():
        return "empty_text"
    text_l = text.strip().lower()
    if len(text_l) < 10:
        return "too_short"
    if len(text_l.split()) < 3:
        return "too_few_words"
    if text_l in LABEL_WORDS:
        return "label_only"
    if text_l in NON_GRIEVANCE_PHRASES:
        return "non_grievance_text"
    return None


# ── Predict ───────────────────────────────────────────────
def predict(
    text: str,
    input_ids=None,
    attention_mask=None,
) -> dict:
    reason = validate_input(text)
    if reason:
        return {
            "status":      "failed",
            "reason":      reason,
            "category":    None,
            "confidence":  0.0,
            "class_index": None,
        }

    cleaned = clean_text(text)

    if input_ids is None:
        enc = tokenizer(
            cleaned,
            return_tensors="pt",
            truncation=True,
            padding=False,
            max_length=MAX_LENGTH,
        )
        input_ids      = enc["input_ids"]
        attention_mask = enc["attention_mask"]

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

    probs      = torch.softmax(outputs.logits, dim=1)
    conf, pred = torch.max(probs, dim=1)
    confidence      = conf.item()
    predicted_index = pred.item()

    if confidence < 0.30:
        return {
            "status":      "success",
            "reason":      "low_confidence",
            "category":    "Other",
            "confidence":  round(confidence, 4),
            "class_index": predicted_index,
        }

    label = label_encoder.inverse_transform([predicted_index])[0]

    return {
        "status":      "success",
        "category":    label,
        "confidence":  round(confidence, 4),
        "class_index": predicted_index,
    }


def get_model_and_tokenizer():
    return model, tokenizer