Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import joblib | |
| import hashlib | |
| from collections import Counter | |
| import numpy as np | |
| import gradio as gr | |
| # --- utils --- | |
| def ngrams(sentence, n=1, lc=True): | |
| sentence = sentence.lower() | |
| return [sentence[i:i+n] for i in range(len(sentence) - n + 1)] | |
| def all_ngrams(sentence, max_ngram=3, lc=True): | |
| result = [] | |
| for i in range(1, max_ngram + 1): | |
| result += [ngrams(sentence, n=i, lc=lc)] | |
| return result | |
| MAX_CHARS = 521 | |
| MAX_BIGRAMS = 1031 | |
| MAX_TRIGRAMS = 1031 | |
| MAXES = [MAX_CHARS, MAX_BIGRAMS, MAX_TRIGRAMS] | |
| def reproducible_hash(string): | |
| h = hashlib.md5(string.encode("utf-8"), usedforsecurity=False) | |
| return int.from_bytes(h.digest()[0:8], "big", signed=True) | |
| def hash_ngrams(ngrams, modulos): | |
| out = [] | |
| for ngram_list, modulo in zip(ngrams, modulos): | |
| codes = [(reproducible_hash(x) % modulo) for x in ngram_list] | |
| out.append(codes) | |
| return out | |
| def calc_rel_freq(codes): | |
| cnt = Counter(codes) | |
| total = sum(cnt.values()) or 1 | |
| return {k: v / total for k, v in cnt.items()} | |
| MAX_SHIFT = [0] | |
| for i in range(1, len(MAXES)): | |
| MAX_SHIFT.append(sum(MAXES[:i])) | |
| def shift_keys(dicts, shift_list): | |
| new = {} | |
| for i, d in enumerate(dicts): | |
| for k, v in d.items(): | |
| new[k + shift_list[i]] = v | |
| return new | |
| def build_freq_dict(sentence): | |
| hngrams = hash_ngrams(all_ngrams(sentence), MAXES) | |
| freqs = list(map(calc_rel_freq, hngrams)) | |
| return shift_keys(freqs, MAX_SHIFT) | |
| # --- load artifacts --- | |
| vectorizer = joblib.load("nld_vectorizer.joblib") | |
| idx2lang = joblib.load("nld_lang_codes.joblib") | |
| input_dim = len(vectorizer.vocabulary_) | |
| num_classes = len(idx2lang) | |
| model = nn.Sequential( | |
| nn.Linear(input_dim, 50), | |
| nn.ReLU(), | |
| nn.Linear(50, num_classes) | |
| ) | |
| model.load_state_dict(torch.load("nld (1).pth", map_location="cpu")) | |
| model.eval() | |
| # --- prediction --- | |
| def detect_lang(text: str): | |
| feat_dict = build_freq_dict(text) | |
| X = vectorizer.transform([feat_dict]) | |
| if hasattr(X, "toarray"): | |
| X = X.toarray() | |
| X = torch.from_numpy(X.astype("float32")) | |
| with torch.no_grad(): | |
| logits = model(X) | |
| pred_idx = torch.argmax(logits, dim=-1).item() | |
| return idx2lang[pred_idx] | |
| # --- UI --- | |
| with gr.Blocks(title="Language Detector") as demo: | |
| gr.Markdown("# Language Detector") | |
| with gr.Row(): | |
| with gr.Column(): | |
| src_text = gr.Textbox(label="Enter text", placeholder="Type here...") | |
| btn = gr.Button("Detect Language") | |
| with gr.Column(): | |
| out_lang = gr.Textbox(label="Predicted language", interactive=False) | |
| btn.click(fn=detect_lang, inputs=src_text, outputs=out_lang) | |
| demo.launch() | |