|
|
--- |
|
|
language: en |
|
|
license: mit |
|
|
tags: |
|
|
- graph-ml |
|
|
- bioinformatics |
|
|
- precision-medicine |
|
|
- explainable-ai |
|
|
- reinforcement-learning |
|
|
datasets: |
|
|
- FuhaiLiAiLab/Target-QA |
|
|
library_name: transformers |
|
|
pipeline_tag: text-generation |
|
|
model-index: |
|
|
- name: GALAX |
|
|
results: |
|
|
- task: |
|
|
type: text-generation |
|
|
name: Target Prioritization |
|
|
dataset: |
|
|
name: Target-QA |
|
|
type: FuhaiLiAiLab/Target-QA |
|
|
metrics: |
|
|
- type: precision |
|
|
value: 0.5472 |
|
|
- type: recall |
|
|
value: 0.5332 |
|
|
- type: hit@10 |
|
|
value: 0.8815 |
|
|
- type: hit@5 |
|
|
value: 0.9249 |
|
|
--- |
|
|
|
|
|
# GALAX: Graph-Augmented Language Model for Explainable Reinforcement-Guided Subgraph Reasoning in Precision Medicine |
|
|
|
|
|
<div align="center"> |
|
|
<img src="https://github.com/FuhaiLiAiLab/GALAX/blob/main/Figures/GALAX-logo.png?raw=true" width="40%" alt="GALAX" /> |
|
|
</div> |
|
|
|
|
|
<div align="center" style="line-height: 1;"> |
|
|
<!-- GitHub --> |
|
|
<a href="https://github.com/FuhaiLiAiLab/GALAX" target="_blank" style="margin: 2px;"> |
|
|
<img alt="GitHub" src="https://img.shields.io/badge/GitHub-GALAX%20Code-181717?logo=github&logoColor=white" style="display: inline-block; vertical-align: middle;"/> |
|
|
</a> |
|
|
|
|
|
<!-- Hugging Face Model --> |
|
|
<a href="https://huggingface.co/FuhaiLiAiLab/GALAX" target="_blank" style="margin: 2px;"> |
|
|
<img alt="Hugging Face Model" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-GALAX%20Model-ffc107?color=ffc107&logoColor=white" style="display: inline-block; vertical-align: middle;"/> |
|
|
</a> |
|
|
|
|
|
<!-- Hugging Face Dataset --> |
|
|
<a href="https://huggingface.co/datasets/FuhaiLiAiLab/Target-QA" target="_blank" style="margin: 2px;"> |
|
|
<img alt="Hugging Face Dataset" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Target--QA%20Dataset-ff6f61?color=ffc107&logoColor=white" style="display: inline-block; vertical-align: middle;"/> |
|
|
</a> |
|
|
</div> |
|
|
|
|
|
<div align="center" style="line-height: 1;"> |
|
|
<!-- arXiv --> |
|
|
<a href="https://arxiv.org/abs/2509.20935" target="_blank" style="margin: 2px;"> |
|
|
<img alt="arXiv" src="https://img.shields.io/badge/arXiv-GALAX%20Paper-b31b1b?logo=arxiv&logoColor=white" style="display: inline-block; vertical-align: middle;"/> |
|
|
</a> |
|
|
|
|
|
<!-- License --> |
|
|
<a href="LICENSE" style="margin: 2px;"> |
|
|
<img alt="License" src="https://img.shields.io/badge/License-MIT-0a4d92?logo=open-source-initiative&logoColor=white" style="display: inline-block; vertical-align: middle;"/> |
|
|
</a> |
|
|
</div> |
|
|
|
|
|
--- |
|
|
|
|
|
## 🧩 Model Overview |
|
|
|
|
|
 |
