yuki-imajuku commited on
Commit
1a0768d
·
1 Parent(s): 0a50292

update to transformers v5

Browse files
Files changed (3) hide show
  1. README.md +28 -12
  2. config.json +1 -1
  3. modeling_metom.py +104 -82
README.md CHANGED
@@ -27,7 +27,7 @@ The final evaluation on the test subset yielded a micro accuracy of 0.9722 and a
27
  Please see also [Google Colab Notebook](https://colab.research.google.com/drive/1jFMZENoTjjum3qlBxV0Q5dTxmpCvqlpf?usp=sharing).
28
  1. Install dependencies (Not required on Google Colab)
29
  ```sh
30
- python -m pip install einops torch torchvision transformers
31
 
32
  # Optional (This is also required on Google Colab if you want to use FlashAttention-2)
33
  pip install flash-attn --no-build-isolation
@@ -40,7 +40,7 @@ from io import BytesIO
40
  from PIL import Image
41
  import requests
42
  import torch
43
- from transformers import AutoModel, AutoProcessor
44
 
45
  repo_name = "SakanaAI/Metom"
46
  device = "cuda"
@@ -49,13 +49,21 @@ torch_dtype = torch.float32 # This can also set `torch.float16` or `torch.bfloa
49
  def get_image(image_url: str) -> Image.Image:
50
  return Image.open(BytesIO(requests.get(image_url).content)).convert("RGB")
51
 
52
- processor = AutoProcessor.from_pretrained(repo_name, trust_remote_code=True)
53
  model = AutoModel.from_pretrained(
54
  repo_name,
55
- torch_dtype=torch_dtype,
56
- _attn_implementation="eager", # This can also set `"sdpa"` or `"flash_attention_2"`
57
  trust_remote_code=True
58
  ).to(device=device)
 
 
 
 
 
 
 
 
59
 
60
  image1 = get_image("https://huggingface.co/SakanaAI/Metom/resolve/main/examples/example1_4E00.jpg") # An example image
61
  image_array1 = processor(images=image1, return_tensors="pt")["pixel_values"].to(device=device, dtype=torch_dtype)
@@ -70,7 +78,7 @@ with torch.inference_mode():
70
  print(model.get_topk_labels(image_array2)) # Returns top-k prediction labels (label only)
71
  # [['定', '芝', '乏', '淀', '実'], ['倉', '衾', '斜', '会', '急']]
72
  print(model.get_topk_labels(image_array2, k=3, return_probs=True)) # Returns prediction top-k labels (label with probability)
73
- # [[('定', 0.9979104399681091), ('芝', 0.0002953427319880575), ('乏', 0.00012814522779081017)], [('倉', 0.9862521290779114), ('衾', 0.0005956474924460053), ('斜', 0.00039981433656066656)]]
74
  ```
75
 
76
  ## Citation
@@ -98,7 +106,7 @@ with torch.inference_mode():
98
  [Google Colab Notebook](https://colab.research.google.com/drive/1jFMZENoTjjum3qlBxV0Q5dTxmpCvqlpf?usp=sharing)もご確認ください。
99
  1. 依存ライブラリをインストールする (Google Colabを使う場合は不要)
100
  ```sh
101
- python -m pip install einops torch torchvision transformers
102
 
103
  # 任意 (FlashAttention-2を使いたい場合はGoogle Colabを使う時でも必要)
104
  pip install flash-attn --no-build-isolation
@@ -111,7 +119,7 @@ from io import BytesIO
111
  from PIL import Image
112
  import requests
113
  import torch
114
- from transformers import AutoModel, AutoProcessor
115
 
116
  repo_name = "SakanaAI/Metom"
117
  device = "cuda"
@@ -120,13 +128,21 @@ torch_dtype = torch.float32 # `torch.float16` や `torch.bfloat16` も指定可
120
  def get_image(image_url: str) -> Image.Image:
121
  return Image.open(BytesIO(requests.get(image_url).content)).convert("RGB")
122
 
123
- processor = AutoProcessor.from_pretrained(repo_name, trust_remote_code=True)
124
  model = AutoModel.from_pretrained(
125
  repo_name,
126
- torch_dtype=torch_dtype,
127
- _attn_implementation="eager", # `"sdpa"` `"flash_attention_2"` 指定可能
128
  trust_remote_code=True
129
  ).to(device=device)
 
 
 
 
 
 
 
 
130
 
131
  image1 = get_image("https://huggingface.co/SakanaAI/Metom/resolve/main/examples/example1_4E00.jpg") # 画像例
132
  image_array1 = processor(images=image1, return_tensors="pt")["pixel_values"].to(device=device, dtype=torch_dtype)
@@ -141,7 +157,7 @@ with torch.inference_mode():
141
  print(model.get_topk_labels(image_array2)) # 上位k件の予測ラベルを返す (ラベルのみ)
142
  # [['定', '芝', '乏', '淀', '実'], ['倉', '衾', '斜', '会', '急']]
143
  print(model.get_topk_labels(image_array2, k=3, return_probs=True)) # 上位k件の予測ラベルを返す (ラベルと確率)
144
- # [[('定', 0.9979104399681091), ('芝', 0.0002953427319880575), ('乏', 0.00012814522779081017)], [('倉', 0.9862521290779114), ('衾', 0.0005956474924460053), ('斜', 0.00039981433656066656)]]
145
  ```
146
 
147
  ## 引用
 
27
  Please see also [Google Colab Notebook](https://colab.research.google.com/drive/1jFMZENoTjjum3qlBxV0Q5dTxmpCvqlpf?usp=sharing).
28
  1. Install dependencies (Not required on Google Colab)
29
  ```sh
30
+ python -m pip install einops torch torchvision "transformers>=5.1.0"
31
 
32
  # Optional (This is also required on Google Colab if you want to use FlashAttention-2)
33
  pip install flash-attn --no-build-isolation
 
40
  from PIL import Image
41
  import requests
42
  import torch
43
+ from transformers import AutoImageProcessor, AutoModel
44
 
45
  repo_name = "SakanaAI/Metom"
46
  device = "cuda"
 
49
  def get_image(image_url: str) -> Image.Image:
50
  return Image.open(BytesIO(requests.get(image_url).content)).convert("RGB")
51
 
52
+ processor = AutoImageProcessor.from_pretrained(repo_name)
53
  model = AutoModel.from_pretrained(
54
  repo_name,
55
+ dtype=torch_dtype,
56
+ attn_implementation="sdpa", # This can also set `"eager"`, `"flash_attention_2"` or other methods supported in transformers v5 (https://huggingface.co/docs/transformers/main/en/attention_interface)
57
  trust_remote_code=True
58
  ).to(device=device)
59
+ # We still support transformers v4
60
+ # model = AutoModel.from_pretrained(
61
+ # repo_name,
62
+ # torch_dtype=torch_dtype,
63
+ # _attn_implementation="sdpa", # This can also set `"eager"` or `"flash_attention_2"`
64
+ # trust_remote_code=True,
65
+ # revision="transformers-v4", # Use this revision
66
+ # ).to(device=device)
67
 
68
  image1 = get_image("https://huggingface.co/SakanaAI/Metom/resolve/main/examples/example1_4E00.jpg") # An example image
69
  image_array1 = processor(images=image1, return_tensors="pt")["pixel_values"].to(device=device, dtype=torch_dtype)
 
78
  print(model.get_topk_labels(image_array2)) # Returns top-k prediction labels (label only)
79
  # [['定', '芝', '乏', '淀', '実'], ['倉', '衾', '斜', '会', '急']]
80
  print(model.get_topk_labels(image_array2, k=3, return_probs=True)) # Returns prediction top-k labels (label with probability)
81
+ # [[('定', 0.9979110360145569), ('芝', 0.0002953446237370372), ('乏', 0.0001281465229112655)], [('倉', 0.9862518906593323), ('衾', 0.0005956498789601028), ('斜', 0.000399815384298563)]]
82
  ```
83
 
84
  ## Citation
 
106
  [Google Colab Notebook](https://colab.research.google.com/drive/1jFMZENoTjjum3qlBxV0Q5dTxmpCvqlpf?usp=sharing)もご確認ください。
107
  1. 依存ライブラリをインストールする (Google Colabを使う場合は不要)
108
  ```sh
109
+ python -m pip install einops torch torchvision "transformers>=5.1.0"
110
 
111
  # 任意 (FlashAttention-2を使いたい場合はGoogle Colabを使う時でも必要)
112
  pip install flash-attn --no-build-isolation
 
119
  from PIL import Image
120
  import requests
121
  import torch
122
+ from transformers import AutoImageProcessor, AutoModel
123
 
124
  repo_name = "SakanaAI/Metom"
125
  device = "cuda"
 
128
  def get_image(image_url: str) -> Image.Image:
129
  return Image.open(BytesIO(requests.get(image_url).content)).convert("RGB")
130
 
131
+ processor = AutoImageProcessor.from_pretrained(repo_name)
132
  model = AutoModel.from_pretrained(
133
  repo_name,
134
+ dtype=torch_dtype,
135
+ attn_implementation="sdpa", # `"eager"`, `"flash_attention_2"` および transformers v5 でサポートされている Attention backends を指定可能 (https://huggingface.co/docs/transformers/main/en/attention_interface)
136
  trust_remote_code=True
137
  ).to(device=device)
138
+ # transformers v4 もサポート
139
+ # model = AutoModel.from_pretrained(
140
+ # repo_name,
141
+ # torch_dtype=torch_dtype,
142
+ # _attn_implementation="sdpa", # `"eager"` や `"flash_attention_2"` も指定可能
143
+ # trust_remote_code=True,
144
+ # revision="transformers-v4", # この revision を使用
145
+ # ).to(device=device)
146
 
147
  image1 = get_image("https://huggingface.co/SakanaAI/Metom/resolve/main/examples/example1_4E00.jpg") # 画像例
148
  image_array1 = processor(images=image1, return_tensors="pt")["pixel_values"].to(device=device, dtype=torch_dtype)
 
157
  print(model.get_topk_labels(image_array2)) # 上位k件の予測ラベルを返す (ラベルのみ)
158
  # [['定', '芝', '乏', '淀', '実'], ['倉', '衾', '斜', '会', '急']]
159
  print(model.get_topk_labels(image_array2, k=3, return_probs=True)) # 上位k件の予測ラベルを返す (ラベルと確率)
160
+ # [[('定', 0.9979110360145569), ('芝', 0.0002953446237370372), ('乏', 0.0001281465229112655)], [('倉', 0.9862518906593323), ('衾', 0.0005956498789601028), ('斜', 0.000399815384298563)]]
161
  ```
162
 
163
  ## 引用
config.json CHANGED
@@ -2720,5 +2720,5 @@
2720
  "model_type": "metom",
2721
  "patch_size": 16,
2722
  "pool": "cls",
2723
- "transformers_version": "4.46.2"
2724
  }
 
2720
  "model_type": "metom",
2721
  "patch_size": 16,
2722
  "pool": "cls",
2723
+ "transformers_version": "5.1.0"
2724
  }
modeling_metom.py CHANGED
@@ -1,27 +1,49 @@
1
  # This file is a modified version of the Vision Transformer - Pytorch implementation
2
  # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
3
- from typing import List, Union, Tuple
4
 
5
  from einops import rearrange, repeat
6
  from einops.layers.torch import Rearrange
7
  import torch
8
  from torch import nn
9
  from transformers import PreTrainedModel
 
 
10
 
11
  from .configuration_metom import MetomConfig
12
 
13
-
14
- try:
15
- from flash_attn import flash_attn_func
16
- FLASH_ATTENTION_2_AVAILABLE = True
17
- except ImportError:
18
- FLASH_ATTENTION_2_AVAILABLE = False
19
 
20
 
21
  def size_pair(t):
22
  return t if isinstance(t, tuple) else (t, t)
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class MetomFeedForward(nn.Module):
26
  def __init__(self, dim, hidden_dim, dropout):
27
  super().__init__()
@@ -31,7 +53,7 @@ class MetomFeedForward(nn.Module):
31
  nn.GELU(),
32
  nn.Dropout(dropout),
33
  nn.Linear(hidden_dim, dim),
34
- nn.Dropout(dropout)
35
  )
36
 
37
  def forward(self, x):
@@ -39,98 +61,99 @@ class MetomFeedForward(nn.Module):
39
 
40
 
41
  class MetomAttention(nn.Module):
42
- def __init__(self, dim, heads, dim_head, dropout):
43
  super().__init__()
44
- inner_dim = dim_head * heads
45
- project_out = not (heads == 1 and dim_head == dim)
46
- self.heads = heads
47
- self.scale = dim_head ** -0.5
48
- self.norm = nn.LayerNorm(dim)
49
- self.attend = nn.Softmax(dim = -1)
50
- self.dropout = nn.Dropout(dropout)
51
- self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
 
52
  self.to_out = nn.Sequential(
53
- nn.Linear(inner_dim, dim),
54
- nn.Dropout(dropout)
55
  ) if project_out else nn.Identity()
56
 
57
- def forward(self, x):
58
- x = self.norm(x)
59
- qkv = self.to_qkv(x).chunk(3, dim = -1)
60
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h = self.heads), qkv)
61
- dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
62
- attn = self.attend(dots)
63
- attn = self.dropout(attn)
64
- out = torch.matmul(attn, v)
65
- out = rearrange(out, "b h n d -> b n (h d)")
66
- return self.to_out(out)
67
-
68
-
69
- class MetomSdpaAttention(MetomAttention):
70
- def forward(self, x):
71
- x = self.norm(x)
72
- qkv = self.to_qkv(x).chunk(3, dim = -1)
73
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h = self.heads), qkv)
74
- out = nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
75
- out = rearrange(out, "b h n d -> b n (h d)")
76
- return self.to_out(out)
77
-
78
-
79
- class MetomFlashAttention2(MetomAttention):
80
- def forward(self, x):
81
  x = self.norm(x)
82
- qkv = self.to_qkv(x).chunk(3, dim = -1)
83
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h = self.heads), qkv)
84
- out = flash_attn_func(q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
85
- out = rearrange(out, "b h n d -> b n (h d)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  return self.to_out(out)
87
 
88
 
89
  class MetomTransformer(nn.Module):
90
- def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, _attn_implementation = "eager"):
91
  super().__init__()
92
- if _attn_implementation == "flash_attention_2":
93
- assert FLASH_ATTENTION_2_AVAILABLE, "FlashAttention-2 is not available. Please install `flash-attn`."
94
- attn_cls = (
95
- MetomAttention if _attn_implementation == "eager" else
96
- MetomSdpaAttention if _attn_implementation == "sdpa" else
97
- MetomFlashAttention2 if _attn_implementation == "flash_attention_2" else
98
- MetomAttention
99
- )
100
- self.norm = nn.LayerNorm(dim)
101
  self.layers = nn.ModuleList([])
102
- for _ in range(depth):
103
- self.layers.append(nn.ModuleList([
104
- attn_cls(dim, heads = heads, dim_head = dim_head, dropout = dropout),
105
- MetomFeedForward(dim, mlp_dim, dropout = dropout)
106
- ]))
107
-
108
- def forward(self, x):
 
 
 
 
109
  for attn, ff in self.layers:
110
- x = attn(x) + x
111
  x = ff(x) + x
112
  return self.norm(x)
113
 
114
 
115
  class MetomModel(PreTrainedModel):
116
  config_class = MetomConfig
 
 
 
117
  _supports_flash_attn_2 = True
118
  _supports_sdpa = True
 
119
 
120
  def __init__(self, config: MetomConfig):
121
  super().__init__(config)
122
  image_height, image_width = size_pair(config.image_size)
123
  patch_height, patch_width = size_pair(config.patch_size)
124
- assert image_height % patch_height == 0 and image_width % patch_width == 0, "Image dimensions must be divisible by the patch size."
 
 
125
 
126
  num_patches = (image_height // patch_height) * (image_width // patch_width)
127
  patch_dim = config.channels * patch_height * patch_width
128
  assert config.pool in {"cls", "mean"}, "pool type must be either cls (cls token) or mean (mean pooling)"
129
  assert len(config.labels) > 0, "labels must be composed of at least one label"
130
- assert config._attn_implementation in {"eager", "sdpa", "flash_attention_2"}, "Attention implementation must be either eager, sdpa or flash_attention_2"
131
 
132
  self.to_patch_embedding = nn.Sequential(
133
- Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
134
  nn.LayerNorm(patch_dim),
135
  nn.Linear(patch_dim, config.dim),
136
  nn.LayerNorm(config.dim),
@@ -138,36 +161,35 @@ class MetomModel(PreTrainedModel):
138
  self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, config.dim))
139
  self.cls_token = nn.Parameter(torch.randn(1, 1, config.dim))
140
  self.dropout = nn.Dropout(config.emb_dropout)
141
- self.transformer = MetomTransformer(
142
- config.dim, config.depth, config.heads, config.dim_head, config.mlp_dim, config.dropout, config._attn_implementation
143
- )
144
  self.pool = config.pool
145
  self.to_latent = nn.Identity()
146
  self.mlp_head = nn.Linear(config.dim, len(config.labels))
147
  self.labels = config.labels
 
148
 
149
- def forward(self, processed_image):
150
- x = self.to_patch_embedding(processed_image)
151
  b, n, _ = x.shape
152
- cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b = b)
153
  x = torch.cat((cls_tokens, x), dim=1)
154
- x += self.pos_embedding[:, :(n + 1)]
155
  x = self.dropout(x)
156
- x = self.transformer(x)
157
- x = x.mean(dim = 1) if self.pool == "mean" else x[:, 0]
158
  x = self.to_latent(x)
159
  return self.mlp_head(x)
160
 
161
- def get_predictions(self, processed_image: torch.Tensor) -> List[str]:
162
- logits = self(processed_image)
163
  indices = torch.argmax(logits, dim=-1)
164
  return [self.labels[i] for i in indices]
165
 
166
  def get_topk_labels(
167
- self, processed_image: torch.Tensor, k: int = 5, return_probs: bool = False
168
- ) -> Union[List[List[str]], List[List[Tuple[str, float]]]]:
169
  assert 0 < k <= len(self.labels), "k must be a positive integer less than or equal to the number of labels"
170
- logits = self(processed_image)
171
  probs = torch.softmax(logits, dim=-1)
172
  topk_probs, topk_indices = torch.topk(probs, k, dim=-1)
173
  topk_labels = [[self.labels[i] for i in ti] for ti in topk_indices]
 
1
  # This file is a modified version of the Vision Transformer - Pytorch implementation
2
  # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
3
+ from collections.abc import Callable
4
 
5
  from einops import rearrange, repeat
6
  from einops.layers.torch import Rearrange
7
  import torch
8
  from torch import nn
9
  from transformers import PreTrainedModel
10
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
11
+ from transformers.utils import logging
12
 
13
  from .configuration_metom import MetomConfig
14
 
15
+ logger = logging.get_logger(__name__)
 
 
 
 
 
16
 
17
 
18
  def size_pair(t):
19
  return t if isinstance(t, tuple) else (t, t)
20
 
21
 
22
+ def metom_eager_attention_forward(
23
+ module: nn.Module,
24
+ query: torch.Tensor,
25
+ key: torch.Tensor,
26
+ value: torch.Tensor,
27
+ attention_mask: torch.Tensor | None,
28
+ scaling: float | None = None,
29
+ dropout: float = 0.0,
30
+ **kwargs,
31
+ ) -> tuple[torch.Tensor, torch.Tensor]:
32
+ if scaling is None:
33
+ scaling = query.size(-1) ** -0.5
34
+
35
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
36
+ if attention_mask is not None:
37
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
38
+ attn_weights = attn_weights + attention_mask
39
+
40
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
41
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
42
+ attn_output = torch.matmul(attn_weights, value)
43
+ attn_output = attn_output.transpose(1, 2).contiguous()
44
+ return attn_output, attn_weights
45
+
46
+
47
  class MetomFeedForward(nn.Module):
48
  def __init__(self, dim, hidden_dim, dropout):
49
  super().__init__()
 
53
  nn.GELU(),
54
  nn.Dropout(dropout),
55
  nn.Linear(hidden_dim, dim),
56
+ nn.Dropout(dropout),
57
  )
58
 
59
  def forward(self, x):
 
61
 
62
 
63
  class MetomAttention(nn.Module):
64
+ def __init__(self, config: MetomConfig):
65
  super().__init__()
66
+ inner_dim = config.dim_head * config.heads
67
+ project_out = not (config.heads == 1 and config.dim_head == config.dim)
68
+
69
+ self.config = config
70
+ self.heads = config.heads
71
+ self.scale = config.dim_head ** -0.5
72
+ self.norm = nn.LayerNorm(config.dim)
73
+ self.dropout = nn.Dropout(config.dropout)
74
+ self.to_qkv = nn.Linear(config.dim, inner_dim * 3, bias=False)
75
  self.to_out = nn.Sequential(
76
+ nn.Linear(inner_dim, config.dim),
77
+ nn.Dropout(config.dropout),
78
  ) if project_out else nn.Identity()
79
 
80
+ def forward(self, x: torch.Tensor, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  x = self.norm(x)
82
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
83
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
84
+
85
+ attn_implementation = self.config._attn_implementation or "eager"
86
+ if attn_implementation == "flex_attention":
87
+ if self.training and self.dropout.p > 0:
88
+ logger.warning_once(
89
+ "`flex_attention` does not support attention dropout during training. Falling back to `eager`."
90
+ )
91
+ attn_implementation = "eager"
92
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
93
+ attn_implementation,
94
+ metom_eager_attention_forward,
95
+ )
96
+ out, _ = attention_interface(
97
+ self,
98
+ q,
99
+ k,
100
+ v,
101
+ None,
102
+ is_causal=False,
103
+ scaling=self.scale,
104
+ dropout=0.0 if not self.training else self.dropout.p,
105
+ **kwargs,
106
+ )
107
+ out = rearrange(out, "b n h d -> b n (h d)")
108
  return self.to_out(out)
109
 
110
 
111
  class MetomTransformer(nn.Module):
112
+ def __init__(self, config: MetomConfig):
113
  super().__init__()
114
+ self.norm = nn.LayerNorm(config.dim)
 
 
 
 
 
 
 
 
115
  self.layers = nn.ModuleList([])
116
+ for _ in range(config.depth):
117
+ self.layers.append(
118
+ nn.ModuleList(
119
+ [
120
+ MetomAttention(config),
121
+ MetomFeedForward(config.dim, config.mlp_dim, dropout=config.dropout),
122
+ ]
123
+ )
124
+ )
125
+
126
+ def forward(self, x: torch.Tensor, **kwargs):
127
  for attn, ff in self.layers:
128
+ x = attn(x, **kwargs) + x
129
  x = ff(x) + x
130
  return self.norm(x)
131
 
132
 
133
  class MetomModel(PreTrainedModel):
134
  config_class = MetomConfig
135
+ main_input_name = "pixel_values"
136
+ _supports_attention_backend = True
137
+ _supports_flash_attn = True
138
  _supports_flash_attn_2 = True
139
  _supports_sdpa = True
140
+ _supports_flex_attn = True
141
 
142
  def __init__(self, config: MetomConfig):
143
  super().__init__(config)
144
  image_height, image_width = size_pair(config.image_size)
145
  patch_height, patch_width = size_pair(config.patch_size)
146
+ assert image_height % patch_height == 0 and image_width % patch_width == 0, (
147
+ "Image dimensions must be divisible by the patch size."
148
+ )
149
 
150
  num_patches = (image_height // patch_height) * (image_width // patch_width)
151
  patch_dim = config.channels * patch_height * patch_width
152
  assert config.pool in {"cls", "mean"}, "pool type must be either cls (cls token) or mean (mean pooling)"
153
  assert len(config.labels) > 0, "labels must be composed of at least one label"
 
154
 
155
  self.to_patch_embedding = nn.Sequential(
156
+ Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_height, p2=patch_width),
157
  nn.LayerNorm(patch_dim),
158
  nn.Linear(patch_dim, config.dim),
159
  nn.LayerNorm(config.dim),
 
161
  self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, config.dim))
162
  self.cls_token = nn.Parameter(torch.randn(1, 1, config.dim))
163
  self.dropout = nn.Dropout(config.emb_dropout)
164
+ self.transformer = MetomTransformer(config)
 
 
165
  self.pool = config.pool
166
  self.to_latent = nn.Identity()
167
  self.mlp_head = nn.Linear(config.dim, len(config.labels))
168
  self.labels = config.labels
169
+ self.post_init()
170
 
171
+ def forward(self, pixel_values: torch.Tensor, **kwargs):
172
+ x = self.to_patch_embedding(pixel_values)
173
  b, n, _ = x.shape
174
+ cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=b)
175
  x = torch.cat((cls_tokens, x), dim=1)
176
+ x += self.pos_embedding[:, : (n + 1)]
177
  x = self.dropout(x)
178
+ x = self.transformer(x, **kwargs)
179
+ x = x.mean(dim=1) if self.pool == "mean" else x[:, 0]
180
  x = self.to_latent(x)
181
  return self.mlp_head(x)
182
 
183
+ def get_predictions(self, pixel_values: torch.Tensor) -> list[str]:
184
+ logits = self(pixel_values=pixel_values)
185
  indices = torch.argmax(logits, dim=-1)
186
  return [self.labels[i] for i in indices]
187
 
188
  def get_topk_labels(
189
+ self, pixel_values: torch.Tensor, k: int = 5, return_probs: bool = False
190
+ ) -> list[list[str]] | list[list[tuple[str, float]]]:
191
  assert 0 < k <= len(self.labels), "k must be a positive integer less than or equal to the number of labels"
192
+ logits = self(pixel_values=pixel_values)
193
  probs = torch.softmax(logits, dim=-1)
194
  topk_probs, topk_indices = torch.topk(probs, k, dim=-1)
195
  topk_labels = [[self.labels[i] for i in ti] for ti in topk_indices]