SofianChay commited on
Commit
7e6ec35
·
verified ·
1 Parent(s): 0568bad

Update SigLino siglino-70M (full content push)

Browse files
README.md CHANGED
@@ -6,17 +6,19 @@ tags:
6
  - image-feature-extraction
7
  ---
8
 
9
- # AMoE-Dense-S
10
 
11
  **Accepted at CVPR 2026**
12
 
13
- [![Project Website](https://img.shields.io/badge/Project-Website-blue)](https://sofianchay.github.io/amoe/)
14
  [![arXiv](https://img.shields.io/badge/arXiv-2512.20157-b31b1b.svg)](https://arxiv.org/abs/2512.20157)
15
- [![GitHub](https://img.shields.io/badge/GitHub-Code-black)](https://github.com/tiiuae/amoe)
16
 
17
- Small dense variant of AMoE. 0.07B parameters.
18
 
19
- Part of the [AMoE model family](https://huggingface.co/collections/tiiuae/amoe-agglomerative-moe-vision-foundation-models).
 
 
20
 
21
  ## Usage
22
 
@@ -25,7 +27,7 @@ import torch
25
  from PIL import Image
26
  from transformers import AutoModel, AutoImageProcessor
27
 
28
- model_id = "tiiuae/amoe-dense-S"
29
  model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to("cuda", dtype=torch.bfloat16)
30
  processor = AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
31
 
@@ -36,8 +38,8 @@ inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
36
  with torch.no_grad():
37
  outputs = model(**inputs)
38
 
39
- # Options: 'amoe' (512d), 'siglip2' (1152d), 'dinov3' (1024d)
40
- patch_features = outputs["patch_features"]["amoe"] # (Batch, Tokens, 512)
41
  summary_features = outputs["summary_features"]["siglip2"] # (Batch, 1152)
42
  ```
43
 
@@ -53,11 +55,23 @@ summary_features = outputs["summary_features"]["siglip2"] # (Batch, 1152)
53
  | Patch Size | 16x16 |
54
  | Teachers | DINOv3, SigLIP2 |
55
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  ## Citation
57
 
58
  ```bibtex
59
  @article{chaybouti2025amoe,
60
- title={AMOE: Agglomerative Mixture-of-Experts Vision Foundation Models},
61
  author={Chaybouti, Sofian and Narayan, Sanath and Dahou, Yasser and Le Khac, Phuc H. and Singh, Ankit and Huynh, Ngoc Dung and Para, Wamiq Reyaz and Kuehne, Hilde and Hacid, Hakim},
62
  journal={arXiv preprint arXiv:2512.20157},
63
  year={2025}
 
6
  - image-feature-extraction
7
  ---
8
 
9
+ # SigLino-70M
10
 
11
  **Accepted at CVPR 2026**
12
 
13
+ [![Project Website](https://img.shields.io/badge/Project-Website-blue)](https://sofianchay.github.io/siglino/)
14
  [![arXiv](https://img.shields.io/badge/arXiv-2512.20157-b31b1b.svg)](https://arxiv.org/abs/2512.20157)
15
+ [![GitHub](https://img.shields.io/badge/GitHub-Code-black)](https://github.com/tiiuae/siglino)
16
 
17
+ This work stems from the **CVPR 2026 AMoE paper**, which designs and applies distillation into a Mixture-of-Experts (MoE) vision architecture. We have chosen the name **SigLino** for better clarity (SigLIP2 + DINOv3).
18
 
19
+ Dense variant of SigLino. 70M parameters.
20
+
21
+ Part of the [SigLino model family](https://huggingface.co/collections/tiiuae/siglino-vision-foundation-models).
22
 
23
  ## Usage
24
 
 
27
  from PIL import Image
28
  from transformers import AutoModel, AutoImageProcessor
29
 
30
+ model_id = "tiiuae/siglino-70M"
31
  model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to("cuda", dtype=torch.bfloat16)
32
  processor = AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
33
 
 
38
  with torch.no_grad():
39
  outputs = model(**inputs)
40
 
41
+ # Options: 'siglino' (512d), 'siglip2' (1152d), 'dinov3' (1024d)
42
+ patch_features = outputs["patch_features"]["siglino"] # (Batch, Tokens, 512)
43
  summary_features = outputs["summary_features"]["siglip2"] # (Batch, 1152)
44
  ```
45
 
 
55
  | Patch Size | 16x16 |
56
  | Teachers | DINOv3, SigLIP2 |
57
 
58
+ ## Results (512x512, ensemble features)
59
+
60
+ | Task | Metric | Score |
61
+ |------|--------|-------|
62
+ | kNN (ImageNet) | Acc | 81.7 |
63
+ | kNN (6-dataset avg) | Acc | 86.2 |
64
+ | Zero-shot cls (ImageNet) | Acc | 71.2 |
65
+ | Flickr30K I2T | R@1 | 90.5 |
66
+ | MSCOCO I2T | R@1 | 65.4 |
67
+ | Pascal VOC (1024) | mIoU | 84.8 |
68
+ | Cityscapes (1024) | mIoU | 61.6 |
69
+
70
  ## Citation
71
 
72
  ```bibtex
73
  @article{chaybouti2025amoe,
74
+ title={AMoE: Agglomerative Mixture-of-Experts Vision Foundation Models},
75
  author={Chaybouti, Sofian and Narayan, Sanath and Dahou, Yasser and Le Khac, Phuc H. and Singh, Ankit and Huynh, Ngoc Dung and Para, Wamiq Reyaz and Kuehne, Hilde and Hacid, Hakim},
76
  journal={arXiv preprint arXiv:2512.20157},
77
  year={2025}
config.json CHANGED
@@ -1,12 +1,12 @@
1
  {
2
  "activation": "silu",
3
  "architectures": [
4
- "AMOEModel"
5
  ],
6
  "auto_map": {
7
- "AutoConfig": "configuration_amoe.AMOEConfig",
8
- "AutoImageProcessor": "image_processing_amoe.AMOEImageProcessor",
9
- "AutoModel": "modeling_amoe.AMOEModel"
10
  },
11
  "channel_size": 3,
12
  "dim": 512,
@@ -16,7 +16,7 @@
16
  "first_n_layers_dense": 12,
17
  "head_dim": 64,
18
  "max_seq_len": 8192,
19
- "model_type": "amoe",
20
  "moe_args": {
21
  "activation": "silu",
22
  "num_experts": 1,
 
1
  {
2
  "activation": "silu",
3
  "architectures": [
4
+ "SigLinoModel"
5
  ],
6
  "auto_map": {
7
+ "AutoConfig": "configuration_siglino.SigLinoConfig",
8
+ "AutoImageProcessor": "image_processing_siglino.SigLinoImageProcessor",
9
+ "AutoModel": "modeling_siglino.SigLinoModel"
10
  },
11
  "channel_size": 3,
12
  "dim": 512,
 
16
  "first_n_layers_dense": 12,
17
  "head_dim": 64,
18
  "max_seq_len": 8192,
19
+ "model_type": "siglino",
20
  "moe_args": {
21
  "activation": "silu",
22
  "num_experts": 1,
configuration_siglino.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import Optional, List, Union, Dict, Tuple
3
+
4
+ class SigLinoConfig(PretrainedConfig):
5
+ """
6
+ Configuration class to store the configuration of an `SigLinoModel`.
7
+ """
8
+ model_type = "siglino"
9
+
10
+ def __init__(
11
+ self,
12
+ dim: int = 768,
13
+ n_layers: int = 18,
14
+ n_heads: int = 12,
15
+ head_dim: Optional[int] = 128,
16
+ n_kv_heads: Optional[int] = 4,
17
+ # MoE configuration
18
+ moe_dim: int = 768,
19
+ moe_args: Optional[Dict] = None,
20
+ # Dense FFN configuration
21
+ first_n_layers_dense: int = 0,
22
+ ffn_dim: Optional[int] = None,
23
+ activation: str = "silu",
24
+ # Vision settings
25
+ channel_size: int = 3,
26
+ spatial_patch_size: int = 16,
27
+ temporal_patch_size: int = 1,
28
+ # RoPE settings
29
+ enable_3d_rope: bool = True,
30
+ rope_theta: float = 100000.0,
31
+ rope_min_freqs: float = 1.0,
32
+ rope_max_freqs: float = 20.0,
33
+ max_seq_len: int = 8192,
34
+ # Normalization
35
+ norm_eps: float = 1e-5,
36
+ use_qk_norm: bool = True,
37
+ use_tok_norm: bool = True,
38
+ parameterized_norm: bool = True,
39
+ # Distillation settings
40
+ n_storage_tokens: int = 4,
41
+ teachers: Tuple[str, ...] = ("siglip2", "dinov3"),
42
+ teachers_dim: Tuple[int, ...] = (1152, 1024),
43
+ # FlexAttention
44
+ use_flex_attn: bool = True,
45
+ **kwargs,
46
+ ):
47
+ self.dim = dim
48
+ self.n_layers = n_layers
49
+ self.n_heads = n_heads
50
+ self.head_dim = head_dim
51
+ self.n_kv_heads = n_kv_heads
52
+
53
+ self.moe_dim = moe_dim
54
+ # Default MoEArgs matching your configs.py
55
+ self.moe_args = moe_args if moe_args is not None else {
56
+ "num_experts": 16,
57
+ "num_shared_experts": 1,
58
+ "top_k": 3,
59
+ "score_before_experts": False,
60
+ "route_norm": True,
61
+ "route_scale": 0.8633,
62
+ "activation": "relu2",
63
+ "score_func": "sigmoid",
64
+ }
65
+
66
+ self.first_n_layers_dense = first_n_layers_dense
67
+ self.ffn_dim = ffn_dim
68
+ self.activation = activation
69
+
70
+ self.channel_size = channel_size
71
+ self.spatial_patch_size = spatial_patch_size
72
+ self.temporal_patch_size = temporal_patch_size
73
+
74
+ self.enable_3d_rope = enable_3d_rope
75
+ self.rope_theta = rope_theta
76
+ self.rope_min_freqs = rope_min_freqs
77
+ self.rope_max_freqs = rope_max_freqs
78
+ self.max_seq_len = max_seq_len
79
+
80
+ self.norm_eps = norm_eps
81
+ self.use_qk_norm = use_qk_norm
82
+ self.use_tok_norm = use_tok_norm
83
+ self.parameterized_norm = parameterized_norm
84
+
85
+ self.n_storage_tokens = n_storage_tokens
86
+ self.teachers = teachers
87
+ self.teachers_dim = teachers_dim
88
+
89
+ self.use_flex_attn = use_flex_attn
90
+
91
+ super().__init__(**kwargs)
image_processing_siglino.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from typing import List, Optional, Union, Dict
5
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
6
+ from transformers.utils import logging
7
+
8
+ # Local import of your existing logic
9
+ # (Assuming smart_resize and convert_image_to_patches are in the same folder or copied here)
10
+ from .image_processor import smart_resize, convert_image_to_patches, pad_along_first_dim
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+ class SigLinoImageProcessor(BaseImageProcessor):
15
+ model_input_names = ["pixel_values", "padding_mask", "spatial_shapes"]
16
+
17
+ def __init__(
18
+ self,
19
+ patch_size: int = 16,
20
+ min_pixels: int = 128 * 128,
21
+ max_pixels: int = 256 * 256,
22
+ image_mean: Optional[List[float]] = None,
23
+ image_std: Optional[List[float]] = None,
24
+ do_resize: bool = True,
25
+ do_rescale: bool = True,
26
+ do_normalize: bool = True,
27
+ **kwargs
28
+ ):
29
+ super().__init__(**kwargs)
30
+ self.patch_size = patch_size
31
+ self.min_pixels = min_pixels
32
+ self.max_pixels = max_pixels
33
+ self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
34
+ self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
35
+ self.do_resize = do_resize
36
+ self.do_rescale = do_rescale
37
+ self.do_normalize = do_normalize
38
+
39
+ def preprocess_single(self, image: Image.Image) -> Dict:
40
+ """Standard preprocessing for a single PIL image."""
41
+ if not isinstance(image, Image.Image):
42
+ image = Image.fromarray(image)
43
+
44
+ image = image.convert("RGB")
45
+ width, height = image.size # PIL uses (W, H)
46
+
47
+ # 1. Smart Resize
48
+ if self.do_resize:
49
+ resized_height, resized_width = smart_resize(
50
+ height, width,
51
+ factor=self.patch_size,
52
+ min_pixels=self.min_pixels,
53
+ max_pixels=self.max_pixels,
54
+ )
55
+ image = image.resize((resized_width, resized_height), Image.BICUBIC)
56
+ else:
57
+ resized_height, resized_width = height, width
58
+
59
+ image_np = np.array(image).astype(np.float32)
60
+
61
+ # 2. Rescale
62
+ if self.do_rescale:
63
+ image_np = image_np / 255.0
64
+
65
+ # 3. Normalize
66
+ if self.do_normalize:
67
+ mean = np.array(self.image_mean, dtype=np.float32)
68
+ std = np.array(self.image_std, dtype=np.float32)
69
+ image_np = (image_np - mean) / std
70
+
71
+ spatial_shape = (resized_height // self.patch_size, resized_width // self.patch_size)
72
+
73
+ # Convert to tensor and patchify
74
+ img_tensor = torch.from_numpy(image_np)
75
+ patches = convert_image_to_patches(img_tensor, self.patch_size)
76
+
77
+ return {
78
+ "patches": patches,
79
+ "spatial_shape": spatial_shape
80
+ }
81
+
82
+ def preprocess(
83
+ self,
84
+ images: Union[Image.Image, List[Image.Image]],
85
+ max_num_patches: int = 256,
86
+ return_tensors: Optional[str] = "pt",
87
+ **kwargs
88
+ ) -> BatchFeature:
89
+ """Main entry point for transformers image processor."""
90
+ if not isinstance(images, (list, tuple)):
91
+ images = [images]
92
+
93
+ results = [self.preprocess_single(img) for img in images]
94
+
95
+ batched_pixels = []
96
+ batched_masks = []
97
+ batched_shapes = []
98
+
99
+ for res in results:
100
+ patches = res["patches"]
101
+ shape = res["spatial_shape"]
102
+
103
+ # Padding logic
104
+ patches_padded, mask = pad_along_first_dim(
105
+ patches,
106
+ max_num_patches,
107
+ pad_value=0.0
108
+ )
109
+
110
+ batched_pixels.append(patches_padded)
111
+ batched_masks.append(mask)
112
+ batched_shapes.append(list(shape))
113
+
114
+ data = {
115
+ "pixel_values": torch.stack(batched_pixels),
116
+ "padding_mask": torch.stack(batched_masks),
117
+ "spatial_shapes": torch.tensor(batched_shapes)
118
+ }
119
+
120
+ return BatchFeature(data=data, tensor_type=return_tensors)
121
+
image_processor.py CHANGED
@@ -70,8 +70,8 @@ def pad_along_first_dim(
70
  return array, mask
71
 
72
 
73
- class AMOEImageProcessor:
74
- """Image processor for AMOE model.
75
  """
76
 
77
  def __init__(
 
70
  return array, mask
71
 
72
 
73
+ class SigLinoImageProcessor:
74
+ """Image processor for SigLino model.
75
  """
76
 
77
  def __init__(
modeling_siglino.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import einops as E
5
+ from typing import Optional, Dict, Union, Tuple
6
+ from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import BaseModelOutput
8
+
9
+ # Relative imports from your local files
10
+ from .configuration_siglino import SigLinoConfig
11
+ from .attention import Attention, create_attention_mask
12
+ from .moe import MoE, FeedForward
13
+ from .rope import (
14
+ precompute_freqs_cis,
15
+ precompute_golden_freqs_cis,
16
+ apply_golden_freqs_cis_to_visual_pos,
17
+ )
18
+
19
+ class PytorchGELUTanh(nn.Module):
20
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
21
+ return F.gelu(x, approximate="tanh")
22
+
23
+ class Siglip2MLP(nn.Module):
24
+ def __init__(self, hidden_size: int, intermediate_size: int):
25
+ super().__init__()
26
+ self.activation_fn = PytorchGELUTanh()
27
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
28
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
29
+
30
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
31
+ hidden_states = self.fc1(hidden_states)
32
+ hidden_states = self.activation_fn(hidden_states)
33
+ hidden_states = self.fc2(hidden_states)
34
+ return hidden_states
35
+
36
+ class Siglip2MultiheadAttentionPoolingHead(nn.Module):
37
+ def __init__(self, hidden_size: int, num_attention_heads: int, output_dim: int):
38
+ super().__init__()
39
+ self.probe = nn.Parameter(torch.randn(1, 1, hidden_size))
40
+ self.attention = nn.MultiheadAttention(hidden_size, num_attention_heads, batch_first=True)
41
+ self.layernorm = nn.LayerNorm(hidden_size, eps=1e-5)
42
+ self.mlp = Siglip2MLP(hidden_size, 4304)
43
+ self.num_heads = num_attention_heads
44
+
45
+ def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
46
+ batch_size = hidden_state.shape[0]
47
+ probe = self.probe.repeat(batch_size, 1, 1)
48
+
49
+ if attention_mask is not None:
50
+ # Mask expansion logic kept from your original model.py
51
+ # Note: This uses einops and specific expansion for MHA
52
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None):
53
+ bsz, src_len = mask.size()
54
+ tgt_len = tgt_len if tgt_len is not None else src_len
55
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
56
+ inverted_mask = torch.tensor(1.0, dtype=dtype, device=mask.device) - expanded_mask
57
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
58
+
59
+ attention_mask = E.rearrange(attention_mask, "(b s) -> b s", b=batch_size)
60
+ target_len, source_len = probe.shape[1], hidden_state.shape[1]
61
+ attention_mask = _expand_mask(attention_mask, hidden_state.dtype, target_len)
62
+ attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
63
+ attention_mask = attention_mask.reshape(-1, target_len, source_len)
64
+
65
+ hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
66
+ residual = hidden_state
67
+ hidden_state = self.layernorm(hidden_state)
68
+ hidden_state = residual + self.mlp(hidden_state)
69
+ return hidden_state[:, 0]
70
+
71
+ class Adapter(nn.Module):
72
+ def __init__(self, in_dim: int, out_dim: int, bias: bool = True):
73
+ super().__init__()
74
+ self.fc1 = nn.Linear(in_dim, out_dim)
75
+ self.norm = nn.LayerNorm(out_dim)
76
+ self.act = nn.GELU()
77
+ self.fc2 = nn.Linear(out_dim, out_dim, bias=bias)
78
+
79
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
80
+ x = self.fc1(x)
81
+ x = self.norm(x)
82
+ x = self.act(x)
83
+ x = self.fc2(x)
84
+ return x
85
+
86
+ class TransformerBlock(nn.Module):
87
+ def __init__(self, layer_id: int, config: SigLinoConfig):
88
+ super().__init__()
89
+ self.dim = config.dim
90
+ self.parameterized_norm = getattr(config, 'parameterized_norm', True)
91
+ if self.parameterized_norm:
92
+ self.attention_norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
93
+ self.ffn_norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
94
+
95
+ self.attention = Attention(
96
+ dim=config.dim,
97
+ n_heads=config.n_heads,
98
+ n_kv_heads=config.n_kv_heads,
99
+ head_dim=config.head_dim,
100
+ use_qk_norm=config.use_qk_norm,
101
+ enable_3d_rope=config.enable_3d_rope,
102
+ use_flex_attn=config.use_flex_attn,
103
+ use_sink_attn=True,
104
+ )
105
+
106
+ # Handle MoE initialization from config dict
107
+ moe_args = config.moe_args
108
+ if isinstance(moe_args, dict):
109
+ from .moe import MoEArgs
110
+ moe_args = MoEArgs(**moe_args)
111
+
112
+ first_n_dense = getattr(config, 'first_n_layers_dense', 0)
113
+ use_dense = layer_id < first_n_dense
114
+ if use_dense:
115
+ ffn_hidden = getattr(config, 'ffn_dim', None) or config.moe_dim
116
+ activation = getattr(config, 'activation', 'silu')
117
+ self.feed_forward = FeedForward(config.dim, ffn_hidden, activation=activation)
118
+ self.moe_enabled = False
119
+ elif moe_args and moe_args.num_experts > 0:
120
+ self.moe = MoE(moe_args, dim=config.dim, hidden_dim=config.moe_dim)
121
+ self.moe_enabled = True
122
+ else:
123
+ self.feed_forward = FeedForward(config.dim, config.moe_dim)
124
+ self.moe_enabled = False
125
+
126
+ self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5
127
+
128
+ def forward(self, x, freqs_cis, freqs_cis_2d=None, pos_thw=None, attention_masks=None, compile=False):
129
+ if self.parameterized_norm:
130
+ x_norm = self.attention_norm(x)
131
+ else:
132
+ x_norm = F.rms_norm(x, (x.size(-1),))
133
+ h = x + self.attention(
134
+ x_norm,
135
+ freqs_cis,
136
+ freqs_cis_2d,
137
+ pos_thw,
138
+ attention_masks=attention_masks,
139
+ compile=compile,
140
+ )
141
+ h_norm = self.ffn_norm(h) if self.parameterized_norm else F.rms_norm(h, (h.size(-1),))
142
+ out = h + self.moe(h_norm) if self.moe_enabled else h + self.feed_forward(h_norm)
143
+ return out
144
+
145
+ class SigLinoPreTrainedModel(PreTrainedModel):
146
+ config_class = SigLinoConfig
147
+ base_model_prefix = "siglino"
148
+ main_input_name = "pixel_values"
149
+ _no_split_modules = ["TransformerBlock"]
150
+
151
+ def _init_weights(self, module):
152
+ # Weight initialization is handled by the internal init_weights call in __init__
153
+ pass
154
+
155
+ def _apply(self, fn):
156
+ # Prevent casting complex RoPE buffers (freqs_cis) to real dtypes on model.to(bf16/fp16)
157
+ complex_buffers = {}
158
+ for name, buf in list(self.named_buffers(recurse=False)):
159
+ if buf is not None and buf.is_complex():
160
+ complex_buffers[name] = buf
161
+ del self._buffers[name]
162
+
163
+ ret = super()._apply(fn)
164
+
165
+ for name, buf in complex_buffers.items():
166
+ dummy = torch.tensor([0.0], device=buf.device)
167
+ res = fn(dummy)
168
+
169
+ if not res.is_complex():
170
+ new_buf = buf.to(device=res.device)
171
+ else:
172
+ new_buf = fn(buf)
173
+
174
+ persistent = name not in self._non_persistent_buffers_set
175
+ self.register_buffer(name, new_buf, persistent=persistent)
176
+
177
+ return ret
178
+
179
+
180
+ class SigLinoModel(SigLinoPreTrainedModel):
181
+ def __init__(self, config: SigLinoConfig):
182
+ super().__init__(config)
183
+ self.config = config
184
+ self.n_layers = config.n_layers
185
+ self.patch_size = config.spatial_patch_size
186
+ self.n_storage_tokens = config.n_storage_tokens
187
+
188
+ # Patch embedding
189
+ self.n_pixels_per_patch = config.temporal_patch_size * config.spatial_patch_size ** 2
190
+ self.img_projector = nn.Linear(
191
+ self.n_pixels_per_patch * config.channel_size,
192
+ config.dim,
193
+ bias=False,
194
+ )
195
+
196
+ self.cls_token = nn.Parameter(torch.empty(1, 1, config.dim))
197
+ if self.n_storage_tokens > 0:
198
+ self.storage_tokens = nn.Parameter(torch.empty(1, self.n_storage_tokens, config.dim))
199
+
200
+ # RoPE
201
+ head_dim = config.head_dim or config.dim // config.n_heads
202
+ d = head_dim // 2
203
+ self.register_buffer("freqs_cis_golden", self._precompute_golden_freqs_cis(d, config))
204
+ self.register_buffer("freqs_cis", self._precompute_freqs_cis(d, config), persistent=False)
205
+
206
+ self.layers = nn.ModuleList([TransformerBlock(i, config) for i in range(config.n_layers)])
207
+ self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
208
+
209
+ # Teacher adapters
210
+ teachers_dict = dict(zip(config.teachers, config.teachers_dim))
211
+ dinov3_dim = teachers_dict.get("dinov3", 1280)
212
+ siglip2_dim = teachers_dict.get("siglip2", 1152)
213
+
214
+ self.dinov3_adapter = Adapter(config.dim, dinov3_dim, bias=False)
215
+ self.siglip2_adapter = Adapter(config.dim, siglip2_dim, bias=False)
216
+ self.layer_norm_dinov3 = nn.LayerNorm(dinov3_dim)
217
+ self.siglip2_multihead_attention_pooling_head = Siglip2MultiheadAttentionPoolingHead(
218
+ siglip2_dim, 16, siglip2_dim
219
+ )
220
+
221
+ self.post_init()
222
+
223
+ def _precompute_freqs_cis(self, head_dim: int, config: SigLinoConfig) -> torch.Tensor:
224
+ return precompute_freqs_cis(head_dim, config.max_seq_len, config.rope_theta)
225
+
226
+ def _precompute_golden_freqs_cis(self, head_dim: int, config: SigLinoConfig) -> torch.Tensor:
227
+ return precompute_golden_freqs_cis(
228
+ config.n_heads, head_dim, config.rope_min_freqs, config.rope_max_freqs
229
+ )
230
+
231
+ def _get_thw_pos(self, batch_size, num_patches, spatial_shapes, device):
232
+ N = batch_size
233
+ R = 1 + self.n_storage_tokens
234
+ S = R + num_patches
235
+ tpos = torch.zeros((N, S), dtype=torch.float32, device=device)
236
+ hpos = torch.zeros((N, S), dtype=torch.float32, device=device)
237
+ wpos = torch.zeros((N, S), dtype=torch.float32, device=device)
238
+
239
+ for n in range(N):
240
+ H, W = spatial_shapes[n].tolist()
241
+ h_coords = torch.arange(H, device=device).float()
242
+ w_coords = torch.arange(W, device=device).float()
243
+ xlim, ylim = (W / H) ** 0.5, (H / W) ** 0.5
244
+ h_norm = -ylim + 2 * ylim * h_coords / max(H - 1, 1)
245
+ w_norm = -xlim + 2 * xlim * w_coords / max(W - 1, 1)
246
+
247
+ # Vectorized fill for patches
248
+ h_grid, w_grid = torch.meshgrid(h_norm, w_norm, indexing='ij')
249
+ hpos[n, R:R+H*W] = h_grid.reshape(-1)
250
+ wpos[n, R:R+H*W] = w_grid.reshape(-1)
251
+
252
+ hpos[n, :R], wpos[n, :R] = float('nan'), float('nan')
253
+
254
+ return torch.stack([tpos, hpos, wpos], dim=0)
255
+
256
+ def forward(
257
+ self,
258
+ pixel_values: torch.Tensor,
259
+ padding_mask: Optional[torch.Tensor] = None,
260
+ spatial_shapes: Optional[torch.Tensor] = None,
261
+ output_hidden_states: bool = False,
262
+ return_dict: bool = True,
263
+ compile: bool = True,
264
+ ) -> Union[Dict, Tuple]:
265
+ N, L, _ = pixel_values.shape
266
+ device = pixel_values.device
267
+ R = 1 + self.n_storage_tokens
268
+
269
+ if padding_mask is None:
270
+ padding_mask = torch.ones((N, L), dtype=pixel_values.dtype, device=device)
271
+
272
+ h_NLD = self.img_projector(pixel_values)
273
+ cls_expanded = self.cls_token.expand(N, -1, -1)
274
+ if self.n_storage_tokens > 0:
275
+ reg_expanded = self.storage_tokens.expand(N, -1, -1)
276
+ h_NSD = torch.cat([cls_expanded, reg_expanded, h_NLD], dim=1)
277
+ else:
278
+ h_NSD = torch.cat([cls_expanded, h_NLD], dim=1)
279
+
280
+ S = h_NSD.shape[1]
281
+ cls_reg_mask = torch.ones((N, R), dtype=padding_mask.dtype, device=device)
282
+ full_mask = torch.cat([cls_reg_mask, padding_mask], dim=1)
283
+
284
+ # FlexAttention Mask
285
+ def mask_mod(b, h, q_idx, kv_idx):
286
+ return full_mask.bool()[b, q_idx] & full_mask.bool()[b, kv_idx]
287
+
288
+ block_mask = create_attention_mask(mask_mod, N, None, S, S)
289
+
290
+ # RoPE
291
+ thw_pos = self._get_thw_pos(N, L, spatial_shapes, device)
292
+ pos_thw = E.rearrange(thw_pos, "p n s -> n s p").to(dtype=torch.float32)
293
+ patch_mask_2d = torch.zeros((N, S), dtype=torch.bool, device=device)
294
+ patch_mask_2d[:, R:] = padding_mask.bool()
295
+ pos_thw[:, :, 1:] = pos_thw[:, :, 1:].masked_fill(~patch_mask_2d.unsqueeze(-1), float("nan"))
296
+
297
+ freqs_cis_golden = apply_golden_freqs_cis_to_visual_pos(
298
+ self.freqs_cis_golden.to(dtype=pos_thw.dtype), pos_thw[:, :, 1:]
299
+ )
300
+
301
+ all_hidden_states = () if output_hidden_states else None
302
+ for layer in self.layers:
303
+ if output_hidden_states:
304
+ all_hidden_states += (h_NSD,)
305
+ h_NSD = layer(h_NSD, self.freqs_cis, freqs_cis_2d=freqs_cis_golden,
306
+ pos_thw=pos_thw, attention_masks=block_mask, compile=compile)
307
+
308
+ h_NSD = self.norm(h_NSD)
309
+
310
+ # Feature Extraction & Adapters
311
+ cls_feats = h_NSD[:, 0]
312
+ patch_feats = h_NSD[:, R:]
313
+
314
+ student_patch_dinov3 = self.dinov3_adapter(patch_feats)
315
+ student_patch_siglip = self.siglip2_adapter(patch_feats)
316
+ student_cls_dinov3 = self.dinov3_adapter(cls_feats)
317
+
318
+ h_sig = self.siglip2_adapter(h_NSD)
319
+ siglip_attn_mask = full_mask.reshape(-1)
320
+ student_summary_siglip = self.siglip2_multihead_attention_pooling_head(h_sig, siglip_attn_mask)
321
+
322
+ output = {
323
+ "last_hidden_state": h_NSD,
324
+ "patch_features": {
325
+ "dinov3": student_patch_dinov3,
326
+ "siglip2": student_patch_siglip,
327
+ "siglino": patch_feats,
328
+ },
329
+ "summary_features": {
330
+ "dinov3": student_cls_dinov3,
331
+ "siglip2": student_summary_siglip,
332
+ "siglino": cls_feats,
333
+ },
334
+ "hidden_states": all_hidden_states,
335
+ }
336
+
337
+ if not return_dict:
338
+ return tuple(v for v in output.values() if v is not None)
339
+ return output
preprocessor_config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "auto_map": {
3
- "AutoImageProcessor": "image_processing_amoe.AMOEImageProcessor"
4
  },
5
  "do_normalize": true,
6
  "do_rescale": true,
@@ -10,7 +10,7 @@
10
  0.5,
11
  0.5
12
  ],
13
- "image_processor_type": "AMOEImageProcessor",
14
  "image_std": [
15
  0.5,
16
  0.5,
 
1
  {
2
  "auto_map": {
3
+ "AutoImageProcessor": "image_processing_siglino.SigLinoImageProcessor"
4
  },
5
  "do_normalize": true,
6
  "do_rescale": true,
 
10
  0.5,
11
  0.5
12
  ],
13
+ "image_processor_type": "SigLinoImageProcessor",
14
  "image_std": [
15
  0.5,
16
  0.5,
utils.py CHANGED
@@ -9,21 +9,21 @@ from PIL import Image
9
  from typing import Union, List
10
  import os
11
 
12
- from .model import AMOE
13
- from .configs import AMOEArgs, amoe_configs
14
- from .image_processor import AMOEImageProcessor
15
 
16
 
17
 
18
- def load_amoe_model(
19
  checkpoint_path: str,
20
- config_name: str = "18-layers-distillation",
21
  device: Union[str, torch.device] = "cuda",
22
  dtype: torch.dtype | None = None,
23
  **kwargs,
24
- ) -> tuple[AMOE, AMOEImageProcessor]:
25
  """
26
- Load a AMOE model from a checkpoint.
27
 
28
  Args:
29
  checkpoint_path: Path to the model checkpoint
@@ -35,13 +35,13 @@ def load_amoe_model(
35
  Tuple of (model, image_processor)
36
  """
37
  # Get configuration
38
- if config_name in amoe_configs:
39
- args = amoe_configs[config_name]
40
  else:
41
- raise ValueError(f"Unknown config: {config_name}. Available: {list(amoe_configs.keys())}")
42
 
43
  # Create model
44
- model = AMOE(args)
45
 
46
  # Standard PyTorch checkpoint
47
  state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
@@ -55,7 +55,7 @@ def load_amoe_model(
55
  model.eval()
56
 
57
  # Create image processor
58
- image_processor = AMOEImageProcessor(patch_size=args.spatial_patch_size, **kwargs)
59
 
60
  return model, image_processor
61
 
@@ -178,7 +178,7 @@ def load_amoe_model(
178
  FEATURE_DIM_DICT = {
179
  "dinov3": 1024,
180
  "siglip2": 1152,
181
- "amoe": 768, # Model dimension
182
  }
183
 
184
  PATCH_SIZE = 16
 
9
  from typing import Union, List
10
  import os
11
 
12
+ from .model import SigLino
13
+ from .configs import SigLinoArgs, siglino_configs
14
+ from .image_processor import SigLinoImageProcessor
15
 
16
 
17
 
18
+ def load_siglino_model(
19
  checkpoint_path: str,
20
+ config_name: str = "siglino-0.3B",
21
  device: Union[str, torch.device] = "cuda",
22
  dtype: torch.dtype | None = None,
23
  **kwargs,
24
+ ) -> tuple[SigLino, SigLinoImageProcessor]:
25
  """
26
+ Load a SigLino model from a checkpoint.
27
 
28
  Args:
29
  checkpoint_path: Path to the model checkpoint
 
35
  Tuple of (model, image_processor)
36
  """
37
  # Get configuration
38
+ if config_name in siglino_configs:
39
+ args = siglino_configs[config_name]
40
  else:
41
+ raise ValueError(f"Unknown config: {config_name}. Available: {list(siglino_configs.keys())}")
42
 
43
  # Create model
44
+ model = SigLino(args)
45
 
46
  # Standard PyTorch checkpoint
47
  state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
 
55
  model.eval()
56
 
57
  # Create image processor
58
+ image_processor = SigLinoImageProcessor(patch_size=args.spatial_patch_size, **kwargs)
59
 
60
  return model, image_processor
61
 
 
178
  FEATURE_DIM_DICT = {
179
  "dinov3": 1024,
180
  "siglip2": 1152,
181
+ "siglino": 768, # Model dimension
182
  }
183
 
184
  PATCH_SIZE = 16