| | --- |
| | license: apache-2.0 |
| | --- |
| | A High-Performance, Reliable Multimodal-Multiomic Whole Slide AI Model generates Genome-wide Spatial Gene Expression from Histopathology Images |
| |
|
| | # git and environment |
| | https://github.com/99wzj/Coladan/ |
| |
|
| |
|
| | # inference demo |
| | ``` |
| | import h5py |
| | import torch |
| | from transformers import AutoConfig |
| | |
| | import pickle |
| | import numpy as np |
| | import pandas as pd |
| | from modeling_coladan import Coladan |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | |
| | #load model |
| | cfg = AutoConfig.from_pretrained("WangZj99/Coladan", trust_remote_code=True) |
| | model = Coladan(cfg) |
| | model = model.to(device) |
| | |
| | from huggingface_hub import hf_hub_download |
| | state_dict_path = hf_hub_download( |
| | "WangZj99/Coladan", |
| | filename="state_dict.pt", |
| | ) |
| | model.load_state_dict(torch.load(state_dict_path)) |
| | |
| | |
| | #load data |
| | demo_path = hf_hub_download( |
| | "WangZj99/Coladan", |
| | filename="10X_demo.pkl", |
| | ) |
| | |
| | with open(demo_path, "rb") as f: |
| | loaded_data = pickle.load(f) |
| | |
| | patches = loaded_data['he_image'] |
| | coords = loaded_data['coordinates'] |
| | predict_genes = loaded_data['predict_gene'] # length 16000 for 16000 genes |
| | |
| | #predict from HE , only support one slide per time |
| | with torch.autocast(device_type='cuda', dtype=torch.bfloat16), torch.inference_mode(): |
| | patches = patches |
| | coords = coords.to(device) |
| | predict_genes = torch.tensor(loaded_data['predict_gene'].to_numpy(), dtype=torch.long) |
| | predict_matrix = model.predict_gene_from_image(patches, coords, predict_genes) |
| | |
| | |
| | #extra |
| | #calculate perason and spearman between ori and predict |
| | gene_order = predict_genes.cpu().numpy() |
| | ori_matrix_df = loaded_data['expression_matrix'][gene_order] |
| | ori_matrix = ori_matrix_df.to_numpy(dtype=np.float32) # shape: [N_cells, 16000] |
| | predict_matrix = predict_matrix.to(dtype=torch.float32).cpu().numpy() |
| | |
| | from coladan import calculate_pearson_and_spearman |
| | mean_pearson,mean_spearman = calculate_pearson_and_spearman(ori = ori_matrix,predict = predict_matrix) |
| | print(f"mean Pearson : {mean_pearson:.4f}") |
| | print(f"mean Spearman : {mean_spearman:.4f}") |
| | ``` |
| |
|