QQ2S3R commited on
Commit
f900cab
·
verified ·
1 Parent(s): b5c7e53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -212
app.py CHANGED
@@ -17,18 +17,19 @@ matplotlib.use('Agg') # 修复后台线程问题
17
  from rdkit import Chem
18
  from rdkit.Chem import Draw, AllChem, MolFromSmiles
19
  from io import BytesIO
20
- import traceback
21
 
22
  # 配置日志
23
- logging.basicConfig(level=logging.INFO,
24
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
 
25
  logger = logging.getLogger(__name__)
26
 
27
- # GPU内存优化
28
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
29
- logger.info("设置GPU内存优化参数: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128")
30
 
31
- # 解压模型文件
32
  if not os.path.exists("best_model-B-6000-185.pth"):
33
  logger.info("开始解压模型文件...")
34
  try:
@@ -36,7 +37,7 @@ if not os.path.exists("best_model-B-6000-185.pth"):
36
  zip_ref.extractall(".")
37
  logger.info("模型文件解压完成!")
38
  except Exception as e:
39
- logger.error(f"解压模型文件失败: {str(e)}")
40
  raise
41
 
42
  # 导入模型工具
@@ -44,7 +45,7 @@ try:
44
  from model_utils import EnhancedGAT, smiles_to_graph, visualize_single_molecule
45
  logger.info("成功导入 model_utils 模块")
46
  except ImportError as e:
47
- logger.error(f"导入 model_utils 失败: {str(e)}")
48
  raise
49
 
50
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -56,263 +57,199 @@ def load_models():
56
  model_info = {
57
  "Elastic": ("models/best_model-E-500-68.pth", 2),
58
  "Plastic": ("models/best_model-P-5000-180.pth", 2),
59
- "Brittle": ("models/best_model-B-6000-185.pth", 2)
60
  }
61
  models = {}
62
  for name, (pth_path, output_dim) in model_info.items():
63
  logger.info(f"正在加载 {name} 模型: {pth_path}")
64
  if not os.path.exists(pth_path):
65
- logger.error(f"模型文件不存在: {pth_path}")
66
- possible_files = [
67
- pth_path,
68
- pth_path.lower(),
69
- pth_path.upper(),
70
- pth_path.replace("-", "_"),
71
- pth_path.replace("_", "-")
72
- ]
73
- found = False
74
- for file in possible_files:
75
- if os.path.exists(file):
76
- logger.warning(f"使用替代文件: {file}")
77
- pth_path = file
78
- found = True
79
  break
80
- if not found:
81
  raise FileNotFoundError(f"模型文件 {pth_path} 不存在")
 
82
  try:
83
- model = EnhancedGAT(input_dim=12, hidden_dim=512,
84
- output_dim=output_dim, num_heads=8)
85
- logger.info(f"加载模型权重: {pth_path}")
86
- try:
87
- state = torch.load(pth_path, map_location=device, weights_only=False)
88
- except Exception as e:
89
- logger.warning(f"weights_only=False 加载失败: {str(e)}")
90
- try:
91
- import numpy as np, torch.serialization
92
- torch.serialization.add_safe_globals(
93
- [getattr(np, '_core', np).multiarray.scalar])
94
- state = torch.load(pth_path, map_location=device, weights_only=True)
95
- except:
96
- state = torch.load(pth_path, map_location=device)
97
- if "model_state_dict" in state:
98
- state_dict = state["model_state_dict"]
99
- else:
100
- state_dict = state
101
- model.load_state_dict(state_dict)
102
- model.eval().to(device)
103
- models[name] = model
104
- logger.info(f"{name} 模型加载成功!")
105
- except Exception as e:
106
- logger.error(f"加载 {name} 模型失败: {str(e)}")
107
- raise
108
  return models
109
 
110
- logger.info("开始加载所有模型...")
111
  models = load_models()
112
- logger.info("所有模型加载完成!")
113
 
114
- def predict_all(smiles):
115
- logger.info(f"收到预测请求: SMILES = {smiles}")
116
- try:
117
- atom_features, (rows, cols, edge_attr), mol = smiles_to_graph(smiles)
118
- x = torch.tensor(atom_features, dtype=torch.float)
119
- edge_index = torch.tensor(np.column_stack((rows, cols)).T, dtype=torch.long)
120
- edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
121
- data = PyGData(x=x,
122
- edge_index=edge_index,
123
- edge_attr=edge_attr,
124
- smiles=[smiles],
125
- batch=torch.zeros(x.size(0), dtype=torch.long))
126
- results = []
127
- for name in ["Elastic", "Plastic", "Brittle"]:
128
- logger.info(f"使用 {name} 模型进行预测...")
129
- try:
130
- buf, pred = visualize_single_molecule(models[name], data, device, name)
131
- if buf:
132
- buf.seek(0)
133
- img = Image.open(buf)
134
- pred_text = f"{name} Result: {'1' if pred == 1 else '0'}"
135
- results.append((pred_text, img))
136
- else:
137
- results.append((f"{name} 预测失败: 未生成图像", None))
138
- except Exception as e:
139
- results.append((f"{name} 预测过程中发生错误: {str(e)}", None))
140
- return (results[0][0], results[0][1],
141
- results[1][0], results[1][1],
142
- results[2][0], results[2][1])
143
- except Exception as e:
144
- logger.error(f"预测过程中发生严重错误: {str(e)}")
145
- return ("预测失败", None, "预测失败", None, "预测失败", None)
146
 
147
- # 原子和键类型选项
148
  ATOM_TYPES = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I", "H"]
149
  BOND_TYPES = ["单键", "双键", "三键"]
150
 
151
- # 初始化分子结构
152
  def init_molecule():
153
  return {"atoms": [], "bonds": []}
154
 
155
- # 添加原子
156
- def add_atom(molecule, atom_type):
157
- molecule["atoms"].append({"id": len(molecule["atoms"]), "type": atom_type})
158
- return molecule
159
-
160
- # 添加键
161
- def add_bond(molecule, atom1_id, atom2_id, bond_type):
162
- if atom1_id >= len(molecule["atoms"]) or atom2_id >= len(molecule["atoms"]):
163
- return molecule
164
- for bond in molecule["bonds"]:
165
- if {bond["atom1"], bond["atom2"]} == {atom1_id, atom2_id}:
166
- return molecule
167
- molecule["bonds"].append({"atom1": atom1_id, "atom2": atom2_id, "type": bond_type})
168
- return molecule
169
-
170
- # 从 JSON 构建 SMILES
171
- def generate_smiles(molecule_json):
172
  try:
173
  mol = Chem.RWMol()
174
- atom_map = {}
175
- for atom in molecule_json["atoms"]:
176
  idx = mol.AddAtom(Chem.Atom(atom["type"]))
177
- atom_map[atom["id"]] = idx
178
- for bond in molecule_json["bonds"]:
179
- t = {"单键": Chem.BondType.SINGLE,
180
- "双键": Chem.BondType.DOUBLE,
181
- "三键": Chem.BondType.TRIPLE}[bond["type"]]
182
- mol.AddBond(atom_map[bond["atom1"]], atom_map[bond["atom2"]], t)
183
  mol.UpdatePropertyCache()
184
  Chem.SanitizeMol(mol)
185
  return Chem.MolToSmiles(mol)
186
  except Exception as e:
187
- logger.error(f"生成SMILES失败: {e}")
188
- return None
189
 
190
- # 可视化分子
191
- def visualize_molecule(molecule_json):
192
- try:
193
- smiles = generate_smiles(molecule_json)
194
- if not smiles:
195
- return None
196
- mol = MolFromSmiles(smiles)
197
- if mol is None:
198
- return None
199
- AllChem.Compute2DCoords(mol)
200
- img = Draw.MolToImage(mol, size=(300, 300))
201
- buf = BytesIO()
202
- img.save(buf, format="PNG")
203
- buf.seek(0)
204
- return buf
205
- except Exception as e:
206
- logger.error(f"可视化分子失败: {str(e)}")
207
  return None
 
 
 
 
 
 
208
 
209
- # 更新下拉菜单
210
- def update_atom_dropdowns(molecule):
211
- choices = [f"{a['id']}: {a['type']}" for a in molecule["atoms"]]
212
  return gr.update(choices=choices, value=None), gr.update(choices=choices, value=None)
213
 
214
- # 更新列表
215
- def update_atoms_list(molecule):
216
- return [[a["id"], a["type"]] for a in molecule["atoms"]]
217
 
218
- def update_bonds_list(molecule):
219
  out = []
220
- for b in molecule["bonds"]:
221
- t1 = next(a["type"] for a in molecule["atoms"] if a["id"]==b["atom1"])
222
- t2 = next(a["type"] for a in molecule["atoms"] if a["id"]==b["atom2"])
223
  out.append([f"{b['atom1']}: {t1}", f"{b['atom2']}: {t2}", b["type"]])
224
  return out
225
 
226
  with gr.Blocks(title="CrystalGAT") as demo:
227
  gr.Markdown("# CrystalGAT")
228
- gr.Markdown("输入SMILES或构建分子,预测弹性/塑性/脆性并可视化")
229
 
230
- # SMILES 输入
231
- with gr.Tab("SMILES输入"):
232
  smiles_input = gr.Textbox(label="SMILES", placeholder="例如: CCO")
233
- submit_btn1 = gr.Button("预测", variant="primary")
234
 
235
- # 构建分子
236
  with gr.Tab("构建分子"):
237
- molecule_state = gr.State(init_molecule())
238
- status_msg = gr.Textbox(label="状态", interactive=False, value="请添加原子开始")
239
 
240
  with gr.Row():
241
- atom_select = gr.Dropdown(label="选择原子类型", choices=ATOM_TYPES, value="C")
242
  add_atom_btn = gr.Button("添加原子")
243
- atoms_list = gr.Dataframe(headers=["ID","原子类型"], datatype=["number","str"], interactive=False)
244
 
245
  with gr.Row():
246
- atom1_select = gr.Dropdown(label="第一个原子", choices=[], value=None)
247
- atom2_select = gr.Dropdown(label="第二个原子", choices=[], value=None)
248
- bond_select = gr.Dropdown(label="键类型", choices=BOND_TYPES, value="单键")
249
  add_bond_btn = gr.Button("添加键")
250
- bonds_list = gr.Dataframe(headers=["原子1","原子2","键类型"], datatype=["str","str","str"], interactive=False)
251
 
252
  clear_btn = gr.Button("清除所有")
253
- generate_btn = gr.Button("生成分子")
254
- molecule_smiles = gr.Textbox(label="SMILES结果", interactive=False)
255
- # 关键点:将 type="pil",才能显示 PIL 图像
256
- molecule_img = gr.Image(type="pil", label="分子预", interactive=False)
257
-
258
- submit_btn2 = gr.Button("使用此分子预测", variant="primary")
259
-
260
- # 事件绑定
261
- add_atom_btn.click(fn=lambda at, mol: add_atom(mol, at),
262
- inputs=[atom_select, molecule_state],
263
- outputs=molecule_state) \
264
- .then(fn=update_atoms_list, inputs=molecule_state, outputs=atoms_list) \
265
- .then(fn=update_atom_dropdowns, inputs=molecule_state,
266
- outputs=[atom1_select, atom2_select]) \
267
- .then(fn=lambda: "原子添加成功", outputs=status_msg)
268
-
269
- add_bond_btn.click(fn=lambda a1, a2, b, mol: add_bond(
270
- mol,
271
- int(a1.split(":")[0]) if a1 else -1,
272
- int(a2.split(":")[0]) if a2 else -1,
273
- b),
274
- inputs=[atom1_select, atom2_select, bond_select, molecule_state],
275
- outputs=molecule_state) \
276
- .then(fn=update_bonds_list, inputs=molecule_state, outputs=bonds_list) \
277
- .then(fn=lambda: "键添加/更新成功", outputs=status_msg)
278
-
279
- clear_btn.click(fn=init_molecule, outputs=molecule_state) \
280
- .then(fn=lambda: ([], []), outputs=[atoms_list, bonds_list]) \
281
- .then(fn=lambda: (gr.update(choices=[], value=None),
282
- gr.update(choices=[], value=None)),
283
- outputs=[atom1_select, atom2_select]) \
284
- .then(fn=lambda: "已清除所有", outputs=status_msg)
285
-
286
- generate_btn.click(fn=generate_smiles,
287
- inputs=molecule_state,
288
- outputs=molecule_smiles) \
289
- .then(fn=visualize_molecule,
290
- inputs=molecule_state,
291
- outputs=molecule_img) \
292
- .then(fn=lambda: "分子生成完成", outputs=status_msg)
293
-
294
- # 预测结果展示
295
  with gr.Row():
296
- elastic_text = gr.Text(label="Elastic")
297
- elastic_img = gr.Image(type="pil", label="Elastic 可视化")
298
  with gr.Row():
299
- plastic_text = gr.Text(label="Plastic")
300
- plastic_img = gr.Image(type="pil", label="Plastic 可视化")
301
  with gr.Row():
302
- brittle_text = gr.Text(label="Brittle")
303
- brittle_img = gr.Image(type="pil", label="Brittle 可视化")
304
-
305
- submit_btn1.click(fn=predict_all,
306
- inputs=smiles_input,
307
- outputs=[elastic_text, elastic_img,
308
- plastic_text, plastic_img,
309
- brittle_text, brittle_img])
310
- submit_btn2.click(fn=lambda s: predict_all(s) if s else (
311
- "请输入SMILES", None, "", None, "", None),
312
- inputs=molecule_smiles,
313
- outputs=[elastic_text, elastic_img,
314
- plastic_text, plastic_img,
315
- brittle_text, brittle_img])
316
 
317
  if __name__ == "__main__":
318
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
17
  from rdkit import Chem
18
  from rdkit.Chem import Draw, AllChem, MolFromSmiles
19
  from io import BytesIO
 
20
 
21
  # 配置日志
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
25
+ )
26
  logger = logging.getLogger(__name__)
