You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

A High-Performance, Reliable Multimodal-Multiomic Whole Slide AI Model generates Genome-wide Spatial Gene Expression from Histopathology Images

git and environment

https://github.com/99wzj/Coladan/

inference demo

import h5py
import torch
from transformers import AutoConfig

import pickle
import numpy as np
import pandas as pd
from modeling_coladan import Coladan 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


#load model
cfg = AutoConfig.from_pretrained("WangZj99/Coladan", trust_remote_code=True)
model = Coladan(cfg)
model = model.to(device)

from huggingface_hub import hf_hub_download
state_dict_path = hf_hub_download(
        "WangZj99/Coladan", 
        filename="state_dict.pt",
    )
model.load_state_dict(torch.load(state_dict_path))


#load data
demo_path = hf_hub_download(
    "WangZj99/Coladan", 
    filename="10X_demo.pkl",
)

with open(demo_path, "rb") as f:
    loaded_data = pickle.load(f)

patches = loaded_data['he_image']
coords = loaded_data['coordinates']
predict_genes = loaded_data['predict_gene'] # length 16000 for 16000 genes

#predict from HE , only support one slide per time
with torch.autocast(device_type='cuda', dtype=torch.bfloat16), torch.inference_mode():
    patches = patches
    coords = coords.to(device)
    predict_genes = torch.tensor(loaded_data['predict_gene'].to_numpy(), dtype=torch.long)
    predict_matrix = model.predict_gene_from_image(patches, coords, predict_genes)
    

#extra
#calculate perason and spearman between ori and predict
gene_order = predict_genes.cpu().numpy()
ori_matrix_df = loaded_data['expression_matrix'][gene_order]
ori_matrix = ori_matrix_df.to_numpy(dtype=np.float32)   # shape: [N_cells, 16000]
predict_matrix = predict_matrix.to(dtype=torch.float32).cpu().numpy()

from coladan import calculate_pearson_and_spearman
mean_pearson,mean_spearman =  calculate_pearson_and_spearman(ori = ori_matrix,predict = predict_matrix)
print(f"mean Pearson  : {mean_pearson:.4f}")
print(f"mean Spearman : {mean_spearman:.4f}")
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support