Xianfish9 commited on
Commit
ec1b79a
·
verified ·
1 Parent(s): 1eeed0f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Adam_lr7e-05_weightdecay0.0001_epochs3480.pth
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ import os
6
+ import re
7
+
8
+ # --- 依赖导入 ---
9
+ # 从你的代码库中导入必要的模块
10
+ # 这要求你的文件结构是正确的 (例如: /Feature_extraction_algorithms/PSTAAP.py)
11
+ from model import CAFN
12
+ from Feature_extraction_algorithms.PSTAAP import PSTAAP_feature
13
+ from Feature_extraction_algorithms.Physicochemical import PC_feature
14
+
15
+ # --- 1. 模型加载 ---
16
+ # 确保 'your_model_name.pth' 和你上传的文件名完全一致
17
+ MODEL_PATH = "Adam_lr7e-05_weightdecay0.0001_epochs3480.pth" # <--- 在这里修改成你的 .pth 文件名
18
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+
20
+ def load_model(model_path):
21
+ model = CAFN().to(device)
22
+ if os.path.exists(model_path):
23
+ model.load_state_dict(torch.load(model_path, map_location=device))
24
+ model.eval() # 设置为评估模式
25
+ print("模型加载成功!")
26
+ return model
27
+ else:
28
+ print(f"错误:在路径 {model_path} 未找到模型文件")
29
+ return None
30
+
31
+ model = load_model(MODEL_PATH)
32
+
33
+ # --- 2. 特征提取函数 ---
34
+ # 这个函数直接改编自你的 dataProcess.py
35
+ def extract_features_from_seq(sequence_list, test_PSTAAP=True):
36
+ """
37
+ 接收一个包含序列的列表,返回模型所需的两个特征张量 x1 和 x2。
38
+ """
39
+ # 提取 PC_feature (对应 x2)
40
+ data2 = PC_feature(sequence_list)
41
+
42
+ # 提取 PSTAAP_feature (对应 x1)
43
+ N = len(sequence_list)
44
+ empty_list_array = [[] for _ in range(N)]
45
+ data = np.array(empty_list_array, dtype=object)
46
+ feature = PSTAAP_feature(sequence_list, test_PSTAAP)
47
+ data = np.hstack((data, feature))
48
+
49
+ # 返回 NumPy 数组
50
+ return data.astype(np.float32), data2.astype(np.float32)
51
+
52
+ # --- 3. 核心预测函数 ---
53
+ # Gradio 界面会调用这个函数
54
+ def predict(sequence_input):
55
+ if model is None:
56
+ return {"错误": "模型未能加载,请检查后台日志"}
57
+
58
+ # 输入验证
59
+ if not sequence_input or not isinstance(sequence_input, str):
60
+ return {"错误": "请输入有效的生物序列"}
61
+
62
+ # 将输入的字符串处理成符合规范的格式
63
+ # .strip() 去除首尾空格, .upper() 转换为大写 (如果需要)
64
+ cleaned_sequence = sequence_input.strip().upper()
65
+
66
+ # 将单个序列放入列表中,因为特征提取函数期望一个列表
67
+ sequence_list = [cleaned_sequence]
68
+
69
+ # a. 调用特征提取
70
+ try:
71
+ x1_np, x2_np = extract_features_from_seq(sequence_list, test_PSTAAP=True)
72
+ except Exception as e:
73
+ # 如果特征提取失败,向用户显示错误
74
+ return {f"特征提取失败": str(e)}
75
+
76
+ # b. 将 NumPy 数组转换为 PyTorch 张量
77
+ # 特征提取函数应该已经为单个序列返回了正确的形状 (1, ...),所以不需要 .unsqueeze()
78
+ tensor_x1 = torch.tensor(x1_np).to(device)
79
+ tensor_x2 = torch.tensor(x2_np).to(device)
80
+
81
+ # c. 进行预测
82
+ with torch.no_grad():
83
+ outputs = model(tensor_x1, tensor_x2)
84
+
85
+ # d. 处理输出
86
+ # 你的模型输出是4个类别。我们用 sigmoid 来获取每个类别的概率
87
+ probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()
88
+
89
+ # e. 格式化成字典,方便在界面上显示
90
+ # 根据你的 make_ylabel 函数,这四个类别分别对应 a, c, m, s
91
+ labels = ["类别 A (a)", "类别 C (c)", "类别 M (m)", "类别 S (s)"]
92
+ result = {label: float(prob) for label, prob in zip(labels, probabilities)}
93
+
94
+ return result
95
+
96
+ # --- 4. 创建并启动 Gradio 界面 ---
97
+ demo = gr.Interface(
98
+ fn=predict,
99
+ inputs=gr.Textbox(
100
+ lines=7,
101
+ label="输入生物序列 (Input Sequence)",
102
+ placeholder="请在这里粘贴你的序列..."
103
+ ),
104
+ outputs=gr.Label(num_top_classes=4, label="预测概率 (Prediction Probabilities)"),
105
+ title="CAFN 模型部署:多标签序列分类器",
106
+ description="输入一个生物序列,模型将预测它属于四个类别 (A, C, M, S) 中每一个的概率。",
107
+ examples=[
108
+ ["PLEPIPIVAAAAA"],
109
+ ["GMWSGGGGISGSLIIVIRAELGVPSGMMILGYLN"],
110
+ ]
111
+ )
112
+
113
+ # 启动应用
114
+ demo.launch()