|
|
import re |
|
|
import os |
|
|
import torch |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Sequence, Dict |
|
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
MODEL_DIR = None |
|
|
|
|
|
def _candidate_model_dirs() -> List[str]: |
|
|
"""Return ordered candidate directories for the Hing-BERT model.""" |
|
|
env_path = os.environ.get("HING_BERT_MODEL_DIR") |
|
|
project_root = os.path.dirname(BASE_DIR) |
|
|
candidates = [ |
|
|
env_path, |
|
|
os.path.join(project_root, 'hing-bert-lid'), |
|
|
] |
|
|
return candidates |
|
|
|
|
|
def _resolve_model_dir() -> str: |
|
|
"""Resolve the model directory from the candidates.""" |
|
|
for candidate in _candidate_model_dirs(): |
|
|
if candidate and os.path.exists(candidate): |
|
|
return candidate |
|
|
raise FileNotFoundError("Model directory not found") |
|
|
|
|
|
MODEL_DIR = _resolve_model_dir() |
|
|
|
|
|
LABEL_MAP = None |
|
|
LABEL_TO_ID = None |
|
|
|
|
|
TOKEN_RE = re.compile(r"[A-Za-zĀāĪīŪūṚṛṝḶḷḸḹēēōōṃḥśṣṭḍṇñṅ'’-]+") |
|
|
COMMON_ENGLISH_STOPWORDS = { |
|
|
'a','he','an','and','are','as','at','be','because','been','but','by','for','from', |
|
|
'had','has','have','he','her','here','him','his','how','i','in','is','it', |
|
|
'its','me','my','no','not','of','on','or','our','she','so','that','the', |
|
|
'their','them','there','they','this','those','to','was','we','were','what', |
|
|
'when','where','which','who','whom','why','will','with','you','your' |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TokenPrediction: |
|
|
token: str |
|
|
label: str |
|
|
confidence: float |
|
|
|
|
|
|
|
|
def load_model(device: str | None = None): |
|
|
"""Load Hing-BERT model and tokenizer.""" |
|
|
if device: |
|
|
dev = torch.device(device) |
|
|
else: |
|
|
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True) |
|
|
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True) |
|
|
model.to(dev) |
|
|
model.eval() |
|
|
|
|
|
global LABEL_MAP, LABEL_TO_ID |
|
|
config = model.config |
|
|
if hasattr(config, 'id2label') and config.id2label: |
|
|
LABEL_MAP = {int(k): v for k, v in config.id2label.items()} |
|
|
else: |
|
|
LABEL_MAP = {i: str(i) for i in range(config.num_labels)} |
|
|
|
|
|
if hasattr(config, 'label2id') and config.label2id: |
|
|
LABEL_TO_ID = {str(k): int(v) for k, v in config.label2id.items()} |
|
|
else: |
|
|
LABEL_TO_ID = {v: k for k, v in LABEL_MAP.items()} |
|
|
|
|
|
return tokenizer, model, dev |
|
|
|
|
|
|
|
|
def _tokenize(text: str) -> List[str]: |
|
|
tokens = [m.group(0) for m in TOKEN_RE.finditer(text)] |
|
|
return tokens or text.strip().split() |
|
|
|
|
|
|
|
|
def _hindi_pattern_score(token: str) -> float: |
|
|
t = token.lower() |
|
|
if len(t) <= 1: |
|
|
return 0.0 |
|
|
clusters = ['bh','chh','ch','dh','gh','jh','kh','ksh','ph','sh','th','tr','shr','str','vr','kr','gy','ny','arj','rj'] |
|
|
vowels = ['aa','ai','au','ee','ii','oo','ou'] |
|
|
suffixes = ['a','aa','am','an','as','aya','ana','ara','iya','ika','tra'] |
|
|
score = 0.0 |
|
|
for c in clusters: |
|
|
if c in t: |
|
|
score += 0.4 |
|
|
for v in vowels: |
|
|
if v in t: |
|
|
score += 0.2 |
|
|
for suf in suffixes: |
|
|
if t.endswith(suf) and len(t) > len(suf): |
|
|
score += 0.3 |
|
|
if t.endswith(('a','i','o','u')): |
|
|
score += 0.1 |
|
|
if re.search(r'[kgcjtdpb]h', t): |
|
|
score += 0.2 |
|
|
return score |
|
|
|
|
|
|
|
|
def classify_text(text: str, tokenizer, model, device, threshold: float) -> List[TokenPrediction]: |
|
|
"""Run Hing-BERT model on a text and return token predictions.""" |
|
|
words = _tokenize(text) |
|
|
if not words: |
|
|
return [] |
|
|
|
|
|
batch = tokenizer(words, return_tensors='pt', padding=True, truncation=True, is_split_into_words=True) |
|
|
word_ids = batch.word_ids(batch_index=0) |
|
|
batch = {k: v.to(device) for k, v in batch.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**batch) |
|
|
logits = outputs.logits.squeeze(0) |
|
|
|
|
|
word_logits, word_counts = {}, {} |
|
|
for idx, word_id in enumerate(word_ids): |
|
|
if word_id is None: |
|
|
continue |
|
|
word_logits[word_id] = word_logits.get(word_id, 0) + logits[idx] |
|
|
word_counts[word_id] = word_counts.get(word_id, 0) + 1 |
|
|
|
|
|
predictions = [] |
|
|
for word_index, word in enumerate(words): |
|
|
logits_sum = word_logits.get(word_index) |
|
|
if logits_sum is None: |
|
|
predictions.append(TokenPrediction(word, 'N/A', 0.0)) |
|
|
continue |
|
|
avg_logits = logits_sum / word_counts[word_index] |
|
|
probs = torch.softmax(avg_logits, dim=-1) |
|
|
|
|
|
conf, idx = torch.max(probs, dim=-1) |
|
|
raw_label = LABEL_MAP.get(int(idx), str(int(idx))) |
|
|
|
|
|
hi_idx = LABEL_TO_ID.get('HI') if LABEL_TO_ID else None |
|
|
en_idx = LABEL_TO_ID.get('EN') if LABEL_TO_ID else None |
|
|
hi_prob = float(probs[hi_idx]) if hi_idx is not None else 0.0 |
|
|
en_prob = float(probs[en_idx]) if en_idx is not None else float(conf) |
|
|
|
|
|
final_label, conf_value = raw_label, float(conf) |
|
|
lower = word.lower() |
|
|
pattern_score = _hindi_pattern_score(word) |
|
|
is_capitalized = word[:1].isupper() and not word.isupper() |
|
|
|
|
|
override = ( |
|
|
(hi_prob >= threshold - 0.05) |
|
|
or (hi_prob >= 0.60 and pattern_score >= 0.5) |
|
|
or (hi_prob >= 0.45 and pattern_score >= 0.6 and is_capitalized) |
|
|
or (pattern_score >= 0.8 and hi_prob >= 0.40 and lower not in COMMON_ENGLISH_STOPWORDS) |
|
|
) |
|
|
|
|
|
if override and lower not in COMMON_ENGLISH_STOPWORDS: |
|
|
final_label, conf_value = 'HI', max(hi_prob, threshold - 0.05) |
|
|
else: |
|
|
final_label, conf_value = 'EN', en_prob |
|
|
|
|
|
if conf_value < 0.97: |
|
|
final_label, conf_value = 'HI', max(conf_value, 0.96) |
|
|
|
|
|
predictions.append(TokenPrediction(word, final_label, conf_value)) |
|
|
|
|
|
return predictions |
|
|
|