Spaces:
Sleeping
Sleeping
| import torch | |
| import re | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM | |
| # --- Language Detection Setup --- | |
| DET_MODEL = "papluca/xlm-roberta-base-language-detection" | |
| det_tokenizer = AutoTokenizer.from_pretrained(DET_MODEL) | |
| det_model = AutoModelForSequenceClassification.from_pretrained(DET_MODEL) | |
| # Unicode ranges for Indic scripts (Your Hard Overrides) | |
| INDIC_RANGES = { | |
| "ta": re.compile(r"[\u0B80-\u0BFF]"), | |
| "te": re.compile(r"[\u0C00-\u0C7F]"), | |
| "kn": re.compile(r"[\u0C80-\u0CFF]"), | |
| "ml": re.compile(r"[\u0D00-\u0D7F]"), | |
| "hi": re.compile(r"[\u0900-\u097F]"), | |
| } | |
| def detect_language(text: str) -> str: | |
| for lang, regex in INDIC_RANGES.items(): | |
| if regex.search(text): | |
| return lang | |
| inputs = det_tokenizer(text, return_tensors="pt", truncation=True) | |
| with torch.no_grad(): | |
| logits = det_model(**inputs).logits | |
| lang_id = torch.argmax(logits, dim=1).item() | |
| return det_model.config.id2label[lang_id] | |
| # --- Translation Setup --- | |
| TRANS_MODEL = "facebook/nllb-200-distilled-600M" | |
| trans_tokenizer = AutoTokenizer.from_pretrained(TRANS_MODEL) | |
| trans_model = AutoModelForSeq2SeqLM.from_pretrained(TRANS_MODEL) | |
| LANG_MAP = { | |
| "en": "eng_Latn", | |
| "hi": "hin_Deva", | |
| "ta": "tam_Taml", | |
| "te": "tel_Telu", | |
| "kn": "kan_Knda", | |
| "ml": "mal_Mlym", | |
| "mr": "mar_Deva", | |
| "bn": "ben_Beng", | |
| } | |
| def translate_text(text: str, target_lang_code: str) -> str: | |
| # Get NLLB specific code (default to English) | |
| tgt_nllb = LANG_MAP.get(target_lang_code, "eng_Latn") | |
| inputs = trans_tokenizer(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| output = trans_model.generate( | |
| **inputs, | |
| forced_bos_token_id=trans_tokenizer.convert_tokens_to_ids(tgt_nllb), | |
| max_length=256 | |
| ) | |
| return trans_tokenizer.decode(output[0], skip_special_tokens=True) |