QQ2S3R commited on
Commit
b5e3512
·
verified ·
1 Parent(s): 29442d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -67
app.py CHANGED
@@ -20,8 +20,8 @@ from rdkit.Chem import AllChem
20
  from rdkit.Chem import MolFromSmiles
21
  from io import BytesIO
22
  import traceback
23
- import tempfile
24
  import base64
 
25
 
26
  # 配置日志
27
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -200,58 +200,236 @@ def predict_all(smiles):
200
  logger.error(error_msg)
201
  return error_msg, None, error_msg, None, error_msg, None
202
 
203
- # 新增函数:将绘制的分结构转换为SMILES
204
- def draw_to_smiles(image_data):
205
- """
206
- 将绘制的分子图像转换为SMILES字符串
207
- :param image_data: 包含分子图像数据的base64字符串
208
- :return: (SMILES字符串, 分子图像)
209
- """
210
- if not image_data:
211
- return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
 
 
 
 
 
 
 
 
 
213
  try:
214
- # 将base64图像数据转换为PIL图像
215
- image_bytes = base64.b64decode(image_data.split(",")[1])
216
- image = Image.open(BytesIO(image_bytes))
217
 
218
- # 保存临时文件
219
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
220
- image.save(tmp.name)
221
- tmp_path = tmp.name
222
 
223
- # 使用RDKit从图像读取分子
224
- mol = MolFromSmiles("") # 创建一个空分子
225
- # 注意:实际应用中这里应该使用OCR或图像识别技术
226
- # 这里简化处理,提示用户输入SMILES
227
 
228
- # 创建分子图像
 
229
  img_buffer = BytesIO()
230
- image.save(img_buffer, format="PNG")
231
  img_buffer.seek(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
- # 提示用户输入SMILES
234
- return None, img_buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
- except Exception as e:
237
- error_msg = f"分子转换失败: {str(e)}"
238
- logger.error(error_msg)
239
- return None, None
240
 
241
- # 新增函数:处理绘制输入预测流程
242
- def predict_from_drawing(smiles):
243
- """
244
- 处理用户绘制的分子结构
245
- :param smiles: 用户输入的SMILES字符串
246
- :return: 预测结果 (6个输出组件)
247
- """
248
- # 使用预测函数
249
- return predict_all(smiles)
250
 
251
  # 使用TabbedInterface组织两种输入方式
252
  with gr.Blocks(title="CrystalGAT") as demo:
253
  gr.Markdown("# CrystalGAT")
254
- gr.Markdown("输入SMILES字符串或绘制分子结构,CrystalGAT将预测其弹性、塑性和脆性分类并可视化注意力权重")
255
 
256
  with gr.Tab("SMILES输入"):
257
  gr.Markdown("### 输入SMILES字符串")
@@ -271,27 +449,19 @@ with gr.Blocks(title="CrystalGAT") as demo:
271
  )
272
  submit_btn1 = gr.Button("预测", variant="primary")
273
 
274
- with gr.Tab("绘制分子"):
275
- gr.Markdown("### 绘制分子结构")
276
- gr.Markdown("使用下方的绘图工具绘制分结构,然后在SMILES框中输入对应的SMILES字符串")
277
-
278
- # 使用标准绘图组件
279
- drawing_input = gr.ImageEditor(
280
- label="绘制分子结构",
281
- type="pil",
282
- interactive=True,
283
- height=300
284
- )
285
 
286
- gr.Markdown("### 输入绘制的分子对应的SMILES")
287
- drawing_smiles = gr.Textbox(
288
- label="SMILES",
289
- placeholder="输入绘制的分子对应的SMILES字符串",
290
- interactive=True
291
- )
292
 
293
- drawing_display = gr.Image(label="分子结构览", interactive=False)
294
- submit_btn2 = gr.Button("预测", variant="primary")
 
295
 
296
  # 输出区域 (两种输入方式共享)
