Zhe-Zhang commited on
Commit
51a3f9c
·
verified ·
1 Parent(s): c0a30f7

upload 5 files

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import hashlib
5
+ import joblib
6
+ from collections import Counter
7
+ import gradio as gr
8
+
9
+ # --- utils (from the notebook) ---
10
+ def ngrams(sentence, n=1, lc=True):
11
+ ngram_l = []
12
+ sentence = sentence.lower()
13
+ for i in range(len(sentence) - n + 1):
14
+ ngram = sentence[i:i+n]
15
+ ngram_l.append(ngram)
16
+ return ngram_l
17
+
18
+ def all_ngrams(sentence, max_ngram=3, lc=True):
19
+ all_ngram_list = []
20
+ for i in range(1, max_ngram + 1):
21
+ all_ngram_list += [ngrams(sentence, n=i, lc=lc)]
22
+ return all_ngram_list
23
+
24
+ MAX_CHARS = 521
25
+ MAX_BIGRAMS = 1031
26
+ MAX_TRIGRAMS = 1031
27
+ MAXES = [MAX_CHARS, MAX_BIGRAMS, MAX_TRIGRAMS]
28
+
29
+ def reproducible_hash(string):
30
+ h = hashlib.md5(string.encode("utf-8"), usedforsecurity=False)
31
+ return int.from_bytes(h.digest()[0:8], 'big', signed=True)
32
+
33
+ def hash_ngrams(ngrams, modulos):
34
+ hash_codes = []
35
+ for ngram_list, modulo in zip(ngrams, modulos):
36
+ codes = [(reproducible_hash(x) % modulo) for x in ngram_list]
37
+ hash_codes.append(codes)
38
+ return hash_codes
39
+
40
+ def calc_rel_freq(codes):
41
+ cnt = Counter(codes)
42
+ total = sum(cnt.values())
43
+ for k in cnt:
44
+ cnt[k] /= total
45
+ return cnt
46
+
47
+ MAX_SHIFT = []
48
+ for i in range(len(MAXES)):
49
+ MAX_SHIFT += [sum(MAXES[:i])]
50
+
51
+ def shift_keys(dicts, MAX_SHIFT):
52
+ new_dict = {}
53
+ for i, ngrams_d in enumerate(dicts):
54
+ for k, v in ngrams_d.items():
55
+ new_dict[k + MAX_SHIFT[i]] = v
56
+ return new_dict
57
+
58
+ def build_freq_dict(sentence, MAXES=MAXES, MAX_SHIFT=MAX_SHIFT):
59
+ hngrams = hash_ngrams(all_ngrams(sentence), MAXES)
60
+ fhcodes = map(calc_rel_freq, hngrams)
61
+ return shift_keys(fhcodes, MAX_SHIFT)
62
+
63
+ # --- load models ---
64
+ clf = joblib.load("nld.joblib")
65
+ vectorizer = joblib.load("nld_vectorizer.joblib")
66
+ idx2lang = joblib.load("nld_lang_codes.joblib")
67
+
68
+ input_dim = len(vectorizer.vocabulary_)
69
+ nbr_classes = len(idx2lang)
70
+
71
+ model = nn.Sequential(
72
+ nn.Linear(input_dim, 50),
73
+ nn.ReLU(),
74
+ nn.Linear(50, nbr_classes)
75
+ )
76
+ model.load_state_dict(torch.load("nld.pth", map_location="cpu"))
77
+ model.eval()
78
+
79
+ # --- prediction function ---
80
+ def detect_lang(src_sentence):
81
+ src_sentence = [src_sentence]
82
+ X_test = vectorizer.transform(map(build_freq_dict, src_sentence))
83
+ if hasattr(X_test, "toarray"):
84
+ X_test = X_test.toarray()
85
+ Y_logits = model(torch.Tensor(X_test))
86
+ pred_languages = torch.argmax(Y_logits, dim=-1).tolist()
87
+ return list(map(idx2lang.get, pred_languages))[0]
88
+
89
+ # --- Gradio UI ---
90
+ with gr.Blocks(title="Antons language detector") as demo:
91
+ gr.Markdown("# Antons language detector")
92
+ with gr.Row():
93
+ with gr.Column():
94
+ src_sentence = gr.Textbox(
95
+ label="Text", placeholder="Write your text...")
96
+ with gr.Column():
97
+ tgt_sentence = gr.Textbox(
98
+ label="Language", placeholder="Language will show here...")
99
+ btn = gr.Button("Guess the language!")
100
+ btn.click(fn=detect_lang, inputs=[src_sentence], outputs=[tgt_sentence])
101
+
102
+ demo.launch()