Spaces:
Sleeping
Sleeping
Update language.py
Browse files- language.py +22 -16
language.py
CHANGED
|
@@ -1,30 +1,36 @@
|
|
| 1 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 2 |
import torch
|
|
|
|
| 3 |
|
| 4 |
MODEL_NAME = "papluca/xlm-roberta-base-language-detection"
|
| 5 |
|
| 6 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 7 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
"ml": "mal",
|
| 16 |
-
"mr": "mar",
|
| 17 |
-
"bn": "ben",
|
| 18 |
-
}
|
| 19 |
|
| 20 |
def detect_language(text: str) -> str:
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
|
|
|
|
|
|
| 23 |
with torch.no_grad():
|
| 24 |
logits = model(**inputs).logits
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
# Default to English if not supported
|
| 30 |
-
return lang if lang in LANG_MAP else "en"
|
|
|
|
| 1 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 2 |
import torch
|
| 3 |
+
import re
|
| 4 |
|
| 5 |
MODEL_NAME = "papluca/xlm-roberta-base-language-detection"
|
| 6 |
|
| 7 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 8 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
|
| 9 |
|
| 10 |
+
# Unicode ranges for Indic scripts
|
| 11 |
+
TAMIL_RANGE = re.compile(r"[\u0B80-\u0BFF]")
|
| 12 |
+
TELUGU_RANGE = re.compile(r"[\u0C00-\u0C7F]")
|
| 13 |
+
KANNADA_RANGE = re.compile(r"[\u0C80-\u0CFF]")
|
| 14 |
+
MALAYALAM_RANGE = re.compile(r"[\u0D00-\u0D7F]")
|
| 15 |
+
DEVANAGARI_RANGE = re.compile(r"[\u0900-\u097F]")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
def detect_language(text: str) -> str:
|
| 18 |
+
# 🔒 HARD OVERRIDES (MOST IMPORTANT)
|
| 19 |
+
if TAMIL_RANGE.search(text):
|
| 20 |
+
return "ta"
|
| 21 |
+
if TELUGU_RANGE.search(text):
|
| 22 |
+
return "te"
|
| 23 |
+
if KANNADA_RANGE.search(text):
|
| 24 |
+
return "kn"
|
| 25 |
+
if MALAYALAM_RANGE.search(text):
|
| 26 |
+
return "ml"
|
| 27 |
+
if DEVANAGARI_RANGE.search(text):
|
| 28 |
+
return "hi"
|
| 29 |
|
| 30 |
+
# Fallback to ML detection
|
| 31 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True)
|
| 32 |
with torch.no_grad():
|
| 33 |
logits = model(**inputs).logits
|
| 34 |
|
| 35 |
+
lang_id = torch.argmax(logits, dim=1).item()
|
| 36 |
+
return model.config.id2label[lang_id]
|
|
|
|
|
|
|
|
|