milkzheng's picture
Upload folder using huggingface_hub
e72f0d1 verified
|
Raw
History Blame Contribute Delete
1.43 kB
metadata
library_name: transformers
tags:
  - image-segmentation
  - pathology
  - dpt
pipeline_tag: image-segmentation

Lung structures Segmentation (DPT)

Pathology segmentation for lung structures (blood vessels and airways).

  • Encoder (freezed): H-optimus-0 ViT backbone (pretrained on histopathology data).
  • Decoder (trained): custom DPT head with multi-scale feature fusion.

Usage

The model expects a normalized (B, 3, H, W) float tensor as pixel_values. Use ImageNet mean/std — same stats applied at training time (matches the H-optimus-0 backbone's expected input distribution).

Input image: 224x224 @ 1.5 MPP

import numpy as np
import torch
from PIL import Image
from torchvision.transforms import ToTensor, Normalize, Resize, Compose
from transformers import AutoModel

model = AutoModel.from_pretrained("RendeiroLab/MetPredict-lung-structure-segmentation", trust_remote_code=True).eval()
device = next(model.parameters()).device

transform = Compose([
    ToTensor(),
    Resize((224, 224)),
    Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    ),
])

img = Image.open("tile.png").convert("RGB")
x = transform(img)
pixel_values = x.unsqueeze(0).to(device)

with torch.inference_mode():
    out = model(pixel_values)
logits = out.logits                                    # (1, n_classes, H, W)
pred = logits.argmax(dim=1)                            # (1, H, W)