Zhe-Zhang's picture
Update app.py
666e224 verified
raw
history blame
2.71 kB
import torch
import torch.nn as nn
import joblib
import hashlib
from collections import Counter
import numpy as np
import gradio as gr
# --- utils ---
def ngrams(sentence, n=1, lc=True):
sentence = sentence.lower()
return [sentence[i:i+n] for i in range(len(sentence) - n + 1)]
def all_ngrams(sentence, max_ngram=3, lc=True):
result = []
for i in range(1, max_ngram + 1):
result += [ngrams(sentence, n=i, lc=lc)]
return result
MAX_CHARS = 521
MAX_BIGRAMS = 1031
MAX_TRIGRAMS = 1031
MAXES = [MAX_CHARS, MAX_BIGRAMS, MAX_TRIGRAMS]
def reproducible_hash(string):
h = hashlib.md5(string.encode("utf-8"), usedforsecurity=False)
return int.from_bytes(h.digest()[0:8], "big", signed=True)
def hash_ngrams(ngrams, modulos):
out = []
for ngram_list, modulo in zip(ngrams, modulos):
codes = [(reproducible_hash(x) % modulo) for x in ngram_list]
out.append(codes)
return out
def calc_rel_freq(codes):
cnt = Counter(codes)
total = sum(cnt.values()) or 1
return {k: v / total for k, v in cnt.items()}
MAX_SHIFT = [0]
for i in range(1, len(MAXES)):
MAX_SHIFT.append(sum(MAXES[:i]))
def shift_keys(dicts, shift_list):
new = {}
for i, d in enumerate(dicts):
for k, v in d.items():
new[k + shift_list[i]] = v
return new
def build_freq_dict(sentence):
hngrams = hash_ngrams(all_ngrams(sentence), MAXES)
freqs = list(map(calc_rel_freq, hngrams))
return shift_keys(freqs, MAX_SHIFT)
# --- load artifacts ---
vectorizer = joblib.load("nld_vectorizer.joblib")
idx2lang = joblib.load("nld_lang_codes.joblib")
input_dim = len(vectorizer.vocabulary_)
num_classes = len(idx2lang)
model = nn.Sequential(
nn.Linear(input_dim, 50),
nn.ReLU(),
nn.Linear(50, num_classes)
)
model.load_state_dict(torch.load("nld (1).pth", map_location="cpu"))
model.eval()
# --- prediction ---
def detect_lang(text: str):
feat_dict = build_freq_dict(text)
X = vectorizer.transform([feat_dict])
if hasattr(X, "toarray"):
X = X.toarray()
X = torch.from_numpy(X.astype("float32"))
with torch.no_grad():
logits = model(X)
pred_idx = torch.argmax(logits, dim=-1).item()
return idx2lang[pred_idx]
# --- UI ---
with gr.Blocks(title="Language Detector") as demo:
gr.Markdown("# Language Detector")
with gr.Row():
with gr.Column():
src_text = gr.Textbox(label="Enter text", placeholder="Type here...")
btn = gr.Button("Detect Language")
with gr.Column():
out_lang = gr.Textbox(label="Predicted language", interactive=False)
btn.click(fn=detect_lang, inputs=src_text, outputs=out_lang)
demo.launch()