|
|
--- |
|
|
license: mit |
|
|
base_model: |
|
|
- google/vit-large-patch16-384 |
|
|
pipeline_tag: image-classification |
|
|
library_name: timm |
|
|
tags: |
|
|
- biology |
|
|
--- |
|
|
# rexologue/vit_large_384_for_trees |
|
|
This repository hosts a fine-tuned `vit_large_patch16_384` classifier |
|
|
|
|
|
## Labels |
|
|
|
|
|
- abies_sibirica |
|
|
- acer_campestre |
|
|
- acer_ginnala |
|
|
- acer_negundo |
|
|
- acer_platanoides |
|
|
- acer_pseudoplatanus |
|
|
- acer_tataricum |
|
|
- aesculus_hippocastanum |
|
|
- alnus_alnobetula_fruticosa |
|
|
- alnus_glutinosa |
|
|
- alnus_incana |
|
|
- arctostaphylos_uva-ursi |
|
|
- berberis_vulgaris |
|
|
- betula_nana |
|
|
- betula_pendula |
|
|
- betula_pubescens |
|
|
- calluna_vulgaris |
|
|
- cornus_alba |
|
|
- cornus_mas |
|
|
- cornus_sanguinea |
|
|
- cornus_suecica |
|
|
- cotoneaster_lucidus |
|
|
- cotoneaster_melanocarpus |
|
|
- daphne_mezereum |
|
|
- elaeagnus_angustifolia |
|
|
- euonymus_europaeus |
|
|
- euonymus_verrucosus |
|
|
- fraxinus_excelsior |
|
|
- fraxinus_pennsylvanica |
|
|
- genista_tinctoria |
|
|
- hippophae_rhamnoides |
|
|
- hypericum_maculatum |
|
|
- hypericum_perforatum |
|
|
- juglans_mandshurica |
|
|
- juniperus_communis |
|
|
- larix_sibirica |
|
|
- ligustrum_vulgare |
|
|
- lonicera_caerulea |
|
|
- lonicera_nigra |
|
|
- lonicera_tatarica |
|
|
- lonicera_xylosteum |
|
|
- physocarpus_opulifolius |
|
|
- picea_abies |
|
|
- picea_obovata |
|
|
- pinus_sibirica |
|
|
- pinus_sylvestris |
|
|
- populus |
|
|
- populus_alba |
|
|
- populus_nigra |
|
|
- populus_tremula |
|
|
- potentilla_argentea |
|
|
- potentilla_erecta |
|
|
- potentilla_intermedia |
|
|
- potentilla_norvegica |
|
|
- potentilla_paradoxa |
|
|
- potentilla_reptans |
|
|
- potentilla_supina |
|
|
- quercus_robur |
|
|
- ribes_nigrum |
|
|
- ribes_rubrum |
|
|
- ribes_uva-crispa |
|
|
- rosa_acicularis |
|
|
- rosa_majalis |
|
|
- rosa_rugosa |
|
|
- rubus_arcticus |
|
|
- rubus_caesius |
|
|
- rubus_chamaemorus |
|
|
- rubus_idaeus |
|
|
- rubus_nessensis |
|
|
- rubus_saxatilis |
|
|
- salix_alba |
|
|
- salix_caprea |
|
|
- salix_cinerea |
|
|
- salix_gmelinii |
|
|
- salix_myrsinifolia |
|
|
- salix_pentandra |
|
|
- salix_triandra |
|
|
- salix_viminalis |
|
|
- sorbaria_sorbifolia |
|
|
- sorbus_aucuparia |
|
|
- spiraea_salicifolia |
|
|
- symphoricarpos_albus |
|
|
- tilia_cordata |
|
|
- ulmus_glabra |
|
|
- ulmus_laevis |
|
|
- ulmus_pumila |
|
|
- vaccinium_myrtillus |
|
|
- vaccinium_oxycoccos |
|
|
- vaccinium_uliginosum |
|
|
- vaccinium_vitis-idaea |
|
|
- viburnum_lantana |
|
|
- viburnum_opulus |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import json, torch, timm |
|
|
from huggingface_hub import hf_hub_download |
|
|
from timm.data.transforms_factory import create_transform |
|
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
|
from PIL import Image |
|
|
|
|
|
REPO = "rexologue/vit_large_384_for_trees" |
|
|
MODEL_NAME = "vit_large_patch16_384" |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
# 1) labels |
|
|
labels_path = hf_hub_download(REPO, filename="labels.json") |
|
|
with open(labels_path, "r", encoding="utf-8") as f: |
|
|
raw = json.load(f) |
|
|
labels = [raw[str(i)] for i in range(len(raw))] if isinstance(raw, dict) else list(raw) |
|
|
|
|
|
# 2) weights |
|
|
ckpt_path = hf_hub_download(REPO, filename="pytorch_model.bin") |
|
|
state = torch.load(ckpt_path, map_location="cpu") |
|
|
if any(k.startswith("module.") for k in state): # DDP fix |
|
|
state = {k.replace("module.", "", 1): v for k, v in state.items()} |
|
|
|
|
|
# 3) model |
|
|
model = timm.create_model(MODEL_NAME, num_classes=len(labels), pretrained=False) |
|
|
model.load_state_dict(state, strict=True) |
|
|
model.to(DEVICE).eval() |
|
|
|
|
|
# 4) preprocessing (ViT-L/16 @ 384 w/ ImageNet mean/std + bicubic) |
|
|
transform = create_transform( |
|
|
input_size=(3, 384, 384), |
|
|
interpolation="bicubic", |
|
|
mean=IMAGENET_DEFAULT_MEAN, |
|
|
std=IMAGENET_DEFAULT_STD, |
|
|
) |
|
|
|
|
|
# 5) run |
|
|
img = Image.open("your_image.jpg").convert("RGB") |
|
|
x = transform(img).unsqueeze(0).to(DEVICE) |
|
|
with torch.no_grad(): |
|
|
logits = model(x) |
|
|
probs = torch.softmax(logits, dim=1)[0].cpu() |
|
|
topk = probs.topk(k=min(5, len(labels))) |
|
|
print([(labels[i], float(probs[i])) for i in topk.indices]) |
|
|
``` |