danielaivanova commited on
Commit
dff7e68
·
verified ·
1 Parent(s): a23d1f4

Upload folder using huggingface_hub

Browse files
__pycache__/inference.cpython-312.pyc ADDED
Binary file (2.34 kB). View file
 
__pycache__/model.cpython-312.pyc ADDED
Binary file (3.26 kB). View file
 
config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "facebook/dinov3-vits16-pretrain-lvd1689m",
3
+ "num_classes": 145,
4
+ "temperature": 1.3,
5
+ "architecture": "DinoV3LinearMultiLinear",
6
+ "architecture_description": "Frozen DinoV3-ViT-S16 backbone + 3-layer MLP head (384 -> 256 -> 128 -> 145 classes)"
7
+ }
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
id2label.json ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "Acantharia",
3
+ "1": "Acanthoica_quattrospina",
4
+ "2": "Akashiwo_sanguinea",
5
+ "3": "Alexandrium_spp",
6
+ "4": "Amphidinium_crassum",
7
+ "5": "Amphidinium_sphenoides",
8
+ "6": "Apedinella",
9
+ "7": "Askenasia",
10
+ "8": "Asterompalus_flabellatus",
11
+ "9": "Asteromphalus_sarcophagus",
12
+ "10": "Azadinium_caudata",
13
+ "11": "Bacillaria",
14
+ "12": "Bacteriastrum",
15
+ "13": "Balanion",
16
+ "14": "Bead",
17
+ "15": "Brockmanniella",
18
+ "16": "Calciopappus",
19
+ "17": "Calyptrosphaera",
20
+ "18": "Carchesium",
21
+ "19": "Cerautulina_pelagica_chain",
22
+ "20": "Cerautulina_pelagica_single_double",
23
+ "21": "Chaetoceros_curvisetus_debelis",
24
+ "22": "Chaetoceros_didymus",
25
+ "23": "Chaetoceros_didymus_single",
26
+ "24": "Chaetoceros_morphotype1",
27
+ "25": "Chaetoceros_peruvianum",
28
+ "26": "Chaetoceros_socialis_type",
29
+ "27": "Chaetoceros_spp",
30
+ "28": "Chaetoceros_tenuissimus",
31
+ "29": "Chrysochromulina",
32
+ "30": "Codonellopsis",
33
+ "31": "Corethron",
34
+ "32": "Coscinodiscus_granii",
35
+ "33": "Cryptophyta",
36
+ "34": "Cylindrotheca_closterium",
37
+ "35": "Delphineis",
38
+ "36": "Detonula_pumila",
39
+ "37": "Detritus",
40
+ "38": "Dictyocha_fibula",
41
+ "39": "Dictyocha_speculum",
42
+ "40": "Dinobyron",
43
+ "41": "Dinophysis_acuminata_complex",
44
+ "42": "Dinophysis_acuta",
45
+ "43": "Dinophysis_caudata",
46
+ "44": "Dinophysis_caudata_var_diegensis",
47
+ "45": "Dinophysis_fortii",
48
+ "46": "Dinophysis_tripos",
49
+ "47": "Diploneis_crabro",
50
+ "48": "Diplopsalis",
51
+ "49": "Ditylum_brightwellii",
52
+ "50": "Entomoneis",
53
+ "51": "Erythropsidium",
54
+ "52": "Eucampia_",
55
+ "53": "Eutintinnus",
56
+ "54": "Eutriptiella",
57
+ "55": "Faecal_pellet",
58
+ "56": "Favella",
59
+ "57": "Flagellate_clump",
60
+ "58": "Flagellate_heart_shape",
61
+ "59": "Flagellate_morphotype1",
62
+ "60": "Flagellate_morphotype2",
63
+ "61": "Gonyaulax_spinifera",
64
+ "62": "Gonyaulax_verior",
65
+ "63": "Guinardia_delicatula_chain",
66
+ "64": "Guinardia_delicatula_single",
67
+ "65": "Guinardia_delicatula_single_double",
68
+ "66": "Guinardia_flaccida",
69
+ "67": "Guinardia_striata",
70
+ "68": "Gymnodiniales_morphotype1",
71
+ "69": "Gymnodinium_catenatum",
72
+ "70": "Gyrodinium_falcatum",
73
+ "71": "Gyrodinium_spirale_type",
74
+ "72": "Halosphaera",
75
+ "73": "Haslea_wawrikae",
76
+ "74": "Helicotheca_tamesis",
77
+ "75": "Heterocapsa_azadinium",
78
+ "76": "Heterocapsa_rotundata",
79
+ "77": "Heterocapsa_type",
80
+ "78": "Karenia_mikimotoi",
81
+ "79": "Karlodinium",
82
+ "80": "Katodinium",
83
+ "81": "Laboea_strobila",
84
+ "82": "Lauderia_annulata",
85
+ "83": "Leegaardiella_sol",
86
+ "84": "Lessardia",
87
+ "85": "Mesodinium_rubrum",
88
+ "86": "Mesodinium_small",
89
+ "87": "Mesoporos",
90
+ "88": "Meuniera_membranacea",
91
+ "89": "Meuniera_membranacea_single",
92
+ "90": "Nauplii",
93
+ "91": "Navicula_transitans_var_derasa",
94
+ "92": "Navicula_transitrans_var_derasa_f_delicatula",
95
+ "93": "Odontella_mobiliensis",
96
+ "94": "Oxytoxum_gracile",
97
+ "95": "Paralia_sulcata",
98
+ "96": "Phaeocystis",
99
+ "97": "Phalachroma_rotundatum",
100
+ "98": "Plagiolemma_distortum",
101
+ "99": "Planktoniella_sol",
102
+ "100": "Pleurosigma",
103
+ "101": "Podosira_stelligera",
104
+ "102": "Polykrikos",
105
+ "103": "Proboscia_truncata",
106
+ "104": "Prorocentrum_cordatum",
107
+ "105": "Prorocentrum_cordatum_minimum",
108
+ "106": "Prorocentrum_dentatum",
109
+ "107": "Prorocentrum_gracile",
110
+ "108": "Prorocentrum_micans",
111
+ "109": "Protoperidinium",
112
+ "110": "Protoperidinium_bipes",
113
+ "111": "Protoperidinium_steinii",
114
+ "112": "Psammodictyon_panduriforme",
115
+ "113": "Pseudchattonella_farcimen_round",
116
+ "114": "Pseudo-nitzschia_chain_double",
117
+ "115": "Pseudo-nitzschia_chain_multiple",
118
+ "116": "Pseudo-nitzschia_single",
119
+ "117": "Pseudochattonella",
120
+ "118": "Pseudochattonella_covering_Dictyocha",
121
+ "119": "Pseudochattonella_farcimen_oblong",
122
+ "120": "Pseudosolenia_calcar-avis",
123
+ "121": "Pterosperma",
124
+ "122": "Radiolaria_lithomelissa",
125
+ "123": "Rotifera",
126
+ "124": "Scrippsiella",
127
+ "125": "Stenosomella",
128
+ "126": "Stephanopyxsis",
129
+ "127": "Strombidium_ciliate",
130
+ "128": "Thalassionema_nitzschioides_double",
131
+ "129": "Thalassionema_nitzschioides_multiple",
132
+ "130": "Thalassionema_nitzschioides_single",
133
+ "131": "Thalassiosira_gravida_double",
134
+ "132": "Thalassiosira_gravida_rotula",
135
+ "133": "Thalassiosira_gravida_single",
136
+ "134": "Tiarina_fusus",
137
+ "135": "Tintinnopsis",
138
+ "136": "Tontonia",
139
+ "137": "Torodinium",
140
+ "138": "Torodinium_teredo",
141
+ "139": "Tripos_furca",
142
+ "140": "Tripos_fusus",
143
+ "141": "Tripos_horridus",
144
+ "142": "Tripos_lineatus",
145
+ "143": "Tripos_muelleri",
146
+ "144": "Undet_small"
147
+ }
images/synthetic_00004.png ADDED
images/synthetic_00005.png ADDED
inference.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModel, AutoImageProcessor
3
+ from model import DinoV3LinearMultiLinear
4
+
5
+ def load_model(weights_path, device="cuda"):
6
+ """
7
+ Load the pre-trained classifier.
8
+
9
+ Args:
10
+ weights_path: Path to the saved weights (.pt file)
11
+ device: Device to load model on ('cuda' or 'cpu')
12
+
13
+ Returns:
14
+ model: Loaded DinoV3LinearMultiLinear model in eval mode
15
+ processor: Image processor for preprocessing input images
16
+ """
17
+
18
+ # Load config
19
+ import json
20
+ with open("config.json", "r") as f:
21
+ config = json.load(f)
22
+ # Load backbone
23
+ backbone = AutoModel.from_pretrained(config["model_name"])
24
+ hidden_size = backbone.config.hidden_size
25
+ # Instantiate classifier head
26
+ model = DinoV3LinearMultiLinear(
27
+ backbone=backbone,
28
+ num_classes=config["num_classes"],
29
+ hidden_size=hidden_size,
30
+ freeze_backbone=True
31
+ )
32
+
33
+ # Load trained weights
34
+ model.load_state_dict(torch.load(weights_path, map_location=device)["model_state_dict"])
35
+ model.to(device)
36
+ model.eval()
37
+
38
+ # Load image processor
39
+ processor = AutoImageProcessor.from_pretrained(config["model_name"])
40
+
41
+ # Load labels
42
+ with open("id2label.json", "r") as f:
43
+ id2label = json.load(f)
44
+
45
+ return model, processor, id2label
46
+
47
+
48
+ def probs_to_labels(probs, id2label):
49
+ """
50
+ Convert probability distribution to labels.
51
+ """
52
+ predicted_indices = probs.argmax(dim=1)
53
+ predicted_labels = [id2label[str(idx.item())] for idx in predicted_indices]
54
+ return predicted_labels
model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoModel
4
+
5
+ class DinoV3LinearMultiLinear(nn.Module):
6
+ def __init__(self, backbone: AutoModel, hidden_size: int, num_classes: int, freeze_backbone: bool = True):
7
+ super().__init__()
8
+ self.backbone = backbone
9
+ self.num_classes = num_classes
10
+ if freeze_backbone:
11
+ for p in self.backbone.parameters():
12
+ p.requires_grad = False
13
+ self.backbone.eval()
14
+ # three linear layers like in the original syke-pic model
15
+ # hidden size -> 256 -> 128 -> num_classes
16
+ self.linear1 = nn.Linear(hidden_size, 256)
17
+ self.linear2 = nn.Linear(256, 128)
18
+ self.linear3 = nn.Linear(128, self.num_classes)
19
+
20
+ def print_num_trainable_parameters(self):
21
+ print(f"Number of trainable parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}")
22
+
23
+ def forward(self, pixel_values):
24
+ outputs = self.backbone(pixel_values=pixel_values)
25
+ last_hidden = outputs.last_hidden_state
26
+ cls = last_hidden[:, 0]
27
+ logits = self.linear3(self.linear2(self.linear1(cls)))
28
+ return logits
29
+
30
+ def predict(self, pixel_values, temperature=1.3):
31
+ """
32
+ Generate probability predictions for a batch of images.
33
+
34
+ Args:
35
+ pixel_values: Preprocessed image tensor (batch_size, 3, H, W)
36
+ temperature: Temperature for softmax calibration (default 1.3)
37
+
38
+ Returns:
39
+ probs: Probability distribution over classes (shape: [batch_size, num_classes])
40
+ """
41
+ logits = self.forward(pixel_values)
42
+ probs = torch.softmax(logits / temperature, dim=1)
43
+ return probs
readme.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ tags:
6
+ - image-classification
7
+ - plankton
8
+ - dinov3
9
+ - biology
10
+ - marine
11
+ datasets:
12
+ - ifcb
13
+ ---
14
+
15
+ # DINO Plankton Classifier
16
+ The model is trained on PML IFCB data consisting of 145 plankton classes.
17
+
18
+ # Inference
19
+ Use the provided inference script. See example in `demo.ipynb` on predicting the classes for two synthetic samples.
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.10.1
2
+ bitsandbytes==0.48.2
3
+ datasets==2.21.0
4
+ diffusers==0.35.1
5
+ evaluate==0.4.5
6
+ fastapi==0.116.1
7
+ ffmpy==0.6.1
8
+ gradio==5.45.0
9
+ gradio_client==1.13.0
10
+ safehttpx==0.1.6
11
+ safetensors==0.6.2
12
+ tokenizers==0.22.0
13
+ torch==2.8.0+cu126
14
+ torchaudio==2.8.0+cu126
15
+ torchvision==0.23.0+cu126
16
+ transformers==4.57.0.dev0
weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4929239fb6342ff28b2ea9ac0d71124fa71479fb3dfbec8df216dd00dcdc48a
3
+ size 88279131