Zhe-Zhang commited on
Commit
50bcdf5
·
verified ·
1 Parent(s): 5630e19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -61
app.py CHANGED
@@ -1,25 +1,21 @@
1
- import numpy as np
2
  import torch
3
  import torch.nn as nn
4
- import hashlib
5
  import joblib
 
6
  from collections import Counter
 
7
  import gradio as gr
8
 
9
- # --- utils (from the notebook) ---
10
  def ngrams(sentence, n=1, lc=True):
11
- ngram_l = []
12
  sentence = sentence.lower()
13
- for i in range(len(sentence) - n + 1):
14
- ngram = sentence[i:i+n]
15
- ngram_l.append(ngram)
16
- return ngram_l
17
 
18
  def all_ngrams(sentence, max_ngram=3, lc=True):
19
- all_ngram_list = []
20
  for i in range(1, max_ngram + 1):
21
- all_ngram_list += [ngrams(sentence, n=i, lc=lc)]
22
- return all_ngram_list
23
 
24
  MAX_CHARS = 521
25
  MAX_BIGRAMS = 1031
@@ -28,78 +24,73 @@ MAXES = [MAX_CHARS, MAX_BIGRAMS, MAX_TRIGRAMS]
28
 
29
  def reproducible_hash(string):
30
  h = hashlib.md5(string.encode("utf-8"), usedforsecurity=False)
31
- return int.from_bytes(h.digest()[0:8], 'big', signed=True)
32
 
33
  def hash_ngrams(ngrams, modulos):
34
- hash_codes = []
35
  for ngram_list, modulo in zip(ngrams, modulos):
36
  codes = [(reproducible_hash(x) % modulo) for x in ngram_list]
37
- hash_codes.append(codes)
38
- return hash_codes
39
 
40
  def calc_rel_freq(codes):
41
  cnt = Counter(codes)
42
- total = sum(cnt.values())
43
- for k in cnt:
44
- cnt[k] /= total
45
- return cnt
46
-
47
- MAX_SHIFT = []
48
- for i in range(len(MAXES)):
49
- MAX_SHIFT += [sum(MAXES[:i])]
50
-
51
- def shift_keys(dicts, MAX_SHIFT):
52
- new_dict = {}
53
- for i, ngrams_d in enumerate(dicts):
54
- for k, v in ngrams_d.items():
55
- new_dict[k + MAX_SHIFT[i]] = v
56
- return new_dict
57
-
58
- def build_freq_dict(sentence, MAXES=MAXES, MAX_SHIFT=MAX_SHIFT):
59
  hngrams = hash_ngrams(all_ngrams(sentence), MAXES)
60
- fhcodes = map(calc_rel_freq, hngrams)
61
- return shift_keys(fhcodes, MAX_SHIFT)
62
 
63
- # --- load models ---
64
- clf = joblib.load("nld.joblib")
65
  vectorizer = joblib.load("nld_vectorizer.joblib")
66
  idx2lang = joblib.load("nld_lang_codes.joblib")
67
 
68
  input_dim = len(vectorizer.vocabulary_)
69
- nbr_classes = len(idx2lang)
70
 
71
  model = nn.Sequential(
72
  nn.Linear(input_dim, 50),
73
  nn.ReLU(),
74
- nn.Linear(50, nbr_classes)
75
  )
76
  model.load_state_dict(torch.load("nld.pth", map_location="cpu"))
77
  model.eval()
78
 
79
- # --- prediction function ---
80
- def detect_lang(src_sentence):
81
- # 直接对单条文本提取特征,用列表推导式(与notebook一致)
82
- test_feat_dicts = [build_freq_dict(src_sentence)]
83
- # 转换为模型输入
84
- X_test = vectorizer.transform(test_feat_dicts)
85
- # 后续处理不变
86
- if hasattr(X_test, "toarray"):
87
- X_test = X_test.toarray()
88
- Y_logits = model(torch.Tensor(X_test))
89
- pred_languages = torch.argmax(Y_logits, dim=-1).tolist()
90
- return list(map(idx2lang.get, pred_languages))[0]
91
-
92
- # --- Gradio UI ---
93
- with gr.Blocks(title="Antons language detector") as demo:
94
- gr.Markdown("# Antons language detector")
95
  with gr.Row():
96
  with gr.Column():
