Zhe-Zhang commited on
Commit
32d1bb5
·
verified ·
1 Parent(s): a51206f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -55
app.py CHANGED
@@ -1,21 +1,25 @@
 
1
  import torch
2
  import torch.nn as nn
3
- import joblib
4
  import hashlib
 
5
  from collections import Counter
6
  import gradio as gr
7
 
8
- # ========== utils ==========
9
  def ngrams(sentence, n=1, lc=True):
10
- if lc:
11
- sentence = sentence.lower()
12
- return [sentence[i:i+n] for i in range(len(sentence) - n + 1)]
 
 
 
13
 
14
  def all_ngrams(sentence, max_ngram=3, lc=True):
15
- result = []
16
  for i in range(1, max_ngram + 1):
17
- result.append(ngrams(sentence, n=i, lc=lc))
18
- return result
19
 
20
  MAX_CHARS = 521
21
  MAX_BIGRAMS = 1031
@@ -24,70 +28,75 @@ MAXES = [MAX_CHARS, MAX_BIGRAMS, MAX_TRIGRAMS]
24
 
25
  def reproducible_hash(string):
26
  h = hashlib.md5(string.encode("utf-8"), usedforsecurity=False)
27
- return int.from_bytes(h.digest()[0:8], "big", signed=True)
28
 
29
  def hash_ngrams(ngrams, modulos):
30
- out = []
31
  for ngram_list, modulo in zip(ngrams, modulos):
32
  codes = [(reproducible_hash(x) % modulo) for x in ngram_list]
33
- out.append(codes)
34
- return out
35
 
36
  def calc_rel_freq(codes):
37
  cnt = Counter(codes)
38
- total = sum(cnt.values()) or 1
39
- return {k: v / total for k, v in cnt.items()}
40
-
41
- MAX_SHIFT = [0]
42
- for i in range(1, len(MAXES)):
43
- MAX_SHIFT.append(sum(MAXES[:i]))
44
-
45
- def shift_keys(dicts, shift_list):
46
- new = {}
47
- for i, d in enumerate(dicts):
48
- for k, v in d.items():
49
- new[k + shift_list[i]] = v
50
- return new
51
-
52
- def build_freq_dict(sentence):
 
 
53
  hngrams = hash_ngrams(all_ngrams(sentence), MAXES)
54
- freqs = list(map(calc_rel_freq, hngrams))
55
- return shift_keys(freqs, MAX_SHIFT)
56
 
57
- # ========== load artifacts ==========
 
58
  vectorizer = joblib.load("nld_vectorizer.joblib")
59
  idx2lang = joblib.load("nld_lang_codes.joblib")
60
 
61
- input_dim = len(vectorizer.feature_names_) # 确保和训练时一致
62
- num_classes = len(idx2lang)
63
 
64
  model = nn.Sequential(
65
  nn.Linear(input_dim, 50),
66
  nn.ReLU(),
67
- nn.Linear(50, num_classes)
68
  )
69
-
70
- state_dict = torch.load("nld.pth", map_location="cpu")
71
- model.load_state_dict(state_dict)
72
  model.eval()
73
 
74
- # ========== prediction ==========
75
- def detect_lang(text: str):
76
- feat_dict = build_freq_dict(text)
77
- X = vectorizer.transform([feat_dict])
78
- X_tensor = torch.from_numpy(X.toarray().astype("float32"))
79
- with torch.no_grad():
80
- logits = model(X_tensor)
81
- pred_idx = torch.argmax(logits, dim=1).item()
82
- return idx2lang[pred_idx]
83
-
84
- # ========== Gradio UI ==========
85
- with gr.Blocks(title="Language Detector") as demo:
86
- gr.Markdown("## Language Detector")
87
  with gr.Row():
88
- text_in = gr.Textbox(label="Input text", placeholder="Type something...")
89
- text_out = gr.Textbox(label="Predicted language", interactive=False)
90
- btn = gr.Button("Detect")
91
- btn.click(fn=detect_lang, inputs=text_in, outputs=text_out)
92
-
93
- demo.launch()
 
 
 
 
 
1
+ import numpy as np
2
  import torch
3
  import torch.nn as nn
 
4
  import hashlib
5
+ import joblib
6
  from collections import Counter
