Zhe-Zhang commited on
Commit
96316d5
·
verified ·
1 Parent(s): b8828ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -77,14 +77,18 @@ model.load_state_dict(torch.load("nld.pth", map_location="cpu"))
77
  model.eval()
78
 
79
  # --- prediction function ---
80
- # 修改预测函数为sklearn版本
81
  def detect_lang(src_sentence):
82
- # 特征提取逻辑不变
83
- X_test = vectorizer.transform([build_freq_dict(src_sentence)])
84
- # 使用sklearn模型预测
85
- pred_idx = clf.predict(X_test)[0]
86
- return idx2lang[pred_idx]
87
-
 
 
 
 
 
88
  # --- Gradio UI ---
89
  with gr.Blocks(title="Antons language detector") as demo:
90
  gr.Markdown("# Antons language detector")
 
77
  model.eval()
78
 
79
  # --- prediction function ---
 
80
  def detect_lang(src_sentence):
81
+ # 直接对单条文本提取特征,用列表推导式(与notebook一致)
82
+ test_feat_dicts = [build_freq_dict(src_sentence)]
83
+ # 转换为模型输入
84
+ X_test = vectorizer.transform(test_feat_dicts)
85
+ # 后续处理不变
86
+ if hasattr(X_test, "toarray"):
87
+ X_test = X_test.toarray()
88
+ Y_logits = model(torch.Tensor(X_test))
89
+ pred_languages = torch.argmax(Y_logits, dim=-1).tolist()
90
+ return list(map(idx2lang.get, pred_languages))[0]
91
+
92
  # --- Gradio UI ---
93
  with gr.Blocks(title="Antons language detector") as demo:
94
  gr.Markdown("# Antons language detector")