File size: 3,042 Bytes
955c212
 
 
 
 
 
325ea21
 
955c212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325ea21
955c212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
749f9f2
 
325ea21
 
955c212
 
325ea21
955c212
 
 
 
325ea21
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import torch
import torch.nn as nn
import hashlib
import joblib
from collections import Counter
import gradio as gr

def ngrams(sentence, n=1, lc=True):
    ngram_l = []
    if lc:
        sentence = sentence.lower()
    for i in range(len(sentence) - n + 1):
        ngram = sentence[i:i+n]
        ngram_l.append(ngram)

    return ngram_l

def all_ngrams(sentence, max_ngram=3, lc=True):
    all_ngram_list = []
    for i in range(1, max_ngram + 1):
        all_ngram_list += [ngrams(sentence, n=i, lc=lc)]
    return all_ngram_list

MAX_CHARS = 521
MAX_BIGRAMS = 1031
MAX_TRIGRAMS = 1031
MAXES = [MAX_CHARS, MAX_BIGRAMS, MAX_TRIGRAMS]

def reproducible_hash(string):
    # We are using MD5 for speed not security.
    h = hashlib.md5(string.encode("utf-8"), usedforsecurity=False)
    return int.from_bytes(h.digest()[0:8], 'big', signed=True)

def hash_ngrams(ngrams, modulos):
    hash_codes = []
    for grams, mod in zip(ngrams, modulos):
        codes = [reproducible_hash(g) % mod for g in grams]
        hash_codes.append(codes)
    return hash_codes

def calc_rel_freq(codes):
    cnt = Counter(codes)
    total = sum(cnt.values())
    for k in cnt:
        cnt[k] /= total
    return cnt
    
MAX_SHIFT = []
for i in range(len(MAXES)):
    MAX_SHIFT += [sum(MAXES[:i])]
    
def shift_keys(dicts, MAX_SHIFT):
    new_dict = {}
    for i, ngrams_d in enumerate(dicts):
        for k, v in ngrams_d.items():
            new_dict[k + MAX_SHIFT[i]] = v
    return new_dict

def build_freq_dict(sentence, MAXES=MAXES, MAX_SHIFT=MAX_SHIFT):
    hngrams = hash_ngrams(all_ngrams(sentence), MAXES)
    fhcodes = map(calc_rel_freq, hngrams)
    return shift_keys(fhcodes, MAX_SHIFT)

# --- load models ---
clf = joblib.load("nld.joblib")
vectorizer = joblib.load("nld_vectorizer.joblib")
idx2lang = joblib.load("nld_lang_codes.joblib")

input_dim = len(vectorizer.vocabulary_)
nbr_classes = len(idx2lang)

model = nn.Sequential(
    nn.Linear(input_dim, 50),
    nn.ReLU(),
    nn.Linear(50, nbr_classes)
)
model.load_state_dict(torch.load("nld.pth", map_location="cpu"))
model.eval()

# --- prediction function ---
def detect_lang(src_sentence):
    src_sentence = [src_sentence]
    X_test = vectorizer.transform(map(build_freq_dict, src_sentence))
    if hasattr(X_test, "toarray"):
        X_test = X_test.toarray()
    Y_logits = model(torch.Tensor(X_test))
    pred_languages = torch.argmax(Y_logits, dim=-1).tolist()
    return list(map(idx2lang.get, pred_languages))[0]
   
# --- Gradio UI ---
with gr.Blocks(title="A language detector") as demo:
    gr.Markdown("# A language detector")
    with gr.Row():
        with gr.Column():
            src_sentence = gr.Textbox(
                label="Text", placeholder="Write your text...")
        with gr.Column():
            tgt_sentence = gr.Textbox(
                label="Language", placeholder="Language will show here...")
    btn = gr.Button("Guess the language!")
    btn.click(fn=detect_lang, inputs=[src_sentence], outputs=[tgt_sentence])

demo.launch()