Zhe-Zhang commited on
Commit
c2ae6a3
·
verified ·
1 Parent(s): 26a8e26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -101
app.py CHANGED
@@ -1,102 +1,102 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import hashlib
5
- import joblib
6
- from collections import Counter
7
- import gradio as gr
8
-
9
- # --- utils (from the notebook) ---
10
- def ngrams(sentence, n=1, lc=True):
11
- ngram_l = []
12
- sentence = sentence.lower()
13
- for i in range(len(sentence) - n + 1):
14
- ngram = sentence[i:i+n]
15
- ngram_l.append(ngram)
16
- return ngram_l
17
-
18
- def all_ngrams(sentence, max_ngram=3, lc=True):
19
- all_ngram_list = []
20
- for i in range(1, max_ngram + 1):
21
- all_ngram_list += [ngrams(sentence, n=i, lc=lc)]
22
- return all_ngram_list
23
-
24
- MAX_CHARS = 521
25
- MAX_BIGRAMS = 1031
26
- MAX_TRIGRAMS = 1031
27
- MAXES = [MAX_CHARS, MAX_BIGRAMS, MAX_TRIGRAMS]
28
-
29
- def reproducible_hash(string):
30
- h = hashlib.md5(string.encode("utf-8"), usedforsecurity=False)
31
- return int.from_bytes(h.digest()[0:8], 'big', signed=True)
32
-
33
- def hash_ngrams(ngrams, modulos):
34
- hash_codes = []
35
- for ngram_list, modulo in zip(ngrams, modulos):
36
- codes = [(reproducible_hash(x) % modulo) for x in ngram_list]
37
- hash_codes.append(codes)
38
- return hash_codes
39
-
40
- def calc_rel_freq(codes):
41
- cnt = Counter(codes)
42
- total = sum(cnt.values())
43
- for k in cnt:
44
- cnt[k] /= total
45
- return cnt
46
-
47
- MAX_SHIFT = []
48
- for i in range(len(MAXES)):
49
- MAX_SHIFT += [sum(MAXES[:i])]
50
-
51
- def shift_keys(dicts, MAX_SHIFT):
52
- new_dict = {}
53
- for i, ngrams_d in enumerate(dicts):
54
- for k, v in ngrams_d.items():
55
- new_dict[k + MAX_SHIFT[i]] = v
56
- return new_dict
57
-
58
- def build_freq_dict(sentence, MAXES=MAXES, MAX_SHIFT=MAX_SHIFT):
59
- hngrams = hash_ngrams(all_ngrams(sentence), MAXES)
60
- fhcodes = map(calc_rel_freq, hngrams)
61
- return shift_keys(fhcodes, MAX_SHIFT)
62
-
63
- # --- load models ---
64
- clf = joblib.load("nld.joblib")
65
- vectorizer = joblib.load("nld_vectorizer.joblib")
66
- idx2lang = joblib.load("nld_lang_codes.joblib")
67
-
68
- input_dim = len(vectorizer.vocabulary_)
69
- nbr_classes = len(idx2lang)
70
-
71
- model = nn.Sequential(
72
- nn.Linear(input_dim, 50),
73
- nn.ReLU(),
74
- nn.Linear(50, nbr_classes)
75
- )
76
- model.load_state_dict(torch.load("nld.pth", map_location="cpu"))
77
- model.eval()
78
-
79
- # --- prediction function ---
80
- def detect_lang(src_sentence):
81
- src_sentence = [src_sentence]
82
- X_test = vectorizer.transform(map(build_freq_dict, src_sentence))
83
- if hasattr(X_test, "toarray"):
84
- X_test = X_test.toarray()
85
- Y_logits = model(torch.Tensor(X_test))
86
- pred_languages = torch.argmax(Y_logits, dim=-1).tolist()
87
- return list(map(idx2lang.get, pred_languages))[0]
88
-
89
- # --- Gradio UI ---
90
- with gr.Blocks(title="Antons language detector") as demo:
91
- gr.Markdown("# Antons language detector")
92
- with gr.Row():
93
- with gr.Column():
94
- src_sentence = gr.Textbox(
95
- label="Text", placeholder="Write your text...")
96
- with gr.Column():
97
- tgt_sentence = gr.Textbox(
98
- label="Language", placeholder="Language will show here...")
99
- btn = gr.Button("Guess the language!")
100
- btn.click(fn=detect_lang, inputs=[src_sentence], outputs=[tgt_sentence])
101
-
102
  demo.launch()
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import hashlib
5
+ import joblib
6
+ from collections import Counter
7
+ import gradio as gr
8
+
9
+ # --- utils (from the notebook) ---
10
+ def ngrams(sentence, n=1, lc=True):
11
+ ngram_l = []
12
+ sentence = sentence.lower()
13
+ for i in range(len(sentence) - n + 1):
14
+ ngram = sentence[i:i+n]
15
+ ngram_l.append(ngram)
16
+ return ngram_l
17
+
18
+ def all_ngrams(sentence, max_ngram=3, lc=True):
19
+ all_ngram_list = []
20
+ for i in range(1, max_ngram + 1):
21
+ all_ngram_list += [ngrams(sentence, n=i, lc=lc)]
22
+ return all_ngram_list
23
+
24
+ MAX_CHARS = 521
25
+ MAX_BIGRAMS = 1031
26
+ MAX_TRIGRAMS = 1031
27
+ MAXES = [MAX_CHARS, MAX_BIGRAMS, MAX_TRIGRAMS]
28
+
29
+ def reproducible_hash(string):
30
+ h = hashlib.md5(string.encode("utf-8"), usedforsecurity=False)
31
+ return int.from_bytes(h.digest()[0:8], 'big', signed=True)
32
+
33
+ def hash_ngrams(ngrams, modulos):
34
+ hash_codes = []
35
+ for ngram_list, modulo in zip(ngrams, modulos):
36
+ codes = [(reproducible_hash(x) % modulo) for x in ngram_list]
37
+ hash_codes.append(codes)
38
+ return hash_codes
39
+
40
+ def calc_rel_freq(codes):
41
+ cnt = Counter(codes)
42
+ total = sum(cnt.values())
43
+ for k in cnt:
44
+ cnt[k] /= total
45
+ return cnt
46
+
47
+ MAX_SHIFT = []
48
+ for i in range(len(MAXES)):
49
+ MAX_SHIFT += [sum(MAXES[:i])]
50
+
51
+ def shift_keys(dicts, MAX_SHIFT):
52
+ new_dict = {}
53
+ for i, ngrams_d in enumerate(dicts):
54
+ for k, v in ngrams_d.items():
55
+ new_dict[k + MAX_SHIFT[i]] = v
56
+ return new_dict
57
+
58
+ def build_freq_dict(sentence, MAXES=MAXES, MAX_SHIFT=MAX_SHIFT):
59
+ hngrams = hash_ngrams(all_ngrams(sentence), MAXES)
60
+ fhcodes = map(calc_rel_freq, hngrams)
61
+ return shift_keys(fhcodes, MAX_SHIFT)
62
+
63
+ # --- load models ---
64
+ clf = joblib.load("nld.joblib")
65
+ vectorizer = joblib.load("nld_vectorizer.joblib")
66
+ idx2lang = joblib.load("nld_lang_codes.joblib")
67
+
68
+ input_dim = len(vectorizer.vocabulary_)
69
+ nbr_classes = len(idx2lang)
70
+
71
+ model = nn.Sequential(
72
+ nn.Linear(input_dim, 50),
73
+ nn.ReLU(),
74
+ nn.Linear(50, nbr_classes)
75
+ )
76
+ model.load_state_dict(torch.load("nld.pth", map_location="cpu"))
77
+ model.eval()
78
+
79
+ # --- prediction function ---
80
+ def detect_lang(src_sentence):
81
+ src_sentence = [src_sentence]
82
+ X_test = vectorizer.transform(map(build_freq_dict, src_sentence))
83
+ if hasattr(X_test, "toarray"):
84
+ X_test = X_test.toarray()
85
+ Y_logits = model(torch.Tensor(X_test))
86
+ pred_languages = torch.argmax(Y_logits, dim=-1).tolist()
87
+ return list(map(idx2lang.get, pred_languages))[0]
88
+
89
+ # --- Gradio UI ---
90
+ with gr.Blocks(title="ZheZhanglanguage detector") as demo:
91
+ gr.Markdown("# ZheZhang language detector")
92
+ with gr.Row():
93
+ with gr.Column():
94
+ src_sentence = gr.Textbox(
95
+ label="Text", placeholder="Write your text...")
96
+ with gr.Column():
97
+ tgt_sentence = gr.Textbox(
98
+ label="Language", placeholder="Language will show here...")
99
+ btn = gr.Button("Guess the language!")
100
+ btn.click(fn=detect_lang, inputs=[src_sentence], outputs=[tgt_sentence])
101
+
102
  demo.launch()