Zhe-Zhang commited on
Commit
c3bf0e9
·
verified ·
1 Parent(s): 0da632d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -22
app.py CHANGED
@@ -3,18 +3,18 @@ import torch.nn as nn
3
  import joblib
4
  import hashlib
5
  from collections import Counter
6
- import numpy as np
7
  import gradio as gr
8
 
9
- # --- utils ---
10
  def ngrams(sentence, n=1, lc=True):
11
- sentence = sentence.lower()
 
12
  return [sentence[i:i+n] for i in range(len(sentence) - n + 1)]
13
 
14
  def all_ngrams(sentence, max_ngram=3, lc=True):
15
  result = []
16
  for i in range(1, max_ngram + 1):
17
- result += [ngrams(sentence, n=i, lc=lc)]
18
  return result
19
 
20
  MAX_CHARS = 521
@@ -54,11 +54,11 @@ def build_freq_dict(sentence):
54
  freqs = list(map(calc_rel_freq, hngrams))
55
  return shift_keys(freqs, MAX_SHIFT)
56
 
57
- # --- load artifacts ---
58
  vectorizer = joblib.load("nld_vectorizer.joblib")
59
  idx2lang = joblib.load("nld_lang_codes.joblib")
60
 
61
- input_dim = len(vectorizer.vocabulary_)
62
  num_classes = len(idx2lang)
63
 
64
  model = nn.Sequential(
@@ -66,31 +66,28 @@ model = nn.Sequential(
66
  nn.ReLU(),
67
  nn.Linear(50, num_classes)
68
  )
69
- model.load_state_dict(torch.load("nld (1).pth", map_location="cpu"))
 
 
70
  model.eval()
71
 
72
- # --- prediction ---
73
  def detect_lang(text: str):
74
  feat_dict = build_freq_dict(text)
75
  X = vectorizer.transform([feat_dict])
76
- if hasattr(X, "toarray"):
77
- X = X.toarray()
78
- X = torch.from_numpy(X.astype("float32"))
79
-
80
  with torch.no_grad():
81
- logits = model(X)
82
- pred_idx = torch.argmax(logits, dim=-1).item()
83
  return idx2lang[pred_idx]
84
 
85
- # --- UI ---
86
  with gr.Blocks(title="Language Detector") as demo:
87
- gr.Markdown("# Language Detector")
88
  with gr.Row():
89
- with gr.Column():
90
- src_text = gr.Textbox(label="Enter text", placeholder="Type here...")
91
- btn = gr.Button("Detect Language")
92
- with gr.Column():
93
- out_lang = gr.Textbox(label="Predicted language", interactive=False)
94
- btn.click(fn=detect_lang, inputs=src_text, outputs=out_lang)
95
 
96
  demo.launch()
 
3
  import joblib
4
  import hashlib
5
  from collections import Counter
 
6
  import gradio as gr
7
 
8
+ # ========== utils ==========
9
  def ngrams(sentence, n=1, lc=True):
10
+ if lc:
11
+ sentence = sentence.lower()
12
  return [sentence[i:i+n] for i in range(len(sentence) - n + 1)]
13
 
14
  def all_ngrams(sentence, max_ngram=3, lc=True):
15
  result = []
16
  for i in range(1, max_ngram + 1):
17
+ result.append(ngrams(sentence, n=i, lc=lc))
18
  return result
19
 
20
  MAX_CHARS = 521
 
54
  freqs = list(map(calc_rel_freq, hngrams))
55
  return shift_keys(freqs, MAX_SHIFT)
56
 
57
+ # ========== load artifacts ==========
58
  vectorizer = joblib.load("nld_vectorizer.joblib")
59
  idx2lang = joblib.load("nld_lang_codes.joblib")
60
 
61
+ input_dim = len(vectorizer.feature_names_) # 确保和训练时一致
62
  num_classes = len(idx2lang)
63
 
64
  model = nn.Sequential(
 
66
  nn.ReLU(),
67
  nn.Linear(50, num_classes)
68
  )
69
+
70
+ state_dict = torch.load("nld.pth", map_location="cpu")
71
+ model.load_state_dict(state_dict)
72
  model.eval()
73
 
74
+ # ========== prediction ==========
75
  def detect_lang(text: str):
76
  feat_dict = build_freq_dict(text)
77
  X = vectorizer.transform([feat_dict])
78
+ X_tensor = torch.from_numpy(X.toarray().astype("float32"))
 
 
 
79
  with torch.no_grad():
80
+ logits = model(X_tensor)
81
+ pred_idx = torch.argmax(logits, dim=1).item()
82
  return idx2lang[pred_idx]
83
 
84
+ # ========== Gradio UI ==========
85
  with gr.Blocks(title="Language Detector") as demo:
86
+ gr.Markdown("## Language Detector")
87
  with gr.Row():
88
+ text_in = gr.Textbox(label="Input text", placeholder="Type something...")
89
+ text_out = gr.Textbox(label="Predicted language", interactive=False)
90
+ btn = gr.Button("Detect")
91
+ btn.click(fn=detect_lang, inputs=text_in, outputs=text_out)
 
 
92
 
93
  demo.launch()