|
|
--- |
|
|
license: mit |
|
|
language: en |
|
|
library_name: transformers |
|
|
tags: |
|
|
- unet |
|
|
- film |
|
|
- computer-vision |
|
|
- image-segmentation |
|
|
- medical-imaging |
|
|
- pytorch |
|
|
--- |
|
|
|
|
|
# FILMUnet2D |
|
|
|
|
|
This model is a 2D U-Net with FiLM conditioning for Ultrasound multi-organ segmentation. |
|
|
|
|
|
## Installation |
|
|
|
|
|
Make sure you have `transformers` and `torch` installed. |
|
|
|
|
|
```bash |
|
|
pip install transformers torch |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
You can load the model and processor using the `Auto` classes from `transformers`. Since this repository contains custom code, make sure to pass `trust_remote_code=True`. |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoModel, AutoImageProcessor |
|
|
from PIL import Image |
|
|
|
|
|
# 1. Load model and processor |
|
|
repo_id = "AImageLab-Zip/US_FiLMUNet" |
|
|
|
|
|
processor = AutoImageProcessor.from_pretrained(repo_id, trust_remote_code=True) |
|
|
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True) |
|
|
model.eval() |
|
|
|
|
|
# 2. Load and preprocess your image |
|
|
# The processor handles resizing, letterboxing, and normalization. |
|
|
image = Image.open("path/to/your/image.png").convert("RGB") |
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
|
|
|
# 3. Prepare conditioning input |
|
|
# This should be an integer tensor representing the organ ID. |
|
|
# Replace `4` with the appropriate ID for your use case. |
|
|
organ_id = torch.tensor([4]) |
|
|
|
|
|
# 4. Run inference |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs, organ_id=organ_id) |
|
|
|
|
|
# 5. Post-process the output to get the final segmentation mask |
|
|
# The processor can convert the logits to a binary mask, automatically handling |
|
|
# the removal of letterbox padding and resizing to the original image dimensions. |
|
|
mask = processor.post_process_semantic_segmentation( |
|
|
outputs, |
|
|
inputs, |
|
|
threshold=0.7, |
|
|
return_as_pil=True |
|
|
)[0] |
|
|
|
|
|
# 6. Save the result |
|
|
mask.save("output_mask.png") |
|
|
|
|
|
print("Segmentation mask saved to output_mask.png") |
|
|
``` |
|
|
|
|
|
### Model Details |
|
|
|
|
|
- **Architecture:** U-Net with FiLM layers for conditional segmentation. |
|
|
- **Conditioning:** The model's output is conditioned on an `organ_id` input. |
|
|
- **Input:** RGB images. |
|
|
- **Output:** A single-channel segmentation mask. |
|
|
|
|
|
### Configuration |
|
|
|
|
|
The model configuration can be accessed via `model.config`. Key parameters include: |
|
|
- `in_channels`: Number of input channels (default: 3). |
|
|
- `num_classes`: Number of output classes (default: 1). |
|
|
- `n_organs`: The number of different organs the model was trained to condition on. |
|
|
- `depth`: The depth of the U-Net. |
|
|
- `size`: The base number of filters in the first layer. |
|
|
|
|
|
### Organ IDs |
|
|
|
|
|
The `organ_id` passed to the model corresponds to the following mapping: |
|
|
|
|
|
```python |
|
|
organ_to_class_dict = { |
|
|
"appendix": 0, |
|
|
"breast": 1, |
|
|
"breast_luminal": 1, |
|
|
"cardiac": 2, |
|
|
"thyroid": 3, |
|
|
"fetal": 4, |
|
|
"kidney": 5, |
|
|
"liver": 6, |
|
|
"testicle": 7, |
|
|
} |
|
|
``` |
|
|
|
|
|
### Alternative Versions |
|
|
|
|
|
This repository contains multiple versions of the model located in subfolders. You can load a specific version by using the `subfolder` parameter. |
|
|
|
|
|
#### 4-Stage U-Net |
|
|
|
|
|
This version has a U-Net depth of 4. |
|
|
|
|
|
```python |
|
|
from transformers import AutoModel |
|
|
|
|
|
model_4_stages = AutoModel.from_pretrained( |
|
|
"AImageLab-Zip/US_FiLMUNet", |
|
|
subfolder="unet_4_stages", |
|
|
trust_remote_code=True |
|
|
) |
|
|
``` |
|
|
|