iclr2025-anonymous commited on
Commit
546bb84
·
verified ·
1 Parent(s): 9a5ac74

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +53 -3
README.md CHANGED
@@ -1,3 +1,53 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - image-feature-extraction
4
+ - cell representation
5
+ - histology
6
+ - medical imaging
7
+ - self-supervised learning
8
+ - vision transformer
9
+ - foundation model
10
+ license: mit
11
+ ---
12
+
13
+ # Model card for LEMON
14
+
15
+ `LEMON` is an open-source foundation model for single-cell histology images. The model is a Vision Transformer (ViT-s/8) trained using self-supervised learning on a dataset of 10 million histology cell images sampled from 10,000 slides from TCGA.
16
+ It is described in detail in its [OpenReview paper](https://openreview.net/pdf?id=JAalsmy7bZ).
17
+
18
+ `LEMON` can be used to extract robust features from single-cell histology images for various downstream applications, such as gene expression prediction or cell type classification.
19
+
20
+
21
+ ## How to use it to extract features.
22
+
23
+ The code below can be used to run inference. `LEMON` expects images of size 40x40 that were extracted at 0.25 microns per pixel (40X).
24
+
25
+ ```python
26
+ import torch
27
+ from pathlib import Path
28
+ from torchvision.transforms import ToPILImage
29
+ from model import prepare_transform, get_vit_feature_extractor
30
+
31
+ device = "cpu"
32
+ model_name = "vits8"
33
+ target_cell_size = 40
34
+ weight_path = Path("lemon.pth.tar")
35
+ stats_path = Path("mean_std.json")
36
+
37
+ # Model
38
+ transform = prepare_transform(stats_path, size=target_cell_size)
39
+ model = get_vit_feature_extractor(weight_path, model_name, img_size=target_cell_size)
40
+ model.eval()
41
+ model.to(device)
42
+
43
+ # Data
44
+ input = torch.rand(3, target_cell_size, target_cell_size)
45
+ input = ToPILImage()(input)
46
+
47
+ # Inference
48
+ with torch.autocast(device_type=device, dtype=torch.float16):
49
+ with torch.inference_mode():
50
+ features = model(transform(input).unsqueeze(0).to(device))
51
+
52
+ assert features.shape == (1, 384)
53
+ ```