chatbotdeep / language.py
Kamaljeyaram07's picture
Create language.py
ec820b3 verified
import torch
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
# --- Language Detection Setup ---
DET_MODEL = "papluca/xlm-roberta-base-language-detection"
det_tokenizer = AutoTokenizer.from_pretrained(DET_MODEL)
det_model = AutoModelForSequenceClassification.from_pretrained(DET_MODEL)
# Unicode ranges for Indic scripts (Your Hard Overrides)
INDIC_RANGES = {
"ta": re.compile(r"[\u0B80-\u0BFF]"),
"te": re.compile(r"[\u0C00-\u0C7F]"),
"kn": re.compile(r"[\u0C80-\u0CFF]"),
"ml": re.compile(r"[\u0D00-\u0D7F]"),
"hi": re.compile(r"[\u0900-\u097F]"),
}
def detect_language(text: str) -> str:
for lang, regex in INDIC_RANGES.items():
if regex.search(text):
return lang
inputs = det_tokenizer(text, return_tensors="pt", truncation=True)
with torch.no_grad():
logits = det_model(**inputs).logits
lang_id = torch.argmax(logits, dim=1).item()
return det_model.config.id2label[lang_id]
# --- Translation Setup ---
TRANS_MODEL = "facebook/nllb-200-distilled-600M"
trans_tokenizer = AutoTokenizer.from_pretrained(TRANS_MODEL)
trans_model = AutoModelForSeq2SeqLM.from_pretrained(TRANS_MODEL)
LANG_MAP = {
"en": "eng_Latn",
"hi": "hin_Deva",
"ta": "tam_Taml",
"te": "tel_Telu",
"kn": "kan_Knda",
"ml": "mal_Mlym",
"mr": "mar_Deva",
"bn": "ben_Beng",
}
def translate_text(text: str, target_lang_code: str) -> str:
# Get NLLB specific code (default to English)
tgt_nllb = LANG_MAP.get(target_lang_code, "eng_Latn")
inputs = trans_tokenizer(text, return_tensors="pt")
with torch.no_grad():
output = trans_model.generate(
**inputs,
forced_bos_token_id=trans_tokenizer.convert_tokens_to_ids(tgt_nllb),
max_length=256
)
return trans_tokenizer.decode(output[0], skip_special_tokens=True)