DinoBloom: A Foundation Model for Generalizable Cell Embeddings in Hematology
Short description
DinoBloom is a ViT (Vision Transformer) built upon DINOv2 (Meta AI) and is trained on data of single cells from peripheral blood and bone marrow. The models in this repository can be used to extract features that serve as inputs for a variety of prediction models. The project was developed by Koch et al. and more information can be found on their GitHub repository and in the accompanying paper. This repository is fork of the original repo of the authors: HuggingFace repository.
Model Versions
DinoBloom is available in four sizes:
| Model |
Feature Dim |
Parameters |
Checkpoint |
| DinoBloom-S |
384 |
22M |
pytorch_model_s.bin |
| DinoBloom-B |
768 |
86M |
pytorch_model_b.bin |
| DinoBloom-L |
1024 |
304M |
pytorch_model_l.bin |
| DinoBloom-G |
1536 |
1136M |
pytorch_model_g.bin |
Long description
DinoBloom, the first foundation model for single cell images in hematology, utilizes a tailored DINOv2 pipeline. The model is built upon an extensive collection of 13 diverse, publicly available datasets of peripheral blood and bone marrow smears, the most substantial open-source cohort in hematology so far, comprising over 380,000 white blood cell images. To assess its generalization capability, it's evaluated on an external dataset with a challenging domain shift. The model outperforms existing medical and non-medical vision models in (i) linear probing and k-nearest neighbor evaluations for cell-type classification on blood and
bone marrow smears and (ii) weakly supervised multiple instance learning for acute myeloid leukemia subtyping by a large margin. A family of four DinoBloom models (small, base, large, and giant) can be adapted for a wide range of downstream applications, be a strong baseline for classification problems, and facilitate the assessment of batch effects in new datasets.
Installation
Install the conda environment with all dependencies:
conda env create -f environment.yaml
conda activate virtual-human-chc-dinobloom
Metadata
Input
- Description: List of single cell images
- Input format:
tensor
- Shape:
[batch_size, C, 224, 224], where batch_size is the number of image and C are the channels of the images
- Data format: (float)
- Example: See
input\001.bmp
- Preprocessing:
- Reshape the image to 224x224
- Normalize the values of the image
Model:
- Modality: Hematology single cell images
- Scale: Per image
- Description: The model uses a Vision Transformer architecture based on fine-tuned DINOv2 models to map a single cell image to a continuous embedding.
- Training data: The model is trained on 13 diverse datasets comprising cell images from white blood and bone marrow smears. See section
References.
- Publication: https://papers.miccai.org/miccai-2024/230-Paper3584.html
Output
- Description: Each image is represented by a multidimensional vector, which size depends on the model's version.
- Output format: tensor
- Shape:
[n, feature_dim] with n the number of images and feature_dim the feature dimensions depended on the model versions (see Model versions)
- Data format: (float)
Example
Feature extraction example
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
variant = "b"
variant_config = {
"s": ("dinov2_vits14", 384),
"b": ("dinov2_vitb14", 768),
"l": ("dinov2_vitl14", 1024),
"g": ("dinov2_vitg14", 1536),
}
dinov2_model, embed_dim = variant_config[variant]
model = torch.hub.load("facebookresearch/dinov2", dinov2_model)
ckpt_path = hf_hub_download(
repo_id="virtual-human-chc/DinoBloom",
filename=f"pytorch_model_{variant}.bin"
)
ckpt = torch.load(ckpt_path, map_location="cpu")
num_tokens = int(1 + (224 / 14) ** 2)
model.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
model.load_state_dict(ckpt, strict=True)
model.to(device)
model.eval()
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
from PIL import Image
img = Image.open("input/001.bmp")
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
features = model(img_tensor)
print(f"Features shape: {features.shape}")
References
DinoBloom builds upon:
Datasets:
Copyright
Code derived from https://github.com/marrlab/DinoBloom is licensed under the Apache 2.0 (See LICENSE file for details) The other code is licensed under the MIT license, Copyright (c) 2025 Maksim Pavlov.