evenflow_models / README.md
gprolcastelo's picture
Update README.md
433e5e9 verified
---
license: apache-2.0
language:
- en
pipeline_tag: tabular-regression
tags:
- VAE
- bioinformatics
- TCGA
- ccRCC
- KIRC
- cancer
---
# Pretrained Models
This directory contains pretrained VAE and reconstruction network models obtained during the WP3 of the EVENFLOW EU project.
These models have been trained on a pre-processed version of the bulk RNA-Seq TCGA datasets of either KIRC or BRCA, independently (see data availability in the respective section).
## Available Models
### KIRC (Kidney Renal Clear Cell Carcinoma)
**Location**: `KIRC/`
*Data availability:* [Zenodo](https://doi.org/10.5281/zenodo.17987300)
**Model Files**:
- `20250321_VAE_idim8516_md512_feat256mse_relu.pth` - VAE weights
- `network_reconstruction.pth` - Reconstruction network weights
- `network_dims.csv` - Network architecture specifications
**Model Specifications**:
- Input dimension: 8,516 genes
- VAE architecture:
- Middle dimension: 512
- Latent dimension: 256
- Loss function: MSE
- Activation: ReLU
- Reconstruction network: [8954, 3512, 824, 3731, 8954]
- Training: Beta-VAE with 3 cycles, 600 epochs total
### BRCA (Breast Invasive Carcinoma)
**Location**: `BRCA/`
*Data availability:* [Zenodo](https://doi.org/10.5281/zenodo.17986123)
**Model Files**:
- `20251209_VAE_idim8954_md1024_feat512mse_relu.pth` - VAE weights
- `network_reconstruction.pth` - Reconstruction network weights
- `network_dims.csv` - Network architecture specifications
**Model Specifications**:
- Input dimension: 8,954 genes
- VAE architecture:
- Middle dimension: 1,024
- Latent dimension: 512
- Loss function: MSE
- Activation: ReLU
- Reconstruction network: [8954, 3104, 790, 4027, 8954]
- Training: Beta-VAE with 3 cycles, 600 epochs total
## Usage
### Loading Models in Python
See [renalprog](https://www.github.com/gprolcastelo/renalprog) for the needed VAE and NetworkReconstruction objects.
```python
import torch
import pandas as pd
import json
from pathlib import Path
import huggingface_hub as hf
from renalprog.modeling.train import VAE, NetworkReconstruction
# Configuration
cancer_type = "KIRC" # or "BRCA"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ============================================================================
# Load VAE Model
# ============================================================================
# Download VAE config
vae_config_path = hf.hf_hub_download(
repo_id="gprolcastelo/evenflow_models",
filename=f"{cancer_type}/config.json"
)
# Load configuration
with open(vae_config_path, "r") as f:
vae_config = json.load(f)
print(f"VAE Configuration: {vae_config}")
# Download VAE model weights
if cancer_type == "KIRC":
vae_filename = "KIRC/20250321_VAE_idim8516_md512_feat256mse_relu.pth"
elif cancer_type == "BRCA":
vae_filename = "BRCA/20251209_VAE_idim8954_md1024_feat512mse_relu.pth"
else:
raise ValueError(f"Unknown cancer type: {cancer_type}")
vae_model_path = hf.hf_hub_download(
repo_id="gprolcastelo/evenflow_models",
filename=vae_filename
)
# Initialize and load VAE
model_vae = VAE(
input_dim=vae_config["INPUT_DIM"],
mid_dim=vae_config["MID_DIM"],
features=vae_config["LATENT_DIM"]
).to(device)
checkpoint_vae = torch.load(vae_model_path, map_location=device, weights_only=False)
model_vae.load_state_dict(checkpoint_vae)
model_vae.eval()
print(f"VAE model loaded successfully from {cancer_type}")
# ============================================================================
# Load Reconstruction Network
# ============================================================================
# Download network dimensions
network_dims_path = hf.hf_hub_download(
repo_id="gprolcastelo/evenflow_models",
filename=f"{cancer_type}/network_dims.csv"
)
# Load network dimensions
network_dims = pd.read_csv(network_dims_path)
layer_dims = network_dims.values.tolist()[0]
print(f"Reconstruction Network dimensions: {layer_dims}")
# Download reconstruction network weights
recnet_model_path = hf.hf_hub_download(
repo_id="gprolcastelo/evenflow_models",
filename=f"{cancer_type}/network_reconstruction.pth"
)
# Initialize and load Reconstruction Network
model_recnet = NetworkReconstruction(layer_dims=layer_dims).to(device)
checkpoint_recnet = torch.load(recnet_model_path, map_location=device, weights_only=False)
model_recnet.load_state_dict(checkpoint_recnet)
model_recnet.eval()
print(f"Reconstruction Network loaded successfully from {cancer_type}")
# ============================================================================
# Use the models
# ============================================================================
# Example: Apply VAE to your data
# your_data = torch.tensor(your_data_array).float().to(device)
# with torch.no_grad():
# vae_output = model_vae(your_data)
# recnet_output = model_recnet(vae_output)
```
## Citation
> **⚠️ Warning**
> This citation is temporary. It will be updated when a pre-print is released.
If you use these pretrained models, please cite:
```bibtex
@software{renalprog2024,
title = {RenalProg: A Deep Learning Framework for Kidney Cancer Progression Modeling},
author = {[Guillermo Prol-Castelo, Elina Syrri, Nikolaos Manginas, Vasileos Manginas, Nikos Katzouris, Davide Cirillo, George Paliouras, Alfonso Valencia]},
year = {2025},
url = {https://github.com/gprolcas/renalprog},
note = {Preprint in preparation}
}
```
## Training Details
These models were trained using:
- Random seed: 2023
- Train/test split: 80/20
- Optimizer: Adam
- Learning rate: 1e-4
- Batch size: 8
- Beta annealing (for VAE): 3 cycles with 0.5 ratio
## Model Performance
**KIRC Model**:
- Reconstruction loss (test): ~1.1
**BRCA Model**:
- Reconstruction loss (test): ~0.9
## License
These pretrained models are provided under the same Apache 2.0 license.
## Contact
For questions about the pretrained models, please:
1. Check the [documentation](https://gprolcastelo.github.io/renalprog/)
2. Open an issue on [GitHub](https://github.com/gprolcastelo/renalprog/issues)
3. Contact the authors
---
**Last Updated**: December 2025
**Version**: 1.0.0-alpha