# ========================================================= # INDICBERT MODEL — CATEGORY CLASSIFICATION (HINDI + TELUGU) # ========================================================= import os import re import torch import pickle from transformers import AutoTokenizer, AutoModelForSequenceClassification # ── Path config ─────────────────────────────────────────── BASE_DIR = os.path.dirname(os.path.abspath(__file__)) ARTIFACT_DIR = os.path.join(BASE_DIR, "artifacts") MAX_LENGTH = 128 # ── Load artifacts ──────────────────────────────────────── # ── Load tokenizer from HF Hub ─────────────────────────── tokenizer = AutoTokenizer.from_pretrained("mohanbot799s/civicconnect-bert-indic") with open(os.path.join(ARTIFACT_DIR, "label_encoder.pkl"), "rb") as f: label_encoder = pickle.load(f) # ── Load model from HF Hub ──────────────────────────────── model = AutoModelForSequenceClassification.from_pretrained( "mohanbot799s/civicconnect-bert-indic" ) model.eval() # ── Edge-case constants ─────────────────────────────────── LABEL_WORDS = { "water", "electricity", "roads", "garbage", "sanitation", "pollution", "transport", "animals", "పానీ", "బిజలీ", "సడక", "కచరా", "నీరు", "విద్యుత్", "రోడ్డు", "చెత్త", } NON_GRIEVANCE_PHRASES = { "hello", "hi", "good morning", "good evening", "thank you", "thanks", "all good", "no issues", "test", "demo", "నమస్తే", "ధన్యవాదాలు", "అన్నీ బాగున్నాయి", "సమస్య లేదు", } # ── Text cleaning (Indic-safe) ──────────────────────────── def clean_text(text: str) -> str: text = str(text) text = re.sub(r"<.*?>", " ", text) text = re.sub(r"[^\u0900-\u097F\u0C00-\u0C7F\u0020-\u007F]", " ", text) text = re.sub(r"\s+", " ", text).strip() return text # ── Input validation ────────────────────────────────────── def validate_input(text: str): if not text or not text.strip(): return "empty_text" text_l = text.strip().lower() if len(text_l) < 5: return "too_short" if len(text_l.split()) < 2: return "too_few_words" if text_l in LABEL_WORDS: return "label_only" if text_l in NON_GRIEVANCE_PHRASES: return "non_grievance_text" return None # ── Predict ─────────────────────────────────────────────── def predict( text: str, input_ids=None, attention_mask=None, ) -> dict: reason = validate_input(text) if reason: return { "status": "failed", "reason": reason, "category": None, "confidence": 0.0, "class_index": None, } cleaned = clean_text(text) if input_ids is None: enc = tokenizer( cleaned, return_tensors="pt", truncation=True, padding=False, max_length=MAX_LENGTH, ) input_ids = enc["input_ids"] attention_mask = enc["attention_mask"] with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) probs = torch.softmax(outputs.logits, dim=1) conf, pred = torch.max(probs, dim=1) confidence = conf.item() predicted_index = pred.item() if confidence < 0.30: return { "status": "success", "reason": "low_confidence", "category": "Other", "confidence": round(confidence, 4), "class_index": predicted_index, } label = label_encoder.inverse_transform([predicted_index])[0] return { "status": "success", "category": label, "confidence": round(confidence, 4), "class_index": predicted_index, } def get_model_and_tokenizer(): return model, tokenizer