Spaces:
Sleeping
Sleeping
File size: 4,530 Bytes
8da2d54 6dfb8b7 8da2d54 6dfb8b7 8da2d54 6dfb8b7 8da2d54 6dfb8b7 8da2d54 6dfb8b7 8da2d54 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | # =========================================================
# INDICBERT MODEL โ CATEGORY CLASSIFICATION (HINDI + TELUGU)
# =========================================================
import os
import re
import torch
import pickle
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# โโ Path config โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ARTIFACT_DIR = os.path.join(BASE_DIR, "artifacts")
MAX_LENGTH = 128
# โโ Load artifacts โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# โโ Load tokenizer from HF Hub โโโโโโโโโโโโโโโโโโโโโโโโโโโ
tokenizer = AutoTokenizer.from_pretrained("mohanbot799s/civicconnect-bert-indic")
with open(os.path.join(ARTIFACT_DIR, "label_encoder.pkl"), "rb") as f:
label_encoder = pickle.load(f)
# โโ Load model from HF Hub โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
model = AutoModelForSequenceClassification.from_pretrained(
"mohanbot799s/civicconnect-bert-indic"
)
model.eval()
# โโ Edge-case constants โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
LABEL_WORDS = {
"water", "electricity", "roads", "garbage",
"sanitation", "pollution", "transport", "animals",
"เฐชเฐพเฐจเฑ", "เฐฌเฐฟเฐเฐฒเฑ", "เฐธเฐกเฐ", "เฐเฐเฐฐเฐพ",
"เฐจเฑเฐฐเฑ", "เฐตเฐฟเฐฆเฑเฐฏเฑเฐคเฑ", "เฐฐเฑเฐกเฑเฐกเฑ", "เฐเฑเฐคเฑเฐค",
}
NON_GRIEVANCE_PHRASES = {
"hello", "hi", "good morning", "good evening",
"thank you", "thanks", "all good", "no issues", "test", "demo",
"เฐจเฐฎเฐธเฑเฐคเฑ", "เฐงเฐจเฑเฐฏเฐตเฐพเฐฆเฐพเฐฒเฑ", "เฐ
เฐจเฑเฐจเฑ เฐฌเฐพเฐเฑเฐจเฑเฐจเฐพเฐฏเฐฟ", "เฐธเฐฎเฐธเฑเฐฏ เฐฒเฑเฐฆเฑ",
}
# โโ Text cleaning (Indic-safe) โโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def clean_text(text: str) -> str:
text = str(text)
text = re.sub(r"<.*?>", " ", text)
text = re.sub(r"[^\u0900-\u097F\u0C00-\u0C7F\u0020-\u007F]", " ", text)
text = re.sub(r"\s+", " ", text).strip()
return text
# โโ Input validation โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def validate_input(text: str):
if not text or not text.strip():
return "empty_text"
text_l = text.strip().lower()
if len(text_l) < 5:
return "too_short"
if len(text_l.split()) < 2:
return "too_few_words"
if text_l in LABEL_WORDS:
return "label_only"
if text_l in NON_GRIEVANCE_PHRASES:
return "non_grievance_text"
return None
# โโ Predict โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def predict(
text: str,
input_ids=None,
attention_mask=None,
) -> dict:
reason = validate_input(text)
if reason:
return {
"status": "failed",
"reason": reason,
"category": None,
"confidence": 0.0,
"class_index": None,
}
cleaned = clean_text(text)
if input_ids is None:
enc = tokenizer(
cleaned,
return_tensors="pt",
truncation=True,
padding=False,
max_length=MAX_LENGTH,
)
input_ids = enc["input_ids"]
attention_mask = enc["attention_mask"]
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
probs = torch.softmax(outputs.logits, dim=1)
conf, pred = torch.max(probs, dim=1)
confidence = conf.item()
predicted_index = pred.item()
if confidence < 0.30:
return {
"status": "success",
"reason": "low_confidence",
"category": "Other",
"confidence": round(confidence, 4),
"class_index": predicted_index,
}
label = label_encoder.inverse_transform([predicted_index])[0]
return {
"status": "success",
"category": label,
"confidence": round(confidence, 4),
"class_index": predicted_index,
}
def get_model_and_tokenizer():
return model, tokenizer |