.gitattributes CHANGED
@@ -33,4 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- eval_viz.png filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
README.md CHANGED
@@ -1,5 +1,6 @@
1
  ---
2
  license: apache-2.0
 
3
  tags:
4
  - image-classification
5
  - multi-label-classification
@@ -44,87 +45,63 @@ backbone fine-tuned end-to-end with a single linear projection head.
44
  | Precision | bfloat16 (backbone) / float32 (projection + loss) |
45
  | Hardware | 2× GPU, ThreadPoolExecutor + NCCL all-reduce |
46
 
47
- ![eval_viz](./eval_viz.png)
48
-
49
  ## Usage
50
 
51
- ### 1. Install dependencies
52
 
53
- ```bash
54
- pip install -r requirements.txt
55
- ```
56
 
57
- Or manually:
58
-
59
- ```bash
60
- pip install torch torchvision safetensors Pillow requests \
61
- python-multipart fastapi uvicorn jinja2 aiofiles
62
- ```
63
 
64
- ### 2. Download model files
 
65
 
66
- ```bash
67
- huggingface-cli download lodestones/taggerine \
68
- tagger_proto.safetensors \
69
- tagger_vocab_with_categories_and_alias_updated.json \
70
- tagger_ui_server.py \
71
- inference_tagger_standalone.py \
72
- --local-dir .
73
  ```
74
 
75
- > **Note:** `tagger_proto.safetensors` is ~5.3 GB. Make sure you have enough disk space.
76
-
77
- ### 3. Download the `tagger_ui/` templates folder
78
-
79
- The server requires the `tagger_ui/templates/` directory to be present alongside `tagger_ui_server.py`:
80
 
81
  ```bash
82
- huggingface-cli download lodestones/taggerine \
83
- --include "tagger_ui/**" \
84
- --local-dir .
85
- ```
 
 
86
 
87
- ### 4. Run the Web UI
 
88
 
89
- ```bash
90
- python tagger_ui_server.py \
91
- --checkpoint tagger_proto.safetensors \
92
- --vocab tagger_vocab_with_categories_and_alias_updated.json \
93
- --port 7860
94
- # → open http://localhost:7860
95
  ```
96
 
97
- **CPU-only machine?** Add `--device cpu` (inference will be slower):
98
 
99
  ```bash
 
 
100
  python tagger_ui_server.py \
101
  --checkpoint tagger_proto.safetensors \
102
- --vocab tagger_vocab_with_categories_and_alias_updated.json \
103
- --device cpu \
104
  --port 7860
105
- ```
106
-
107
- ### Standalone CLI inference (no server)
108
-
109
- ```bash
110
- python inference_tagger_standalone.py \
111
- --checkpoint tagger_proto.safetensors \
112
- --vocab tagger_vocab_with_categories_and_alias_updated.json \
113
- --images photo.jpg \
114
- --topk 30
115
  ```
116
 
117
  ## Files
118
 
119
  | File | Description |
120
  |---|---|
121
- | `tagger_proto.safetensors` | Model weights (bfloat16) |
122
- | `tagger_vocab_with_categories_and_alias_updated.json` | `{"idx2tag": [...], "tag2category": {...}}` — 74 625 tags with category metadata |
123
- | `tagger_vocab_with_categories.json` | Same without alias data |
124
- | `tagger_vocab.json` | Minimal vocab — `{"idx2tag": [...]}` only |
125
- | `inference_tagger_standalone.py` | Self-contained CLI inference script (no `transformers` dep) |
126
  | `tagger_ui_server.py` | FastAPI + Jinja2 web UI server |
127
- | `requirements.txt` | Python dependencies |
128
 
129
  ## Tag Vocabulary
130
 
@@ -141,7 +118,8 @@ Minimum tag frequency threshold: **50** occurrences across the combined dataset.
141
 
142
  ## Limitations
143
 
144
- - Evaluated on booru-style illustrations and furry art; performance on photographic images or other art works to some extend.
 
145
  - The vocabulary reflects the biases of e621 and Danbooru annotation practices.
146
 
147
  ## License
 
1
  ---
2
  license: apache-2.0
3
+
4
  tags:
5
  - image-classification
6
  - multi-label-classification
 
45
  | Precision | bfloat16 (backbone) / float32 (projection + loss) |
46
  | Hardware | 2× GPU, ThreadPoolExecutor + NCCL all-reduce |
47
 
 
 
48
  ## Usage
49
 
50
+ ### Standalone (no `transformers` dependency)
51
 
52
+ ```python
53
+ from inference_tagger_standalone import Tagger
 
54
 
55
+ tagger = Tagger(
56
+ checkpoint_path="tagger_proto.safetensors",
57
+ vocab_path="tagger_vocab_with_categories.json",
58
+ device="cuda",
59
+ )
 
60
 
61
+ tags = tagger.predict("photo.jpg", topk=40)
62
+ # → [("solo", 0.98), ("anthro", 0.95), ...]
63
 
64
+ # or threshold-based
65
+ tags = tagger.predict("https://example.com/image.jpg", threshold=0.35)
 
 
 
 
 
66
  ```
67
 
68
+ ### CLI
 
 
 
 
69
 
70
  ```bash
71
+ # top-30 tags, pretty output
72
+ python inference_tagger_standalone.py \
73
+ --checkpoint tagger_proto.safetensors \
74
+ --vocab tagger_vocab_with_categories.json \
75
+ --images photo.jpg https://example.com/image.jpg \
76
+ --topk 30
77
 
78
+ # comma-separated string (pipe into diffusion trainer)
79
+ python inference_tagger_standalone.py ... --format tags
80
 
81
+ # JSON
82
+ python inference_tagger_standalone.py ... --format json
 
 
 
 
83
  ```
84
 
85
+ ### Web UI
86
 
87
  ```bash
88
+ pip install fastapi uvicorn jinja2 aiofiles
89
+
90
  python tagger_ui_server.py \
91
  --checkpoint tagger_proto.safetensors \
92
+ --vocab tagger_vocab_with_categories.json \
 
93
  --port 7860
94
+ # → open http://localhost:7860
 
 
 
 
 
 
 
 
 
95
  ```
96
 
97
  ## Files
98
 
99
  | File | Description |
100
  |---|---|
101
+ | `*.safetensors` | Model weights (bfloat16) |
102
+ | `tagger_vocab_with_categories.json` | `{"idx2tag": [...]}` — 74 625 tag strings ordered by training frequency |
103
+ | `inference_tagger_standalone.py` | Self-contained inference script (no `transformers` dep) |
 
 
104
  | `tagger_ui_server.py` | FastAPI + Jinja2 web UI server |
 
105
 
106
  ## Tag Vocabulary
107
 
 
118
 
119
  ## Limitations
120
 
121
+ - Evaluated on booru-style illustrations and furry art; performance on photographic
122
+ images or other art styles is untested.
123
  - The vocabulary reflects the biases of e621 and Danbooru annotation practices.
124
 
125
  ## License
eval_viz.png DELETED

Git LFS Details

  • SHA256: 8cc88c49e85e69897c4ab5b7abb3e9a5400036fcf26060e3b5ff75e62931b177
  • Pointer size: 131 Bytes
  • Size of remote file: 270 kB
inference_tagger_standalone.py CHANGED
@@ -64,19 +64,17 @@ from safetensors.torch import load_file
64
  # All hyperparameters match facebook/dinov3-vith16plus-pretrain-lvd1689m
65
  # =============================================================================
66
 
67
- D_MODEL = 1280
68
- N_HEADS = 20
69
- HEAD_DIM = D_MODEL // N_HEADS # 64
70
- N_LAYERS = 32
71
- D_FFN = 5120
72
  N_REGISTERS = 4
