QQ2S3R commited on
Commit
8ab481c
·
verified ·
1 Parent(s): b23630b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -340
app.py CHANGED
@@ -15,13 +15,9 @@ from torch_geometric.data import Data as PyGData
15
  import matplotlib
16
  matplotlib.use('Agg') # 修复后台线程问题
17
  from rdkit import Chem
18
- from rdkit.Chem import Draw
19
- from rdkit.Chem import AllChem
20
- from rdkit.Chem import MolFromSmiles
21
  from io import BytesIO
22
  import traceback
23
- import base64
24
- import json
25
 
26
  # 配置日志
27
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -62,15 +58,10 @@ def load_models():
62
  "Brittle": ("models/best_model-B-6000-185.pth", 2)
63
  }
64
  models = {}
65
-
66
  for name, (pth_path, output_dim) in model_info.items():
67
  logger.info(f"正在加载 {name} 模型: {pth_path}")
68
-
69
- # 检查模型文件是否存在
70
  if not os.path.exists(pth_path):
71
  logger.error(f"模型文件不存在: {pth_path}")
72
-
73
- # 尝试可能的文件名变体
74
  possible_files = [
75
  pth_path,
76
  pth_path.lower(),
@@ -78,7 +69,6 @@ def load_models():
78
  pth_path.replace("-", "_"),
79
  pth_path.replace("_", "-")
80
  ]
81
-
82
  found = False
83
  for file in possible_files:
84
  if os.path.exists(file):
@@ -86,119 +76,70 @@ def load_models():
86
  pth_path = file
87
  found = True
88
  break
89
-
90
  if not found:
91
- logger.error(f"找不到任何匹配的模型文件: {pth_path}")
92
  raise FileNotFoundError(f"模型文件 {pth_path} 不存在")
93
-
94
  try:
95
- # 修复模型初始化参数
96
  model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=output_dim, num_heads=8)
97
-
98
- # 加载模型状态 - 解决 PyTorch 2.6+ 的安全问题
99
  logger.info(f"加载模型权重: {pth_path}")
100
-
101
- # 方法 1: 禁用 weights_only
102
  try:
103
- # 尝试使用 weights_only=False 加载
104
  state = torch.load(pth_path, map_location=device, weights_only=False)
105
- logger.info("使用 weights_only=False 成功加载模型")
106
  except Exception as e:
107
- logger.warning(f"使用 weights_only=False 加载失败: {str(e)}")
108
- logger.info("尝试使用 weights_only=True 并添加安全全局变量")
109
-
110
- # 方法 2: 添加安全全局变量
111
  try:
112
- # 导入必要的模块
113
- import numpy as np
114
- import torch.serialization
115
-
116
- # 添加安全全局变量
117
  torch.serialization.add_safe_globals([getattr(np, '_core', np).multiarray.scalar])
118
  state = torch.load(pth_path, map_location=device, weights_only=True)
119
- logger.info("使用 weights_only=True 和安全全局变量成功加载模型")
120
  except:
121
- # 最后尝试原始方式
122
- logger.warning("安全方式加载失败,尝试原始加载方式")
123
  state = torch.load(pth_path, map_location=device)
124
-
125
- # 检查状态字典键名
126
  if "model_state_dict" in state:
127
  state_dict = state["model_state_dict"]
128
- logger.info("使用 'model_state_dict' 加载模型")
129
  else:
130
- state_dict = state # 直接使用整个文件
131
- logger.info("使用整个状态字典加载模型")
132
-
133
- # 加载模型参数
134
  model.load_state_dict(state_dict)
135
  model.eval().to(device)
136
  models[name] = model
137
  logger.info(f"{name} 模型加载成功!")
138
-
139
  except Exception as e:
140
  logger.error(f"加载 {name} 模型失败: {str(e)}")
141
  raise
142
-
143
  return models
144
 
145
  logger.info("开始加载所有模型...")