27
 
28
+ # GPU 内存优化
29
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
30
+ logger.info("设置 GPU 内存优化参数: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128")
31
 
32
+ # 解压模型文件(如果尚未解压)
33
  if not os.path.exists("best_model-B-6000-185.pth"):
34
  logger.info("开始解压模型文件...")
35
  try:
 
37
  zip_ref.extractall(".")
38
  logger.info("模型文件解压完成!")
39
  except Exception as e:
40
+ logger.error(f"解压模型文件失败: {e}")
41
  raise
42
 
43
  # 导入模型工具
 
45
  from model_utils import EnhancedGAT, smiles_to_graph, visualize_single_molecule
46
  logger.info("成功导入 model_utils 模块")
47
  except ImportError as e:
48
+ logger.error(f"导入 model_utils 失败: {e}")
49
  raise
50
 
51
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
57
  model_info = {
58
  "Elastic": ("models/best_model-E-500-68.pth", 2),
59
  "Plastic": ("models/best_model-P-5000-180.pth", 2),
60
+ "Brittle": ("models/best_model-B-6000-185.pth", 2),
61
  }
62
  models = {}
63
  for name, (pth_path, output_dim) in model_info.items():
64
  logger.info(f"正在加载 {name} 模型: {pth_path}")
65
  if not os.path.exists(pth_path):
