XiaSheng commited on
Commit
0dcbf0b
·
verified ·
1 Parent(s): 537c716

Initial upload of FreeChunk model with custom code

Browse files
Files changed (3) hide show
  1. configuration.py +157 -0
  2. encoder.py +15 -32
  3. freechunker.py +769 -0
configuration.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """FreeChunker configuration: Modified from XLM-RoBERTa configuration"""
17
+
18
+ from collections import OrderedDict
19
+ from typing import Mapping
20
+
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from transformers.onnx import OnnxConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class FreeChunkerConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`FreeChunkerModel`] or a [`TFFreeChunkerModel`]. It
32
+ is used to instantiate a XLM-RoBERTa model according to the specified arguments, defining the model architecture.
33
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the FreeChunker
34
+ [FacebookAI/xlm-roberta-base](https://huggingface.co/FacebookAI/xlm-roberta-base) architecture.
35
+
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+
39
+
40
+ Args:
41
+ vocab_size (`int`, *optional*, defaults to 30522):
42
+ Vocabulary size of the XLM-RoBERTa model. Defines the number of different tokens that can be represented by
43
+ the `inputs_ids` passed when calling [`FreeChunekrModel`] or [`TFFreeChunekrModel`].
44
+ hidden_size (`int`, *optional*, defaults to 768):
45
+ Dimensionality of the encoder layers and the pooler layer.
46
+ num_hidden_layers (`int`, *optional*, defaults to 12):
47
+ Number of hidden layers in the Transformer encoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 12):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ intermediate_size (`int`, *optional*, defaults to 3072):
51
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
52
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
53
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
54
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
55
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
56
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
57
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
58
+ The dropout ratio for the attention probabilities.
59
+ max_position_embeddings (`int`, *optional*, defaults to 512):
60
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
61
+ just in case (e.g., 512 or 1024 or 2048).
62
+ type_vocab_size (`int`, *optional*, defaults to 2):
63
+ The vocabulary size of the `token_type_ids` passed when calling [`FreeChunekrModel`] or
64
+ [`TFFreeChunekrModel`].
65
+ initializer_range (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
68
+ The epsilon used by the layer normalization layers.
69
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
70
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
71
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
72
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
73
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
74
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
75
+ is_decoder (`bool`, *optional*, defaults to `False`):
76
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
77
+ use_cache (`bool`, *optional*, defaults to `True`):
78
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
79
+ relevant if `config.is_decoder=True`.
80
+ classifier_dropout (`float`, *optional*):
81
+ The dropout ratio for the classification head.
82
+
83
+ Examples:
84
+
85
+ ```python
86
+ >>> from transformers import FreeChunekrConfig, FreeChunekrModel
87
+
88
+ >>> # Initializing a XLM-RoBERTa FacebookAI/xlm-roberta-base style configuration
89
+ >>> configuration = FreeChunekrConfig()
90
+
91
+ >>> # Initializing a model (with random weights) from the FacebookAI/xlm-roberta-base style configuration
92
+ >>> model = FreeChunekrModel(configuration)
93
+
94
+ >>> # Accessing the model configuration
95
+ >>> configuration = model.config
96
+ ```"""
97
+
98
+ model_type = "xlm-roberta"
99
+
100
+ def __init__(
101
+ self,
102
+ vocab_size=30522,
103
+ hidden_size=768,
104
+ num_hidden_layers=12,
105
+ num_attention_heads=12,
106
+ intermediate_size=3072,
107
+ hidden_act="gelu",
108
+ hidden_dropout_prob=0.1,
109
+ attention_probs_dropout_prob=0.1,
110
+ max_position_embeddings=512,
111
+ type_vocab_size=2,
112
+ initializer_range=0.02,
113
+ layer_norm_eps=1e-12,
114
+ pad_token_id=1,
115
+ bos_token_id=0,
116
+ eos_token_id=2,
117
+ position_embedding_type="absolute",
118
+ use_cache=True,
119
+ classifier_dropout=None,
120
+ **kwargs,
121
+ ):
122
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
123
+
124
+ self.vocab_size = vocab_size
125
+ self.hidden_size = hidden_size
126
+ self.num_hidden_layers = num_hidden_layers
127
+ self.num_attention_heads = num_attention_heads
128
+ self.hidden_act = hidden_act
129
+ self.intermediate_size = intermediate_size
130
+ self.hidden_dropout_prob = hidden_dropout_prob
131
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
132
+ self.max_position_embeddings = max_position_embeddings
133
+ self.type_vocab_size = type_vocab_size
134
+ self.initializer_range = initializer_range
135
+ self.layer_norm_eps = layer_norm_eps
136
+ self.position_embedding_type = position_embedding_type
137
+ self.use_cache = use_cache
138
+ self.classifier_dropout = classifier_dropout
139
+
140
+
141
+ # Copied from transformers.models.roberta.configuration_roberta.RobertaOnnxConfig with Roberta->FreeChunekr
142
+ class FreeChunekrOnnxConfig(OnnxConfig):
143
+ @property
144
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
145
+ if self.task == "multiple-choice":
146
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
147
+ else:
148
+ dynamic_axis = {0: "batch", 1: "sequence"}
149
+ return OrderedDict(
150
+ [
151
+ ("input_ids", dynamic_axis),
152
+ ("attention_mask", dynamic_axis),
153
+ ]
154
+ )
155
+
156
+
157
+ __all__ = ["FreeChunkerConfig", "FreeChunkerOnnxConfig"]
encoder.py CHANGED
@@ -9,25 +9,26 @@ import numpy as np
9
  import pickle
10
  import os
11
  from typing import List, Tuple, Union
12
- from .sentenizer import Sentenceizer
13
- from .modeling_freechunker import FreeChunkerModel
14
- from .aggregator import TextAggregator
15
 
16
  class UnifiedEncoder:
17
  """
18
  Unified text encoder, supporting text sentence splitting and encoding for multiple models
19
  """
20
 
21
- def __init__(self, model_name: str, local_model_path: str = None):
22
  """
23
  Initialize unified text encoder
24
 
25
  Args:
26
- model_name (str): Model name (e.g. 'bge-m3', 'jina', 'nomic')
27
- local_model_path (str, optional): Local model path for loading FreeChunker weights.
28
- If None, tries to load from current directory or Hugging Face.
29
  """
30
  self.model_name = model_name
 
31
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'mps')
32
 
33
  # Initialize text aggregator
@@ -37,39 +38,21 @@ class UnifiedEncoder:
37
  print(f"Using local model path: {local_model_path}")
38
  print(f"Using device: {self.device}")
39
 
40
- # If local_model_path is not provided, assume current directory or let from_pretrained handle it
41
- if local_model_path is None:
42
- local_model_path = "."
43
-
44
- try:
45
- self.model = FreeChunkerModel.from_pretrained(local_model_path)
46
- except Exception as e:
47
- print(f"Failed to load model from {local_model_path}: {e}")
48
- print("Trying to load as a fresh model or from HF hub if applicable...")
49
- # Fallback or re-raise
50
- raise e
51
-
52
  self.model.to(self.device)
53
  self.model.eval()
54
 
55
  # Select model and preprocessor based on model name
56
- # Predefined model mapping: name -> (local_path, HF_model_ID)
57
- # Note: Local paths are environment specific, so we primarily rely on HF IDs or passed arguments
58
  model_configs = {
59
- 'bge-m3': ('/share/home/ecnuzwx/UnifiedRAG/cache/models--BAAI--bge-m3', 'BAAI/bge-m3'),
60
- 'nomic-embed-text-v1.5': ('/share/home/ecnuzwx/UnifiedRAG/cache/models--nomic-ai--nomic-embed-text-v1.5', 'nomic-ai/nomic-embed-text-v1.5'),
61
- 'jina': ('/share/home/ecnuzwx/UnifiedRAG/cache/models--jinaai--jina-embeddings-v2-small-en', 'jinaai/jina-embeddings-v2-small-en')
62
  }
63
 
64
  if model_name in model_configs:
65
- local_path, hf_id = model_configs[model_name]
66
- # Prioritize local path if it exists, otherwise use HF ID
67
- if os.path.exists(local_path):
68
- target_model = local_path
69
- else:
70
- target_model = hf_id
71
-
72
- self.sentenceizer = Sentenceizer(model_name=target_model)
73
  else:
74
  # Try using model_name directly as path or ID
75
  print(f"Unknown predefined model name: {model_name}, trying to load directly...")
 
9
  import pickle
10
  import os
11
  from typing import List, Tuple, Union
12
+ from sentenizer import Sentenceizer
13
+ from freechunker import FreeChunkerModel
14
+ from aggregator import TextAggregator
15
 
16
  class UnifiedEncoder:
17
  """
18
  Unified text encoder, supporting text sentence splitting and encoding for multiple models
19
  """
20
 
21
+ def __init__(self, model_name: str, local_model_path: str = None, granularities: List[int] = None):
22
  """
23
  Initialize unified text encoder
24
 
25
  Args:
26
+ model_name (str): Model name
27
+ local_model_path (str, optional): Local model path for loading fine-tuned weights
28
+ granularities (List[int], optional): Granularities for chunking
29
  """
30
  self.model_name = model_name
31
+ self.granularities = granularities
32
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'mps')
33
 
34
  # Initialize text aggregator
 
38
  print(f"Using local model path: {local_model_path}")
39
  print(f"Using device: {self.device}")
40
 
41
+ self.model = FreeChunkerModel.from_pretrained(local_model_path)
 
 
 
 
 
 
 
 
 
 
 
42
  self.model.to(self.device)
43
  self.model.eval()
44
 
45
  # Select model and preprocessor based on model name
46
+ # Predefined model mapping: name -> HF_model_ID
 
47
  model_configs = {
48
+ 'bge-m3': 'BAAI/bge-m3',
49
+ 'nomic-embed-text-v1.5': 'nomic-ai/nomic-embed-text-v1.5',
50
+ 'jina': 'jinaai/jina-embeddings-v2-small-en'
51
  }
52
 
53
  if model_name in model_configs:
54
+ hf_id = model_configs[model_name]
55
+ self.sentenceizer = Sentenceizer(model_name=hf_id)
 
 
 
 
 
 
56
  else:
57
  # Try using model_name directly as path or ID
58
  print(f"Unknown predefined model name: {model_name}, trying to load directly...")
freechunker.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019 Facebook AI Research and the HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """FreeChunker model: Modified from PyTorch XLM-RoBERTa model."""
17
+ from utils import generate_shifted_matrix
18
+ import math
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from packaging import version
24
+ from torch import nn
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutputWithPoolingAndCrossAttentions
28
+ )
29
+ from transformers.modeling_utils import PreTrainedModel
30
+ from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
31
+ from transformers.utils import (
32
+ add_code_sample_docstrings,
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ get_torch_version,
36
+ logging
37
+ )
38
+ from configuration import FreeChunkerConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ _CHECKPOINT_FOR_DOC = "FacebookAI/xlm-roberta-base"
44
+ _CONFIG_FOR_DOC = "FreeChunkerConfig"
45
+
46
+
47
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->FreeChunker
48
+ class FreeChunkerEmbeddings(nn.Module):
49
+ """
50
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
51
+ """
52
+
53
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
57
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
58
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
65
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
66
+ self.register_buffer(
67
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
68
+ )
69
+ self.register_buffer(
70
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
71
+ )
72
+
73
+ # End copy
74
+ self.padding_idx = config.pad_token_id
75
+ self.position_embeddings = nn.Embedding(
76
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
77
+ )
78
+
79
+ def forward(
80
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None
81
+ ):
82
+ if position_ids is None:
83
+ if input_ids is not None:
84
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
85
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
86
+ else:
87
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
88
+
89
+ if input_ids is not None:
90
+ input_shape = input_ids.size()
91
+ else:
92
+ input_shape = inputs_embeds.size()[:-1]
93
+
94
+ seq_length = input_shape[1]
95
+
96
+ if position_ids is None:
97
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=self.position_ids.device)
98
+ position_ids = position_ids.unsqueeze(0).expand(input_shape)
99
+
100
+ if token_type_ids is None:
101
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
102
+
103
+ if inputs_embeds is None:
104
+ inputs_embeds = self.word_embeddings(input_ids)
105
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
106
+
107
+ embeddings = inputs_embeds + token_type_embeddings
108
+ if self.position_embedding_type == "absolute":
109
+ position_embeddings = self.position_embeddings(position_ids)
110
+ embeddings += position_embeddings
111
+ embeddings = self.LayerNorm(embeddings)
112
+ embeddings = self.dropout(embeddings)
113
+ return embeddings
114
+
115
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
116
+ """
117
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
118
+
119
+ Args:
120
+ inputs_embeds: torch.Tensor
121
+
122
+ Returns: torch.Tensor
123
+ """
124
+ input_shape = inputs_embeds.size()[:-1]
125
+ sequence_length = input_shape[1]
126
+
127
+ position_ids = torch.arange(
128
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
129
+ )
130
+ return position_ids.unsqueeze(0).expand(input_shape)
131
+
132
+
133
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->FreeChunker
134
+ class FreeChunkerSelfAttention(nn.Module):
135
+ def __init__(self, config, position_embedding_type=None):
136
+ super().__init__()
137
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
138
+ raise ValueError(
139
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
140
+ f"heads ({config.num_attention_heads})"
141
+ )
142
+
143
+ self.num_attention_heads = config.num_attention_heads
144
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
145
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
146
+
147
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
148
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
149
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
150
+
151
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
152
+ self.position_embedding_type = position_embedding_type or getattr(
153
+ config, "position_embedding_type", "absolute"
154
+ )
155
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
156
+ self.max_position_embeddings = config.max_position_embeddings
157
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
158
+
159
+ self.is_decoder = config.is_decoder
160
+
161
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
162
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
163
+ x = x.view(new_x_shape)
164
+ return x.permute(0, 2, 1, 3)
165
+
166
+ def forward(
167
+ self,
168
+ hidden_states: torch.Tensor,
169
+ hidden_states2: torch.Tensor, # Second input stream, required parameter
170
+ attention_mask: Optional[torch.FloatTensor] = None,
171
+ head_mask: Optional[torch.FloatTensor] = None,
172
+ output_attentions: Optional[bool] = False,
173
+ ) -> Tuple[torch.Tensor]:
174
+ # Query comes from hidden_states
175
+ mixed_query_layer = self.query(hidden_states)
176
+ query_layer = self.transpose_for_scores(mixed_query_layer)
177
+
178
+ # Key and Value come from hidden_states2
179
+ key_layer = self.transpose_for_scores(self.key(hidden_states2))
180
+ value_layer = self.transpose_for_scores(self.value(hidden_states2))
181
+
182
+ # Calculate attention scores
183
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
184
+
185
+ # Modified positional encoding handling
186
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
187
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
188
+
189
+ # hidden_states positions are all the first position (0, 0, 0, ...)
190
+ position_ids_l = torch.zeros(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
191
+ # hidden_states2 uses normal incremental position sequence (0, 1, 2, 3, ...)
192
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
193
+ distance = position_ids_l - position_ids_r
194
+
195
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
196
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
197
+
198
+ if self.position_embedding_type == "relative_key":
199
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
200
+ attention_scores = attention_scores + relative_position_scores
201
+ elif self.position_embedding_type == "relative_key_query":
202
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
203
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
204
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
205
+
206
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
207
+
208
+ if attention_mask is not None:
209
+ attention_scores = attention_scores + attention_mask
210
+
211
+ # Normalize to probabilities
212
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
213
+ attention_probs = self.dropout(attention_probs)
214
+
215
+ # Apply head mask
216
+ if head_mask is not None:
217
+ attention_probs = attention_probs * head_mask
218
+
219
+ # Calculate context
220
+ context_layer = torch.matmul(attention_probs, value_layer)
221
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
222
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
223
+ context_layer = context_layer.view(new_context_layer_shape)
224
+
225
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
226
+ return outputs
227
+
228
+
229
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->FreeChunker
230
+ class FreeChunkerSdpaSelfAttention(FreeChunkerSelfAttention):
231
+ def __init__(self, config, position_embedding_type=None):
232
+ super().__init__(config, position_embedding_type=position_embedding_type)
233
+ self.dropout_prob = config.attention_probs_dropout_prob
234
+ self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states: torch.Tensor,
239
+ hidden_states2: torch.Tensor, # Second input stream, required parameter
240
+ attention_mask: Optional[torch.Tensor] = None,
241
+ head_mask: Optional[torch.FloatTensor] = None,
242
+ output_attentions: Optional[bool] = False,
243
+ ) -> Tuple[torch.Tensor]:
244
+ # If relative positional encoding, output attentions, or head mask are present, fallback to parent implementation
245
+ if (self.position_embedding_type != "absolute" or
246
+ output_attentions or
247
+ head_mask is not None):
248
+ return super().forward(
249
+ hidden_states,
250
+ hidden_states2,
251
+ attention_mask,
252
+ head_mask,
253
+ output_attentions,
254
+ )
255
+
256
+ # Use optimized implementation of SDPA
257
+ bsz, tgt_len, _ = hidden_states.size()
258
+
259
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
260
+ key_layer = self.transpose_for_scores(self.key(hidden_states2))
261
+ value_layer = self.transpose_for_scores(self.value(hidden_states2))
262
+
263
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
264
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
265
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
266
+ if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
267
+ query_layer = query_layer.contiguous()
268
+ key_layer = key_layer.contiguous()
269
+ value_layer = value_layer.contiguous()
270
+
271
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
272
+ query_layer,
273
+ key_layer,
274
+ value_layer,
275
+ attn_mask=attention_mask,
276
+ dropout_p=self.dropout_prob if self.training else 0.0,
277
+ is_causal=False, # For customized tasks, causal mask is not used
278
+ )
279
+
280
+ attn_output = attn_output.transpose(1, 2)
281
+ attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
282
+
283
+ outputs = (attn_output,)
284
+ return outputs
285
+
286
+
287
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->FreeChunker
288
+ class FreeChunkerSelfOutput(nn.Module):
289
+ def __init__(self, config):
290
+ super().__init__()
291
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
292
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
293
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
294
+
295
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
296
+ hidden_states = self.dense(hidden_states)
297
+ hidden_states = self.dropout(hidden_states)
298
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
299
+ return hidden_states
300
+
301
+
302
+ XLM_ROBERTA_SELF_ATTENTION_CLASSES = {
303
+ "eager": FreeChunkerSelfAttention,
304
+ "sdpa": FreeChunkerSdpaSelfAttention,
305
+ }
306
+
307
+
308
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->FreeChunker
309
+ class FreeChunkerAttention(nn.Module):
310
+ def __init__(self, config, position_embedding_type=None):
311
+ super().__init__()
312
+ self.self = XLM_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](
313
+ config, position_embedding_type=position_embedding_type
314
+ )
315
+ self.output = FreeChunkerSelfOutput(config)
316
+ self.pruned_heads = set()
317
+
318
+ def prune_heads(self, heads):
319
+ if len(heads) == 0:
320
+ return
321
+ heads, index = find_pruneable_heads_and_indices(
322
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
323
+ )
324
+
325
+ # Prune linear layers
326
+ self.self.query = prune_linear_layer(self.self.query, index)
327
+ self.self.key = prune_linear_layer(self.self.key, index)
328
+ self.self.value = prune_linear_layer(self.self.value, index)
329
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
330
+
331
+ # Update hyper params and store pruned heads
332
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
333
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
334
+ self.pruned_heads = self.pruned_heads.union(heads)
335
+
336
+ def forward(
337
+ self,
338
+ hidden_states: torch.Tensor,
339
+ hidden_states2: torch.Tensor, # Second input stream, required parameter
340
+ attention_mask: Optional[torch.FloatTensor] = None,
341
+ head_mask: Optional[torch.FloatTensor] = None,
342
+ output_attentions: Optional[bool] = False,
343
+ ) -> Tuple[torch.Tensor]:
344
+ self_outputs = self.self(
345
+ hidden_states,
346
+ hidden_states2, # Pass second input stream
347
+ attention_mask,
348
+ head_mask,
349
+ output_attentions,
350
+ )
351
+ attention_output = self.output(self_outputs[0], hidden_states)
352
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
353
+ return outputs
354
+
355
+
356
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate with Roberta->FreeChunker
357
+ class FreeChunkerIntermediate(nn.Module):
358
+ def __init__(self, config):
359
+ super().__init__()
360
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
361
+ if isinstance(config.hidden_act, str):
362
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
363
+ else:
364
+ self.intermediate_act_fn = config.hidden_act
365
+
366
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
367
+ hidden_states = self.dense(hidden_states)
368
+ hidden_states = self.intermediate_act_fn(hidden_states)
369
+ return hidden_states
370
+
371
+
372
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaOutput with Roberta->FreeChunker
373
+ class FreeChunkerOutput(nn.Module):
374
+ def __init__(self, config):
375
+ super().__init__()
376
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
377
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
378
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
379
+
380
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
381
+ hidden_states = self.dense(hidden_states)
382
+ hidden_states = self.dropout(hidden_states)
383
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
384
+ return hidden_states
385
+
386
+
387
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->FreeChunker
388
+ class FreeChunkerLayer(nn.Module):
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
392
+ self.seq_len_dim = 1
393
+ self.attention = FreeChunkerAttention(config)
394
+ self.is_decoder = config.is_decoder
395
+ self.add_cross_attention = config.add_cross_attention
396
+ if self.add_cross_attention:
397
+ if not self.is_decoder:
398
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
399
+ self.crossattention = FreeChunkerAttention(config, position_embedding_type="absolute")
400
+ self.intermediate = FreeChunkerIntermediate(config)
401
+ self.output = FreeChunkerOutput(config)
402
+
403
+ def forward(
404
+ self,
405
+ hidden_states: torch.Tensor,
406
+ hidden_states2: torch.Tensor, # Second input stream, required parameter
407
+ attention_mask: Optional[torch.FloatTensor] = None,
408
+ head_mask: Optional[torch.FloatTensor] = None,
409
+ output_attentions: Optional[bool] = False,
410
+ ) -> Tuple[torch.Tensor]:
411
+ attention_outputs = self.attention(
412
+ hidden_states,
413
+ hidden_states2, # Pass second input stream
414
+ attention_mask,
415
+ head_mask,
416
+ output_attentions,
417
+ )
418
+ attention_output = attention_outputs[0]
419
+
420
+ outputs = attention_outputs[1:] # add self attentions if we output attention weights
421
+
422
+ layer_output = self.feed_forward_chunk(attention_output)
423
+ outputs = (layer_output,) + outputs
424
+
425
+ return outputs
426
+
427
+ def feed_forward_chunk(self, attention_output):
428
+ intermediate_output = self.intermediate(attention_output)
429
+ layer_output = self.output(intermediate_output, attention_output)
430
+ return layer_output
431
+
432
+
433
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->FreeChunker
434
+ class FreeChunkerEncoder(nn.Module):
435
+ def __init__(self, config):
436
+ super().__init__()
437
+ self.config = config
438
+ self.layer = nn.ModuleList([FreeChunkerLayer(config) for _ in range(config.num_hidden_layers)])
439
+ self.gradient_checkpointing = False
440
+
441
+ def forward(
442
+ self,
443
+ hidden_states: torch.Tensor,
444
+ hidden_states2: torch.Tensor, # Second input stream, required parameter
445
+ attention_mask: Optional[torch.FloatTensor] = None,
446
+ head_mask: Optional[torch.FloatTensor] = None,
447
+ ) -> torch.Tensor:
448
+
449
+ for i, layer_module in enumerate(self.layer):
450
+ layer_head_mask = head_mask[i] if head_mask is not None else None
451
+
452
+ if self.gradient_checkpointing and self.training:
453
+
454
+ def create_custom_forward(module):
455
+ def custom_forward(*inputs):
456
+ return module(*inputs)
457
+
458
+ return custom_forward
459
+
460
+ layer_outputs = torch.utils.checkpoint.checkpoint(
461
+ create_custom_forward(layer_module),
462
+ hidden_states,
463
+ hidden_states2, # Pass second input stream
464
+ attention_mask,
465
+ layer_head_mask,
466
+ )
467
+ else:
468
+ layer_outputs = layer_module(
469
+ hidden_states,
470
+ hidden_states2, # Pass second input stream
471
+ attention_mask,
472
+ layer_head_mask,
473
+ )
474
+
475
+ hidden_states = layer_outputs[0]
476
+
477
+ return hidden_states
478
+
479
+
480
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaPooler with Roberta->FreeChunker
481
+ class FreeChunkerPooler(nn.Module):
482
+ def __init__(self, config):
483
+ super().__init__()
484
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
485
+ self.activation = nn.Tanh()
486
+
487
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
488
+ # We "pool" the model by simply taking the hidden state corresponding
489
+ # to the first token.
490
+ first_token_tensor = hidden_states[:, 0]
491
+ pooled_output = self.dense(first_token_tensor)
492
+ pooled_output = self.activation(pooled_output)
493
+ return pooled_output
494
+
495
+
496
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel with Roberta->FreeChunker
497
+ class FreeChunkerPreTrainedModel(PreTrainedModel):
498
+ """
499
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
500
+ models.
501
+ """
502
+
503
+ config_class = FreeChunkerConfig
504
+ base_model_prefix = "roberta"
505
+ supports_gradient_checkpointing = True
506
+ _no_split_modules = ["FreeChunkerEmbeddings", "FreeChunkerSelfAttention", "FreeChunkerSdpaSelfAttention"]
507
+ _supports_sdpa = True
508
+
509
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
510
+ def _init_weights(self, module):
511
+ """Initialize the weights"""
512
+ if isinstance(module, nn.Linear):
513
+ # Slightly different from the TF version which uses truncated_normal for initialization
514
+ # cf https://github.com/pytorch/pytorch/pull/5617
515
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
516
+ if module.bias is not None:
517
+ module.bias.data.zero_()
518
+ elif isinstance(module, nn.Embedding):
519
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
520
+ if module.padding_idx is not None:
521
+ module.weight.data[module.padding_idx].zero_()
522
+ elif isinstance(module, nn.LayerNorm):
523
+ module.bias.data.zero_()
524
+ module.weight.data.fill_(1.0)
525
+
526
+
527
+ XLM_ROBERTA_START_DOCSTRING = r"""
528
+
529
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
530
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
531
+ etc.)
532
+
533
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
534
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
535
+ and behavior.
536
+
537
+ Parameters:
538
+ config ([`FreeChunkerConfig`]): Model configuration class with all the parameters of the
539
+ model. Initializing with a config file does not load the weights associated with the model, only the
540
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
541
+ """
542
+
543
+ XLM_ROBERTA_INPUTS_DOCSTRING = r"""
544
+ Args:
545
+ input_ids (`torch.LongTensor` of shape `({0})`):
546
+ Indices of input sequence tokens in the vocabulary.
547
+
548
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
549
+ [`PreTrainedTokenizer.__call__`] for details.
550
+
551
+ [What are input IDs?](../glossary#input-ids)
552
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
553
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
554
+
555
+ - 1 for tokens that are **not masked**,
556
+ - 0 for tokens that are **masked**.
557
+
558
+ [What are attention masks?](../glossary#attention-mask)
559
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
560
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
561
+ 1]`:
562
+
563
+ - 0 corresponds to a *sentence A* token,
564
+ - 1 corresponds to a *sentence B* token.
565
+
566
+ [What are token type IDs?](../glossary#token-type-ids)
567
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
568
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
569
+ config.max_position_embeddings - 1]`.
570
+
571
+ [What are position IDs?](../glossary#position-ids)
572
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
573
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
574
+
575
+ - 1 indicates the head is **not masked**,
576
+ - 0 indicates the head is **masked**.
577
+
578
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
579
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
580
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
581
+ model's internal embedding lookup matrix.
582
+ output_attentions (`bool`, *optional*):
583
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
584
+ tensors for more detail.
585
+ output_hidden_states (`bool`, *optional*):
586
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
587
+ more detail.
588
+ return_dict (`bool`, *optional*):
589
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
590
+ """
591
+
592
+
593
+ @add_start_docstrings(
594
+ "The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
595
+ XLM_ROBERTA_START_DOCSTRING,
596
+ )
597
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaModel with Roberta->FreeChunker, ROBERTA->XLM_ROBERTA
598
+ class FreeChunkerModel(FreeChunkerPreTrainedModel):
599
+ """
600
+
601
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
602
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
603
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
604
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
605
+
606
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
607
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
608
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
609
+ """
610
+
611
+ _no_split_modules = ["FreeChunkerEmbeddings", "FreeChunkerLayer"]
612
+
613
+ def __init__(self, config, add_pooling_layer=True):
614
+ super().__init__(config)
615
+ self.config = config
616
+ self.config.vocab_size = 2
617
+ self.embeddings = FreeChunkerEmbeddings(self.config)
618
+ self.encoder = FreeChunkerEncoder(config)
619
+
620
+ self.pooler = FreeChunkerPooler(config) if add_pooling_layer else None
621
+
622
+ self.attn_implementation = config._attn_implementation
623
+ self.position_embedding_type = config.position_embedding_type
624
+
625
+ # Initialize weights and apply final processing
626
+ self.post_init()
627
+
628
+ def get_input_embeddings(self):
629
+ return self.embeddings.word_embeddings
630
+
631
+ def set_input_embeddings(self, value):
632
+ self.embeddings.word_embeddings = value
633
+
634
+ def _prune_heads(self, heads_to_prune):
635
+ """
636
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
637
+ class PreTrainedModel
638
+ """
639
+ for layer, heads in heads_to_prune.items():
640
+ self.encoder.layer[layer].attention.prune_heads(heads)
641
+
642
+ @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
643
+ @add_code_sample_docstrings(
644
+ checkpoint=_CHECKPOINT_FOR_DOC,
645
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
646
+ config_class=_CONFIG_FOR_DOC,
647
+ )
648
+ def forward(
649
+ self,
650
+ inputs_embeds=None,
651
+ labels=None,
652
+ loss_weights: bool = False,
653
+ input_ids: Optional[torch.Tensor] = None,
654
+ head_mask: Optional[torch.Tensor] = None,
655
+ encoder_hidden_states: Optional[torch.Tensor] = None,
656
+ encoder_attention_mask: Optional[torch.Tensor] = None,
657
+ granularities: Optional[list] = None
658
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
659
+
660
+ # Get input device
661
+ input_device = inputs_embeds.device
662
+
663
+ # Dimension adaptation: if input dimension is less than 1024, pad with 0
664
+ original_hidden_size = inputs_embeds.shape[-1]
665
+ target_hidden_size = self.config.hidden_size # 1024
666
+
667
+ if original_hidden_size < target_hidden_size:
668
+ # Calculate number of dimensions to pad
669
+ padding_size = target_hidden_size - original_hidden_size
670
+ # Pad with 0 on the last dimension
671
+ padding = torch.zeros(inputs_embeds.shape[:-1] + (padding_size,),
672
+ device=input_device, dtype=inputs_embeds.dtype)
673
+ inputs_embeds = torch.cat([inputs_embeds, padding], dim=-1)
674
+
675
+ # Adjust max_power based on sequence length
676
+ sequence_length = inputs_embeds.shape[1]
677
+
678
+ shifted_matrix = generate_shifted_matrix(sequence_length, device=input_device, granularities=granularities)
679
+
680
+ # Generate attention mask
681
+ encoder_attention_mask = shifted_matrix.transpose(1, 2)
682
+ encoder_attention_mask = torch.where(encoder_attention_mask == 1.0, 0.0, float('-inf'))[:, None, :, :]
683
+
684
+ # Fixed input IDs and position IDs
685
+ input_ids = torch.tensor([[0] * shifted_matrix.shape[2]], device=input_device)
686
+ position_ids = torch.tensor([[0] * shifted_matrix.shape[2]], device=input_device)
687
+
688
+ # Embedding layer processing
689
+ embedding_output = self.embeddings(
690
+ input_ids=input_ids,
691
+ position_ids=position_ids,
692
+ token_type_ids=None,
693
+ )
694
+
695
+ # Set second input stream
696
+ encoder_hidden_states = inputs_embeds
697
+
698
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
699
+
700
+ # Encoder processing
701
+ sequence_output = self.encoder(
702
+ embedding_output,
703
+ hidden_states2=encoder_hidden_states, # Second input stream
704
+ attention_mask=encoder_attention_mask, # Use generated mask
705
+ head_mask=head_mask,
706
+ )
707
+
708
+ if original_hidden_size < target_hidden_size:
709
+
710
+ sequence_output = sequence_output[..., :original_hidden_size]
711
+ # Also truncate inputs_embeds back to original size to match dimensions of sequence_output
712
+ inputs_embeds = inputs_embeds[..., :original_hidden_size]
713
+
714
+ shift_matrix = shifted_matrix.transpose(1, 2).squeeze(0)
715
+ # Loss calculation
716
+ loss = None
717
+ if labels is not None:
718
+ emb = sequence_output.view(-1, sequence_output.shape[-1])
719
+ lab = labels.view(-1, labels.shape[-1])
720
+ target = torch.ones(emb.size(0), device=emb.device)
721
+
722
+ # If weights are provided, use weighted cosine loss
723
+ if loss_weights:
724
+ # Validate weight dimensions
725
+ loss_weights = shift_matrix.sum(dim=1).to(emb.device)
726
+
727
+ # Calculate unweighted cosine loss
728
+ cos_loss_fn = torch.nn.CosineEmbeddingLoss(reduction='none')
729
+ individual_losses = cos_loss_fn(emb, lab, target)
730
+
731
+ # Apply weights and calculate weighted average
732
+ weighted_losses = individual_losses * loss_weights
733
+ loss = weighted_losses.sum() / loss_weights.sum()
734
+ else:
735
+ # Use standard cosine loss
736
+ cos_loss = torch.nn.CosineEmbeddingLoss()
737
+ loss = cos_loss(emb, lab, target)
738
+
739
+ embedding = torch.cat([inputs_embeds, sequence_output], dim=1)
740
+ embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)
741
+ # embedding = torch.nn.functional.normalize(sequence_output, p=2, dim=-1)
742
+
743
+ return {
744
+ "loss": loss,
745
+ "embedding": embedding.squeeze(0),
746
+ "shift_matrix": shift_matrix
747
+ }
748
+
749
+ # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
750
+ def create_position_ids_from_input_ids(input_ids, padding_idx):
751
+ """
752
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
753
+ are ignored. This is modified from fairseq's `utils.make_positions`.
754
+
755
+ Args:
756
+ x: torch.Tensor x:
757
+
758
+ Returns: torch.Tensor
759
+ """
760
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
761
+ mask = input_ids.ne(padding_idx).int()
762
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
763
+ return incremental_indices.long() + padding_idx
764
+
765
+
766
+ __all__ = [
767
+ "FreeChunkerModel",
768
+ "FreeChunkerPreTrainedModel",
769
+ ]