Tang-xiaoxiao commited on
Commit
502dbff
·
verified ·
1 Parent(s): ed49b93

Delete modeling_m3d_lamed.py

Browse files
Files changed (1) hide show
  1. modeling_m3d_lamed.py +0 -2105
modeling_m3d_lamed.py DELETED
@@ -1,2105 +0,0 @@
1
- from __future__ import annotations
2
- from typing import Union
3
- from transformers import Phi3Config, Phi3Model, Phi3ForCausalLM
4
- from transformers.modeling_outputs import CausalLMOutputWithPast
5
- from transformers.generation.utils import GenerateOutput
6
- from .configuration_m3d_lamed import LamedPhi3Config
7
- from abc import ABC, abstractmethod
8
- from torch import Tensor
9
- import math
10
- from typing import Any, Dict, List
11
- import torch
12
- import torch.nn as nn
13
- from typing import Optional, Tuple, Type
14
- from monai.networks.blocks import PatchEmbed
15
- import numpy as np
16
- import torch.nn.functional as F
17
-
18
- from einops import rearrange
19
- from einops.layers.torch import Rearrange
20
- from collections.abc import Sequence
21
- from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
22
- from monai.networks.blocks.transformerblock import TransformerBlock
23
- from monai.networks.nets import ViT
24
-
25
-
26
- class BinaryDiceLoss(nn.Module):
27
- def __init__(self, smooth=1, p=2, reduction='mean'):
28
- super(BinaryDiceLoss, self).__init__()
29
- self.smooth = smooth
30
- self.p = p
31
- self.reduction = reduction
32
-
33
- def forward(self, predict, target):
34
- predict = torch.sigmoid(predict)
35
- target_ = target.clone().float()
36
- target_[target == -1] = 0
37
- assert predict.shape[0] == target.shape[0], "predict & target batch size don't match\n" + str(predict.shape) + '\n' + str(target.shape[0])
38
- predict = predict.contiguous().view(predict.shape[0], -1)
39
- target_ = target_.contiguous().view(target_.shape[0], -1)
40
-
41
- num = torch.sum(torch.mul(predict, target_), dim=1)
42
- den = torch.sum(predict, dim=1) + torch.sum(target_, dim=1) + self.smooth
43
-
44
- dice_score = 2*num / den
45
- dice_loss = 1 - dice_score
46
-
47
- # dice_loss_avg = dice_loss[target[:,0]!=-1].sum() / dice_loss[target[:,0]!=-1].shape[0]
48
- dice_loss_avg = dice_loss.sum() / dice_loss.shape[0]
49
-
50
- return dice_loss_avg
51
-
52
- class BCELoss(nn.Module):
53
- def __init__(self):
54
- super(BCELoss, self).__init__()
55
- self.criterion = nn.BCEWithLogitsLoss()
56
-
57
- def forward(self, predict, target):
58
- assert predict.shape == target.shape, 'predict & target shape do not match\n' + str(predict.shape) + '\n' + str(target.shape)
59
- target_ = target.clone()
60
- target_[target == -1] = 0
61
-
62
- ce_loss = self.criterion(predict, target_.float())
63
-
64
- return ce_loss
65
-
66
-
67
-
68
- class LayerNorm2d(nn.Module):
69
- def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
70
- super().__init__()
71
- self.weight = nn.Parameter(torch.ones(num_channels))
72
- self.bias = nn.Parameter(torch.zeros(num_channels))
73
- self.eps = eps
74
-
75
- def forward(self, x: torch.Tensor) -> torch.Tensor:
76
- u = x.mean(1, keepdim=True)
77
- s = (x - u).pow(2).mean(1, keepdim=True)
78
- x = (x - u) / torch.sqrt(s + self.eps)
79
- x = self.weight[:, None, None] * x + self.bias[:, None, None]
80
- return x
81
-
82
-
83
- class MLPBlock(nn.Module):
84
- def __init__(
85
- self,
86
- embedding_dim: int,
87
- mlp_dim: int,
88
- act: Type[nn.Module] = nn.GELU,
89
- ) -> None:
90
- super().__init__()
91
- self.lin1 = nn.Linear(embedding_dim, mlp_dim)
92
- self.lin2 = nn.Linear(mlp_dim, embedding_dim)
93
- self.act = act()
94
-
95
- def forward(self, x: torch.Tensor) -> torch.Tensor:
96
- return self.lin2(self.act(self.lin1(x)))
97
-
98
-
99
- class TwoWayTransformer(nn.Module):
100
- def __init__(
101
- self,
102
- depth: int,
103
- embedding_dim: int,
104
- num_heads: int,
105
- mlp_dim: int,
106
- activation: Type[nn.Module] = nn.ReLU,
107
- attention_downsample_rate: int = 2,
108
- ) -> None:
109
- """
110
- A transformer decoder that attends to an input image using
111
- queries whose positional embedding is supplied.
112
-
113
- Args:
114
- depth (int): number of layers in the transformer
115
- embedding_dim (int): the channel dimension for the input embeddings
116
- num_heads (int): the number of heads for multihead attention. Must
117
- divide embedding_dim
118
- mlp_dim (int): the channel dimension internal to the MLP block
119
- activation (nn.Module): the activation to use in the MLP block
120
- """
121
- super().__init__()
122
- self.depth = depth
123
- self.embedding_dim = embedding_dim
124
- self.num_heads = num_heads
125
- self.mlp_dim = mlp_dim
126
- self.layers = nn.ModuleList()
127
-
128
- for i in range(depth):
129
- self.layers.append(
130
- TwoWayAttentionBlock(
131
- embedding_dim=embedding_dim,
132
- num_heads=num_heads,
133
- mlp_dim=mlp_dim,
134
- activation=activation,
135
- attention_downsample_rate=attention_downsample_rate,
136
- skip_first_layer_pe=(i == 0),
137
- )
138
- )
139
-
140
- self.final_attn_token_to_image = Attention(
141
- embedding_dim, num_heads, downsample_rate=attention_downsample_rate
142
- )
143
- self.norm_final_attn = nn.LayerNorm(embedding_dim)
144
-
145
- def forward(
146
- self,
147
- image_embedding: Tensor,
148
- image_pe: Tensor,
149
- point_embedding: Tensor,
150
- ) -> Tuple[Tensor, Tensor]:
151
- """
152
- Args:
153
- image_embedding (torch.Tensor): image to attend to. Should be shape
154
- B x embedding_dim x h x w for any h and w.
155
- image_pe (torch.Tensor): the positional encoding to add to the image. Must
156
- have the same shape as image_embedding.
157
- point_embedding (torch.Tensor): the embedding to add to the query points.
158
- Must have shape B x N_points x embedding_dim for any N_points.
159
-
160
- Returns:
161
- torch.Tensor: the processed point_embedding
162
- torch.Tensor: the processed image_embedding
163
- """
164
- # BxCxHxW -> BxHWxC == B x N_image_tokens x C
165
- bs, c, h, w, d = image_embedding.shape
166
- image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
167
- image_pe = image_pe.flatten(2).permute(0, 2, 1)
168
-
169
- # Prepare queries
170
- queries = point_embedding
171
- keys = image_embedding
172
-
173
- # Apply transformer blocks and final layernorm
174
- for layer in self.layers:
175
- queries, keys = layer(
176
- queries=queries,
177
- keys=keys,
178
- query_pe=point_embedding,
179
- key_pe=image_pe,
180
- )
181
-
182
- # Apply the final attention layer from the points to the image
183
- q = queries + point_embedding
184
- k = keys + image_pe
185
- attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
186
- queries = queries + attn_out
187
- queries = self.norm_final_attn(queries)
188
-
189
- return queries, keys
190
-
191
-
192
- class TwoWayAttentionBlock(nn.Module):
193
- def __init__(
194
- self,
195
- embedding_dim: int,
196
- num_heads: int,
197
- mlp_dim: int = 2048,
198
- activation: Type[nn.Module] = nn.ReLU,
199
- attention_downsample_rate: int = 2,
200
- skip_first_layer_pe: bool = False,
201
- ) -> None:
202
- """
203
- A transformer block with four layers: (1) self-attention of sparse
204
- inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
205
- block on sparse inputs, and (4) cross attention of dense inputs to sparse
206
- inputs.
207
-
208
- Arguments:
209
- embedding_dim (int): the channel dimension of the embeddings
210
- num_heads (int): the number of heads in the attention layers
211
- mlp_dim (int): the hidden dimension of the mlp block
212
- activation (nn.Module): the activation of the mlp block
213
- skip_first_layer_pe (bool): skip the PE on the first layer
214
- """
215
- super().__init__()
216
- self.self_attn = Attention(embedding_dim, num_heads)
217
- self.norm1 = nn.LayerNorm(embedding_dim)
218
-
219
- self.cross_attn_token_to_image = Attention(
220
- embedding_dim, num_heads, downsample_rate=attention_downsample_rate
221
- )
222
- self.norm2 = nn.LayerNorm(embedding_dim)
223
-
224
- self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
225
- self.norm3 = nn.LayerNorm(embedding_dim)
226
-
227
- self.norm4 = nn.LayerNorm(embedding_dim)
228
- self.cross_attn_image_to_token = Attention(
229
- embedding_dim, num_heads, downsample_rate=attention_downsample_rate
230
- )
231
-
232
- self.skip_first_layer_pe = skip_first_layer_pe
233
-
234
- def forward(
235
- self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
236
- ) -> Tuple[Tensor, Tensor]:
237
- # Self attention block
238
- if self.skip_first_layer_pe:
239
- queries = self.self_attn(q=queries, k=queries, v=queries)
240
- else:
241
- q = queries + query_pe
242
- attn_out = self.self_attn(q=q, k=q, v=queries)
243
- queries = queries + attn_out
244
- queries = self.norm1(queries)
245
-
246
- # Cross attention block, tokens attending to image embedding
247
- q = queries + query_pe
248
- k = keys + key_pe
249
- attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
250
- queries = queries + attn_out
251
- queries = self.norm2(queries)
252
-
253
- # MLP block
254
- mlp_out = self.mlp(queries)
255
- queries = queries + mlp_out
256
- queries = self.norm3(queries)
257
-
258
- # Cross attention block, image embedding attending to tokens
259
- q = queries + query_pe
260
- k = keys + key_pe
261
- attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
262
- keys = keys + attn_out
263
- keys = self.norm4(keys)
264
-
265
- return queries, keys
266
-
267
-
268
- class Attention(nn.Module):
269
- """
270
- An attention layer that allows for downscaling the size of the embedding
271
- after projection to queries, keys, and values.
272
- """
273
-
274
- def __init__(
275
- self,
276
- embedding_dim: int,
277
- num_heads: int,
278
- downsample_rate: int = 1,
279
- ) -> None:
280
- super().__init__()
281
- self.embedding_dim = embedding_dim
282
- self.internal_dim = embedding_dim // downsample_rate
283
- self.num_heads = num_heads
284
- assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
285
-
286
- self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
287
- self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
288
- self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
289
- self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
290
-
291
- def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
292
- b, n, c = x.shape
293
- x = x.reshape(b, n, num_heads, c // num_heads)
294
- return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
295
-
296
- def _recombine_heads(self, x: Tensor) -> Tensor:
297
- b, n_heads, n_tokens, c_per_head = x.shape
298
- x = x.transpose(1, 2)
299
- return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
300
-
301
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
302
- # Input projections
303
- q = self.q_proj(q)
304
- k = self.k_proj(k)
305
- v = self.v_proj(v)
306
-
307
- # Separate into heads
308
- q = self._separate_heads(q, self.num_heads)
309
- k = self._separate_heads(k, self.num_heads)
310
- v = self._separate_heads(v, self.num_heads)
311
-
312
- # Attention
313
- _, _, _, c_per_head = q.shape
314
- attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
315
- attn = attn / math.sqrt(c_per_head)
316
- attn = torch.softmax(attn, dim=-1)
317
-
318
- # Get output
319
- out = attn @ v
320
- out = self._recombine_heads(out)
321
- out = self.out_proj(out)
322
-
323
- return out
324
-
325
-
326
-
327
- # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
328
- class ImageEncoderViT(nn.Module):
329
- def __init__(
330
- self,
331
- img_size: int = 1024,
332
- patch_size: int = 16,
333
- in_chans: int = 1,
334
- embed_dim: int = 768,
335
- depth: int = 12,
336
- num_heads: int = 12,
337
- mlp_ratio: float = 4.0,
338
- out_chans: int = 256,
339
- qkv_bias: bool = True,
340
- norm_layer: Type[nn.Module] = nn.LayerNorm,
341
- act_layer: Type[nn.Module] = nn.GELU,
342
- use_abs_pos: bool = True,
343
- use_rel_pos: bool = False,
344
- rel_pos_zero_init: bool = True,
345
- window_size: int = 0,
346
- global_attn_indexes: Tuple[int, ...] = (),
347
- ) -> None:
348
- """
349
- Args:
350
- img_size (int): Input image size.
351
- patch_size (int): Patch size.
352
- in_chans (int): Number of input image channels.
353
- embed_dim (int): Patch embedding dimension.
354
- depth (int): Depth of ViT.
355
- num_heads (int): Number of attention heads in each ViT block.
356
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
357
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
358
- norm_layer (nn.Module): Normalization layer.
359
- act_layer (nn.Module): Activation layer.
360
- use_abs_pos (bool): If True, use absolute positional embeddings.
361
- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
362
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
363
- window_size (int): Window size for window attention blocks.
364
- global_attn_indexes (list): Indexes for blocks using global attention.
365
- """
366
- super().__init__()
367
- self.img_size = img_size
368
-
369
- # self.patch_embed = PatchEmbed(
370
- # kernel_size=(patch_size, patch_size),
371
- # stride=(patch_size, patch_size),
372
- # in_chans=in_chans,
373
- # embed_dim=embed_dim,
374
- # )
375
-
376
- self.patch_embed = PatchEmbed(
377
- patch_size=patch_size,
378
- in_chans=in_chans,
379
- embed_dim=embed_dim,
380
- spatial_dims=3,
381
- )
382
-
383
- self.pos_embed: Optional[nn.Parameter] = None
384
- if use_abs_pos:
385
- # Initialize absolute positional embedding with pretrain image size.
386
- self.pos_embed = nn.Parameter(
387
- torch.zeros(1, img_size // patch_size, img_size // patch_size, img_size // patch_size, embed_dim)
388
- )
389
-
390
- self.blocks = nn.ModuleList()
391
- for i in range(depth):
392
- block = Block(
393
- dim=embed_dim,
394
- num_heads=num_heads,
395
- mlp_ratio=mlp_ratio,
396
- qkv_bias=qkv_bias,
397
- norm_layer=norm_layer,
398
- act_layer=act_layer,
399
- use_rel_pos=use_rel_pos,
400
- rel_pos_zero_init=rel_pos_zero_init,
401
- window_size=window_size if i not in global_attn_indexes else 0,
402
- input_size=(img_size // patch_size, img_size // patch_size),
403
- )
404
- self.blocks.append(block)
405
-
406
- self.neck = nn.Sequential(
407
- nn.Conv2d(
408
- embed_dim,
409
- out_chans,
410
- kernel_size=1,
411
- bias=False,
412
- ),
413
- LayerNorm2d(out_chans),
414
- nn.Conv2d(
415
- out_chans,
416
- out_chans,
417
- kernel_size=3,
418
- padding=1,
419
- bias=False,
420
- ),
421
- LayerNorm2d(out_chans),
422
- )
423
-
424
- def forward(self, x: torch.Tensor) -> torch.Tensor:
425
- x = self.patch_embed(x)
426
- print('patch embedded shape: ', x.shape) # embedded: [8, 768, 6, 6, 6]
427
- if self.pos_embed is not None:
428
- x = x + self.pos_embed
429
-
430
- for blk in self.blocks:
431
- x = blk(x)
432
-
433
- x = self.neck(x.permute(0, 3, 1, 2))
434
-
435
- return x
436
-
437
-
438
- class Block(nn.Module):
439
- """Transformer blocks with support of window attention and residual propagation blocks"""
440
-
441
- def __init__(
442
- self,
443
- dim: int,
444
- num_heads: int,
445
- mlp_ratio: float = 4.0,
446
- qkv_bias: bool = True,
447
- norm_layer: Type[nn.Module] = nn.LayerNorm,
448
- act_layer: Type[nn.Module] = nn.GELU,
449
- use_rel_pos: bool = False,
450
- rel_pos_zero_init: bool = True,
451
- window_size: int = 0,
452
- input_size: Optional[Tuple[int, int]] = None,
453
- ) -> None:
454
- """
455
- Args:
456
- dim (int): Number of input channels.
457
- num_heads (int): Number of attention heads in each ViT block.
458
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
459
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
460
- norm_layer (nn.Module): Normalization layer.
461
- act_layer (nn.Module): Activation layer.
462
- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
463
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
464
- window_size (int): Window size for window attention blocks. If it equals 0, then
465
- use global attention.
466
- input_size (tuple(int, int) or None): Input resolution for calculating the relative
467
- positional parameter size.
468
- """
469
- super().__init__()
470
- self.norm1 = norm_layer(dim)
471
- self.attn = Attention2(
472
- dim,
473
- num_heads=num_heads,
474
- qkv_bias=qkv_bias,
475
- use_rel_pos=use_rel_pos,
476
- rel_pos_zero_init=rel_pos_zero_init,
477
- input_size=input_size if window_size == 0 else (window_size, window_size),
478
- )
479
-
480
- self.norm2 = norm_layer(dim)
481
- self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
482
-
483
- self.window_size = window_size
484
-
485
- def forward(self, x: torch.Tensor) -> torch.Tensor:
486
- shortcut = x
487
- x = self.norm1(x)
488
- # Window partition
489
- if self.window_size > 0:
490
- H, W = x.shape[1], x.shape[2]
491
- x, pad_hw = window_partition(x, self.window_size)
492
-
493
- x = self.attn(x)
494
- # Reverse window partition
495
- if self.window_size > 0:
496
- x = window_unpartition(x, self.window_size, pad_hw, (H, W))
497
-
498
- x = shortcut + x
499
- x = x + self.mlp(self.norm2(x))
500
-
501
- return x
502
-
503
-
504
- class Attention2(nn.Module):
505
- """Multi-head Attention block with relative position embeddings."""
506
-
507
- def __init__(
508
- self,
509
- dim: int,
510
- num_heads: int = 8,
511
- qkv_bias: bool = True,
512
- use_rel_pos: bool = False,
513
- rel_pos_zero_init: bool = True,
514
- input_size: Optional[Tuple[int, int]] = None,
515
- ) -> None:
516
- """
517
- Args:
518
- dim (int): Number of input channels.
519
- num_heads (int): Number of attention heads.
520
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
521
- rel_pos (bool): If True, add relative positional embeddings to the attention map.
522
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
523
- input_size (tuple(int, int) or None): Input resolution for calculating the relative
524
- positional parameter size.
525
- """
526
- super().__init__()
527
- self.num_heads = num_heads
528
- head_dim = dim // num_heads
529
- self.scale = head_dim ** -0.5
530
-
531
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
532
- self.proj = nn.Linear(dim, dim)
533
-
534
- self.use_rel_pos = use_rel_pos
535
- if self.use_rel_pos:
536
- assert (
537
- input_size is not None
538
- ), "Input size must be provided if using relative positional encoding."
539
- # initialize relative positional embeddings
540
- self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
541
- self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
542
-
543
- def forward(self, x: torch.Tensor) -> torch.Tensor:
544
- B, H, W, _ = x.shape
545
- # qkv with shape (3, B, nHead, H * W, C)
546
- qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
547
- # q, k, v with shape (B * nHead, H * W, C)
548
- q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
549
-
550
- attn = (q * self.scale) @ k.transpose(-2, -1)
551
-
552
- if self.use_rel_pos:
553
- attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
554
-
555
- attn = attn.softmax(dim=-1)
556
- x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
557
- x = self.proj(x)
558
-
559
- return x
560
-
561
-
562
- def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
563
- """
564
- Partition into non-overlapping windows with padding if needed.
565
- Args:
566
- x (tensor): input tokens with [B, H, W, C].
567
- window_size (int): window size.
568
-
569
- Returns:
570
- windows: windows after partition with [B * num_windows, window_size, window_size, C].
571
- (Hp, Wp): padded height and width before partition
572
- """
573
- B, H, W, C = x.shape
574
-
575
- pad_h = (window_size - H % window_size) % window_size
576
- pad_w = (window_size - W % window_size) % window_size
577
- if pad_h > 0 or pad_w > 0:
578
- x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
579
- Hp, Wp = H + pad_h, W + pad_w
580
-
581
- x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
582
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
583
- return windows, (Hp, Wp)
584
-
585
-
586
- def window_unpartition(
587
- windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
588
- ) -> torch.Tensor:
589
- """
590
- Window unpartition into original sequences and removing padding.
591
- Args:
592
- windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
593
- window_size (int): window size.
594
- pad_hw (Tuple): padded height and width (Hp, Wp).
595
- hw (Tuple): original height and width (H, W) before padding.
596
-
597
- Returns:
598
- x: unpartitioned sequences with [B, H, W, C].
599
- """
600
- Hp, Wp = pad_hw
601
- H, W = hw
602
- B = windows.shape[0] // (Hp * Wp // window_size // window_size)
603
- x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
604
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
605
-
606
- if Hp > H or Wp > W:
607
- x = x[:, :H, :W, :].contiguous()
608
- return x
609
-
610
-
611
- def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
612
- """
613
- Get relative positional embeddings according to the relative positions of
614
- query and key sizes.
615
- Args:
616
- q_size (int): size of query q.
617
- k_size (int): size of key k.
618
- rel_pos (Tensor): relative position embeddings (L, C).
619
-
620
- Returns:
621
- Extracted positional embeddings according to relative positions.
622
- """
623
- max_rel_dist = int(2 * max(q_size, k_size) - 1)
624
- # Interpolate rel pos if needed.
625
- if rel_pos.shape[0] != max_rel_dist:
626
- # Interpolate rel pos.
627
- rel_pos_resized = F.interpolate(
628
- rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
629
- size=max_rel_dist,
630
- mode="linear",
631
- )
632
- rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
633
- else:
634
- rel_pos_resized = rel_pos
635
-
636
- # Scale the coords with short length if shapes for q and k are different.
637
- q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
638
- k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
639
- relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
640
-
641
- return rel_pos_resized[relative_coords.long()]
642
-
643
-
644
- def add_decomposed_rel_pos(
645
- attn: torch.Tensor,
646
- q: torch.Tensor,
647
- rel_pos_h: torch.Tensor,
648
- rel_pos_w: torch.Tensor,
649
- q_size: Tuple[int, int],
650
- k_size: Tuple[int, int],
651
- ) -> torch.Tensor:
652
- """
653
- Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
654
- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
655
- Args:
656
- attn (Tensor): attention map.
657
- q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
658
- rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
659
- rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
660
- q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
661
- k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
662
-
663
- Returns:
664
- attn (Tensor): attention map with added relative positional embeddings.
665
- """
666
- q_h, q_w = q_size
667
- k_h, k_w = k_size
668
- Rh = get_rel_pos(q_h, k_h, rel_pos_h)
669
- Rw = get_rel_pos(q_w, k_w, rel_pos_w)
670
-
671
- B, _, dim = q.shape
672
- r_q = q.reshape(B, q_h, q_w, dim)
673
- rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
674
- rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
675
-
676
- attn = (
677
- attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
678
- ).view(B, q_h * q_w, k_h * k_w)
679
-
680
- return attn
681
-
682
-
683
- class PromptEncoder(nn.Module):
684
- def __init__(
685
- self,
686
- embed_dim: int,
687
- image_embedding_size: Tuple[int, int, int],
688
- input_image_size: Tuple[int, int, int],
689
- mask_in_chans: int,
690
- activation: Type[nn.Module] = nn.GELU,
691
- ) -> None:
692
- """
693
- Encodes prompts for input to SAM's mask decoder.
694
-
695
- Arguments:
696
- embed_dim (int): The prompts' embedding dimension
697
- image_embedding_size (tuple(int, int)): The spatial size of the
698
- image embedding, as (H, W).
699
- input_image_size (int): The padded size of the image as input
700
- to the image encoder, as (H, W).
701
- mask_in_chans (int): The number of hidden channels used for
702
- encoding input masks.
703
- activation (nn.Module): The activation to use when encoding
704
- input masks.
705
- """
706
- super().__init__()
707
- self.embed_dim = embed_dim
708
- self.input_image_size = input_image_size
709
- self.image_embedding_size = image_embedding_size
710
- self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
711
-
712
- self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
713
- point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
714
- self.point_embeddings = nn.ModuleList(point_embeddings)
715
- self.not_a_point_embed = nn.Embedding(1, embed_dim)
716
-
717
- self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1], 4 * image_embedding_size[2])
718
- self.mask_downscaling = nn.Sequential(
719
- nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
720
- LayerNorm2d(mask_in_chans // 4),
721
- activation(),
722
- nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
723
- LayerNorm2d(mask_in_chans),
724
- activation(),
725
- nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
726
- )
727
- self.no_mask_embed = nn.Embedding(1, embed_dim)
728
-
729
- def get_dense_pe(self) -> torch.Tensor:
730
- """
731
- Returns the positional encoding used to encode point prompts,
732
- applied to a dense set of points the shape of the image encoding.
733
-
734
- Returns:
735
- torch.Tensor: Positional encoding with shape
736
- 1x(embed_dim)x(embedding_h)x(embedding_w)
737
- """
738
- return self.pe_layer(self.image_embedding_size).unsqueeze(0)
739
-
740
- def _embed_points(
741
- self,
742
- points: torch.Tensor,
743
- labels: torch.Tensor,
744
- pad: bool,
745
- ) -> torch.Tensor:
746
- """Embeds point prompts."""
747
- points = points + 0.5 # Shift to center of pixel
748
- if pad:
749
- padding_point = torch.zeros((points.shape[0], 1, 3), device=points.device)
750
- padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
751
- points = torch.cat([points, padding_point], dim=1)
752
- labels = torch.cat([labels, padding_label], dim=1)
753
- point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
754
- point_embedding[labels == -1] = 0.0
755
- point_embedding[labels == -1] += self.not_a_point_embed.weight
756
- point_embedding[labels == 0] += self.point_embeddings[0].weight
757
- point_embedding[labels == 1] += self.point_embeddings[1].weight
758
- return point_embedding
759
-
760
- def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
761
- """Embeds box prompts."""
762
- boxes = boxes + 0.5 # Shift to center of pixel
763
- coords = boxes.reshape(-1, 2, 3)
764
- corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
765
- corner_embedding[:, 0, :] += self.point_embeddings[2].weight
766
- corner_embedding[:, 1, :] += self.point_embeddings[3].weight
767
- return corner_embedding
768
-
769
- def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
770
- """Embeds mask inputs."""
771
- mask_embedding = self.mask_downscaling(masks)
772
- return mask_embedding
773
-
774
- def _get_batch_size(
775
- self,
776
- points: Optional[Tuple[torch.Tensor, torch.Tensor]],
777
- boxes: Optional[torch.Tensor],
778
- masks: Optional[torch.Tensor],
779
- text_embedding: Optional[torch.Tensor],
780
- ) -> int:
781
- """
782
- Gets the batch size of the output given the batch size of the input prompts.
783
- """
784
- if points is not None:
785
- return points[0].shape[0]
786
- elif boxes is not None:
787
- return boxes.shape[0]
788
- elif masks is not None:
789
- return masks.shape[0]
790
- elif text_embedding is not None:
791
- return text_embedding.shape[0]
792
- else:
793
- return 1
794
-
795
- def _get_device(self) -> torch.device:
796
- return self.point_embeddings[0].weight.device
797
-
798
- def forward(
799
- self,
800
- points: Optional[Tuple[torch.Tensor, torch.Tensor]],
801
- boxes: Optional[torch.Tensor],
802
- masks: Optional[torch.Tensor],
803
- text_embedding: Optional[torch.Tensor],
804
- ) -> Tuple[torch.Tensor, torch.Tensor]:
805
- """
806
- Embeds different types of prompts, returning both sparse and dense
807
- embeddings.
808
-
809
- Arguments:
810
- points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
811
- and labels to embed.
812
- boxes (torch.Tensor or none): boxes to embed
813
- masks (torch.Tensor or none): masks to embed
814
- text: test prompt (B, 768)
815
-
816
- Returns:
817
- torch.Tensor: sparse embeddings for the points and boxes, with shape
818
- BxNx(embed_dim), where N is determined by the number of input points
819
- and boxes.
820
- torch.Tensor: dense embeddings for the masks, in the shape
821
- Bx(embed_dim)x(embed_H)x(embed_W)
822
- """
823
- # print('prompt encoder here...')
824
-
825
- bs = self._get_batch_size(points, boxes, masks, text_embedding)
826
- sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device(),
827
- dtype=self.point_embeddings[0].weight.dtype)
828
- # print('sparse_embeddings ', sparse_embeddings.shape)
829
- if points is not None:
830
- coords, labels = points
831
- point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
832
- sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
833
-
834
- if boxes is not None:
835
- box_embeddings = self._embed_boxes(boxes)
836
- sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
837
-
838
- if text_embedding is not None:
839
- sparse_embeddings = torch.cat([sparse_embeddings, text_embedding.unsqueeze(dim=1)], dim=1)
840
-
841
- # print('box_embeddings ', box_embeddings.shape)
842
- # print('sparse_embeddings after box/point/text', sparse_embeddings.shape)
843
-
844
- if masks is not None:
845
- dense_embeddings = self._embed_masks(masks)
846
- else:
847
- dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1, 1).expand(
848
- bs, -1, int(self.image_embedding_size[0]), int(self.image_embedding_size[1]),
849
- int(self.image_embedding_size[2])
850
- )
851
- return sparse_embeddings, dense_embeddings
852
-
853
-
854
- class PositionEmbeddingRandom(nn.Module):
855
- """
856
- Positional encoding using random spatial frequencies.
857
- """
858
-
859
- def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
860
- super().__init__()
861
- if scale is None or scale <= 0.0:
862
- scale = 1.0
863
- self.register_buffer(
864
- "positional_encoding_gaussian_matrix",
865
- scale * torch.randn((3, num_pos_feats)),
866
- )
867
-
868
- def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
869
- """Positionally encode points that are normalized to [0,1]."""
870
- # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
871
- coords = 2 * coords - 1
872
- coords = coords @ self.positional_encoding_gaussian_matrix
873
- coords = 2 * np.pi * coords
874
- # outputs d_1 x ... x d_n x C shape
875
- return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
876
-
877
- def forward(self, size: Tuple[int, int, int]) -> torch.Tensor:
878
- """Generate positional encoding for a grid of the specified size."""
879
- h, w, d = size
880
- device: Any = self.positional_encoding_gaussian_matrix.device
881
- dtype = self.positional_encoding_gaussian_matrix.dtype
882
- grid = torch.ones((h, w, d), device=device, dtype=dtype)
883
- y_embed = grid.cumsum(dim=0) - 0.5
884
- x_embed = grid.cumsum(dim=1) - 0.5
885
- z_embed = grid.cumsum(dim=2) - 0.5
886
- y_embed = y_embed / h
887
- x_embed = x_embed / w
888
- z_embed = z_embed / d
889
-
890
- pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))
891
- return pe.permute(3, 0, 1, 2) # C x H x W x D
892
-
893
- def forward_with_coords(
894
- self, coords_input: torch.Tensor, image_size: Tuple[int, int]
895
- ) -> torch.Tensor:
896
- """Positionally encode points that are not normalized to [0,1]."""
897
- coords = coords_input.clone()
898
- coords[:, :, 0] = coords[:, :, 0] / image_size[1]
899
- coords[:, :, 1] = coords[:, :, 1] / image_size[0]
900
- coords[:, :, 2] = coords[:, :, 2] / image_size[2]
901
- return self._pe_encoding(coords.to(torch.float)) # B x N x C
902
-
903
-
904
- class MaskDecoder(nn.Module):
905
- def __init__(
906
- self,
907
- *,
908
- image_encoder_type: str,
909
- transformer_dim: int,
910
- transformer: nn.Module,
911
- num_multimask_outputs: int = 3,
912
- activation: Type[nn.Module] = nn.GELU,
913
- iou_head_depth: int = 3,
914
- iou_head_hidden_dim: int = 256,
915
- image_size,
916
- patch_size,
917
- ) -> None:
918
- """
919
- Predicts masks given an image and prompt embeddings, using a
920
- transformer architecture.
921
-
922
- Arguments:
923
- transformer_dim (int): the channel dimension of the transformer
924
- transformer (nn.Module): the transformer used to predict masks
925
- num_multimask_outputs (int): the number of masks to predict
926
- when disambiguating masks
927
- activation (nn.Module): the type of activation to use when
928
- upscaling masks
929
- iou_head_depth (int): the depth of the MLP used to predict
930
- mask quality
931
- iou_head_hidden_dim (int): the hidden dimension of the MLP
932
- used to predict mask quality
933
- """
934
- super().__init__()
935
- self.transformer_dim = transformer_dim
936
- self.transformer = transformer
937
-
938
- self.num_multimask_outputs = num_multimask_outputs
939
-
940
- self.iou_token = nn.Embedding(1, transformer_dim)
941
- self.num_mask_tokens = num_multimask_outputs + 1
942
- self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
943
-
944
- if image_encoder_type == 'swin_vit':
945
- self.feat_shape = image_size / patch_size
946
- self.output_upscaling = nn.Sequential(
947
- nn.ConvTranspose3d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
948
- nn.LayerNorm(
949
- (transformer_dim // 4, int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))),
950
- # swin
951
- activation(),
952
- nn.ConvTranspose3d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), # swin
953
- # nn.Conv3d(transformer_dim // 4, transformer_dim // 8, kernel_size=3, stride=1, padding=1), # vit
954
- activation(),
955
- )
956
- else:
957
- self.feat_shape = image_size / patch_size * 2
958
- self.output_upscaling = nn.Sequential(
959
- nn.ConvTranspose3d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
960
- nn.LayerNorm(
961
- (transformer_dim // 4, int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))),
962
- # vit
963
- activation(),
964
- nn.ConvTranspose3d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
965
- # nn.Conv3d(transformer_dim // 4, transformer_dim // 8, kernel_size=3, stride=1, padding=1),
966
- activation(),
967
- )
968
- self.output_hypernetworks_mlps = nn.ModuleList(
969
- [
970
- MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
971
- for i in range(self.num_mask_tokens)
972
- ]
973
- )
974
-
975
- self.iou_prediction_head = MLP(
976
- transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
977
- )
978
-
979
- self.txt_align_upscaled_embedding = nn.Linear(768, 96)
980
-
981
- def forward(
982
- self,
983
- image_embeddings: torch.Tensor,
984
- text_embedding: Optional[torch.Tensor],
985
- image_pe: torch.Tensor,
986
- sparse_prompt_embeddings: torch.Tensor,
987
- dense_prompt_embeddings: torch.Tensor,
988
- multimask_output: bool,
989
- ) -> Tuple[torch.Tensor, torch.Tensor]:
990
- """
991
- Predict masks given image and prompt embeddings.
992
-
993
- Arguments:
994
- image_embeddings (torch.Tensor): the embeddings from the image encoder
995
- image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
996
- sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
997
- dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
998
- multimask_output (bool): Whether to return multiple masks or a single
999
- mask.
1000
-
1001
- Returns:
1002
- torch.Tensor: batched predicted masks
1003
- torch.Tensor: batched predictions of mask quality
1004
- """
1005
- # print('--------------decoder here--------------')
1006
- masks, iou_pred = self.predict_masks(
1007
- image_embeddings=image_embeddings,
1008
- text_embedding=text_embedding,
1009
- image_pe=image_pe,
1010
- sparse_prompt_embeddings=sparse_prompt_embeddings,
1011
- dense_prompt_embeddings=dense_prompt_embeddings,
1012
- )
1013
-
1014
- # Select the correct mask or masks for output
1015
- if multimask_output:
1016
- mask_slice = slice(1, None)
1017
- else:
1018
- mask_slice = slice(0, 1)
1019
- masks = masks[:, mask_slice, :, :, :]
1020
- iou_pred = iou_pred[:, mask_slice]
1021
-
1022
- # Prepare output
1023
- return masks, iou_pred
1024
-
1025
- def predict_masks(
1026
- self,
1027
- image_embeddings: torch.Tensor,
1028
- text_embedding: torch.Tensor,
1029
- image_pe: torch.Tensor,
1030
- sparse_prompt_embeddings: torch.Tensor,
1031
- dense_prompt_embeddings: torch.Tensor,
1032
- ) -> Tuple[torch.Tensor, torch.Tensor]:
1033
- """Predicts masks. See 'forward' for more details."""
1034
- # Concatenate output tokens
1035
- output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
1036
- output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
1037
- tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # [2, 7=(5+2), 256]
1038
- # Expand per-image data in batch direction to be per-mask
1039
- if image_embeddings.shape[0] != tokens.shape[0]:
1040
- src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
1041
- else:
1042
- src = image_embeddings
1043
-
1044
- src = src + dense_prompt_embeddings
1045
- pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
1046
- b, c, h, w, d = src.shape
1047
-
1048
- # Run the transformer
1049
- hs, src = self.transformer(src, pos_src, tokens)
1050
- iou_token_out = hs[:, 0, :]
1051
- mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :]
1052
-
1053
- # Upscale mask embeddings and predict masks using the mask tokens
1054
- src = src.transpose(1, 2).view(b, c, h, w, d)
1055
- # print('src ', src.shape) # vit:[B, 768, 12, 12, 6], swin: [B, 6, 6, 3]
1056
- upscaled_embedding = self.output_upscaling(src)
1057
- hyper_in_list: List[torch.Tensor] = []
1058
- for i in range(self.num_mask_tokens):
1059
- hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
1060
- hyper_in = torch.stack(hyper_in_list, dim=1)
1061
- b, c, h, w, d = upscaled_embedding.shape
1062
- # print('hyper_in ', hyper_in.shape) # [2, 4, 96]
1063
- # print('upscaled_embedding ', upscaled_embedding.shape) # [2, 96, 24, 24, 12]*
1064
- masks = (hyper_in @ upscaled_embedding.view(b, c, h * w * d)).view(b, -1, h, w, d)
1065
- # print('masks here ', masks.shape) # [2, 4, 24, 24, 12]
1066
-
1067
- if text_embedding is not None:
1068
- # text_embedding: B x 768, upscaled_embedding: B x c x h x w x d => B x 1 x h x w x d
1069
- text_embedding_down = self.txt_align_upscaled_embedding(text_embedding).unsqueeze(dim=1)
1070
- upscaled_embedding = upscaled_embedding.view(b, c, h * w * d)
1071
- # print('text_embedding_down ', text_embedding_down.shape) # [2, 1, 96]
1072
- # text_embedding_norm = F.normalize(text_embedding_down, dim=-1)
1073
- # upscaled_embedding_norm = F.normalize(upscaled_embedding, dim=1)
1074
- # sim = (text_embedding_norm @ upscaled_embedding_norm).view(b, -1, h, w, d)
1075
- # print(text_embedding_down.shape, upscaled_embedding.shape)
1076
- sim = (text_embedding_down @ upscaled_embedding).view(b, -1, h, w, d)
1077
- # print('sim ', sim.shape) # [B, 1, 24, 24, 12]
1078
- sim = sim.repeat(1, masks.shape[1], 1, 1, 1)
1079
- # print('sim after', sim.shape) # [B, 4, 24, 24, 12]
1080
- masks = masks + sim
1081
- # Generate mask quality predictions
1082
- iou_pred = self.iou_prediction_head(iou_token_out)
1083
-
1084
- return masks, iou_pred
1085
-
1086
-
1087
- # Lightly adapted from
1088
- # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
1089
- class MLP(nn.Module):
1090
- def __init__(
1091
- self,
1092
- input_dim: int,
1093
- hidden_dim: int,
1094
- output_dim: int,
1095
- num_layers: int,
1096
- sigmoid_output: bool = False,
1097
- ) -> None:
1098
- super().__init__()
1099
- self.num_layers = num_layers
1100
- h = [hidden_dim] * (num_layers - 1)
1101
- self.layers = nn.ModuleList(
1102
- nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
1103
- )
1104
- self.sigmoid_output = sigmoid_output
1105
-
1106
- def forward(self, x):
1107
- for i, layer in enumerate(self.layers):
1108
- x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1109
- if self.sigmoid_output:
1110
- x = F.sigmoid(x)
1111
- return x
1112
-
1113
-
1114
- class Sam(nn.Module):
1115
- mask_threshold: float = 0.0
1116
- image_format: str = "RGB"
1117
-
1118
- def __init__(
1119
- self,
1120
- image_encoder: ImageEncoderViT,
1121
- prompt_encoder: PromptEncoder,
1122
- mask_decoder: MaskDecoder,
1123
- pixel_mean: List[float] = [123.675, 116.28, 103.53],
1124
- pixel_std: List[float] = [58.395, 57.12, 57.375],
1125
- ) -> None:
1126
- """
1127
- SAM predicts object masks from an image and input prompts.
1128
-
1129
- Arguments:
1130
- image_encoder (ImageEncoderViT): The backbone used to encode the
1131
- image into image embeddings that allow for efficient mask prediction.
1132
- prompt_encoder (PromptEncoder): Encodes various types of input prompts.
1133
- mask_decoder (MaskDecoder): Predicts masks from the image embeddings
1134
- and encoded prompts.
1135
- pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
1136
- pixel_std (list(float)): Std values for normalizing pixels in the input image.
1137
- """
1138
- super().__init__()
1139
- self.image_encoder = image_encoder
1140
- self.prompt_encoder = prompt_encoder
1141
- self.mask_decoder = mask_decoder
1142
- self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
1143
- self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
1144
-
1145
- @property
1146
- def device(self) -> Any:
1147
- return self.pixel_mean.device
1148
-
1149
- @torch.no_grad()
1150
- def forward(
1151
- self,
1152
- batched_input: List[Dict[str, Any]],
1153
- multimask_output: bool,
1154
- ) -> List[Dict[str, torch.Tensor]]:
1155
- """
1156
- Predicts masks end-to-end from provided images and prompts.
1157
- If prompts are not known in advance, using SamPredictor is
1158
- recommended over calling the model directly.
1159
-
1160
- Arguments:
1161
- batched_input (list(dict)): A list over input images, each a
1162
- dictionary with the following keys. A prompt key can be
1163
- excluded if it is not present.
1164
- 'image': The image as a torch tensor in 3xHxW format,
1165
- already transformed for input to the model.
1166
- 'original_size': (tuple(int, int)) The original size of
1167
- the image before transformation, as (H, W).
1168
- 'point_coords': (torch.Tensor) Batched point prompts for
1169
- this image, with shape BxNx2. Already transformed to the
1170
- input frame of the model.
1171
- 'point_labels': (torch.Tensor) Batched labels for point prompts,
1172
- with shape BxN.
1173
- 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
1174
- Already transformed to the input frame of the model.
1175
- 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
1176
- in the form Bx1xHxW.
1177
- multimask_output (bool): Whether the model should predict multiple
1178
- disambiguating masks, or return a single mask.
1179
-
1180
- Returns:
1181
- (list(dict)): A list over input images, where each element is
1182
- as dictionary with the following keys.
1183
- 'masks': (torch.Tensor) Batched binary mask predictions,
1184
- with shape BxCxHxW, where B is the number of input prompts,
1185
- C is determined by multimask_output, and (H, W) is the
1186
- original size of the image.
1187
- 'iou_predictions': (torch.Tensor) The model's predictions
1188
- of mask quality, in shape BxC.
1189
- 'low_res_logits': (torch.Tensor) Low resolution logits with
1190
- shape BxCxHxW, where H=W=256. Can be passed as mask input
1191
- to subsequent iterations of prediction.
1192
- """
1193
- input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
1194
- image_embeddings = self.image_encoder(input_images)
1195
-
1196
- outputs = []
1197
- for image_record, curr_embedding in zip(batched_input, image_embeddings):
1198
- if "point_coords" in image_record:
1199
- points = (image_record["point_coords"], image_record["point_labels"])
1200
- else:
1201
- points = None
1202
- sparse_embeddings, dense_embeddings = self.prompt_encoder(
1203
- points=points,
1204
- boxes=image_record.get("boxes", None),
1205
- masks=image_record.get("mask_inputs", None),
1206
- )
1207
- low_res_masks, iou_predictions = self.mask_decoder(
1208
- image_embeddings=curr_embedding.unsqueeze(0),
1209
- image_pe=self.prompt_encoder.get_dense_pe(),
1210
- sparse_prompt_embeddings=sparse_embeddings,
1211
- dense_prompt_embeddings=dense_embeddings,
1212
- multimask_output=multimask_output,
1213
- )
1214
- masks = self.postprocess_masks(
1215
- low_res_masks,
1216
- input_size=image_record["image"].shape[-2:],
1217
- original_size=image_record["original_size"],
1218
- )
1219
- masks = masks > self.mask_threshold
1220
- outputs.append(
1221
- {
1222
- "masks": masks,
1223
- "iou_predictions": iou_predictions,
1224
- "low_res_logits": low_res_masks,
1225
- }
1226
- )
1227
- return outputs
1228
-
1229
- def postprocess_masks(
1230
- self,
1231
- masks: torch.Tensor,
1232
- input_size: Tuple[int, ...],
1233
- original_size: Tuple[int, ...],
1234
- ) -> torch.Tensor:
1235
- """
1236
- Remove padding and upscale masks to the original image size.
1237
-
1238
- Arguments:
1239
- masks (torch.Tensor): Batched masks from the mask_decoder,
1240
- in BxCxHxW format.
1241
- input_size (tuple(int, int)): The size of the image input to the
1242
- model, in (H, W) format. Used to remove padding.
1243
- original_size (tuple(int, int)): The original size of the image
1244
- before resizing for input to the model, in (H, W) format.
1245
-
1246
- Returns:
1247
- (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
1248
- is given by original_size.
1249
- """
1250
- masks = F.interpolate(
1251
- masks,
1252
- (self.image_encoder.img_size, self.image_encoder.img_size),
1253
- mode="bilinear",
1254
- align_corners=False,
1255
- )
1256
- masks = masks[..., : input_size[0], : input_size[1]]
1257
- masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
1258
- return masks
1259
-
1260
- def preprocess(self, x: torch.Tensor) -> torch.Tensor:
1261
- """Normalize pixel values and pad to a square input."""
1262
- # Normalize colors
1263
- # TODO
1264
- x = (x - self.pixel_mean) / self.pixel_std
1265
-
1266
- # Pad
1267
- h, w = x.shape[-2:]
1268
- padh = self.image_encoder.img_size - h
1269
- padw = self.image_encoder.img_size - w
1270
- x = F.pad(x, (0, padw, 0, padh))
1271
- return x
1272
-
1273
-
1274
- """
1275
- Examples::
1276
- # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
1277
- >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
1278
- # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
1279
- >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
1280
- # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
1281
- >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
1282
- """
1283
-
1284
-
1285
- def build_sam_vit_3d(args, checkpoint=None):
1286
- print('build_sam_vit_3d...')
1287
- return _build_sam(
1288
- image_encoder_type='vit',
1289
- embed_dim=768,
1290
- patch_size=args.patch_size,
1291
- checkpoint=checkpoint,
1292
- image_size=args.image_size,
1293
- )
1294
-
1295
-
1296
- sam_model_registry = {
1297
- "vit": build_sam_vit_3d,
1298
- }
1299
-
1300
-
1301
- def _build_sam(
1302
- image_encoder_type,
1303
- embed_dim,
1304
- patch_size,
1305
- checkpoint,
1306
- image_size,
1307
- ):
1308
- mlp_dim = 3072
1309
- num_layers = 12
1310
- num_heads = 12
1311
- pos_embed = 'perceptron'
1312
- dropout_rate = 0.0
1313
-
1314
- image_encoder = ViT(
1315
- in_channels=1,
1316
- img_size=image_size,
1317
- patch_size=patch_size,
1318
- hidden_size=embed_dim,
1319
- mlp_dim=mlp_dim,
1320
- num_layers=num_layers,
1321
- num_heads=num_heads,
1322
- pos_embed=pos_embed,
1323
- classification=False,
1324
- dropout_rate=dropout_rate,
1325
- )
1326
- image_embedding_size = [int(item) for item in (np.array(image_size) / np.array(patch_size))]
1327
-
1328
- if checkpoint is not None:
1329
- with open(checkpoint, "rb") as f:
1330
- state_dict = torch.load(f, map_location='cpu')['state_dict']
1331
- encoder_dict = {k.replace('model.encoder.', ''): v for k, v in state_dict.items() if 'model.encoder.' in k}
1332
- image_encoder.load_state_dict(encoder_dict)
1333
- print(f'===> image_encoder.load_param: {checkpoint}')
1334
- sam = Sam(
1335
- image_encoder=image_encoder,
1336
- prompt_encoder=PromptEncoder(
1337
- embed_dim=embed_dim,
1338
- image_embedding_size=image_embedding_size,
1339
- input_image_size=image_size,
1340
- mask_in_chans=16,
1341
- ),
1342
- mask_decoder=MaskDecoder(
1343
- image_encoder_type=image_encoder_type,
1344
- num_multimask_outputs=3,
1345
- transformer=TwoWayTransformer(
1346
- depth=2,
1347
- embedding_dim=embed_dim,
1348
- mlp_dim=2048,
1349
- num_heads=8,
1350
- ),
1351
- transformer_dim=embed_dim,
1352
- iou_head_depth=3,
1353
- iou_head_hidden_dim=256,
1354
- image_size=np.array(image_size),
1355
- patch_size=np.array(patch_size),
1356
- ),
1357
- pixel_mean=[123.675, 116.28, 103.53],
1358
- pixel_std=[58.395, 57.12, 57.375],
1359
- )
1360
- sam.eval()
1361
- return sam
1362
-
1363
- class SegVol(nn.Module):
1364
- def __init__(self,
1365
- image_encoder,
1366
- mask_decoder,
1367
- prompt_encoder,
1368
- roi_size,
1369
- patch_size,
1370
- ):
1371
- super().__init__()
1372
- self.image_encoder = image_encoder
1373
- self.mask_decoder = mask_decoder
1374
- self.prompt_encoder = prompt_encoder
1375
- self.feat_shape = np.array(roi_size)/np.array(patch_size)
1376
-
1377
- def forward(self, image, text_emb=None, text=None, boxes=None, points=None):
1378
- bs = image.shape[0]
1379
- img_shape = (image.shape[2], image.shape[3], image.shape[4])
1380
- image_embedding, _ = self.image_encoder(image)
1381
-
1382
- image_embedding = image_embedding.transpose(1, 2).view(bs, -1,
1383
- int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))
1384
-
1385
- logits = self.forward_decoder(image_embedding, img_shape, text_emb=text_emb, text=text, boxes=boxes, points=points)
1386
-
1387
- return logits
1388
-
1389
- def forward_decoder(self, image_embedding, img_shape, text_emb=None, text=None, boxes=None, points=None):
1390
- text_embedding = text_emb
1391
- sparse_embeddings, dense_embeddings = self.prompt_encoder(
1392
- points=None,
1393
- boxes=None,
1394
- masks=None,
1395
- text_embedding=text_embedding,
1396
- )
1397
-
1398
- dense_pe = self.prompt_encoder.get_dense_pe()
1399
-
1400
- low_res_masks, _ = self.mask_decoder(
1401
- image_embeddings=image_embedding,
1402
- text_embedding = text_embedding,
1403
- image_pe=dense_pe,
1404
- sparse_prompt_embeddings=sparse_embeddings,
1405
- dense_prompt_embeddings=dense_embeddings,
1406
- multimask_output=False,
1407
- )
1408
- logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False)
1409
-
1410
- return logits
1411
-
1412
-
1413
- def build_segmentation_module(config, **kwargs):
1414
- segmentation_module = getattr(config, 'segmentation_module')
1415
- if 'segvol' in segmentation_module.lower():
1416
- sam_model = sam_model_registry['vit'](args=config, checkpoint=None)
1417
- seg_model = SegVol(
1418
- image_encoder=sam_model.image_encoder,
1419
- mask_decoder=sam_model.mask_decoder,
1420
- prompt_encoder=sam_model.prompt_encoder,
1421
- roi_size=config.image_size,
1422
- patch_size=config.patch_size,
1423
- )
1424
- return seg_model
1425
- else:
1426
- raise ValueError(f'Unknown segmentation module: {segmentation_module}')
1427
-
1428
-
1429
- class IdentityMap(nn.Module):
1430
- def __init__(self):
1431
- super().__init__()
1432
-
1433
- def forward(self, x, *args, **kwargs):
1434
- return x
1435
-
1436
- @property
1437
- def config(self):
1438
- return {"mm_projector_type": 'identity'}
1439
-
1440
-
1441
-
1442
- class SpatialPoolingProjector(nn.Module):
1443
- def __init__(self, image_size, patch_size, in_dim, out_dim, layer_type, layer_num, pooling_type='spatial', pooling_size=2):
1444
- super().__init__()
1445
- self.in_dim = in_dim
1446
- self.pooling_size = pooling_size
1447
-
1448
- self.num_patches_pre = [img // pch for img, pch in zip(image_size, patch_size)]
1449
- self.num_patches_post = [num // pooling_size for num in self.num_patches_pre]
1450
-
1451
- if layer_type == 'linear':
1452
- depth = int(layer_num)
1453
- modules = [nn.Linear(in_dim, out_dim)]
1454
- for _ in range(1, depth):
1455
- modules.append(nn.Linear(out_dim, out_dim))
1456
- self.projector = nn.Sequential(*modules)
1457
- elif layer_type == 'mlp':
1458
- depth = int(layer_num)
1459
- modules = [nn.Linear(in_dim, out_dim)]
1460
- for _ in range(1, depth):
1461
- modules.append(nn.GELU())
1462
- modules.append(nn.Linear(out_dim, out_dim))
1463
- self.projector = nn.Sequential(*modules)
1464
- else:
1465
- print("Projector error!")
1466
-
1467
- self.pooling_type = pooling_type
1468
-
1469
- def forward(self, x):
1470
- B = x.shape[0] # B*N*D
1471
-
1472
- if self.pooling_type == 'spatial':
1473
- to_3d = Rearrange("b (p1 p2 p3) d -> b d p1 p2 p3", b=B, d=self.in_dim, p1=self.num_patches_pre[0], p2=self.num_patches_pre[1], p3=self.num_patches_pre[2])
1474
- x = to_3d(x)
1475
- x = F.avg_pool3d(x, kernel_size=self.pooling_size, stride=self.pooling_size)
1476
- to_seq = Rearrange("b d p1 p2 p3 -> b (p1 p2 p3) d", b=B, d=self.in_dim, p1=self.num_patches_post[0], p2=self.num_patches_post[1], p3=self.num_patches_post[2])
1477
- x = to_seq(x)
1478
- elif self.pooling_type == 'sequence':
1479
- x = x.permute(0, 2, 1) #b d n
1480
- x = F.avg_pool1d(x, kernel_size=self.pooling_size**3, stride=self.pooling_size**3)
1481
- x = x.permute(0, 2, 1) #b n d
1482
-
1483
- x = rearrange(x, "b n d -> (b n) d")
1484
- x = self.projector(x)
1485
- x = rearrange(x, "(b n) d -> b n d", b=B)
1486
-
1487
- return x
1488
-
1489
- @property
1490
- def proj_out_num(self):
1491
- num = 1
1492
- for n in self.num_patches_post:
1493
- num *= n
1494
- return num
1495
-
1496
-
1497
- class Minigpt(nn.Module):
1498
- def __init__(self, config=None):
1499
- super(Minigpt, self).__init__()
1500
- # c*4 is the input size, and c is the output size for the linear layer
1501
- inc, ouc = config.mm_hidden_size, config.hidden_size
1502
- self.linear = nn.Linear(inc * 4, ouc)
1503
-
1504
- def forward(self, x):
1505
- # x is the input tensor with shape [b, num_tokens, c]
1506
- b, num_tokens, c = x.shape
1507
-
1508
- # Check if num_tokens is divisible by 4
1509
- if num_tokens % 4 != 0:
1510
- raise ValueError("num_tokens must be divisible by 4")
1511
-
1512
- # Reshape x to [b, num_tokens/4, c*4]
1513
- x = x.view(b, num_tokens // 4, c * 4)
1514
-
1515
- # Apply the linear transformation
1516
- x = self.linear(x)
1517
- return x
1518
-
1519
-
1520
- class Vanilla(nn.Module):
1521
- def __init__(self, config=None):
1522
- super(Vanilla, self).__init__()
1523
- # c*4 is the input size, and c is the output size for the linear layer
1524
- inc, ouc = config.mm_hidden_size, config.hidden_size
1525
- self.linear = nn.Linear(inc * 4, ouc)
1526
-
1527
- def forward(self, x):
1528
- b, num_tokens, c = x.shape
1529
-
1530
- # Check if num_tokens is divisible by 4
1531
- if num_tokens % 4 != 0:
1532
- raise ValueError("num_tokens must be divisible by 4")
1533
-
1534
- # First, reshape to [b, num_tokens//4, 4, c]
1535
- x = x.view(b, num_tokens // 4, 4, c)
1536
-
1537
- # Then, permute to interleave the tokens
1538
- x = x.permute(0, 1, 3, 2).contiguous()
1539
-
1540
- # Finally, reshape to [b, num_tokens//4, c*4] to interleave features of 4 tokens
1541
- x = x.view(b, num_tokens // 4, c * 4)
1542
-
1543
- # Apply the linear transformation
1544
- x = self.linear(x)
1545
- return x
1546
-
1547
-
1548
- def build_mm_projector(config, delay_load=False, **kwargs):
1549
- projector_type = getattr(config, 'mm_projector_type')
1550
-
1551
- if projector_type == 'linear':
1552
- return nn.Linear(config.mm_hidden_size, config.hidden_size)
1553
-
1554
-
1555
- elif projector_type == 'spp':
1556
- return SpatialPoolingProjector(image_size=config.image_size,
1557
- patch_size=config.patch_size,
1558
- in_dim=config.mm_hidden_size,
1559
- out_dim=config.hidden_size,
1560
- layer_type=config.proj_layer_type,
1561
- layer_num=config.proj_layer_num,
1562
- pooling_type=config.proj_pooling_type,
1563
- pooling_size=config.proj_pooling_size)
1564
-
1565
-
1566
- elif projector_type == 'identity':
1567
- return IdentityMap()
1568
- else:
1569
- raise ValueError(f'Unknown projector type: {projector_type}')
1570
-
1571
-
1572
-
1573
- class myViT(nn.Module):
1574
- """
1575
- Vision Transformer (ViT), based on: "Dosovitskiy et al.,
1576
- An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
1577
-
1578
- ViT supports Torchscript but only works for Pytorch after 1.8.
1579
- """
1580
-
1581
- def __init__(
1582
- self,
1583
- in_channels: int,
1584
- img_size: Sequence[int] | int,
1585
- patch_size: Sequence[int] | int,
1586
- hidden_size: int = 768,
1587
- mlp_dim: int = 3072,
1588
- num_layers: int = 12,
1589
- num_heads: int = 12,
1590
- pos_embed: str = "conv",
1591
- classification: bool = False,
1592
- num_classes: int = 2,
1593
- dropout_rate: float = 0.0,
1594
- spatial_dims: int = 3,
1595
- post_activation="Tanh",
1596
- qkv_bias: bool = False,
1597
- save_attn: bool = False,
1598
- ) -> None:
1599
- """
1600
- Args:
1601
- in_channels (int): dimension of input channels.
1602
- img_size (Union[Sequence[int], int]): dimension of input image.
1603
- patch_size (Union[Sequence[int], int]): dimension of patch size.
1604
- hidden_size (int, optional): dimension of hidden layer. Defaults to 768.
1605
- mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
1606
- num_layers (int, optional): number of transformer blocks. Defaults to 12.
1607
- num_heads (int, optional): number of attention heads. Defaults to 12.
1608
- pos_embed (str, optional): position embedding layer type. Defaults to "conv".
1609
- classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
1610
- num_classes (int, optional): number of classes if classification is used. Defaults to 2.
1611
- dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
1612
- spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
1613
- post_activation (str, optional): add a final acivation function to the classification head
1614
- when `classification` is True. Default to "Tanh" for `nn.Tanh()`.
1615
- Set to other values to remove this function.
1616
- qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
1617
- save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.
1618
-
1619
- Examples::
1620
-
1621
- # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
1622
- >>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv')
1623
-
1624
- # for 3-channel with image size of (128,128,128), 24 layers and classification backbone
1625
- >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True)
1626
-
1627
- # for 3-channel with image size of (224,224), 12 layers and classification backbone
1628
- >>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2)
1629
-
1630
- """
1631
-
1632
- super().__init__()
1633
-
1634
- if not (0 <= dropout_rate <= 1):
1635
- raise ValueError("dropout_rate should be between 0 and 1.")
1636
-
1637
- if hidden_size % num_heads != 0:
1638
- raise ValueError("hidden_size should be divisible by num_heads.")
1639
- self.hidden_size = hidden_size
1640
- self.classification = classification
1641
- self.patch_embedding = PatchEmbeddingBlock(
1642
- in_channels=in_channels,
1643
- img_size=img_size,
1644
- patch_size=patch_size,
1645
- hidden_size=hidden_size,
1646
- num_heads=num_heads,
1647
- pos_embed=pos_embed,
1648
- dropout_rate=dropout_rate,
1649
- spatial_dims=spatial_dims,
1650
- )
1651
- self.blocks = nn.ModuleList(
1652
- [
1653
- TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
1654
- for i in range(num_layers)
1655
- ]
1656
- )
1657
- self.norm = nn.LayerNorm(hidden_size)
1658
- if self.classification:
1659
- self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
1660
- # if post_activation == "Tanh":
1661
- # self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh())
1662
- # else:
1663
- # self.classification_head = nn.Linear(hidden_size, num_classes) # type: ignore
1664
-
1665
- def forward(self, x):
1666
- x = self.patch_embedding(x)
1667
- if hasattr(self, "cls_token"):
1668
- cls_token = self.cls_token.expand(x.shape[0], -1, -1)
1669
- x = torch.cat((cls_token, x), dim=1)
1670
- hidden_states_out = []
1671
- for blk in self.blocks:
1672
- x = blk(x)
1673
- hidden_states_out.append(x)
1674
- x = self.norm(x)
1675
- # if hasattr(self, "classification_head"):
1676
- # x = self.classification_head(x[:, 0])
1677
- return x, hidden_states_out
1678
-
1679
-
1680
- class ViT3DTower(nn.Module):
1681
- def __init__(self, config):
1682
- super().__init__()
1683
- self.config = config
1684
- self.select_layer = config.vision_select_layer
1685
- self.select_feature = config.vision_select_feature
1686
-
1687
- self.vision_tower = myViT(
1688
- in_channels=self.config.image_channel,
1689
- img_size=self.config.image_size,
1690
- patch_size=self.config.patch_size,
1691
- pos_embed="perceptron",
1692
- spatial_dims=len(self.config.patch_size),
1693
- classification=True,
1694
- )
1695
-
1696
- def forward(self, images):
1697
- last_feature, hidden_states = self.vision_tower(images)
1698
- if self.select_layer == -1:
1699
- image_features = last_feature
1700
- elif self.select_layer < -1:
1701
- image_features = hidden_states[self.select_feature]
1702
- else:
1703
- raise ValueError(f'Unexpected select layer: {self.select_layer}')
1704
-
1705
- if self.select_feature == 'patch':
1706
- image_features = image_features[:, 1:]
1707
- elif self.select_feature == 'cls_patch':
1708
- image_features = image_features
1709
- else:
1710
- raise ValueError(f'Unexpected select feature: {self.select_feature}')
1711
-
1712
- return image_features
1713
-
1714
- @property
1715
- def dtype(self):
1716
- return self.vision_tower.dtype
1717
-
1718
- @property
1719
- def device(self):
1720
- return self.vision_tower.device
1721
-
1722
- @property
1723
- def hidden_size(self):
1724
- return self.vision_tower.hidden_size
1725
-
1726
-
1727
- def build_vision_tower(config, **kwargs):
1728
- vision_tower = getattr(config, 'vision_tower', None)
1729
- if 'vit3d' in vision_tower.lower():
1730
- return ViT3DTower(config, **kwargs)
1731
- else:
1732
- raise ValueError(f'Unknown vision tower: {vision_tower}')
1733
-
1734
- class LamedMetaModel:
1735
- def __init__(self, config):
1736
- super(LamedMetaModel, self).__init__(config)
1737
-
1738
- self.config = config
1739
- self.seg_enable = False
1740
-
1741
- if hasattr(config, "vision_tower"):
1742
- self.vision_tower = build_vision_tower(config)
1743
- self.mm_projector = build_mm_projector(config)
1744
-
1745
- if hasattr(config, "segmentation_module") and config.segmentation_module is not None:
1746
- self.seg_enable = True
1747
- self.seg_module = build_segmentation_module(config)
1748
-
1749
- self.seg_projector = nn.Sequential(
1750
- nn.Linear(config.hidden_size, config.hidden_size),
1751
- nn.ReLU(inplace=True),
1752
- nn.Linear(config.hidden_size, config.mm_hidden_size),
1753
- nn.Dropout(0.1),
1754
- )
1755
-
1756
- self.dice_loss = BinaryDiceLoss()
1757
- self.bce_loss = BCELoss()
1758
-
1759
- def get_vision_tower(self):
1760
- vision_tower = getattr(self, 'vision_tower', None)
1761
- return vision_tower
1762
-
1763
- def initialize_vision_modules(self, model_args):
1764
- self.config.image_channel = model_args.image_channel
1765
- self.config.image_size = model_args.image_size
1766
- self.config.patch_size = model_args.patch_size
1767
-
1768
- self.config.vision_tower = model_args.vision_tower
1769
- self.config.vision_select_layer = model_args.vision_select_layer
1770
- self.config.vision_select_feature = model_args.vision_select_feature
1771
-
1772
- self.config.mm_projector_type = model_args.mm_projector_type
1773
- self.config.proj_layer_type = model_args.proj_layer_type
1774
- self.config.proj_layer_num = model_args.proj_layer_num
1775
- self.config.proj_pooling_type = model_args.proj_pooling_type
1776
- self.config.proj_pooling_size = model_args.proj_pooling_size
1777
-
1778
- # vision tower
1779
- if self.get_vision_tower() is None:
1780
- self.vision_tower = build_vision_tower(self.config)
1781
- # If you have a more robust vision encoder, try freezing the vision tower by requires_grad_(False)
1782
-
1783
-
1784
- if model_args.pretrain_vision_model is not None:
1785
- vision_model_weights = torch.load(model_args.pretrain_vision_model, map_location='cpu')
1786
- self.vision_tower.vision_tower.load_state_dict(vision_model_weights, strict=True)
1787
-
1788
- self.config.mm_hidden_size = self.vision_tower.hidden_size
1789
-
1790
- # mm_projector
1791
- if getattr(self, 'mm_projector', None) is None:
1792
- self.mm_projector = build_mm_projector(self.config)
1793
-
1794
- if model_args.pretrain_mm_mlp_adapter is not None:
1795
- mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
1796
- def get_w(weights, keyword):
1797
- return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
1798
- self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=True)
1799
-
1800
- def initialize_seg_modules(self, model_args):
1801
- self.config.segmentation_module = model_args.segmentation_module
1802
-
1803
- # segmentation_module
1804
- if getattr(self, 'segmentation_module', None) is None:
1805
- self.seg_module = build_segmentation_module(self.config)
1806
- self.seg_projector = nn.Sequential(
1807
- nn.Linear(self.config.hidden_size, self.config.hidden_size),
1808
- nn.ReLU(inplace=True),
1809
- nn.Linear(self.config.hidden_size, self.config.mm_hidden_size),
1810
- nn.Dropout(0.1),
1811
- )
1812
- self.seg_enable = True
1813
-
1814
- if model_args.pretrain_seg_module is not None:
1815
- seg_module_weights = torch.load(model_args.pretrain_seg_module, map_location='cpu')
1816
- self.seg_module.load_state_dict(seg_module_weights, strict=True)
1817
-
1818
- self.dice_loss = BinaryDiceLoss()
1819
- self.bce_loss = BCELoss()
1820
-
1821
- class LamedMetaForCausalLM(ABC):
1822
- @abstractmethod
1823
- def get_model(self):
1824
- pass
1825
-
1826
- def get_vision_tower(self):
1827
- return self.get_model().get_vision_tower()
1828
-
1829
- def encode_images(self, images):
1830
- image_features = self.get_model().get_vision_tower()(images)
1831
- image_features = self.get_model().mm_projector(image_features)
1832
- return image_features
1833
-
1834
- def prepare_inputs_for_multimodal(
1835
- self, input_ids, position_ids, attention_mask, past_key_values, labels,
1836
- images,
1837
- ):
1838
- vision_tower = self.get_vision_tower()
1839
- if vision_tower is None or images is None or input_ids.shape[1] == 1:
1840
- return input_ids, position_ids, attention_mask, past_key_values, None, labels
1841
- else:
1842
- image_features = self.encode_images(images)
1843
- inputs_embeds = self.get_model().embed_tokens(input_ids)
1844
- inputs_embeds = torch.cat(
1845
- (inputs_embeds[:, :1, :], image_features, inputs_embeds[:, (image_features.shape[1] + 1):, :]), dim=1)
1846
- return None, position_ids, attention_mask, past_key_values, inputs_embeds, labels
1847
-
1848
- def initialize_vision_tokenizer(self, model_args, tokenizer):
1849
- num_new_tokens = model_args.num_new_tokens
1850
-
1851
- self.resize_token_embeddings(len(tokenizer))
1852
-
1853
- if num_new_tokens > 0:
1854
- input_embeddings = self.get_input_embeddings().weight.data
1855
- output_embeddings = self.get_output_embeddings().weight.data
1856
-
1857
- input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
1858
- dim=0, keepdim=True)
1859
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
1860
- dim=0, keepdim=True)
1861
-
1862
- input_embeddings[-num_new_tokens:] = input_embeddings_avg
1863
- output_embeddings[-num_new_tokens:] = output_embeddings_avg
1864
-
1865
- if model_args.tune_mm_mlp_adapter:
1866
- for p in self.get_input_embeddings().parameters():
1867
- p.requires_grad = True
1868
- for p in self.get_output_embeddings().parameters():
1869
- p.requires_grad = False
1870
- else:
1871
- # we add 4 new tokens
1872
- # if new tokens need input, please train input_embeddings
1873
- for p in self.get_input_embeddings().parameters():
1874
- p.requires_grad = True
1875
- # if new tokens need predict, please train output_embeddings
1876
- for p in self.get_output_embeddings().parameters():
1877
- p.requires_grad = True
1878
-
1879
- if model_args.pretrain_mm_mlp_adapter:
1880
- mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
1881
- embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
1882
-
1883
- if input_embeddings.shape == embed_tokens_weight.shape:
1884
- input_embeddings = embed_tokens_weight
1885
- elif embed_tokens_weight.shape[0] == num_new_tokens:
1886
- input_embeddings[-num_new_tokens:] = embed_tokens_weight
1887
- else:
1888
- raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
1889
-
1890
-
1891
-
1892
- class LamedPhi3Model(LamedMetaModel, Phi3Model):
1893
- config_class = LamedPhi3Config
1894
- def __init__(self, config: Phi3Config):
1895
- super(LamedPhi3Model, self).__init__(config)
1896
-
1897
-
1898
- class LamedPhi3ForCausalLM(LamedMetaForCausalLM, Phi3ForCausalLM):
1899
- config_class = LamedPhi3Config
1900
-
1901
- def __init__(self, config):
1902
- super(LamedPhi3ForCausalLM, self).__init__(config)
1903
- self.model = LamedPhi3Model(config)
1904
- self.vocab_size = config.vocab_size
1905
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1906
-
1907
- # Initialize weights and apply final processing
1908
- self.post_init()
1909
-
1910
- def get_model(self):
1911
- return self.model
1912
-
1913
- def forward(
1914
- self,
1915
- images: Optional[torch.FloatTensor] = None,
1916
- input_ids: torch.LongTensor = None,
1917
- labels: Optional[torch.LongTensor] = None,
1918
- attention_mask: Optional[torch.Tensor] = None,
1919
- segs: Optional[torch.FloatTensor] = None,
1920
-
1921
- position_ids: Optional[torch.LongTensor] = None,
1922
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1923
- inputs_embeds: Optional[torch.FloatTensor] = None,
1924
- use_cache: Optional[bool] = None,
1925
- output_attentions: Optional[bool] = None,
1926
- output_hidden_states: Optional[bool] = None,
1927
- return_dict: Optional[bool] = None,
1928
- cache_position: Optional[torch.LongTensor] = None,
1929
- **kwargs, # <<<<<<<<<<<<<< 加上这行!
1930
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1931
-
1932
- input_ids_pre = input_ids
1933
-
1934
- if inputs_embeds is None:
1935
- (
1936
- input_ids,
1937
- position_ids,
1938
- attention_mask,
1939
- past_key_values,
1940
- inputs_embeds,
1941
- labels
1942
- ) = self.prepare_inputs_for_multimodal(
1943
- input_ids,
1944
- position_ids,
1945
- attention_mask,
1946
- past_key_values,
1947
- labels,
1948
- images,
1949
- )
1950
-
1951
- try:
1952
- seg_ids = torch.nonzero(torch.sum(segs, dim=(1, 2, 3, 4))).flatten().tolist()
1953
- except:
1954
- seg_ids = []
1955
-
1956
- if self.get_model().seg_enable and seg_ids:
1957
- outputs = super().forward(
1958
- input_ids=input_ids,
1959
- inputs_embeds=inputs_embeds,
1960
- attention_mask=attention_mask,
1961
- labels=labels,
1962
- output_hidden_states=True,
1963
-
1964
- position_ids=position_ids,
1965
- past_key_values=past_key_values,
1966
- use_cache=use_cache,
1967
- output_attentions=output_attentions,
1968
- return_dict=return_dict
1969
- )
1970
-
1971
- output_hidden_states = outputs.hidden_states
1972
-
1973
- last_hidden_state = output_hidden_states[-1]
1974
-
1975
- seg_token_mask = input_ids_pre[:, 1:] == self.config.seg_token_id
1976
- seg_token_mask = torch.cat(
1977
- [
1978
- seg_token_mask,
1979
- torch.zeros((seg_token_mask.shape[0], 1), dtype=seg_token_mask.dtype).cuda(),
1980
- ],
1981
- dim=1,
1982
- )
1983
-
1984
- seg_prompts = []
1985
- for i in seg_ids:
1986
- if torch.sum(seg_token_mask[i]) == 1:
1987
- seg_token = last_hidden_state[i][seg_token_mask[i]]
1988
- seg_prompt = self.get_model().seg_projector(seg_token)
1989
- elif torch.sum(seg_token_mask[i]) > 1:
1990
- seg_tokens = last_hidden_state[i][seg_token_mask[i]]
1991
- seg_token = torch.mean(seg_tokens, dim=0, keepdim=True)
1992
- seg_prompt = self.get_model().seg_projector(seg_token)
1993
- else:
1994
- seg_prompt = torch.zeros([1, self.config.mm_hidden_size], dtype=last_hidden_state.dtype,
1995
- device=last_hidden_state.device)
1996
- seg_prompts.append(seg_prompt)
1997
-
1998
- seg_prompts = torch.cat(seg_prompts, dim=0)
1999
- logits = self.get_model().seg_module(images[seg_ids], text_emb=seg_prompts)
2000
- loss_dice = self.get_model().dice_loss(logits, segs[seg_ids])
2001
- loss_bce = self.get_model().bce_loss(logits, segs[seg_ids])
2002
- seg_loss = loss_dice + loss_bce
2003
- outputs.loss = outputs.loss + seg_loss
2004
- return outputs
2005
- else:
2006
- return super().forward(
2007
- input_ids=input_ids,
2008
- attention_mask=attention_mask,
2009
- position_ids=position_ids,
2010
- past_key_values=past_key_values,
2011
- inputs_embeds=inputs_embeds,
2012
- labels=labels,
2013
- use_cache=use_cache,
2014
- output_attentions=output_attentions,
2015
- output_hidden_states=output_hidden_states,
2016
- return_dict=return_dict
2017
- )
2018
-
2019
-
2020
- @torch.no_grad()
2021
- def generate(
2022
- self,
2023
- images: Optional[torch.Tensor] = None,
2024
- inputs: Optional[torch.Tensor] = None,
2025
- seg_enable: bool = False,
2026
- **kwargs,
2027
- ) -> Union[GenerateOutput, torch.LongTensor, Any]:
2028
- position_ids = kwargs.pop("position_ids", None)
2029
- attention_mask = kwargs.pop("attention_mask", None)
2030
- if "inputs_embeds" in kwargs:
2031
- raise NotImplementedError("`inputs_embeds` is not supported")
2032
-
2033
- if images is not None:
2034
- (
2035
- inputs,
2036
- position_ids,
2037
- attention_mask,
2038
- _,
2039
- inputs_embeds,
2040
- _
2041
- ) = self.prepare_inputs_for_multimodal(
2042
- inputs,
2043
- position_ids,
2044
- attention_mask,
2045
- None,
2046
- None,
2047
- images,
2048
- )
2049
- else:
2050
- inputs_embeds = self.get_model().embed_tokens(inputs)
2051
-
2052
- if seg_enable:
2053
- outputs = super().generate(
2054
- inputs_embeds=inputs_embeds,
2055
- output_hidden_states=True,
2056
- return_dict_in_generate=True,
2057
- **kwargs
2058
- )
2059
-
2060
- output_hidden_states = outputs.hidden_states
2061
- output_ids = outputs.sequences
2062
-
2063
- seg_token_mask = output_ids[:, 1:] == self.config.seg_token_id
2064
-
2065
- last_tensors = [tuple[-1] for tuple in output_hidden_states]
2066
- last_hidden_state = torch.cat(last_tensors[1:], dim=1)
2067
-
2068
- seg_prompts = []
2069
- noseg_ids = []
2070
- for i in range(len(seg_token_mask)):
2071
- if torch.sum(seg_token_mask[i]) == 1:
2072
- seg_token = last_hidden_state[i][seg_token_mask[i]]
2073
- seg_prompt = self.get_model().seg_projector(seg_token)
2074
- elif torch.sum(seg_token_mask[i]) > 1:
2075
- seg_tokens = last_hidden_state[i][seg_token_mask[i]]
2076
- seg_token = torch.mean(seg_tokens, dim=0, keepdim=True)
2077
- seg_prompt = self.get_model().seg_projector(seg_token)
2078
- else:
2079
- noseg_ids.append(i)
2080
- seg_prompt = torch.zeros([1, self.config.mm_hidden_size], dtype=last_hidden_state.dtype,
2081
- device=last_hidden_state.device)
2082
- seg_prompts.append(seg_prompt)
2083
-
2084
- seg_prompts = torch.cat(seg_prompts, dim=0)
2085
- logits = self.get_model().seg_module(images, seg_prompts)
2086
- logits[noseg_ids] = -torch.inf
2087
-
2088
- return output_ids, logits
2089
- else:
2090
- output_ids = super().generate(
2091
- inputs_embeds=inputs_embeds,
2092
- **kwargs
2093
- )
2094
- return output_ids
2095
-
2096
-
2097
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
2098
- inputs_embeds=None, **kwargs):
2099
- images = kwargs.pop("images", None)
2100
- inputs = super().prepare_inputs_for_generation(
2101
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
2102
- )
2103
- if images is not None:
2104
- inputs['images'] = images
2105
- return inputs