File size: 3,550 Bytes
fe948e7 dd484fb fe948e7 dd484fb fe948e7 dd484fb fe948e7 dd484fb fe948e7 dd484fb 0228949 fe948e7 0228949 dd484fb 0228949 dd484fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
---
license: mit
base_model:
- google/vit-large-patch16-384
pipeline_tag: image-classification
library_name: timm
tags:
- biology
---
# rexologue/vit_large_384_for_trees
This repository hosts a fine-tuned `vit_large_patch16_384` classifier
## Labels
- abies_sibirica
- acer_campestre
- acer_ginnala
- acer_negundo
- acer_platanoides
- acer_pseudoplatanus
- acer_tataricum
- aesculus_hippocastanum
- alnus_alnobetula_fruticosa
- alnus_glutinosa
- alnus_incana
- arctostaphylos_uva-ursi
- berberis_vulgaris
- betula_nana
- betula_pendula
- betula_pubescens
- calluna_vulgaris
- cornus_alba
- cornus_mas
- cornus_sanguinea
- cornus_suecica
- cotoneaster_lucidus
- cotoneaster_melanocarpus
- daphne_mezereum
- elaeagnus_angustifolia
- euonymus_europaeus
- euonymus_verrucosus
- fraxinus_excelsior
- fraxinus_pennsylvanica
- genista_tinctoria
- hippophae_rhamnoides
- hypericum_maculatum
- hypericum_perforatum
- juglans_mandshurica
- juniperus_communis
- larix_sibirica
- ligustrum_vulgare
- lonicera_caerulea
- lonicera_nigra
- lonicera_tatarica
- lonicera_xylosteum
- physocarpus_opulifolius
- picea_abies
- picea_obovata
- pinus_sibirica
- pinus_sylvestris
- populus
- populus_alba
- populus_nigra
- populus_tremula
- potentilla_argentea
- potentilla_erecta
- potentilla_intermedia
- potentilla_norvegica
- potentilla_paradoxa
- potentilla_reptans
- potentilla_supina
- quercus_robur
- ribes_nigrum
- ribes_rubrum
- ribes_uva-crispa
- rosa_acicularis
- rosa_majalis
- rosa_rugosa
- rubus_arcticus
- rubus_caesius
- rubus_chamaemorus
- rubus_idaeus
- rubus_nessensis
- rubus_saxatilis
- salix_alba
- salix_caprea
- salix_cinerea
- salix_gmelinii
- salix_myrsinifolia
- salix_pentandra
- salix_triandra
- salix_viminalis
- sorbaria_sorbifolia
- sorbus_aucuparia
- spiraea_salicifolia
- symphoricarpos_albus
- tilia_cordata
- ulmus_glabra
- ulmus_laevis
- ulmus_pumila
- vaccinium_myrtillus
- vaccinium_oxycoccos
- vaccinium_uliginosum
- vaccinium_vitis-idaea
- viburnum_lantana
- viburnum_opulus
## Usage
```python
import json, torch, timm
from huggingface_hub import hf_hub_download
from timm.data.transforms_factory import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from PIL import Image
REPO = "rexologue/vit_large_384_for_trees"
MODEL_NAME = "vit_large_patch16_384"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 1) labels
labels_path = hf_hub_download(REPO, filename="labels.json")
with open(labels_path, "r", encoding="utf-8") as f:
raw = json.load(f)
labels = [raw[str(i)] for i in range(len(raw))] if isinstance(raw, dict) else list(raw)
# 2) weights
ckpt_path = hf_hub_download(REPO, filename="pytorch_model.bin")
state = torch.load(ckpt_path, map_location="cpu")
if any(k.startswith("module.") for k in state): # DDP fix
state = {k.replace("module.", "", 1): v for k, v in state.items()}
# 3) model
model = timm.create_model(MODEL_NAME, num_classes=len(labels), pretrained=False)
model.load_state_dict(state, strict=True)
model.to(DEVICE).eval()
# 4) preprocessing (ViT-L/16 @ 384 w/ ImageNet mean/std + bicubic)
transform = create_transform(
input_size=(3, 384, 384),
interpolation="bicubic",
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
)
# 5) run
img = Image.open("your_image.jpg").convert("RGB")
x = transform(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1)[0].cpu()
topk = probs.topk(k=min(5, len(labels)))
print([(labels[i], float(probs[i])) for i in topk.indices])
``` |