Update app.py
Browse files
app.py
CHANGED
|
@@ -110,7 +110,7 @@ def load_models():
|
|
| 110 |
import torch.serialization
|
| 111 |
|
| 112 |
# 添加安全全局变量
|
| 113 |
-
torch.serialization.add_safe_globals([getattr(np.
|
| 114 |
state = torch.load(pth_path, map_location=device, weights_only=True)
|
| 115 |
logger.info("使用 weights_only=True 和安全全局变量成功加载模型")
|
| 116 |
except:
|
|
@@ -261,24 +261,6 @@ def predict_from_drawing(mol_data):
|
|
| 261 |
# 使用预测函数
|
| 262 |
return predict_all(smiles)
|
| 263 |
|
| 264 |
-
# 创建绘图输入界面
|
| 265 |
-
drawing_input = gr.Sketchpad(
|
| 266 |
-
label="绘制分子结构",
|
| 267 |
-
type="mol",
|
| 268 |
-
height=300,
|
| 269 |
-
brush_radius=10
|
| 270 |
-
)
|
| 271 |
-
|
| 272 |
-
# 创建输出组件
|
| 273 |
-
outputs = [
|
| 274 |
-
gr.Text(label="Elastic"),
|
| 275 |
-
gr.Image(type="pil", label="Elastic attention visualization"),
|
| 276 |
-
gr.Text(label="Plastic"),
|
| 277 |
-
gr.Image(type="pil", label="Plastic attention visualization"),
|
| 278 |
-
gr.Text(label="Brittle"),
|
| 279 |
-
gr.Image(type="pil", label="Brittle attention visualization")
|
| 280 |
-
]
|
| 281 |
-
|
| 282 |
# 使用TabbedInterface组织两种输入方式
|
| 283 |
with gr.Blocks(title="CrystalGAT") as demo:
|
| 284 |
gr.Markdown("# CrystalGAT")
|
|
@@ -305,6 +287,7 @@ with gr.Blocks(title="CrystalGAT") as demo:
|
|
| 305 |
with gr.Tab("绘制分子"):
|
| 306 |
gr.Markdown("### 绘制分子结构")
|
| 307 |
gr.Markdown("使用下方的绘图工具绘制分子结构,然后点击预测按钮")
|
|
|
|
| 308 |
drawing_display = gr.Image(label="分子结构预览", interactive=False)
|
| 309 |
drawing_output = gr.Text(label="生成的SMILES")
|
| 310 |
submit_btn2 = gr.Button("预测", variant="primary")
|
|
@@ -331,7 +314,7 @@ with gr.Blocks(title="CrystalGAT") as demo:
|
|
| 331 |
|
| 332 |
# 绘图输入路径
|
| 333 |
drawing_input.change(
|
| 334 |
-
fn=draw_to_smiles,
|
| 335 |
inputs=drawing_input,
|
| 336 |
outputs=[drawing_output, drawing_display]
|
| 337 |
)
|
|
|
|
| 110 |
import torch.serialization
|
| 111 |
|
| 112 |
# 添加安全全局变量
|
| 113 |
+
torch.serialization.add_safe_globals([getattr(np._core.multiarray, 'scalar')]) # 修复警告
|
| 114 |
state = torch.load(pth_path, map_location=device, weights_only=True)
|
| 115 |
logger.info("使用 weights_only=True 和安全全局变量成功加载模型")
|
| 116 |
except:
|
|
|
|
| 261 |
# 使用预测函数
|
| 262 |
return predict_all(smiles)
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
# 使用TabbedInterface组织两种输入方式
|
| 265 |
with gr.Blocks(title="CrystalGAT") as demo:
|
| 266 |
gr.Markdown("# CrystalGAT")
|
|
|
|
| 287 |
with gr.Tab("绘制分子"):
|
| 288 |
gr.Markdown("### 绘制分子结构")
|
| 289 |
gr.Markdown("使用下方的绘图工具绘制分子结构,然后点击预测按钮")
|
| 290 |
+
drawing_input = gr.MoleculeEditor(label="绘制分子结构")
|
| 291 |
drawing_display = gr.Image(label="分子结构预览", interactive=False)
|
| 292 |
drawing_output = gr.Text(label="生成的SMILES")
|
| 293 |
submit_btn2 = gr.Button("预测", variant="primary")
|
|
|
|
| 314 |
|
| 315 |
# 绘图输入路径
|
| 316 |
drawing_input.change(
|
| 317 |
+
fn=lambda mol: draw_to_smiles(mol) if mol else (None, None),
|
| 318 |
inputs=drawing_input,
|
| 319 |
outputs=[drawing_output, drawing_display]
|
| 320 |
)
|