ToxiPredict โ Multi-Task GNN for Toxicophore Prediction
Uncertainty-aware multi-task graph neural network trained on the Tox21 dataset for predicting toxicity across 10 biological endpoints, with 5-fold cross-validated performance.
Model Description
| Property | Value |
|---|---|
| Architecture | MultiTaskGNN_ResGATv2_JK_VN |
| Input | Molecular graphs (SMILES โ 45-dim node, 11-dim edge features) |
| Output | 10 binary toxicity predictions + uncertainty weights |
| Parameters | 12 learnable homoscedastic uncertainty log-variance parameters |
| Training Data | Tox21 (6264 training compounds after scaffold split) |
| Validation | 5-fold Bemis-Murcko scaffold cross-validation |
| Framework | PyTorch 2.10 + PyTorch Geometric 2.6 |
Performance
5-Fold Cross-Validation: 0.7856 ยฑ 0.0394 Mean AUC
The model was evaluated using Bemis-Murcko scaffold split, ensuring that structurally similar molecules are grouped in the same fold. This provides a realistic estimate of generalization to novel chemical scaffolds.
| Task | Type |
|---|---|
| NR-AR | Nuclear Receptor |
| NR-AhR | Nuclear Receptor |
| NR-Aromatase | Nuclear Receptor |
| NR-ER | Nuclear Receptor |
| NR-PPAR-gamma | Nuclear Receptor |
| SR-ARE | Stress Response |
| SR-ATAD5 | Stress Response |
| SR-HSE | Stress Response |
| SR-MMP | Stress Response |
| SR-p53 | Stress Response |
Architecture Details
The model extends standard GAT with three key innovations:
- Residual GATv2 Convolutions: Two-layer GATv2 with residual connections and 4 attention heads per layer, providing dynamic attention mechanisms that adapt to molecular substructures.
- JumpingKnowledge (JK) Aggregation: Concatenates intermediate layer representations before prediction, preserving both local and global structural information.
- Virtual Node: A learned virtual node connected to all atoms enables global molecular context propagation across the graph.
- Homoscedastic Uncertainty Weighting: Learnable per-task log-variance parameters dynamically balance gradient contributions during multi-task training.
Usage
import torch
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
# Download model
model_path = hf_hub_download(
repo_id="Arko007/toxipredict-gnn-models",
filename="model.safetensors"
)
state_dict = load_file(model_path)
# Load config
config_path = hf_hub_download(
repo_id="Arko007/toxipredict-gnn-models",
filename="model_config.json"
)
import json
with open(config_path) as f:
config = json.load(f)
# Initialize model with same architecture
model = MultiTaskGNN_ResGATv2_JK_VN(
node_dim=config["node_dim"],
edge_dim=config["edge_dim"],
hidden_dim=config["hidden_dim"],
num_tasks=config["num_tasks"],
dropout=config["dropout"]
)
model.load_state_dict(state_dict)
model.eval()
Training Details
- Optimizer: Adam (lr=1e-3, weight_decay=1e-5)
- Batch Size: 64
- Max Epochs: 200 (early stopping patience=20)
- Loss: Homoscedastic uncertainty-weighted binary cross-entropy
- Hardware: NVIDIA T4 GPU (Kaggle)
- Training Time: ~12 minutes per run
References
- Tox21 Challenge: https://tripod.nih.gov/tox21/challenge/
- GATv2: Brody et al., ICLR 2022
- JumpingKnowledge: Xu et al., ICML 2018
- Homoscedastic Uncertainty: Kendall et al., NeurIPS 2017
License
MIT
Space using Arko007/toxipredict-gnn-models 1
Evaluation results
- 5-Fold CV Mean AUC on Tox21self-reported0.786
- 5-Fold CV AUC Standard Deviation on Tox21self-reported0.039