|
|
|
|
|
**GALAX** is a graph-augmented language model designed for explainable target prioritization in precision medicine. It combines three key components: |
|
|
- **LLaMA3-8B-Instruct** as the language backbone, further adapted with the BioMedGraphica corpus and fine-tuned on Target-QA. |
|
|
- **Graph Attention Network (GAT)** pretrained on integrated multi-omics data and BioMedGraphica knowledge graphs. |
|
|
- **A reinforcement-guided subgraph generator** that enables interpretable reasoning by constructing biologically meaningful subgraphs from multi-omics and knowledge graph signals. |
|
|
|
|
|
By jointly leveraging **multi-omics features**, **protein–protein interactions**, and **disease–target associations**, GALAX provides an interpretable framework for **CRISPR target prioritization** across diverse cancer cell lines. To support benchmarking and reproducibility, we also introduce the **[Target-QA dataset](https://huggingface.co/datasets/FuhaiLiAiLab/Target-QA)**. |
|
|
|
|
|
--- |
|
|
|
|
|
## 🚀 How to Use |
|
|
|
|
|
```python |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from huggingface_hub import snapshot_download |
|
|
import os, torch |
|
|
|
|
|
# 1. Load GALAX language model |
|
|
model_id = "FuhaiLiAiLab/GALAX" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
lm_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
device_map="auto", |
|
|
torch_dtype="auto" |
|
|
) |
|
|
|
|
|
# 2. Access graph foundation model |
|
|
repo_path = snapshot_download(model_id) |
|
|
combined_model_path = os.path.join(repo_path, "best_combined_model.pt") |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
best_combined_model = torch.load(combined_model_path, map_location=device) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## ⚙️ Experimental Setup |
|
|
|
|
|
- **Backbone LM:** LLaMA3-8B-Instruct (QA-tuned). |
|
|
- **Graph Encoder:** BioBERT-v1.1 embeddings + GAT with edge masking. |
|
|
- **Training:** Adam optimizer on 2× NVIDIA H100 (80GB). |
|
|
- **Top features per omics modality:** K = 10. |
|
|
- **Subgraph rollout depth:** L = 5, candidate nodes η = 20. |
|
|
- **Evaluation:** Precision, Recall, F1, Jaccard, Hit@5, Hit@10. |
|
|
|
|
|
--- |
|
|
|
|
|
|
|
|
## 📊 Results |
|
|
|
|
|
GALAX consistently outperforms baselines and ablation variants. |
|
|
|
|
|
- **Overall Precision:** 0.5472 |
|
|
- **Overall Recall:** 0.5332 |
|
|
- **Hit@10:** 0.8815 |
|
|
- **Hit@5:** 0.9249 |
|
|
|
|
|
**Table 1. Precision and Recall across datasets** |
|
|
|
|
|
| Model | Overall Precision ↑ | Overall Recall ↑ | LUAD Precision ↑ | LUAD Recall ↑ | BRCA Precision ↑ | BRCA Recall ↑ | |
|
|
|-------------------------|---------------------|------------------|------------------|---------------|------------------|---------------| |
|
|
| M2T | 0.0016 | 0.0011 | 0.0020 | 0.0014 | 0.0000 | 0.0000 | |
|
|
| GAT | 0.0006 ± 0.0000 | 0.0006 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0033 ± 0.0000 | 0.0033 ± 0.0000 | |
|
|
| L3 + Omics | 0.0071 ± 0.0032 | 0.0013 ± 0.0002 | 0.0079 ± 0.0137 | 0.0005 ± 0.0008 | 0.0020 ± 0.0035 | 0.0017 ± 0.0029 | |
|
|
| L3 + Omics + KG | 0.0125 ± 0.0032 | 0.0029 ± 0.0003 | 0.0014 ± 0.0025 | 0.0010 ± 0.0017 | 0.0073 ± 0.0068 | 0.0033 ± 0.0029 | |
|
|
| L3-FT(Med) + Omics | 0.0179 ± 0.0045 | 0.0133 ± 0.0064 | 0.0091 ± 0.0018 | 0.0105 ± 0.0044 | 0.0110 ± 0.0086 | 0.0106 ± 0.0075 | |
|
|
| L3-FT(Med) + Omics + KG | 0.0158 ± 0.0030 | 0.0058 ± 0.0011 | 0.0081 ± 0.0071 | 0.0024 ± 0.0017 | 0.0149 ± 0.0057 | 0.0050 ± 0.0000 | |
|
|
| L3-FT(QA) + Omics | 0.5250 ± 0.0282 | 0.4959 ± 0.0435 | 0.5201 ± 0.0408 | 0.4905 ± 0.0532 | 0.5074 ± 0.0498 | 0.4856 ± 0.0570 | |
|
|
| L3-FT(QA) + Omics + KG | 0.5185 ± 0.0240 | 0.4908 ± 0.0402 | 0.5214 ± 0.0242 | 0.4952 ± 0.0432 | 0.4856 ± 0.0395 | 0.4656 ± 0.0436 | |
|
|
| G-Retriever + pre-GAT | 0.4763 ± 0.0004 | 0.3929 ± 0.0063 | 0.4642 ± 0.0181 | 0.3881 ± 0.0264 | 0.4414 ± 0.0099 | 0.3772 ± 0.0010 | |
|
|
| **GALAX** | **0.5472 ± 0.0053** | **0.5332 ± 0.0031** | **0.5345 ± 0.0185** | **0.5157 ± 0.0043** | **0.5608 ± 0.0031** | **0.5533 ± 0.0033** | |
|
|
|
|
|
**Table 2. Hit@10 and Hit@5 across datasets** |
|
|
|
|
|
| Model | Overall Hit@10 ↑ | Overall Hit@5 ↑ | LUAD Hit@10 ↑ | LUAD Hit@5 ↑ | BRCA Hit@10 ↑ | BRCA Hit@5 ↑ | |
|
|
|-------------------------|------------------|-----------------|---------------|--------------|---------------|--------------| |
|
|
| M2T | 0.0029 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | |
|
|
| GAT | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | |
|
|
| L3 + Omics | 0.0021 ± 0.0037 | 0.0032 ± 0.0055 | 0.0048 ± 0.0082 | 0.0095 ± 0.0165 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | |
|
|
| L3 + Omics + KG | 0.0122 ± 0.0033 | 0.0085 ± 0.0037 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0056 ± 0.0096 | 0.0111 ± 0.0192 | |
|
|
| L3-FT(Med) + Omics | 0.0122 ± 0.0072 | 0.0116 ± 0.0097 | 0.0000 ± 0.0000 | 0.0000 ± 0.0000 | 0.0111 ± 0.0192 | 0.0000 ± 0.0000 | |
|
|
| L3-FT(Med) + Omics + KG | 0.0132 ± 0.0040 | 0.0106 ± 0.0048 | 0.0048 ± 0.0082 | 0.0095 ± 0.0165 | 0.0111 ± 0.0192 | 0.0000 ± 0.0000 | |
|
|
| L3-FT(QA) + Omics | 0.8693 ± 0.0157 | 0.8889 ± 0.0168 | 0.8667 ± 0.0218 | 0.8476 ± 0.0165 | 0.8389 ± 0.0096 | 0.8889 ± 0.0509 | |
|
|
| L3-FT(QA) + Omics + KG | 0.8529 ± 0.0153 | 0.8794 ± 0.0114 | 0.8048 ± 0.0541 | 0.7905 ± 0.0436 | 0.8222 ± 0.0347 | 0.8778 ± 0.0192 | |
|
|
| G-Retriever + pre-GAT | 0.8550 ± 0.0046 | 0.8804 ± 0.0037 | 0.8524 ± 0.0165 | 0.8857 ± 0.0000 | **0.8667 ± 0.0000** | 0.8667 ± 0.0000 | |
|
|
| **GALAX** | **0.8815 ± 0.0033** | **0.9249 ± 0.0048** | **0.8810 ± 0.0082** | **0.9238 ± 0.0436** | 0.8500 ± 0.0441 | **0.8889 ± 0.0839** | |
|
|
|
|
|
--- |
|
|
|
|
|
## 🔬 Intended Uses |
|
|
|
|
|
- **Research use only** |
|
|
- Benchmarking **graph-language foundation models** in target priorization |
|
|
- Target prioritization in **cancer biology** |
|
|
|
|
|
--- |
|
|
|
|
|
## 📜 Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@article{zhang2025galax, |
|
|
title = {GALAX: Graph-Augmented Language Model for Explainable Reinforcement-Guided Subgraph Reasoning in Precision Medicine}, |
|
|
author = {Zhang, Heming and Huang, Di and Li, Wenyu and Province, Michael and Chen, Yixin and Payne, Philip and Li, Fuhai}, |
|
|
journal = {arXiv preprint arXiv:2509.20935}, |
|
|
year = {2025}, |
|
|
doi = {10.48550/arXiv.2509.20935}, |
|
|
url = {https://arxiv.org/abs/2509.20935} |
|
|
} |
|
|
|