File size: 5,825 Bytes
54c7199 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import re
import gradio as gr
import fasttext
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
# ----------------------------
# Download required models
# ----------------------------
print("Downloading IndicLID models...")
FTN_PATH = hf_hub_download("ai4bharat/IndicLID-FTN", filename="model_baseline_roman.bin")
FTR_PATH = hf_hub_download("ai4bharat/IndicLID-FTR", filename="model_baseline_roman.bin")
BERT_PATH = hf_hub_download("ai4bharat/IndicLID-BERT", filename="basline_nn_simple.pt")
print("Download complete.")
# ----------------------------
# Data helper for BERT batching
# ----------------------------
class IndicBERT_Data(Dataset):
def __init__(self, indices, X):
self.x = list(X)
self.i = list(indices)
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return self.i[idx], self.x[idx]
# ----------------------------
# IndicLID Class
# ----------------------------
class IndicLID:
def __init__(self, input_threshold=0.5, roman_lid_threshold=0.6):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.FTN = fasttext.load_model(FTN_PATH)
self.FTR = fasttext.load_model(FTR_PATH)
self.BERT = torch.load(BERT_PATH, map_location=self.device)
self.BERT.eval()
self.tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicBERTv2-MLM-only")
self.input_threshold = input_threshold
self.model_threshold = roman_lid_threshold
# Official label map
self.label_map_reverse = {
0:'asm_Latn',1:'ben_Latn',2:'brx_Latn',3:'guj_Latn',4:'hin_Latn',
5:'kan_Latn',6:'kas_Latn',7:'kok_Latn',8:'mai_Latn',9:'mal_Latn',
10:'mni_Latn',11:'mar_Latn',12:'nep_Latn',13:'ori_Latn',14:'pan_Latn',
15:'san_Latn',16:'snd_Latn',17:'tam_Latn',18:'tel_Latn',19:'urd_Latn',
20:'eng_Latn',21:'other',22:'asm_Beng',23:'ben_Beng',24:'brx_Deva',
25:'doi_Deva',26:'guj_Gujr',27:'hin_Deva',28:'kan_Knda',29:'kas_Arab',
30:'kas_Deva',31:'kok_Deva',32:'mai_Deva',33:'mal_Mlym',34:'mni_Beng',
35:'mni_Meti',36:'mar_Deva',37:'nep_Deva',38:'ori_Orya',39:'pan_Guru',
40:'san_Deva',41:'sat_Olch',42:'snd_Arab',43:'tam_Tamil',44:'tel_Telu',
45:'urd_Arab'
}
def char_percent_check(self, text):
total_chars = sum(c.isalpha() for c in text)
roman_chars = sum(bool(re.match(r"[A-Za-z]", c)) for c in text)
return roman_chars / total_chars if total_chars else 0
def native_inference(self, data, out_dict):
if not data: return out_dict
texts = [x[1] for x in data]
preds = self.FTN.predict(texts)
for (idx, txt), lbls, scrs in zip(data, preds[0], preds[1]):
out_dict[idx] = {"text": txt, "label": lbls[0][9:], "score": float(scrs[0]), "model": "FTN"}
return out_dict
def ftr_inference(self, data, out_dict, batch_size):
if not data: return out_dict
texts = [x[1] for x in data]
preds = self.FTR.predict(texts)
bert_inputs = []
for (idx, txt), lbls, scrs in zip(data, preds[0], preds[1]):
if float(scrs[0]) > self.model_threshold:
out_dict[idx] = {"text": txt, "label": lbls[0][9:], "score": float(scrs[0]), "model": "FTR"}
else:
bert_inputs.append((idx, txt))
return self.bert_inference(bert_inputs, out_dict, batch_size)
def bert_inference(self, data, out_dict, batch_size):
if not data: return out_dict
ds = IndicBERT_Data([x[0] for x in data], [x[1] for x in data])
dl = DataLoader(ds, batch_size=batch_size)
with torch.no_grad():
for idxs, texts in dl:
enc = self.tokenizer(list(texts), return_tensors="pt", padding=True,
truncation=True, max_length=512).to(self.device)
outputs = self.BERT(**enc)
preds = torch.argmax(outputs.logits, dim=1)
probs = torch.softmax(outputs.logits, dim=1)
for batch_i, p in enumerate(preds):
i = idxs[batch_i].item()
label_idx = p.item()
out_dict[i] = {
"text": texts[batch_i],
"label": self.label_map_reverse[label_idx],
"score": probs[batch_i, label_idx].item(),
"model": "BERT"
}
return out_dict
def batch_predict(self, texts, batch_size=8):
native, roman = [], []
for i, t in enumerate(texts):
if self.char_percent_check(t) > self.input_threshold:
roman.append((i, t))
else:
native.append((i, t))
out_dict = {}
out_dict = self.native_inference(native, out_dict)
out_dict = self.ftr_inference(roman, out_dict, batch_size)
return [out_dict[i] for i in sorted(out_dict.keys())]
# ----------------------------
# Gradio UI
# ----------------------------
lid_model = IndicLID()
def detect(text_block):
lines = [l.strip() for l in text_block.splitlines() if l.strip()]
if not lines:
return []
return lid_model.batch_predict(lines)
with gr.Blocks(title="IndicLID by AI4Bharat") as demo:
gr.Markdown("## IndicLID (AI4Bharat) — Full Ensemble\nDetects Indian languages in native & roman scripts.")
inp = gr.Textbox(lines=8, label="Enter one sentence per line")
out = gr.JSON(label="Predictions")
btn = gr.Button("Detect Language")
btn.click(fn=detect, inputs=inp, outputs=out)
if __name__ == "__main__":
demo.launch()
|