73
- PATCH_SIZE = 16
74
- ROPE_THETA = 100.0
75
- ROPE_RESCALE = 2.0
76
- LN_EPS = 1e-5
77
- LAYERSCALE = 1.0
78
-
79
- FEATURE_DIM = (1 + N_REGISTERS) * D_MODEL # 6400
80
 
81
 
82
  # ---------------------------------------------------------------------------
@@ -85,23 +83,25 @@ FEATURE_DIM = (1 + N_REGISTERS) * D_MODEL # 6400
85
 
86
  @lru_cache(maxsize=32)
87
  def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
 
88
  device = torch.device(device_str)
89
  cy = torch.arange(0.5, h, dtype=torch.float32, device=device) / h
90
  cx = torch.arange(0.5, w, dtype=torch.float32, device=device) / w
91
  coords = torch.stack(torch.meshgrid(cy, cx, indexing="ij"), dim=-1).flatten(0, 1)
92
- coords = 2.0 * coords - 1.0
93
  coords = coords * ROPE_RESCALE
94
  return coords # [h*w, 2]
95
 
96
 
97
  def _build_rope(h_patches: int, w_patches: int,
98
  dtype: torch.dtype, device: torch.device):
99
- coords = _patch_coords_cached(h_patches, w_patches, str(device))
 
100
  inv_freq = 1.0 / (ROPE_THETA ** torch.arange(
101
- 0, 1, 4 / HEAD_DIM, dtype=torch.float32, device=device))
102
- angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :]
103
- angles = angles.flatten(1, 2).tile(2)
104
- cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0)
105
  sin = torch.sin(angles).to(dtype).unsqueeze(0).unsqueeze(0)
106
  return cos, sin
107
 
@@ -113,6 +113,7 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
113
 
114
  def _apply_rope(q: torch.Tensor, k: torch.Tensor,
115
  cos: torch.Tensor, sin: torch.Tensor):
 
116
  n_pre = 1 + N_REGISTERS
117
  q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :]
118
  k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :]
