Kamaljeyaram07 commited on
Commit
b931e4a
·
verified ·
1 Parent(s): 3286cbc

Update language.py

Browse files
Files changed (1) hide show
  1. language.py +22 -16
language.py CHANGED
@@ -1,30 +1,36 @@
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  import torch
 
3
 
4
  MODEL_NAME = "papluca/xlm-roberta-base-language-detection"
5
 
6
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
7
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
8
 
9
- LANG_MAP = {
10
- "en": "eng",
11
- "hi": "hin",
12
- "ta": "tam",
13
- "te": "tel",
14
- "kn": "kan",
15
- "ml": "mal",
16
- "mr": "mar",
17
- "bn": "ben",
18
- }
19
 
20
  def detect_language(text: str) -> str:
21
- inputs = tokenizer(text, return_tensors="pt", truncation=True)
 
 
 
 
 
 
 
 
 
 
22
 
 
 
23
  with torch.no_grad():
24
  logits = model(**inputs).logits
25
 
26
- predicted_id = torch.argmax(logits, dim=1).item()
27
- lang = model.config.id2label[predicted_id]
28
-
29
- # Default to English if not supported
30
- return lang if lang in LANG_MAP else "en"
 
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  import torch
3
+ import re
4
 
5
  MODEL_NAME = "papluca/xlm-roberta-base-language-detection"
6
 
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
9
 
10
+ # Unicode ranges for Indic scripts
11
+ TAMIL_RANGE = re.compile(r"[\u0B80-\u0BFF]")
12
+ TELUGU_RANGE = re.compile(r"[\u0C00-\u0C7F]")
13
+ KANNADA_RANGE = re.compile(r"[\u0C80-\u0CFF]")
14
+ MALAYALAM_RANGE = re.compile(r"[\u0D00-\u0D7F]")
15
+ DEVANAGARI_RANGE = re.compile(r"[\u0900-\u097F]")
 
 
 
 
16
 
17
  def detect_language(text: str) -> str:
18
+ # 🔒 HARD OVERRIDES (MOST IMPORTANT)
19
+ if TAMIL_RANGE.search(text):
20
+ return "ta"
21
+ if TELUGU_RANGE.search(text):
22
+ return "te"
23
+ if KANNADA_RANGE.search(text):
24
+ return "kn"
25
+ if MALAYALAM_RANGE.search(text):
26
+ return "ml"
27
+ if DEVANAGARI_RANGE.search(text):
28
+ return "hi"
29
 
30
+ # Fallback to ML detection
31
+ inputs = tokenizer(text, return_tensors="pt", truncation=True)
32
  with torch.no_grad():
33
  logits = model(**inputs).logits
34
 
35
+ lang_id = torch.argmax(logits, dim=1).item()
36
+ return model.config.id2label[lang_id]