146
- try:
147
- models = load_models()
148
- logger.info("所有模型加载完成!")
149
- except Exception as e:
150
- logger.error(f"模型加载过程中发生错误: {str(e)}")
151
- raise
152
 
153
  def predict_all(smiles):
154
  logger.info(f"收到预测请求: SMILES = {smiles}")
155
  try:
156
- # 转换SMILES为图数据
157
- logger.info("转换SMILES为图数据...")
158
  atom_features, (rows, cols, edge_attr), mol = smiles_to_graph(smiles)
159
-
160
- # 创建PyG数据对象
161
- logger.info("创建PyG数据对象...")
162
  x = torch.tensor(atom_features, dtype=torch.float)
163
  edge_index = torch.tensor(np.column_stack((rows, cols)).T, dtype=torch.long)
164
  edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
165
- data = PyGData(
166
- x=x,
167
- edge_index=edge_index,
168
- edge_attr=edge_attr,
169
- smiles=[smiles],
170
- batch=torch.zeros(x.size(0), dtype=torch.long)
171
- )
172
-
173
  results = []
174
- # 对每个模型进行预测
175
  for name in ["Elastic", "Plastic", "Brittle"]:
176
  logger.info(f"使用 {name} 模型进行预测...")
177
  try:
178
  buf, pred = visualize_single_molecule(models[name], data, device, name)
179
- # 修复图像处理
180
  if buf:
181
- buf.seek(0) # 重置缓冲区位置
182
  img = Image.open(buf)
183
  pred_text = f"{name} Result: {'1' if pred == 1 else '0'}"
184
- logger.info(f"{name} 预测结果: {pred}")
185
  results.append((pred_text, img))
186
  else:
187
- error_msg = f"{name} 预测失败: 未生成图像"
188
- logger.error(error_msg)
189
- results.append((error_msg, None))
190
-
191
  except Exception as e:
192
- error_msg = f"{name} 预测过程中发生错误: {str(e)}"
193
- logger.error(error_msg)
194
- results.append((error_msg, None))
195
-
196
- return results[0][0], results[0][1], results[1][0], results[1][1], results[2][0], results[2][1]
197
-
198
  except Exception as e:
199
- error_msg = f"预测过程中发生严重错误: {str(e)}"
200
- logger.error(error_msg)
201
- return error_msg, None, error_msg, None, error_msg, None
202
 
203
  # 原子和键类型选项
204
  ATOM_TYPES = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I", "H"]
@@ -206,318 +147,149 @@ BOND_TYPES = ["单键", "双键", "三键"]
206
 
207
  # 初始化分子结构
208
  def init_molecule():
209
- return {
210
- "atoms": [],
211
- "bonds": []
212
- }
213
 
214
  # 添加原子
215
  def add_atom(molecule, atom_type):
216
- molecule["atoms"].append({
217
- "id": len(molecule["atoms"]),
218
- "type": atom_type
219
- })
220
  return molecule
221
 
222
  # 添加键
223
  def add_bond(molecule, atom1_id, atom2_id, bond_type):
224
- # 检查原子是否存在
225
  if atom1_id >= len(molecule["atoms"]) or atom2_id >= len(molecule["atoms"]):
226
  return molecule
227
-
228
- # 检查是否已存在键
229
  for bond in molecule["bonds"]:
230
- if (bond["atom1"] == atom1_id and bond["atom2"] == atom2_id) or \
231
- (bond["atom1"] == atom2_id and bond["atom2"] == atom1_id):
232
  return molecule
233
-
234
- molecule["bonds"].append({
235
- "atom1": atom1_id,
236
- "atom2": atom2_id,
237
- "type": bond_type
238
- })
239
  return molecule
240
 
241
- # 从JSON生成SMILES
242
  def generate_smiles(molecule_json):
243
  try:
244
- # 创建空分子
245
  mol = Chem.RWMol()