297
  gr.Markdown("## 预测结果")
@@ -313,16 +483,10 @@ with gr.Blocks(title="CrystalGAT") as demo:
313
  outputs=[elastic_text, elastic_img, plastic_text, plastic_img, brittle_text, brittle_img]
314
  )
315
 
316
- # 绘图输入路径
317
- drawing_input.change(
318
- fn=draw_to_smiles,
319
- inputs=drawing_input,
320
- outputs=[drawing_smiles, drawing_display]
321
- )
322
-
323
  submit_btn2.click(
324
- fn=predict_from_drawing,
325
- inputs=drawing_smiles,
326
  outputs=[elastic_text, elastic_img, plastic_text, plastic_img, brittle_text, brittle_img]
327
  )
328
 
 
20
  from rdkit.Chem import MolFromSmiles
21
  from io import BytesIO
22
  import traceback
 
23
  import base64
24
+ import json
25
 
26
  # 配置日志
27
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
200
  logger.error(error_msg)
201
  return error_msg, None, error_msg, None, error_msg, None
202
 
203
+ # 和键类型选项
204
+ ATOM_TYPES = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I", "H"]
205
+ BOND_TYPES = ["单键", "双键", "三键"]
206
+
207
+ # 初始化分子结构
208
+ def init_molecule():
209
+ return {
210
+ "atoms": [],
211
+ "bonds": []
212
+ }
213
+
214
+ # 添加原子
215
+ def add_atom(molecule, atom_type, x, y):
216
+ molecule["atoms"].append({
217
+ "id": len(molecule["atoms"]),
218
+ "type": atom_type,
219
+ "x": x,
220
+ "y": y
221
+ })
222
+ return molecule
223
+
224
+ # 添加键
225
+ def add_bond(molecule, atom1_id, atom2_id, bond_type):
226
+ # 检查原子是否存在
227
+ if atom1_id >= len(molecule["atoms"]) or atom2_id >= len(molecule["atoms"]):
228
+ return molecule
229
+
230
+ # 检查是否已存在键
231
+ for bond in molecule["bonds"]:
232
+ if (bond["atom1"] == atom1_id and bond["atom2"] == atom2_id) or \
233
+ (bond["atom1"] == atom2_id and bond["atom2"] == atom1_id):
234
+ return molecule
235
+
236
+ molecule["bonds"].append({
237
+ "atom1": atom1_id,
238
+ "atom2": atom2_id,
239
+ "type": bond_type
240
+ })
241
+ return molecule
242
+
243
+ # 从JSON结构生成SMILES
244
+ def generate_smiles(molecule_json):
245
+ try:
246
+ # 创建空分子
247
+ mol = Chem.RWMol()
248
+
249
+ # 添加原子
250
+ atom_map = {}
251
+ for atom in molecule_json["atoms"]:
252
+ new_atom = Chem.Atom(atom["type"])
253
+ idx = mol.AddAtom(new_atom)
254
+ atom_map[atom["id"]] = idx
255
+
256
+ # 添加键
257
+ for bond in molecule_json["bonds"]:
258
+ start_atom = atom_map[bond["atom1"]]
259
+ end_atom = atom_map[bond["atom2"]]
260
+
261
+ # 确定键类型
262
+ bond_type_mapping = {
263
+ "单键": Chem.BondType.SINGLE,
264
+ "双键": Chem.BondType.DOUBLE,
265
+ "三键": Chem.BondType.TRIPLE
266
+ }
267
+ bond_type = bond_type_mapping.get(bond["type"], Chem.BondType.SINGLE)
268
+
269
+ mol.AddBond(start_atom, end_atom, bond_type)
270
+
271
+ # 清理分子
272
+ mol.UpdatePropertyCache()
273
+ Chem.SanitizeMol(mol)
274
 
