| | --- |
| | tags: |
| | - cytology foundation model |
| | - vision transformer |
| | - cervical screening |
| | license: apache-2.0 |
| | --- |
| | ## How to use UniCAS to extract features. |
| |
|
| | The code below can be used to run inference; `UniCAS` expects images of size 224x224 that were extracted at 20× magnification. |
| |
|
| | ```python |
| | import functools |
| | import timm |
| | import torch |
| | from torchvision import transforms |
| | |
| | params = { |
| | 'patch_size': 16, |
| | 'embed_dim': 1024, |
| | 'depth': 24, |
| | 'num_heads': 16, |
| | 'init_values': 1e-05, |
| | 'mlp_ratio': 2.671875 * 2, |
| | 'mlp_layer': functools.partial( |
| | timm.layers.mlp.GluMlp, gate_last=False |
| | ), |
| | 'act_layer': torch.nn.modules.activation.SiLU, |
| | 'no_embed_class': False, |
| | 'img_size': 224, |
| | 'num_classes': 0, |
| | 'in_chans': 3 |
| | } |
| | |
| | model = timm.models.VisionTransformer(**params) |
| | print(model.load_state_dict(torch.load("UniCAS.pth"), strict=False)) |
| | model = model.eval().to("cuda") |
| | |
| | |
| | transform = transforms.Compose([ |
| | transforms.ToTensor(), |
| | transforms.Normalize( |
| | mean=(0.485, 0.456, 0.406), |
| | std=(0.229, 0.224, 0.225), |
| | ), |
| | ]) |
| | |
| | input = torch.rand(3, 224, 224) |
| | input = transforms.ToPILImage()(input) |
| | input = transform(input).unsqueeze(0) |
| | with torch.no_grad(): |
| | features = model(input.to("cuda")) |
| | print(features.shape) # torch.Size([1, 1024]) |
| | ``` |