66
+ # 尝试其他变体
67
+ for alt in [pth_path.lower(), pth_path.upper(), pth_path.replace("-", "_"), pth_path.replace("_", "-")]:
68
+ if os.path.exists(alt):
69
+ logger.warning(f"使用替代模型文件: {alt}")
70
+ pth_path = alt
 
 
 
 
 
 
 
 
 
71
  break
72
+ else:
73
  raise FileNotFoundError(f"模型文件 {pth_path} 不存在")
74
+ model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=output_dim, num_heads=8)
75
  try:
76
+ state = torch.load(pth_path, map_location=device, weights_only=False)
77
+ except Exception:
78
+ # 备用加载方式
79
+ state = torch.load(pth_path, map_location=device)
80
+ # 支持两种格式
81
+ state_dict = state.get("model_state_dict", state)
82
+ model.load_state_dict(state_dict)
83
+ model.eval().to(device)
84
+ models[name] = model
85
+ logger.info(f"{name} 模型加载成功")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  return models
87
 
 
88
  models = load_models()
 
89
 
90
+ def predict_all(smiles: str):
91
+ """对 Elastic, Plastic, Brittle 三个模型做预测,返回文本与 PIL 图像对象。"""
92
+ atom_features, (rows, cols, edge_attr), mol = smiles_to_graph(smiles)
93
+ x = torch.tensor(atom_features, dtype=torch.float)
94
+ edge_index = torch.tensor(np.vstack((rows, cols)), dtype=torch.long)
95
+ edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
96
+ data = PyGData(
97
+ x=x,
98
+ edge_index=edge_index,
99
+ edge_attr=edge_attr,
100
+ smiles=[smiles],
101
+ batch=torch.zeros(x.size(0), dtype=torch.long),
102
+ )
103
+
104
+ outputs = []
105
+ for name in ["Elastic", "Plastic", "Brittle"]:
106
+ buf, pred = visualize_single_molecule(models[name], data, device, name)
107
+ if buf is not None:
108
+ buf.seek(0)
109
+ img = Image.open(buf)
110
+ text = f"{name} Result: {int(pred)}"
111
+ outputs.append((text, img))
112
+ else:
113
+ outputs.append((f"{name} 预测失败", None))
114
+ # 拆包为 6 个输出
115
+ return (*outputs[0], *outputs[1], *outputs[2])
116
+
117
+ # 分子构建相关函数
 
 
 
 
118
 
 
119
  ATOM_TYPES = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I", "H"]
