DannyJun commited on
Commit
7691d80
·
verified ·
1 Parent(s): c3acb98

Upload modeling_sprvla.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_sprvla.py +2124 -0
modeling_sprvla.py ADDED
@@ -0,0 +1,2124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from copy import deepcopy
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union, Dict, Any, Sequence, Callable
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from contextlib import nullcontext
10
+
11
+ from transformers.models.auto import AutoModelForCausalLM, AutoModelForImageTextToText
12
+ from transformers.activations import ACT2FN
13
+ from transformers.cache_utils import Cache, DynamicCache
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.generation.configuration_utils import GenerationConfig
16
+ from transformers.generation.utils import GenerateOutput
17
+ from transformers.integrations import use_kernel_forward_from_hub
18
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
19
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward, FlashAttentionKwargs
20
+ from transformers import GradientCheckpointingLayer
21
+ from transformers.modeling_outputs import (
22
+ BaseModelOutput,
23
+ BaseModelOutputWithPast,
24
+ BaseModelOutputWithPooling,
25
+ CausalLMOutputWithPast,
26
+ )
27
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
28
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
29
+ from transformers.processing_utils import Unpack
30
+ from transformers.utils import (
31
+ ModelOutput,
32
+ can_return_tuple,
33
+ is_torch_flex_attn_available,
34
+ logging,
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ )
38
+
39
+ from .configuration_sprvla import SPRVLAConfig, SPRVLAVitConfig, SPRVLAAdapterConfig, SPRVLALlmConfig
40
+
41
+ import re
42
+ import numpy as np
43
+ from transformers import Qwen2Tokenizer
44
+
45
+
46
+ if is_torch_flex_attn_available():
47
+ from torch.nn.attention.flex_attention import BlockMask
48
+
49
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ SPRVLA_START_DOCSTRING = r"""
56
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
57
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
58
+ etc.)
59
+
60
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
61
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
62
+ and behavior.
63
+
64
+ Parameters:
65
+ config ([`SPRVLAConfig`]):
66
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
67
+ load the weights associated with the model, only the configuration. Check out the
68
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
69
+ """
70
+
71
+
72
+ NUM_RE = re.compile(r'[+-]?(?:\d+(?:\.\d+)?|\.\d+)(?:[eE][+-]?\d+)?$')
73
+ DEPTH_RE = re.compile(r'<DEPTH_START>(.*?)<DEPTH_END>', re.DOTALL)
74
+ # One-level-nested [...] matcher: outer block that may contain inner [ ... ] lists
75
+ OUTER_BLOCK_RE = re.compile(r'\[(?:[^\[\]]|\[[^\[\]]*\])+\]')
76
+
77
+ def _is_number(s: str) -> bool:
78
+ return bool(NUM_RE.match(s))
79
+
80
+ def _has_non_ascii(s: str) -> bool:
81
+ return any(ord(ch) > 127 for ch in s)
82
+
83
+ def _to_number(s: str):
84
+ """Parse string number to int when possible, else float."""
85
+ v = float(s)
86
+ return int(v) if v.is_integer() else v
87
+
88
+ def extract_depth_string(text: str, include_tags: bool = False) -> list[str]:
89
+ """
90
+ Return all occurrences of depth strings.
91
+ If include_tags=True, each item is '<DEPTH_START>...<DEPTH_END>';
92
+ otherwise each item is just the inner '...'.
93
+ """
94
+ matches = list(DEPTH_RE.finditer(text))
95
+ if include_tags:
96
+ return [m.group(0) for m in matches]
97
+ return [m.group(1) for m in matches]
98
+
99
+ def extract_trace_lists(
100
+ text: str,
101
+ point_len: int | None = 2, # e.g., 2 for [x,y], 3 for [x,y,z]; None = any length ≥1
102
+ min_points: int = 1
103
+ ) -> list[list[list[float]]]:
104
+ """
105
+ Extract *numeric* lists-of-lists like [[140,225],[130,212],...].
106
+ Returns a list of traces; each trace is a list of points (lists of numbers).
107
+
108
+ Heuristic:
109
+ - Find outer [ ... ] blocks that may contain inner lists
110
+ - Keep blocks where every inner list is fully numeric
111
+ - Enforce per-point length (point_len) and a minimum number of points (min_points)
112
+ """
113
+ traces: list[list[list[float]]] = []
114
+
115
+ # Find outer blocks that can contain nested lists
116
+ for block in OUTER_BLOCK_RE.findall(text):
117
+ inner_strs = re.findall(r'\[([^\[\]]+)\]', block) # contents of each inner [...]
118
+ if len(inner_strs) < min_points:
119
+ continue
120
+
121
+ rows: list[list[float]] = []
122
+ ok = True
123
+ for row in inner_strs:
124
+ parts = [p.strip().strip('"').strip("'") for p in row.split(',')]
125
+ if point_len is not None and len(parts) != point_len:
126
+ ok = False
127
+ break
128
+ if not all(_is_number(p) for p in parts):
129
+ ok = False
130
+ break
131
+ rows.append([_to_number(p) for p in parts])
132
+
133
+ if ok:
134
+ traces.append(rows)
135
+
136
+ return traces
137
+
138
+ def extract_action_token_lists(
139
+ text: str,
140
+ only_len: int | None = None, # e.g., 7 if you expect 7-D actions
141
+ require_non_ascii: bool = True # set False if your tokens can be pure ASCII
142
+ ) -> list[list[str]]:
143
+ """
144
+ Extract all [ ... ] groups split by commas, discard numeric lists,
145
+ and return token lists (quotes stripped, whitespace trimmed).
146
+ """
147
+ lists = []
148
+ # Match NON-nested bracketed groups: [ ... ] without inner [ or ]
149
+ for inner in re.findall(r'\[([^\[\]]+)\]', text):
150
+ parts = [p.strip().strip('"').strip("'") for p in inner.split(',')]
151
+
152
+ if only_len is not None and len(parts) != only_len:
153
+ continue
154
+
155
+ # If *all* items are numeric -> not action tokens (like coordinates)
156
+ if all(_is_number(p) for p in parts):
157
+ continue
158
+
159
+ # Optionally require at least one non-ASCII char across tokens (helps exclude plain words/numbers)
160
+ if require_non_ascii and not any(_has_non_ascii(p) for p in parts):
161
+ continue
162
+
163
+ lists.append(parts)
164
+
165
+ return lists
166
+
167
+
168
+ @dataclass
169
+ class SPRVLACausalLMOutputWithPast(ModelOutput):
170
+ """
171
+ Base class for SPRVLA causal language model (or autoregressive) outputs.
172
+
173
+ Args:
174
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
175
+ Language modeling loss (for next-token prediction).
176
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
177
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
178
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
179
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
180
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
181
+
182
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
183
+ `past_key_values` input) to speed up sequential decoding.
184
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
185
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
186
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
187
+
188
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
189
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
190
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
191
+ sequence_length)`.
192
+
193
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
194
+ heads.
195
+ image_hidden_states (`torch.FloatTensor`, *optional*):
196
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
197
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
198
+ """
199
+
200
+ loss: Optional[torch.FloatTensor] = None
201
+ logits: Optional[torch.FloatTensor] = None
202
+ past_key_values: Optional[List[torch.FloatTensor]] = None
203
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
204
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
205
+ image_hidden_states: Optional[torch.FloatTensor] = None
206
+
207
+
208
+ @dataclass
209
+ class SPRVLAModelOutputWithPast(BaseModelOutputWithPast):
210
+ """
211
+ Base class for SPRVLA outputs, with hidden states and attentions.
212
+
213
+ Args:
214
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
215
+ Sequence of hidden-states at the output of the last layer of the model.
216
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
217
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
218
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
219
+
220
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
221
+ `past_key_values` input) to speed up sequential decoding.
222
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
223
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
224
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
225
+
226
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
227
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
228
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
229
+ sequence_length)`.
230
+
231
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
232
+ heads.
233
+ image_hidden_states (`torch.FloatTensor`, *optional*):
234
+ A `torch.FloatTensor` of size `(batch_num_patches, hidden_size)`.
235
+ image_hidden_states of the model produced by the vision backbone
236
+ """
237
+
238
+ image_hidden_states: Optional[torch.FloatTensor] = None
239
+ logits: Optional[torch.FloatTensor] = None
240
+
241
+
242
+ class SPRVLAPreTrainedModel(PreTrainedModel):
243
+ config_class = SPRVLALlmConfig
244
+ base_model_prefix = "model"
245
+ supports_gradient_checkpointing = True
246
+ _no_split_modules = ["SPRVLADecoderLayer", "SPRVLAPostNormDecoderLayer"]
247
+ _skip_keys_device_placement = ["past_key_values"]
248
+ _supports_flash_attn_2 = True
249
+ _supports_sdpa = True
250
+ _supports_flex_attn = False
251
+ _supports_cache_class = True
252
+ _supports_quantized_cache = True
253
+ _supports_static_cache = True
254
+ _supports_attention_backend = True
255
+
256
+ def _init_weights(self, module):
257
+ std = self.config.initializer_range
258
+ if isinstance(module, (nn.Linear,)):
259
+ module.weight.data.normal_(mean=0.0, std=std)
260
+ if module.bias is not None:
261
+ module.bias.data.zero_()
262
+ elif isinstance(module, SPRVLAEmbedding):
263
+ module.embedding.data.normal_(mean=0.0, std=std)
264
+ module.new_embedding.data.normal_(mean=0.0, std=std)
265
+ elif isinstance(module, nn.Embedding):
266
+ module.weight.data.normal_(mean=0.0, std=std)
267
+ if module.padding_idx is not None:
268
+ module.weight.data[module.padding_idx].zero_()
269
+ elif isinstance(module, SPRVLARMSNorm):
270
+ module.weight.data.fill_(1.0)
271
+ elif isinstance(module, nn.LayerNorm):
272
+ module.weight.data.fill_(1.0)
273
+ if module.bias is not None:
274
+ module.bias.data.zero_()
275
+
276
+
277
+ class ViTMLP(nn.Module):
278
+ def __init__(self, dim: int, hidden_dim: int, hidden_act: str, device: Union[str, torch.device] = None):
279
+ super().__init__()
280
+ self.w1 = nn.Linear(dim, hidden_dim, bias=True, device=device)
281
+ self.act = ACT2FN[hidden_act]
282
+ self.w2 = nn.Linear(hidden_dim, dim, bias=True, device=device)
283
+
284
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
285
+ return self.w2(self.act(self.w1(x)))
286
+
287
+
288
+ class ViTMultiHeadDotProductAttention(nn.Module):
289
+ def __init__(
290
+ self,
291
+ hidden_size: int,
292
+ num_heads: int,
293
+ num_key_value_heads: int,
294
+ head_dim: int,
295
+ use_bias: bool = True,
296
+ input_dim: Optional[int] = None,
297
+ float32_attention: bool = True,
298
+ attention_dropout: float = 0.0,
299
+ residual_dropout: float = 0.0,
300
+ device: Union[str, torch.device] = None,
301
+ attn_implementation: str = "eager",
302
+ ):
303
+ super().__init__()
304
+
305
+ self.hidden_size = hidden_size
306
+ self.num_heads = num_heads
307
+ self.head_dim = head_dim
308
+ self.num_key_value_heads = num_key_value_heads
309
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
310
+ self.attn_implementation = attn_implementation
311
+ self.is_causal = False
312
+
313
+ input_dim = input_dim or hidden_size
314
+
315
+ self.wq = nn.Linear(
316
+ input_dim,
317
+ self.num_heads * self.head_dim,
318
+ bias=use_bias,
319
+ device=device,
320
+ )
321
+ self.wk = nn.Linear(
322
+ input_dim,
323
+ self.num_key_value_heads * self.head_dim,
324
+ bias=use_bias,
325
+ device=device,
326
+ )
327
+ self.wv = nn.Linear(
328
+ input_dim,
329
+ self.num_key_value_heads * self.head_dim,
330
+ bias=use_bias,
331
+ device=device,
332
+ )
333
+ self.wo = nn.Linear(
334
+ self.num_heads * self.head_dim,
335
+ self.hidden_size,
336
+ )
337
+ self.float32_attention = float32_attention
338
+ self.attention_dropout = attention_dropout
339
+ self.residual_dropout = nn.Dropout(residual_dropout)
340
+
341
+ def _split_heads(self, hidden_states, num_heads) -> torch.Tensor:
342
+ return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
343
+
344
+ def _merge_heads(self, hidden_states) -> torch.Tensor:
345
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
346
+
347
+ def forward(
348
+ self,
349
+ inputs_q: torch.Tensor,
350
+ inputs_kv: Optional[torch.Tensor] = None,
351
+ attn_mask: Optional[torch.Tensor] = None,
352
+ ) -> torch.Tensor:
353
+
354
+ if inputs_kv is not None:
355
+ inputs_k = inputs_kv
356
+ inputs_v = inputs_kv
357
+ else:
358
+ inputs_k = inputs_q
359
+ inputs_v = inputs_q
360
+
361
+ xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v)
362
+
363
+ xq = self._split_heads(xq, self.num_heads)
364
+ xk = self._split_heads(xk, self.num_key_value_heads)
365
+ xv = self._split_heads(xv, self.num_key_value_heads)
366
+
367
+ if self.num_heads != self.num_key_value_heads:
368
+ xk = xk.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
369
+ xv = xv.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
370
+
371
+ og_dtype = xq.dtype
372
+
373
+ if self.float32_attention:
374
+ xq = xq.to(torch.float)
375
+ xk = xk.to(torch.float)
376
+ xv = xv.to(torch.float)
377
+ elif self.attn_implementation == "sdpa" and not torch.is_autocast_enabled():
378
+ xv = xv.to(torch.float)
379
+
380
+ dropout_p = 0.0 if not self.training else self.attention_dropout
381
+
382
+ if self.attn_implementation == "eager":
383
+ attn_weights = torch.einsum("...qhd,...khd->...hqk", xq / math.sqrt(xq.size(-1)), xk)
384
+ attn_weights = F.softmax(attn_weights, dim=-1)
385
+ attn_weights = F.dropout(
386
+ attn_weights,
387
+ p=dropout_p,
388
+ training=self.training
389
+ )
390
+ attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv)
391
+
392
+ elif self.attn_implementation == "sdpa":
393
+ if not torch.is_autocast_enabled():
394
+ xv = xv.to(torch.float)
395
+
396
+ flash_ok = (
397
+ attn_mask is None
398
+ and xq.dtype in (torch.float16, torch.bfloat16)
399
+ and xk.dtype == xq.dtype
400
+ and xv.dtype == xq.dtype
401
+ )
402
+
403
+ sdp_ctx = (
404
+ torch.backends.cuda.sdp_kernel(
405
+ enable_flash=flash_ok,
406
+ enable_mem_efficient=True,
407
+ enable_math=True,
408
+ enable_cudnn=True,
409
+ )
410
+ if hasattr(torch.backends.cuda, "sdp_kernel")
411
+ else nullcontext()
412
+ )
413
+ with sdp_ctx:
414
+ attn_output = F.scaled_dot_product_attention(
415
+ xq.transpose(1, 2).contiguous(),
416
+ xk.transpose(1, 2).contiguous(),
417
+ xv.transpose(1, 2).contiguous(),
418
+ attn_mask=attn_mask,
419
+ is_causal=False,
420
+ dropout_p=dropout_p,
421
+ ).transpose(1, 2)
422
+
423
+ elif self.attn_implementation == "flash_attention_2":
424
+ assert not self.config.float32_attention
425
+ # Downcast in case we are running with fp32 hidden states
426
+ attn_output = _flash_attention_forward(
427
+ xq.transpose(1, 2).to(torch.bfloat16),
428
+ xk.transpose(1, 2).to(torch.bfloat16),
429
+ xv.transpose(1, 2).to(torch.bfloat16),
430
+ attention_mask=None,
431
+ query_length=inputs_q.shape[1],
432
+ is_causal=False,
433
+ dropout=dropout_p,
434
+ )
435
+ else:
436
+ raise ValueError(f"Attention implementation {self.attn_implementation} not supported")
437
+
438
+ attn_output = attn_output.to(og_dtype)
439
+ attn_output = self._merge_heads(attn_output)
440
+ attn_output = self.wo(attn_output)
441
+ attn_output = self.residual_dropout(attn_output)
442
+
443
+ return attn_output
444
+
445
+
446
+ class SPRVLAVisionBlock(nn.Module):
447
+
448
+ def __init__(self, config: SPRVLAVitConfig, device: Union[str, torch.device] = None):
449
+ super().__init__()
450
+ self.attention = ViTMultiHeadDotProductAttention(
451
+ hidden_size=config.hidden_size,
452
+ num_heads=config.num_attention_heads,
453
+ num_key_value_heads=config.num_key_value_heads,
454
+ head_dim=config.head_dim,
455
+ float32_attention=config.float32_attention,
456
+ attention_dropout=config.attention_dropout,
457
+ residual_dropout=config.residual_dropout,
458
+ device=device,
459
+ attn_implementation=config._attn_implementation,
460
+ )
461
+ self.feed_forward = ViTMLP(config.hidden_size, config.intermediate_size, config.hidden_act, device=device)
462
+ self.attention_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device)
463
+ self.ffn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device)
464
+
465
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
466
+ x = x + self.attention(self.attention_norm(x))
467
+ x = x + self.feed_forward(self.ffn_norm(x))
468
+ return x
469
+
470
+
471
+ class SPRVLAVisionBlockCollection(nn.Module):
472
+
473
+ def __init__(self, config: SPRVLAVitConfig, device: Union[str, torch.device] = None):
474
+ super().__init__()
475
+ self.conifg = config
476
+ self.resblocks = nn.ModuleList([
477
+ SPRVLAVisionBlock(config, device) for _ in range(config.num_hidden_layers)
478
+ ])
479
+
480
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
481
+ hidden_states = []
482
+ for r in self.resblocks:
483
+ x = r(x)
484
+ hidden_states.append(x)
485
+ return hidden_states
486
+
487
+
488
+ def _expand_token(token, batch_size: int):
489
+ return token.view(1, 1, -1).expand(batch_size, -1, -1)
490
+
491
+
492
+ class SPRVLAVisionTransformer(nn.Module):
493
+
494
+ def __init__(self, config: SPRVLAVitConfig, device: Union[str, torch.device] = None):
495
+ super().__init__()
496
+ self.config = config
497
+
498
+ self.scale = config.hidden_size ** -0.5
499
+
500
+ # optional CLS
501
+ self.num_prefix_tokens: int = 1 if config.use_cls_token else 0
502
+ if config.use_cls_token:
503
+ self.class_embedding = nn.Parameter(
504
+ torch.zeros(config.hidden_size, device=device)
505
+ )
506
+
507
+ # positional embeddings
508
+ self.positional_embedding = nn.Parameter(
509
+ torch.zeros(config.image_num_pos, config.hidden_size, device=device),
510
+ )
511
+
512
+ image_patch_size = config.image_patch_size
513
+ self.patch_embedding = nn.Linear(
514
+ image_patch_size * image_patch_size * 3,
515
+ config.hidden_size,
516
+ bias=config.patch_bias,
517
+ device=device,
518
+ )
519
+
520
+ # optional pre-LN
521
+ self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device) \
522
+ if config.pre_layernorm else None
523
+
524
+ self.transformer = SPRVLAVisionBlockCollection(config, device)
525
+
526
+ def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
527
+ pos_emb = self.positional_embedding
528
+ if self.config.use_cls_token:
529
+ cls_pos, pos_emb = pos_emb[:1], pos_emb[1:] # split out CLS
530
+
531
+ pos_emb = pos_emb.reshape(
532
+ (int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])
533
+ )
534
+
535
+ (patch_num_0, patch_num_1) = patch_num
536
+
537
+ if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
538
+ # Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
539
+ # antialias: default True in jax.image.resize
540
+ pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
541
+ pos_emb = F.interpolate(
542
+ pos_emb, size=(patch_num_0, patch_num_1), mode="bicubic", align_corners=False, antialias=True,
543
+ )
544
+ pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
545
+
546
+ pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
547
+
548
+ if self.config.use_cls_token:
549
+ x = x + torch.cat([cls_pos[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype)
550
+ else:
551
+ x = x + pos_emb[None, :, :].to(x.dtype)
552
+
553
+ return x
554
+
555
+ def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]:
556
+ """
557
+ : param x: (batch_size, num_patch, n_pixels)
558
+ """
559
+ if patch_num is None:
560
+ patch_num = self.config.image_num_patch
561
+
562
+ B, N, D = x.shape
563
+
564
+ x = self.patch_embedding(x)
565
+
566
+ if self.config.use_cls_token:
567
+ x = torch.cat([_expand_token(self.class_embedding, x.size(0)).to(x.dtype), x], dim=1)
568
+
569
+ # class embeddings and positional embeddings
570
+ x = self.add_pos_emb(x, patch_num)
571
+
572
+ if self.pre_ln is not None:
573
+ x = self.pre_ln(x)
574
+
575
+ hidden_states = self.transformer(x)
576
+ return hidden_states
577
+
578
+
579
+ class ImageProjectorMLP(nn.Module):
580
+
581
+ def __init__(
582
+ self,
583
+ input_dim: int,
584
+ hidden_dim: int,
585
+ output_dim: int,
586
+ hidden_act: str,
587
+ device: Union[str, torch.device] = None,
588
+ ):
589
+ super().__init__()
590
+ self.w1 = nn.Linear(input_dim, hidden_dim, bias=False, device=device)
591
+ self.w2 = nn.Linear(hidden_dim, output_dim, bias=False, device=device)
592
+ self.w3 = nn.Linear(input_dim, hidden_dim, bias=False, device=device)
593
+ self.act = ACT2FN[hidden_act]
594
+
595
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
596
+ return self.w2(self.act(self.w1(x)) * self.w3(x))
597
+
598
+
599
+ class SPRVLAVisionBackbone(nn.Module):
600
+ def __init__(self, vit_config: SPRVLAVitConfig, adapter_config: SPRVLAAdapterConfig):
601
+ super().__init__()
602
+ self.vit_config = vit_config
603
+ self.adapter_config = adapter_config
604
+
605
+ self.vit_layers = []
606
+ for layer in adapter_config.vit_layers:
607
+ if layer >= 0:
608
+ self.vit_layers.append(layer)
609
+ else:
610
+ self.vit_layers.append(layer + vit_config.num_hidden_layers)
611
+
612
+ last_layer_needed = max(self.vit_layers) + 1
613
+ if last_layer_needed < vit_config.num_hidden_layers:
614
+ new_vit_config = deepcopy(vit_config)
615
+ new_vit_config.num_hidden_layers = last_layer_needed
616
+ self.image_vit = SPRVLAVisionTransformer(new_vit_config)
617
+ else:
618
+ self.image_vit = SPRVLAVisionTransformer(vit_config)
619
+
620
+ self.num_prefix_tokens: int = self.image_vit.num_prefix_tokens
621
+
622
+ # optional pad_embed
623
+ self.pad_embed = None
624
+ if adapter_config.image_padding_embed == "pad_and_partial_pad":
625
+ pool_dim = vit_config.hidden_size * len(adapter_config.vit_layers)
626
+ self.pad_embed = nn.Parameter(torch.zeros((2, pool_dim)))
627
+
628
+ pool_dim = vit_config.hidden_size * len(adapter_config.vit_layers)
629
+ self.image_pooling_2d = ViTMultiHeadDotProductAttention(
630
+ hidden_size=adapter_config.hidden_size,
631
+ num_heads=adapter_config.num_attention_heads,
632
+ num_key_value_heads=adapter_config.num_key_value_heads,
633
+ head_dim=adapter_config.head_dim,
634
+ input_dim=pool_dim,
635
+ float32_attention=adapter_config.float32_attention,
636
+ attention_dropout=adapter_config.attention_dropout,
637
+ residual_dropout=adapter_config.residual_dropout,
638
+ attn_implementation=adapter_config._attn_implementation,
639
+ )
640
+ self.image_projector = ImageProjectorMLP(
641
+ adapter_config.hidden_size,
642
+ adapter_config.intermediate_size,
643
+ adapter_config.text_hidden_size,
644
+ adapter_config.hidden_act,
645
+ )
646
+ self.image_feature_dropout = nn.Dropout(adapter_config.image_feature_dropout)
647
+
648
+ def encode_image(self, images: torch.Tensor) -> torch.Tensor:
649
+ """
650
+ : param images: (batch_size, num_crops, num_patch, n_pixels)
651
+ """
652
+ B, T, N, D = images.shape
653
+ images = images.view(B * T, N, D)
654
+ image_features = self.image_vit(images)
655
+
656
+ features = []
657
+ for layer in self.vit_layers:
658
+ features.append(image_features[layer])
659
+ image_features = torch.cat(features, dim=-1)
660
+
661
+ if self.num_prefix_tokens > 0:
662
+ image_features = image_features[:, 1:]
663
+ image_features = image_features.view(B, T, N, -1)
664
+ return image_features
665
+
666
+ @property
667
+ def dtype(self) -> torch.dtype:
668
+ return self.image_vit.patch_embedding.weight.dtype
669
+
670
+ @property
671
+ def device(self) -> torch.device:
672
+ return self.image_vit.patch_embedding.weight.device
673
+
674
+ def forward(
675
+ self,
676
+ images: torch.Tensor,
677
+ pooled_patches_idx: torch.Tensor,
678
+ image_masks: torch.Tensor = None,
679
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
680
+
681
+ # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
682
+ batch_size, num_image = images.shape[:2]
683
+ images = images.to(device=self.device, dtype=self.dtype)
684
+ image_features = self.encode_image(images)
685
+
686
+ # optional padding embeddings
687
+ if self.pad_embed is not None and image_masks is not None:
688
+ image_masks = image_masks.to(device=self.device)
689
+ all_pad = (image_masks == 0).to(image_features.dtype)
690
+ partial = torch.logical_and(image_masks < 1, ~ (image_masks == 0)).to(image_features.dtype)
691
+ image_features = image_features + self.pad_embed[0][None,None,None,:] * all_pad[...,None] \
692
+ + self.pad_embed[1][None,None,None,:] * partial[...,None]
693
+
694
+ image_features = self.image_feature_dropout(image_features)
695
+ dim = image_features.shape[-1]
696
+
697
+ valid = pooled_patches_idx >= 0
698
+ valid_token = torch.any(valid, -1)
699
+
700
+ # Use `pooled_patches_idx` to arange the features for image pooling
701
+ batch_idx = torch.arange(pooled_patches_idx.shape[0], dtype=torch.long, device=pooled_patches_idx.device)
702
+ batch_idx = torch.tile(batch_idx.view(batch_size, 1, 1), [1, pooled_patches_idx.shape[1], pooled_patches_idx.shape[2]])
703
+
704
+ # Now [batch, num_high_res_features, pool_dim, dim]
705
+ to_pool = image_features.reshape(batch_size, -1, dim)[batch_idx, torch.clip(pooled_patches_idx, 0)]
706
+ to_pool = to_pool * valid.to(self.dtype)[:, :, :, None]
707
+ to_pool = to_pool.reshape([-1, pooled_patches_idx.shape[-1], dim])
708
+
709
+ query = to_pool.mean(-2, keepdim=True)
710
+ pooled_features = self.image_pooling_2d(query, to_pool)
711
+ pooled_features = pooled_features.reshape([batch_size, -1, pooled_features.shape[-1]])
712
+
713
+ # MLP layer to map the feature.
714
+ pooled_features = self.image_projector(pooled_features)
715
+ return pooled_features.view(-1, pooled_features.shape[-1])[valid_token.flatten()]
716
+
717
+
718
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
719
+ def rotate_half(x):
720
+ """Rotates half the hidden dims of the input."""
721
+ x1 = x[..., : x.shape[-1] // 2]
722
+ x2 = x[..., x.shape[-1] // 2 :]
723
+ return torch.cat((-x2, x1), dim=-1)
724
+
725
+
726
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
727
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
728
+ """Applies Rotary Position Embedding to the query and key tensors.
729
+
730
+ Args:
731
+ q (`torch.Tensor`): The query tensor.
732
+ k (`torch.Tensor`): The key tensor.
733
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
734
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
735
+ position_ids (`torch.Tensor`, *optional*):
736
+ Deprecated and unused.
737
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
738
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
739
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
740
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
741
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
742
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
743
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
744
+ Returns:
745
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
746
+ """
747
+ cos = cos.unsqueeze(unsqueeze_dim)
748
+ sin = sin.unsqueeze(unsqueeze_dim)
749
+ q_embed = (q * cos) + (rotate_half(q) * sin)
750
+ k_embed = (k * cos) + (rotate_half(k) * sin)
751
+ return q_embed, k_embed
752
+
753
+
754
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
755
+ class SPRVLARotaryEmbedding(nn.Module):
756
+
757
+ def __init__(self, config: SPRVLALlmConfig, device: Union[str, torch.device] = None):
758
+ super().__init__()
759
+ # BC: "rope_type" was originally "type"
760
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
761
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
762
+ else:
763
+ self.rope_type = "default"
764
+ self.max_seq_len_cached = config.max_position_embeddings
765
+ self.original_max_seq_len = config.max_position_embeddings
766
+
767
+ self.config = config
768
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
769
+
770
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
771
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
772
+ self.original_inv_freq = self.inv_freq
773
+
774
+ @torch.no_grad()
775
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
776
+ def forward(self, x, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
777
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
778
+ position_ids_expanded = position_ids[:, None, :].float()
779
+
780
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
781
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
782
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
783
+ emb = torch.cat((freqs, freqs), dim=-1)
784
+ cos = emb.cos() * self.attention_scaling
785
+ sin = emb.sin() * self.attention_scaling
786
+
787
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
788
+
789
+
790
+ @use_kernel_forward_from_hub("RMSNorm")
791
+ class SPRVLARMSNorm(nn.Module):
792
+
793
+ def __init__(
794
+ self,
795
+ size: int,
796
+ eps: float = 1e-6,
797
+ device: Union[str, torch.device] = None,
798
+ ):
799
+ super().__init__()
800
+ self.weight = nn.Parameter(torch.ones(size, device=device))
801
+ self.eps = eps
802
+
803
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
804
+ with torch.autocast(enabled=False, device_type=x.device.type):
805
+ og_dtype = x.dtype
806
+ x = x.to(torch.float32)
807
+ variance = x.pow(2).mean(-1, keepdim=True)
808
+ x = x * torch.rsqrt(variance + self.eps)
809
+ x = x.to(og_dtype)
810
+
811
+ return self.weight * x
812
+
813
+ def extra_repr(self):
814
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
815
+
816
+
817
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
818
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
819
+ """
820
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
821
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
822
+ """
823
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
824
+ if n_rep == 1:
825
+ return hidden_states
826
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
827
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
828
+
829
+
830
+ def eager_attention_forward(
831
+ module: nn.Module,
832
+ query: torch.Tensor,
833
+ key: torch.Tensor,
834
+ value: torch.Tensor,
835
+ attention_mask: Optional[torch.Tensor],
836
+ scaling: float,
837
+ dropout: float = 0.0,
838
+ **kwargs,
839
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
840
+ key_states = repeat_kv(key, module.num_key_value_groups)
841
+ value_states = repeat_kv(value, module.num_key_value_groups)
842
+
843
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
844
+ if attention_mask is not None:
845
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
846
+ attn_weights = attn_weights + causal_mask
847
+
848
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
849
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
850
+ attn_output = torch.matmul(attn_weights, value_states)
851
+ attn_output = attn_output.transpose(1, 2).contiguous()
852
+
853
+ return attn_output, attn_weights
854
+
855
+
856
+ class SPRVLAAttention(nn.Module):
857
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
858
+
859
+ # copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->SPRVLA
860
+ def __init__(self, config: SPRVLALlmConfig, layer_idx: Optional[int] = None) -> None:
861
+ super().__init__()
862
+ self.config = config
863
+ self.layer_idx = layer_idx
864
+ if layer_idx is None:
865
+ logger.warning_once(
866
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
867
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
868
+ "when creating this class."
869
+ )
870
+
871
+ self.num_heads = config.num_attention_heads
872
+ self.num_key_value_heads = config.num_key_value_heads
873
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
874
+ self.head_dim = config.head_dim
875
+ self.scaling = self.head_dim**-0.5
876
+ self.is_causal = True
877
+
878
+ if (config.head_dim * config.num_attention_heads) != config.hidden_size:
879
+ raise ValueError(
880
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {config.hidden_size}"
881
+ f" and `num_attention_heads`: {config.num_attention_heads})."
882
+ )
883
+
884
+ self.fused_dims = (
885
+ config.hidden_size,
886
+ config.head_dim * config.num_key_value_heads,
887
+ config.head_dim * config.num_key_value_heads,
888
+ )
889
+ self.att_proj = nn.Linear(
890
+ config.hidden_size,
891
+ sum(self.fused_dims),
892
+ bias=config.qkv_bias,
893
+ )
894
+
895
+ # Layer norms.
896
+ self.k_norm: Optional[SPRVLARMSNorm] = None
897
+ self.q_norm: Optional[SPRVLARMSNorm] = None
898
+ self.qk_norm_type: Optional[str] = None
899
+ if config.use_qk_norm:
900
+ k_norm_size = (
901
+ config.head_dim
902
+ if config.qk_norm_type == "qwen3" else
903
+ config.num_key_value_heads * config.head_dim
904
+ )
905
+ self.k_norm = SPRVLARMSNorm(k_norm_size, eps=config.layer_norm_eps)
906
+ q_norm_size = (
907
+ config.head_dim
908
+ if config.qk_norm_type == "qwen3" else
909
+ config.num_attention_heads * config.head_dim
910
+ )
911
+ self.q_norm = SPRVLARMSNorm(q_norm_size, eps=config.layer_norm_eps)
912
+ self.qk_norm_type = config.qk_norm_type
913
+
914
+ self.attention_dropout = config.attention_dropout
915
+
916
+ self.attn_out = nn.Linear(
917
+ config.hidden_size,
918
+ config.hidden_size,
919
+ bias=False,
920
+ )
921
+
922
+ def forward(
923
+ self,
924
+ hidden_states: torch.Tensor,
925
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
926
+ attention_mask: Optional[torch.Tensor],
927
+ past_key_value: Optional[Cache] = None,
928
+ cache_position: Optional[torch.LongTensor] = None,
929
+ **kwargs: Unpack[FlashAttentionKwargs],
930
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
931
+ input_shape = hidden_states.shape[:-1]
932
+ hidden_shape = (*input_shape, -1, self.head_dim)
933
+
934
+ qkv = self.att_proj(hidden_states)
935
+ query_states, key_states, value_states = qkv.split(self.fused_dims, dim=-1)
936
+ value_states = value_states.view(hidden_shape)
937
+
938
+ # Optionally apply layer norm to keys and queries.
939
+ if self.q_norm is not None and self.k_norm is not None and self.qk_norm_type != "qwen3":
940
+ query_states = self.q_norm(query_states)
941
+ key_states = self.k_norm(key_states)
942
+
943
+ query_states = query_states.view(hidden_shape)
944
+ key_states = key_states.view(hidden_shape)
945
+ if self.q_norm is not None and self.k_norm is not None and self.qk_norm_type == "qwen3":
946
+ query_states = self.q_norm(query_states)
947
+ key_states = self.k_norm(key_states)
948
+ query_states = query_states.transpose(1, 2)
949
+ key_states = key_states.transpose(1, 2)
950
+ value_states = value_states.transpose(1, 2)
951
+
952
+ cos, sin = position_embeddings
953
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
954
+
955
+ if past_key_value is not None:
956
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
957
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
958
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
959
+
960
+ attention_interface: Callable = eager_attention_forward
961
+ if self.config._attn_implementation != "eager":
962
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
963
+ logger.warning_once(
964
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
965
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
966
+ )
967
+ else:
968
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
969
+
970
+ attn_output, attn_weights = attention_interface(
971
+ self,
972
+ query_states,
973
+ key_states,
974
+ value_states,
975
+ attention_mask,
976
+ dropout=0.0 if not self.training else self.attention_dropout,
977
+ scaling=self.scaling,
978
+ **kwargs,
979
+ )
980
+
981
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
982
+ attn_output = self.attn_out(attn_output)
983
+
984
+ return attn_output, attn_weights
985
+
986
+
987
+ class LanguageModelMLP(nn.Module):
988
+
989
+ def __init__(
990
+ self,
991
+ input_dim: int,
992
+ intermediate_size: int,
993
+ hidden_act: str,
994
+ device: Union[str, torch.device] = None,
995
+ ):
996
+ super().__init__()
997
+ self.ff_proj = nn.Linear(input_dim, intermediate_size * 2, bias=False, device=device)
998
+ self.ff_out = nn.Linear(intermediate_size, input_dim, bias=False, device=device)
999
+ self.act = ACT2FN[hidden_act]
1000
+
1001
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1002
+ x = self.ff_proj(x)
1003
+ x, gate = x.chunk(2, dim=-1)
1004
+ x = self.act(gate) * x
1005
+ x = self.ff_out(x)
1006
+ return x
1007
+
1008
+
1009
+ class SPRVLADecoderLayer(GradientCheckpointingLayer):
1010
+
1011
+ def __init__(
1012
+ self,
1013
+ config: SPRVLALlmConfig,
1014
+ layer_idx: Optional[int] = None,
1015
+ device: Union[str, torch.device] = None
1016
+ ):
1017
+ super().__init__()
1018
+ self.config = config
1019
+
1020
+ self.self_attn = SPRVLAAttention(config, layer_idx)
1021
+ self.attn_norm = SPRVLARMSNorm(
1022
+ config.hidden_size, eps=config.layer_norm_eps, device=device)
1023
+ self.dropout = nn.Dropout(config.residual_dropout)
1024
+ self.mlp = LanguageModelMLP(
1025
+ config.hidden_size, config.intermediate_size, config.hidden_act, device=device)
1026
+ self.ff_norm = SPRVLARMSNorm(
1027
+ config.hidden_size, eps=config.layer_norm_eps, device=device)
1028
+
1029
+ def forward(
1030
+ self,
1031
+ hidden_states: torch.Tensor,
1032
+ attention_mask: Optional[torch.Tensor] = None,
1033
+ position_ids: Optional[torch.LongTensor] = None,
1034
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1035
+ output_attentions: Optional[bool] = False,
1036
+ use_cache: Optional[bool] = False,
1037
+ cache_position: Optional[torch.LongTensor] = None,
1038
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
1039
+ **kwargs,
1040
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1041
+ """
1042
+ Args:
1043
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1044
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1045
+ `(batch, sequence_length)` where padding elements are indicated by 0.
1046
+ output_attentions (`bool`, *optional*):
1047
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1048
+ returned tensors for more detail.
1049
+ use_cache (`bool`, *optional*):
1050
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1051
+ (see `past_key_values`).
1052
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1053
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1054
+ Indices depicting the position of the input sequence tokens in the sequence.
1055
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
1056
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
1057
+ with `head_dim` being the embedding dimension of each attention head.
1058
+ kwargs (`dict`, *optional*):
1059
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
1060
+ into the model
1061
+ """
1062
+
1063
+ residual = hidden_states
1064
+ hidden_states = self.attn_norm(hidden_states)
1065
+
1066
+ # Self Attention
1067
+ hidden_states, self_attn_weights = self.self_attn(
1068
+ hidden_states=hidden_states,
1069
+ attention_mask=attention_mask,
1070
+ position_ids=position_ids,
1071
+ past_key_value=past_key_value,
1072
+ output_attentions=output_attentions,
1073
+ use_cache=use_cache,
1074
+ cache_position=cache_position,
1075
+ position_embeddings=position_embeddings,
1076
+ )
1077
+
1078
+ hidden_states = residual + self.dropout(hidden_states)
1079
+
1080
+ # Fully Connected
1081
+ residual = hidden_states
1082
+ hidden_states = self.ff_norm(hidden_states)
1083
+ hidden_states = self.mlp(hidden_states)
1084
+
1085
+ hidden_states = residual + self.dropout(hidden_states)
1086
+
1087
+ outputs = (hidden_states,)
1088
+
1089
+ if output_attentions:
1090
+ outputs += (self_attn_weights,)
1091
+
1092
+ return outputs
1093
+
1094
+
1095
+ class SPRVLAPostNormDecoderLayer(SPRVLADecoderLayer):
1096
+ def forward(
1097
+ self,
1098
+ hidden_states: torch.Tensor,
1099
+ attention_mask: Optional[torch.Tensor] = None,
1100
+ position_ids: Optional[torch.LongTensor] = None,
1101
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1102
+ output_attentions: Optional[bool] = False,
1103
+ use_cache: Optional[bool] = False,
1104
+ cache_position: Optional[torch.LongTensor] = None,
1105
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
1106
+ **kwargs,
1107
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1108
+ """
1109
+ Args:
1110
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1111
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1112
+ `(batch, sequence_length)` where padding elements are indicated by 0.
1113
+ output_attentions (`bool`, *optional*):
1114
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1115
+ returned tensors for more detail.
1116
+ use_cache (`bool`, *optional*):
1117
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1118
+ (see `past_key_values`).
1119
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1120
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1121
+ Indices depicting the position of the input sequence tokens in the sequence.
1122
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
1123
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
1124
+ with `head_dim` being the embedding dimension of each attention head.
1125
+ kwargs (`dict`, *optional*):
1126
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
1127
+ into the model
1128
+ """
1129
+
1130
+ residual = hidden_states
1131
+
1132
+ # Self Attention
1133
+ hidden_states, self_attn_weights = self.self_attn(
1134
+ hidden_states=hidden_states,
1135
+ attention_mask=attention_mask,
1136
+ position_ids=position_ids,
1137
+ past_key_value=past_key_value,
1138
+ output_attentions=output_attentions,
1139
+ use_cache=use_cache,
1140
+ cache_position=cache_position,
1141
+ position_embeddings=position_embeddings,
1142
+ )
1143
+ hidden_states = self.attn_norm(hidden_states)
1144
+
1145
+ hidden_states = residual + self.dropout(hidden_states)
1146
+
1147
+ # Fully Connected
1148
+ residual = hidden_states
1149
+ hidden_states = self.mlp(hidden_states)
1150
+ hidden_states = self.ff_norm(hidden_states)
1151
+
1152
+ hidden_states = residual + self.dropout(hidden_states)
1153
+
1154
+ outputs = (hidden_states,)
1155
+
1156
+ if output_attentions:
1157
+ outputs += (self_attn_weights,)
1158
+
1159
+ return outputs
1160
+
1161
+
1162
+ class SPRVLAEmbedding(nn.Module):
1163
+ def __init__(
1164
+ self,
1165
+ num_embeddings: int,
1166
+ num_new_embeddings: int,
1167
+ features: int,
1168
+ device: Union[str, torch.device] = None,
1169
+ ):
1170
+ super().__init__()
1171
+ self.embedding = nn.Parameter(
1172
+ torch.zeros(num_embeddings, features, device=device),
1173
+ )
1174
+ self.new_embedding = nn.Parameter(
1175
+ torch.zeros(num_new_embeddings, features, device=device),
1176
+ )
1177
+
1178
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1179
+ return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0))
1180
+
1181
+
1182
+ SPRVLA_TEXT_ONLY_INPUTS_DOCSTRING = r"""
1183
+ Args:
1184
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1185
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1186
+ it.
1187
+
1188
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1189
+ [`PreTrainedTokenizer.__call__`] for details.
1190
+
1191
+ [What are input IDs?](../glossary#input-ids)
1192
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1193
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1194
+
1195
+ - 1 for tokens that are **not masked**,
1196
+ - 0 for tokens that are **masked**.
1197
+
1198
+ [What are attention masks?](../glossary#attention-mask)
1199
+
1200
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1201
+ [`PreTrainedTokenizer.__call__`] for details.
1202
+
1203
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1204
+ `past_key_values`).
1205
+
1206
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1207
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1208
+ information on the default strategy.
1209
+
1210
+ - 1 indicates the head is **not masked**,
1211
+ - 0 indicates the head is **masked**.
1212
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1213
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1214
+ config.n_positions - 1]`.
1215
+
1216
+ [What are position IDs?](../glossary#position-ids)
1217
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1218
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1219
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1220
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1221
+
1222
+ Two formats are allowed:
1223
+ - a [`~cache_utils.Cache`] instance, see our
1224
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
1225
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1226
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1227
+ cache format.
1228
+
1229
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1230
+ legacy cache format will be returned.
1231
+
1232
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1233
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1234
+ of shape `(batch_size, sequence_length)`.
1235
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1236
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1237
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1238
+ model's internal embedding lookup matrix.
1239
+ use_cache (`bool`, *optional*):
1240
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1241
+ `past_key_values`).
1242
+ output_attentions (`bool`, *optional*):
1243
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1244
+ tensors for more detail.
1245
+ output_hidden_states (`bool`, *optional*):
1246
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1247
+ more detail.
1248
+ return_dict (`bool`, *optional*):
1249
+ Whether or not to return a [`CausalLMOutputWithPast`] instead of a plain tuple.
1250
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1251
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
1252
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
1253
+ the complete sequence length.
1254
+ """
1255
+
1256
+
1257
+ @add_start_docstrings(
1258
+ "The bare SPRVLA text-only model outputting raw hidden-states without any specific head on top.",
1259
+ SPRVLA_START_DOCSTRING,
1260
+ )
1261
+ class SPRVLALlm(SPRVLAPreTrainedModel):
1262
+ def __init__(self, config: SPRVLALlmConfig):
1263
+ super().__init__(config)
1264
+ self.config = config
1265
+ if config.additional_vocab_size is not None:
1266
+ self.wte = SPRVLAEmbedding(
1267
+ config.vocab_size,
1268
+ config.additional_vocab_size,
1269
+ config.hidden_size,
1270
+ )
1271
+ else:
1272
+ self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
1273
+ self.emb_drop = nn.Dropout(config.embedding_dropout)
1274
+ decoder_layer = SPRVLAPostNormDecoderLayer if config.norm_after else SPRVLADecoderLayer
1275
+ self.blocks = nn.ModuleList(
1276
+ [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1277
+ )
1278
+ self.ln_f = SPRVLARMSNorm(config.hidden_size, eps=config.layer_norm_eps)
1279
+ self.rotary_emb = SPRVLARotaryEmbedding(config)
1280
+ self.gradient_checkpointing = False
1281
+
1282
+ # Initialize weights and apply final processing
1283
+ self.post_init()
1284
+
1285
+ def get_input_embeddings(self) -> torch.nn.Module:
1286
+ return self.wte
1287
+
1288
+ def set_input_embeddings(self, value: torch.nn.Module) -> None:
1289
+ self.wte = value
1290
+
1291
+ @can_return_tuple
1292
+ def forward(
1293
+ self,
1294
+ input_ids: Optional[torch.LongTensor] = None,
1295
+ attention_mask: Optional[torch.Tensor] = None,
1296
+ position_ids: Optional[torch.LongTensor] = None,
1297
+ past_key_values: Optional[Cache] = None,
1298
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1299
+ use_cache: Optional[bool] = None,
1300
+ output_attentions: Optional[bool] = None,
1301
+ output_hidden_states: Optional[bool] = None,
1302
+ cache_position: Optional[torch.LongTensor] = None,
1303
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
1304
+ ) -> BaseModelOutputWithPast:
1305
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1306
+ output_hidden_states = (
1307
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1308
+ )
1309
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1310
+
1311
+ if (input_ids is None) ^ (inputs_embeds is not None):
1312
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1313
+
1314
+ if self.gradient_checkpointing and self.training and use_cache:
1315
+ logger.warning_once(
1316
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1317
+ )
1318
+ use_cache = False
1319
+
1320
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
1321
+ if not isinstance(past_key_values, (type(None), Cache)):
1322
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
1323
+
1324
+ if inputs_embeds is None:
1325
+ input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
1326
+ inputs_embeds = self.wte(input_ids)
1327
+
1328
+ if use_cache and past_key_values is None:
1329
+ past_key_values = DynamicCache()
1330
+
1331
+ if cache_position is None:
1332
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1333
+ cache_position = torch.arange(
1334
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1335
+ )
1336
+
1337
+ if position_ids is None:
1338
+ position_ids = cache_position.unsqueeze(0)
1339
+
1340
+ causal_mask = self._update_causal_mask(
1341
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1342
+ )
1343
+
1344
+ hidden_states = inputs_embeds
1345
+
1346
+ # create position embeddings to be shared across the decoder layers
1347
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1348
+
1349
+ # decoder layers
1350
+ all_hidden_states = () if output_hidden_states else None
1351
+ all_self_attns = () if output_attentions else None
1352
+
1353
+ for decoder_block in self.blocks[: self.config.num_hidden_layers]:
1354
+ if output_hidden_states:
1355
+ all_hidden_states += (hidden_states,)
1356
+
1357
+ layer_outputs = decoder_block(
1358
+ hidden_states,
1359
+ attention_mask=causal_mask,
1360
+ position_ids=position_ids,
1361
+ past_key_value=past_key_values,
1362
+ output_attentions=output_attentions,
1363
+ use_cache=use_cache,
1364
+ cache_position=cache_position,
1365
+ position_embeddings=position_embeddings,
1366
+ **flash_attn_kwargs,
1367
+ )
1368
+
1369
+ hidden_states = layer_outputs[0]
1370
+
1371
+ if output_attentions:
1372
+ all_self_attns += (layer_outputs[1],)
1373
+
1374
+ hidden_states = self.ln_f(hidden_states)
1375
+
1376
+ # add hidden states from the last decoder layer
1377
+ if output_hidden_states:
1378
+ all_hidden_states += (hidden_states,)
1379
+
1380
+ return BaseModelOutputWithPast(
1381
+ last_hidden_state=hidden_states,
1382
+ past_key_values=past_key_values if use_cache else None,
1383
+ hidden_states=all_hidden_states,
1384
+ attentions=all_self_attns,
1385
+ )
1386
+
1387
+ def _update_causal_mask(
1388
+ self,
1389
+ attention_mask: Union[torch.Tensor, "BlockMask"],
1390
+ input_tensor: torch.Tensor,
1391
+ cache_position: torch.Tensor,
1392
+ past_key_values: Cache,
1393
+ output_attentions: bool = False,
1394
+ ):
1395
+ if self.config._attn_implementation == "flash_attention_2":
1396
+ if attention_mask is not None and (attention_mask == 0.0).any():
1397
+ return attention_mask
1398
+ return None
1399
+ if self.config._attn_implementation == "flex_attention":
1400
+ if isinstance(attention_mask, torch.Tensor):
1401
+ attention_mask = make_flex_block_causal_mask(attention_mask)
1402
+ return attention_mask
1403
+
1404
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1405
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1406
+ # to infer the attention mask.
1407
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1408
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
1409
+
1410
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1411
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
1412
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1413
+ attention_mask,
1414
+ inputs_embeds=input_tensor,
1415
+ past_key_values_length=past_seen_tokens,
1416
+ is_training=self.training,
1417
+ ):
1418
+ return None
1419
+
1420
+ dtype = input_tensor.dtype
1421
+ sequence_length = input_tensor.shape[1]
1422
+ if using_compilable_cache:
1423
+ target_length = past_key_values.get_max_cache_shape()
1424
+ else:
1425
+ target_length = (
1426
+ attention_mask.shape[-1]
1427
+ if isinstance(attention_mask, torch.Tensor)
1428
+ else past_seen_tokens + sequence_length + 1
1429
+ )
1430
+
1431
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1432
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1433
+ attention_mask,
1434
+ sequence_length=sequence_length,
1435
+ target_length=target_length,
1436
+ dtype=dtype,
1437
+ cache_position=cache_position,
1438
+ batch_size=input_tensor.shape[0],
1439
+ )
1440
+
1441
+ if (
1442
+ self.config._attn_implementation == "sdpa"
1443
+ and attention_mask is not None
1444
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
1445
+ and not output_attentions
1446
+ ):
1447
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1448
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1449
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1450
+ min_dtype = torch.finfo(dtype).min
1451
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1452
+
1453
+ return causal_mask
1454
+
1455
+ @staticmethod
1456
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1457
+ attention_mask: torch.Tensor,
1458
+ sequence_length: int,
1459
+ target_length: int,
1460
+ dtype: torch.dtype,
1461
+ cache_position: torch.Tensor,
1462
+ batch_size: int,
1463
+ **kwargs,
1464
+ ):
1465
+ """
1466
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1467
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1468
+
1469
+ Args:
1470
+ attention_mask (`torch.Tensor`):
1471
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1472
+ `(batch_size, 1, query_length, key_value_length)`.
1473
+ sequence_length (`int`):
1474
+ The sequence length being processed.
1475
+ target_length (`int`):
1476
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1477
+ to account for the 0 padding, the part of the cache that is not filled yet.
1478
+ dtype (`torch.dtype`):
1479
+ The dtype to use for the 4D attention mask.
1480
+ cache_position (`torch.Tensor`):
1481
+ Indices depicting the position of the input sequence tokens in the sequence.
1482
+ batch_size (`torch.Tensor`):
1483
+ Batch size.
1484
+ """
1485
+ if attention_mask is not None and attention_mask.dim() == 4:
1486
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1487
+ causal_mask = attention_mask
1488
+ else:
1489
+ min_dtype = torch.finfo(dtype).min
1490
+ causal_mask = torch.full(
1491
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
1492
+ )
1493
+ if sequence_length != 1:
1494
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1495
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
1496
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1497
+ if attention_mask is not None:
1498
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1499
+ mask_length = attention_mask.shape[-1]
1500
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
1501
+ causal_mask.device
1502
+ )
1503
+ padding_mask = padding_mask == 0
1504
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1505
+ padding_mask, min_dtype
1506
+ )
1507
+
1508
+ return causal_mask
1509
+
1510
+
1511
+ @add_start_docstrings(
1512
+ "The SPRVLA text-only model which consists of a language model + lm head.",
1513
+ SPRVLA_START_DOCSTRING,
1514
+ )
1515
+ class SPRVLAForCausalLM(SPRVLAPreTrainedModel, GenerationMixin):
1516
+ _tied_weights_keys = [] # Weights are not tied
1517
+ _tp_plan = {"lm_head": "colwise_rep"}
1518
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1519
+ base_model_prefix = "model"
1520
+
1521
+ def __init__(self, config: SPRVLALlmConfig):
1522
+ super().__init__(config)
1523
+ self.model = SPRVLALlm(config)
1524
+ self.vocab_size = config.vocab_size
1525
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1526
+
1527
+ # Initialize weights and apply final processing
1528
+ self.post_init()
1529
+
1530
+ def get_input_embeddings(self) -> torch.nn.Module:
1531
+ return self.model.wte
1532
+
1533
+ def set_input_embeddings(self, value: torch.nn.Module) -> None:
1534
+ self.model.wte = value
1535
+
1536
+ def get_output_embeddings(self):
1537
+ return self.lm_head
1538
+
1539
+ def set_output_embeddings(self, value: torch.nn.Module) -> None:
1540
+ self.lm_head = value
1541
+
1542
+ def set_decoder(self, decoder: torch.nn.Module) -> None:
1543
+ self.model = decoder
1544
+
1545
+ def get_decoder(self) -> torch.nn.Module:
1546
+ return self.model
1547
+
1548
+ @can_return_tuple
1549
+ @add_start_docstrings_to_model_forward(SPRVLA_TEXT_ONLY_INPUTS_DOCSTRING)
1550
+ def forward(
1551
+ self,
1552
+ input_ids: Optional[torch.LongTensor] = None,
1553
+ attention_mask: Optional[torch.Tensor] = None,
1554
+ position_ids: Optional[torch.LongTensor] = None,
1555
+ past_key_values: Optional[Cache] = None,
1556
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1557
+ labels: Optional[torch.LongTensor] = None,
1558
+ use_cache: Optional[bool] = None,
1559
+ output_attentions: Optional[bool] = None,
1560
+ output_hidden_states: Optional[bool] = None,
1561
+ cache_position: Optional[torch.LongTensor] = None,
1562
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1563
+ **kwargs,
1564
+ ) -> CausalLMOutputWithPast:
1565
+ r"""
1566
+ ```python
1567
+ >>> from transformers import AutoTokenizer, SPRVLAForCausalLM
1568
+
1569
+ >>> model = SPRVLAForCausalLM.from_pretrained("...")
1570
+ >>> tokenizer = AutoTokenizer.from_pretrained("...")
1571
+
1572
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1573
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1574
+
1575
+ >>> # Generate
1576
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1577
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1578
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1579
+ ```"""
1580
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1581
+ output_hidden_states = (
1582
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1583
+ )
1584
+
1585
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1586
+ outputs: BaseModelOutputWithPast = self.model(
1587
+ input_ids=input_ids,
1588
+ attention_mask=attention_mask,
1589
+ position_ids=position_ids,
1590
+ past_key_values=past_key_values,
1591
+ inputs_embeds=inputs_embeds,
1592
+ use_cache=use_cache,
1593
+ output_attentions=output_attentions,
1594
+ output_hidden_states=output_hidden_states,
1595
+ cache_position=cache_position,
1596
+ **kwargs,
1597
+ )
1598
+
1599
+ hidden_states = outputs.last_hidden_state
1600
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1601
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1602
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1603
+
1604
+ loss = None
1605
+ if labels is not None:
1606
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1607
+
1608
+ return CausalLMOutputWithPast(
1609
+ loss=loss,
1610
+ logits=logits,
1611
+ past_key_values=outputs.past_key_values,
1612
+ hidden_states=outputs.hidden_states,
1613
+ attentions=outputs.attentions,
1614
+ )
1615
+
1616
+
1617
+ SPRVLA_INPUTS_DOCSTRING = r"""
1618
+ Args:
1619
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1620
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1621
+ it.
1622
+
1623
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1624
+ [`PreTrainedTokenizer.__call__`] for details.
1625
+
1626
+ [What are input IDs?](../glossary#input-ids)
1627
+ images (`torch.FloatTensor` of shape `(batch_size, n_crops, 27*27, 3*14*14)`, *optional*):
1628
+ The input crops in with pixel values between 0 and 1 and normalized with SigLIP2 mean/std
1629
+
1630
+ Each crop contains 27x27 patches with 14*14*3 pixel values
1631
+ image_masks (`torch.FloatTensor` of shape `(batch_size, n_crops, n_patches, n_features)`, *optional*):
1632
+ Image masks showing what percent of each patch is paddding
1633
+ pooled_patches_idx (`torch.LongTensor` of shape `(batch_size, n_image_tokens, n_pooled_patches)`):
1634
+ For each patch_id tokens in `input_ids`, the indices of the patches in `images`
1635
+ to pool for that token, masked with -1
1636
+ means ignore the patch.
1637
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1638
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1639
+
1640
+ - 1 for tokens that are **not masked**,
1641
+ - 0 for tokens that are **masked**.
1642
+
1643
+ [What are attention masks?](../glossary#attention-mask)
1644
+
1645
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1646
+ [`PreTrainedTokenizer.__call__`] for details.
1647
+
1648
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1649
+ `past_key_values`).
1650
+
1651
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1652
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1653
+ information on the default strategy.
1654
+
1655
+ - 1 indicates the head is **not masked**,
1656
+ - 0 indicates the head is **masked**.
1657
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1658
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1659
+ config.n_positions - 1]`.
1660
+
1661
+ [What are position IDs?](../glossary#position-ids)
1662
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1663
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1664
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1665
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1666
+
1667
+ Two formats are allowed:
1668
+ - a [`~cache_utils.Cache`] instance, see our
1669
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
1670
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1671
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1672
+ cache format.
1673
+
1674
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1675
+ legacy cache format will be returned.
1676
+
1677
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1678
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1679
+ of shape `(batch_size, sequence_length)`.
1680
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1681
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1682
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1683
+ model's internal embedding lookup matrix.
1684
+ use_cache (`bool`, *optional*):
1685
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1686
+ `past_key_values`).
1687
+ output_attentions (`bool`, *optional*):
1688
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1689
+ tensors for more detail.
1690
+ output_hidden_states (`bool`, *optional*):
1691
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1692
+ more detail.
1693
+ return_dict (`bool`, *optional*):
1694
+ Whether or not to return a [`SPRVLACausalLMOutputWithPast`] instead of a plain tuple.
1695
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1696
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
1697
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
1698
+ the complete sequence length.
1699
+ """
1700
+
1701
+
1702
+ @add_start_docstrings(
1703
+ "The bare SPRVLA model outputting raw hidden-states without any specific head on top.",
1704
+ SPRVLA_START_DOCSTRING,
1705
+ )
1706
+ class SPRVLAModel(SPRVLAPreTrainedModel):
1707
+ _checkpoint_conversion_mapping = {}
1708
+
1709
+ def __init__(self, config: SPRVLAConfig):
1710
+ super().__init__(config)
1711
+ self.transformer: SPRVLALlm = SPRVLALlm(config.llm_config)
1712
+ self.vision_backbone: Optional[SPRVLAVisionBackbone] = None
1713
+ if config.vit_config is not None and config.adapter_config is not None:
1714
+ self.vision_backbone = SPRVLAVisionBackbone(config.vit_config, config.adapter_config)
1715
+
1716
+ # Initialize weights and apply final processing
1717
+ self.post_init()
1718
+
1719
+ def get_input_embeddings(self) -> torch.nn.Module:
1720
+ return self.transformer.wte
1721
+
1722
+ def set_input_embeddings(self, value: torch.nn.Module) -> None:
1723
+ self.transformer.wte = value
1724
+
1725
+ @property
1726
+ def device(self) -> torch.device:
1727
+ return self.transformer.ln_f.weight.device
1728
+
1729
+ def build_input_embeddings(
1730
+ self,
1731
+ input_ids: torch.LongTensor,
1732
+ images: Optional[torch.FloatTensor] = None, # image inputs
1733
+ image_masks: Optional[torch.Tensor] = None,
1734
+ pooled_patches_idx: Optional[torch.LongTensor] = None,
1735
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1736
+
1737
+ # Get embeddings of input.
1738
+ # shape: (batch_size, seq_len, d_model)
1739
+ input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
1740
+ x = self.transformer.wte(input_ids)
1741
+
1742
+ image_features: Optional[torch.FloatTensor] = None
1743
+ if images is not None:
1744
+ image_features = self.vision_backbone(images, pooled_patches_idx)
1745
+ is_image_patch = input_ids.view(-1) == self.config.image_patch_id
1746
+ assert is_image_patch.sum() == len(image_features)
1747
+ x.view(-1, x.shape[-1])[is_image_patch] += image_features
1748
+
1749
+ # shape: (batch_size, seq_len, d_model)
1750
+ x = self.transformer.emb_drop(x) # type: ignore
1751
+
1752
+ return x, image_features
1753
+
1754
+ @can_return_tuple
1755
+ def forward(
1756
+ self,
1757
+ input_ids: Optional[torch.LongTensor] = None,
1758
+ images: Optional[torch.FloatTensor] = None,
1759
+ image_masks: Optional[torch.Tensor] = None,
1760
+ pooled_patches_idx: Optional[torch.Tensor] = None,
1761
+ attention_mask: Optional[torch.Tensor] = None,
1762
+ position_ids: Optional[torch.Tensor] = None,
1763
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1764
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1765
+ use_cache: Optional[bool] = None,
1766
+ output_attentions: Optional[bool] = None,
1767
+ output_hidden_states: Optional[bool] = None,
1768
+ cache_position: Optional[torch.LongTensor] = None,
1769
+ ) -> Union[Tuple, SPRVLAModelOutputWithPast]:
1770
+
1771
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1772
+ output_hidden_states = (
1773
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1774
+ )
1775
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1776
+
1777
+ if (input_ids is None) ^ (inputs_embeds is not None):
1778
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1779
+
1780
+ if images is not None and inputs_embeds is not None:
1781
+ raise ValueError(
1782
+ "You cannot specify both images and inputs_embeds at the same time."
1783
+ )
1784
+
1785
+ if inputs_embeds is None:
1786
+ inputs_embeds, image_features = self.build_input_embeddings(
1787
+ input_ids, images, image_masks, pooled_patches_idx)
1788
+
1789
+ outputs = self.transformer(
1790
+ attention_mask=attention_mask,
1791
+ position_ids=position_ids,
1792
+ past_key_values=past_key_values,
1793
+ inputs_embeds=inputs_embeds,
1794
+ use_cache=use_cache,
1795
+ output_attentions=output_attentions,
1796
+ output_hidden_states=output_hidden_states,
1797
+ cache_position=cache_position,
1798
+ )
1799
+
1800
+ return SPRVLAModelOutputWithPast(
1801
+ last_hidden_state=outputs.last_hidden_state,
1802
+ past_key_values=outputs.past_key_values,
1803
+ hidden_states=outputs.hidden_states,
1804
+ attentions=outputs.attentions,
1805
+ image_hidden_states=image_features if images is not None else None,
1806
+ )
1807
+
1808
+ @add_start_docstrings(
1809
+ "The SPRVLA model which consists of a vision backbone and a language model + lm head.",
1810
+ SPRVLA_START_DOCSTRING,
1811
+ )
1812
+ class SPRVLAForActionReasoning(SPRVLAPreTrainedModel, GenerationMixin):
1813
+ _checkpoint_conversion_mapping = {}
1814
+ _tied_weights_keys = [] # Weights are not tied
1815
+ config_class = SPRVLAConfig
1816
+
1817
+ def __init__(self, config: SPRVLAConfig):
1818
+ super().__init__(config)
1819
+
1820
+ self.model = SPRVLAModel(config)
1821
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1822
+ self.vocab_size = config.vocab_size
1823
+
1824
+ # Initialize weights and apply final processing
1825
+ self.post_init()
1826
+
1827
+ # --- Action parsing / de-tokenization setup ---
1828
+ # Stats dict expected under config.norm_stats (per-dataset key). If missing, default to empty.
1829
+ self.norm_stats = getattr(config, "norm_stats", None) or {}
1830
+ # Number of discretization bins used for action tokens, defaults to 256.
1831
+ self.n_action_bins = getattr(config, "n_action_bins", 256)
1832
+ # Precompute bin centers in [-1, 1] for inverse token to value mapping.
1833
+ self.bins = np.linspace(-1.0, 1.0, self.n_action_bins)
1834
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
1835
+ # Lazily constructed tokenizer for converting token strings to ids
1836
+ self._qwen_tokenizer = None
1837
+
1838
+ def get_input_embeddings(self) -> torch.nn.Module:
1839
+ return self.model.transformer.wte
1840
+
1841
+ def set_input_embeddings(self, value: torch.nn.Module) -> None:
1842
+ self.model.transformer.wte = value
1843
+
1844
+ def get_output_embeddings(self):
1845
+ self.lm_head
1846
+
1847
+ def set_output_embeddings(self, value: torch.nn.Module) -> None:
1848
+ self.lm_head = value
1849
+
1850
+ # Make modules available throught conditional class for BC
1851
+ @property
1852
+ def language_model(self) -> torch.nn.Module:
1853
+ return self.model.transformer
1854
+
1855
+ @property
1856
+ def vision_backbone(self) -> torch.nn.Module:
1857
+ return self.model.vision_backbone
1858
+
1859
+ @can_return_tuple
1860
+ @add_start_docstrings_to_model_forward(SPRVLA_INPUTS_DOCSTRING)
1861
+ def forward(
1862
+ self,
1863
+ input_ids: torch.LongTensor = None,
1864
+ images: Optional[torch.Tensor] = None,
1865
+ image_masks: Optional[torch.Tensor] = None,
1866
+ pooled_patches_idx: Optional[torch.Tensor] = None,
1867
+ attention_mask: Optional[torch.Tensor] = None,
1868
+ position_ids: Optional[torch.LongTensor] = None,
1869
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1870
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1871
+ labels: Optional[torch.LongTensor] = None,
1872
+ use_cache: Optional[bool] = None,
1873
+ output_attentions: Optional[bool] = None,
1874
+ output_hidden_states: Optional[bool] = None,
1875
+ cache_position: Optional[torch.LongTensor] = None,
1876
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1877
+ **kwargs,
1878
+ ) -> Union[Tuple, SPRVLACausalLMOutputWithPast]:
1879
+ r"""
1880
+ ```python
1881
+ >>> from PIL import Image
1882
+ >>> import requests
1883
+ >>> from transformers import AutoProcessor, SPRVLAForActionReasoning
1884
+
1885
+ >>> model = SPRVLAForActionReasoning.from_pretrained("...")
1886
+ >>> processor = AutoProcessor.from_pretrained("...")
1887
+
1888
+ >>> prompt = "What's the content of the image?"
1889
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
1890
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1891
+
1892
+ >>> inputs = processor(images=image, text=prompt, apply_chat_template=True, return_tensors="pt")
1893
+
1894
+ >>> # Generate
1895
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=15)
1896
+ >>> generated_tokens = generated_ids[:, inputs['input_ids'].size(1):]
1897
+ >>> processor.batch_decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1898
+ "The image features a busy city street with a stop sign prominently displayed"
1899
+ ```"""
1900
+ outputs = self.model(
1901
+ input_ids=input_ids,
1902
+ images=images,
1903
+ image_masks=image_masks,
1904
+ pooled_patches_idx=pooled_patches_idx,
1905
+ attention_mask=attention_mask,
1906
+ position_ids=position_ids,
1907
+ past_key_values=past_key_values,
1908
+ inputs_embeds=inputs_embeds,
1909
+ use_cache=use_cache,
1910
+ output_attentions=output_attentions,
1911
+ output_hidden_states=output_hidden_states,
1912
+ cache_position=cache_position,
1913
+ )
1914
+
1915
+ hidden_states = outputs.last_hidden_state
1916
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1917
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1918
+
1919
+ loss = None
1920
+ if labels is not None:
1921
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size)
1922
+
1923
+ return SPRVLACausalLMOutputWithPast(
1924
+ loss=loss,
1925
+ logits=logits,
1926
+ past_key_values=outputs.past_key_values,
1927
+ hidden_states=outputs.hidden_states,
1928
+ attentions=outputs.attentions,
1929
+ image_hidden_states=outputs.image_hidden_states,
1930
+ )
1931
+
1932
+ # ===== Utilities for action parsing / un-normalization =====
1933
+ def _check_unnorm_key(self, unnorm_key: Optional[str]) -> str:
1934
+ """Validate and resolve which dataset key to use from self.norm_stats."""
1935
+ if not self.norm_stats:
1936
+ raise ValueError("No norm_stats found in config; cannot unnormalize actions.")
1937
+ if unnorm_key is None:
1938
+ if len(self.norm_stats) != 1:
1939
+ raise ValueError(
1940
+ f"Model has multiple dataset stats; please pass `unnorm_key` from {list(self.norm_stats.keys())}"
1941
+ )
1942
+ return next(iter(self.norm_stats.keys()))
1943
+ if unnorm_key not in self.norm_stats:
1944
+ raise ValueError(f"`unnorm_key`={unnorm_key!r} not in {list(self.norm_stats.keys())}")
1945
+ return unnorm_key
1946
+
1947
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
1948
+ """Return action dimensionality from q01 stats length for the dataset key."""
1949
+ key = self._check_unnorm_key(unnorm_key)
1950
+ return len(self.norm_stats[key]["action"]["q01"])
1951
+
1952
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
1953
+ """Return the full action stats dict for a given dataset key."""
1954
+ key = self._check_unnorm_key(unnorm_key)
1955
+ return self.norm_stats[key]["action"]
1956
+
1957
+ @torch.no_grad()
1958
+ def parse_action(self, text: str, unnorm_key: Optional[str] = None) -> list:
1959
+ """
1960
+ Parse a generated text to extract one 1×D action token list, decode to continuous values,
1961
+ and unnormalize using dataset-specific stats from `config.norm_stats`.
1962
+
1963
+ This follows the pipeline used in `experiments/robot/libero/main_libero_10_evaluation.py`:
1964
+ - Find bracketed token lists following the phrase "the action that the robot should take is" (case-insensitive),
1965
+ falling back to any bracketed list in the text.
1966
+ - Convert token strings → ids via Qwen2Tokenizer.
1967
+ - Map ids → discretized bin indices using: `discretized = vocab_size - token_id - 1` (clipped to bins)
1968
+ - Convert bins → normalized actions in [-1, 1] using precomputed `bin_centers`.
1969
+ - Unnormalize with q01/q99 and optional `mask` from norm_stats.
1970
+
1971
+ Returns:
1972
+ List[float]: unnormalized action vector of length D.
1973
+ """
1974
+ # Resolve action dimension and stats
1975
+ action_dim = self.get_action_dim(unnorm_key)
1976
+ stats = self.get_action_stats(unnorm_key)
1977
+ q01 = np.asarray(stats["q01"], dtype=np.float32)
1978
+ q99 = np.asarray(stats["q99"], dtype=np.float32)
1979
+ mask = np.asarray(stats.get("mask", np.ones_like(q01, dtype=bool)), dtype=bool)
1980
+ # the gripper state should not be normalized
1981
+ mask[-1] = False
1982
+
1983
+ # Lazily load the tokenizer (shared across calls)
1984
+ if self._qwen_tokenizer is None:
1985
+ self._qwen_tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2-7B")
1986
+
1987
+ token_lists = extract_action_token_lists(text, only_len=action_dim)
1988
+ action_lists = []
1989
+
1990
+ # Choose the first list (temporal aggregation, if any, should be done by the caller)
1991
+ for tokens in token_lists:
1992
+
1993
+ # Convert tokens → ids (replace None with vocab_size to avoid negatives)
1994
+ ids = self._qwen_tokenizer.convert_tokens_to_ids(tokens)
1995
+ ids = [self._qwen_tokenizer.vocab_size if i is None else int(i) for i in ids]
1996
+ ids = np.asarray(ids, dtype=np.int64)
1997
+
1998
+ # ids → discretized bin indices → normalized actions in [-1, 1]
1999
+ discretized = self._qwen_tokenizer.vocab_size - ids
2000
+ discretized = np.clip(discretized - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
2001
+ normalized = self.bin_centers[discretized]
2002
+
2003
+ # Unnormalize using per-dimension statistics
2004
+ unnorm = 0.5 * (normalized + 1.0) * (q99 - q01) + q01
2005
+ actions = np.where(mask, unnorm, normalized)
2006
+
2007
+ action_lists.append([float(x) for x in actions])
2008
+
2009
+ # Return a Python list of float actions
2010
+ return action_lists
2011
+
2012
+ @torch.no_grad()
2013
+ def parse_trace(self, text: str) -> list:
2014
+ return extract_trace_lists(text, point_len=2, min_points=1)
2015
+
2016
+ @torch.no_grad()
2017
+ def parse_depth(self, text: str) -> list:
2018
+ return extract_depth_string(text, include_tags=True)
2019
+
2020
+
2021
+ def prepare_inputs_for_generation(
2022
+ self,
2023
+ input_ids: torch.LongTensor,
2024
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
2025
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2026
+ images: Optional[torch.FloatTensor] = None,
2027
+ image_masks: Optional[torch.Tensor] = None,
2028
+ pooled_patches_idx: Optional[torch.Tensor] = None,
2029
+ attention_mask: Optional[torch.Tensor] = None,
2030
+ cache_position: Optional[torch.LongTensor] = None,
2031
+ logits_to_keep: Optional[Union[int, torch.Tensor]] = None,
2032
+ **kwargs,
2033
+ ):
2034
+
2035
+ model_inputs = super().prepare_inputs_for_generation(
2036
+ input_ids,
2037
+ past_key_values=past_key_values,
2038
+ inputs_embeds=inputs_embeds,
2039
+ attention_mask=attention_mask,
2040
+ cache_position=cache_position,
2041
+ logits_to_keep=logits_to_keep,
2042
+ **kwargs,
2043
+ )
2044
+
2045
+ if cache_position[0] == 0:
2046
+ model_inputs["images"] = images
2047
+ model_inputs["pooled_patches_idx"] = pooled_patches_idx
2048
+ model_inputs["image_masks"] = image_masks
2049
+
2050
+ return model_inputs
2051
+
2052
+ def _update_model_kwargs_for_generation(
2053
+ self,
2054
+ outputs: ModelOutput,
2055
+ model_kwargs: Dict[str, Any],
2056
+ is_encoder_decoder: bool = False,
2057
+ num_new_tokens: int = 1,
2058
+ ) -> Dict[str, Any]:
2059
+ if model_kwargs["use_cache"] and "images" in model_kwargs:
2060
+ # After the first step, no long pass the images into forward since the images tokens
2061
+ # are already cached
2062
+ for k in ["images", "image_masks", "pooled_patches_idx"]:
2063
+ del model_kwargs[k]
2064
+ return super()._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder, num_new_tokens)
2065
+
2066
+ @staticmethod
2067
+ def _prepare_4d_causal_attention_mask_with_cache_position(
2068
+ attention_mask: torch.Tensor,
2069
+ sequence_length: int,
2070
+ target_length: int,
2071
+ dtype: torch.dtype,
2072
+ cache_position: torch.Tensor,
2073
+ batch_size: int,
2074
+ **kwargs,
2075
+ ):
2076
+ """
2077
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
2078
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
2079
+
2080
+ Args:
2081
+ attention_mask (`torch.Tensor`):
2082
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
2083
+ `(batch_size, 1, query_length, key_value_length)`.
2084
+ sequence_length (`int`):
2085
+ The sequence length being processed.
2086
+ target_length (`int`):
2087
+ The target length: when generating with static cache, the mask should be as long as the static cache,
2088
+ to account for the 0 padding, the part of the cache that is not filled yet.
2089
+ dtype (`torch.dtype`):
2090
+ The dtype to use for the 4D attention mask.
2091
+ cache_position (`torch.Tensor`):
2092
+ Indices depicting the position of the input sequence tokens in the sequence.
2093
+ batch_size (`torch.Tensor`):
2094
+ Batch size.
2095
+ """
2096
+ if attention_mask is not None and attention_mask.dim() == 4:
2097
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
2098
+ causal_mask = attention_mask
2099
+ else:
2100
+ min_dtype = torch.finfo(dtype).min
2101
+ causal_mask = torch.full(
2102
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
2103
+ )
2104
+ if sequence_length != 1:
2105
+ causal_mask = torch.triu(causal_mask, diagonal=1)
2106
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
2107
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
2108
+ if attention_mask is not None:
2109
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
2110
+ mask_length = attention_mask.shape[-1]
2111
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
2112
+ causal_mask.device
2113
+ )
2114
+ padding_mask = padding_mask == 0
2115
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
2116
+ padding_mask, min_dtype
2117
+ )
2118
+
2119
+ return causal_mask
2120
+
2121
+
2122
+ # Always register for multi-modal features
2123
+ AutoModelForImageTextToText.register(SPRVLAConfig, SPRVLAForActionReasoning)
2124
+ AutoModelForCausalLM.register(SPRVLALlmConfig, SPRVLAForCausalLM)