QQ2S3R commited on
Commit
36df9cc
·
verified ·
1 Parent(s): bae8e82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +303 -224
app.py CHANGED
@@ -1,224 +1,303 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Tue Jul 8 15:00:01 2025
4
-
5
- @author: User
6
- """
7
- import os
8
- import zipfile
9
- import torch
10
- import numpy as np
11
- import logging
12
- from PIL import Image
13
- import gradio as gr
14
- from torch_geometric.data import Data as PyGData
15
- import matplotlib
16
- matplotlib.use('Agg') # 修复后台线程问题
17
-
18
- # 配置日志
19
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
20
- logger = logging.getLogger(__name__)
21
-
22
- # GPU内存优化
23
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
24
- logger.info("设置GPU内存优化参数: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128")
25
-
26
- # 解压模型文件
27
- if not os.path.exists("best_model-B-6000-185.pth"):
28
- logger.info("开始解压模型文件...")
29
- try:
30
- with zipfile.ZipFile("models.zip", 'r') as zip_ref:
31
- zip_ref.extractall(".")
32
- logger.info("模型文件解压完成!")
33
- except Exception as e:
34
- logger.error(f"解压模型文件失败: {str(e)}")
35
- raise
36
-
37
- # 导入模型工具
38
- try:
39
- from model_utils import EnhancedGAT, smiles_to_graph, visualize_single_molecule
40
- logger.info("成功导入 model_utils 模块")
41
- except ImportError as e:
42
- logger.error(f"导入 model_utils 失败: {str(e)}")
43
- raise
44
-
45
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
- logger.info(f"使用设备: {device}")
47
- if torch.cuda.is_available():
48
- logger.info(f"GPU 信息: {torch.cuda.get_device_name(0)}")
49
-
50
- def load_models():
51
- model_info = {
52
- "Elastic": ("models/best_model-E-500-68.pth", 2),
53
- "Plastic": ("models/best_model-P-5000-180.pth", 2),
54
- "Brittle": ("models/best_model-B-6000-185.pth", 2)
55
- }
56
- models = {}
57
-
58
- for name, (pth_path, output_dim) in model_info.items():
59
- logger.info(f"正在加载 {name} 模型: {pth_path}")
60
-
61
- # 检查模型文件是否存在
62
- if not os.path.exists(pth_path):
63
- logger.error(f"模型文件不存在: {pth_path}")
64
-
65
- # 尝试可能的文件名变体
66
- possible_files = [
67
- pth_path,
68
- pth_path.lower(),
69
- pth_path.upper(),
70
- pth_path.replace("-", "_"),
71
- pth_path.replace("_", "-")
72
- ]
73
-
74
- found = False
75
- for file in possible_files:
76
- if os.path.exists(file):
77
- logger.warning(f"使用替代文件: {file}")
78
- pth_path = file
79
- found = True
80
- break
81
-
82
- if not found:
83
- logger.error(f"找不到任何匹配的模型文件: {pth_path}")
84
- raise FileNotFoundError(f"模型文件 {pth_path} 不存在")
85
-
86
- try:
87
- # 修复模型初始化参数
88
- model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=output_dim, num_heads=8)
89
-
90
- # 加载模型状态 - 解决 PyTorch 2.6+ 的安全问题
91
- logger.info(f"加载模型权重: {pth_path}")
92
-
93
- # 方法 1: 禁用 weights_only
94
- try:
95
- # 尝试使用 weights_only=False 加载
96
- state = torch.load(pth_path, map_location=device, weights_only=False)
97
- logger.info("使用 weights_only=False 成功加载模型")
98
- except Exception as e:
99
- logger.warning(f"使用 weights_only=False 加载失败: {str(e)}")
100
- logger.info("尝试使用 weights_only=True 并添加安全全局变量")
101
-
102
- # 方法 2: 添加安全全局变量
103
- try:
104
- # 导入必要的模块
105
- import numpy as np
106
- import torch.serialization
107
-
108
- # 添加安全全局变量
109
- torch.serialization.add_safe_globals([getattr(np.core.multiarray, 'scalar')])
110
- state = torch.load(pth_path, map_location=device, weights_only=True)
111
- logger.info("使用 weights_only=True 和安全全局变量成功加载模型")
112
- except:
113
- # 最后尝试原始方式
114
- logger.warning("安全方式加载失败,尝试原始加载方式")
115
- state = torch.load(pth_path, map_location=device)
116
-
117
- # 检查状态字典键名
118
- if "model_state_dict" in state:
119
- state_dict = state["model_state_dict"]
120
- logger.info("使用 'model_state_dict' 加载模型")
121
- else:
122
- state_dict = state # 直接使用整个文件
123
- logger.info("使用整个状态字典加载模型")
124
-
125
- # 加载模型参数
126
- model.load_state_dict(state_dict)
127
- model.eval().to(device)
128
- models[name] = model
129
- logger.info(f"{name} 模型加载成功!")
130
-
131
- except Exception as e:
132
- logger.error(f"加载 {name} 模型失败: {str(e)}")
133
- raise
134
-
135
- return models
136
-
137
- logger.info("开始加载所有模型...")
138
- try:
139
- models = load_models()
140
- logger.info("所有模型加载完成!")
141
- except Exception as e:
142
- logger.error(f"模型加载过程中发生错误: {str(e)}")
143
- raise
144
-
145
- def predict_all(smiles):
146
- logger.info(f"收到预测请求: SMILES = {smiles}")
147
- try:
148
- # 转换SMILES为图数据
149
- logger.info("转换SMILES为图数据...")
150
- atom_features, (rows, cols, edge_attr), mol = smiles_to_graph(smiles)
151
-
152
- # 创建PyG数对象
153
- logger.info("创建PyG数据对象...")
154
- x = torch.tensor(atom_features, dtype=torch.float)
155
- edge_index = torch.tensor(np.column_stack((rows, cols)).T, dtype=torch.long)
156
- edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
157
- data = PyGData(
158
- x=x,
159
- edge_index=edge_index,
160
- edge_attr=edge_attr,
161
- smiles=[smiles],
162
- batch=torch.zeros(x.size(0), dtype=torch.long)
163
- )
164
-
165
- results = []
166
- # 对每个模型进行预测
167
- for name in ["Elastic", "Plastic", "Brittle"]:
168
- logger.info(f"使用 {name} 模型进行预测...")
169
- try:
170
- buf, pred = visualize_single_molecule(models[name], data, device, name)
171
- # 修复图像处理
172
- if buf:
173
- buf.seek(0) # 重置缓冲区位置
174
- img = Image.open(buf)
175
- pred_text = f"{name} Result: {'1' if pred == 1 else '0'}"
176
- logger.info(f"{name} 预测结果: {pred}")
177
- results.append((pred_text, img))
178
- else:
179
- error_msg = f"{name} 预测失败: 未生成图像"
180
- logger.error(error_msg)
181
- results.append((error_msg, None))
182
-
183
- except Exception as e:
184
- error_msg = f"{name} 预测过程中发生错误: {str(e)}"
185
- logger.error(error_msg)
186
- results.append((error_msg, None))
187
-
188
- return results[0][0], results[0][1], results[1][0], results[1][1], results[2][0], results[2][1]
189
-
190
- except Exception as e:
191
- error_msg = f"预测过程中发生严重错误: {str(e)}"
192
- logger.error(error_msg)
193
- return error_msg, None, error_msg, None, error_msg, None
194
-
195
- # 修复输出组件
196
- outputs = [
197
- gr.Text(label="Elastic"),
198
- gr.Image(type="pil", label="Elastic attention visualization"),
199
- gr.Text(label="Plastic"),
200
- gr.Image(type="pil", label="Plastic attention visualization"),
201
- gr.Text(label="Brittle"),
202
- gr.Image(type="pil", label="Brittle attention visualization")
203
- ]
204
-
205
- demo = gr.Interface(
206
- fn=predict_all,
207
- inputs=gr.Textbox(label="SMILES", placeholder="Enter a SMILES string, for example: CCO"),
208
- outputs=outputs,
209
- title="CrystalGAT",
210
- description="Input the SMILES string of a molecule, and CrystalGAT will predict its elasticity, plasticity and brittleness classification and visualize the attention weights",
211
- examples=[
212
- ["CCO", "乙醇"],
213
- ["C1=CC=CC=C1", "苯"],
214
- ["CCOC(=O)C", "乙酸乙酯"]
215
- ]
216
- )
217
-
218
- if __name__ == "__main__":
219
- logger.info("启动Gradio应用...")
220
- try:
221
- demo.launch(server_name="0.0.0.0", server_port=7860)
222
- except Exception as e:
223
- logger.error(f"启动Gradio应用失败: {str(e)}")
224
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Tue Jul 8 15:00:01 2025
4
+ @author: User
5
+ """
6
+ import os
7
+ import zipfile
8
+ import torch
9
+ import numpy as np
10
+ import logging
11
+ from PIL import Image
12
+ import gradio as gr
13
+ from torch_geometric.data import Data as PyGData
14
+ import matplotlib
15
+ matplotlib.use('Agg') # 修复后台线程问题
16
+
17
+ # 新增导入
18
+ from rdkit import Chem
19
+ from rdkit.Chem import AllChem, Draw
20
+ from rdkit.Chem.Draw import rdMolDraw2D
21
+ import tempfile
22
+
23
+ # 配置日志
24
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # GPU内存优化
28
+ if torch.cuda.is_available():
29
+ torch.cuda.empty_cache()
30
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
31
+ logger.info("设置GPU内存优化参数: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128")
32
+ else:
33
+ logger.info("使用CPU运行")
34
+
35
+ # 解压模型文件
36
+ MODEL_FILES = ["best_model-E-500-68.pth", "best_model-P-5000-180.pth", "best_model-B-6000-185.pth"]
37
+ if not all(os.path.exists(f) for f in MODEL_FILES):
38
+ logger.info("开始解压模型文件...")
39
+ try:
40
+ with zipfile.ZipFile("models.zip", 'r') as zip_ref:
41
+ zip_ref.extractall(".")
42
+ logger.info("模型文件解压完成!")
43
+ except Exception as e:
44
+ logger.error(f"解压模型文件失败: {str(e)}")
45
+ raise
46
+
47
+ # 导入模型工具
48
+ try:
49
+ from model_utils import EnhancedGAT, smiles_to_graph, visualize_single_molecule
50
+ logger.info("成功导入 model_utils 模块")
51
+ except ImportError as e:
52
+ logger.error(f"导入 model_utils 失败: {str(e)}")
53
+ raise
54
+
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ logger.info(f"使用设备: {device}")
57
+ if torch.cuda.is_available():
58
+ logger.info(f"GPU 信息: {torch.cuda.get_device_name(0)}")
59
+
60
+ def load_models():
61
+ model_info = {
62
+ "Elastic": ("best_model-E-500-68.pth", 2),
63
+ "Plastic": ("best_model-P-5000-180.pth", 2),
64
+ "Brittle": ("best_model-B-6000-185.pth", 2)
65
+ }
66
+ models = {}
67
+
68
+ for name, (pth_path, output_dim) in model_info.items():
69
+ logger.info(f"正在加载 {name} 模型: {pth_path}")
70
+
71
+ if not os.path.exists(pth_path):
72
+ logger.error(f"模型文件不存在: {pth_path}")
73
+ raise FileNotFoundError(f"模型文件 {pth_path} 不存在")
74
+
75
+ try:
76
+ model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=output_dim, num_heads=8)
77
+
78
+ # 简化模型加载
79
+ state_dict = torch.load(pth_path, map_location=device)
80
+ if "model_state_dict" in state_dict:
81
+ model.load_state_dict(state_dict["model_state_dict"])
82
+ else:
83
+ model.load_state_dict(state_dict)
84
+
85
+ model.eval().to(device)
86
+ models[name] = model
87
+ logger.info(f"{name} 模型加载成功!")
88
+ except Exception as e:
89
+ logger.error(f"加载 {name} 模型失败: {str(e)}")
90
+ raise
91
+
92
+ return models
93
+
94
+ logger.info("开始加载所有模型...")
95
+ try:
96
+ models = load_models()
97
+ logger.info("所有模型加载完成!")
98
+ except Exception as e:
99
+ logger.error(f"模型加载过程中发生错误: {str(e)}")
100
+ # 创建虚拟模型保持应用运行
101
+ dummy_model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=2, num_heads=8).eval()
102
+ models = {name: dummy_model for name in ["Elastic", "Plastic", "Brittle"]}
103
+ logger.warning("使用虚拟模型继续运行,功能受限")
104
+
105
+ def predict_all(smiles):
106
+ logger.info(f"收到预测请求: SMILES = {smiles}")
107
+ try:
108
+ # 转换SMILES为图数据
109
+ logger.info("转换SMILES为图数据...")
110
+ atom_features, (rows, cols, edge_attr), mol = smiles_to_graph(smiles)
111
+
112
+ # 创建PyG数据对象
113
+ logger.info("创建PyG数据对象...")
114
+ x = torch.tensor(atom_features, dtype=torch.float)
115
+ edge_index = torch.tensor(np.column_stack((rows, cols)).T, dtype=torch.long)
116
+ edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
117
+ data = PyGData(
118
+ x=x,
119
+ edge_index=edge_index,
120
+ edge_attr=edge_attr,
121
+ smiles=[smiles],
122
+ batch=torch.zeros(x.size(0), dtype=torch.long
123
+ )
124
+
125
+ results = []
126
+ # 对每个模型进行预测
127
+ for name in ["Elastic", "Plastic", "Brittle"]:
128
+ logger.info(f"使用 {name} 模型进行预测...")
129
+ try:
130
+ buf, pred = visualize_single_molecule(models[name], data, device, name)
131
+ if buf:
132
+ img = Image.open(buf)
133
+ pred_text = f"{name}: {'Positive' if pred == 1 else 'Negative'}"
134
+ results.append((pred_text, img))
135
+ else:
136
+ error_msg = f"{name} 预测失败: 未生成图像"
137
+ logger.error(error_msg)
138
+ results.append((error_msg, None))
139
+ except Exception as e:
140
+ error_msg = f"{name} 预测错误: {str(e)}"
141
+ logger.error(error_msg)
142
+ results.append((error_msg, None))
143
+
144
+ return results[0][0], results[0][1], results[1][0], results[1][1], results[2][0], results[2][1]
145
+ except Exception as e:
146
+ error_msg = f"预测过程错误: {str(e)}"
147
+ logger.error(error_msg)
148
+ return error_msg, None, error_msg, None, error_msg, None
149
+
150
+ # ===== 新增分子结构绘制功能 =====
151
+ def draw_molecule(smiles=None):
152
+ """根SMILES生成分子结构图像"""
153
+ try:
154
+ if not smiles:
155
+ return None
156
+
157
+ mol = Chem.MolFromSmiles(smiles)
158
+ if not mol:
159
+ return None
160
+
161
+ # 生成2D分子图像
162
+ drawer = rdMolDraw2D.MolDraw2DCairo(400, 300)
163
+ drawer.DrawMolecule(mol)
164
+ drawer.FinishDrawing()
165
+
166
+ # 转换为PIL图像
167
+ img_data = drawer.GetDrawingText()
168
+ return Image.open(tempfile.NamedTemporaryFile(delete=False, suffix='.png').write(img_data))
169
+ except Exception as e:
170
+ logger.error(f"分子绘制失败: {str(e)}")
171
+ return None
172
+
173
+ def predict_from_structure(mol_dict):
174
+ """从绘制的分子结构预测"""
175
+ try:
176
+ if not mol_dict or not mol_dict['atoms']:
177
+ return "请绘制分子结构", None, "请绘制分子结构", None, "请绘制分子结构", None
178
+
179
+ # 转换绘制结构为SMILES
180
+ mol = Chem.RWMol()
181
+ atom_map = {}
182
+
183
+ # 添加原子
184
+ for atom in mol_dict['atoms']:
185
+ new_atom = Chem.Atom(atom['atom_symbol'])
186
+ idx = mol.AddAtom(new_atom)
187
+ atom_map[atom['atom_index']] = idx
188
+
189
+ # 添加键
190
+ for bond in mol_dict['bonds']:
191
+ start = atom_map[bond['start_atom']]
192
+ end = atom_map[bond['end_atom']]
193
+ bond_type = Chem.BondType.values[bond['bond_type'] - 1] # 转换键类型
194
+ mol.AddBond(start, end, bond_type)
195
+
196
+ # 获取SMILES
197
+ smiles = Chem.MolToSmiles(mol)
198
+ logger.info(f"转换的SMILES: {smiles}")
199
+
200
+ # 进行预测
201
+ return predict_all(smiles)
202
+ except Exception as e:
203
+ error_msg = f"结构转换错误: {str(e)}"
204
+ logger.error(error_msg)
205
+ return error_msg, None, error_msg, None, error_msg, None
206
+
207
+ # ===== 创建多选项卡界面 =====
208
+ with gr.Blocks(title="CrystalGAT") as demo:
209
+ gr.Markdown("# CrystalGAT分子性质预测")
210
+ gr.Markdown("输入SMILES或绘制分子结构,预测弹性、塑性和脆性分类并可视化注意力权重")
211
+
212
+ with gr.Tab("SMILES输入"):
213
+ with gr.Row():
214
+ with gr.Column():
215
+ smiles_input = gr.Textbox(
216
+ label="SMILES",
217
+ placeholder="输入SMILES字符串,例如: CCO",
218
+ interactive=True
219
+ )
220
+ gr.Examples(
221
+ examples=[
222
+ ["CCO", "乙醇"],
223
+ ["C1=CC=CC=C1", "苯"],
224
+ ["CCOC(=O)C", "乙酸乙酯"]
225
+ ],
226
+ inputs=smiles_input
227
+ )
228
+ submit_btn = gr.Button("预测")
229
+ with gr.Column():
230
+ molecule_img = gr.Image(label="分子结构", interactive=False)
231
+ smiles_input.change(
232
+ fn=draw_molecule,
233
+ inputs=smiles_input,
234
+ outputs=molecule_img
235
+ )
236
+
237
+ with gr.Row():
238
+ with gr.Column():
239
+ elastic_text = gr.Text(label="弹性")
240
+ elastic_img = gr.Image(label="注意力可视化")
241
+ with gr.Column():
242
+ plastic_text = gr.Text(label="塑性")
243
+ plastic_img = gr.Image(label="注意力可视化")
244
+ with gr.Column():
245
+ brittle_text = gr.Text(label="脆性")
246
+ brittle_img = gr.Image(label="注意力可视化")
247
+
248
+ submit_btn.click(
249
+ fn=predict_all,
250
+ inputs=smiles_input,
251
+ outputs=[
252
+ elastic_text, elastic_img,
253
+ plastic_text, plastic_img,
254
+ brittle_text, brittle_img
255
+ ]
256
+ )
257
+
258
+ with gr.Tab("绘制分子结构"):
259
+ with gr.Row():
260
+ with gr.Column():
261
+ molecule_editor = gr.Molecule(
262
+ label="绘制分子结构",
263
+ type="sketch",
264
+ interactive=True
265
+ )
266
+ draw_submit = gr.Button("预测")
267
+ with gr.Column():
268
+ gr.Markdown("### 绘制说明")
269
+ gr.Markdown("1. 从右侧选择原子工具<br>2. 在画布上点击添加原子<br>3. 选择键工具连接原子<br>4. 点击预测按钮进行分析")
270
+
271
+ with gr.Row():
272
+ with gr.Column():
273
+ draw_elastic_text = gr.Text(label="弹性")
274
+ draw_elastic_img = gr.Image(label="注意力可视化")
275
+ with gr.Column():
276
+ draw_plastic_text = gr.Text(label="塑性")
277
+ draw_plastic_img = gr.Image(label="注意力可视化")
278
+ with gr.Column():
279
+ draw_brittle_text = gr.Text(label="脆性")
280
+ draw_brittle_img = gr.Image(label="注意力可视化")
281
+
282
+ draw_submit.click(
283
+ fn=predict_from_structure,
284
+ inputs=molecule_editor,
285
+ outputs=[
286
+ draw_elastic_text, draw_elastic_img,
287
+ draw_plastic_text, draw_plastic_img,
288
+ draw_brittle_text, draw_brittle_img
289
+ ]
290
+ )
291
+
292
+ if __name__ == "__main__":
293
+ logger.info("启动Gradio应用...")
294
+ try:
295
+ demo.launch(
296
+ server_name="0.0.0.0",
297
+ server_port=7860,
298
+ share=False,
299
+ show_error=True
300
+ )
301
+ except Exception as e:
302
+ logger.error(f"启动Gradio应用失败: {str(e)}")
303
+ raise