97
- src_sentence = gr.Textbox(
98
- label="Text", placeholder="Write your text...")
99
  with gr.Column():
100
- tgt_sentence = gr.Textbox(
101
- label="Language", placeholder="Language will show here...")
102
- btn = gr.Button("Guess the language!")
103
- btn.click(fn=detect_lang, inputs=[src_sentence], outputs=[tgt_sentence])
104
 
105
- demo.launch()
 
 
1
  import torch
2
  import torch.nn as nn
 
3
  import joblib
4
+ import hashlib
5
  from collections import Counter
6
+ import numpy as np
7
  import gradio as gr
8
 
9
+ # --- utils ---
10
  def ngrams(sentence, n=1, lc=True):
 
11
  sentence = sentence.lower()
12
+ return [sentence[i:i+n] for i in range(len(sentence) - n + 1)]
 
 
 
13
 
14
  def all_ngrams(sentence, max_ngram=3, lc=True):
15
+ result = []
16
  for i in range(1, max_ngram + 1):
17
+ result += [ngrams(sentence, n=i, lc=lc)]
18
+ return result
19
 
20
  MAX_CHARS = 521
21
  MAX_BIGRAMS = 1031
 
24
 
25
  def reproducible_hash(string):
26
  h = hashlib.md5(string.encode("utf-8"), usedforsecurity=False)
27
+ return int.from_bytes(h.digest()[0:8], "big", signed=True)
28
 
29
  def hash_ngrams(ngrams, modulos):
30
+ out = []
31
  for ngram_list, modulo in zip(ngrams, modulos):
32
  codes = [(reproducible_hash(x) % modulo) for x in ngram_list]
33
+ out.append(codes)
34
+ return out
35
 
36
  def calc_rel_freq(codes):
37
  cnt = Counter(codes)
38
+ total = sum(cnt.values()) or 1
39
+ return {k: v / total for k, v in cnt.items()}
40
+
41
+ MAX_SHIFT = [0]
42
+ for i in range(1, len(MAXES)):
43
+ MAX_SHIFT.append(sum(MAXES[:i]))
44
+
45
+ def shift_keys(dicts, shift_list):
46
+ new = {}
47
+ for i, d in enumerate(dicts):
48
+ for k, v in d.items():
49
+ new[k + shift_list[i]] = v
50
+ return new
51
+
52
+ def build_freq_dict(sentence):
 
 
53
  hngrams = hash_ngrams(all_ngrams(sentence), MAXES)
54
+ freqs = list(map(calc_rel_freq, hngrams))
55
+ return shift_keys(freqs, MAX_SHIFT)
56
 
57
+ # --- load artifacts ---
 
58
  vectorizer = joblib.load("nld_vectorizer.joblib")
59
  idx2lang = joblib.load("nld_lang_codes.joblib")
60
 
61
  input_dim = len(vectorizer.vocabulary_)
62
+ num_classes = len(idx2lang)
63
 
64
  model = nn.Sequential(
65
  nn.Linear(input_dim, 50),
66
  nn.ReLU(),
67
+ nn.Linear(50, num_classes)
68
  )
69
  model.load_state_dict(torch.load("nld.pth", map_location="cpu"))
70
  model.eval()
71
 
72
+ # --- prediction ---
73
+ def detect_lang(text: str):
74
+ feat_dict = build_freq_dict(text)
75
+ X = vectorizer.transform([feat_dict])
76
+ if hasattr(X, "toarray"):
77
+ X = X.toarray()
78
+ X = torch.from_numpy(X.astype("float32"))
79
+
80
+ with torch.no_grad():
81
+ logits = model(X)
82
+ pred_idx = torch.argmax(logits, dim=-1).item()
83
+ return idx2lang[pred_idx]
84
+
85
+ # --- UI ---
86
+ with gr.Blocks(title="Language Detector") as demo:
87
+ gr.Markdown("# Language Detector")
88
  with gr.Row():
89
  with gr.Column():
90
+ src_text = gr.Textbox(label="Enter text", placeholder="Type here...")
91
+ btn = gr.Button("Detect Language")
92
  with gr.Column():
93
+ out_lang = gr.Textbox(label="Predicted language", interactive=False)
94
+ btn.click(fn=detect_lang, inputs=src_text, outputs=out_lang)
 
 
95
 
96
+ demo.launch()