246
-
247
- # 添加原子
248
  atom_map = {}
249
  for atom in molecule_json["atoms"]:
250
- new_atom = Chem.Atom(atom["type"])
251
- idx = mol.AddAtom(new_atom)
252
  atom_map[atom["id"]] = idx
253
-
254
- # 添加键
255
  for bond in molecule_json["bonds"]:
256
- start_atom = atom_map[bond["atom1"]]
257
- end_atom = atom_map[bond["atom2"]]
258
-
259
- # 确定键类型
260
- bond_type_mapping = {
261
- "单键": Chem.BondType.SINGLE,
262
- "双键": Chem.BondType.DOUBLE,
263
- "三键": Chem.BondType.TRIPLE
264
- }
265
- bond_type = bond_type_mapping.get(bond["type"], Chem.BondType.SINGLE)
266
-
267
- mol.AddBond(start_atom, end_atom, bond_type)
268
-
269
- # 清理分子
270
  mol.UpdatePropertyCache()
271
  Chem.SanitizeMol(mol)
272
-
273
- # 生成SMILES
274
- smiles = Chem.MolToSmiles(mol)
275
- return smiles
276
  except Exception as e:
277
- logger.error(f"生成SMILES失败: {str(e)}")
278
  return None
279
 
280
- # 可视化分子结构
281
  def visualize_molecule(molecule_json):
282
- try:
283
- smiles = generate_smiles(molecule_json)
284
- if not smiles:
285
- return None
286
-
287
- mol = Chem.MolFromSmiles(smiles)
288
- if mol is None:
289
- return None
290
-
291
- # 生成2D坐标
292
- AllChem.Compute2DCoords(mol)
293
-
294
- # 创建图像
295
- img = Draw.MolToImage(mol, size=(300, 300))
296
- img_buffer = BytesIO()
297
- img.save(img_buffer, format="PNG")
298
- img_buffer.seek(0)
299
- return img_buffer
300
- except Exception as e:
301
- logger.error(f"可视化分子失败: {str(e)}")
302
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
- # 使用TabbedInterface组织两种输入方式
305
  with gr.Blocks(title="CrystalGAT") as demo:
306
  gr.Markdown("# CrystalGAT")
307
- gr.Markdown("输入SMILES字符串或构建分子结构CrystalGAT将预测弹性塑性脆性分类并可视化注意力权重")
308
 
309
  with gr.Tab("SMILES输入"):
310
- gr.Markdown("### 输入SMILES字符串")
311
- smiles_input = gr.Textbox(
312
- label="SMILES",
313
- placeholder="输入SMILES字符串,例如: CCO",
314
- interactive=True
315
- )
316
- gr.Examples(
317
- examples=[
318
- ["CCO", "乙醇"],
319
- ["C1=CC=CC=C1", "苯"],
320
- ["CCOC(=O)C", "乙酸乙酯"]
321
- ],
322
- inputs=smiles_input,
323
- label="示例"
324
- )
325
  submit_btn1 = gr.Button("预测", variant="primary")
326
 
327
  with gr.Tab("构建分子"):
328
- gr.Markdown("### 构建分子结构")
329
- gr.Markdown("1. 添加原子:选择原子类型并点击'添加原子'按钮")
330
- gr.Markdown("2. 添加键:选择两个原子和键类型,然后点击'添加键'按钮")
331
- gr.Markdown("3. 完成后点击'生成分子'按钮预览并预测")
332
-
333
- # 初始化分子状态
334
  molecule_state = gr.State(init_molecule())
 
335
 
336
  with gr.Row():
