ToxiPredict โ€” Multi-Task GNN for Toxicophore Prediction

Uncertainty-aware multi-task graph neural network trained on the Tox21 dataset for predicting toxicity across 10 biological endpoints, with 5-fold cross-validated performance.

Model Description

Property Value
Architecture MultiTaskGNN_ResGATv2_JK_VN
Input Molecular graphs (SMILES โ†’ 45-dim node, 11-dim edge features)
Output 10 binary toxicity predictions + uncertainty weights
Parameters 12 learnable homoscedastic uncertainty log-variance parameters
Training Data Tox21 (6264 training compounds after scaffold split)
Validation 5-fold Bemis-Murcko scaffold cross-validation
Framework PyTorch 2.10 + PyTorch Geometric 2.6

Performance

5-Fold Cross-Validation: 0.7856 ยฑ 0.0394 Mean AUC

The model was evaluated using Bemis-Murcko scaffold split, ensuring that structurally similar molecules are grouped in the same fold. This provides a realistic estimate of generalization to novel chemical scaffolds.

Task Type
NR-AR Nuclear Receptor
NR-AhR Nuclear Receptor
NR-Aromatase Nuclear Receptor
NR-ER Nuclear Receptor
NR-PPAR-gamma Nuclear Receptor
SR-ARE Stress Response
SR-ATAD5 Stress Response
SR-HSE Stress Response
SR-MMP Stress Response
SR-p53 Stress Response

Architecture Details

The model extends standard GAT with three key innovations:

  1. Residual GATv2 Convolutions: Two-layer GATv2 with residual connections and 4 attention heads per layer, providing dynamic attention mechanisms that adapt to molecular substructures.
  2. JumpingKnowledge (JK) Aggregation: Concatenates intermediate layer representations before prediction, preserving both local and global structural information.
  3. Virtual Node: A learned virtual node connected to all atoms enables global molecular context propagation across the graph.
  4. Homoscedastic Uncertainty Weighting: Learnable per-task log-variance parameters dynamically balance gradient contributions during multi-task training.

Usage

import torch
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

# Download model
model_path = hf_hub_download(
    repo_id="Arko007/toxipredict-gnn-models",
    filename="model.safetensors"
)
state_dict = load_file(model_path)

# Load config
config_path = hf_hub_download(
    repo_id="Arko007/toxipredict-gnn-models",
    filename="model_config.json"
)
import json
with open(config_path) as f:
    config = json.load(f)

# Initialize model with same architecture
model = MultiTaskGNN_ResGATv2_JK_VN(
    node_dim=config["node_dim"],
    edge_dim=config["edge_dim"],
    hidden_dim=config["hidden_dim"],
    num_tasks=config["num_tasks"],
    dropout=config["dropout"]
)
model.load_state_dict(state_dict)
model.eval()

Training Details

  • Optimizer: Adam (lr=1e-3, weight_decay=1e-5)
  • Batch Size: 64
  • Max Epochs: 200 (early stopping patience=20)
  • Loss: Homoscedastic uncertainty-weighted binary cross-entropy
  • Hardware: NVIDIA T4 GPU (Kaggle)
  • Training Time: ~12 minutes per run

References

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Safetensors
Model size
522k params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Space using Arko007/toxipredict-gnn-models 1

Evaluation results