|
|
|
|
|
import argparse |
|
|
import os |
|
|
import re |
|
|
import sys |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Sequence, Set, Tuple, Dict, Union, Optional |
|
|
from hindi_xlit import HindiTransliterator |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
MODEL_DIR = os.path.join(BASE_DIR, 'hing_bert_module', 'hing-bert-lid') |
|
|
DICTIONARY_PATH = os.path.join(BASE_DIR, 'hing_bert_module', 'dictionary.txt') |
|
|
OUTPUT_PATH = os.path.join(BASE_DIR, 'output2.txt') |
|
|
LOG_INITIALIZED = False |
|
|
|
|
|
LABEL_MAP = None |
|
|
LABEL_TO_ID = None |
|
|
|
|
|
TOKEN_RE = re.compile(r"[A-Za-zĀāĪīŪūṚṛṝḶḷḸḹēēōōṃḥśṣṭḍṇñṅ'’-]+") |
|
|
COMMON_ENGLISH_STOPWORDS = { |
|
|
'a','he', 'an', 'and', 'are', 'as', 'at', 'be', 'because', 'been', 'but', 'by', 'for', 'from', |
|
|
'had', 'has', 'have', 'he', 'her', 'here', 'him', 'his', 'how', 'i', 'in', 'is', 'it', |
|
|
'its', 'me', 'my', 'no', 'not', 'of', 'on', 'or', 'our', 'she', 'so', 'that', 'the', |
|
|
'their', 'them', 'there', 'they', 'this', 'those', 'to', 'was', 'we', 'were', 'what', |
|
|
'when', 'where', 'which', 'who', 'whom', 'why', 'will', 'with', 'you', 'your' |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TokenPrediction: |
|
|
token: str |
|
|
label: str |
|
|
confidence: float |
|
|
|
|
|
|
|
|
|
|
|
def load_model(device: str | None = None): |
|
|
if device: |
|
|
dev = torch.device(device) |
|
|
else: |
|
|
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True) |
|
|
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True) |
|
|
model.to(dev) |
|
|
model.eval() |
|
|
|
|
|
global LABEL_MAP, LABEL_TO_ID |
|
|
config = model.config |
|
|
if hasattr(config, 'id2label') and config.id2label: |
|
|
LABEL_MAP = {int(k): v for k, v in config.id2label.items()} |
|
|
else: |
|
|
LABEL_MAP = {i: str(i) for i in range(config.num_labels)} |
|
|
if hasattr(config, 'label2id') and config.label2id: |
|
|
LABEL_TO_ID = {str(k): int(v) for k, v in config.label2id.items()} |
|
|
else: |
|
|
LABEL_TO_ID = {v: k for k, v in LABEL_MAP.items()} |
|
|
|
|
|
return tokenizer, model, dev |
|
|
def _tokenize(text: str) -> List[str]: |
|
|
tokens = [m.group(0) for m in TOKEN_RE.finditer(text)] |
|
|
if tokens: |
|
|
return tokens |
|
|
return text.strip().split() |
|
|
|
|
|
|
|
|
def _hindi_pattern_score(token: str) -> float: |
|
|
t = token.lower() |
|
|
if len(t) <= 1: |
|
|
return 0.0 |
|
|
clusters = ['bh', 'chh', 'ch', 'dh', 'gh', 'jh', 'kh', 'ksh', 'ph', 'sh', 'th', 'tr', 'shr', 'str', 'vr', 'kr', 'gy', 'ny', 'arj', 'rj'] |
|
|
vowels = ['aa', 'ai', 'au', 'ee', 'ii', 'oo', 'ou'] |
|
|
suffixes = ['a', 'aa', 'am', 'an', 'as', 'aya', 'ana', 'ara', 'iya', 'ika', 'tra'] |
|
|
score = 0.0 |
|
|
for c in clusters: |
|
|
if c in t: |
|
|
score += 0.4 |
|
|
for v in vowels: |
|
|
if v in t: |
|
|
score += 0.2 |
|
|
for suf in suffixes: |
|
|
if t.endswith(suf) and len(t) > len(suf): |
|
|
score += 0.3 |
|
|
if t.endswith(('a', 'i', 'o', 'u')): |
|
|
score += 0.1 |
|
|
if re.search(r'[kgcjtdpb]h', t): |
|
|
score += 0.2 |
|
|
return score |
|
|
|
|
|
|
|
|
def classify_text( |
|
|
text: str, |
|
|
tokenizer, |
|
|
model, |
|
|
device, |
|
|
threshold: float, |
|
|
) -> List[TokenPrediction]: |
|
|
words = _tokenize(text) |
|
|
if not words: |
|
|
return [] |
|
|
|
|
|
batch = tokenizer( |
|
|
words, |
|
|
return_tensors='pt', |
|
|
padding=True, |
|
|
truncation=True, |
|
|
is_split_into_words=True |
|
|
) |
|
|
word_ids = batch.word_ids(batch_index=0) |
|
|
batch = {k: v.to(device) for k, v in batch.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**batch) |
|
|
logits = outputs.logits.squeeze(0) |
|
|
|
|
|
word_logits: dict[int, torch.Tensor] = {} |
|
|
word_counts: dict[int, int] = {} |
|
|
for idx, word_id in enumerate(word_ids): |
|
|
if word_id is None: |
|
|
continue |
|
|
if word_id not in word_logits: |
|
|
word_logits[word_id] = logits[idx] |
|
|
word_counts[word_id] = 1 |
|
|
else: |
|
|
word_logits[word_id] += logits[idx] |
|
|
word_counts[word_id] += 1 |
|
|
|
|
|
predictions: List[TokenPrediction] = [] |
|
|
for word_index, word in enumerate(words): |
|
|
logits_sum = word_logits.get(word_index) |
|
|
if logits_sum is None: |
|
|
predictions.append(TokenPrediction(word, 'N/A', 0.0)) |
|
|
continue |
|
|
avg_logits = logits_sum / word_counts[word_index] |
|
|
probs = torch.softmax(avg_logits, dim=-1) |
|
|
|
|
|
conf, idx = torch.max(probs, dim=-1) |
|
|
raw_label = LABEL_MAP.get(int(idx), str(int(idx))) |
|
|
|
|
|
hi_idx = LABEL_TO_ID.get('HI') if LABEL_TO_ID else None |
|
|
en_idx = LABEL_TO_ID.get('EN') if LABEL_TO_ID else None |
|
|
hi_prob = float(probs[hi_idx]) if hi_idx is not None else 0.0 |
|
|
en_prob = float(probs[en_idx]) if en_idx is not None else float(conf) |
|
|
|
|
|
final_label = raw_label |
|
|
conf_value = float(conf) |
|
|
|
|
|
if hi_idx is not None and hi_prob >= threshold: |
|
|
final_label = 'HI' |
|
|
conf_value = hi_prob |
|
|
elif raw_label == 'HI': |
|
|
final_label = 'HI' |
|
|
conf_value = hi_prob |
|
|
else: |
|
|
lower = word.lower() |
|
|
pattern_score = _hindi_pattern_score(word) |
|
|
is_capitalized = word[:1].isupper() and not word.isupper() |
|
|
|
|
|
override = False |
|
|
if hi_prob >= threshold - 0.05: |
|
|
override = True |
|
|
elif hi_prob >= 0.60 and pattern_score >= 0.5: |
|
|
override = True |
|
|
elif hi_prob >= 0.45 and pattern_score >= 0.6 and is_capitalized: |
|
|
override = True |
|
|
elif pattern_score >= 0.8 and hi_prob >= 0.40 and lower not in COMMON_ENGLISH_STOPWORDS: |
|
|
override = True |
|
|
|
|
|
if override and lower not in COMMON_ENGLISH_STOPWORDS: |
|
|
final_label = 'HI' |
|
|
conf_value = max(hi_prob, threshold - 0.05) |
|
|
else: |
|
|
final_label = 'EN' |
|
|
conf_value = en_prob |
|
|
|
|
|
if conf_value < 0.97: |
|
|
final_label = 'HI' |
|
|
conf_value = max(conf_value, 0.96) |
|
|
|
|
|
predictions.append(TokenPrediction(word, final_label, conf_value)) |
|
|
|
|
|
return predictions |
|
|
|
|
|
|
|
|
def _print_predictions(predictions: Sequence[TokenPrediction]): |
|
|
print("Token\tLabel\tConfidence") |
|
|
for pred in predictions: |
|
|
print(f"{pred.token}\t{pred.label}\t{pred.confidence:.4f}") |
|
|
|
|
|
|
|
|
def _init_output_log(): |
|
|
global LOG_INITIALIZED |
|
|
if LOG_INITIALIZED: |
|
|
return |
|
|
with open(OUTPUT_PATH, 'w', encoding='utf-8') as f: |
|
|
f.write('HingBERT-LID session log\n') |
|
|
LOG_INITIALIZED = True |
|
|
|
|
|
|
|
|
def load_dictionary(filename: str = None) -> Dict[str, str]: |
|
|
""" |
|
|
Load the mythology dictionary from file. |
|
|
Returns a dictionary mapping English words to Hindi transliterations. |
|
|
""" |
|
|
filename = filename or DICTIONARY_PATH |
|
|
dictionary = {} |
|
|
try: |
|
|
with open(filename, 'r', encoding='utf-8') as f: |
|
|
in_dict = False |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
|
|
|
|
|
|
if not line or line.startswith('#'): |
|
|
continue |
|
|
|
|
|
|
|
|
if 'MYTHOLOGY_DICTIONARY = {' in line: |
|
|
in_dict = True |
|
|
line = line.split('{', 1)[1].strip() |
|
|
if not line: |
|
|
continue |
|
|
|
|
|
if not in_dict: |
|
|
continue |
|
|
|
|
|
|
|
|
if ':' in line: |
|
|
|
|
|
while not line.rstrip().endswith(','): |
|
|
next_line = next(f, '').strip() |
|
|
if not next_line: |
|
|
break |
|
|
line += ' ' + next_line |
|
|
|
|
|
|
|
|
line = line.split('}')[0].strip() |
|
|
|
|
|
|
|
|
entries = [e.strip() for e in line.split(',') if ':' in e] |
|
|
|
|
|
for entry in entries: |
|
|
try: |
|
|
key_part, value_part = entry.split(':', 1) |
|
|
key = key_part.strip().strip("'\"") |
|
|
value = value_part.strip().strip("'\"").rstrip('}') |
|
|
if key and value: |
|
|
dictionary[key.lower()] = value |
|
|
except (ValueError, IndexError): |
|
|
continue |
|
|
|
|
|
|
|
|
if '}' in line and in_dict: |
|
|
break |
|
|
|
|
|
print(f"✓ Dictionary loaded successfully: {len(dictionary)} words") |
|
|
return dictionary |
|
|
|
|
|
except FileNotFoundError: |
|
|
print(f"Warning: Dictionary file '{filename}' not found.") |
|
|
print("Proceeding with model-only transliteration.") |
|
|
return {} |
|
|
except Exception as e: |
|
|
print(f"Warning: Error loading dictionary: {str(e)}") |
|
|
print("Proceeding with model-only transliteration.") |
|
|
return {} |
|
|
|
|
|
|
|
|
def get_transliteration(word: str, dictionary: Dict[str, str], transliterator, show_source: bool = False) -> Union[str, tuple]: |
|
|
""" |
|
|
Get transliteration for a word. |
|
|
First checks dictionary, then falls back to model. |
|
|
|
|
|
Args: |
|
|
word: English word to transliterate |
|
|
dictionary: Dictionary mapping English to Hindi |
|
|
transliterator: HindiTransliterator instance |
|
|
show_source: If True, returns (transliteration, source) |
|
|
|
|
|
Returns: |
|
|
Transliteration string, or tuple (transliteration, source) if show_source=True |
|
|
""" |
|
|
word_lower = word.lower().strip() |
|
|
|
|
|
|
|
|
if word_lower in dictionary: |
|
|
result = dictionary[word_lower] |
|
|
if show_source: |
|
|
return result, "dictionary" |
|
|
return result |
|
|
|
|
|
|
|
|
try: |
|
|
model_result = transliterator.transliterate(word) |
|
|
|
|
|
|
|
|
if isinstance(model_result, list): |
|
|
result = model_result[0] |
|
|
else: |
|
|
result = model_result |
|
|
|
|
|
if show_source: |
|
|
return result, "model" |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
if show_source: |
|
|
return word, "error" |
|
|
return word |
|
|
|
|
|
|
|
|
def _write_predictions(predictions: Sequence[TokenPrediction], source_text: str): |
|
|
"""Write predictions to output file.""" |
|
|
global LOG_INITIALIZED |
|
|
_init_output_log() |
|
|
with open(OUTPUT_PATH, 'a', encoding='utf-8') as f: |
|
|
if not LOG_INITIALIZED: |
|
|
f.write('\n' + '='*80 + '\n') |
|
|
LOG_INITIALIZED = True |
|
|
f.write(f'\nSource: {source_text}\n') |
|
|
for pred in predictions: |
|
|
f.write(f"{pred.token}\t{pred.label}\t{pred.confidence:.4f}\n") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Test l3cube-pune/hing-bert-lid on text (token-level).') |
|
|
parser.add_argument('--device', type=str, default=None, help='torch device (cpu or cuda)') |
|
|
parser.add_argument('--text', type=str, default=None, help='Text to classify.') |
|
|
parser.add_argument('--threshold', type=float, default=0.80, help='Confidence threshold for Hindi override heuristics (default=0.80)') |
|
|
parser.add_argument('--dictionary', type=str, default=DICTIONARY_PATH, help='Path to dictionary file (default: dictionary.txt)') |
|
|
args = parser.parse_args() |
|
|
|
|
|
tokenizer, model, device = load_model(args.device) |
|
|
|
|
|
|
|
|
dictionary = load_dictionary(args.dictionary) |
|
|
transliterator = HindiTransliterator() |
|
|
|
|
|
hindi_words = set() |
|
|
|
|
|
if args.text: |
|
|
preds = classify_text(args.text, tokenizer, model, device, args.threshold) |
|
|
_print_predictions(preds) |
|
|
_write_predictions(preds, args.text) |
|
|
|
|
|
hindi_words.update(pred.token for pred in preds if pred.label == 'HI') |
|
|
print("\nHindi words found:", ", ".join(hindi_words) if hindi_words else "None") |
|
|
return |
|
|
|
|
|
print("Interactive mode. Type text lines (QUIT to exit).") |
|
|
_init_output_log() |
|
|
all_input = [] |
|
|
try: |
|
|
for line in sys.stdin: |
|
|
line = line.rstrip('\n') |
|
|
if not line: |
|
|
continue |
|
|
if line.strip().upper() == 'QUIT': |
|
|
break |
|
|
all_input.append(line) |
|
|
preds = classify_text(line, tokenizer, model, device, args.threshold) |
|
|
_print_predictions(preds) |
|
|
_write_predictions(preds, line) |
|
|
|
|
|
hindi_words.update(pred.token for pred in preds if pred.label == 'HI') |
|
|
print() |
|
|
except (KeyboardInterrupt, EOFError): |
|
|
pass |
|
|
finally: |
|
|
|
|
|
if hindi_words: |
|
|
print("\nAll Hindi words found:", ", ".join(sorted(hindi_words))) |
|
|
|
|
|
|
|
|
print("\nTransliterated to Devanagari:") |
|
|
hindi_to_devanagari = {} |
|
|
|
|
|
for word in sorted(hindi_words): |
|
|
try: |
|
|
devanagari = get_transliteration(word, dictionary, transliterator) |
|
|
hindi_to_devanagari[word] = devanagari |
|
|
source = " (dictionary)" if word.lower() in dictionary else " (model)" |
|
|
print(f"{word} -> {devanagari}{source}") |
|
|
except Exception as e: |
|
|
print(f"{word} -> [Error: {str(e)}]") |
|
|
|
|
|
|
|
|
output_file = os.path.join(BASE_DIR, 'final_output.txt') |
|
|
with open(output_file, 'w', encoding='utf-8') as f: |
|
|
|
|
|
f.write("=== Original Text ===\n") |
|
|
f.write('\n'.join(all_input) + '\n\n') |
|
|
|
|
|
|
|
|
f.write("=== Reconstructed Text with Devanagari ===\n") |
|
|
for line in all_input: |
|
|
reconstructed = line |
|
|
for word, devanagari in hindi_to_devanagari.items(): |
|
|
|
|
|
reconstructed = re.sub( |
|
|
rf'(?<![\w\-])({re.escape(word)})(?![\w\-])', |
|
|
devanagari, |
|
|
reconstructed, |
|
|
flags=re.IGNORECASE |
|
|
) |
|
|
f.write(reconstructed + '\n') |
|
|
|
|
|
|
|
|
f.write("\n=== Hindi Words and Transliterations ===\n") |
|
|
for word, devanagari in sorted(hindi_to_devanagari.items()): |
|
|
f.write(f"{word} -> {devanagari}\n") |
|
|
|
|
|
print(f"\nOutput saved to: {output_file}") |
|
|
else: |
|
|
print("\nNo Hindi words found.") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
if sys.platform == 'win32': |
|
|
try: |
|
|
sys.stdout.reconfigure(encoding='utf-8') |
|
|
except (AttributeError, TypeError): |
|
|
import io |
|
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace') |
|
|
main() |
|
|
|