337
- # 原子选择
338
- with gr.Column():
339
- gr.Markdown("### 添加原子")
340
- atom_select = gr.Dropdown(
341
- label="选择原子类型",
342
- choices=ATOM_TYPES,
343
- value="C"
344
- )
345
- add_atom_btn = gr.Button("添加原子")
346
- atoms_list = gr.Dataframe(
347
- label="原子列表",
348
- headers=["ID", "原子类型"],
349
- datatype=["number", "str"],
350
- interactive=False
351
- )
352
-
353
- # 键选择
354
- with gr.Column():
355
- gr.Markdown("### 添加键")
356
- # 创建原子选项列表
357
- atom_options = gr.State([])
358
-
359
- atom1_select = gr.Dropdown(
360
- label="选择第一个原子",
361
- choices=[],
362
- interactive=True
363
- )
364
- atom2_select = gr.Dropdown(
365
- label="选择第二个原子",
366
- choices=[],
367
- interactive=True
368
- )
369
- bond_select = gr.Dropdown(
370
- label="选择键类型",
371
- choices=BOND_TYPES,
372
- value="单键"
373
- )
374
- add_bond_btn = gr.Button("添加键")
375
- bonds_list = gr.Dataframe(
376
- label="键列表",
377
- headers=["原子1", "原子2", "键类型"],
378
- datatype=["str", "str", "str"],
379
- interactive=False
380
- )
381
 
382
- # 操作按钮
383
  with gr.Row():
384
- clear_btn = gr.Button("清除所有")
385
- generate_btn = gr.Button("生成分子")
 
 
 
386
 
387
- # 分子预览
 
 
388
  molecule_img = gr.Image(label="分子预览", interactive=False)