120
  BOND_TYPES = ["单键", "双键", "三键"]
121
 
 
122
  def init_molecule():
123
  return {"atoms": [], "bonds": []}
124
 
125
+ def add_atom(mol_json, atom_type):
126
+ mol_json["atoms"].append({"id": len(mol_json["atoms"]), "type": atom_type})
127
+ return mol_json
128
+
129
+ def add_bond(mol_json, atom1_sel, atom2_sel, bond_type):
130
+ if not atom1_sel or not atom2_sel:
131
+ return mol_json
132
+ a1 = int(atom1_sel.split(":")[0])
133
+ a2 = int(atom2_sel.split(":")[0])
134
+ # 避免重复
135
+ for b in mol_json["bonds"]:
136
+ if set([b["atom1"], b["atom2"]]) == set([a1, a2]):
137
+ return mol_json
138
+ mol_json["bonds"].append({"atom1": a1, "atom2": a2, "type": bond_type})
139
+ return mol_json
140
+
141
+ def generate_smiles(mol_json):
142
  try:
143
  mol = Chem.RWMol()
144
+ id_map = {}
145
+ for atom in mol_json["atoms"]:
146
  idx = mol.AddAtom(Chem.Atom(atom["type"]))
147
+ id_map[atom["id"]] = idx
148
+ for bond in mol_json["bonds"]:
149
+ bt = {"单键": Chem.BondType.SINGLE,
150
+ "双键": Chem.BondType.DOUBLE,
151
+ "三键": Chem.BondType.TRIPLE}[bond["type"]]
152
+ mol.AddBond(id_map[bond["atom1"]], id_map[bond["atom2"]], bt)
153
  mol.UpdatePropertyCache()
