Update SigLino siglino-70M (full content push)
Browse files- README.md +23 -9
- config.json +5 -5
- configuration_siglino.py +91 -0
- image_processing_siglino.py +121 -0
- image_processor.py +2 -2
- modeling_siglino.py +339 -0
- preprocessor_config.json +2 -2
- utils.py +13 -13
README.md
CHANGED
|
@@ -6,17 +6,19 @@ tags:
|
|
| 6 |
- image-feature-extraction
|
| 7 |
---
|
| 8 |
|
| 9 |
-
#
|
| 10 |
|
| 11 |
**Accepted at CVPR 2026**
|
| 12 |
|
| 13 |
-
[](https://sofianchay.github.io/
|
| 14 |
[](https://arxiv.org/abs/2512.20157)
|
| 15 |
-
[](https://github.com/tiiuae/
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
| 20 |
|
| 21 |
## Usage
|
| 22 |
|
|
@@ -25,7 +27,7 @@ import torch
|
|
| 25 |
from PIL import Image
|
| 26 |
from transformers import AutoModel, AutoImageProcessor
|
| 27 |
|
| 28 |
-
model_id = "tiiuae/
|
| 29 |
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to("cuda", dtype=torch.bfloat16)
|
| 30 |
processor = AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
|
| 31 |
|
|
@@ -36,8 +38,8 @@ inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
|
|
| 36 |
with torch.no_grad():
|
| 37 |
outputs = model(**inputs)
|
| 38 |
|
| 39 |
-
# Options: '
|
| 40 |
-
patch_features = outputs["patch_features"]["
|
| 41 |
summary_features = outputs["summary_features"]["siglip2"] # (Batch, 1152)
|
| 42 |
```
|
| 43 |
|
|
@@ -53,11 +55,23 @@ summary_features = outputs["summary_features"]["siglip2"] # (Batch, 1152)
|
|
| 53 |
| Patch Size | 16x16 |
|
| 54 |
| Teachers | DINOv3, SigLIP2 |
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
## Citation
|
| 57 |
|
| 58 |
```bibtex
|
| 59 |
@article{chaybouti2025amoe,
|
| 60 |
-
title={
|
| 61 |
author={Chaybouti, Sofian and Narayan, Sanath and Dahou, Yasser and Le Khac, Phuc H. and Singh, Ankit and Huynh, Ngoc Dung and Para, Wamiq Reyaz and Kuehne, Hilde and Hacid, Hakim},
|
| 62 |
journal={arXiv preprint arXiv:2512.20157},
|
| 63 |
year={2025}
|
|
|
|
| 6 |
- image-feature-extraction
|
| 7 |
---
|
| 8 |
|
| 9 |
+
# SigLino-70M
|
| 10 |
|
| 11 |
**Accepted at CVPR 2026**
|
| 12 |
|
| 13 |
+
[](https://sofianchay.github.io/siglino/)
|
| 14 |
[](https://arxiv.org/abs/2512.20157)
|
| 15 |
+
[](https://github.com/tiiuae/siglino)
|
| 16 |
|
| 17 |
+
This work stems from the **CVPR 2026 AMoE paper**, which designs and applies distillation into a Mixture-of-Experts (MoE) vision architecture. We have chosen the name **SigLino** for better clarity (SigLIP2 + DINOv3).
|
| 18 |
|
| 19 |
+
Dense variant of SigLino. 70M parameters.
|
| 20 |
+
|
| 21 |
+
Part of the [SigLino model family](https://huggingface.co/collections/tiiuae/siglino-vision-foundation-models).
|
| 22 |
|
| 23 |
## Usage
|
| 24 |
|
|
|
|
| 27 |
from PIL import Image
|
| 28 |
from transformers import AutoModel, AutoImageProcessor
|
| 29 |
|
| 30 |
+
model_id = "tiiuae/siglino-70M"
|
| 31 |
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to("cuda", dtype=torch.bfloat16)
|
| 32 |
processor = AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
|
| 33 |
|
|
|
|
| 38 |
with torch.no_grad():
|
| 39 |
outputs = model(**inputs)
|
| 40 |
|
| 41 |
+
# Options: 'siglino' (512d), 'siglip2' (1152d), 'dinov3' (1024d)
|
| 42 |
+
patch_features = outputs["patch_features"]["siglino"] # (Batch, Tokens, 512)
|
| 43 |
summary_features = outputs["summary_features"]["siglip2"] # (Batch, 1152)
|
| 44 |
```
|
| 45 |
|
|
|
|
| 55 |
| Patch Size | 16x16 |
|
| 56 |
| Teachers | DINOv3, SigLIP2 |
|
| 57 |
|
| 58 |
+
## Results (512x512, ensemble features)
|
| 59 |
+
|
| 60 |
+
| Task | Metric | Score |
|
| 61 |
+
|------|--------|-------|
|
| 62 |
+
| kNN (ImageNet) | Acc | 81.7 |
|
| 63 |
+
| kNN (6-dataset avg) | Acc | 86.2 |
|
| 64 |
+
| Zero-shot cls (ImageNet) | Acc | 71.2 |
|
| 65 |
+
| Flickr30K I2T | R@1 | 90.5 |
|
| 66 |
+
| MSCOCO I2T | R@1 | 65.4 |
|
| 67 |
+
| Pascal VOC (1024) | mIoU | 84.8 |
|
| 68 |
+
| Cityscapes (1024) | mIoU | 61.6 |
|
| 69 |
+
|
| 70 |
## Citation
|
| 71 |
|
| 72 |
```bibtex
|
| 73 |
@article{chaybouti2025amoe,
|
| 74 |
+
title={AMoE: Agglomerative Mixture-of-Experts Vision Foundation Models},
|
| 75 |
author={Chaybouti, Sofian and Narayan, Sanath and Dahou, Yasser and Le Khac, Phuc H. and Singh, Ankit and Huynh, Ngoc Dung and Para, Wamiq Reyaz and Kuehne, Hilde and Hacid, Hakim},
|
| 76 |
journal={arXiv preprint arXiv:2512.20157},
|
| 77 |
year={2025}
|
config.json
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
{
|
| 2 |
"activation": "silu",
|
| 3 |
"architectures": [
|
| 4 |
-
"
|
| 5 |
],
|
| 6 |
"auto_map": {
|
| 7 |
-
"AutoConfig": "
|
| 8 |
-
"AutoImageProcessor": "
|
| 9 |
-
"AutoModel": "
|
| 10 |
},
|
| 11 |
"channel_size": 3,
|
| 12 |
"dim": 512,
|
|
@@ -16,7 +16,7 @@
|
|
| 16 |
"first_n_layers_dense": 12,
|
| 17 |
"head_dim": 64,
|
| 18 |
"max_seq_len": 8192,
|
| 19 |
-
"model_type": "
|
| 20 |
"moe_args": {
|
| 21 |
"activation": "silu",
|
| 22 |
"num_experts": 1,
|
|
|
|
| 1 |
{
|
| 2 |
"activation": "silu",
|
| 3 |
"architectures": [
|
| 4 |
+
"SigLinoModel"
|
| 5 |
],
|
| 6 |
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_siglino.SigLinoConfig",
|
| 8 |
+
"AutoImageProcessor": "image_processing_siglino.SigLinoImageProcessor",
|
| 9 |
+
"AutoModel": "modeling_siglino.SigLinoModel"
|
| 10 |
},
|
| 11 |
"channel_size": 3,
|
| 12 |
"dim": 512,
|
|
|
|
| 16 |
"first_n_layers_dense": 12,
|
| 17 |
"head_dim": 64,
|
| 18 |
"max_seq_len": 8192,
|
| 19 |
+
"model_type": "siglino",
|
| 20 |
"moe_args": {
|
| 21 |
"activation": "silu",
|
| 22 |
"num_experts": 1,
|
configuration_siglino.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
from typing import Optional, List, Union, Dict, Tuple
|
| 3 |
+
|
| 4 |
+
class SigLinoConfig(PretrainedConfig):
|
| 5 |
+
"""
|
| 6 |
+
Configuration class to store the configuration of an `SigLinoModel`.
|
| 7 |
+
"""
|
| 8 |
+
model_type = "siglino"
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
dim: int = 768,
|
| 13 |
+
n_layers: int = 18,
|
| 14 |
+
n_heads: int = 12,
|
| 15 |
+
head_dim: Optional[int] = 128,
|
| 16 |
+
n_kv_heads: Optional[int] = 4,
|
| 17 |
+
# MoE configuration
|
| 18 |
+
moe_dim: int = 768,
|
| 19 |
+
moe_args: Optional[Dict] = None,
|
| 20 |
+
# Dense FFN configuration
|
| 21 |
+
first_n_layers_dense: int = 0,
|
| 22 |
+
ffn_dim: Optional[int] = None,
|
| 23 |
+
activation: str = "silu",
|
| 24 |
+
# Vision settings
|
| 25 |
+
channel_size: int = 3,
|
| 26 |
+
spatial_patch_size: int = 16,
|
| 27 |
+
temporal_patch_size: int = 1,
|
| 28 |
+
# RoPE settings
|
| 29 |
+
enable_3d_rope: bool = True,
|
| 30 |
+
rope_theta: float = 100000.0,
|
| 31 |
+
rope_min_freqs: float = 1.0,
|
| 32 |
+
rope_max_freqs: float = 20.0,
|
| 33 |
+
max_seq_len: int = 8192,
|
| 34 |
+
# Normalization
|
| 35 |
+
norm_eps: float = 1e-5,
|
| 36 |
+
use_qk_norm: bool = True,
|
| 37 |
+
use_tok_norm: bool = True,
|
| 38 |
+
parameterized_norm: bool = True,
|
| 39 |
+
# Distillation settings
|
| 40 |
+
n_storage_tokens: int = 4,
|
| 41 |
+
teachers: Tuple[str, ...] = ("siglip2", "dinov3"),
|
| 42 |
+
teachers_dim: Tuple[int, ...] = (1152, 1024),
|
| 43 |
+
# FlexAttention
|
| 44 |
+
use_flex_attn: bool = True,
|
| 45 |
+
**kwargs,
|
| 46 |
+
):
|
| 47 |
+
self.dim = dim
|
| 48 |
+
self.n_layers = n_layers
|
| 49 |
+
self.n_heads = n_heads
|
| 50 |
+
self.head_dim = head_dim
|
| 51 |
+
self.n_kv_heads = n_kv_heads
|
| 52 |
+
|
| 53 |
+
self.moe_dim = moe_dim
|
| 54 |
+
# Default MoEArgs matching your configs.py
|
| 55 |
+
self.moe_args = moe_args if moe_args is not None else {
|
| 56 |
+
"num_experts": 16,
|
| 57 |
+
"num_shared_experts": 1,
|
| 58 |
+
"top_k": 3,
|
| 59 |
+
"score_before_experts": False,
|
| 60 |
+
"route_norm": True,
|
| 61 |
+
"route_scale": 0.8633,
|
| 62 |
+
"activation": "relu2",
|
| 63 |
+
"score_func": "sigmoid",
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
self.first_n_layers_dense = first_n_layers_dense
|
| 67 |
+
self.ffn_dim = ffn_dim
|
| 68 |
+
self.activation = activation
|
| 69 |
+
|
| 70 |
+
self.channel_size = channel_size
|
| 71 |
+
self.spatial_patch_size = spatial_patch_size
|
| 72 |
+
self.temporal_patch_size = temporal_patch_size
|
| 73 |
+
|
| 74 |
+
self.enable_3d_rope = enable_3d_rope
|
| 75 |
+
self.rope_theta = rope_theta
|
| 76 |
+
self.rope_min_freqs = rope_min_freqs
|
| 77 |
+
self.rope_max_freqs = rope_max_freqs
|
| 78 |
+
self.max_seq_len = max_seq_len
|
| 79 |
+
|
| 80 |
+
self.norm_eps = norm_eps
|
| 81 |
+
self.use_qk_norm = use_qk_norm
|
| 82 |
+
self.use_tok_norm = use_tok_norm
|
| 83 |
+
self.parameterized_norm = parameterized_norm
|
| 84 |
+
|
| 85 |
+
self.n_storage_tokens = n_storage_tokens
|
| 86 |
+
self.teachers = teachers
|
| 87 |
+
self.teachers_dim = teachers_dim
|
| 88 |
+
|
| 89 |
+
self.use_flex_attn = use_flex_attn
|
| 90 |
+
|
| 91 |
+
super().__init__(**kwargs)
|
image_processing_siglino.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from typing import List, Optional, Union, Dict
|
| 5 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
| 6 |
+
from transformers.utils import logging
|
| 7 |
+
|
| 8 |
+
# Local import of your existing logic
|
| 9 |
+
# (Assuming smart_resize and convert_image_to_patches are in the same folder or copied here)
|
| 10 |
+
from .image_processor import smart_resize, convert_image_to_patches, pad_along_first_dim
|
| 11 |
+
|
| 12 |
+
logger = logging.get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
class SigLinoImageProcessor(BaseImageProcessor):
|
| 15 |
+
model_input_names = ["pixel_values", "padding_mask", "spatial_shapes"]
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
patch_size: int = 16,
|
| 20 |
+
min_pixels: int = 128 * 128,
|
| 21 |
+
max_pixels: int = 256 * 256,
|
| 22 |
+
image_mean: Optional[List[float]] = None,
|
| 23 |
+
image_std: Optional[List[float]] = None,
|
| 24 |
+
do_resize: bool = True,
|
| 25 |
+
do_rescale: bool = True,
|
| 26 |
+
do_normalize: bool = True,
|
| 27 |
+
**kwargs
|
| 28 |
+
):
|
| 29 |
+
super().__init__(**kwargs)
|
| 30 |
+
self.patch_size = patch_size
|
| 31 |
+
self.min_pixels = min_pixels
|
| 32 |
+
self.max_pixels = max_pixels
|
| 33 |
+
self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
|
| 34 |
+
self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
|
| 35 |
+
self.do_resize = do_resize
|
| 36 |
+
self.do_rescale = do_rescale
|
| 37 |
+
self.do_normalize = do_normalize
|
| 38 |
+
|
| 39 |
+
def preprocess_single(self, image: Image.Image) -> Dict:
|
| 40 |
+
"""Standard preprocessing for a single PIL image."""
|
| 41 |
+
if not isinstance(image, Image.Image):
|
| 42 |
+
image = Image.fromarray(image)
|
| 43 |
+
|
| 44 |
+
image = image.convert("RGB")
|
| 45 |
+
width, height = image.size # PIL uses (W, H)
|
| 46 |
+
|
| 47 |
+
# 1. Smart Resize
|
| 48 |
+
if self.do_resize:
|
| 49 |
+
resized_height, resized_width = smart_resize(
|
| 50 |
+
height, width,
|
| 51 |
+
factor=self.patch_size,
|
| 52 |
+
min_pixels=self.min_pixels,
|
| 53 |
+
max_pixels=self.max_pixels,
|
| 54 |
+
)
|
| 55 |
+
image = image.resize((resized_width, resized_height), Image.BICUBIC)
|
| 56 |
+
else:
|
| 57 |
+
resized_height, resized_width = height, width
|
| 58 |
+
|
| 59 |
+
image_np = np.array(image).astype(np.float32)
|
| 60 |
+
|
| 61 |
+
# 2. Rescale
|
| 62 |
+
if self.do_rescale:
|
| 63 |
+
image_np = image_np / 255.0
|
| 64 |
+
|
| 65 |
+
# 3. Normalize
|
| 66 |
+
if self.do_normalize:
|
| 67 |
+
mean = np.array(self.image_mean, dtype=np.float32)
|
| 68 |
+
std = np.array(self.image_std, dtype=np.float32)
|
| 69 |
+
image_np = (image_np - mean) / std
|
| 70 |
+
|
| 71 |
+
spatial_shape = (resized_height // self.patch_size, resized_width // self.patch_size)
|
| 72 |
+
|
| 73 |
+
# Convert to tensor and patchify
|
| 74 |
+
img_tensor = torch.from_numpy(image_np)
|
| 75 |
+
patches = convert_image_to_patches(img_tensor, self.patch_size)
|
| 76 |
+
|
| 77 |
+
return {
|
| 78 |
+
"patches": patches,
|
| 79 |
+
"spatial_shape": spatial_shape
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
def preprocess(
|
| 83 |
+
self,
|
| 84 |
+
images: Union[Image.Image, List[Image.Image]],
|
| 85 |
+
max_num_patches: int = 256,
|
| 86 |
+
return_tensors: Optional[str] = "pt",
|
| 87 |
+
**kwargs
|
| 88 |
+
) -> BatchFeature:
|
| 89 |
+
"""Main entry point for transformers image processor."""
|
| 90 |
+
if not isinstance(images, (list, tuple)):
|
| 91 |
+
images = [images]
|
| 92 |
+
|
| 93 |
+
results = [self.preprocess_single(img) for img in images]
|
| 94 |
+
|
| 95 |
+
batched_pixels = []
|
| 96 |
+
batched_masks = []
|
| 97 |
+
batched_shapes = []
|
| 98 |
+
|
| 99 |
+
for res in results:
|
| 100 |
+
patches = res["patches"]
|
| 101 |
+
shape = res["spatial_shape"]
|
| 102 |
+
|
| 103 |
+
# Padding logic
|
| 104 |
+
patches_padded, mask = pad_along_first_dim(
|
| 105 |
+
patches,
|
| 106 |
+
max_num_patches,
|
| 107 |
+
pad_value=0.0
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
batched_pixels.append(patches_padded)
|
| 111 |
+
batched_masks.append(mask)
|
| 112 |
+
batched_shapes.append(list(shape))
|
| 113 |
+
|
| 114 |
+
data = {
|
| 115 |
+
"pixel_values": torch.stack(batched_pixels),
|
| 116 |
+
"padding_mask": torch.stack(batched_masks),
|
| 117 |
+
"spatial_shapes": torch.tensor(batched_shapes)
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 121 |
+
|
image_processor.py
CHANGED
|
@@ -70,8 +70,8 @@ def pad_along_first_dim(
|
|
| 70 |
return array, mask
|
| 71 |
|
| 72 |
|
| 73 |
-
class
|
| 74 |
-
"""Image processor for
|
| 75 |
"""
|
| 76 |
|
| 77 |
def __init__(
|
|
|
|
| 70 |
return array, mask
|
| 71 |
|
| 72 |
|
| 73 |
+
class SigLinoImageProcessor:
|
| 74 |
+
"""Image processor for SigLino model.
|
| 75 |
"""
|
| 76 |
|
| 77 |
def __init__(
|
modeling_siglino.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import einops as E
|
| 5 |
+
from typing import Optional, Dict, Union, Tuple
|
| 6 |
+
from transformers import PreTrainedModel
|
| 7 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 8 |
+
|
| 9 |
+
# Relative imports from your local files
|
| 10 |
+
from .configuration_siglino import SigLinoConfig
|
| 11 |
+
from .attention import Attention, create_attention_mask
|
| 12 |
+
from .moe import MoE, FeedForward
|
| 13 |
+
from .rope import (
|
| 14 |
+
precompute_freqs_cis,
|
| 15 |
+
precompute_golden_freqs_cis,
|
| 16 |
+
apply_golden_freqs_cis_to_visual_pos,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
class PytorchGELUTanh(nn.Module):
|
| 20 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 21 |
+
return F.gelu(x, approximate="tanh")
|
| 22 |
+
|
| 23 |
+
class Siglip2MLP(nn.Module):
|
| 24 |
+
def __init__(self, hidden_size: int, intermediate_size: int):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.activation_fn = PytorchGELUTanh()
|
| 27 |
+
self.fc1 = nn.Linear(hidden_size, intermediate_size)
|
| 28 |
+
self.fc2 = nn.Linear(intermediate_size, hidden_size)
|
| 29 |
+
|
| 30 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
hidden_states = self.fc1(hidden_states)
|
| 32 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 33 |
+
hidden_states = self.fc2(hidden_states)
|
| 34 |
+
return hidden_states
|
| 35 |
+
|
| 36 |
+
class Siglip2MultiheadAttentionPoolingHead(nn.Module):
|
| 37 |
+
def __init__(self, hidden_size: int, num_attention_heads: int, output_dim: int):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.probe = nn.Parameter(torch.randn(1, 1, hidden_size))
|
| 40 |
+
self.attention = nn.MultiheadAttention(hidden_size, num_attention_heads, batch_first=True)
|
| 41 |
+
self.layernorm = nn.LayerNorm(hidden_size, eps=1e-5)
|
| 42 |
+
self.mlp = Siglip2MLP(hidden_size, 4304)
|
| 43 |
+
self.num_heads = num_attention_heads
|
| 44 |
+
|
| 45 |
+
def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 46 |
+
batch_size = hidden_state.shape[0]
|
| 47 |
+
probe = self.probe.repeat(batch_size, 1, 1)
|
| 48 |
+
|
| 49 |
+
if attention_mask is not None:
|
| 50 |
+
# Mask expansion logic kept from your original model.py
|
| 51 |
+
# Note: This uses einops and specific expansion for MHA
|
| 52 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None):
|
| 53 |
+
bsz, src_len = mask.size()
|
| 54 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 55 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 56 |
+
inverted_mask = torch.tensor(1.0, dtype=dtype, device=mask.device) - expanded_mask
|
| 57 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
| 58 |
+
|
| 59 |
+
attention_mask = E.rearrange(attention_mask, "(b s) -> b s", b=batch_size)
|
| 60 |
+
target_len, source_len = probe.shape[1], hidden_state.shape[1]
|
| 61 |
+
attention_mask = _expand_mask(attention_mask, hidden_state.dtype, target_len)
|
| 62 |
+
attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
|
| 63 |
+
attention_mask = attention_mask.reshape(-1, target_len, source_len)
|
| 64 |
+
|
| 65 |
+
hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
|
| 66 |
+
residual = hidden_state
|
| 67 |
+
hidden_state = self.layernorm(hidden_state)
|
| 68 |
+
hidden_state = residual + self.mlp(hidden_state)
|
| 69 |
+
return hidden_state[:, 0]
|
| 70 |
+
|
| 71 |
+
class Adapter(nn.Module):
|
| 72 |
+
def __init__(self, in_dim: int, out_dim: int, bias: bool = True):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.fc1 = nn.Linear(in_dim, out_dim)
|
| 75 |
+
self.norm = nn.LayerNorm(out_dim)
|
| 76 |
+
self.act = nn.GELU()
|
| 77 |
+
self.fc2 = nn.Linear(out_dim, out_dim, bias=bias)
|
| 78 |
+
|
| 79 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 80 |
+
x = self.fc1(x)
|
| 81 |
+
x = self.norm(x)
|
| 82 |
+
x = self.act(x)
|
| 83 |
+
x = self.fc2(x)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
class TransformerBlock(nn.Module):
|
| 87 |
+
def __init__(self, layer_id: int, config: SigLinoConfig):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.dim = config.dim
|
| 90 |
+
self.parameterized_norm = getattr(config, 'parameterized_norm', True)
|
| 91 |
+
if self.parameterized_norm:
|
| 92 |
+
self.attention_norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
|
| 93 |
+
self.ffn_norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
|
| 94 |
+
|
| 95 |
+
self.attention = Attention(
|
| 96 |
+
dim=config.dim,
|
| 97 |
+
n_heads=config.n_heads,
|
| 98 |
+
n_kv_heads=config.n_kv_heads,
|
| 99 |
+
head_dim=config.head_dim,
|
| 100 |
+
use_qk_norm=config.use_qk_norm,
|
| 101 |
+
enable_3d_rope=config.enable_3d_rope,
|
| 102 |
+
use_flex_attn=config.use_flex_attn,
|
| 103 |
+
use_sink_attn=True,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Handle MoE initialization from config dict
|
| 107 |
+
moe_args = config.moe_args
|
| 108 |
+
if isinstance(moe_args, dict):
|
| 109 |
+
from .moe import MoEArgs
|
| 110 |
+
moe_args = MoEArgs(**moe_args)
|
| 111 |
+
|
| 112 |
+
first_n_dense = getattr(config, 'first_n_layers_dense', 0)
|
| 113 |
+
use_dense = layer_id < first_n_dense
|
| 114 |
+
if use_dense:
|
| 115 |
+
ffn_hidden = getattr(config, 'ffn_dim', None) or config.moe_dim
|
| 116 |
+
activation = getattr(config, 'activation', 'silu')
|
| 117 |
+
self.feed_forward = FeedForward(config.dim, ffn_hidden, activation=activation)
|
| 118 |
+
self.moe_enabled = False
|
| 119 |
+
elif moe_args and moe_args.num_experts > 0:
|
| 120 |
+
self.moe = MoE(moe_args, dim=config.dim, hidden_dim=config.moe_dim)
|
| 121 |
+
self.moe_enabled = True
|
| 122 |
+
else:
|
| 123 |
+
self.feed_forward = FeedForward(config.dim, config.moe_dim)
|
| 124 |
+
self.moe_enabled = False
|
| 125 |
+
|
| 126 |
+
self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5
|
| 127 |
+
|
| 128 |
+
def forward(self, x, freqs_cis, freqs_cis_2d=None, pos_thw=None, attention_masks=None, compile=False):
|
| 129 |
+
if self.parameterized_norm:
|
| 130 |
+
x_norm = self.attention_norm(x)
|
| 131 |
+
else:
|
| 132 |
+
x_norm = F.rms_norm(x, (x.size(-1),))
|
| 133 |
+
h = x + self.attention(
|
| 134 |
+
x_norm,
|
| 135 |
+
freqs_cis,
|
| 136 |
+
freqs_cis_2d,
|
| 137 |
+
pos_thw,
|
| 138 |
+
attention_masks=attention_masks,
|
| 139 |
+
compile=compile,
|
| 140 |
+
)
|
| 141 |
+
h_norm = self.ffn_norm(h) if self.parameterized_norm else F.rms_norm(h, (h.size(-1),))
|
| 142 |
+
out = h + self.moe(h_norm) if self.moe_enabled else h + self.feed_forward(h_norm)
|
| 143 |
+
return out
|
| 144 |
+
|
| 145 |
+
class SigLinoPreTrainedModel(PreTrainedModel):
|
| 146 |
+
config_class = SigLinoConfig
|
| 147 |
+
base_model_prefix = "siglino"
|
| 148 |
+
main_input_name = "pixel_values"
|
| 149 |
+
_no_split_modules = ["TransformerBlock"]
|
| 150 |
+
|
| 151 |
+
def _init_weights(self, module):
|
| 152 |
+
# Weight initialization is handled by the internal init_weights call in __init__
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
def _apply(self, fn):
|
| 156 |
+
# Prevent casting complex RoPE buffers (freqs_cis) to real dtypes on model.to(bf16/fp16)
|
| 157 |
+
complex_buffers = {}
|
| 158 |
+
for name, buf in list(self.named_buffers(recurse=False)):
|
| 159 |
+
if buf is not None and buf.is_complex():
|
| 160 |
+
complex_buffers[name] = buf
|
| 161 |
+
del self._buffers[name]
|
| 162 |
+
|
| 163 |
+
ret = super()._apply(fn)
|
| 164 |
+
|
| 165 |
+
for name, buf in complex_buffers.items():
|
| 166 |
+
dummy = torch.tensor([0.0], device=buf.device)
|
| 167 |
+
res = fn(dummy)
|
| 168 |
+
|
| 169 |
+
if not res.is_complex():
|
| 170 |
+
new_buf = buf.to(device=res.device)
|
| 171 |
+
else:
|
| 172 |
+
new_buf = fn(buf)
|
| 173 |
+
|
| 174 |
+
persistent = name not in self._non_persistent_buffers_set
|
| 175 |
+
self.register_buffer(name, new_buf, persistent=persistent)
|
| 176 |
+
|
| 177 |
+
return ret
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class SigLinoModel(SigLinoPreTrainedModel):
|
| 181 |
+
def __init__(self, config: SigLinoConfig):
|
| 182 |
+
super().__init__(config)
|
| 183 |
+
self.config = config
|
| 184 |
+
self.n_layers = config.n_layers
|
| 185 |
+
self.patch_size = config.spatial_patch_size
|
| 186 |
+
self.n_storage_tokens = config.n_storage_tokens
|
| 187 |
+
|
| 188 |
+
# Patch embedding
|
| 189 |
+
self.n_pixels_per_patch = config.temporal_patch_size * config.spatial_patch_size ** 2
|
| 190 |
+
self.img_projector = nn.Linear(
|
| 191 |
+
self.n_pixels_per_patch * config.channel_size,
|
| 192 |
+
config.dim,
|
| 193 |
+
bias=False,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
self.cls_token = nn.Parameter(torch.empty(1, 1, config.dim))
|
| 197 |
+
if self.n_storage_tokens > 0:
|
| 198 |
+
self.storage_tokens = nn.Parameter(torch.empty(1, self.n_storage_tokens, config.dim))
|
| 199 |
+
|
| 200 |
+
# RoPE
|
| 201 |
+
head_dim = config.head_dim or config.dim // config.n_heads
|
| 202 |
+
d = head_dim // 2
|
| 203 |
+
self.register_buffer("freqs_cis_golden", self._precompute_golden_freqs_cis(d, config))
|
| 204 |
+
self.register_buffer("freqs_cis", self._precompute_freqs_cis(d, config), persistent=False)
|
| 205 |
+
|
| 206 |
+
self.layers = nn.ModuleList([TransformerBlock(i, config) for i in range(config.n_layers)])
|
| 207 |
+
self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
|
| 208 |
+
|
| 209 |
+
# Teacher adapters
|
| 210 |
+
teachers_dict = dict(zip(config.teachers, config.teachers_dim))
|
| 211 |
+
dinov3_dim = teachers_dict.get("dinov3", 1280)
|
| 212 |
+
siglip2_dim = teachers_dict.get("siglip2", 1152)
|
| 213 |
+
|
| 214 |
+
self.dinov3_adapter = Adapter(config.dim, dinov3_dim, bias=False)
|
| 215 |
+
self.siglip2_adapter = Adapter(config.dim, siglip2_dim, bias=False)
|
| 216 |
+
self.layer_norm_dinov3 = nn.LayerNorm(dinov3_dim)
|
| 217 |
+
self.siglip2_multihead_attention_pooling_head = Siglip2MultiheadAttentionPoolingHead(
|
| 218 |
+
siglip2_dim, 16, siglip2_dim
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
self.post_init()
|
| 222 |
+
|
| 223 |
+
def _precompute_freqs_cis(self, head_dim: int, config: SigLinoConfig) -> torch.Tensor:
|
| 224 |
+
return precompute_freqs_cis(head_dim, config.max_seq_len, config.rope_theta)
|
| 225 |
+
|
| 226 |
+
def _precompute_golden_freqs_cis(self, head_dim: int, config: SigLinoConfig) -> torch.Tensor:
|
| 227 |
+
return precompute_golden_freqs_cis(
|
| 228 |
+
config.n_heads, head_dim, config.rope_min_freqs, config.rope_max_freqs
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
def _get_thw_pos(self, batch_size, num_patches, spatial_shapes, device):
|
| 232 |
+
N = batch_size
|
| 233 |
+
R = 1 + self.n_storage_tokens
|
| 234 |
+
S = R + num_patches
|
| 235 |
+
tpos = torch.zeros((N, S), dtype=torch.float32, device=device)
|
| 236 |
+
hpos = torch.zeros((N, S), dtype=torch.float32, device=device)
|
| 237 |
+
wpos = torch.zeros((N, S), dtype=torch.float32, device=device)
|
| 238 |
+
|
| 239 |
+
for n in range(N):
|
| 240 |
+
H, W = spatial_shapes[n].tolist()
|
| 241 |
+
h_coords = torch.arange(H, device=device).float()
|
| 242 |
+
w_coords = torch.arange(W, device=device).float()
|
| 243 |
+
xlim, ylim = (W / H) ** 0.5, (H / W) ** 0.5
|
| 244 |
+
h_norm = -ylim + 2 * ylim * h_coords / max(H - 1, 1)
|
| 245 |
+
w_norm = -xlim + 2 * xlim * w_coords / max(W - 1, 1)
|
| 246 |
+
|
| 247 |
+
# Vectorized fill for patches
|
| 248 |
+
h_grid, w_grid = torch.meshgrid(h_norm, w_norm, indexing='ij')
|
| 249 |
+
hpos[n, R:R+H*W] = h_grid.reshape(-1)
|
| 250 |
+
wpos[n, R:R+H*W] = w_grid.reshape(-1)
|
| 251 |
+
|
| 252 |
+
hpos[n, :R], wpos[n, :R] = float('nan'), float('nan')
|
| 253 |
+
|
| 254 |
+
return torch.stack([tpos, hpos, wpos], dim=0)
|
| 255 |
+
|
| 256 |
+
def forward(
|
| 257 |
+
self,
|
| 258 |
+
pixel_values: torch.Tensor,
|
| 259 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 260 |
+
spatial_shapes: Optional[torch.Tensor] = None,
|
| 261 |
+
output_hidden_states: bool = False,
|
| 262 |
+
return_dict: bool = True,
|
| 263 |
+
compile: bool = True,
|
| 264 |
+
) -> Union[Dict, Tuple]:
|
| 265 |
+
N, L, _ = pixel_values.shape
|
| 266 |
+
device = pixel_values.device
|
| 267 |
+
R = 1 + self.n_storage_tokens
|
| 268 |
+
|
| 269 |
+
if padding_mask is None:
|
| 270 |
+
padding_mask = torch.ones((N, L), dtype=pixel_values.dtype, device=device)
|
| 271 |
+
|
| 272 |
+
h_NLD = self.img_projector(pixel_values)
|
| 273 |
+
cls_expanded = self.cls_token.expand(N, -1, -1)
|
| 274 |
+
if self.n_storage_tokens > 0:
|
| 275 |
+
reg_expanded = self.storage_tokens.expand(N, -1, -1)
|
| 276 |
+
h_NSD = torch.cat([cls_expanded, reg_expanded, h_NLD], dim=1)
|
| 277 |
+
else:
|
| 278 |
+
h_NSD = torch.cat([cls_expanded, h_NLD], dim=1)
|
| 279 |
+
|
| 280 |
+
S = h_NSD.shape[1]
|
| 281 |
+
cls_reg_mask = torch.ones((N, R), dtype=padding_mask.dtype, device=device)
|
| 282 |
+
full_mask = torch.cat([cls_reg_mask, padding_mask], dim=1)
|
| 283 |
+
|
| 284 |
+
# FlexAttention Mask
|
| 285 |
+
def mask_mod(b, h, q_idx, kv_idx):
|
| 286 |
+
return full_mask.bool()[b, q_idx] & full_mask.bool()[b, kv_idx]
|
| 287 |
+
|
| 288 |
+
block_mask = create_attention_mask(mask_mod, N, None, S, S)
|
| 289 |
+
|
| 290 |
+
# RoPE
|
| 291 |
+
thw_pos = self._get_thw_pos(N, L, spatial_shapes, device)
|
| 292 |
+
pos_thw = E.rearrange(thw_pos, "p n s -> n s p").to(dtype=torch.float32)
|
| 293 |
+
patch_mask_2d = torch.zeros((N, S), dtype=torch.bool, device=device)
|
| 294 |
+
patch_mask_2d[:, R:] = padding_mask.bool()
|
| 295 |
+
pos_thw[:, :, 1:] = pos_thw[:, :, 1:].masked_fill(~patch_mask_2d.unsqueeze(-1), float("nan"))
|
| 296 |
+
|
| 297 |
+
freqs_cis_golden = apply_golden_freqs_cis_to_visual_pos(
|
| 298 |
+
self.freqs_cis_golden.to(dtype=pos_thw.dtype), pos_thw[:, :, 1:]
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
all_hidden_states = () if output_hidden_states else None
|
| 302 |
+
for layer in self.layers:
|
| 303 |
+
if output_hidden_states:
|
| 304 |
+
all_hidden_states += (h_NSD,)
|
| 305 |
+
h_NSD = layer(h_NSD, self.freqs_cis, freqs_cis_2d=freqs_cis_golden,
|
| 306 |
+
pos_thw=pos_thw, attention_masks=block_mask, compile=compile)
|
| 307 |
+
|
| 308 |
+
h_NSD = self.norm(h_NSD)
|
| 309 |
+
|
| 310 |
+
# Feature Extraction & Adapters
|
| 311 |
+
cls_feats = h_NSD[:, 0]
|
| 312 |
+
patch_feats = h_NSD[:, R:]
|
| 313 |
+
|
| 314 |
+
student_patch_dinov3 = self.dinov3_adapter(patch_feats)
|
| 315 |
+
student_patch_siglip = self.siglip2_adapter(patch_feats)
|
| 316 |
+
student_cls_dinov3 = self.dinov3_adapter(cls_feats)
|
| 317 |
+
|
| 318 |
+
h_sig = self.siglip2_adapter(h_NSD)
|
| 319 |
+
siglip_attn_mask = full_mask.reshape(-1)
|
| 320 |
+
student_summary_siglip = self.siglip2_multihead_attention_pooling_head(h_sig, siglip_attn_mask)
|
| 321 |
+
|
| 322 |
+
output = {
|
| 323 |
+
"last_hidden_state": h_NSD,
|
| 324 |
+
"patch_features": {
|
| 325 |
+
"dinov3": student_patch_dinov3,
|
| 326 |
+
"siglip2": student_patch_siglip,
|
| 327 |
+
"siglino": patch_feats,
|
| 328 |
+
},
|
| 329 |
+
"summary_features": {
|
| 330 |
+
"dinov3": student_cls_dinov3,
|
| 331 |
+
"siglip2": student_summary_siglip,
|
| 332 |
+
"siglino": cls_feats,
|
| 333 |
+
},
|
| 334 |
+
"hidden_states": all_hidden_states,
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
if not return_dict:
|
| 338 |
+
return tuple(v for v in output.values() if v is not None)
|
| 339 |
+
return output
|
preprocessor_config.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"auto_map": {
|
| 3 |
-
"AutoImageProcessor": "
|
| 4 |
},
|
| 5 |
"do_normalize": true,
|
| 6 |
"do_rescale": true,
|
|
@@ -10,7 +10,7 @@
|
|
| 10 |
0.5,
|
| 11 |
0.5
|
| 12 |
],
|
| 13 |
-
"image_processor_type": "
|
| 14 |
"image_std": [
|
| 15 |
0.5,
|
| 16 |
0.5,
|
|
|
|
| 1 |
{
|
| 2 |
"auto_map": {
|
| 3 |
+
"AutoImageProcessor": "image_processing_siglino.SigLinoImageProcessor"
|
| 4 |
},
|
| 5 |
"do_normalize": true,
|
| 6 |
"do_rescale": true,
|
|
|
|
| 10 |
0.5,
|
| 11 |
0.5
|
| 12 |
],
|
| 13 |
+
"image_processor_type": "SigLinoImageProcessor",
|
| 14 |
"image_std": [
|
| 15 |
0.5,
|
| 16 |
0.5,
|
utils.py
CHANGED
|
@@ -9,21 +9,21 @@ from PIL import Image
|
|
| 9 |
from typing import Union, List
|
| 10 |
import os
|
| 11 |
|
| 12 |
-
from .model import
|
| 13 |
-
from .configs import
|
| 14 |
-
from .image_processor import
|
| 15 |
|
| 16 |
|
| 17 |
|
| 18 |
-
def
|
| 19 |
checkpoint_path: str,
|
| 20 |
-
config_name: str = "
|
| 21 |
device: Union[str, torch.device] = "cuda",
|
| 22 |
dtype: torch.dtype | None = None,
|
| 23 |
**kwargs,
|
| 24 |
-
) -> tuple[
|
| 25 |
"""
|
| 26 |
-
Load a
|
| 27 |
|
| 28 |
Args:
|
| 29 |
checkpoint_path: Path to the model checkpoint
|
|
@@ -35,13 +35,13 @@ def load_amoe_model(
|
|
| 35 |
Tuple of (model, image_processor)
|
| 36 |
"""
|
| 37 |
# Get configuration
|
| 38 |
-
if config_name in
|
| 39 |
-
args =
|
| 40 |
else:
|
| 41 |
-
raise ValueError(f"Unknown config: {config_name}. Available: {list(
|
| 42 |
|
| 43 |
# Create model
|
| 44 |
-
model =
|
| 45 |
|
| 46 |
# Standard PyTorch checkpoint
|
| 47 |
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
|
@@ -55,7 +55,7 @@ def load_amoe_model(
|
|
| 55 |
model.eval()
|
| 56 |
|
| 57 |
# Create image processor
|
| 58 |
-
image_processor =
|
| 59 |
|
| 60 |
return model, image_processor
|
| 61 |
|
|
@@ -178,7 +178,7 @@ def load_amoe_model(
|
|
| 178 |
FEATURE_DIM_DICT = {
|
| 179 |
"dinov3": 1024,
|
| 180 |
"siglip2": 1152,
|
| 181 |
-
"
|
| 182 |
}
|
| 183 |
|
| 184 |
PATCH_SIZE = 16
|
|
|
|
| 9 |
from typing import Union, List
|
| 10 |
import os
|
| 11 |
|
| 12 |
+
from .model import SigLino
|
| 13 |
+
from .configs import SigLinoArgs, siglino_configs
|
| 14 |
+
from .image_processor import SigLinoImageProcessor
|
| 15 |
|
| 16 |
|
| 17 |
|
| 18 |
+
def load_siglino_model(
|
| 19 |
checkpoint_path: str,
|
| 20 |
+
config_name: str = "siglino-0.3B",
|
| 21 |
device: Union[str, torch.device] = "cuda",
|
| 22 |
dtype: torch.dtype | None = None,
|
| 23 |
**kwargs,
|
| 24 |
+
) -> tuple[SigLino, SigLinoImageProcessor]:
|
| 25 |
"""
|
| 26 |
+
Load a SigLino model from a checkpoint.
|
| 27 |
|
| 28 |
Args:
|
| 29 |
checkpoint_path: Path to the model checkpoint
|
|
|
|
| 35 |
Tuple of (model, image_processor)
|
| 36 |
"""
|
| 37 |
# Get configuration
|
| 38 |
+
if config_name in siglino_configs:
|
| 39 |
+
args = siglino_configs[config_name]
|
| 40 |
else:
|
| 41 |
+
raise ValueError(f"Unknown config: {config_name}. Available: {list(siglino_configs.keys())}")
|
| 42 |
|
| 43 |
# Create model
|
| 44 |
+
model = SigLino(args)
|
| 45 |
|
| 46 |
# Standard PyTorch checkpoint
|
| 47 |
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
|
|
|
| 55 |
model.eval()
|
| 56 |
|
| 57 |
# Create image processor
|
| 58 |
+
image_processor = SigLinoImageProcessor(patch_size=args.spatial_patch_size, **kwargs)
|
| 59 |
|
| 60 |
return model, image_processor
|
| 61 |
|
|
|
|
| 178 |
FEATURE_DIM_DICT = {
|
| 179 |
"dinov3": 1024,
|
| 180 |
"siglip2": 1152,
|
| 181 |
+
"siglino": 768, # Model dimension
|
| 182 |
}
|
| 183 |
|
| 184 |
PATCH_SIZE = 16
|