Zhe-Zhang commited on
Commit
6ae01eb
·
verified ·
1 Parent(s): 32d1bb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -77,6 +77,7 @@ model.load_state_dict(torch.load("nld.pth", map_location="cpu"))
77
  model.eval()
78
 
79
  # --- prediction function ---
 
80
  def detect_lang(src_sentence):
81
  src_sentence = [src_sentence]
82
  X_test = vectorizer.transform(map(build_freq_dict, src_sentence))
@@ -85,7 +86,14 @@ def detect_lang(src_sentence):
85
  Y_logits = model(torch.Tensor(X_test))
86
  pred_languages = torch.argmax(Y_logits, dim=-1).tolist()
87
  return list(map(idx2lang.get, pred_languages))[0]
88
-
 
 
 
 
 
 
 
89
  # --- Gradio UI ---
90
  with gr.Blocks(title="language detector") as demo:
91
  gr.Markdown("# language detector")
 
77
  model.eval()
78
 
79
  # --- prediction function ---
80
+ '''
81
  def detect_lang(src_sentence):
82
  src_sentence = [src_sentence]
83
  X_test = vectorizer.transform(map(build_freq_dict, src_sentence))
 
86
  Y_logits = model(torch.Tensor(X_test))
87
  pred_languages = torch.argmax(Y_logits, dim=-1).tolist()
88
  return list(map(idx2lang.get, pred_languages))[0]
89
+ '''
90
+ # sklearn
91
+ def detect_lang(src_sentence):
92
+ X_test = vectorizer.transform([build_freq_dict(src_sentence)])
93
+ # predict using sklearn
94
+ pred_idx = clf.predict(X_test)[0]
95
+ return idx2lang[pred_idx]
96
+
97
  # --- Gradio UI ---
98
  with gr.Blocks(title="language detector") as demo:
99
  gr.Markdown("# language detector")