#!/usr/bin/env python3 import argparse import os import re import sys from dataclasses import dataclass from typing import List, Sequence, Set, Tuple, Dict, Union, Optional from hindi_xlit import HindiTransliterator import torch from transformers import AutoTokenizer, AutoModelForTokenClassification BASE_DIR = os.path.dirname(os.path.abspath(__file__)) MODEL_DIR = os.path.join(BASE_DIR, 'hing_bert_module', 'hing-bert-lid') DICTIONARY_PATH = os.path.join(BASE_DIR, 'hing_bert_module', 'dictionary.txt') OUTPUT_PATH = os.path.join(BASE_DIR, 'output2.txt') LOG_INITIALIZED = False LABEL_MAP = None LABEL_TO_ID = None TOKEN_RE = re.compile(r"[A-Za-zĀāĪīŪūṚṛṝḶḷḸḹēēōōṃḥśṣṭḍṇñṅ'’-]+") COMMON_ENGLISH_STOPWORDS = { 'a','he', 'an', 'and', 'are', 'as', 'at', 'be', 'because', 'been', 'but', 'by', 'for', 'from', 'had', 'has', 'have', 'he', 'her', 'here', 'him', 'his', 'how', 'i', 'in', 'is', 'it', 'its', 'me', 'my', 'no', 'not', 'of', 'on', 'or', 'our', 'she', 'so', 'that', 'the', 'their', 'them', 'there', 'they', 'this', 'those', 'to', 'was', 'we', 'were', 'what', 'when', 'where', 'which', 'who', 'whom', 'why', 'will', 'with', 'you', 'your' } @dataclass class TokenPrediction: token: str label: str confidence: float def load_model(device: str | None = None): if device: dev = torch.device(device) else: dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True) model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True) model.to(dev) model.eval() global LABEL_MAP, LABEL_TO_ID config = model.config if hasattr(config, 'id2label') and config.id2label: LABEL_MAP = {int(k): v for k, v in config.id2label.items()} else: LABEL_MAP = {i: str(i) for i in range(config.num_labels)} if hasattr(config, 'label2id') and config.label2id: LABEL_TO_ID = {str(k): int(v) for k, v in config.label2id.items()} else: LABEL_TO_ID = {v: k for k, v in LABEL_MAP.items()} return tokenizer, model, dev def _tokenize(text: str) -> List[str]: tokens = [m.group(0) for m in TOKEN_RE.finditer(text)] if tokens: return tokens return text.strip().split() def _hindi_pattern_score(token: str) -> float: t = token.lower() if len(t) <= 1: return 0.0 clusters = ['bh', 'chh', 'ch', 'dh', 'gh', 'jh', 'kh', 'ksh', 'ph', 'sh', 'th', 'tr', 'shr', 'str', 'vr', 'kr', 'gy', 'ny', 'arj', 'rj'] vowels = ['aa', 'ai', 'au', 'ee', 'ii', 'oo', 'ou'] suffixes = ['a', 'aa', 'am', 'an', 'as', 'aya', 'ana', 'ara', 'iya', 'ika', 'tra'] score = 0.0 for c in clusters: if c in t: score += 0.4 for v in vowels: if v in t: score += 0.2 for suf in suffixes: if t.endswith(suf) and len(t) > len(suf): score += 0.3 if t.endswith(('a', 'i', 'o', 'u')): score += 0.1 if re.search(r'[kgcjtdpb]h', t): score += 0.2 return score def classify_text( text: str, tokenizer, model, device, threshold: float, ) -> List[TokenPrediction]: words = _tokenize(text) if not words: return [] batch = tokenizer( words, return_tensors='pt', padding=True, truncation=True, is_split_into_words=True ) word_ids = batch.word_ids(batch_index=0) batch = {k: v.to(device) for k, v in batch.items()} with torch.no_grad(): outputs = model(**batch) logits = outputs.logits.squeeze(0) word_logits: dict[int, torch.Tensor] = {} word_counts: dict[int, int] = {} for idx, word_id in enumerate(word_ids): if word_id is None: continue if word_id not in word_logits: word_logits[word_id] = logits[idx] word_counts[word_id] = 1 else: word_logits[word_id] += logits[idx] word_counts[word_id] += 1 predictions: List[TokenPrediction] = [] for word_index, word in enumerate(words): logits_sum = word_logits.get(word_index) if logits_sum is None: predictions.append(TokenPrediction(word, 'N/A', 0.0)) continue avg_logits = logits_sum / word_counts[word_index] probs = torch.softmax(avg_logits, dim=-1) conf, idx = torch.max(probs, dim=-1) raw_label = LABEL_MAP.get(int(idx), str(int(idx))) hi_idx = LABEL_TO_ID.get('HI') if LABEL_TO_ID else None en_idx = LABEL_TO_ID.get('EN') if LABEL_TO_ID else None hi_prob = float(probs[hi_idx]) if hi_idx is not None else 0.0 en_prob = float(probs[en_idx]) if en_idx is not None else float(conf) final_label = raw_label conf_value = float(conf) if hi_idx is not None and hi_prob >= threshold: final_label = 'HI' conf_value = hi_prob elif raw_label == 'HI': final_label = 'HI' conf_value = hi_prob else: lower = word.lower() pattern_score = _hindi_pattern_score(word) is_capitalized = word[:1].isupper() and not word.isupper() override = False if hi_prob >= threshold - 0.05: override = True elif hi_prob >= 0.60 and pattern_score >= 0.5: override = True elif hi_prob >= 0.45 and pattern_score >= 0.6 and is_capitalized: override = True elif pattern_score >= 0.8 and hi_prob >= 0.40 and lower not in COMMON_ENGLISH_STOPWORDS: override = True if override and lower not in COMMON_ENGLISH_STOPWORDS: final_label = 'HI' conf_value = max(hi_prob, threshold - 0.05) else: final_label = 'EN' conf_value = en_prob if conf_value < 0.97: final_label = 'HI' conf_value = max(conf_value, 0.96) predictions.append(TokenPrediction(word, final_label, conf_value)) return predictions def _print_predictions(predictions: Sequence[TokenPrediction]): print("Token\tLabel\tConfidence") for pred in predictions: print(f"{pred.token}\t{pred.label}\t{pred.confidence:.4f}") def _init_output_log(): global LOG_INITIALIZED if LOG_INITIALIZED: return with open(OUTPUT_PATH, 'w', encoding='utf-8') as f: f.write('HingBERT-LID session log\n') LOG_INITIALIZED = True def load_dictionary(filename: str = None) -> Dict[str, str]: """ Load the mythology dictionary from file. Returns a dictionary mapping English words to Hindi transliterations. """ filename = filename or DICTIONARY_PATH dictionary = {} try: with open(filename, 'r', encoding='utf-8') as f: in_dict = False for line in f: line = line.strip() # Skip empty lines and comments if not line or line.startswith('#'): continue # Check if we're in the dictionary section if 'MYTHOLOGY_DICTIONARY = {' in line: in_dict = True line = line.split('{', 1)[1].strip() if not line: # If the line ends after { continue if not in_dict: continue # Process key-value pairs if ':' in line: # Handle multi-line entries while not line.rstrip().endswith(','): next_line = next(f, '').strip() if not next_line: break line += ' ' + next_line # Handle the last line which might end with } line = line.split('}')[0].strip() # Split into key-value pairs entries = [e.strip() for e in line.split(',') if ':' in e] for entry in entries: try: key_part, value_part = entry.split(':', 1) key = key_part.strip().strip("'\"") value = value_part.strip().strip("'\"").rstrip('}') if key and value: dictionary[key.lower()] = value except (ValueError, IndexError): continue # Check for end of dictionary if '}' in line and in_dict: break print(f"✓ Dictionary loaded successfully: {len(dictionary)} words") return dictionary except FileNotFoundError: print(f"Warning: Dictionary file '{filename}' not found.") print("Proceeding with model-only transliteration.") return {} except Exception as e: print(f"Warning: Error loading dictionary: {str(e)}") print("Proceeding with model-only transliteration.") return {} def get_transliteration(word: str, dictionary: Dict[str, str], transliterator, show_source: bool = False) -> Union[str, tuple]: """ Get transliteration for a word. First checks dictionary, then falls back to model. Args: word: English word to transliterate dictionary: Dictionary mapping English to Hindi transliterator: HindiTransliterator instance show_source: If True, returns (transliteration, source) Returns: Transliteration string, or tuple (transliteration, source) if show_source=True """ word_lower = word.lower().strip() # Check dictionary first if word_lower in dictionary: result = dictionary[word_lower] if show_source: return result, "dictionary" return result # Fall back to model try: model_result = transliterator.transliterate(word) # Handle if model returns a list if isinstance(model_result, list): result = model_result[0] # Take first (best) result else: result = model_result if show_source: return result, "model" return result except Exception as e: if show_source: return word, "error" return word def _write_predictions(predictions: Sequence[TokenPrediction], source_text: str): """Write predictions to output file.""" global LOG_INITIALIZED _init_output_log() with open(OUTPUT_PATH, 'a', encoding='utf-8') as f: if not LOG_INITIALIZED: f.write('\n' + '='*80 + '\n') LOG_INITIALIZED = True f.write(f'\nSource: {source_text}\n') for pred in predictions: f.write(f"{pred.token}\t{pred.label}\t{pred.confidence:.4f}\n") def main(): parser = argparse.ArgumentParser(description='Test l3cube-pune/hing-bert-lid on text (token-level).') parser.add_argument('--device', type=str, default=None, help='torch device (cpu or cuda)') parser.add_argument('--text', type=str, default=None, help='Text to classify.') parser.add_argument('--threshold', type=float, default=0.80, help='Confidence threshold for Hindi override heuristics (default=0.80)') parser.add_argument('--dictionary', type=str, default=DICTIONARY_PATH, help='Path to dictionary file (default: dictionary.txt)') args = parser.parse_args() tokenizer, model, device = load_model(args.device) # Load dictionary for transliteration dictionary = load_dictionary(args.dictionary) transliterator = HindiTransliterator() hindi_words = set() # To store unique Hindi words if args.text: preds = classify_text(args.text, tokenizer, model, device, args.threshold) _print_predictions(preds) _write_predictions(preds, args.text) # Add Hindi words to the set hindi_words.update(pred.token for pred in preds if pred.label == 'HI') print("\nHindi words found:", ", ".join(hindi_words) if hindi_words else "None") return print("Interactive mode. Type text lines (QUIT to exit).") _init_output_log() all_input = [] # Store all input text try: for line in sys.stdin: line = line.rstrip('\n') if not line: continue if line.strip().upper() == 'QUIT': break all_input.append(line) # Add to full input preds = classify_text(line, tokenizer, model, device, args.threshold) _print_predictions(preds) _write_predictions(preds, line) # Add Hindi words to the set hindi_words.update(pred.token for pred in preds if pred.label == 'HI') print() except (KeyboardInterrupt, EOFError): pass finally: # Print all collected Hindi words before exiting if hindi_words: print("\nAll Hindi words found:", ", ".join(sorted(hindi_words))) # Create a mapping of Hindi words to their Devanagari transliterations print("\nTransliterated to Devanagari:") hindi_to_devanagari = {} for word in sorted(hindi_words): try: devanagari = get_transliteration(word, dictionary, transliterator) hindi_to_devanagari[word] = devanagari source = " (dictionary)" if word.lower() in dictionary else " (model)" print(f"{word} -> {devanagari}{source}") except Exception as e: print(f"{word} -> [Error: {str(e)}]") # Save the output to a file output_file = os.path.join(BASE_DIR, 'final_output.txt') with open(output_file, 'w', encoding='utf-8') as f: # Write original full text f.write("=== Original Text ===\n") f.write('\n'.join(all_input) + '\n\n') # Write reconstructed text f.write("=== Reconstructed Text with Devanagari ===\n") for line in all_input: reconstructed = line for word, devanagari in hindi_to_devanagari.items(): # Use a more precise replacement that preserves punctuation and spacing reconstructed = re.sub( rf'(? {devanagari}\n") print(f"\nOutput saved to: {output_file}") else: print("\nNo Hindi words found.") if __name__ == '__main__': if sys.platform == 'win32': try: sys.stdout.reconfigure(encoding='utf-8') except (AttributeError, TypeError): import io sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace') main()