tts_project / imp_scripts /text_processor.py
PraveenSharma08's picture
Initial project upload: Hindi/English Text-to-Speech pipeline
8a02978
#!/usr/bin/env python3
import argparse
import os
import re
import sys
from dataclasses import dataclass
from typing import List, Sequence, Set, Tuple, Dict, Union, Optional
from hindi_xlit import HindiTransliterator
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_DIR = os.path.join(BASE_DIR, 'hing_bert_module', 'hing-bert-lid')
DICTIONARY_PATH = os.path.join(BASE_DIR, 'hing_bert_module', 'dictionary.txt')
OUTPUT_PATH = os.path.join(BASE_DIR, 'output2.txt')
LOG_INITIALIZED = False
LABEL_MAP = None
LABEL_TO_ID = None
TOKEN_RE = re.compile(r"[A-Za-zĀāĪīŪūṚṛṝḶḷḸḹēēōōṃḥśṣṭḍṇñṅ'’-]+")
COMMON_ENGLISH_STOPWORDS = {
'a','he', 'an', 'and', 'are', 'as', 'at', 'be', 'because', 'been', 'but', 'by', 'for', 'from',
'had', 'has', 'have', 'he', 'her', 'here', 'him', 'his', 'how', 'i', 'in', 'is', 'it',
'its', 'me', 'my', 'no', 'not', 'of', 'on', 'or', 'our', 'she', 'so', 'that', 'the',
'their', 'them', 'there', 'they', 'this', 'those', 'to', 'was', 'we', 'were', 'what',
'when', 'where', 'which', 'who', 'whom', 'why', 'will', 'with', 'you', 'your'
}
@dataclass
class TokenPrediction:
token: str
label: str
confidence: float
def load_model(device: str | None = None):
if device:
dev = torch.device(device)
else:
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)
model.to(dev)
model.eval()
global LABEL_MAP, LABEL_TO_ID
config = model.config
if hasattr(config, 'id2label') and config.id2label:
LABEL_MAP = {int(k): v for k, v in config.id2label.items()}
else:
LABEL_MAP = {i: str(i) for i in range(config.num_labels)}
if hasattr(config, 'label2id') and config.label2id:
LABEL_TO_ID = {str(k): int(v) for k, v in config.label2id.items()}
else:
LABEL_TO_ID = {v: k for k, v in LABEL_MAP.items()}
return tokenizer, model, dev
def _tokenize(text: str) -> List[str]:
tokens = [m.group(0) for m in TOKEN_RE.finditer(text)]
if tokens:
return tokens
return text.strip().split()
def _hindi_pattern_score(token: str) -> float:
t = token.lower()
if len(t) <= 1:
return 0.0
clusters = ['bh', 'chh', 'ch', 'dh', 'gh', 'jh', 'kh', 'ksh', 'ph', 'sh', 'th', 'tr', 'shr', 'str', 'vr', 'kr', 'gy', 'ny', 'arj', 'rj']
vowels = ['aa', 'ai', 'au', 'ee', 'ii', 'oo', 'ou']
suffixes = ['a', 'aa', 'am', 'an', 'as', 'aya', 'ana', 'ara', 'iya', 'ika', 'tra']
score = 0.0
for c in clusters:
if c in t:
score += 0.4
for v in vowels:
if v in t:
score += 0.2
for suf in suffixes:
if t.endswith(suf) and len(t) > len(suf):
score += 0.3
if t.endswith(('a', 'i', 'o', 'u')):
score += 0.1
if re.search(r'[kgcjtdpb]h', t):
score += 0.2
return score
def classify_text(
text: str,
tokenizer,
model,
device,
threshold: float,
) -> List[TokenPrediction]:
words = _tokenize(text)
if not words:
return []
batch = tokenizer(
words,
return_tensors='pt',
padding=True,
truncation=True,
is_split_into_words=True
)
word_ids = batch.word_ids(batch_index=0)
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**batch)
logits = outputs.logits.squeeze(0)
word_logits: dict[int, torch.Tensor] = {}
word_counts: dict[int, int] = {}
for idx, word_id in enumerate(word_ids):
if word_id is None:
continue
if word_id not in word_logits:
word_logits[word_id] = logits[idx]
word_counts[word_id] = 1
else:
word_logits[word_id] += logits[idx]
word_counts[word_id] += 1
predictions: List[TokenPrediction] = []
for word_index, word in enumerate(words):
logits_sum = word_logits.get(word_index)
if logits_sum is None:
predictions.append(TokenPrediction(word, 'N/A', 0.0))
continue
avg_logits = logits_sum / word_counts[word_index]
probs = torch.softmax(avg_logits, dim=-1)
conf, idx = torch.max(probs, dim=-1)
raw_label = LABEL_MAP.get(int(idx), str(int(idx)))
hi_idx = LABEL_TO_ID.get('HI') if LABEL_TO_ID else None
en_idx = LABEL_TO_ID.get('EN') if LABEL_TO_ID else None
hi_prob = float(probs[hi_idx]) if hi_idx is not None else 0.0
en_prob = float(probs[en_idx]) if en_idx is not None else float(conf)
final_label = raw_label
conf_value = float(conf)
if hi_idx is not None and hi_prob >= threshold:
final_label = 'HI'
conf_value = hi_prob
elif raw_label == 'HI':
final_label = 'HI'
conf_value = hi_prob
else:
lower = word.lower()
pattern_score = _hindi_pattern_score(word)
is_capitalized = word[:1].isupper() and not word.isupper()
override = False
if hi_prob >= threshold - 0.05:
override = True
elif hi_prob >= 0.60 and pattern_score >= 0.5:
override = True
elif hi_prob >= 0.45 and pattern_score >= 0.6 and is_capitalized:
override = True
elif pattern_score >= 0.8 and hi_prob >= 0.40 and lower not in COMMON_ENGLISH_STOPWORDS:
override = True
if override and lower not in COMMON_ENGLISH_STOPWORDS:
final_label = 'HI'
conf_value = max(hi_prob, threshold - 0.05)
else:
final_label = 'EN'
conf_value = en_prob
if conf_value < 0.97:
final_label = 'HI'
conf_value = max(conf_value, 0.96)
predictions.append(TokenPrediction(word, final_label, conf_value))
return predictions
def _print_predictions(predictions: Sequence[TokenPrediction]):
print("Token\tLabel\tConfidence")
for pred in predictions:
print(f"{pred.token}\t{pred.label}\t{pred.confidence:.4f}")
def _init_output_log():
global LOG_INITIALIZED
if LOG_INITIALIZED:
return
with open(OUTPUT_PATH, 'w', encoding='utf-8') as f:
f.write('HingBERT-LID session log\n')
LOG_INITIALIZED = True
def load_dictionary(filename: str = None) -> Dict[str, str]:
"""
Load the mythology dictionary from file.
Returns a dictionary mapping English words to Hindi transliterations.
"""
filename = filename or DICTIONARY_PATH
dictionary = {}
try:
with open(filename, 'r', encoding='utf-8') as f:
in_dict = False
for line in f:
line = line.strip()
# Skip empty lines and comments
if not line or line.startswith('#'):
continue
# Check if we're in the dictionary section
if 'MYTHOLOGY_DICTIONARY = {' in line:
in_dict = True
line = line.split('{', 1)[1].strip()
if not line: # If the line ends after {
continue
if not in_dict:
continue
# Process key-value pairs
if ':' in line:
# Handle multi-line entries
while not line.rstrip().endswith(','):
next_line = next(f, '').strip()
if not next_line:
break
line += ' ' + next_line
# Handle the last line which might end with }
line = line.split('}')[0].strip()
# Split into key-value pairs
entries = [e.strip() for e in line.split(',') if ':' in e]
for entry in entries:
try:
key_part, value_part = entry.split(':', 1)
key = key_part.strip().strip("'\"")
value = value_part.strip().strip("'\"").rstrip('}')
if key and value:
dictionary[key.lower()] = value
except (ValueError, IndexError):
continue
# Check for end of dictionary
if '}' in line and in_dict:
break
print(f"✓ Dictionary loaded successfully: {len(dictionary)} words")
return dictionary
except FileNotFoundError:
print(f"Warning: Dictionary file '{filename}' not found.")
print("Proceeding with model-only transliteration.")
return {}
except Exception as e:
print(f"Warning: Error loading dictionary: {str(e)}")
print("Proceeding with model-only transliteration.")
return {}
def get_transliteration(word: str, dictionary: Dict[str, str], transliterator, show_source: bool = False) -> Union[str, tuple]:
"""
Get transliteration for a word.
First checks dictionary, then falls back to model.
Args:
word: English word to transliterate
dictionary: Dictionary mapping English to Hindi
transliterator: HindiTransliterator instance
show_source: If True, returns (transliteration, source)
Returns:
Transliteration string, or tuple (transliteration, source) if show_source=True
"""
word_lower = word.lower().strip()
# Check dictionary first
if word_lower in dictionary:
result = dictionary[word_lower]
if show_source:
return result, "dictionary"
return result
# Fall back to model
try:
model_result = transliterator.transliterate(word)
# Handle if model returns a list
if isinstance(model_result, list):
result = model_result[0] # Take first (best) result
else:
result = model_result
if show_source:
return result, "model"
return result
except Exception as e:
if show_source:
return word, "error"
return word
def _write_predictions(predictions: Sequence[TokenPrediction], source_text: str):
"""Write predictions to output file."""
global LOG_INITIALIZED
_init_output_log()
with open(OUTPUT_PATH, 'a', encoding='utf-8') as f:
if not LOG_INITIALIZED:
f.write('\n' + '='*80 + '\n')
LOG_INITIALIZED = True
f.write(f'\nSource: {source_text}\n')
for pred in predictions:
f.write(f"{pred.token}\t{pred.label}\t{pred.confidence:.4f}\n")
def main():
parser = argparse.ArgumentParser(description='Test l3cube-pune/hing-bert-lid on text (token-level).')
parser.add_argument('--device', type=str, default=None, help='torch device (cpu or cuda)')
parser.add_argument('--text', type=str, default=None, help='Text to classify.')
parser.add_argument('--threshold', type=float, default=0.80, help='Confidence threshold for Hindi override heuristics (default=0.80)')
parser.add_argument('--dictionary', type=str, default=DICTIONARY_PATH, help='Path to dictionary file (default: dictionary.txt)')
args = parser.parse_args()
tokenizer, model, device = load_model(args.device)
# Load dictionary for transliteration
dictionary = load_dictionary(args.dictionary)
transliterator = HindiTransliterator()
hindi_words = set() # To store unique Hindi words
if args.text:
preds = classify_text(args.text, tokenizer, model, device, args.threshold)
_print_predictions(preds)
_write_predictions(preds, args.text)
# Add Hindi words to the set
hindi_words.update(pred.token for pred in preds if pred.label == 'HI')
print("\nHindi words found:", ", ".join(hindi_words) if hindi_words else "None")
return
print("Interactive mode. Type text lines (QUIT to exit).")
_init_output_log()
all_input = [] # Store all input text
try:
for line in sys.stdin:
line = line.rstrip('\n')
if not line:
continue
if line.strip().upper() == 'QUIT':
break
all_input.append(line) # Add to full input
preds = classify_text(line, tokenizer, model, device, args.threshold)
_print_predictions(preds)
_write_predictions(preds, line)
# Add Hindi words to the set
hindi_words.update(pred.token for pred in preds if pred.label == 'HI')
print()
except (KeyboardInterrupt, EOFError):
pass
finally:
# Print all collected Hindi words before exiting
if hindi_words:
print("\nAll Hindi words found:", ", ".join(sorted(hindi_words)))
# Create a mapping of Hindi words to their Devanagari transliterations
print("\nTransliterated to Devanagari:")
hindi_to_devanagari = {}
for word in sorted(hindi_words):
try:
devanagari = get_transliteration(word, dictionary, transliterator)
hindi_to_devanagari[word] = devanagari
source = " (dictionary)" if word.lower() in dictionary else " (model)"
print(f"{word} -> {devanagari}{source}")
except Exception as e:
print(f"{word} -> [Error: {str(e)}]")
# Save the output to a file
output_file = os.path.join(BASE_DIR, 'final_output.txt')
with open(output_file, 'w', encoding='utf-8') as f:
# Write original full text
f.write("=== Original Text ===\n")
f.write('\n'.join(all_input) + '\n\n')
# Write reconstructed text
f.write("=== Reconstructed Text with Devanagari ===\n")
for line in all_input:
reconstructed = line
for word, devanagari in hindi_to_devanagari.items():
# Use a more precise replacement that preserves punctuation and spacing
reconstructed = re.sub(
rf'(?<![\w\-])({re.escape(word)})(?![\w\-])',
devanagari,
reconstructed,
flags=re.IGNORECASE
)
f.write(reconstructed + '\n')
# Write Hindi words and their transliterations
f.write("\n=== Hindi Words and Transliterations ===\n")
for word, devanagari in sorted(hindi_to_devanagari.items()):
f.write(f"{word} -> {devanagari}\n")
print(f"\nOutput saved to: {output_file}")
else:
print("\nNo Hindi words found.")
if __name__ == '__main__':
if sys.platform == 'win32':
try:
sys.stdout.reconfigure(encoding='utf-8')
except (AttributeError, TypeError):
import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
main()