275
+ # 生成SMILES
276
+ smiles = Chem.MolToSmiles(mol)
277
+ return smiles
278
+ except Exception as e:
279
+ logger.error(f"生成SMILES失败: {str(e)}")
280
+ return None
281
+
282
+ # 可视化分子结构
283
+ def visualize_molecule(molecule_json):
284
  try:
285
+ smiles = generate_smiles(molecule_json)
286
+ if not smiles:
287
+ return None
288
 
289
+ mol = Chem.MolFromSmiles(smiles)
290
+ if mol is None:
291
+ return None
 
292
 
293
+ # 生成2D坐标
294
+ AllChem.Compute2DCoords(mol)
 
 
295
 
296
+ # 创建图像
297
+ img = Draw.MolToImage(mol, size=(300, 300))
298
  img_buffer = BytesIO()
299
+ img.save(img_buffer, format="PNG")
300
  img_buffer.seek(0)
301
+ return img_buffer
302
+ except Exception as e:
303
+ logger.error(f"可视化分子失败: {str(e)}")
304
+ return None
305
+
306
+ # 分子构建界面
307
+ def build_molecule_interface():
308
+ # 初始化分子
309
+ molecule = init_molecule()
310
+
311
+ with gr.Blocks() as interface:
312
+ gr.Markdown("### 构建分子")
313
+
314
+ # 原子选择
315
+ with gr.Row():
316
+ atom_select = gr.Dropdown(
317
+ label="选择原子类型",
318
+ choices=ATOM_TYPES,
319
+ value="C"
320
+ )
321
+ bond_select = gr.Dropdown(
322
+ label="选择键类型",
323
+ choices=BOND_TYPES,
324
+ value="单键"
325
+ )
326
+
327
+ # 画布
328
+ canvas = gr.Image(
329
+ label="分子结构",
330
+ interactive=True,
331
+ tool="select",
332
+ height=400
333
+ )
334
+
335
+ # 状态显示
336
+ status = gr.Textbox(label="状态", interactive=False)
337
+
338
+ # 操作按钮
339
+ with gr.Row():
340
+ add_atom_btn = gr.Button("添加原子")
341
+ add_bond_btn = gr.Button("添加键")
342
+ clear_btn = gr.Button("清除")
343
+ generate_btn = gr.Button("生成分子")
344
+
345
+ # 分子预览
346
+ molecule_img = gr.Image(label="分子预览", interactive=False)
347
+ molecule_smiles = gr.Textbox(label="生成的SMILES", interactive=False)
348
+
349
+ # 存储分子结构的隐藏状态
350
+ molecule_state = gr.State(molecule)
351
 