154
  Chem.SanitizeMol(mol)
155
  return Chem.MolToSmiles(mol)
156
  except Exception as e:
157
+ logger.error(f"生成 SMILES 失败: {e}")
158
+ return ""
159
 
160
+ def visualize_molecule(mol_json):
161
+ """直接返回 PIL.Image.Image 对象,或 None"""
162
+ smiles = generate_smiles(mol_json)
163
+ if not smiles:
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  return None
165
+ mol = MolFromSmiles(smiles)
166
+ if mol is None:
167
+ return None
168
+ AllChem.Compute2DCoords(mol)
169
+ img = Draw.MolToImage(mol, size=(300, 300))
170
+ return img
171
 
172
+ def update_atom_dropdowns(mol_json):
173
+ choices = [f"{a['id']}: {a['type']}" for a in mol_json["atoms"]]
 
174
  return gr.update(choices=choices, value=None), gr.update(choices=choices, value=None)
175
 
176
+ def update_atoms_list(mol_json):
177
+ return [[a["id"], a["type"]] for a in mol_json["atoms"]]
 
178
 
179
+ def update_bonds_list(mol_json):
180
  out = []
181
+ for b in mol_json["bonds"]:
182
+ t1 = next(a["type"] for a in mol_json["atoms"] if a["id"] == b["atom1"])
183
+ t2 = next(a["type"] for a in mol_json["atoms"] if a["id"] == b["atom2"])
184
  out.append([f"{b['atom1']}: {t1}", f"{b['atom2']}: {t2}", b["type"]])
185
  return out
186
 
187
  with gr.Blocks(title="CrystalGAT") as demo:
