Patho3dMatrix / README.md
xtxx's picture
Update README.md
2944f2f verified
metadata
license: cc-by-nc-nd-4.0
language:
  - en
pipeline_tag: image-feature-extraction
library_name: timm
metrics:
  - accuracy

Using Patho3dMatrix to extract features from pathology image

import torch
import timm
from PIL import Image
from torchvision import transforms
from safetensors.torch import load_file

MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

if __name__ == '__main__':
    # Init Patho3DMatrix Foundation Model
    patho3dmatrix = timm.create_model(
        "vit_large_patch14_dinov2.lvd142m",
        pretrained=False,
        dynamic_img_size=True,
        num_classes=0,
    )

    # Load safetensors weights
    patho3dmatrix_weights_path = 'pytorch_model.safetensors'
    state_dict = load_file(patho3dmatrix_weights_path, device='cpu')
    msg = patho3dmatrix.load_state_dict(state_dict, strict=True)
    print(msg)
    print('weights loaded successfully')

    # Set device
    device = torch.device('cuda:5')
    patho3dmatrix = patho3dmatrix.to(device)
    patho3dmatrix.eval()

    # Image preprocess
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=MEAN, std=STD),
    ])

    # Encode one image
    img_path = 'test.png'
    img = Image.open(img_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        feat = patho3dmatrix(img_tensor)

    print('feature shape:', feat.shape)

Evaluation Pipeline