352
+ # 原子位置存储
353
+ last_click_pos = gr.State((0, 0))
354
+
355
+ # 事件处理
356
+ canvas.select(
357
+ fn=lambda evt: {"x": evt.index[0], "y": evt.index[1]},
358
+ outputs=last_click_pos
359
+ )
360
+
361
+ add_atom_btn.click(
362
+ fn=lambda atom, pos, mol: add_atom(mol, atom, pos[0], pos[1]),
363
+ inputs=[atom_select, last_click_pos, molecule_state],
364
+ outputs=molecule_state
365
+ ).then(
366
+ fn=lambda mol: f"添加原子成功! 当前原子数: {len(mol['atoms'])}, 键数: {len(mol['bonds'])}",
367
+ inputs=molecule_state,
368
+ outputs=status
369
+ )
370
+
371
+ # 添加键需要选择两个原子
372
+ atom1_state = gr.State(-1)
373
+ atom2_state = gr.State(-1)
374
+
375
+ canvas.select(
376
+ fn=lambda evt, mol: {"atom_id": find_nearest_atom(mol, evt.index[0], evt.index[1])},
377
+ inputs=molecule_state,
378
+ outputs=atom1_state
379
+ )
380
+
381
+ canvas.select(
382
+ fn=lambda evt, mol: {"atom_id": find_nearest_atom(mol, evt.index[0], evt.index[1])},
383
+ inputs=molecule_state,
384
+ outputs=atom2_state
385
+ )
386
+
387
+ add_bond_btn.click(
388
+ fn=lambda bond, atom1, atom2, mol: add_bond(mol, atom1, atom2, bond),
389
+ inputs=[bond_select, atom1_state, atom2_state, molecule_state],
390
+ outputs=molecule_state
391
+ ).then(
392
+ fn=lambda mol: f"添加键成功! 当前原子数: {len(mol['atoms'])}, 键数: {len(mol['bonds'])}",
393
+ inputs=molecule_state,
394
+ outputs=status
395
+ )
396
+
397
+ clear_btn.click(
398
+ fn=init_molecule,
399
+ outputs=molecule_state
400
+ ).then(
401
+ fn=lambda: "已清除所有原子和键",
402
+ outputs=status
403
+ )
404
+
405
+ generate_btn.click(
406
+ fn=generate_smiles,
407
+ inputs=molecule_state,
408
+ outputs=molecule_smiles
409
+ ).then(
410
+ fn=visualize_molecule,
411
+ inputs=molecule_state,
412
+ outputs=molecule_img
413
+ ).then(
414
+ fn=lambda mol: f"分子生成完成! 原子数: {len(mol['atoms'])}, 键数: {len(mol['bonds'])}",
415
+ inputs=molecule_state,
416
+ outputs=status
417
+ )
418
 
419
+ return interface
 
 
 
420
 
421
+ # 查找最近原子
422
+ def find_nearest_atom(molecule, x, y, radius=20):
423
+ for atom in molecule["atoms"]:
424
+ dist = ((atom["x"] - x)**2 + (atom["y"] - y)**2)**0.5
425
+ if dist <= radius:
426
+ return atom["id"]
427
+ return -1
 
 
428
 
429
  # 使用TabbedInterface组织两种输入方式
430
  with gr.Blocks(title="CrystalGAT") as demo:
431
  gr.Markdown("# CrystalGAT")
432
+ gr.Markdown("输入SMILES字符串或构建分子结构,CrystalGAT将预测其弹性、塑性和脆性分类并可视化注意力权重")
433
 
434
  with gr.Tab("SMILES输入"):
435
  gr.Markdown("### 输入SMILES字符串")
 
449
  )
450
  submit_btn1 = gr.Button("预测", variant="primary")
451
 
452
+ with gr.Tab("构建分子"):
453
+ gr.Markdown("### 构建分子结构")
454
+ gr.Markdown("1. 选择原类型和键类型")
455
+ gr.Markdown("2. 点击'添加原子'按钮后在画布上点击放置原子")
456
+ gr.Markdown("3. 点击'添加键'按钮后依次点击两个原子连接它们")
457
+ gr.Markdown("4. 完成后点击'生成分子'按钮")
 
 
 
 
 
458
 
459
+ # 分子构建界面
460
+ build_interface = build_molecule_interface()
 
 
 
 
461
 
462
+ #测按钮
463
+ submit_btn2 = gr.Button("使用此分子进行预测", variant="primary")
464
+ build_smiles = gr.Textbox(visible=False) # 隐藏的SMILES存储
465
 
466
  # 输出区域 (两种输入方式共享)
467
  gr.Markdown("## 预测结果")
 
483
  outputs=[elastic_text, elastic_img, plastic_text, plastic_img, brittle_text, brittle_img]
484
  )
485
 
486
+ # 分子构建预测路径
 
 
 
 
 
 
487
  submit_btn2.click(
488
+ fn=lambda smiles: predict_all(smiles) if smiles else ("请输入有效的SMILES", None, "", None, "", None),
489
+ inputs=build_smiles,
490
  outputs=[elastic_text, elastic_img, plastic_text, plastic_img, brittle_text, brittle_img]
491
  )
492