XiaSheng commited on
Commit
d347397
·
verified ·
1 Parent(s): 044e8e5

Initial upload of FreeChunk model with custom code

Browse files
README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FreeChunker-Jina
2
+
3
+ FreeChunker is a training-free embedding optimization method that dynamically chunks text to improve retrieval performance. This repository contains the **FreeChunker** model initialized with **jinaai/jina-embeddings-v2-small-en** embeddings.
4
+
5
+ ## Features
6
+
7
+ - **Dynamic Chunking**: Automatically groups sentences into semantically coherent chunks.
8
+ - **Optimized for RAG**: Improves retrieval augmented generation by providing better context segments.
9
+ - **Backbone**: Built on top of `jinaai/jina-embeddings-v2-small-en` sentence embeddings.
10
+
11
+ ## Requirements
12
+
13
+ ```bash
14
+ pip install torch transformers sentence-transformers numpy
15
+ ```
16
+
17
+ ## Usage
18
+
19
+ You can use the provided `UnifiedEncoder` class (in `encoder.py`) to easily use the model for encoding and retrieval.
20
+
21
+ ### Using UnifiedEncoder
22
+
23
+ ```python
24
+ from encoder import UnifiedEncoder
25
+
26
+ # Initialize the encoder
27
+ # local_model_path="." assumes you are in the directory containing model.safetensors
28
+ encoder = UnifiedEncoder(model_name="jina", local_model_path=".")
29
+
30
+ # Input text
31
+ text = """
32
+ Your long text goes here. FreeChunker will split this text into sentences,
33
+ generate embeddings using Jina, and then group them into semantic chunks.
34
+ It handles long documents effectively.
35
+ """
36
+
37
+ # Build vector store (chunks and encodes the text)
38
+ encoder.build_vector_store(text)
39
+
40
+ # Query
41
+ query = "How does FreeChunker work?"
42
+ results = encoder.query(query, top_k=3, aggregation_mode='post')
43
+
44
+ print("Results:", results)
45
+ ```
46
+
47
+ ### Manual Pipeline
48
+
49
+ If you prefer to use the components separately:
50
+
51
+ 1. **Split and Encode**: Use `Sentenceizer` (wrapping `jinaai/jina-embeddings-v2-small-en`) to get sentence embeddings.
52
+ 2. **FreeChunker**: Pass embeddings to `FreeChunkerModel`.
53
+ 3. **Process**: Use the output `shift_matrix` to group sentences.
54
+
55
+ ```python
56
+ from sentenizer import Sentenceizer
57
+ from modeling_freechunker import FreeChunkerModel
58
+ import torch
59
+
60
+ # 1. Setup Sentenceizer with Backbone
61
+ sentenceizer = Sentenceizer(model_name="jinaai/jina-embeddings-v2-small-en")
62
+
63
+ # 2. Load FreeChunker Model
64
+ model = FreeChunkerModel.from_pretrained(".", trust_remote_code=True)
65
+ model.eval()
66
+
67
+ # 3. Process Text
68
+ text = "Your text..."
69
+ sentences, embeddings = sentenceizer.split_and_encode(text)
70
+
71
+ # 4. Forward pass through FreeChunker
72
+ inputs_embeds = torch.tensor(embeddings).unsqueeze(0) # Batch size 1
73
+ with torch.no_grad():
74
+ outputs = model(inputs_embeds=inputs_embeds)
75
+
76
+ # outputs['embedding'] contains refined embeddings
77
+ # outputs['shift_matrix'] contains chunking information
78
+ ```
79
+
80
+ ## Files
81
+
82
+ - `model.safetensors`: The FreeChunker model weights.
83
+ - `encoder.py`: High-level interface (`UnifiedEncoder`) for end-to-end usage.
84
+ - `sentenizer.py`: Helper for text splitting and backbone embedding.
85
+ - `aggregator.py`: Helper for aggregating retrieved results.
86
+ - `configuration_freechunker.py` & `modeling_freechunker.py`: Model definition.
87
+
88
+ ## Citation
89
+
90
+ If you use this model in your research, please cite:
91
+
92
+ ```
93
+ Zhang W, Jiang Y H, Wu Y. FreeChunker: A Cross-Granularity Chunking Framework[J]. arXiv preprint arXiv:2510.20356, 2025.
94
+ ```
aggregator.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Text Aggregator - Precise text segment aggregation based on sentence position markers
5
+
6
+ Main functions:
7
+ 1. Detect overlaps between text segments based on 【Begin-x】【End-y】 markers
8
+ 2. Automatically merge and reconstruct based on original order when overlapping
9
+ 3. Retain the highest scoring segments
10
+ """
11
+
12
+ import re
13
+ from typing import List, Tuple
14
+
15
+ class TextAggregator:
16
+ """
17
+ Text aggregator for merging retrieved text segments
18
+ Implements splitting, deduplication, sorting, and reconstruction of text segments based on 【Begin-x】【End-x】 markers
19
+ """
20
+
21
+ def __init__(self):
22
+ """
23
+ Initialize text aggregator
24
+ """
25
+ pass
26
+
27
+ def _extract_segments_from_text(self, text: str) -> List[Tuple[int, str]]:
28
+ """
29
+ Extract all 【Begin-x】...【End-x】 segments from text
30
+
31
+ Args:
32
+ text: Text containing position markers
33
+
34
+ Returns:
35
+ List[Tuple[int, str]]: List of (begin_index, segment_text)
36
+ """
37
+ segments = []
38
+ # Match 【Begin-x】...【End-x】 pattern
39
+ pattern = r'【Begin-(\d+)】(.*?)【End-\1】'
40
+ matches = re.findall(pattern, text, re.DOTALL)
41
+
42
+ for match in matches:
43
+ begin_idx = int(match[0])
44
+ segment_content = match[1]
45
+ full_segment = f"【Begin-{begin_idx}】{segment_content}【End-{begin_idx}】"
46
+ segments.append((begin_idx, full_segment))
47
+
48
+ return segments
49
+
50
+ def _remove_boundary_markers(self, text: str) -> str:
51
+ """
52
+ Remove all boundary markers from text, keeping only content
53
+
54
+ Args:
55
+ text: Text containing boundary markers
56
+
57
+ Returns:
58
+ str: Text with boundary markers removed
59
+ """
60
+ # Remove 【Begin-x】 and 【End-x】 markers
61
+ clean_text = re.sub(r'【Begin-\d+】|【End-\d+】', '', text)
62
+ return clean_text.strip()
63
+
64
+
65
+
66
+ def aggregate_segments(self, segments: List[str]) -> str:
67
+ """
68
+ Aggregate text segments: split, deduplicate, sort, reconstruct
69
+
70
+ Args:
71
+ segments: List of text segments
72
+
73
+ Returns:
74
+ str: Aggregated text string
75
+ """
76
+ if not segments:
77
+ return ""
78
+
79
+ # Step 1: Extract segments from all input texts
80
+ all_segments = {} # {begin_index: segment_text}
81
+
82
+ for text in segments:
83
+ extracted = self._extract_segments_from_text(text)
84
+ for begin_idx, segment in extracted:
85
+ # Deduplication: Keep only one segment for the same begin_index
86
+ if begin_idx not in all_segments:
87
+ all_segments[begin_idx] = segment
88
+
89
+ # Step 2: Sort by begin_index
90
+ sorted_segments = sorted(all_segments.items())
91
+
92
+ # Step 3: Reconstruct text
93
+ if not sorted_segments:
94
+ return []
95
+
96
+ # Build continuous text
97
+ result_text = ""
98
+ prev_end = -1
99
+
100
+ for begin_idx, segment in sorted_segments:
101
+ # If not continuous, add ellipsis
102
+ if prev_end != -1 and begin_idx != prev_end + 1:
103
+ result_text += "..."
104
+
105
+ # Add content of current segment (remove boundary markers)
106
+ content = self._remove_boundary_markers(segment)
107
+ result_text += content
108
+
109
+ prev_end = begin_idx
110
+
111
+ return result_text
112
+
113
+ def aggregate_segments_complete(self, segments: List[str]) -> str:
114
+ """
115
+ Completely aggregate all text segments
116
+
117
+ Args:
118
+ segments: List of text segments
119
+
120
+ Returns:
121
+ str: Aggregated text string
122
+ """
123
+ return self.aggregate_segments(segments)
124
+
125
+
126
+
127
+
128
+ def demo():
129
+ """Demo function - Show text splitting, deduplication, sorting, and reconstruction based on position markers"""
130
+ print("=== Text Aggregator Demo (Completely Rewritten Version) ===\n")
131
+
132
+ # Create aggregator
133
+ aggregator = TextAggregator()
134
+
135
+ # Test data - Format according to user example
136
+ test_segments = [
137
+ "【Begin-1】sdfsdf【End-1】【Begin-2】sdfsdf【End-2】",
138
+ "【Begin-2】sdfsdf【End-2】【Begin-3】sdfsdf【End-3】",
139
+ "【Begin-5】sdfsdf【End-5】【Begin-6】sdfsdf【End-6】"
140
+ ]
141
+
142
+ print("Original input segments:")
143
+ for i, text in enumerate(test_segments, 1):
144
+ print(f"{i}. {text}")
145
+
146
+ print("\n=== Step 1: Extract segments from each text ===")
147
+ all_extracted = {}
148
+ for i, text in enumerate(test_segments, 1):
149
+ extracted = aggregator._extract_segments_from_text(text)
150
+ print(f"Segments extracted from text {i}: {extracted}")
151
+ for begin_idx, segment in extracted:
152
+ if begin_idx not in all_extracted:
153
+ all_extracted[begin_idx] = segment
154
+ print(f" Add segment: Begin-{begin_idx}")
155
+ else:
156
+ print(f" Skip duplicate segment: Begin-{begin_idx}")
157
+
158
+ print(f"\nAll segments after deduplication: {list(all_extracted.keys())}")
159
+
160
+ print("\n=== Step 2: Sort by Begin marker ===")
161
+ sorted_segments = sorted(all_extracted.items())
162
+ print("Sorted segments:")
163
+ for begin_idx, segment in sorted_segments:
164
+ print(f" Begin-{begin_idx}: {segment}")
165
+
166
+ print("\n=== Step 3: Reconstruct text (remove boundary markers, add ellipsis) ===")
167
+ result = aggregator.aggregate_segments(test_segments)
168
+ print(f"Final result: {result}")
169
+
170
+ print("\n=== Full Test Cases ===")
171
+
172
+ # More complex test cases
173
+ complex_segments = [
174
+ "【Begin-1】First sentence【End-1】【Begin-2】Second sentence【End-2】【Begin-3】Third sentence【End-3】",
175
+ "【Begin-2】Second sentence【End-2】【Begin-3】Third sentence【End-3】【Begin-4】Fourth sentence【End-4】",
176
+ "【Begin-6】Sixth sentence【End-6】【Begin-7】Seventh sentence【End-7】",
177
+ "【Begin-4】Fourth sentence【End-4】【Begin-5】Fifth sentence【End-5】"
178
+ ]
179
+
180
+ print("\nComplex test input:")
181
+ for i, text in enumerate(complex_segments, 1):
182
+ print(f"{i}. {text}")
183
+
184
+ complex_result = aggregator.aggregate_segments(complex_segments)
185
+ print(f"\nComplex test result: {complex_result}")
186
+
187
+ print("\n=== Boundary Case Tests ===")
188
+
189
+ # Test empty input
190
+ empty_result = aggregator.aggregate_segments([])
191
+ print(f"Empty input result: {empty_result}")
192
+
193
+ # Test single segment
194
+ single_result = aggregator.aggregate_segments(["【Begin-1】Single segment【End-1】"])
195
+ print(f"Single segment result: {single_result}")
196
+
197
+ # Test text without markers (should return empty)
198
+ no_marker_result = aggregator.aggregate_segments(["Normal text without markers"])
199
+ print(f"Text without markers result: {no_marker_result}")
200
+
201
+ print("\n=== Demo Completed ===")
202
+
203
+
204
+ if __name__ == "__main__":
205
+ demo()
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "FreeChunkerModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "classifier_dropout": null,
8
+ "dtype": "float32",
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 8194,
17
+ "model_type": "xlm-roberta",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "output_past": true,
21
+ "pad_token_id": 1,
22
+ "position_embedding_type": "absolute",
23
+ "transformers_version": "4.56.1",
24
+ "type_vocab_size": 1,
25
+ "use_cache": true,
26
+ "vocab_size": 2,
27
+ "max_power": 4,
28
+ "auto_map": {
29
+ "AutoConfig": "configuration_freechunker.FreeChunkerConfig",
30
+ "AutoModel": "modeling_freechunker.FreeChunkerModel"
31
+ }
32
+ }
configuration_freechunker.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """FreeChunker configuration: Modified from XLM-RoBERTa configuration"""
17
+
18
+ from collections import OrderedDict
19
+ from typing import Mapping
20
+
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from transformers.onnx import OnnxConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class FreeChunkerConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`FreeChunkerModel`] or a [`TFFreeChunkerModel`]. It
32
+ is used to instantiate a XLM-RoBERTa model according to the specified arguments, defining the model architecture.
33
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the FreeChunker
34
+ [FacebookAI/xlm-roberta-base](https://huggingface.co/FacebookAI/xlm-roberta-base) architecture.
35
+
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+
39
+
40
+ Args:
41
+ vocab_size (`int`, *optional*, defaults to 30522):
42
+ Vocabulary size of the XLM-RoBERTa model. Defines the number of different tokens that can be represented by
43
+ the `inputs_ids` passed when calling [`FreeChunekrModel`] or [`TFFreeChunekrModel`].
44
+ hidden_size (`int`, *optional*, defaults to 768):
45
+ Dimensionality of the encoder layers and the pooler layer.
46
+ num_hidden_layers (`int`, *optional*, defaults to 12):
47
+ Number of hidden layers in the Transformer encoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 12):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ intermediate_size (`int`, *optional*, defaults to 3072):
51
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
52
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
53
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
54
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
55
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
56
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
57
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
58
+ The dropout ratio for the attention probabilities.
59
+ max_position_embeddings (`int`, *optional*, defaults to 512):
60
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
61
+ just in case (e.g., 512 or 1024 or 2048).
62
+ type_vocab_size (`int`, *optional*, defaults to 2):
63
+ The vocabulary size of the `token_type_ids` passed when calling [`FreeChunekrModel`] or
64
+ [`TFFreeChunekrModel`].
65
+ initializer_range (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
68
+ The epsilon used by the layer normalization layers.
69
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
70
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
71
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
72
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
73
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
74
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
75
+ is_decoder (`bool`, *optional*, defaults to `False`):
76
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
77
+ use_cache (`bool`, *optional*, defaults to `True`):
78
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
79
+ relevant if `config.is_decoder=True`.
80
+ classifier_dropout (`float`, *optional*):
81
+ The dropout ratio for the classification head.
82
+
83
+ Examples:
84
+
85
+ ```python
86
+ >>> from transformers import FreeChunekrConfig, FreeChunekrModel
87
+
88
+ >>> # Initializing a XLM-RoBERTa FacebookAI/xlm-roberta-base style configuration
89
+ >>> configuration = FreeChunekrConfig()
90
+
91
+ >>> # Initializing a model (with random weights) from the FacebookAI/xlm-roberta-base style configuration
92
+ >>> model = FreeChunekrModel(configuration)
93
+
94
+ >>> # Accessing the model configuration
95
+ >>> configuration = model.config
96
+ ```"""
97
+
98
+ model_type = "xlm-roberta"
99
+
100
+ def __init__(
101
+ self,
102
+ vocab_size=30522,
103
+ hidden_size=768,
104
+ num_hidden_layers=12,
105
+ num_attention_heads=12,
106
+ intermediate_size=3072,
107
+ hidden_act="gelu",
108
+ hidden_dropout_prob=0.1,
109
+ attention_probs_dropout_prob=0.1,
110
+ max_position_embeddings=512,
111
+ type_vocab_size=2,
112
+ initializer_range=0.02,
113
+ layer_norm_eps=1e-12,
114
+ pad_token_id=1,
115
+ bos_token_id=0,
116
+ eos_token_id=2,
117
+ position_embedding_type="absolute",
118
+ use_cache=True,
119
+ classifier_dropout=None,
120
+ **kwargs,
121
+ ):
122
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
123
+
124
+ self.vocab_size = vocab_size
125
+ self.hidden_size = hidden_size
126
+ self.num_hidden_layers = num_hidden_layers
127
+ self.num_attention_heads = num_attention_heads
128
+ self.hidden_act = hidden_act
129
+ self.intermediate_size = intermediate_size
130
+ self.hidden_dropout_prob = hidden_dropout_prob
131
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
132
+ self.max_position_embeddings = max_position_embeddings
133
+ self.type_vocab_size = type_vocab_size
134
+ self.initializer_range = initializer_range
135
+ self.layer_norm_eps = layer_norm_eps
136
+ self.position_embedding_type = position_embedding_type
137
+ self.use_cache = use_cache
138
+ self.classifier_dropout = classifier_dropout
139
+
140
+
141
+ # Copied from transformers.models.roberta.configuration_roberta.RobertaOnnxConfig with Roberta->FreeChunekr
142
+ class FreeChunekrOnnxConfig(OnnxConfig):
143
+ @property
144
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
145
+ if self.task == "multiple-choice":
146
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
147
+ else:
148
+ dynamic_axis = {0: "batch", 1: "sequence"}
149
+ return OrderedDict(
150
+ [
151
+ ("input_ids", dynamic_axis),
152
+ ("attention_mask", dynamic_axis),
153
+ ]
154
+ )
155
+
156
+
157
+ __all__ = ["FreeChunkerConfig", "FreeChunkerOnnxConfig"]
encoder.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ UnifiedEncoder - Unified text encoder
4
+ Integrates sentence splitting and multiple encoding models into a unified interface
5
+ """
6
+
7
+ import torch
8
+ import numpy as np
9
+ import pickle
10
+ import os
11
+ from typing import List, Tuple, Union
12
+ from .sentenizer import Sentenceizer
13
+ from .modeling_freechunker import FreeChunkerModel
14
+ from .aggregator import TextAggregator
15
+
16
+ class UnifiedEncoder:
17
+ """
18
+ Unified text encoder, supporting text sentence splitting and encoding for multiple models
19
+ """
20
+
21
+ def __init__(self, model_name: str, local_model_path: str = None):
22
+ """
23
+ Initialize unified text encoder
24
+
25
+ Args:
26
+ model_name (str): Model name (e.g. 'bge-m3', 'jina', 'nomic')
27
+ local_model_path (str, optional): Local model path for loading FreeChunker weights.
28
+ If None, tries to load from current directory or Hugging Face.
29
+ """
30
+ self.model_name = model_name
31
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'mps')
32
+
33
+ # Initialize text aggregator
34
+ self.aggregator = TextAggregator()
35
+
36
+ print(f"Initializing unified text encoder, model: {model_name}")
37
+ print(f"Using local model path: {local_model_path}")
38
+ print(f"Using device: {self.device}")
39
+
40
+ # If local_model_path is not provided, assume current directory or let from_pretrained handle it
41
+ if local_model_path is None:
42
+ local_model_path = "."
43
+
44
+ try:
45
+ self.model = FreeChunkerModel.from_pretrained(local_model_path)
46
+ except Exception as e:
47
+ print(f"Failed to load model from {local_model_path}: {e}")
48
+ print("Trying to load as a fresh model or from HF hub if applicable...")
49
+ # Fallback or re-raise
50
+ raise e
51
+
52
+ self.model.to(self.device)
53
+ self.model.eval()
54
+
55
+ # Select model and preprocessor based on model name
56
+ # Predefined model mapping: name -> (local_path, HF_model_ID)
57
+ # Note: Local paths are environment specific, so we primarily rely on HF IDs or passed arguments
58
+ model_configs = {
59
+ 'bge-m3': ('/share/home/ecnuzwx/UnifiedRAG/cache/models--BAAI--bge-m3', 'BAAI/bge-m3'),
60
+ 'nomic-embed-text-v1.5': ('/share/home/ecnuzwx/UnifiedRAG/cache/models--nomic-ai--nomic-embed-text-v1.5', 'nomic-ai/nomic-embed-text-v1.5'),
61
+ 'jina': ('/share/home/ecnuzwx/UnifiedRAG/cache/models--jinaai--jina-embeddings-v2-small-en', 'jinaai/jina-embeddings-v2-small-en')
62
+ }
63
+
64
+ if model_name in model_configs:
65
+ local_path, hf_id = model_configs[model_name]
66
+ # Prioritize local path if it exists, otherwise use HF ID
67
+ if os.path.exists(local_path):
68
+ target_model = local_path
69
+ else:
70
+ target_model = hf_id
71
+
72
+ self.sentenceizer = Sentenceizer(model_name=target_model)
73
+ else:
74
+ # Try using model_name directly as path or ID
75
+ print(f"Unknown predefined model name: {model_name}, trying to load directly...")
76
+ self.sentenceizer = Sentenceizer(model_name=model_name)
77
+
78
+ print("Unified text encoder initialized!")
79
+
80
+ def encode(self, text: str, show_progress: bool = True) -> Tuple[List[str], np.ndarray, List[List[str]]]:
81
+ """
82
+ Split text and encode, return results grouped by shift_matrix
83
+
84
+ Args:
85
+ text (str): Input text
86
+ show_progress (bool): Whether to show progress
87
+
88
+ Returns:
89
+ Tuple[List[str], np.ndarray, List[List[str]]]: (Original sentence list, encoded vector array, grouped sentence list by shift_matrix)
90
+ """
91
+ with torch.no_grad():
92
+ sentences, input_embeddings = self.sentenceizer.split_and_encode(text, show_progress=show_progress)
93
+
94
+ if len(sentences) == 0:
95
+ return sentences, np.array([]), []
96
+ if isinstance(input_embeddings, np.ndarray):
97
+ input_embeddings = torch.from_numpy(input_embeddings)
98
+ input_embeddings = input_embeddings.to(self.device)
99
+ inputs_embeds = input_embeddings.unsqueeze(0)
100
+ outputs = self.model(inputs_embeds=inputs_embeds)
101
+ final_embeddings = outputs['embedding']
102
+ shift_matrix = outputs['shift_matrix']
103
+
104
+ # Group sentences using shift_matrix
105
+ sentences = [f"【Begin-{num}】" + sentence + f"【End-{num}】" for num, sentence in enumerate(sentences)]
106
+ grouped_sentences = self._group_sentences_by_shift_matrix(sentences, shift_matrix)
107
+ result_embeddings = final_embeddings.cpu().numpy()
108
+
109
+ return sentences, result_embeddings, grouped_sentences
110
+
111
+ def _group_sentences_by_shift_matrix(self, sentences: List[str], shift_matrix: torch.Tensor) -> List[List[str]]:
112
+ """
113
+ Group sentences according to shift_matrix (Optimized version)
114
+
115
+ Args:
116
+ sentences (List[str]): Original sentence list
117
+ shift_matrix (torch.Tensor): Mask matrix with shape [num_chunks, seq_len]
118
+
119
+ Returns:
120
+ List[List[str]]: List of sentences grouped by shift_matrix
121
+ """
122
+
123
+ grouped_sentences = []
124
+ num_chunks, seq_len = shift_matrix.shape
125
+
126
+ for chunk_idx in range(num_chunks):
127
+ chunk_mask = shift_matrix[chunk_idx] # [seq_len]
128
+
129
+ # Use vectorized operation to get all indices that are 1
130
+ valid_indices = (chunk_mask == 1).nonzero(as_tuple=True)[0].cpu().numpy()
131
+
132
+ # Select only indices within the sentence list range
133
+ valid_indices = valid_indices[valid_indices < len(sentences)]
134
+
135
+ if len(valid_indices) > 0:
136
+ # Get sentences directly by index
137
+ chunk_sentences = [sentences[idx] for idx in valid_indices]
138
+ grouped_sentences.append(chunk_sentences)
139
+
140
+ return grouped_sentences
141
+
142
+ def build_vector_store(self, text: str, show_progress: bool = True):
143
+ """
144
+ Build vector store based on long text
145
+
146
+ Args:
147
+ text (str): Long text
148
+ show_progress (bool): Whether to show progress
149
+ """
150
+
151
+ sentences, embeddings, grouped_sentences = self.encode(text, show_progress)
152
+
153
+ # grouped_texts = [" ".join(group) if isinstance(group, list) else str(group) for group in grouped_sentences]
154
+
155
+ grouped_texts = sentences + [" ".join(group) if isinstance(group, list) else str(group) for group in grouped_sentences]
156
+
157
+ self.vector_store = {
158
+ 'sentences': sentences, # Keep original sentences for debugging
159
+ 'embeddings': embeddings, # embeddings correspond to grouped_sentences
160
+ 'grouped_sentences': grouped_sentences, # Original grouping structure
161
+ 'grouped_texts': grouped_texts # Text for retrieval
162
+ }
163
+
164
+ if show_progress:
165
+ print(f"Vector store built: {len(sentences)} original sentences, {len(grouped_sentences)} groups, {len(embeddings)} embedding vectors")
166
+ print(f"Vector store verification: embeddings.shape={embeddings.shape}, grouped_texts count={len(grouped_texts)}\n")
167
+
168
+ def query(self, query: str, top_k: int = 5, aggregation_mode: str = 'post', tokenizer=None) -> Union[List[Tuple[str, float]], str]:
169
+ """
170
+ Query vector store
171
+
172
+ Args:
173
+ query (str): Query text
174
+ top_k (int): Return top k most similar results
175
+ aggregation_mode (str): Aggregation mode
176
+ - 'none': No aggregation, return top_k results directly [(text, score), ...]
177
+ - 'post': Post-aggregation mode, return aggregated text string
178
+
179
+ Returns:
180
+ Union[List[Tuple[str, float]], str]:
181
+ - If aggregation_mode='none', return [(sentence, similarity_score), ...]
182
+ - If aggregation_mode='post', return aggregated string
183
+ """
184
+ if not hasattr(self, 'vector_store'):
185
+ raise ValueError("Vector store not built, please call build_vector_store method first")
186
+
187
+ # Encode query text
188
+ query_embeddings = self.sentenceizer.encode([query])
189
+ query_embedding = query_embeddings[0]
190
+
191
+ # Calculate cosine similarity
192
+ similarities = np.dot(self.vector_store['embeddings'], query_embedding)
193
+
194
+ # Sort (descending)
195
+ sorted_indices = np.argsort(similarities)[::-1]
196
+
197
+ if aggregation_mode == 'none':
198
+ return self._get_direct_results(sorted_indices, similarities, top_k)
199
+ elif aggregation_mode == 'post':
200
+ return self._post_aggregation(sorted_indices, similarities, top_k, tokenizer=tokenizer)
201
+ else:
202
+ print(f"Warning: Unknown aggregation_mode '{aggregation_mode}', falling back to 'none'")
203
+ return self._get_direct_results(sorted_indices, similarities, top_k)
204
+
205
+ def _get_direct_results(self, sorted_indices: np.ndarray, similarities: np.ndarray, top_k: int) -> List[Tuple[str, float]]:
206
+
207
+ available_count = len(self.vector_store['grouped_texts'])
208
+ actual_top_k = min(top_k, available_count)
209
+ top_indices = sorted_indices[:actual_top_k]
210
+
211
+ results = []
212
+ for idx in top_indices:
213
+ if idx < len(self.vector_store['grouped_texts']):
214
+ grouped_text = self.vector_store['grouped_texts'][idx]
215
+ score = similarities[idx]
216
+ results.append((grouped_text, float(score)))
217
+
218
+ return results
219
+
220
+ def _post_aggregation(self, sorted_indices: np.ndarray, similarities: np.ndarray, top_k: int, tokenizer=None) -> List[Tuple[str, float]]:
221
+
222
+ # Get top_k results first
223
+ direct_results = self._get_direct_results(sorted_indices, similarities, top_k)
224
+
225
+ # Extract text parts for aggregation
226
+ texts = [text for text, score in direct_results]
227
+
228
+ aggregated_texts = self.aggregator.aggregate_segments(texts)
229
+
230
+
231
+ return aggregated_texts
232
+
233
+
234
+ def load_vector_store(self, file_path: str):
235
+ """
236
+ Load vector store from file
237
+
238
+ Args:
239
+ file_path (str): Vector store file path
240
+ """
241
+ if not os.path.exists(file_path):
242
+ raise FileNotFoundError(f"Vector store file not found: {file_path}")
243
+
244
+ with open(file_path, 'rb') as f:
245
+ self.vector_store = pickle.load(f)
246
+
247
+ print(f"Vector store loaded from {file_path}")
248
+ print(f"Vector store info: {len(self.vector_store['grouped_texts'])} groups, embedding dimension: {self.vector_store['embeddings'].shape}")
249
+
250
+ def has_vector_store(self) -> bool:
251
+ """
252
+ Check if vector store is built or loaded
253
+
254
+ Returns:
255
+ bool: Whether a vector store is available
256
+ """
257
+ return hasattr(self, 'vector_store') and self.vector_store is not None
final_loss_curve.png ADDED
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aca47fe33b4f8d4b507ac46c60817fc9287a1b81d63c0ad06559196d64c9a30d
3
+ size 1247063776
modeling_freechunker.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019 Facebook AI Research and the HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """FreeChunker model: Modified from PyTorch XLM-RoBERTa model."""
17
+ from .utils import generate_shifted_matrix
18
+ import math
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from packaging import version
24
+ from torch import nn
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutputWithPoolingAndCrossAttentions
28
+ )
29
+ from transformers.modeling_utils import PreTrainedModel
30
+ from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
31
+ from transformers.utils import (
32
+ add_code_sample_docstrings,
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ get_torch_version,
36
+ logging
37
+ )
38
+ from .configuration_freechunker import FreeChunkerConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ _CHECKPOINT_FOR_DOC = "FacebookAI/xlm-roberta-base"
44
+ _CONFIG_FOR_DOC = "FreeChunkerConfig"
45
+
46
+
47
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->FreeChunker
48
+ class FreeChunkerEmbeddings(nn.Module):
49
+ """
50
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
51
+ """
52
+
53
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
57
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
58
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
65
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
66
+ self.register_buffer(
67
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
68
+ )
69
+ self.register_buffer(
70
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
71
+ )
72
+
73
+ # End copy
74
+ self.padding_idx = config.pad_token_id
75
+ self.position_embeddings = nn.Embedding(
76
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
77
+ )
78
+
79
+ def forward(
80
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None
81
+ ):
82
+ if position_ids is None:
83
+ if input_ids is not None:
84
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
85
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
86
+ else:
87
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
88
+
89
+ if input_ids is not None:
90
+ input_shape = input_ids.size()
91
+ else:
92
+ input_shape = inputs_embeds.size()[:-1]
93
+
94
+ seq_length = input_shape[1]
95
+
96
+ if position_ids is None:
97
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=self.position_ids.device)
98
+ position_ids = position_ids.unsqueeze(0).expand(input_shape)
99
+
100
+ if token_type_ids is None:
101
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
102
+
103
+ if inputs_embeds is None:
104
+ inputs_embeds = self.word_embeddings(input_ids)
105
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
106
+
107
+ embeddings = inputs_embeds + token_type_embeddings
108
+ if self.position_embedding_type == "absolute":
109
+ position_embeddings = self.position_embeddings(position_ids)
110
+ embeddings += position_embeddings
111
+ embeddings = self.LayerNorm(embeddings)
112
+ embeddings = self.dropout(embeddings)
113
+ return embeddings
114
+
115
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
116
+ """
117
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
118
+
119
+ Args:
120
+ inputs_embeds: torch.Tensor
121
+
122
+ Returns: torch.Tensor
123
+ """
124
+ input_shape = inputs_embeds.size()[:-1]
125
+ sequence_length = input_shape[1]
126
+
127
+ position_ids = torch.arange(
128
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
129
+ )
130
+ return position_ids.unsqueeze(0).expand(input_shape)
131
+
132
+
133
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->FreeChunker
134
+ class FreeChunkerSelfAttention(nn.Module):
135
+ def __init__(self, config, position_embedding_type=None):
136
+ super().__init__()
137
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
138
+ raise ValueError(
139
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
140
+ f"heads ({config.num_attention_heads})"
141
+ )
142
+
143
+ self.num_attention_heads = config.num_attention_heads
144
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
145
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
146
+
147
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
148
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
149
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
150
+
151
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
152
+ self.position_embedding_type = position_embedding_type or getattr(
153
+ config, "position_embedding_type", "absolute"
154
+ )
155
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
156
+ self.max_position_embeddings = config.max_position_embeddings
157
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
158
+
159
+ self.is_decoder = config.is_decoder
160
+
161
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
162
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
163
+ x = x.view(new_x_shape)
164
+ return x.permute(0, 2, 1, 3)
165
+
166
+ def forward(
167
+ self,
168
+ hidden_states: torch.Tensor,
169
+ hidden_states2: torch.Tensor, # Second input stream, required parameter
170
+ attention_mask: Optional[torch.FloatTensor] = None,
171
+ head_mask: Optional[torch.FloatTensor] = None,
172
+ output_attentions: Optional[bool] = False,
173
+ ) -> Tuple[torch.Tensor]:
174
+ # Query comes from hidden_states
175
+ mixed_query_layer = self.query(hidden_states)
176
+ query_layer = self.transpose_for_scores(mixed_query_layer)
177
+
178
+ # Key and Value come from hidden_states2
179
+ key_layer = self.transpose_for_scores(self.key(hidden_states2))
180
+ value_layer = self.transpose_for_scores(self.value(hidden_states2))
181
+
182
+ # Calculate attention scores
183
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
184
+
185
+ # Modified positional encoding handling
186
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
187
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
188
+
189
+ # hidden_states positions are all the first position (0, 0, 0, ...)
190
+ position_ids_l = torch.zeros(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
191
+ # hidden_states2 uses normal incremental position sequence (0, 1, 2, 3, ...)
192
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
193
+ distance = position_ids_l - position_ids_r
194
+
195
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
196
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
197
+
198
+ if self.position_embedding_type == "relative_key":
199
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
200
+ attention_scores = attention_scores + relative_position_scores
201
+ elif self.position_embedding_type == "relative_key_query":
202
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
203
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
204
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
205
+
206
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
207
+
208
+ if attention_mask is not None:
209
+ attention_scores = attention_scores + attention_mask
210
+
211
+ # Normalize to probabilities
212
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
213
+ attention_probs = self.dropout(attention_probs)
214
+
215
+ # Apply head mask
216
+ if head_mask is not None:
217
+ attention_probs = attention_probs * head_mask
218
+
219
+ # Calculate context
220
+ context_layer = torch.matmul(attention_probs, value_layer)
221
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
222
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
223
+ context_layer = context_layer.view(new_context_layer_shape)
224
+
225
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
226
+ return outputs
227
+
228
+
229
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->FreeChunker
230
+ class FreeChunkerSdpaSelfAttention(FreeChunkerSelfAttention):
231
+ def __init__(self, config, position_embedding_type=None):
232
+ super().__init__(config, position_embedding_type=position_embedding_type)
233
+ self.dropout_prob = config.attention_probs_dropout_prob
234
+ self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states: torch.Tensor,
239
+ hidden_states2: torch.Tensor, # Second input stream, required parameter
240
+ attention_mask: Optional[torch.Tensor] = None,
241
+ head_mask: Optional[torch.FloatTensor] = None,
242
+ output_attentions: Optional[bool] = False,
243
+ ) -> Tuple[torch.Tensor]:
244
+ # If relative positional encoding, output attentions, or head mask are present, fallback to parent implementation
245
+ if (self.position_embedding_type != "absolute" or
246
+ output_attentions or
247
+ head_mask is not None):
248
+ return super().forward(
249
+ hidden_states,
250
+ hidden_states2,
251
+ attention_mask,
252
+ head_mask,
253
+ output_attentions,
254
+ )
255
+
256
+ # Use optimized implementation of SDPA
257
+ bsz, tgt_len, _ = hidden_states.size()
258
+
259
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
260
+ key_layer = self.transpose_for_scores(self.key(hidden_states2))
261
+ value_layer = self.transpose_for_scores(self.value(hidden_states2))
262
+
263
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
264
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
265
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
266
+ if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
267
+ query_layer = query_layer.contiguous()
268
+ key_layer = key_layer.contiguous()
269
+ value_layer = value_layer.contiguous()
270
+
271
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
272
+ query_layer,
273
+ key_layer,
274
+ value_layer,
275
+ attn_mask=attention_mask,
276
+ dropout_p=self.dropout_prob if self.training else 0.0,
277
+ is_causal=False, # For customized tasks, causal mask is not used
278
+ )
279
+
280
+ attn_output = attn_output.transpose(1, 2)
281
+ attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
282
+
283
+ outputs = (attn_output,)
284
+ return outputs
285
+
286
+
287
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->FreeChunker
288
+ class FreeChunkerSelfOutput(nn.Module):
289
+ def __init__(self, config):
290
+ super().__init__()
291
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
292
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
293
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
294
+
295
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
296
+ hidden_states = self.dense(hidden_states)
297
+ hidden_states = self.dropout(hidden_states)
298
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
299
+ return hidden_states
300
+
301
+
302
+ XLM_ROBERTA_SELF_ATTENTION_CLASSES = {
303
+ "eager": FreeChunkerSelfAttention,
304
+ "sdpa": FreeChunkerSdpaSelfAttention,
305
+ }
306
+
307
+
308
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->FreeChunker
309
+ class FreeChunkerAttention(nn.Module):
310
+ def __init__(self, config, position_embedding_type=None):
311
+ super().__init__()
312
+ self.self = XLM_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](
313
+ config, position_embedding_type=position_embedding_type
314
+ )
315
+ self.output = FreeChunkerSelfOutput(config)
316
+ self.pruned_heads = set()
317
+
318
+ def prune_heads(self, heads):
319
+ if len(heads) == 0:
320
+ return
321
+ heads, index = find_pruneable_heads_and_indices(
322
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
323
+ )
324
+
325
+ # Prune linear layers
326
+ self.self.query = prune_linear_layer(self.self.query, index)
327
+ self.self.key = prune_linear_layer(self.self.key, index)
328
+ self.self.value = prune_linear_layer(self.self.value, index)
329
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
330
+
331
+ # Update hyper params and store pruned heads
332
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
333
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
334
+ self.pruned_heads = self.pruned_heads.union(heads)
335
+
336
+ def forward(
337
+ self,
338
+ hidden_states: torch.Tensor,
339
+ hidden_states2: torch.Tensor, # Second input stream, required parameter
340
+ attention_mask: Optional[torch.FloatTensor] = None,
341
+ head_mask: Optional[torch.FloatTensor] = None,
342
+ output_attentions: Optional[bool] = False,
343
+ ) -> Tuple[torch.Tensor]:
344
+ self_outputs = self.self(
345
+ hidden_states,
346
+ hidden_states2, # Pass second input stream
347
+ attention_mask,
348
+ head_mask,
349
+ output_attentions,
350
+ )
351
+ attention_output = self.output(self_outputs[0], hidden_states)
352
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
353
+ return outputs
354
+
355
+
356
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate with Roberta->FreeChunker
357
+ class FreeChunkerIntermediate(nn.Module):
358
+ def __init__(self, config):
359
+ super().__init__()
360
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
361
+ if isinstance(config.hidden_act, str):
362
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
363
+ else:
364
+ self.intermediate_act_fn = config.hidden_act
365
+
366
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
367
+ hidden_states = self.dense(hidden_states)
368
+ hidden_states = self.intermediate_act_fn(hidden_states)
369
+ return hidden_states
370
+
371
+
372
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaOutput with Roberta->FreeChunker
373
+ class FreeChunkerOutput(nn.Module):
374
+ def __init__(self, config):
375
+ super().__init__()
376
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
377
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
378
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
379
+
380
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
381
+ hidden_states = self.dense(hidden_states)
382
+ hidden_states = self.dropout(hidden_states)
383
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
384
+ return hidden_states
385
+
386
+
387
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->FreeChunker
388
+ class FreeChunkerLayer(nn.Module):
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
392
+ self.seq_len_dim = 1
393
+ self.attention = FreeChunkerAttention(config)
394
+ self.is_decoder = config.is_decoder
395
+ self.add_cross_attention = config.add_cross_attention
396
+ if self.add_cross_attention:
397
+ if not self.is_decoder:
398
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
399
+ self.crossattention = FreeChunkerAttention(config, position_embedding_type="absolute")
400
+ self.intermediate = FreeChunkerIntermediate(config)
401
+ self.output = FreeChunkerOutput(config)
402
+
403
+ def forward(
404
+ self,
405
+ hidden_states: torch.Tensor,
406
+ hidden_states2: torch.Tensor, # Second input stream, required parameter
407
+ attention_mask: Optional[torch.FloatTensor] = None,
408
+ head_mask: Optional[torch.FloatTensor] = None,
409
+ output_attentions: Optional[bool] = False,
410
+ ) -> Tuple[torch.Tensor]:
411
+ attention_outputs = self.attention(
412
+ hidden_states,
413
+ hidden_states2, # Pass second input stream
414
+ attention_mask,
415
+ head_mask,
416
+ output_attentions,
417
+ )
418
+ attention_output = attention_outputs[0]
419
+
420
+ outputs = attention_outputs[1:] # add self attentions if we output attention weights
421
+
422
+ layer_output = self.feed_forward_chunk(attention_output)
423
+ outputs = (layer_output,) + outputs
424
+
425
+ return outputs
426
+
427
+ def feed_forward_chunk(self, attention_output):
428
+ intermediate_output = self.intermediate(attention_output)
429
+ layer_output = self.output(intermediate_output, attention_output)
430
+ return layer_output
431
+
432
+
433
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->FreeChunker
434
+ class FreeChunkerEncoder(nn.Module):
435
+ def __init__(self, config):
436
+ super().__init__()
437
+ self.config = config
438
+ self.layer = nn.ModuleList([FreeChunkerLayer(config) for _ in range(config.num_hidden_layers)])
439
+ self.gradient_checkpointing = False
440
+
441
+ def forward(
442
+ self,
443
+ hidden_states: torch.Tensor,
444
+ hidden_states2: torch.Tensor, # Second input stream, required parameter
445
+ attention_mask: Optional[torch.FloatTensor] = None,
446
+ head_mask: Optional[torch.FloatTensor] = None,
447
+ ) -> torch.Tensor:
448
+
449
+ for i, layer_module in enumerate(self.layer):
450
+ layer_head_mask = head_mask[i] if head_mask is not None else None
451
+
452
+ if self.gradient_checkpointing and self.training:
453
+
454
+ def create_custom_forward(module):
455
+ def custom_forward(*inputs):
456
+ return module(*inputs)
457
+
458
+ return custom_forward
459
+
460
+ layer_outputs = torch.utils.checkpoint.checkpoint(
461
+ create_custom_forward(layer_module),
462
+ hidden_states,
463
+ hidden_states2, # Pass second input stream
464
+ attention_mask,
465
+ layer_head_mask,
466
+ )
467
+ else:
468
+ layer_outputs = layer_module(
469
+ hidden_states,
470
+ hidden_states2, # Pass second input stream
471
+ attention_mask,
472
+ layer_head_mask,
473
+ )
474
+
475
+ hidden_states = layer_outputs[0]
476
+
477
+ return hidden_states
478
+
479
+
480
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaPooler with Roberta->FreeChunker
481
+ class FreeChunkerPooler(nn.Module):
482
+ def __init__(self, config):
483
+ super().__init__()
484
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
485
+ self.activation = nn.Tanh()
486
+
487
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
488
+ # We "pool" the model by simply taking the hidden state corresponding
489
+ # to the first token.
490
+ first_token_tensor = hidden_states[:, 0]
491
+ pooled_output = self.dense(first_token_tensor)
492
+ pooled_output = self.activation(pooled_output)
493
+ return pooled_output
494
+
495
+
496
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel with Roberta->FreeChunker
497
+ class FreeChunkerPreTrainedModel(PreTrainedModel):
498
+ """
499
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
500
+ models.
501
+ """
502
+
503
+ config_class = FreeChunkerConfig
504
+ base_model_prefix = "roberta"
505
+ supports_gradient_checkpointing = True
506
+ _no_split_modules = ["FreeChunkerEmbeddings", "FreeChunkerSelfAttention", "FreeChunkerSdpaSelfAttention"]
507
+ _supports_sdpa = True
508
+
509
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
510
+ def _init_weights(self, module):
511
+ """Initialize the weights"""
512
+ if isinstance(module, nn.Linear):
513
+ # Slightly different from the TF version which uses truncated_normal for initialization
514
+ # cf https://github.com/pytorch/pytorch/pull/5617
515
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
516
+ if module.bias is not None:
517
+ module.bias.data.zero_()
518
+ elif isinstance(module, nn.Embedding):
519
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
520
+ if module.padding_idx is not None:
521
+ module.weight.data[module.padding_idx].zero_()
522
+ elif isinstance(module, nn.LayerNorm):
523
+ module.bias.data.zero_()
524
+ module.weight.data.fill_(1.0)
525
+
526
+
527
+ XLM_ROBERTA_START_DOCSTRING = r"""
528
+
529
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
530
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
531
+ etc.)
532
+
533
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
534
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
535
+ and behavior.
536
+
537
+ Parameters:
538
+ config ([`FreeChunkerConfig`]): Model configuration class with all the parameters of the
539
+ model. Initializing with a config file does not load the weights associated with the model, only the
540
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
541
+ """
542
+
543
+ XLM_ROBERTA_INPUTS_DOCSTRING = r"""
544
+ Args:
545
+ input_ids (`torch.LongTensor` of shape `({0})`):
546
+ Indices of input sequence tokens in the vocabulary.
547
+
548
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
549
+ [`PreTrainedTokenizer.__call__`] for details.
550
+
551
+ [What are input IDs?](../glossary#input-ids)
552
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
553
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
554
+
555
+ - 1 for tokens that are **not masked**,
556
+ - 0 for tokens that are **masked**.
557
+
558
+ [What are attention masks?](../glossary#attention-mask)
559
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
560
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
561
+ 1]`:
562
+
563
+ - 0 corresponds to a *sentence A* token,
564
+ - 1 corresponds to a *sentence B* token.
565
+
566
+ [What are token type IDs?](../glossary#token-type-ids)
567
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
568
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
569
+ config.max_position_embeddings - 1]`.
570
+
571
+ [What are position IDs?](../glossary#position-ids)
572
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
573
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
574
+
575
+ - 1 indicates the head is **not masked**,
576
+ - 0 indicates the head is **masked**.
577
+
578
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
579
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
580
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
581
+ model's internal embedding lookup matrix.
582
+ output_attentions (`bool`, *optional*):
583
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
584
+ tensors for more detail.
585
+ output_hidden_states (`bool`, *optional*):
586
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
587
+ more detail.
588
+ return_dict (`bool`, *optional*):
589
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
590
+ """
591
+
592
+
593
+ @add_start_docstrings(
594
+ "The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
595
+ XLM_ROBERTA_START_DOCSTRING,
596
+ )
597
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaModel with Roberta->FreeChunker, ROBERTA->XLM_ROBERTA
598
+ class FreeChunkerModel(FreeChunkerPreTrainedModel):
599
+ """
600
+
601
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
602
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
603
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
604
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
605
+
606
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
607
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
608
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
609
+ """
610
+
611
+ _no_split_modules = ["FreeChunkerEmbeddings", "FreeChunkerLayer"]
612
+
613
+ def __init__(self, config, add_pooling_layer=True):
614
+ super().__init__(config)
615
+ self.config = config
616
+ self.config.vocab_size = 2
617
+ self.embeddings = FreeChunkerEmbeddings(self.config)
618
+ self.encoder = FreeChunkerEncoder(config)
619
+
620
+ self.pooler = FreeChunkerPooler(config) if add_pooling_layer else None
621
+
622
+ self.attn_implementation = config._attn_implementation
623
+ self.position_embedding_type = config.position_embedding_type
624
+
625
+ # Initialize weights and apply final processing
626
+ self.post_init()
627
+
628
+ def get_input_embeddings(self):
629
+ return self.embeddings.word_embeddings
630
+
631
+ def set_input_embeddings(self, value):
632
+ self.embeddings.word_embeddings = value
633
+
634
+ def _prune_heads(self, heads_to_prune):
635
+ """
636
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
637
+ class PreTrainedModel
638
+ """
639
+ for layer, heads in heads_to_prune.items():
640
+ self.encoder.layer[layer].attention.prune_heads(heads)
641
+
642
+ @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
643
+ @add_code_sample_docstrings(
644
+ checkpoint=_CHECKPOINT_FOR_DOC,
645
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
646
+ config_class=_CONFIG_FOR_DOC,
647
+ )
648
+ def forward(
649
+ self,
650
+ inputs_embeds=None,
651
+ labels=None,
652
+ loss_weights: bool = False,
653
+ input_ids: Optional[torch.Tensor] = None,
654
+ head_mask: Optional[torch.Tensor] = None,
655
+ encoder_hidden_states: Optional[torch.Tensor] = None,
656
+ encoder_attention_mask: Optional[torch.Tensor] = None
657
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
658
+
659
+ # Get input device
660
+ input_device = inputs_embeds.device
661
+
662
+ # Dimension adaptation: if input dimension is less than 1024, pad with 0
663
+ original_hidden_size = inputs_embeds.shape[-1]
664
+ target_hidden_size = self.config.hidden_size # 1024
665
+
666
+ if original_hidden_size < target_hidden_size:
667
+ # Calculate number of dimensions to pad
668
+ padding_size = target_hidden_size - original_hidden_size
669
+ # Pad with 0 on the last dimension
670
+ padding = torch.zeros(inputs_embeds.shape[:-1] + (padding_size,),
671
+ device=input_device, dtype=inputs_embeds.dtype)
672
+ inputs_embeds = torch.cat([inputs_embeds, padding], dim=-1)
673
+
674
+ # Adjust max_power based on sequence length
675
+ sequence_length = inputs_embeds.shape[1]
676
+
677
+ shifted_matrix = generate_shifted_matrix(sequence_length, device=input_device)
678
+
679
+ # Generate attention mask
680
+ encoder_attention_mask = shifted_matrix.transpose(1, 2)
681
+ encoder_attention_mask = torch.where(encoder_attention_mask == 1.0, 0.0, float('-inf'))[:, None, :, :]
682
+
683
+ # Fixed input IDs and position IDs
684
+ input_ids = torch.tensor([[0] * shifted_matrix.shape[2]], device=input_device)
685
+ position_ids = torch.tensor([[0] * shifted_matrix.shape[2]], device=input_device)
686
+
687
+ # Embedding layer processing
688
+ embedding_output = self.embeddings(
689
+ input_ids=input_ids,
690
+ position_ids=position_ids,
691
+ token_type_ids=None,
692
+ )
693
+
694
+ # Set second input stream
695
+ encoder_hidden_states = inputs_embeds
696
+
697
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
698
+
699
+ # Encoder processing
700
+ sequence_output = self.encoder(
701
+ embedding_output,
702
+ hidden_states2=encoder_hidden_states, # Second input stream
703
+ attention_mask=encoder_attention_mask, # Use generated mask
704
+ head_mask=head_mask,
705
+ )
706
+
707
+ if original_hidden_size < target_hidden_size:
708
+
709
+ sequence_output = sequence_output[..., :original_hidden_size]
710
+ # Also truncate inputs_embeds back to original size to match dimensions of sequence_output
711
+ inputs_embeds = inputs_embeds[..., :original_hidden_size]
712
+
713
+ shift_matrix = shifted_matrix.transpose(1, 2).squeeze(0)
714
+ # Loss calculation
715
+ loss = None
716
+ if labels is not None:
717
+ emb = sequence_output.view(-1, sequence_output.shape[-1])
718
+ lab = labels.view(-1, labels.shape[-1])
719
+ target = torch.ones(emb.size(0), device=emb.device)
720
+
721
+ # If weights are provided, use weighted cosine loss
722
+ if loss_weights:
723
+ # Validate weight dimensions
724
+ loss_weights = shift_matrix.sum(dim=1).to(emb.device)
725
+
726
+ # Calculate unweighted cosine loss
727
+ cos_loss_fn = torch.nn.CosineEmbeddingLoss(reduction='none')
728
+ individual_losses = cos_loss_fn(emb, lab, target)
729
+
730
+ # Apply weights and calculate weighted average
731
+ weighted_losses = individual_losses * loss_weights
732
+ loss = weighted_losses.sum() / loss_weights.sum()
733
+ else:
734
+ # Use standard cosine loss
735
+ cos_loss = torch.nn.CosineEmbeddingLoss()
736
+ loss = cos_loss(emb, lab, target)
737
+
738
+ embedding = torch.cat([inputs_embeds, sequence_output], dim=1)
739
+ embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)
740
+ # embedding = torch.nn.functional.normalize(sequence_output, p=2, dim=-1)
741
+
742
+ return {
743
+ "loss": loss,
744
+ "embedding": embedding.squeeze(0),
745
+ "shift_matrix": shift_matrix
746
+ }
747
+
748
+ # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
749
+ def create_position_ids_from_input_ids(input_ids, padding_idx):
750
+ """
751
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
752
+ are ignored. This is modified from fairseq's `utils.make_positions`.
753
+
754
+ Args:
755
+ x: torch.Tensor x:
756
+
757
+ Returns: torch.Tensor
758
+ """
759
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
760
+ mask = input_ids.ne(padding_idx).int()
761
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
762
+ return incremental_indices.long() + padding_idx
763
+
764
+
765
+ __all__ = [
766
+ "FreeChunkerModel",
767
+ "FreeChunkerPreTrainedModel",
768
+ ]
sentenizer.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Sentenceizer - Universal sentence splitter + vector encoder
4
+ Length-constrained sentence splitting tool that protects special formats but not quotes/brackets
5
+ """
6
+
7
+ import numpy as np
8
+ from typing import List, Tuple, Union, Optional
9
+ from sentence_transformers import SentenceTransformer
10
+ from transformers import AutoTokenizer
11
+
12
+ # --- Integrated TraditionalChunking ---
13
+
14
+ def setup_tokenizer(model_name="xlm-roberta-base"):
15
+ """Setup tokenizer"""
16
+ try:
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ except Exception as e:
19
+ print(f"Warning: Could not load tokenizer for {model_name}: {e}. Falling back to bert-base-uncased")
20
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
21
+ return tokenizer
22
+
23
+ def fixed_size_chunking(text: str, tokenizer=None, chunk_size: int = 256, overlap: int = 0) -> List[str]:
24
+ """
25
+ Fixed-size chunking based on token count (Strict truncation)
26
+
27
+ Args:
28
+ text: Text to chunk
29
+ tokenizer: Tokenizer
30
+ chunk_size: Token count per chunk
31
+ overlap: Overlapping token count
32
+ """
33
+ if tokenizer is None:
34
+ tokenizer = setup_tokenizer()
35
+
36
+ # Encode the entire text, do not add special tokens to keep it clean
37
+ tokens = tokenizer.encode(text, add_special_tokens=False)
38
+ total_tokens = len(tokens)
39
+
40
+ chunks = []
41
+
42
+ # Calculate step size
43
+ step = chunk_size - overlap
44
+ if step <= 0:
45
+ step = 1 # Prevent infinite loop, theoretically overlap should be smaller than chunk_size
46
+
47
+ for i in range(0, total_tokens, step):
48
+ # Truncate tokens for current chunk
49
+ chunk_tokens = tokens[i : i + chunk_size]
50
+
51
+ # Decode back to text
52
+ chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
53
+
54
+ if chunk_text.strip():
55
+ chunks.append(chunk_text.strip())
56
+
57
+ return chunks
58
+
59
+ def traditional_chunking(text, tokenizer=None, chunk_size=256, overlap=0):
60
+ """
61
+ Fixed-size chunking based on tokens
62
+
63
+ Args:
64
+ text: Text to chunk
65
+ tokenizer: Tokenizer
66
+ chunk_size: Token count per chunk
67
+ overlap: Overlapping token count
68
+ """
69
+ return fixed_size_chunking(text, tokenizer, chunk_size, overlap)
70
+
71
+ class TraditionalChunking:
72
+ def __init__(self, model_name_or_path=None, tokenizer=None, chunk_size=256, overlap=0):
73
+ if tokenizer is not None:
74
+ self.tokenizer = tokenizer
75
+ elif model_name_or_path is not None:
76
+ self.tokenizer = setup_tokenizer(model_name_or_path)
77
+ else:
78
+ self.tokenizer = setup_tokenizer()
79
+ self.chunk_size = chunk_size
80
+ self.overlap = overlap
81
+
82
+ def chunk(self, text):
83
+ return traditional_chunking(text, self.tokenizer, self.chunk_size, self.overlap)
84
+
85
+ # --- End TraditionalChunking ---
86
+
87
+
88
+ class Sentenceizer:
89
+ """
90
+ Universal sentence splitter and encoder with length constraints, protecting special formats
91
+ """
92
+
93
+ def __init__(self, model_name: Optional[str] = None):
94
+ """
95
+ Initialize Sentenceizer
96
+
97
+ Args:
98
+ model_name (str, optional): SentenceTransformer model name
99
+ If None, no encoding model is loaded
100
+ """
101
+ # Initialize chunker with model_name if available, otherwise default
102
+ self.chunker = TraditionalChunking(model_name_or_path=model_name if model_name else "xlm-roberta-base", chunk_size=256, overlap=0)
103
+
104
+ self.model = None
105
+ self.model_name = model_name
106
+ if model_name:
107
+ print(f"Loading sentence transformer model: {model_name}")
108
+ self.model = SentenceTransformer(model_name, trust_remote_code=True)
109
+ self.model.eval()
110
+ print(f"Model loaded successfully. Embedding dimension: {self.model.get_sentence_embedding_dimension()}")
111
+
112
+ def split(self, text: str) -> List[str]:
113
+ """
114
+ Split text into sentence list using NLTK sent_tokenize, then merge short sentences
115
+
116
+ Args:
117
+ text (str): Input text
118
+
119
+ Returns:
120
+ List[str]: List of sentences
121
+ """
122
+ if not text.strip():
123
+ return []
124
+
125
+ return self.chunker.chunk(text)
126
+
127
+ def split_with_positions(self, text: str) -> List[Tuple[str, int, int]]:
128
+ """
129
+ Split text and return sentences with their positions in the original text
130
+
131
+ Args:
132
+ text (str): Input text
133
+
134
+ Returns:
135
+ List[Tuple[str, int, int]]: List of (sentence, start_position, end_position)
136
+ """
137
+ sentences = self.split(text)
138
+ sentences_with_pos = []
139
+
140
+ start_pos = 0
141
+ for sentence in sentences:
142
+ # Find sentence position in original text
143
+ pos = text.find(sentence, start_pos)
144
+ if pos != -1:
145
+ sentences_with_pos.append((sentence, pos, pos + len(sentence)))
146
+ start_pos = pos + len(sentence)
147
+ else:
148
+ # If not found (possibly due to merging or splitting), use estimated position
149
+ sentences_with_pos.append((sentence, start_pos, start_pos + len(sentence)))
150
+ start_pos += len(sentence)
151
+
152
+ return sentences_with_pos
153
+
154
+ def encode(self, text: Union[str, List[str]], show_progress: bool = False) -> np.ndarray:
155
+ """
156
+ Encode text
157
+
158
+ Args:
159
+ text (Union[str, List[str]]): Input text, can be a single string or list of strings
160
+ If it's a string, sentence splitting will be performed first
161
+ show_progress (bool): Whether to show progress bar
162
+
163
+ Returns:
164
+ np.ndarray: Encoded vector array with shape (n_sentences, embedding_dim)
165
+
166
+ Raises:
167
+ ValueError: If no model is loaded
168
+ """
169
+ if self.model is None:
170
+ raise ValueError("No model loaded. Please initialize with a model_name.")
171
+
172
+ # If input is string, perform sentence splitting first
173
+ if isinstance(text, str):
174
+ sentences = self.split(text)
175
+ else:
176
+ sentences = text
177
+
178
+ if not sentences:
179
+ return np.array([])
180
+
181
+ # Use sentence transformer for encoding, limit max batch size to 64
182
+ embeddings = self.model.encode(
183
+ sentences,
184
+ show_progress_bar=show_progress,
185
+ convert_to_numpy=True,
186
+ batch_size=4
187
+ )
188
+
189
+ return embeddings
190
+
191
+ def split_and_encode(self, text: str, show_progress: bool = True) -> Tuple[List[str], np.ndarray]:
192
+ """
193
+ Split text and encode
194
+
195
+ Args:
196
+ text (str): Input text
197
+ show_progress (bool): Whether to show progress bar
198
+
199
+ Returns:
200
+ Tuple[List[str], np.ndarray]: (sentence list, encoded vector array)
201
+ """
202
+ sentences = self.split(text)
203
+ embeddings = self.encode(sentences, show_progress=show_progress)
204
+ return sentences, embeddings
205
+
206
+ @property
207
+ def embedding_dimension(self) -> int:
208
+ """Get embedding dimension"""
209
+ if self.model is None:
210
+ raise ValueError("No model loaded.")
211
+ return self.model.get_sentence_embedding_dimension()
212
+
213
+ def test_sentenceizer():
214
+ """Test universal sentence splitting functionality and protection mechanisms"""
215
+
216
+ print("=== Testing Universal Sentence Splitting and Protection Mechanisms ===")
217
+
218
+ # Use reasonable length constraints for testing
219
+ sentenceizer = Sentenceizer()
220
+
221
+ test_cases = [
222
+ # Basic sentence splitting test
223
+ "This is the first sentence. This is the second sentence! This is the third sentence?",
224
+
225
+ # Quote sentence splitting test (should be able to split)
226
+ 'He said "Hello there. How are you? I hope you are well." Then he left.',
227
+
228
+ # Abbreviation protection test (should not split at abbreviations)
229
+ "Dr. Smith is here. Mr. Jones left at 3 p.m. today. The U.S. economy is growing.",
230
+
231
+ # Number protection test (should not split within numbers)
232
+ "The temperature is 36.5 degrees. The price is $19.99. Version 2.1.3 was released.",
233
+
234
+ # Ellipsis protection test (should not split at ellipsis)
235
+ "This is incomplete... But this continues the thought. Another sentence follows.",
236
+
237
+ # URL protection test (should not split within URLs)
238
+ "Visit https://www.example.com for more info. The website www.test.org has details.",
239
+
240
+ # Email protection test (should not split within emails)
241
+ "Contact me at john.doe@example.com for questions. Send reports to admin@company.org please.",
242
+
243
+ # Date and time protection test
244
+ "The meeting is on 12/25/2023. We start at 3:30 p.m. today. See you then.",
245
+
246
+ # Non-English text test
247
+ "这是第一个句子。这是第二个句子!这是第三个句子?",
248
+
249
+ # Mixed text test
250
+ "This is English. 这是中文。Mix of both languages!",
251
+
252
+ # Complex mixed test
253
+ "访问 https://www.baidu.com 获取信息。联系邮箱是 test@163.com。价格为 ¥99.99 元。",
254
+
255
+ # Long sentence test (should be split)
256
+ "This is a very long sentence that should be split into multiple parts because it exceeds the maximum length limit that we have set for individual sentences in our system, and we need to handle this properly.",
257
+
258
+ # Sentences starting with numbers
259
+ "Today is sunny. 123 people attended the meeting. Everyone was happy.",
260
+
261
+ # Sentences starting with special characters
262
+ "First sentence here. \"Quoted sentence comes next.\" Final sentence ends it.",
263
+ ]
264
+
265
+ for i, text in enumerate(test_cases, 1):
266
+ print(f"\n--- Test Case {i} ---")
267
+ print(f"Original: {text}")
268
+
269
+ sentences = sentenceizer.split(text)
270
+ print(f"Split Result ({len(sentences)} sentences):")
271
+ for j, sentence in enumerate(sentences, 1):
272
+ print(f" {j}. ({len(sentence)} chars) {sentence}")
273
+
274
+
275
+ if __name__ == "__main__":
276
+ test_sentenceizer()
training_losses.json ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Utility Functions
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ import torch
9
+
10
+ def generate_shifted_matrix(n, device=None):
11
+
12
+ matrix_columns = []
13
+ granularities = [2, 4]
14
+
15
+ for granularity in granularities:
16
+ if granularity > n:
17
+ continue
18
+
19
+ # Calculate step size for this granularity
20
+ step_size = max(1, granularity // 2)
21
+ max_start = n - granularity
22
+
23
+ for start in range(0, max_start + 1, step_size):
24
+ column = torch.zeros(n, dtype=torch.int, device=device)
25
+ column[start:start + granularity] = 1
26
+ matrix_columns.append(column)
27
+
28
+ # If the last position is not covered, add a mask at the end
29
+ if max_start >= 0 and (max_start % step_size) != 0:
30
+ column = torch.zeros(n, dtype=torch.int, device=device)
31
+ column[-granularity:] = 1
32
+ matrix_columns.append(column)
33
+
34
+ if not matrix_columns:
35
+ column = torch.ones(n, dtype=torch.int, device=device)
36
+ matrix_columns.append(column)
37
+
38
+ result = torch.stack(matrix_columns, dim=1).unsqueeze(0).expand(1, -1, -1)
39
+ return result
40
+
41
+ def create_attention_mask(shift_matrix: torch.Tensor) -> torch.Tensor:
42
+ """
43
+ Create attention mask from shift matrix
44
+
45
+ Args:
46
+ shift_matrix (torch.Tensor): shift matrix, shape [num_chunks, seq_len]
47
+
48
+ Returns:
49
+ torch.Tensor: attention mask, shape [1, num_chunks, seq_len, seq_len]
50
+ """
51
+ # Transpose and create attention mask
52
+ attention_mask = shift_matrix.transpose(0, 1) # [seq_len, num_chunks]
53
+ attention_mask = torch.where(attention_mask == 1.0, 0.0, float('-inf'))
54
+
55
+ # Add dimensions to match expected shape of attention
56
+ attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, num_chunks]
57
+
58
+ return attention_mask
59
+
60
+ def normalize_embeddings(embeddings: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
61
+ """
62
+ L2 normalize embeddings
63
+
64
+ Args:
65
+ embeddings (torch.Tensor): Embeddings
66
+ eps (float): Small value to prevent division by zero
67
+
68
+ Returns:
69
+ torch.Tensor: Normalized embeddings
70
+ """
71
+ norm = torch.norm(embeddings, dim=-1, keepdim=True)
72
+ return embeddings / (norm + eps)
73
+
74
+ def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
75
+ """
76
+ Calculate cosine similarity
77
+
78
+ Args:
79
+ a (torch.Tensor): Vector A
80
+ b (torch.Tensor): Vector B
81
+
82
+ Returns:
83
+ torch.Tensor: Cosine similarity
84
+ """
85
+ a_norm = normalize_embeddings(a)
86
+ b_norm = normalize_embeddings(b)
87
+ return torch.sum(a_norm * b_norm, dim=-1)
88
+
89
+ def batch_cosine_similarity(embeddings1: torch.Tensor, embeddings2: torch.Tensor) -> torch.Tensor:
90
+ """
91
+ Calculate batch cosine similarity
92
+
93
+ Args:
94
+ embeddings1 (torch.Tensor): Embeddings group 1, shape [N, dim]
95
+ embeddings2 (torch.Tensor): Embeddings group 2, shape [M, dim]
96
+
97
+ Returns:
98
+ torch.Tensor: Similarity matrix, shape [N, M]
99
+ """
100
+ embeddings1_norm = normalize_embeddings(embeddings1)
101
+ embeddings2_norm = normalize_embeddings(embeddings2)
102
+
103
+ return torch.matmul(embeddings1_norm, embeddings2_norm.transpose(0, 1))
104
+
105
+ def split_embeddings_by_shift_matrix(embeddings: torch.Tensor, shift_matrix: torch.Tensor) -> list:
106
+ """
107
+ Split embeddings based on shift matrix
108
+
109
+ Args:
110
+ embeddings (torch.Tensor): Embeddings, shape [seq_len, hidden_dim]
111
+ shift_matrix (torch.Tensor): shift matrix, shape [num_chunks, seq_len]
112
+
113
+ Returns:
114
+ list: List of split embeddings
115
+ """
116
+ split_embeddings = []
117
+ num_chunks, seq_len = shift_matrix.shape
118
+
119
+ for chunk_idx in range(num_chunks):
120
+ mask = shift_matrix[chunk_idx] # [seq_len]
121
+ indices = torch.nonzero(mask, as_tuple=True)[0] # Get indices of non-zero positions
122
+
123
+ if len(indices) > 0:
124
+ chunk_embeddings = embeddings[indices] # [chunk_size, hidden_dim]
125
+ split_embeddings.append(chunk_embeddings)
126
+
127
+ return split_embeddings
128
+
129
+ def pool_embeddings(embeddings: torch.Tensor, method: str = 'mean') -> torch.Tensor:
130
+ """
131
+ Pool embeddings
132
+
133
+ Args:
134
+ embeddings (torch.Tensor): Embeddings, shape [seq_len, hidden_dim]
135
+ method (str): Pooling method, optional 'mean', 'max', 'first', 'last'
136
+
137
+ Returns:
138
+ torch.Tensor: Pooled vector, shape [hidden_dim]
139
+ """
140
+ if method == 'mean':
141
+ return torch.mean(embeddings, dim=0)
142
+ elif method == 'max':
143
+ return torch.max(embeddings, dim=0)[0]
144
+ elif method == 'first':
145
+ return embeddings[0]
146
+ elif method == 'last':
147
+ return embeddings[-1]
148
+ else:
149
+ raise ValueError(f"Unknown pooling method: {method}")
150
+
151
+ def aggregate_chunk_embeddings(split_embeddings: list, method: str = 'mean') -> torch.Tensor:
152
+ """
153
+ Aggregate chunk embeddings
154
+
155
+ Args:
156
+ split_embeddings (list): List of split embeddings
157
+ method (str): Aggregation method
158
+
159
+ Returns:
160
+ torch.Tensor: Aggregated embeddings, shape [num_chunks, hidden_dim]
161
+ """
162
+ if not split_embeddings:
163
+ return torch.tensor([])
164
+
165
+ aggregated = []
166
+ for chunk_embeddings in split_embeddings:
167
+ pooled = pool_embeddings(chunk_embeddings, method)
168
+ aggregated.append(pooled)
169
+
170
+ return torch.stack(aggregated)
171
+
172
+ def safe_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
173
+ """
174
+ Safely convert tensor to numpy array
175
+
176
+ Args:
177
+ tensor (torch.Tensor): Input tensor
178
+
179
+ Returns:
180
+ np.ndarray: Numpy array
181
+ """
182
+ if tensor.requires_grad:
183
+ tensor = tensor.detach()
184
+
185
+ if tensor.is_cuda:
186
+ tensor = tensor.cpu()
187
+
188
+ return tensor.numpy()
189
+
190
+ def ensure_tensor_on_device(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
191
+ """
192
+ Ensure tensor is on specified device
193
+
194
+ Args:
195
+ tensor (torch.Tensor): Input tensor
196
+ device (torch.device): Target device
197
+
198
+ Returns:
199
+ torch.Tensor: Tensor on target device
200
+ """
201
+ if tensor.device != device:
202
+ tensor = tensor.to(device)
203
+ return tensor
204
+
205
+ def get_available_device() -> torch.device:
206
+ """
207
+ Get available device
208
+
209
+ Returns:
210
+ torch.device: Available device
211
+ """
212
+ if torch.cuda.is_available():
213
+ return torch.device('cuda')
214
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
215
+ return torch.device('mps')
216
+ else:
217
+ return torch.device('cpu')
218
+
219
+ def print_tensor_info(tensor: torch.Tensor, name: str = "tensor"):
220
+ """
221
+ Print tensor info
222
+
223
+ Args:
224
+ tensor (torch.Tensor): Input tensor
225
+ name (str): Tensor name
226
+ """
227
+ print(f"{name}:")
228
+ print(f" Shape: {tensor.shape}")
229
+ print(f" Data Type: {tensor.dtype}")
230
+ print(f" Device: {tensor.device}")
231
+ print(f" Requires Grad: {tensor.requires_grad}")
232
+ if tensor.numel() > 0:
233
+ print(f" Value Range: [{tensor.min().item():.6f}, {tensor.max().item():.6f}]")
234
+ print(f" Mean: {tensor.mean().item():.6f}")
235
+ print(f" Std Dev: {tensor.std().item():.6f}")