@@ -122,7 +123,7 @@ def _apply_rope(q: torch.Tensor, k: torch.Tensor,
122
 
123
 
124
  # ---------------------------------------------------------------------------
125
- # Transformer blocks
126
  # ---------------------------------------------------------------------------
127
 
128
  class _Attention(nn.Module):
@@ -133,7 +134,7 @@ class _Attention(nn.Module):
133
  self.v_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
134
  self.o_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
135
 
136
- def forward(self, x, cos, sin):
137
  B, S, _ = x.shape
138
  q = self.q_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
139
  k = self.k_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
@@ -147,273 +148,125 @@ class _GatedMLP(nn.Module):
147
  def __init__(self):
148
  super().__init__()
149
  self.gate_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
150
- self.up_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
151
- self.down_proj = nn.Linear(D_FFN, D_MODEL, bias=True)
152
 
153
- def forward(self, x):
154
  return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
155
 
156
 
157
  class _Block(nn.Module):
158
  def __init__(self):
159
  super().__init__()
160
- self.norm1 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
161
- self.attention = _Attention()
162
  self.layer_scale1 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
163
- self.norm2 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
164
- self.mlp = _GatedMLP()
165
  self.layer_scale2 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
166
 
167
- def forward(self, x, cos, sin):
168
  x = x + self.attention(self.norm1(x), cos, sin) * self.layer_scale1
169
  x = x + self.mlp(self.norm2(x)) * self.layer_scale2
170
  return x
171
 
172
 
173
- class _Embeddings(nn.Module):
174
- def __init__(self):
175
- super().__init__()
176
- # zeros() rather than empty() so a forgotten checkpoint key fails
177
- # predictably instead of producing undefined outputs.
178
- self.cls_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
179
- self.mask_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
180
- self.register_tokens = nn.Parameter(torch.zeros(1, N_REGISTERS, D_MODEL))
181
- self.patch_embeddings = nn.Conv2d(
182
- 3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE)
183
-
184
- def forward(self, pixel_values):
185
- B = pixel_values.shape[0]
186
- dtype = self.patch_embeddings.weight.dtype
187
- patches = self.patch_embeddings(
188
- pixel_values.to(dtype)).flatten(2).transpose(1, 2)
189
- cls = self.cls_token.expand(B, -1, -1)
190
- regs = self.register_tokens.expand(B, -1, -1)
191
- return torch.cat([cls, regs, patches], dim=1)
192
-
193
 
194
  class DINOv3ViTH(nn.Module):
195
  """DINOv3 ViT-H/16+ backbone.
196
 
197
- Token layout: [CLS, reg_0..reg_3, patch_0..patch_N].
198
  Returns last_hidden_state [B, 1+R+P, D_MODEL].
 
 
 
 
199
  """
200
 
201
  def __init__(self):
202
  super().__init__()
 
203
  self.embeddings = _Embeddings()
204
  self.layer = nn.ModuleList([_Block() for _ in range(N_LAYERS)])
205
- self.norm = nn.LayerNorm(D_MODEL, eps=LN_EPS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- def forward(self, pixel_values):
208
- _, _, H, W = pixel_values.shape
209
- x = self.embeddings(pixel_values)
210
  h_p, w_p = H // PATCH_SIZE, W // PATCH_SIZE
211
  cos, sin = _build_rope(h_p, w_p, x.dtype, pixel_values.device)
212
- for block in self.layer:
213
- x = block(x, cos, sin)
214
- return self.norm(x)
215
 
216
- def get_image_tokens(self, pixel_values):
217
- """Return patch tokens only (no CLS/registers) as [B, h_p*w_p, D_MODEL]
218
- and the spatial grid dimensions (h_p, w_p)."""
219
- _, _, H, W = pixel_values.shape
220
- h_p, w_p = H // PATCH_SIZE, W // PATCH_SIZE
221
- x = self.embeddings(pixel_values)
222
- cos, sin = _build_rope(h_p, w_p, x.dtype, pixel_values.device)
223
  for block in self.layer:
224
  x = block(x, cos, sin)
225
- x = self.norm(x)
226
- # token layout: [CLS, reg_0..reg_R-1, patch_0..patch_N]
227
- patch_tokens = x[:, 1 + N_REGISTERS:, :] # [B, h_p*w_p, D_MODEL]
228
- return patch_tokens, h_p, w_p
229
-
230
 
231
- # =============================================================================
232
- # Head — auto-detected from the checkpoint
233
- # =============================================================================
234
 
235
- class _LowRankHead(nn.Module):
236
- """Two-matrix low-rank projection head.
237
 
238
- features (in_dim)
239
- Linear(in_dim, rank, bias=?)
240
- Linear(rank, num_tags, bias=?)
 
241
  """
242
 
243
- def __init__(self, in_dim: int, rank: int, num_tags: int,
244
- down_bias: bool, up_bias: bool):
245
  super().__init__()
246
- self.proj_down = nn.Linear(in_dim, rank, bias=down_bias)
247
- self.proj_up = nn.Linear(rank, num_tags, bias=up_bias)
248
-
249
- def forward(self, x):
250
- return self.proj_up(self.proj_down(x))
251
 
252
-
253
- def _build_head_from_checkpoint(
254
- head_sd: dict,
255
- in_dim: int,
256
- num_tags: int,
257
- ) -> tuple[nn.Module, dict]:
258
- """Inspect head_sd and build a matching Module.
259
-
260
- Supports two layouts, in order of preference:
261
- 1. Single linear — any ``*.weight`` with shape [num_tags, in_dim]
262
- 2. Low-rank pair (2 mats) — one ``*.weight`` [rank, in_dim] plus
263
- one ``*.weight`` [num_tags, rank]
264
-
265
- Returns (module, remapped_state_dict) where the remapped state dict
266
- matches the module's own key names so strict loading works.
267
- """
268
- weights_2d = [(k, v) for k, v in head_sd.items()
269
- if k.endswith(".weight") and v.ndim == 2]
270
-
271
- # --- Case 1: single dense linear ---------------------------------------
272
- singles = [(k, v) for k, v in weights_2d
273
- if tuple(v.shape) == (num_tags, in_dim)]
274
- if len(weights_2d) <= 2 and len(singles) == 1:
275
- wkey, wval = singles[0]
276
- base = wkey[:-len(".weight")]
277
- bias_key = base + ".bias"
278
- has_bias = bias_key in head_sd
279
- module = nn.Linear(in_dim, num_tags, bias=has_bias)
280
- remapped = {"weight": wval}
281
- if has_bias:
282
- remapped["bias"] = head_sd[bias_key]
283
- # Sanity check: no extra keys we don't understand
284
- expected_src = {wkey} | ({bias_key} if has_bias else set())
285
- extra = set(head_sd) - expected_src
286
- if extra:
287
- raise RuntimeError(
288
- f"Head has single-linear shape but extra unknown keys: {sorted(extra)}")
289
- return module, remapped
290
-
291
- # --- Case 2: low-rank pair ---------------------------------------------
292
- down = None # (key, tensor) with shape [rank, in_dim]
293
- up = None # (key, tensor) with shape [num_tags, rank]
294
- for k, v in weights_2d:
295
- if v.shape[1] == in_dim and v.shape[0] != num_tags:
296
- down = (k, v)
297
- elif v.shape[0] == num_tags and v.shape[1] != in_dim:
298
- up = (k, v)
299
-
300
- if down is not None and up is not None:
301
- rank_down = down[1].shape[0]
302
- rank_up = up[1].shape[1]
303
- if rank_down != rank_up:
304
- raise RuntimeError(
305
- f"Low-rank head: inner dims disagree "
306
- f"(down out={rank_down}, up in={rank_up})")
307
-
308
- down_key, down_w = down
309
- up_key, up_w = up
310
- down_base = down_key[:-len(".weight")]
311
- up_base = up_key[:-len(".weight")]
312
- down_bias_key = down_base + ".bias"
313
- up_bias_key = up_base + ".bias"
314
- has_down_bias = down_bias_key in head_sd
315
- has_up_bias = up_bias_key in head_sd
316
-
317
- module = _LowRankHead(
318
- in_dim=in_dim,
319
- rank=rank_down,
320
- num_tags=num_tags,
321
- down_bias=has_down_bias,
322
- up_bias=has_up_bias,
323
- )
324
- remapped = {
325
- "proj_down.weight": down_w,
326
- "proj_up.weight": up_w,
327
- }
328
- if has_down_bias:
329
- remapped["proj_down.bias"] = head_sd[down_bias_key]
330
- if has_up_bias:
331
- remapped["proj_up.bias"] = head_sd[up_bias_key]
332
-
333
- # Sanity check
334
- expected_src = {down_key, up_key}
335
- if has_down_bias:
336
- expected_src.add(down_bias_key)
337
- if has_up_bias:
338
- expected_src.add(up_bias_key)
339
- extra = set(head_sd) - expected_src
340
- if extra:
341
- raise RuntimeError(
342
- f"Low-rank head detected but checkpoint has extra unknown "
343
- f"head keys: {sorted(extra)}")
344
-
345
- print(f"[Tagger] Detected low-rank head: "
346
- f"in_dim={in_dim}, rank={rank_down}, num_tags={num_tags} "
347
- f"(down_bias={has_down_bias}, up_bias={has_up_bias})")
348
- return module, remapped
349
-
350
- raise RuntimeError(
351
- "Could not infer head architecture from checkpoint. "
352
- f"Non-backbone keys found: {sorted(head_sd.keys())}"
353
- )
354
 
355
 
356
  # =============================================================================
357
- # Tagger wrapper module
358
  # =============================================================================
359
 
360
  class DINOv3Tagger(nn.Module):
361
- """Backbone + head. The head is attached after the checkpoint is
362
- inspected (so we can build the right shape)."""
363
-
364
- def __init__(self):
365
- super().__init__()
366
- self.backbone = DINOv3ViTH()
367
- self.head: nn.Module | None = None # attached by Tagger
368
-
369
- def forward(self, pixel_values):
370
- hidden = self.backbone(pixel_values)
371
- cls = hidden[:, 0, :]
372
- regs = hidden[:, 1: 1 + N_REGISTERS, :].flatten(1)
373
- features = torch.cat([cls, regs], dim=-1).float() # fp32 for head
374
- return self.head(features)
375
-
376
-
377
- # =============================================================================
378
- # Checkpoint loading helpers
379
- # =============================================================================
380
 
381
- def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]:
382
- """Split full state dict into (backbone_sd, head_sd), stripping the
383
- ``backbone.`` prefix and applying the remaps needed to match
384
- ``DINOv3ViTH``'s parameter layout:
385
-
386
- 1. ``backbone.model.layer.N.*`` → ``layer.N.*``
387
- (the checkpoint has an HF-style intermediate ``model`` wrapper
388
- that our flat backbone class does not)
389
- 2. ``...layer_scale{1,2}.lambda1`` → ``...layer_scale{1,2}``
390
- (HF stores layer_scale as a sub-module with a ``lambda1``
391
- parameter; we use a plain ``nn.Parameter``)
392
- 3. Drop any ``rope_embeddings`` buffers (recomputed on the fly)
393
  """
394
- backbone_sd: dict = {}
395
- head_sd: dict = {}
396
- for k, v in sd.items():
397
- if k.startswith("backbone."):
398
- nk = k[len("backbone."):]
399
- # Remap (1): strip intermediate "model." before "layer."
400
- if nk.startswith("model.layer."):
401
- nk = nk[len("model."):]
402
- backbone_sd[nk] = v
403
- else:
404
- head_sd[k] = v
405
-
406
- # Remap (2): layer.N.layer_scale{1,2}.lambda1 → layer.N.layer_scale{1,2}
407
- for k in list(backbone_sd.keys()):
408
- if ".layer_scale" in k and k.endswith(".lambda1"):
409
- backbone_sd[k[:-len(".lambda1")]] = backbone_sd.pop(k)
410
 
411
- # Remap (3): drop rope buffers (recomputed on the fly)
412
- for k in list(backbone_sd.keys()):
413
- if "rope_embeddings" in k:
414
- backbone_sd.pop(k)
415
 
416
- return backbone_sd, head_sd
 
 
 
 
 
417
 
418
 
419
  # =============================================================================
@@ -421,7 +274,7 @@ def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]:
421
  # =============================================================================
422
 
423
  _IMAGENET_MEAN = [0.485, 0.456, 0.406]
424
- _IMAGENET_STD = [0.229, 0.224, 0.225]
425
 
426
 
427
  def _snap(x: int, m: int) -> int:
@@ -438,22 +291,12 @@ def _open_image(source) -> Image.Image:
438
 
439
 
440
  def preprocess_image(source, max_size: int = 1024) -> torch.Tensor:
441
- """Load and preprocess an image → [1, 3, H, W] float32, ImageNet-normalised.
442
-
443
- Aspect ratio is preserved: a single scale factor is chosen so that the
444
- long edge fits inside max_size after snapping to a PATCH_SIZE multiple.
445
- """
446
  img = _open_image(source)
447
  w, h = img.size