7
  import gradio as gr
8
 
9
+ # --- utils (from the notebook) ---
10
  def ngrams(sentence, n=1, lc=True):
11
+ ngram_l = []
12
+ sentence = sentence.lower()
13
+ for i in range(len(sentence) - n + 1):
14
+ ngram = sentence[i:i+n]
15
+ ngram_l.append(ngram)
16
+ return ngram_l
17
 
18
  def all_ngrams(sentence, max_ngram=3, lc=True):
19
+ all_ngram_list = []
20
  for i in range(1, max_ngram + 1):
21
+ all_ngram_list += [ngrams(sentence, n=i, lc=lc)]
22
+ return all_ngram_list
23
 
24
  MAX_CHARS = 521
25
  MAX_BIGRAMS = 1031
 
28
 
29
  def reproducible_hash(string):
30
  h = hashlib.md5(string.encode("utf-8"), usedforsecurity=False)
31
+ return int.from_bytes(h.digest()[0:8], 'big', signed=True)
32
 
33
  def hash_ngrams(ngrams, modulos):
34
+ hash_codes = []
35
  for ngram_list, modulo in zip(ngrams, modulos):
36
  codes = [(reproducible_hash(x) % modulo) for x in ngram_list]
37
+ hash_codes.append(codes)
38
+ return hash_codes
39
 
40
  def calc_rel_freq(codes):
41
  cnt = Counter(codes)
42
+ total = sum(cnt.values())
43
+ for k in cnt:
44
+ cnt[k] /= total
45
+ return cnt
46
+
47
+ MAX_SHIFT = []
48
+ for i in range(len(MAXES)):
49
+ MAX_SHIFT += [sum(MAXES[:i])]
50
+
51
+ def shift_keys(dicts, MAX_SHIFT):
52
+ new_dict = {}
53
+ for i, ngrams_d in enumerate(dicts):
54
+ for k, v in ngrams_d.items():
55
+ new_dict[k + MAX_SHIFT[i]] = v
56
+ return new_dict
57
+
58
+ def build_freq_dict(sentence, MAXES=MAXES, MAX_SHIFT=MAX_SHIFT):
59
  hngrams = hash_ngrams(all_ngrams(sentence), MAXES)
60
+ fhcodes = map(calc_rel_freq, hngrams)
61
+ return shift_keys(fhcodes, MAX_SHIFT)
62
 
63
+ # --- load models ---
64
+ clf = joblib.load("nld.joblib")
65
  vectorizer = joblib.load("nld_vectorizer.joblib")
66
  idx2lang = joblib.load("nld_lang_codes.joblib")
67
 
68
+ input_dim = len(vectorizer.vocabulary_)
69
+ nbr_classes = len(idx2lang)
70
 
71
  model = nn.Sequential(
72
  nn.Linear(input_dim, 50),
73
  nn.ReLU(),
74
+ nn.Linear(50, nbr_classes)
75
  )
76
+ model.load_state_dict(torch.load("nld.pth", map_location="cpu"))
 
 
77
  model.eval()
78
 
79
+ # --- prediction function ---
80
+ def detect_lang(src_sentence):
81
+ src_sentence = [src_sentence]
82
+ X_test = vectorizer.transform(map(build_freq_dict, src_sentence))
83
+ if hasattr(X_test, "toarray"):
84
+ X_test = X_test.toarray()
85
+ Y_logits = model(torch.Tensor(X_test))
86
+ pred_languages = torch.argmax(Y_logits, dim=-1).tolist()
87
+ return list(map(idx2lang.get, pred_languages))[0]
88
+
89
+ # --- Gradio UI ---
90
+ with gr.Blocks(title="language detector") as demo:
91
+ gr.Markdown("# language detector")
92
  with gr.Row():
93
+ with gr.Column():
94
+ src_sentence = gr.Textbox(
95
+ label="Text", placeholder="Write your text...")
96
+ with gr.Column():
97
+ tgt_sentence = gr.Textbox(
98
+ label="Language", placeholder="Language will show here...")
99
+ btn = gr.Button("Guess the language!")
100
+ btn.click(fn=detect_lang, inputs=[src_sentence], outputs=[tgt_sentence])
101
+
102
+ demo.launch()