389
- molecule_smiles = gr.Textbox(label="生成的SMILES", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
- # 预测按钮
392
- submit_btn2 = gr.Button("使用此分子进行预测", variant="primary")
 
 
 
 
 
 
 
393
 
394
- # 输出区域 (两种输入方式共享)
395
- gr.Markdown("## 预测结果")
396
  with gr.Row():
397
  elastic_text = gr.Text(label="Elastic")
398
- elastic_img = gr.Image(type="pil", label="Elastic attention visualization")
399
  with gr.Row():
400
  plastic_text = gr.Text(label="Plastic")
401
- plastic_img = gr.Image(type="pil", label="Plastic attention visualization")
402
  with gr.Row():
403
  brittle_text = gr.Text(label="Brittle")
404
- brittle_img = gr.Image(type="pil", label="Brittle attention visualization")
405
-
406
- # 更新原子选项列表
407
- def update_atom_options(molecule):
408
- # 创建原子选项列表
409
- options = [f"{i}: {atom['type']}" for i, atom in enumerate(molecule["atoms"])]
410
- return options
411
 
412
- # 更新原子列表显示
413
- def update_atoms_list(molecule):
414
- atoms_data = [[i, atom["type"]] for i, atom in enumerate(molecule["atoms"])]
415
- return atoms_data
416
-
417
- # 更新键列表显示
418
- def update_bonds_list(molecule):
419
- bonds_data = []
420
- for bond in molecule["bonds"]:
421
- atom1_type = molecule["atoms"][bond["atom1"]]["type"]
422
- atom2_type = molecule["atoms"][bond["atom2"]]["type"]
423
- bonds_data.append([
424
- f"{bond['atom1']}: {atom1_type}",
425
- f"{bond['atom2']}: {atom2_type}",
426
- bond["type"]
427
- ])
428
- return bonds_data
429
-
430
- # 事件处理 - 添加原子
431
- add_atom_btn.click(
432
- fn=lambda atom, mol: add_atom(mol, atom),
433
- inputs=[atom_select, molecule_state],
434
- outputs=molecule_state
435
- ).then(
436
- fn=update_atoms_list,
437
- inputs=molecule_state,
438
- outputs=atoms_list
439
- ).then(
440
- fn=update_atom_options,
441
- inputs=molecule_state,
442
- outputs=atom_options
443
- ).then(
444
- lambda options: [
445
- gr.Dropdown.update(choices=options, value=options[0] if options else None),
446
- gr.Dropdown.update(choices=options, value=options[0] if options else None)
447
- ],
448
- inputs=atom_options,
449
- outputs=[atom1_select, atom2_select]
450
- )
451
-
452
- # 事件处理 - 添加键
453
- add_bond_btn.click(
454
- fn=lambda atom1, atom2, bond, mol: add_bond(mol, int(atom1.split(":")[0]), int(atom2.split(":")[0]), bond),
455
- inputs=[atom1_select, atom2_select, bond_select, molecule_state],
456
- outputs=molecule_state
457
- ).then(
458
- fn=update_bonds_list,
459
- inputs=molecule_state,
460
- outputs=bonds_list
461
- )
462
-
463
- # 事件处理 - 清除所有
464
- clear_btn.click(
465
- fn=init_molecule,
466
- outputs=molecule_state
467
- ).then(
468
- fn=lambda: [],
469
- outputs=atoms_list
470
- ).then(
471
- fn=lambda: [],
472
- outputs=bonds_list
473
- ).then(
474
- fn=lambda: [],
475
- outputs=atom_options
476
- ).then(
477
- lambda: [
478
- gr.Dropdown.update(choices=[], value=None),
479
- gr.Dropdown.update(choices=[], value=None)
480
- ],
481
- outputs=[atom1_select, atom2_select]
482
- ).then(
483
- fn=lambda: None,
484
- outputs=molecule_img
485
- ).then(
486
- fn=lambda: "",
487
- outputs=molecule_smiles
488
- )
489
-
490
- # 事件处理 - 生成分子
491
- generate_btn.click(
492
- fn=generate_smiles,
493
- inputs=molecule_state,
494
- outputs=molecule_smiles
495
- ).then(
496
- fn=visualize_molecule,
497
- inputs=molecule_state,
498
- outputs=molecule_img
499
- )
500
-
501
- # 设置交互
502
- # SMILES输入路径
503
- submit_btn1.click(
504
- fn=predict_all,
505
- inputs=smiles_input,
506
- outputs=[elastic_text, elastic_img, plastic_text, plastic_img, brittle_text, brittle_img]
507
- )
508
-
509
- # 分子构建预测路径
510
- submit_btn2.click(
511
- fn=lambda smiles: predict_all(smiles) if smiles else ("请输入有效的SMILES", None, "", None, "", None),
512
- inputs=molecule_smiles,
513
- outputs=[elastic_text, elastic_img, plastic_text, plastic_img, brittle_text, brittle_img]
514
- )
515
 
516
  if __name__ == "__main__":
517
- logger.info("启动Gradio应用...")
518
- try:
519
- demo.launch(server_name="0.0.0.0", server_port=7860)
520
- except Exception as e:
521
- logger.error(f"启动Gradio应用失败: {str(e)}")
522
- logger.error(traceback.format_exc())
523
- raise
 
15
  import matplotlib
16
  matplotlib.use('Agg') # 修复后台线程问题
17
  from rdkit import Chem
18
+ from rdkit.Chem import Draw, AllChem, MolFromSmiles
 
 
19
  from io import BytesIO
20
  import traceback
 
 
21
 
22
  # 配置日志
23
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
58
  "Brittle": ("models/best_model-B-6000-185.pth", 2)
59
  }
60
  models = {}
 
61
  for name, (pth_path, output_dim) in model_info.items():
62
  logger.info(f"正在加载 {name} 模型: {pth_path}")
 
 
63
  if not os.path.exists(pth_path):
64
  logger.error(f"模型文件不存在: {pth_path}")
 
 
65
  possible_files = [
66
  pth_path,
67
  pth_path.lower(),
 
69
  pth_path.replace("-", "_"),
70
  pth_path.replace("_", "-")
71
  ]
 
72
  found = False
73
  for file in possible_files:
74
  if os.path.exists(file):
 
76
  pth_path = file
77
  found = True
78
  break
 
79
  if not found:
 
80
  raise FileNotFoundError(f"模型文件 {pth_path} 不存在")
 