188
  gr.Markdown("# CrystalGAT")
189
+ gr.Markdown("输入 SMILES 构建分子,预测弹性/塑性/脆性 并可视化注意力权重")
190
 
191
+ with gr.Tab("SMILES 输入"):
 
192
  smiles_input = gr.Textbox(label="SMILES", placeholder="例如: CCO")
193
+ predict_btn1 = gr.Button("预测", variant="primary")
194
 
 
195
  with gr.Tab("构建分子"):
196
+ state = gr.State(init_molecule())
197
+ status = gr.Textbox(label="状态", interactive=False, value="请添加原子开始")
198
 
199
  with gr.Row():
200
+ atom_type = gr.Dropdown(label="选择原子类型", choices=ATOM_TYPES, value="C")
201
  add_atom_btn = gr.Button("添加原子")
202
+ atom_table = gr.Dataframe(headers=["ID", "原子类型"], datatype=["number","str"], interactive=False)
203
 
204
  with gr.Row():
205
+ atom1 = gr.Dropdown(label="第一个原子", choices=[], value=None)
206
+ atom2 = gr.Dropdown(label="第二个原子", choices=[], value=None)
207
+ bond_type = gr.Dropdown(label="键类型", choices=BOND_TYPES, value="单键")
208
  add_bond_btn = gr.Button("添加键")
209
+ bond_table = gr.Dataframe(headers=["原子1", "原子2", "键类型"], datatype=["str","str","str"], interactive=False)
210
 
211
  clear_btn = gr.Button("清除所有")
212
+ gen_btn = gr.Button("生成分子")
213
+ smiles_out = gr.Textbox(label="SMILES 结果", interactive=False)
214
+ mol_img = gr.Image(type="pil", label="分子预览")
215
+ predict_btn2 = gr.Button("使用此分子预", variant="primary")
216
+
217
+ add_atom_btn.click(fn=add_atom, inputs=[state, atom_type], outputs=state) \
218
+ .then(fn=update_atoms_list, inputs=state, outputs=atom_table) \
219
+ .then(fn=update_atom_dropdowns, inputs=state, outputs=[atom1, atom2]) \
220
+ .then(fn=lambda: "原子添加成功", outputs=status)
221
+
222
+ add_bond_btn.click(fn=add_bond, inputs=[state, atom1, atom2, bond_type], outputs=state) \
223
+ .then(fn=update_bonds_list, inputs=state, outputs=bond_table) \
224
+ .then(fn=lambda: "键添加/更新成功", outputs=status)
225
+
226
+ clear_btn.click(fn=lambda: init_molecule(), outputs=state) \
227
+ .then(fn=lambda: ([], []), outputs=[atom_table, bond_table]) \
228
+ .then(fn=lambda: (gr.update(choices=[], value=None),
229
+ gr.update(choices=[], value=None)),
230
+ outputs=[atom1, atom2]) \
231
+ .then(fn=lambda: "已清除所有", outputs=status)
232
+
233
+ gen_btn.click(fn=generate_smiles, inputs=state, outputs=smiles_out) \
234
+ .then(fn=visualize_molecule, inputs=state, outputs=mol_img) \
235
+ .then(fn=lambda: "分子生成完成", outputs=status)
236
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  with gr.Row():
238
+ e_txt = gr.Text(label="Elastic")
239
+ e_img = gr.Image(type="pil", label="Elastic 可视化")
240
  with gr.Row():
241
+ p_txt = gr.Text(label="Plastic")
242
+ p_img = gr.Image(type="pil", label="Plastic 可视化")
243
  with gr.Row():
244
+ b_txt = gr.Text(label="Brittle")
245
+ b_img = gr.Image(type="pil", label="Brittle 可视化")
246
+
247
+ predict_btn1.click(fn=predict_all,
248
+ inputs=smiles_input,
249
+ outputs=[e_txt, e_img, p_txt, p_img, b_txt, b_img])
250
+ predict_btn2.click(fn=lambda s: predict_all(s) if s else ("请输入SMILES", None, "", None, "", None),
251
+ inputs=smiles_out,
252
+ outputs=[e_txt, e_img, p_txt, p_img, b_txt, b_img])
 
 
 
 
 
253
 
254
  if __name__ == "__main__":
255
  demo.launch(server_name="0.0.0.0", server_port=7860)