File size: 2,706 Bytes
c2ae6a3
 
 
50bcdf5
c2ae6a3
50bcdf5
c2ae6a3
 
50bcdf5
c2ae6a3
 
50bcdf5
c2ae6a3
 
50bcdf5
c2ae6a3
50bcdf5
 
c2ae6a3
 
 
 
 
 
 
 
50bcdf5
c2ae6a3
b8828ec
50bcdf5
b8828ec
c2ae6a3
50bcdf5
 
c2ae6a3
 
 
50bcdf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2ae6a3
50bcdf5
 
c2ae6a3
50bcdf5
b8828ec
 
c2ae6a3
 
50bcdf5
c2ae6a3
 
 
 
50bcdf5
c2ae6a3
666e224
c2ae6a3
 
50bcdf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2ae6a3
b8828ec
50bcdf5
 
b8828ec
50bcdf5
 
b8828ec
50bcdf5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
import joblib
import hashlib
from collections import Counter
import numpy as np
import gradio as gr

# --- utils ---
def ngrams(sentence, n=1, lc=True):
    sentence = sentence.lower()
    return [sentence[i:i+n] for i in range(len(sentence) - n + 1)]

def all_ngrams(sentence, max_ngram=3, lc=True):
    result = []
    for i in range(1, max_ngram + 1):
        result += [ngrams(sentence, n=i, lc=lc)]
    return result

MAX_CHARS = 521
MAX_BIGRAMS = 1031
MAX_TRIGRAMS = 1031
MAXES = [MAX_CHARS, MAX_BIGRAMS, MAX_TRIGRAMS]

def reproducible_hash(string):
    h = hashlib.md5(string.encode("utf-8"), usedforsecurity=False)
    return int.from_bytes(h.digest()[0:8], "big", signed=True)

def hash_ngrams(ngrams, modulos):
    out = []
    for ngram_list, modulo in zip(ngrams, modulos):
        codes = [(reproducible_hash(x) % modulo) for x in ngram_list]
        out.append(codes)
    return out

def calc_rel_freq(codes):
    cnt = Counter(codes)
    total = sum(cnt.values()) or 1
    return {k: v / total for k, v in cnt.items()}

MAX_SHIFT = [0]
for i in range(1, len(MAXES)):
    MAX_SHIFT.append(sum(MAXES[:i]))

def shift_keys(dicts, shift_list):
    new = {}
    for i, d in enumerate(dicts):
        for k, v in d.items():
            new[k + shift_list[i]] = v
    return new

def build_freq_dict(sentence):
    hngrams = hash_ngrams(all_ngrams(sentence), MAXES)
    freqs = list(map(calc_rel_freq, hngrams))
    return shift_keys(freqs, MAX_SHIFT)

# --- load artifacts ---
vectorizer = joblib.load("nld_vectorizer.joblib")
idx2lang = joblib.load("nld_lang_codes.joblib")

input_dim = len(vectorizer.vocabulary_)
num_classes = len(idx2lang)

model = nn.Sequential(
    nn.Linear(input_dim, 50),
    nn.ReLU(),
    nn.Linear(50, num_classes)
)
model.load_state_dict(torch.load("nld (1).pth", map_location="cpu"))
model.eval()

# --- prediction ---
def detect_lang(text: str):
    feat_dict = build_freq_dict(text)
    X = vectorizer.transform([feat_dict])
    if hasattr(X, "toarray"):
        X = X.toarray()
    X = torch.from_numpy(X.astype("float32"))

    with torch.no_grad():
        logits = model(X)
        pred_idx = torch.argmax(logits, dim=-1).item()
    return idx2lang[pred_idx]

# --- UI ---
with gr.Blocks(title="Language Detector") as demo:
    gr.Markdown("# Language Detector")
    with gr.Row():
        with gr.Column():
            src_text = gr.Textbox(label="Enter text", placeholder="Type here...")
            btn = gr.Button("Detect Language")
        with gr.Column():
            out_lang = gr.Textbox(label="Predicted language", interactive=False)
    btn.click(fn=detect_lang, inputs=src_text, outputs=out_lang)

demo.launch()