81
  try:
 
82
  model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=output_dim, num_heads=8)
 
 
83
  logger.info(f"加载模型权重: {pth_path}")
 
 
84
  try:
 
85
  state = torch.load(pth_path, map_location=device, weights_only=False)
 
86
  except Exception as e:
87
+ logger.warning(f"weights_only=False 加载失败: {str(e)}")
 
 
 
88
  try:
89
+ import numpy as np, torch.serialization
 
 
 
 
90
  torch.serialization.add_safe_globals([getattr(np, '_core', np).multiarray.scalar])
91
  state = torch.load(pth_path, map_location=device, weights_only=True)
 
92
  except:
 
 
93
  state = torch.load(pth_path, map_location=device)
 
 
94
  if "model_state_dict" in state:
95
  state_dict = state["model_state_dict"]
 
96
  else:
97
+ state_dict = state
 
 
 
98
  model.load_state_dict(state_dict)
99
  model.eval().to(device)
100
  models[name] = model
101
  logger.info(f"{name} 模型加载成功!")
 
102
  except Exception as e:
103
  logger.error(f"加载 {name} 模型失败: {str(e)}")
104
  raise
 
105
  return models
106
 
107
  logger.info("开始加载所有模型...")
108
+ models = load_models()
109
+ logger.info("所有模型加载完成!")
 
 
 
 
110
 
111
  def predict_all(smiles):
112
  logger.info(f"收到预测请求: SMILES = {smiles}")
113
  try:
 
 
114
  atom_features, (rows, cols, edge_attr), mol = smiles_to_graph(smiles)
 
 
 
115
  x = torch.tensor(atom_features, dtype=torch.float)
116
  edge_index = torch.tensor(np.column_stack((rows, cols)).T, dtype=torch.long)
117
  edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
118
+ data = PyGData(x=x,
119
+ edge_index=edge_index,
120
+ edge_attr=edge_attr,
121
+ smiles=[smiles],
122
+ batch=torch.zeros(x.size(0), dtype=torch.long))
 
 
 
123
  results = []
 
124
  for name in ["Elastic", "Plastic", "Brittle"]:
125
  logger.info(f"使用 {name} 模型进行预测...")
126
  try:
127
  buf, pred = visualize_single_molecule(models[name], data, device, name)
 
128
  if buf:
129
+ buf.seek(0)
130
  img = Image.open(buf)
131
  pred_text = f"{name} Result: {'1' if pred == 1 else '0'}"
 
132
  results.append((pred_text, img))
133
  else:
134
+ results.append((f"{name} 预测失败: 未生成图像", None))
 
 
 
135
  except Exception as e:
136
+ results.append((f"{name} 预测过程中发生错误: {str(e)}", None))
137
+ return (results[0][0], results[0][1],
138
+ results[1][0], results[1][1],
139
+ results[2][0], results[2][1])
 
 
140
  except Exception as e:
141
+ logger.error(f"预测过程中发生严重错误: {str(e)}")
142
+ return ("预测失败", None, "预测失败", None, "预测失败", None)
 
143
 
144
  # 原子和键类型选项
145
  ATOM_TYPES = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I", "H"]
 
147
 
148
  # 初始化分子结构
149
  def init_molecule():
150
+ return {"atoms": [], "bonds": []}
 
 
 
151
 
152
  # 添加原子
153
  def add_atom(molecule, atom_type):
154
+ molecule["atoms"].append({"id": len(molecule["atoms"]), "type": atom_type})
 
 
 
155
  return molecule
156
 
157
  # 添加键
158
  def add_bond(molecule, atom1_id, atom2_id, bond_type):
 
159
  if atom1_id >= len(molecule["atoms"]) or atom2_id >= len(molecule["atoms"]):
160
  return molecule
 
 
161
  for bond in molecule["bonds"]:
