KerasHub
prasadsachin's picture
Update README.md with new model card content
8186ea1 verified
|
Raw
History Blame Contribute Delete
4.76 kB
---
library_name: keras-hub
---
### Model Overview
# Swin Transformer
Instantiates the Swin Transformer architecture.
## Model Details
The Swin Transformer (Shifted Window Transformer) is a hierarchical vision transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection.
This hierarchical architecture has the flexibility to model at various scales and has linear computational complexity with respect to image size. These qualities make Swin Transformer compatible with a broad range of vision tasks, including image classification, object detection, and semantic segmentation.
### Reference
* [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)
Unlike traditional Vision Transformers (ViT), which compute attention globally across all patches (resulting in quadratic complexity relative to image size), Swin Transformer computes self-attention within local non-overlapping windows. By shifting the window partition between consecutive layers, the model achieves cross-window connections, maintaining linear computational complexity while enabling robust global context modeling.
### Links
* [Swin Transformer Quickstart Notebook](https://www.kaggle.com/code/prasadsachin/swin-transformer-quickstart-keras-hub)
* [Swin Transformer API Documentation](https://keras.io/keras_hub/api/models/swin_transformer/)
* [KerasHub Beginner Guide](https://keras.io/guides/keras_hub/getting_started/)
* [KerasHub Model Publishing Guide](https://keras.io/guides/keras_hub/upload/)
## Installation
Keras and KerasHub can be installed with:
```bash
pip install -U -q keras-hub
pip install -U -q keras
```
JAX, TensorFlow, and PyTorch come preinstalled in Kaggle Notebooks. For instructions on installing them in another environment, see the [Keras Getting Started](https://keras.io/getting_started/) page.
## Presets
The following model checkpoints are provided by the Keras team. Weights have been ported from [Hugging Face Hub](https://huggingface.co/microsoft).
| Preset name | Parameters | Description |
| :--- | :--- | :--- |
| **swin_tiny_patch4_window7_224** | 28.29M | Tiny Swin Transformer model pre-trained on ImageNet-1k at a 224x224 resolution |
| **swin_small_patch4_window7_224** | 49.61M | Small Swin Transformer model pre-trained on ImageNet-1k at a 224x224 resolution |
| **swin_base_patch4_window7_224** | 87.77M | Base Swin Transformer model pre-trained on ImageNet-1k at a 224x224 resolution |
| **swin_base_patch4_window12_384** | 87.90M | Base Swin Transformer model pre-trained on ImageNet-1k at a 384x384 resolution |
| **swin_large_patch4_window7_224** | 196.53M | Large Swin Transformer model pre-trained ImageNet-1k at a 224x224 resolution |
| **swin_large_patch4_window12_384** | 196.74M | Large Swin Transformer model pre-trained on ImageNet-1k at a 384x384 resolution |
## Example Use
```python
import numpy as np
import keras_hub
# Pretrained Swin Transformer backbone
model = keras_hub.models.SwinTransformerBackbone.from_preset("swin_tiny_patch4_window7_224")
input_data = np.random.uniform(0, 1, size=(2, 224, 224, 3))
model(input_data)
# Randomly initialized Swin Transformer backbone with custom config
model = keras_hub.models.SwinTransformerBackbone(
image_shape=(224, 224, 3),
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
window_size=7,
)
model(input_data)
# Use Swin Transformer for image classification task
classifier = keras_hub.models.SwinTransformerImageClassifier.from_preset(
"swin_tiny_patch4_window7_224",
num_classes=1000,
)
# Use Hugging Face presets directly for on-the-fly conversion
classifier = keras_hub.models.SwinTransformerImageClassifier.from_preset(
"hf://microsoft/swin-tiny-patch4-window7-224"
)
```
## Example Usage
```
import numpy as np
import keras_hub
# Top-5 ImageNet class decoding.
model = keras_hub.models.SwinTransformerImageClassifier.from_preset(
"swin_tiny_patch4_window7_224"
)
images = np.random.randint(0, 256, size=(1, 384, 384, 3), dtype="uint8")
logits = model.predict(images, verbose=0)
print(keras_hub.utils.decode_imagenet_predictions(logits, top=5)[0])
```
## Example Usage with Hugging Face URI
```
import numpy as np
import keras_hub
# Top-5 ImageNet class decoding.
model = keras_hub.models.SwinTransformerImageClassifier.from_preset(
"hf://keras/swin_tiny_patch4_window7_224"
)
images = np.random.randint(0, 256, size=(1, 384, 384, 3), dtype="uint8")
logits = model.predict(images, verbose=0)
print(keras_hub.utils.decode_imagenet_predictions(logits, top=5)[0])
```