Jia0603 commited on
Commit
955c212
·
verified ·
1 Parent(s): 1fafd7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -8
app.py CHANGED
@@ -1,16 +1,104 @@
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def copy(src_sentence):
4
- return src_sentence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- with gr.Blocks(title="Copier") as demo:
7
- gr.Markdown("# Copier")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  with gr.Row():
9
  with gr.Column():
10
- src_sentence = gr.Textbox(label="Source text", placeholder="Write your text...")
 
11
  with gr.Column():
12
- tgt_sentence = gr.Textbox(label="Copy", placeholder="Copy will show here...")
13
- btn = gr.Button("Copy!")
14
- btn.click(fn=copy, inputs=[src_sentence], outputs=[tgt_sentence])
 
15
 
16
  demo.launch()
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import hashlib
5
+ import joblib
6
+ from collections import Counter
7
  import gradio as gr
8
 
9
+ def ngrams(sentence, n=1, lc=True):
10
+ ngram_l = []
11
+ if lc:
12
+ sentence = sentence.lower()
13
+ for i in range(len(sentence) - n + 1):
14
+ ngram = sentence[i:i+n]
15
+ ngram_l.append(ngram)
16
+
17
+ return ngram_l
18
+
19
+ def all_ngrams(sentence, max_ngram=3, lc=True):
20
+ all_ngram_list = []
21
+ for i in range(1, max_ngram + 1):
22
+ all_ngram_list += [ngrams(sentence, n=i, lc=lc)]
23
+ return all_ngram_list
24
+
25
+ MAX_CHARS = 521
26
+ MAX_BIGRAMS = 1031
27
+ MAX_TRIGRAMS = 1031
28
+ MAXES = [MAX_CHARS, MAX_BIGRAMS, MAX_TRIGRAMS]
29
+
30
+ def reproducible_hash(string):
31
+ # We are using MD5 for speed not security.
32
+ h = hashlib.md5(string.encode("utf-8"), usedforsecurity=False)
33
+ return int.from_bytes(h.digest()[0:8], 'big', signed=True)
34
+
35
+ def hash_ngrams(ngrams, modulos):
36
+ hash_codes = []
37
+ for grams, mod in zip(ngrams, modulos):
38
+ codes = [reproducible_hash(g) % mod for g in grams]
39
+ hash_codes.append(codes)
40
+ return hash_codes
41
+
42
+ def calc_rel_freq(codes):
43
+ cnt = Counter(codes)
44
+ total = sum(cnt.values())
45
+ for k in cnt:
46
+ cnt[k] /= total
47
+ return cnt
48
 
49
+ MAX_SHIFT = []
50
+ for i in range(len(MAXES)):
51
+ MAX_SHIFT += [sum(MAXES[:i])]
52
+
53
+ def shift_keys(dicts, MAX_SHIFT):
54
+ new_dict = {}
55
+ for i, ngrams_d in enumerate(dicts):
56
+ for k, v in ngrams_d.items():
57
+ new_dict[k + MAX_SHIFT[i]] = v
58
+ return new_dict
59
+
60
+ def build_freq_dict(sentence, MAXES=MAXES, MAX_SHIFT=MAX_SHIFT):
61
+ hngrams = hash_ngrams(all_ngrams(sentence), MAXES)
62
+ fhcodes = map(calc_rel_freq, hngrams)
63
+ return shift_keys(fhcodes, MAX_SHIFT)
64
+
65
+ # --- load models ---
66
+ clf = joblib.load("nld.joblib")
67
+ vectorizer = joblib.load("nld_vectorizer.joblib")
68
+ idx2lang = joblib.load("nld_lang_codes.joblib")
69
+
70
+ input_dim = len(vectorizer.vocabulary_)
71
+ nbr_classes = len(idx2lang)
72
+
73
+ model = nn.Sequential(
74
+ nn.Linear(input_dim, 50),
75
+ nn.ReLU(),
76
+ nn.Linear(50, nbr_classes)
77
+ )
78
+ model.load_state_dict(torch.load("nld.pth", map_location="cpu"))
79
+ model.eval()
80
+
81
+ # --- prediction function ---
82
+ def detect_lang(src_sentence):
83
+ src_sentence = [src_sentence]
84
+ X_test = vectorizer.transform(map(build_freq_dict, src_sentence))
85
+ if hasattr(X_test, "toarray"):
86
+ X_test = X_test.toarray()
87
+ Y_logits = model(torch.Tensor(X_test))
88
+ pred_languages = torch.argmax(Y_logits, dim=-1).tolist()
89
+ return list(map(idx2lang.get, pred_languages))[0]
90
+
91
+ # --- Gradio UI ---
92
+ with gr.Blocks(title="Antons language detector") as demo:
93
+ gr.Markdown("# Antons language detector")
94
  with gr.Row():
95
  with gr.Column():
96
+ src_sentence = gr.Textbox(
97
+ label="Text", placeholder="Write your text...")
98
  with gr.Column():
99
+ tgt_sentence = gr.Textbox(
100
+ label="Language", placeholder="Language will show here...")
101
+ btn = gr.Button("Guess the language!")
102
+ btn.click(fn=detect_lang, inputs=[src_sentence], outputs=[tgt_sentence])
103
 
104
  demo.launch()