162
+ if {bond["atom1"], bond["atom2"]} == {atom1_id, atom2_id}:
 
163
  return molecule
164
+ molecule["bonds"].append({"atom1": atom1_id, "atom2": atom2_id, "type": bond_type})
 
 
 
 
 
165
  return molecule
166
 
167
+ # 从 JSON SMILES
168
  def generate_smiles(molecule_json):
169
  try:
 
170
  mol = Chem.RWMol()
 
 
171
  atom_map = {}
172
  for atom in molecule_json["atoms"]:
173
+ idx = mol.AddAtom(Chem.Atom(atom["type"]))
 
174
  atom_map[atom["id"]] = idx
 
 
175
  for bond in molecule_json["bonds"]:
176
+ t = {"单键": Chem.BondType.SINGLE,
177
+ "双键": Chem.BondType.DOUBLE,
178
+ "三键": Chem.BondType.TRIPLE}[bond["type"]]
179
+ mol.AddBond(atom_map[bond["atom1"]], atom_map[bond["atom2"]], t)
 
 
 
 
 
 
 
 
 
 
180
  mol.UpdatePropertyCache()
181
  Chem.SanitizeMol(mol)
182
+ return Chem.MolToSmiles(mol)
 
 
 
183
  except Exception as e:
184
+ logger.error(f"生成SMILES失败: {e}")
185
  return None
186
 
187
+ # 可视化分子
188
  def visualize_molecule(molecule_json):
189
+ smiles = generate_smiles(molecule_json)
190
+ if not smiles:
191
+ return None
192
+ mol = MolFromSmiles(smiles)
193
+ if mol is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  return None
195
+ AllChem.Compute2DCoords(mol)
196
+ img = Draw.MolToImage(mol, size=(300, 300))
197
+ buf = BytesIO()
198
+ img.save(buf, format="PNG")
199
+ buf.seek(0)
200
+ return buf
201
+
202
+ # 更新下拉菜单选项(改为 gr.update)
203
+ def update_atom_dropdowns(molecule):
204
+ choices = [f"{a['id']}: {a['type']}" for a in molecule["atoms"]]
205
+ return gr.update(choices=choices, value=None), gr.update(choices=choices, value=None)
206
+
207
+ # 更新原子列表
208
+ def update_atoms_list(molecule):
209
+ return [[a["id"], a["type"]] for a in molecule["atoms"]]
210
+
211
+ # 更新键列表
212
+ def update_bonds_list(molecule):
213
+ out = []
214
+ for b in molecule["bonds"]:
215
+ t1 = next(a["type"] for a in molecule["atoms"] if a["id"]==b["atom1"])
216
+ t2 = next(a["type"] for a in molecule["atoms"] if a["id"]==b["atom2"])
217
+ out.append([f"{b['atom1']}: {t1}", f"{b['atom2']}: {t2}", b["type"]])
218
+ return out
219
 
 
220
  with gr.Blocks(title="CrystalGAT") as demo:
221
  gr.Markdown("# CrystalGAT")
222
+ gr.Markdown("输入SMILES或构建分子,预测弹性/塑性/脆性并可视化")
223
 
224
  with gr.Tab("SMILES输入"):
225
+ smiles_input = gr.Textbox(label="SMILES", placeholder="例如: CCO")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  submit_btn1 = gr.Button("预测", variant="primary")
227
 
228
  with gr.Tab("构建分子"):
 
 
 
 
 
 
229
  molecule_state = gr.State(init_molecule())
230
+ status_msg = gr.Textbox(label="状态", interactive=False, value="请添加原子开始")
231
 
232
  with gr.Row():