448
-
449
- # Target long-edge (snapped to patch multiple).
450
- long_edge = max(w, h)
451
- target_long = _snap(min(long_edge, max_size), PATCH_SIZE)
452
- scale = target_long / long_edge
453
-
454
- new_w = _snap(max(PATCH_SIZE, round(w * scale)), PATCH_SIZE)
455
- new_h = _snap(max(PATCH_SIZE, round(h * scale)), PATCH_SIZE)
456
-
457
  return v2.Compose([
458
  v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
459
  v2.ToImage(),
@@ -472,15 +315,13 @@ class Tagger:
472
  Parameters
473
  ----------
474
  checkpoint_path : str
475
- Path to a .safetensors or .pt/.pth checkpoint.
476
  vocab_path : str
477
- Path to tagger_vocab.json or tagger_vocab_with_categories.json
478
- (either must contain an ``idx2tag`` list).
479
  device : str
480
- "cuda", "cuda:0", "cpu", ...
481
  dtype : torch.dtype
482
- Backbone precision. bfloat16 recommended on Ampere+, float16 for
483
- older GPUs, float32 for CPU. The head always runs in fp32.
484
  max_size : int
485
  Long-edge cap in pixels before feeding to the model.
486
  """
@@ -493,13 +334,8 @@ class Tagger:
493
  dtype: torch.dtype = torch.bfloat16,
494
  max_size: int = 1024,
495
  ):
496
- want_cuda = device.startswith("cuda")
497
- if want_cuda and not torch.cuda.is_available():
498
- print("[Tagger] CUDA not available, falling back to CPU")
499
- device = "cpu"
500
- dtype = torch.float32
501
- self.device = torch.device(device)
502
- self.dtype = dtype
503
  self.max_size = max_size
504
 
505
  with open(vocab_path) as f:
@@ -508,112 +344,36 @@ class Tagger:
508
  self.num_tags = len(self.idx2tag)
509
  print(f"[Tagger] Vocabulary: {self.num_tags:,} tags")
510
 
511
- # --- Load checkpoint to CPU first so we can inspect shapes ---------
 
512
  print(f"[Tagger] Loading checkpoint: {checkpoint_path}")
513
  if checkpoint_path.endswith((".safetensors", ".sft")):
514
- sd = load_file(checkpoint_path, device="cpu")
515
  else:
516
- sd = torch.load(checkpoint_path, map_location="cpu")
517
-
518
- backbone_sd, head_sd = _split_and_clean_state_dict(sd)
519
-
520
- if not head_sd:
521
- raise RuntimeError(
522
- "Checkpoint contains no non-backbone keys — cannot build head.")
523
-
524
- # --- Build model, inferring head shape from the checkpoint --------
525
- self.model = DINOv3Tagger()
526
- head_module, head_sd_remapped = _build_head_from_checkpoint(
527
- head_sd, in_dim=FEATURE_DIM, num_tags=self.num_tags,
528
- )
529
- self.model.head = head_module
530
-
531
- # --- Strict load — mismatches raise instead of silently passing ----
532
- self.model.backbone.load_state_dict(backbone_sd, strict=True)
533
- self.model.head.load_state_dict(head_sd_remapped, strict=True)
534
-
535
- # --- Move to device. Backbone → bf16/fp16; head stays fp32. --------
536
- self.model.backbone = self.model.backbone.to(
537
- device=self.device, dtype=dtype)
538
- self.model.head = self.model.head.to(
539
- device=self.device, dtype=torch.float32)
540
- self.model.eval()
541
- print(f"[Tagger] Ready on {self.device} (backbone={dtype}, head=fp32)")
542
-
543
- @torch.no_grad()
544
- def embed_pca(
545
- self,
546
- image,
547
- n_components: int = 3,
548
- max_size: int | None = None,
549
- ) -> "Image.Image":
550
- """Run PCA on the patch-token features of *image* and return a
551
- false-colour RGB PIL image where R/G/B channels correspond to the
552
- first three principal components, each normalised to [0, 255].
553
-
554
- Parameters
555
- ----------
556
- image :
557
- Local path, URL, or PIL.Image.Image.
558
- n_components :
559
- Number of PCA components (must be 3 for RGB output).
560
- max_size :
561
- Long-edge cap in pixels (defaults to ``self.max_size``).
562
- """
563
- if n_components != 3:
564
- raise ValueError("n_components must be 3 for false-colour RGB output")
565
- if max_size is None:
566
- max_size = self.max_size
567
-
568
- if isinstance(image, Image.Image):
569
- img = image.convert("RGB")
570
- w, h = img.size
571
- scale = min(1.0, max_size / max(w, h))
572
- new_w = _snap(round(w * scale), PATCH_SIZE)
573
- new_h = _snap(round(h * scale), PATCH_SIZE)
574
- pv = v2.Compose([
575
- v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
576
- v2.ToImage(),
577
- v2.ToDtype(torch.float32, scale=True),
578
- v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
579
- ])(img).unsqueeze(0).to(self.device)
580
- else:
581
- pv = preprocess_image(image, max_size=max_size).to(self.device)
582
-
583
- with torch.autocast(device_type=self.device.type, dtype=self.dtype):
584
- patch_tokens, h_p, w_p = self.model.backbone.get_image_tokens(pv)
585
-
586
- # patch_tokens: [1, h_p*w_p, D_MODEL] → [N, D]
587
- tokens = patch_tokens[0].float() # fp32 for PCA
588
 
589
- # Centre
590
- mean = tokens.mean(dim=0, keepdim=True)
591
- tokens_c = tokens - mean
 
 
592
 
593
- # PCA via SVD (economy)
594
- _, _, Vt = torch.linalg.svd(tokens_c, full_matrices=False)
595
- components = Vt[:n_components] # [3, D]
596
- projected = tokens_c @ components.T # [N, 3]
597
-
598
- # Normalise each component to [0, 1]
599
- lo = projected.min(dim=0).values
600
- hi = projected.max(dim=0).values
601
- projected = (projected - lo) / (hi - lo + 1e-8)
602
-
603
- # Reshape to spatial grid and convert to uint8 PIL image
604
- rgb = projected.reshape(h_p, w_p, 3).cpu().numpy()
605
- rgb_uint8 = (rgb * 255).clip(0, 255).astype("uint8")
606
- return Image.fromarray(rgb_uint8, mode="RGB")
607
 
608
  @torch.no_grad()
609
  def predict(self, image, topk: int | None = 30,
610
  threshold: float | None = None) -> list[tuple[str, float]]:
611
- """Tag a single image (local path or URL)."""
 
612
  if topk is None and threshold is None:
613
  topk = 30
614
 
615
  pv = preprocess_image(image, max_size=self.max_size).to(self.device)
616
- logits = self.model(pv)[0]
 
617
  scores = torch.sigmoid(logits.float())
618
 
619
  if topk is not None:
@@ -621,18 +381,17 @@ class Tagger:
621
  else:
622
  assert threshold is not None
623
  indices = (scores >= threshold).nonzero(as_tuple=True)[0]
624
- values = scores[indices]
625
- order = values.argsort(descending=True)
626
  indices, values = indices[order], values[order]
627
 
628
- return [(self.idx2tag[i], float(v))
629
- for i, v in zip(indices.tolist(), values.tolist())]
630
 
631
  @torch.no_grad()
632
  def predict_batch(self, images, topk: int | None = 30,
633
- threshold: float | None = None):
634
- return [self.predict(img, topk=topk, threshold=threshold)
635
- for img in images]
636
 
637
 
638
  # =============================================================================
@@ -640,20 +399,17 @@ class Tagger:
640
  # =============================================================================
641
 
642
  def _fmt_pretty(path: str, results) -> str:
643
- lines = [f"\n{'─' * 60}", f" {path}", f"{'─' * 60}"]
644
  for rank, (tag, score) in enumerate(results, 1):
645
  bar = "█" * int(score * 20)
646
- lines.append(f" {rank:>3}. {score:.3f} {bar:<20} {tag}")
647
  return "\n".join(lines)
648
 
649
-
650
  def _fmt_tags(results) -> str:
651
  return ", ".join(tag for tag, _ in results)
652
 
653
-
654
  def _fmt_json(path: str, results) -> dict:
655
- return {"file": path,
656
- "tags": [{"tag": t, "score": round(s, 4)} for t, s in results]}
657
 
658
 
659
  # =============================================================================
@@ -662,40 +418,28 @@ def _fmt_json(path: str, results) -> dict:
662
 
663
  def main():
664
  parser = argparse.ArgumentParser(
665
- description="DINOv3 ViT-H/16+ tagger inference (standalone)",
666
  formatter_class=argparse.RawDescriptionHelpFormatter,
667
  )
668
- parser.add_argument("--checkpoint", required=True,
669
- help="Path to .safetensors or .pt checkpoint")
670
- parser.add_argument("--vocab", required=True,
671
- help="Path to tagger_vocab*.json")
672
- parser.add_argument("--images", nargs="+", required=True,
673
- help="Image paths and/or http(s) URLs")
674
- parser.add_argument("--device", default="cuda",
675
- help="Device: cuda, cuda:0, cpu (default: cuda)")
676
  parser.add_argument("--max-size", type=int, default=1024,
677
- help="Long-edge cap in pixels (default: 1024)")
678
 
679
  mode = parser.add_mutually_exclusive_group()
680
- mode.add_argument("--topk", type=int, default=30,
681
- help="Return top-k tags (default: 30)")
682
- mode.add_argument("--threshold", type=float,
683
- help="Return all tags with score >= threshold")
684
 
685
  parser.add_argument("--format", choices=["pretty", "tags", "json"],
686
  default="pretty", help="Output format (default: pretty)")
687
  args = parser.parse_args()
688
 
689
- tagger = Tagger(
690
- checkpoint_path=args.checkpoint,
691
- vocab_path=args.vocab,
692
- device=args.device,
693
- max_size=args.max_size,
694
- )
695
 
696
- topk, threshold = (
697
- (None, args.threshold) if args.threshold else (args.topk, None)
698
- )
699
  json_out = []
700
 
701
  for src in args.images:
@@ -704,16 +448,13 @@ def main():
704
  print(f"[warning] File not found: {src}", file=sys.stderr)
705
  continue
706
  results = tagger.predict(src, topk=topk, threshold=threshold)
707
- if args.format == "pretty":
708
- print(_fmt_pretty(src, results))
709
- elif args.format == "tags":
710
- print(_fmt_tags(results))
711
- elif args.format == "json":
712
- json_out.append(_fmt_json(src, results))
713
 
714
  if args.format == "json":
715
  print(json.dumps(json_out, indent=2, ensure_ascii=False))
716
 
717
 
718
  if __name__ == "__main__":
719
- main()
 
64
  # All hyperparameters match facebook/dinov3-vith16plus-pretrain-lvd1689m
65
  # =============================================================================
66
 
67
+ D_MODEL = 1280
68
+ N_HEADS = 20
69
+ HEAD_DIM = D_MODEL // N_HEADS # 64
70
+ N_LAYERS = 32
71
+ D_FFN = 5120
72
  N_REGISTERS = 4
73
+ PATCH_SIZE = 16
74
+ ROPE_THETA = 100.0
75
+ ROPE_RESCALE = 2.0 # pos_embed_rescale applied at inference
76
+ LN_EPS = 1e-5
77
+ LAYERSCALE = 1.0
 
 
78
 
79
 
80
  # ---------------------------------------------------------------------------
 
83
 
84
  @lru_cache(maxsize=32)
85
  def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
86
+ """Normalised [-1,+1] patch-centre coordinates (float32, cached)."""
87
  device = torch.device(device_str)
88
  cy = torch.arange(0.5, h, dtype=torch.float32, device=device) / h
89
  cx = torch.arange(0.5, w, dtype=torch.float32, device=device) / w
90
  coords = torch.stack(torch.meshgrid(cy, cx, indexing="ij"), dim=-1).flatten(0, 1)
91
+ coords = 2.0 * coords - 1.0 # [0,1] → [-1,+1]
92
  coords = coords * ROPE_RESCALE
93
  return coords # [h*w, 2]
94
 
95
 
96
  def _build_rope(h_patches: int, w_patches: int,
97
  dtype: torch.dtype, device: torch.device):
98
+ """Return (cos, sin) of shape [1, 1, h*w, HEAD_DIM] for broadcasting."""
99
+ coords = _patch_coords_cached(h_patches, w_patches, str(device)) # [P, 2]
100
  inv_freq = 1.0 / (ROPE_THETA ** torch.arange(
101
+ 0, 1, 4 / HEAD_DIM, dtype=torch.float32, device=device)) # [D/4]
102
+ angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :] # [P, 2, D/4]
103
+ angles = angles.flatten(1, 2).tile(2) # [P, D]
104
+ cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0) # [1,1,P,D]
105
  sin = torch.sin(angles).to(dtype).unsqueeze(0).unsqueeze(0)
106
  return cos, sin
107
 
 
113
 
114
  def _apply_rope(q: torch.Tensor, k: torch.Tensor,
115
  cos: torch.Tensor, sin: torch.Tensor):
116
+ """Apply RoPE only to patch tokens (skip CLS + register prefix)."""
117
  n_pre = 1 + N_REGISTERS
118
  q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :]
119
  k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :]
 
123
 
124
 
125
  # ---------------------------------------------------------------------------
126
+ # Building blocks
127
  # ---------------------------------------------------------------------------
128
 
129
  class _Attention(nn.Module):
 
134
  self.v_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
135
  self.o_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
136
 
137
+ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
138
  B, S, _ = x.shape
139
  q = self.q_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
140
  k = self.k_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
 
148
  def __init__(self):
149
  super().__init__()
150
  self.gate_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
151
+ self.up_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
152
+ self.down_proj = nn.Linear(D_FFN, D_MODEL, bias=True)
153
 
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
  return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
156
 
157
 
158
  class _Block(nn.Module):
159
  def __init__(self):
160
  super().__init__()
161
+ self.norm1 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
162
+ self.attention = _Attention()
163
  self.layer_scale1 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
164
+ self.norm2 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
165
+ self.mlp = _GatedMLP()
166
  self.layer_scale2 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
167
 
168
+ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
169
  x = x + self.attention(self.norm1(x), cos, sin) * self.layer_scale1
170
  x = x + self.mlp(self.norm2(x)) * self.layer_scale2
171
  return x
172
 
173
 
174
+ # ---------------------------------------------------------------------------
175
+ # Full backbone
176
+ # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  class DINOv3ViTH(nn.Module):
179
  """DINOv3 ViT-H/16+ backbone.
180
 
181
+ Accepts any H, W that are multiples of 16.
182
  Returns last_hidden_state [B, 1+R+P, D_MODEL].
183
+ Token layout: [CLS, reg_0..reg_3, patch_0..patch_N].
184
+
185
+ State-dict keys are intentionally identical to the HuggingFace
186
+ transformers layout so .safetensors checkpoints load without remapping.
187
  """
188
 
189
  def __init__(self):
190
  super().__init__()
191
+ # These names must match HF exactly
192
  self.embeddings = _Embeddings()
193
  self.layer = nn.ModuleList([_Block() for _ in range(N_LAYERS)])
194
+ self.norm = nn.LayerNorm(D_MODEL, eps=LN_EPS)
195
+
196
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
197
+ strict, missing_keys, unexpected_keys, error_msgs):
198
+ # HF stores layer_scale as a sub-module with a "lambda1" parameter;
199
+ # we store it as a plain Parameter directly on _Block.
200
+ # Remap "layer.i.layer_scale{1,2}.lambda1" → "layer.i.layer_scale{1,2}"
201
+ for k in list(state_dict.keys()):
202
+ if k.startswith(prefix) and ".layer_scale" in k and k.endswith(".lambda1"):
203
+ new_k = k[:-len(".lambda1")]
204
+ state_dict[new_k] = state_dict.pop(k)
205
+ # Drop rope_embeddings buffer (computed on-the-fly)
206
+ for k in list(state_dict.keys()):
207
+ if k.startswith(prefix) and "rope_embeddings" in k:
208
+ state_dict.pop(k)
209
+ super()._load_from_state_dict(
210
+ state_dict, prefix, local_metadata, strict,
211
+ missing_keys, unexpected_keys, error_msgs)
212
+
213
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
214
+ B, _, H, W = pixel_values.shape
215
+ x = self.embeddings(pixel_values) # [B, 1+R+P, D]
216
 
 
 
 
217
  h_p, w_p = H // PATCH_SIZE, W // PATCH_SIZE
218
  cos, sin = _build_rope(h_p, w_p, x.dtype, pixel_values.device)
 
 
 
219
 
 
 
 
 
 
 
 
220
  for block in self.layer:
221
  x = block(x, cos, sin)
 
 
 
 
 
222
 
223
+ return self.norm(x)
 
 
224
 
 
 
225
 
226
+ class _Embeddings(nn.Module):
227
+ """Patch + CLS + register token embeddings.
228
+ Key names match HF: embeddings.cls_token, embeddings.register_tokens,
229
+ embeddings.patch_embeddings.{weight,bias}.
230
  """
231
 
232
+ def __init__(self):
 
233
  super().__init__()
234
+ self.cls_token = nn.Parameter(torch.empty(1, 1, D_MODEL))
235
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, D_MODEL)) # unused at inference
236
+ self.register_tokens = nn.Parameter(torch.empty(1, N_REGISTERS, D_MODEL))
237
+ self.patch_embeddings = nn.Conv2d(3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE)
 
238
 
239
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
240
+ B = pixel_values.shape[0]
241
+ dtype = self.patch_embeddings.weight.dtype
242
+ patches = self.patch_embeddings(pixel_values.to(dtype)).flatten(2).transpose(1, 2)
243
+ cls = self.cls_token.expand(B, -1, -1)
244
+ regs = self.register_tokens.expand(B, -1, -1)
245
+ return torch.cat([cls, regs, patches], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
 
248
  # =============================================================================
249
+ # Tagger head
250
  # =============================================================================
251
 
252
  class DINOv3Tagger(nn.Module):
253
+ """DINOv3 ViT-H/16+ backbone + linear projection head.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
+ features = concat(CLS, reg_0..reg_3) [B, (1+R)*D]
256
+ projection: Linear [B, num_tags]
 
 
 
 
 
 
 
 
 
 
257
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
+ def __init__(self, num_tags: int, projection_bias: bool = False):
260
+ super().__init__()
261
+ self.backbone = DINOv3ViTH()
262
+ self.projection = nn.Linear((1 + N_REGISTERS) * D_MODEL, num_tags, bias=projection_bias)
263
 
264
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
265
+ hidden = self.backbone(pixel_values) # [B, S, D]
266
+ cls = hidden[:, 0, :] # [B, D]
267
+ regs = hidden[:, 1: 1 + N_REGISTERS, :].flatten(1) # [B, R*D]
268
+ features = torch.cat([cls, regs], dim=-1) # [B, (1+R)*D]
269
+ return self.projection(features.float()) # fp32 for stability
270
 
271
 
272
  # =============================================================================
 
274
  # =============================================================================
275
 
276
  _IMAGENET_MEAN = [0.485, 0.456, 0.406]
277
+ _IMAGENET_STD = [0.229, 0.224, 0.225]
278
 
279
 
280
  def _snap(x: int, m: int) -> int:
 
291
 
292
 
293
  def preprocess_image(source, max_size: int = 1024) -> torch.Tensor:
294
+ """Load and preprocess an image → [1, 3, H, W] float32, ImageNet-normalised."""
 
 
 
 
295
  img = _open_image(source)
296
  w, h = img.size
297
+ scale = min(1.0, max_size / max(w, h))
298
+ new_w = _snap(round(w * scale), PATCH_SIZE)
299
+ new_h = _snap(round(h * scale), PATCH_SIZE)
 
 
 
 
 
 
300
  return v2.Compose([
301
  v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
302
  v2.ToImage(),
 
315
  Parameters
316
  ----------
317
  checkpoint_path : str
318
+ Path to a .safetensors or .pth checkpoint saved by TaggerTrainer.
319
  vocab_path : str
320
+ Path to tagger_vocab.json ({"idx2tag": [...]}).
 
321
  device : str
322
+ "cuda", "cuda:0", "cpu", etc.
323
  dtype : torch.dtype
324
+ bfloat16 recommended on Ampere+; float16 for older GPUs; float32 for CPU.
 
325
  max_size : int
326
  Long-edge cap in pixels before feeding to the model.
327
  """
 
334
  dtype: torch.dtype = torch.bfloat16,
335
  max_size: int = 1024,
336
  ):
337
+ self.device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
338
+ self.dtype = dtype
 
 
 
 
 
339
  self.max_size = max_size
340
 
341
  with open(vocab_path) as f:
 
344
  self.num_tags = len(self.idx2tag)
345
  print(f"[Tagger] Vocabulary: {self.num_tags:,} tags")
346
 
347
+ self.model = DINOv3Tagger(num_tags=self.num_tags)
348
+
349
  print(f"[Tagger] Loading checkpoint: {checkpoint_path}")
350
  if checkpoint_path.endswith((".safetensors", ".sft")):
351
+ sd = load_file(checkpoint_path, device=str(self.device))
352
  else:
353
+ sd = torch.load(checkpoint_path, map_location=str(self.device))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
+ missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
356
+ if missing:
357
+ print(f"[Tagger] Missing keys ({len(missing)}): {missing[:5]}{'...' if len(missing) > 5 else ''}")
358
+ if unexpected:
359
+ print(f"[Tagger] Unexpected keys ({len(unexpected)}): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
360
 
361
+ self.model.backbone = self.model.backbone.to(dtype=dtype)
362
+ self.model = self.model.to(self.device)
363
+ self.model.eval()
364
+ print(f"[Tagger] Ready on {self.device} ({dtype})")
 
 
 
 
 
 
 
 
 
 
365
 
366
  @torch.no_grad()
367
  def predict(self, image, topk: int | None = 30,
368
  threshold: float | None = None) -> list[tuple[str, float]]:
369
+ """Tag a single image (local path or URL).
370
+ Specify either topk OR threshold. Returns [(tag, score), ...] desc."""
371
  if topk is None and threshold is None:
372
  topk = 30
373
 
374
  pv = preprocess_image(image, max_size=self.max_size).to(self.device)
375
+ with torch.autocast(device_type=self.device.type, dtype=self.dtype):
376
+ logits = self.model(pv)[0]
377
  scores = torch.sigmoid(logits.float())
378
 
379
  if topk is not None:
 
381
  else:
382
  assert threshold is not None
383
  indices = (scores >= threshold).nonzero(as_tuple=True)[0]
384
+ values = scores[indices]
385
+ order = values.argsort(descending=True)
386
  indices, values = indices[order], values[order]
387
 
388
+ return [(self.idx2tag[i], float(v)) for i, v in zip(indices.tolist(), values.tolist())]
 
389
 
390
  @torch.no_grad()
391
  def predict_batch(self, images, topk: int | None = 30,
392
+ threshold: float | None = None) -> list[list[tuple[str, float]]]:
393
+ """Tag multiple images (processed individually for mixed resolutions)."""
394
+ return [self.predict(img, topk=topk, threshold=threshold) for img in images]
395
 
396
 
397
  # =============================================================================
 
399
  # =============================================================================
400
 
401
  def _fmt_pretty(path: str, results) -> str:
402
+ lines = [f"\n{'─' * 60}", f" {path}", f"{'─' * 60}"]
403
  for rank, (tag, score) in enumerate(results, 1):
404
  bar = "█" * int(score * 20)
405
+ lines.append(f" {rank:>3}. {score:.3f} {bar:<20} {tag}")
406
  return "\n".join(lines)
407
 
 
408
  def _fmt_tags(results) -> str:
409
  return ", ".join(tag for tag, _ in results)
410
 
 
411
  def _fmt_json(path: str, results) -> dict:
412
+ return {"file": path, "tags": [{"tag": t, "score": round(s, 4)} for t, s in results]}
 
413
 
414
 
415
  # =============================================================================
 
418
 
419
  def main():
420
  parser = argparse.ArgumentParser(
421
+ description="DINOv3 ViT-H/16+ tagger inference (standalone, no transformers dep)",
422
  formatter_class=argparse.RawDescriptionHelpFormatter,
423
  )
424
+ parser.add_argument("--checkpoint", required=True, help="Path to .safetensors or .pth checkpoint")
425
+ parser.add_argument("--vocab", required=True, help="Path to tagger_vocab.json")
426
+ parser.add_argument("--images", nargs="+", required=True, help="Image paths and/or http(s) URLs")
427
+ parser.add_argument("--device", default="cuda", help="Device: cuda, cuda:0, cpu, … (default: cuda)")
 
 
 
 
428
  parser.add_argument("--max-size", type=int, default=1024,
429
+ help="Long-edge cap in pixels, multiple of 16 (default: 1024)")
430
 
431
  mode = parser.add_mutually_exclusive_group()
432
+ mode.add_argument("--topk", type=int, default=30, help="Return top-k tags (default: 30)")
433
+ mode.add_argument("--threshold", type=float, help="Return all tags with score >= threshold")
 
 
434
 
435
  parser.add_argument("--format", choices=["pretty", "tags", "json"],
436
  default="pretty", help="Output format (default: pretty)")
437
  args = parser.parse_args()
438
 
439
+ tagger = Tagger(checkpoint_path=args.checkpoint, vocab_path=args.vocab,
440
+ device=args.device, max_size=args.max_size)
 
 
 
 
441
 
442
+ topk, threshold = (None, args.threshold) if args.threshold else (args.topk, None)
 
 
443
  json_out = []
444
 
445
  for src in args.images:
 
448
  print(f"[warning] File not found: {src}", file=sys.stderr)
449
  continue
450
  results = tagger.predict(src, topk=topk, threshold=threshold)
451
+ if args.format == "pretty": print(_fmt_pretty(src, results))
452
+ elif args.format == "tags": print(_fmt_tags(results))
453
+ elif args.format == "json": json_out.append(_fmt_json(src, results))
 
 
 
454
 
455
  if args.format == "json":
456
  print(json.dumps(json_out, indent=2, ensure_ascii=False))
457
 
458
 
459
  if __name__ == "__main__":
460
+ main()
requirements.txt DELETED
@@ -1,11 +0,0 @@
1
- packaging
2
- safetensors
3
- requests
4
- Pillow
5
- torch
6
- torchvision
7
- python-multipart
8
- fastapi>=0.121.0
9
- uvicorn
10
- jinja2
11
- aiofiles
 
 
 
 
 
 
 
 
 
 
 
 
tagger_proto.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6ef471936e144eb75a21b37e664f7499e0139a6643170a59473cdde8d1d4c238
3
- size 5272842048
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20fd8c1cad2fa5653c632d847397f0491faa662d33abde1e9239041bc95b8b6c
3
+ size 5272838400
tagger_ui/templates/index.html CHANGED
@@ -202,36 +202,6 @@
202
  .tag-pill:hover { opacity: .8; }
203
  .tag-pill .score { font-size: .66rem; opacity: .7; }
204
  .tag-pill.hidden { display: none; }
205
-
206
- /* ---- PCA panel ---- */
207
- .preview-wrap { flex-wrap: wrap; }
208
- .preview-col { flex: 1 1 0; min-width: 0; }
209
- .pca-col {
210
- flex: 1 1 0; min-width: 0;
211
- display: flex; flex-direction: column; gap: .5rem;
212
- }
213
- .pca-label {
214
- font-size: .72rem; color: var(--muted); text-align: center;
215
- letter-spacing: .04em; text-transform: uppercase;
216
- }
217
- #pca-img {
218
- border-radius: var(--radius); width: 100%; max-height: 420px;
219
- object-fit: contain; border: 1px solid var(--border);
220
- display: block; image-rendering: pixelated;
221
- }
222
- #pca-spinner {
223
- display: none; width: 18px; height: 18px; margin: auto;
224
- border: 3px solid var(--border); border-top-color: var(--accent);
225
- border-radius: 50%; animation: spin .7s linear infinite;
226
- }
227
- .pca-toggle {
228
- background: var(--bg); border: 1px solid var(--border);
229
- border-radius: 6px; color: var(--muted); cursor: pointer;
230
- font-size: .75rem; padding: .3rem .7rem; align-self: center;
231
- transition: border-color .15s, color .15s;
232
- }
233
- .pca-toggle:hover { border-color: var(--accent); color: var(--text); }
234
- .pca-toggle.active { border-color: var(--accent); color: #a78bfa; }
235
  </style>
236
  </head>
237
  <body>
@@ -269,22 +239,12 @@
269
 
270
  <div id="results-area">
271
 
272
- <!-- image + PCA side by side -->
273
  <div class="preview-wrap">
274
- <div class="preview-col">
275
  <img id="preview-img" src="" alt="preview" />
276
  <div class="img-meta" id="img-meta"></div>
277
  </div>
278
- <div class="pca-col" id="pca-col" style="display:none">
279
- <div class="pca-label">PCA · patch features (R=PC1, G=PC2, B=PC3)</div>
280
- <div id="pca-spinner"></div>
281
- <img id="pca-img" src="" alt="PCA" style="display:none" />
282
- </div>
283
- </div>
284
-
285
- <!-- PCA toggle -->
286
- <div style="display:flex;justify-content:flex-end;margin-bottom:.6rem">
287
- <button class="pca-toggle" id="pca-toggle" onclick="togglePca()">Show PCA</button>
288
  </div>
289
 
290
  <!-- global copy bar -->
@@ -338,48 +298,6 @@
338
  if (el) el.value = Math.max(1, Math.min(99, parseInt(pct) || 1));
339
  }
340
 
341
- // ---- PCA state ----
342
- let _pcaEnabled = false;
343
- let _lastPcaRequest = null; // { type: 'url'|'file', url?: string, file?: File }
344
-
345
- function togglePca() {
346
- _pcaEnabled = !_pcaEnabled;
347
- const btn = document.getElementById('pca-toggle');
348
- btn.textContent = _pcaEnabled ? 'Hide PCA' : 'Show PCA';
349
- btn.classList.toggle('active', _pcaEnabled);
350
- document.getElementById('pca-col').style.display = _pcaEnabled ? 'flex' : 'none';
351
- if (_pcaEnabled && _lastPcaRequest) runPca(_lastPcaRequest);
352
- }
353
-
354
- function runPca(req) {
355
- const spinner = document.getElementById('pca-spinner');
356
- const img = document.getElementById('pca-img');
357
- spinner.style.display = 'block';
358
- img.style.display = 'none';
359
-
360
- const maxSize = document.getElementById('maxsize-input').value;
361
- let fetchPromise;
362
- if (req.type === 'url') {
363
- fetchPromise = fetch(
364
- `/pca/url?max_size=${maxSize}&url=${encodeURIComponent(req.url)}`,
365
- { method: 'POST' }
366
- );
367
- } else {
368
- const fd = new FormData();
369
- fd.append('file', req.file);
370
- fetchPromise = fetch(`/pca/upload?max_size=${maxSize}`, { method: 'POST', body: fd });
371
- }
372
-
373
- fetchPromise
374
- .then(r => r.ok ? r.blob() : Promise.reject('PCA failed'))
375
- .then(blob => {
376
- img.src = URL.createObjectURL(blob);
377
- img.style.display = 'block';
378
- })
379
- .catch(() => { img.style.display = 'none'; })
380
- .finally(() => { spinner.style.display = 'none'; });
381
- }
382
-
383
  // ---- drag & drop ----
384
  const dz = document.getElementById('drop-zone');
385
  dz.addEventListener('dragover', e => { e.preventDefault(); dz.classList.add('drag-over'); });
@@ -396,8 +314,6 @@
396
  const url = document.getElementById('url-input').value.trim();
397
  if (!url) return;
398
  setPreview(url, url);
399
- _lastPcaRequest = { type: 'url', url };
400
- if (_pcaEnabled) runPca(_lastPcaRequest);
401
  submitFetch(`/tag/url?max_size=${document.getElementById('maxsize-input').value}&url=${encodeURIComponent(url)}`,
402
  { method: 'POST' });
403
  }
@@ -409,8 +325,6 @@
409
  const reader = new FileReader();
410
  reader.onload = e => setPreview(e.target.result, file.name);
411
  reader.readAsDataURL(file);
412
- _lastPcaRequest = { type: 'file', file };
413
- if (_pcaEnabled) runPca(_lastPcaRequest);
414
  submitFetch(`/tag/upload?max_size=${maxSize}`, { method: 'POST', body: fd });
415
  }
416
 
 
202
  .tag-pill:hover { opacity: .8; }
203
  .tag-pill .score { font-size: .66rem; opacity: .7; }
204
  .tag-pill.hidden { display: none; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  </style>
206
  </head>
207
  <body>
 
239
 
240
  <div id="results-area">
241
 
242
+ <!-- image full-width on top -->
243
  <div class="preview-wrap">
244
+ <div style="width:100%">
245
  <img id="preview-img" src="" alt="preview" />
246
  <div class="img-meta" id="img-meta"></div>
247
  </div>
 
 
 
 
 
 
 
 
 
 
248
  </div>
249
 
250
  <!-- global copy bar -->
 
298
  if (el) el.value = Math.max(1, Math.min(99, parseInt(pct) || 1));
299
  }
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  // ---- drag & drop ----
302
  const dz = document.getElementById('drop-zone');
303
  dz.addEventListener('dragover', e => { e.preventDefault(); dz.classList.add('drag-over'); });
 
314
  const url = document.getElementById('url-input').value.trim();
315
  if (!url) return;
316
  setPreview(url, url);
 
 
317
  submitFetch(`/tag/url?max_size=${document.getElementById('maxsize-input').value}&url=${encodeURIComponent(url)}`,
318
  { method: 'POST' });
319
  }
 
325
  const reader = new FileReader();
326
  reader.onload = e => setPreview(e.target.result, file.name);
327
  reader.readAsDataURL(file);
 
 
328
  submitFetch(`/tag/upload?max_size=${maxSize}`, { method: 'POST', body: fd });
329
  }
330
 
tagger_ui_server.py CHANGED
@@ -48,7 +48,7 @@ CATEGORY_META: dict[int, dict] = {
48
  3: {"name": "contributor", "color": "#a78bfa"}, # raw 2
49
  4: {"name": "copyright", "color": "#fb923c"}, # raw 3
50
  5: {"name": "character", "color": "#60a5fa"}, # raw 4
51
- 6: {"name": "species", "color": "#facc15"}, # raw 5
52
  7: {"name": "disambiguation", "color": "#94a3b8"}, # raw 6
53
  8: {"name": "meta", "color": "#e2e8f0"}, # raw 7
54
  9: {"name": "lore", "color": "#f87171"}, # raw 8
@@ -113,46 +113,6 @@ async def tag_upload(
113
  return _run_tagger(img, max_size, floor)
114
 
115
 
116
- # ---------------------------------------------------------------------------
117
- # PCA endpoints
118
- # ---------------------------------------------------------------------------
119
-
120
- @app.post("/pca/url")
121
- async def pca_url(
122
- url: str = Query(...),
123
- max_size: int = Query(default=1024),
124
- ):
125
- from fastapi.responses import Response
126
- assert _tagger is not None
127
- try:
128
- from inference_tagger_standalone import _open_image
129
- img = _open_image(url)
130
- except Exception as e:
131
- raise HTTPException(status_code=400, detail=f"Could not fetch image: {e}")
132
- pca_img = _tagger.embed_pca(img, max_size=max_size)
133
- buf = io.BytesIO()
134
- pca_img.save(buf, format="PNG")
135
- return Response(content=buf.getvalue(), media_type="image/png")
136
-
137
-
138
- @app.post("/pca/upload")
139
- async def pca_upload(
140
- file: UploadFile = File(...),
141
- max_size: int = Query(default=1024),
142
- ):
143
- from fastapi.responses import Response
144
- assert _tagger is not None
145
- try:
146
- data = await file.read()
147
- img = Image.open(io.BytesIO(data)).convert("RGB")
148
- except Exception as e:
149
- raise HTTPException(status_code=400, detail=f"Could not read image: {e}")
150
- pca_img = _tagger.embed_pca(img, max_size=max_size)
151
- buf = io.BytesIO()
152
- pca_img.save(buf, format="PNG")
153
- return Response(content=buf.getvalue(), media_type="image/png")
154
-
155
-
156
  # ---------------------------------------------------------------------------
157
  # Inference helper
158
  # ---------------------------------------------------------------------------
 
48
  3: {"name": "contributor", "color": "#a78bfa"}, # raw 2
49
  4: {"name": "copyright", "color": "#fb923c"}, # raw 3
50
  5: {"name": "character", "color": "#60a5fa"}, # raw 4
51
+ 6: {"name": "species/meta", "color": "#facc15"}, # raw 5
52
  7: {"name": "disambiguation", "color": "#94a3b8"}, # raw 6
53
  8: {"name": "meta", "color": "#e2e8f0"}, # raw 7
54
  9: {"name": "lore", "color": "#f87171"}, # raw 8
 
113
  return _run_tagger(img, max_size, floor)
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  # ---------------------------------------------------------------------------
117
  # Inference helper
118
  # ---------------------------------------------------------------------------
tagger_vocab_with_categories.json CHANGED
The diff for this file is too large to render. See raw diff
 
tagger_vocab_with_categories_and_alias.json DELETED
The diff for this file is too large to render. See raw diff
 
tagger_vocab_with_categories_and_alias_updated.json DELETED
The diff for this file is too large to render. See raw diff