Xianfish9 commited on
Commit
ded76e3
·
verified ·
1 Parent(s): 7616287

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -59,30 +59,37 @@ def extract_features_from_seq(sequence_list):
59
  # --- 4. 核心预测函数 ---
60
  def predict(sequence_input):
61
  if model is None:
62
- return {"错误": "模型未能加载或初始化失败,请检查后台日志"}
 
63
 
64
  if not sequence_input or not isinstance(sequence_input, str):
65
- return {"错误": "请输入有效的生物序列"}
 
66
 
67
  cleaned_sequence = sequence_input.strip().upper()
68
  sequence_list = [cleaned_sequence]
69
 
70
- try:
71
- # !!! 在这里调用了上面的函数 !!!
72
- x1_np, x2_np = extract_features_from_seq(sequence_list)
73
- except Exception as e:
74
- # 如果特征提取失败(包括 NameError),会在这里捕获
75
- return {f"特征提取失败": str(e)}
76
-
77
  tensor_x1 = torch.tensor(x1_np).to(device)
78
  tensor_x2 = torch.tensor(x2_np).to(device)
79
 
 
80
  with torch.no_grad():
81
  outputs = model(tensor_x1, tensor_x2)
82
 
 
83
  probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
84
 
 
85
  labels = ["类别 A (a)", "类别 C (c)", "类别 M (m)", "类别 S (s)"]
 
 
 
 
86
  result = {label: float(prob) for label, prob in zip(labels, probabilities)}
87
 
88
  return result
 
59
  # --- 4. 核心预测函数 ---
60
  def predict(sequence_input):
61
  if model is None:
62
+ # 如果模型加载失败,可以提前抛出错误
63
+ raise gr.Error("模型未能加载或初始化失败,请检查后台日志。")
64
 
65
  if not sequence_input or not isinstance(sequence_input, str):
66
+ # 对于无效输入,也直接抛出错误
67
+ raise gr.Error("请输入有效的生物序列。")
68
 
69
  cleaned_sequence = sequence_input.strip().upper()
70
  sequence_list = [cleaned_sequence]
71
 
72
+ # !!! 移除这里的 try...except !!!
73
+ # 让任何可能发生的错误自然地被Gradio捕获
74
+ x1_np, x2_np = extract_features_from_seq(sequence_list)
75
+
76
+ # NumPy 数组转换为 PyTorch 张量
 
 
77
  tensor_x1 = torch.tensor(x1_np).to(device)
78
  tensor_x2 = torch.tensor(x2_np).to(device)
79
 
80
+ # 模型预测
81
  with torch.no_grad():
82
  outputs = model(tensor_x1, tensor_x2)
83
 
84
+ # 计算概率
85
  probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
86
 
87
+ # 准备输出结果
88
  labels = ["类别 A (a)", "类别 C (c)", "类别 M (m)", "类别 S (s)"]
89
+ # 确保即使只有一个序列,结果也能正确处理
90
+ if probabilities.ndim == 0: # 如果只有一个输出
91
+ probabilities = [probabilities]
92
+
93
  result = {label: float(prob) for label, prob in zip(labels, probabilities)}
94
 
95
  return result