233
+ atom_select = gr.Dropdown(label="选择原子类型", choices=ATOM_TYPES, value="C")
234
+ add_atom_btn = gr.Button("添加原子")
235
+ atoms_list = gr.Dataframe(headers=["ID","原子类型"], datatype=["number","str"], interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
 
237
  with gr.Row():
238
+ atom1_select = gr.Dropdown(label="第一个原子", choices=[], value=None)
239
+ atom2_select = gr.Dropdown(label="第二个原子", choices=[], value=None)
240
+ bond_select = gr.Dropdown(label="键类型", choices=BOND_TYPES, value="单键")
241
+ add_bond_btn = gr.Button("添加键")
242
+ bonds_list = gr.Dataframe(headers=["原子1","原子2","键类型"], datatype=["str","str","str"], interactive=False)
243
 
244
+ clear_btn = gr.Button("清除所有")
245
+ generate_btn = gr.Button("生成分子")
246
+ molecule_smiles = gr.Textbox(label="SMILES结果", interactive=False)
247
  molecule_img = gr.Image(label="分子预览", interactive=False)
248
+ submit_btn2 = gr.Button("使用此分子预测", variant="primary")
249
+
250
+ # 绑定事件
251
+ add_atom_btn.click(fn=lambda at, mol: add_atom(mol, at),
252
+ inputs=[atom_select, molecule_state],
253
+ outputs=molecule_state) \
254
+ .then(fn=update_atoms_list, inputs=molecule_state, outputs=atoms_list) \
255
+ .then(fn=update_atom_dropdowns, inputs=molecule_state, outputs=[atom1_select, atom2_select]) \
256
+ .then(fn=lambda: "原子添加成功", outputs=status_msg)
257
+
258
+ add_bond_btn.click(fn=lambda a1, a2, b, mol: (add_bond(mol,
259
+ int(a1.split(":")[0]),
260
+ int(a2.split(":")[0]), b)) if a1 and a2 else mol,
261
+ inputs=[atom1_select, atom2_select, bond_select, molecule_state],
262
+ outputs=molecule_state) \
263
+ .then(fn=update_bonds_list, inputs=molecule_state, outputs=bonds_list) \
264
+ .then(fn=lambda: "键添加/更新成功", outputs=status_msg)
265
 
266
+ clear_btn.click(fn=init_molecule, outputs=molecule_state) \
267
+ .then(fn=lambda: ([], []), outputs=[atoms_list, bonds_list]) \
268
+ .then(fn=lambda: (gr.update(choices=[], value=None), gr.update(choices=[], value=None)),
269
+ outputs=[atom1_select, atom2_select]) \
270
+ .then(fn=lambda: "已清除所有", outputs=status_msg)
271
+
272
+ generate_btn.click(fn=generate_smiles, inputs=molecule_state, outputs=molecule_smiles) \
273
+ .then(fn=visualize_molecule, inputs=molecule_state, outputs=molecule_img) \
274
+ .then(fn=lambda: "分子生成完成", outputs=status_msg)
275
 
276
+ # 输出区域
 
277
  with gr.Row():
278
  elastic_text = gr.Text(label="Elastic")
279
+ elastic_img = gr.Image(type="pil", label="Elastic 可视化")
280
  with gr.Row():
281
  plastic_text = gr.Text(label="Plastic")
282
+ plastic_img = gr.Image(type="pil", label="Plastic 可视化")
283
  with gr.Row():
284
  brittle_text = gr.Text(label="Brittle")
285
+ brittle_img = gr.Image(type="pil", label="Brittle 可视化")
 
 
 
 
 
 
286
 
287
+ submit_btn1.click(fn=predict_all,
288
+ inputs=smiles_input,
289
+ outputs=[elastic_text, elastic_img, plastic_text, plastic_img, brittle_text, brittle_img])
290
+ submit_btn2.click(fn=lambda s: predict_all(s) if s else ("请输入SMILES", None, "", None, "", None),
291
+ inputs=molecule_smiles,
292
+ outputs=[elastic_text, elastic_img, plastic_text, plastic_img, brittle_text, brittle_img])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  if __name__ == "__main__":
295
+ demo.launch(server_name="0.0.0.0", server_port=7860)