| | --- |
| | tags: |
| | - image-classification |
| | - timm |
| | library_name: timm |
| | license: apache-2.0 |
| | --- |
| | # Model card for kat_small_patch16_224.vitft |
| | |
| | KAT model trained on ImageNet-1k (1 million images, 1,000 classes) at resolution 224x224. It was first introduced in the paper Kolmogorov–Arnold Transformer. |
| | |
| | ## Model description |
| | KAT is a model that replaces channel mixer in transfomrers with Group Rational Kolmogorov–Arnold Network (GR-KAN). |
| | |
| | ## Usage |
| | The model definition is at https://github.com/Adamdad/kat, `katransformer.py`. |
| | |
| | ```python |
| | from urllib.request import urlopen |
| | from PIL import Image |
| | import timm |
| | import torch |
| | import katransformer |
| | |
| | img = Image.open(urlopen( |
| | 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png' |
| | )) |
| | |
| | # Move model to CUDA |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | model = timm.create_model("hf_hub:adamdad/kat_small_patch16_224.vitft", pretrained=True) |
| | model = model.to(device) |
| | model = model.eval() |
| | |
| | |
| | |
| | # get model specific transforms (normalization, resize) |
| | data_config = timm.data.resolve_model_data_config(model) |
| | transforms = timm.data.create_transform(**data_config, is_training=False) |
| | |
| | output = model(transforms(img).unsqueeze(0).to(device)) # unsqueeze single image into batch of 1 |
| | |
| | top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5) |
| | print(top5_probabilities) |
| | print(top5_class_indices) |
| | |
| | ``` |
| | |
| | ## Bibtex |
| | ```bibtex |
| | @misc{yang2024compositional, |
| | title={Kolmogorov–Arnold Transformer}, |
| | author={Xingyi Yang and Xinchao Wang}, |
| | year={2024}, |
| | eprint={XXXX}, |
| | archivePrefix={arXiv}, |
| | primaryClass={cs.CV} |
| | } |
| | ``` |
| | |