Initial upload of FreeChunk model with custom code
Browse files- README.md +94 -0
- aggregator.py +205 -0
- config.json +32 -0
- configuration_freechunker.py +157 -0
- encoder.py +257 -0
- final_loss_curve.png +0 -0
- model.safetensors +3 -0
- modeling_freechunker.py +768 -0
- sentenizer.py +276 -0
- training_losses.json +0 -0
- utils.py +235 -0
README.md
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FreeChunker-Jina
|
| 2 |
+
|
| 3 |
+
FreeChunker is a training-free embedding optimization method that dynamically chunks text to improve retrieval performance. This repository contains the **FreeChunker** model initialized with **jinaai/jina-embeddings-v2-small-en** embeddings.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- **Dynamic Chunking**: Automatically groups sentences into semantically coherent chunks.
|
| 8 |
+
- **Optimized for RAG**: Improves retrieval augmented generation by providing better context segments.
|
| 9 |
+
- **Backbone**: Built on top of `jinaai/jina-embeddings-v2-small-en` sentence embeddings.
|
| 10 |
+
|
| 11 |
+
## Requirements
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
pip install torch transformers sentence-transformers numpy
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## Usage
|
| 18 |
+
|
| 19 |
+
You can use the provided `UnifiedEncoder` class (in `encoder.py`) to easily use the model for encoding and retrieval.
|
| 20 |
+
|
| 21 |
+
### Using UnifiedEncoder
|
| 22 |
+
|
| 23 |
+
```python
|
| 24 |
+
from encoder import UnifiedEncoder
|
| 25 |
+
|
| 26 |
+
# Initialize the encoder
|
| 27 |
+
# local_model_path="." assumes you are in the directory containing model.safetensors
|
| 28 |
+
encoder = UnifiedEncoder(model_name="jina", local_model_path=".")
|
| 29 |
+
|
| 30 |
+
# Input text
|
| 31 |
+
text = """
|
| 32 |
+
Your long text goes here. FreeChunker will split this text into sentences,
|
| 33 |
+
generate embeddings using Jina, and then group them into semantic chunks.
|
| 34 |
+
It handles long documents effectively.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
# Build vector store (chunks and encodes the text)
|
| 38 |
+
encoder.build_vector_store(text)
|
| 39 |
+
|
| 40 |
+
# Query
|
| 41 |
+
query = "How does FreeChunker work?"
|
| 42 |
+
results = encoder.query(query, top_k=3, aggregation_mode='post')
|
| 43 |
+
|
| 44 |
+
print("Results:", results)
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### Manual Pipeline
|
| 48 |
+
|
| 49 |
+
If you prefer to use the components separately:
|
| 50 |
+
|
| 51 |
+
1. **Split and Encode**: Use `Sentenceizer` (wrapping `jinaai/jina-embeddings-v2-small-en`) to get sentence embeddings.
|
| 52 |
+
2. **FreeChunker**: Pass embeddings to `FreeChunkerModel`.
|
| 53 |
+
3. **Process**: Use the output `shift_matrix` to group sentences.
|
| 54 |
+
|
| 55 |
+
```python
|
| 56 |
+
from sentenizer import Sentenceizer
|
| 57 |
+
from modeling_freechunker import FreeChunkerModel
|
| 58 |
+
import torch
|
| 59 |
+
|
| 60 |
+
# 1. Setup Sentenceizer with Backbone
|
| 61 |
+
sentenceizer = Sentenceizer(model_name="jinaai/jina-embeddings-v2-small-en")
|
| 62 |
+
|
| 63 |
+
# 2. Load FreeChunker Model
|
| 64 |
+
model = FreeChunkerModel.from_pretrained(".", trust_remote_code=True)
|
| 65 |
+
model.eval()
|
| 66 |
+
|
| 67 |
+
# 3. Process Text
|
| 68 |
+
text = "Your text..."
|
| 69 |
+
sentences, embeddings = sentenceizer.split_and_encode(text)
|
| 70 |
+
|
| 71 |
+
# 4. Forward pass through FreeChunker
|
| 72 |
+
inputs_embeds = torch.tensor(embeddings).unsqueeze(0) # Batch size 1
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
outputs = model(inputs_embeds=inputs_embeds)
|
| 75 |
+
|
| 76 |
+
# outputs['embedding'] contains refined embeddings
|
| 77 |
+
# outputs['shift_matrix'] contains chunking information
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Files
|
| 81 |
+
|
| 82 |
+
- `model.safetensors`: The FreeChunker model weights.
|
| 83 |
+
- `encoder.py`: High-level interface (`UnifiedEncoder`) for end-to-end usage.
|
| 84 |
+
- `sentenizer.py`: Helper for text splitting and backbone embedding.
|
| 85 |
+
- `aggregator.py`: Helper for aggregating retrieved results.
|
| 86 |
+
- `configuration_freechunker.py` & `modeling_freechunker.py`: Model definition.
|
| 87 |
+
|
| 88 |
+
## Citation
|
| 89 |
+
|
| 90 |
+
If you use this model in your research, please cite:
|
| 91 |
+
|
| 92 |
+
```
|
| 93 |
+
Zhang W, Jiang Y H, Wu Y. FreeChunker: A Cross-Granularity Chunking Framework[J]. arXiv preprint arXiv:2510.20356, 2025.
|
| 94 |
+
```
|
aggregator.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Text Aggregator - Precise text segment aggregation based on sentence position markers
|
| 5 |
+
|
| 6 |
+
Main functions:
|
| 7 |
+
1. Detect overlaps between text segments based on 【Begin-x】【End-y】 markers
|
| 8 |
+
2. Automatically merge and reconstruct based on original order when overlapping
|
| 9 |
+
3. Retain the highest scoring segments
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import re
|
| 13 |
+
from typing import List, Tuple
|
| 14 |
+
|
| 15 |
+
class TextAggregator:
|
| 16 |
+
"""
|
| 17 |
+
Text aggregator for merging retrieved text segments
|
| 18 |
+
Implements splitting, deduplication, sorting, and reconstruction of text segments based on 【Begin-x】【End-x】 markers
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
"""
|
| 23 |
+
Initialize text aggregator
|
| 24 |
+
"""
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
def _extract_segments_from_text(self, text: str) -> List[Tuple[int, str]]:
|
| 28 |
+
"""
|
| 29 |
+
Extract all 【Begin-x】...【End-x】 segments from text
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
text: Text containing position markers
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
List[Tuple[int, str]]: List of (begin_index, segment_text)
|
| 36 |
+
"""
|
| 37 |
+
segments = []
|
| 38 |
+
# Match 【Begin-x】...【End-x】 pattern
|
| 39 |
+
pattern = r'【Begin-(\d+)】(.*?)【End-\1】'
|
| 40 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
| 41 |
+
|
| 42 |
+
for match in matches:
|
| 43 |
+
begin_idx = int(match[0])
|
| 44 |
+
segment_content = match[1]
|
| 45 |
+
full_segment = f"【Begin-{begin_idx}】{segment_content}【End-{begin_idx}】"
|
| 46 |
+
segments.append((begin_idx, full_segment))
|
| 47 |
+
|
| 48 |
+
return segments
|
| 49 |
+
|
| 50 |
+
def _remove_boundary_markers(self, text: str) -> str:
|
| 51 |
+
"""
|
| 52 |
+
Remove all boundary markers from text, keeping only content
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
text: Text containing boundary markers
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
str: Text with boundary markers removed
|
| 59 |
+
"""
|
| 60 |
+
# Remove 【Begin-x】 and 【End-x】 markers
|
| 61 |
+
clean_text = re.sub(r'【Begin-\d+】|【End-\d+】', '', text)
|
| 62 |
+
return clean_text.strip()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def aggregate_segments(self, segments: List[str]) -> str:
|
| 67 |
+
"""
|
| 68 |
+
Aggregate text segments: split, deduplicate, sort, reconstruct
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
segments: List of text segments
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
str: Aggregated text string
|
| 75 |
+
"""
|
| 76 |
+
if not segments:
|
| 77 |
+
return ""
|
| 78 |
+
|
| 79 |
+
# Step 1: Extract segments from all input texts
|
| 80 |
+
all_segments = {} # {begin_index: segment_text}
|
| 81 |
+
|
| 82 |
+
for text in segments:
|
| 83 |
+
extracted = self._extract_segments_from_text(text)
|
| 84 |
+
for begin_idx, segment in extracted:
|
| 85 |
+
# Deduplication: Keep only one segment for the same begin_index
|
| 86 |
+
if begin_idx not in all_segments:
|
| 87 |
+
all_segments[begin_idx] = segment
|
| 88 |
+
|
| 89 |
+
# Step 2: Sort by begin_index
|
| 90 |
+
sorted_segments = sorted(all_segments.items())
|
| 91 |
+
|
| 92 |
+
# Step 3: Reconstruct text
|
| 93 |
+
if not sorted_segments:
|
| 94 |
+
return []
|
| 95 |
+
|
| 96 |
+
# Build continuous text
|
| 97 |
+
result_text = ""
|
| 98 |
+
prev_end = -1
|
| 99 |
+
|
| 100 |
+
for begin_idx, segment in sorted_segments:
|
| 101 |
+
# If not continuous, add ellipsis
|
| 102 |
+
if prev_end != -1 and begin_idx != prev_end + 1:
|
| 103 |
+
result_text += "..."
|
| 104 |
+
|
| 105 |
+
# Add content of current segment (remove boundary markers)
|
| 106 |
+
content = self._remove_boundary_markers(segment)
|
| 107 |
+
result_text += content
|
| 108 |
+
|
| 109 |
+
prev_end = begin_idx
|
| 110 |
+
|
| 111 |
+
return result_text
|
| 112 |
+
|
| 113 |
+
def aggregate_segments_complete(self, segments: List[str]) -> str:
|
| 114 |
+
"""
|
| 115 |
+
Completely aggregate all text segments
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
segments: List of text segments
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
str: Aggregated text string
|
| 122 |
+
"""
|
| 123 |
+
return self.aggregate_segments(segments)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def demo():
|
| 129 |
+
"""Demo function - Show text splitting, deduplication, sorting, and reconstruction based on position markers"""
|
| 130 |
+
print("=== Text Aggregator Demo (Completely Rewritten Version) ===\n")
|
| 131 |
+
|
| 132 |
+
# Create aggregator
|
| 133 |
+
aggregator = TextAggregator()
|
| 134 |
+
|
| 135 |
+
# Test data - Format according to user example
|
| 136 |
+
test_segments = [
|
| 137 |
+
"【Begin-1】sdfsdf【End-1】【Begin-2】sdfsdf【End-2】",
|
| 138 |
+
"【Begin-2】sdfsdf【End-2】【Begin-3】sdfsdf【End-3】",
|
| 139 |
+
"【Begin-5】sdfsdf【End-5】【Begin-6】sdfsdf【End-6】"
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
print("Original input segments:")
|
| 143 |
+
for i, text in enumerate(test_segments, 1):
|
| 144 |
+
print(f"{i}. {text}")
|
| 145 |
+
|
| 146 |
+
print("\n=== Step 1: Extract segments from each text ===")
|
| 147 |
+
all_extracted = {}
|
| 148 |
+
for i, text in enumerate(test_segments, 1):
|
| 149 |
+
extracted = aggregator._extract_segments_from_text(text)
|
| 150 |
+
print(f"Segments extracted from text {i}: {extracted}")
|
| 151 |
+
for begin_idx, segment in extracted:
|
| 152 |
+
if begin_idx not in all_extracted:
|
| 153 |
+
all_extracted[begin_idx] = segment
|
| 154 |
+
print(f" Add segment: Begin-{begin_idx}")
|
| 155 |
+
else:
|
| 156 |
+
print(f" Skip duplicate segment: Begin-{begin_idx}")
|
| 157 |
+
|
| 158 |
+
print(f"\nAll segments after deduplication: {list(all_extracted.keys())}")
|
| 159 |
+
|
| 160 |
+
print("\n=== Step 2: Sort by Begin marker ===")
|
| 161 |
+
sorted_segments = sorted(all_extracted.items())
|
| 162 |
+
print("Sorted segments:")
|
| 163 |
+
for begin_idx, segment in sorted_segments:
|
| 164 |
+
print(f" Begin-{begin_idx}: {segment}")
|
| 165 |
+
|
| 166 |
+
print("\n=== Step 3: Reconstruct text (remove boundary markers, add ellipsis) ===")
|
| 167 |
+
result = aggregator.aggregate_segments(test_segments)
|
| 168 |
+
print(f"Final result: {result}")
|
| 169 |
+
|
| 170 |
+
print("\n=== Full Test Cases ===")
|
| 171 |
+
|
| 172 |
+
# More complex test cases
|
| 173 |
+
complex_segments = [
|
| 174 |
+
"【Begin-1】First sentence【End-1】【Begin-2】Second sentence【End-2】【Begin-3】Third sentence【End-3】",
|
| 175 |
+
"【Begin-2】Second sentence【End-2】【Begin-3】Third sentence【End-3】【Begin-4】Fourth sentence【End-4】",
|
| 176 |
+
"【Begin-6】Sixth sentence【End-6】【Begin-7】Seventh sentence【End-7】",
|
| 177 |
+
"【Begin-4】Fourth sentence【End-4】【Begin-5】Fifth sentence【End-5】"
|
| 178 |
+
]
|
| 179 |
+
|
| 180 |
+
print("\nComplex test input:")
|
| 181 |
+
for i, text in enumerate(complex_segments, 1):
|
| 182 |
+
print(f"{i}. {text}")
|
| 183 |
+
|
| 184 |
+
complex_result = aggregator.aggregate_segments(complex_segments)
|
| 185 |
+
print(f"\nComplex test result: {complex_result}")
|
| 186 |
+
|
| 187 |
+
print("\n=== Boundary Case Tests ===")
|
| 188 |
+
|
| 189 |
+
# Test empty input
|
| 190 |
+
empty_result = aggregator.aggregate_segments([])
|
| 191 |
+
print(f"Empty input result: {empty_result}")
|
| 192 |
+
|
| 193 |
+
# Test single segment
|
| 194 |
+
single_result = aggregator.aggregate_segments(["【Begin-1】Single segment【End-1】"])
|
| 195 |
+
print(f"Single segment result: {single_result}")
|
| 196 |
+
|
| 197 |
+
# Test text without markers (should return empty)
|
| 198 |
+
no_marker_result = aggregator.aggregate_segments(["Normal text without markers"])
|
| 199 |
+
print(f"Text without markers result: {no_marker_result}")
|
| 200 |
+
|
| 201 |
+
print("\n=== Demo Completed ===")
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
if __name__ == "__main__":
|
| 205 |
+
demo()
|
config.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"FreeChunkerModel"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"bos_token_id": 0,
|
| 7 |
+
"classifier_dropout": null,
|
| 8 |
+
"dtype": "float32",
|
| 9 |
+
"eos_token_id": 2,
|
| 10 |
+
"hidden_act": "gelu",
|
| 11 |
+
"hidden_dropout_prob": 0.1,
|
| 12 |
+
"hidden_size": 1024,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"intermediate_size": 4096,
|
| 15 |
+
"layer_norm_eps": 1e-05,
|
| 16 |
+
"max_position_embeddings": 8194,
|
| 17 |
+
"model_type": "xlm-roberta",
|
| 18 |
+
"num_attention_heads": 16,
|
| 19 |
+
"num_hidden_layers": 24,
|
| 20 |
+
"output_past": true,
|
| 21 |
+
"pad_token_id": 1,
|
| 22 |
+
"position_embedding_type": "absolute",
|
| 23 |
+
"transformers_version": "4.56.1",
|
| 24 |
+
"type_vocab_size": 1,
|
| 25 |
+
"use_cache": true,
|
| 26 |
+
"vocab_size": 2,
|
| 27 |
+
"max_power": 4,
|
| 28 |
+
"auto_map": {
|
| 29 |
+
"AutoConfig": "configuration_freechunker.FreeChunkerConfig",
|
| 30 |
+
"AutoModel": "modeling_freechunker.FreeChunkerModel"
|
| 31 |
+
}
|
| 32 |
+
}
|
configuration_freechunker.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""FreeChunker configuration: Modified from XLM-RoBERTa configuration"""
|
| 17 |
+
|
| 18 |
+
from collections import OrderedDict
|
| 19 |
+
from typing import Mapping
|
| 20 |
+
|
| 21 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 22 |
+
from transformers.onnx import OnnxConfig
|
| 23 |
+
from transformers.utils import logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class FreeChunkerConfig(PretrainedConfig):
|
| 30 |
+
r"""
|
| 31 |
+
This is the configuration class to store the configuration of a [`FreeChunkerModel`] or a [`TFFreeChunkerModel`]. It
|
| 32 |
+
is used to instantiate a XLM-RoBERTa model according to the specified arguments, defining the model architecture.
|
| 33 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the FreeChunker
|
| 34 |
+
[FacebookAI/xlm-roberta-base](https://huggingface.co/FacebookAI/xlm-roberta-base) architecture.
|
| 35 |
+
|
| 36 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 37 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
| 42 |
+
Vocabulary size of the XLM-RoBERTa model. Defines the number of different tokens that can be represented by
|
| 43 |
+
the `inputs_ids` passed when calling [`FreeChunekrModel`] or [`TFFreeChunekrModel`].
|
| 44 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 45 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 46 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 47 |
+
Number of hidden layers in the Transformer encoder.
|
| 48 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 49 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 50 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 51 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
| 52 |
+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
| 53 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 54 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 55 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 56 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 57 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 58 |
+
The dropout ratio for the attention probabilities.
|
| 59 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 60 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 61 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 62 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
| 63 |
+
The vocabulary size of the `token_type_ids` passed when calling [`FreeChunekrModel`] or
|
| 64 |
+
[`TFFreeChunekrModel`].
|
| 65 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 66 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 67 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 68 |
+
The epsilon used by the layer normalization layers.
|
| 69 |
+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
| 70 |
+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
| 71 |
+
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
| 72 |
+
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
|
| 73 |
+
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
| 74 |
+
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
|
| 75 |
+
is_decoder (`bool`, *optional*, defaults to `False`):
|
| 76 |
+
Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
|
| 77 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 78 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 79 |
+
relevant if `config.is_decoder=True`.
|
| 80 |
+
classifier_dropout (`float`, *optional*):
|
| 81 |
+
The dropout ratio for the classification head.
|
| 82 |
+
|
| 83 |
+
Examples:
|
| 84 |
+
|
| 85 |
+
```python
|
| 86 |
+
>>> from transformers import FreeChunekrConfig, FreeChunekrModel
|
| 87 |
+
|
| 88 |
+
>>> # Initializing a XLM-RoBERTa FacebookAI/xlm-roberta-base style configuration
|
| 89 |
+
>>> configuration = FreeChunekrConfig()
|
| 90 |
+
|
| 91 |
+
>>> # Initializing a model (with random weights) from the FacebookAI/xlm-roberta-base style configuration
|
| 92 |
+
>>> model = FreeChunekrModel(configuration)
|
| 93 |
+
|
| 94 |
+
>>> # Accessing the model configuration
|
| 95 |
+
>>> configuration = model.config
|
| 96 |
+
```"""
|
| 97 |
+
|
| 98 |
+
model_type = "xlm-roberta"
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
vocab_size=30522,
|
| 103 |
+
hidden_size=768,
|
| 104 |
+
num_hidden_layers=12,
|
| 105 |
+
num_attention_heads=12,
|
| 106 |
+
intermediate_size=3072,
|
| 107 |
+
hidden_act="gelu",
|
| 108 |
+
hidden_dropout_prob=0.1,
|
| 109 |
+
attention_probs_dropout_prob=0.1,
|
| 110 |
+
max_position_embeddings=512,
|
| 111 |
+
type_vocab_size=2,
|
| 112 |
+
initializer_range=0.02,
|
| 113 |
+
layer_norm_eps=1e-12,
|
| 114 |
+
pad_token_id=1,
|
| 115 |
+
bos_token_id=0,
|
| 116 |
+
eos_token_id=2,
|
| 117 |
+
position_embedding_type="absolute",
|
| 118 |
+
use_cache=True,
|
| 119 |
+
classifier_dropout=None,
|
| 120 |
+
**kwargs,
|
| 121 |
+
):
|
| 122 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 123 |
+
|
| 124 |
+
self.vocab_size = vocab_size
|
| 125 |
+
self.hidden_size = hidden_size
|
| 126 |
+
self.num_hidden_layers = num_hidden_layers
|
| 127 |
+
self.num_attention_heads = num_attention_heads
|
| 128 |
+
self.hidden_act = hidden_act
|
| 129 |
+
self.intermediate_size = intermediate_size
|
| 130 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 131 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 132 |
+
self.max_position_embeddings = max_position_embeddings
|
| 133 |
+
self.type_vocab_size = type_vocab_size
|
| 134 |
+
self.initializer_range = initializer_range
|
| 135 |
+
self.layer_norm_eps = layer_norm_eps
|
| 136 |
+
self.position_embedding_type = position_embedding_type
|
| 137 |
+
self.use_cache = use_cache
|
| 138 |
+
self.classifier_dropout = classifier_dropout
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# Copied from transformers.models.roberta.configuration_roberta.RobertaOnnxConfig with Roberta->FreeChunekr
|
| 142 |
+
class FreeChunekrOnnxConfig(OnnxConfig):
|
| 143 |
+
@property
|
| 144 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 145 |
+
if self.task == "multiple-choice":
|
| 146 |
+
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
| 147 |
+
else:
|
| 148 |
+
dynamic_axis = {0: "batch", 1: "sequence"}
|
| 149 |
+
return OrderedDict(
|
| 150 |
+
[
|
| 151 |
+
("input_ids", dynamic_axis),
|
| 152 |
+
("attention_mask", dynamic_axis),
|
| 153 |
+
]
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
__all__ = ["FreeChunkerConfig", "FreeChunkerOnnxConfig"]
|
encoder.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
UnifiedEncoder - Unified text encoder
|
| 4 |
+
Integrates sentence splitting and multiple encoding models into a unified interface
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pickle
|
| 10 |
+
import os
|
| 11 |
+
from typing import List, Tuple, Union
|
| 12 |
+
from .sentenizer import Sentenceizer
|
| 13 |
+
from .modeling_freechunker import FreeChunkerModel
|
| 14 |
+
from .aggregator import TextAggregator
|
| 15 |
+
|
| 16 |
+
class UnifiedEncoder:
|
| 17 |
+
"""
|
| 18 |
+
Unified text encoder, supporting text sentence splitting and encoding for multiple models
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, model_name: str, local_model_path: str = None):
|
| 22 |
+
"""
|
| 23 |
+
Initialize unified text encoder
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
model_name (str): Model name (e.g. 'bge-m3', 'jina', 'nomic')
|
| 27 |
+
local_model_path (str, optional): Local model path for loading FreeChunker weights.
|
| 28 |
+
If None, tries to load from current directory or Hugging Face.
|
| 29 |
+
"""
|
| 30 |
+
self.model_name = model_name
|
| 31 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'mps')
|
| 32 |
+
|
| 33 |
+
# Initialize text aggregator
|
| 34 |
+
self.aggregator = TextAggregator()
|
| 35 |
+
|
| 36 |
+
print(f"Initializing unified text encoder, model: {model_name}")
|
| 37 |
+
print(f"Using local model path: {local_model_path}")
|
| 38 |
+
print(f"Using device: {self.device}")
|
| 39 |
+
|
| 40 |
+
# If local_model_path is not provided, assume current directory or let from_pretrained handle it
|
| 41 |
+
if local_model_path is None:
|
| 42 |
+
local_model_path = "."
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
self.model = FreeChunkerModel.from_pretrained(local_model_path)
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"Failed to load model from {local_model_path}: {e}")
|
| 48 |
+
print("Trying to load as a fresh model or from HF hub if applicable...")
|
| 49 |
+
# Fallback or re-raise
|
| 50 |
+
raise e
|
| 51 |
+
|
| 52 |
+
self.model.to(self.device)
|
| 53 |
+
self.model.eval()
|
| 54 |
+
|
| 55 |
+
# Select model and preprocessor based on model name
|
| 56 |
+
# Predefined model mapping: name -> (local_path, HF_model_ID)
|
| 57 |
+
# Note: Local paths are environment specific, so we primarily rely on HF IDs or passed arguments
|
| 58 |
+
model_configs = {
|
| 59 |
+
'bge-m3': ('/share/home/ecnuzwx/UnifiedRAG/cache/models--BAAI--bge-m3', 'BAAI/bge-m3'),
|
| 60 |
+
'nomic-embed-text-v1.5': ('/share/home/ecnuzwx/UnifiedRAG/cache/models--nomic-ai--nomic-embed-text-v1.5', 'nomic-ai/nomic-embed-text-v1.5'),
|
| 61 |
+
'jina': ('/share/home/ecnuzwx/UnifiedRAG/cache/models--jinaai--jina-embeddings-v2-small-en', 'jinaai/jina-embeddings-v2-small-en')
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
if model_name in model_configs:
|
| 65 |
+
local_path, hf_id = model_configs[model_name]
|
| 66 |
+
# Prioritize local path if it exists, otherwise use HF ID
|
| 67 |
+
if os.path.exists(local_path):
|
| 68 |
+
target_model = local_path
|
| 69 |
+
else:
|
| 70 |
+
target_model = hf_id
|
| 71 |
+
|
| 72 |
+
self.sentenceizer = Sentenceizer(model_name=target_model)
|
| 73 |
+
else:
|
| 74 |
+
# Try using model_name directly as path or ID
|
| 75 |
+
print(f"Unknown predefined model name: {model_name}, trying to load directly...")
|
| 76 |
+
self.sentenceizer = Sentenceizer(model_name=model_name)
|
| 77 |
+
|
| 78 |
+
print("Unified text encoder initialized!")
|
| 79 |
+
|
| 80 |
+
def encode(self, text: str, show_progress: bool = True) -> Tuple[List[str], np.ndarray, List[List[str]]]:
|
| 81 |
+
"""
|
| 82 |
+
Split text and encode, return results grouped by shift_matrix
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
text (str): Input text
|
| 86 |
+
show_progress (bool): Whether to show progress
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Tuple[List[str], np.ndarray, List[List[str]]]: (Original sentence list, encoded vector array, grouped sentence list by shift_matrix)
|
| 90 |
+
"""
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
sentences, input_embeddings = self.sentenceizer.split_and_encode(text, show_progress=show_progress)
|
| 93 |
+
|
| 94 |
+
if len(sentences) == 0:
|
| 95 |
+
return sentences, np.array([]), []
|
| 96 |
+
if isinstance(input_embeddings, np.ndarray):
|
| 97 |
+
input_embeddings = torch.from_numpy(input_embeddings)
|
| 98 |
+
input_embeddings = input_embeddings.to(self.device)
|
| 99 |
+
inputs_embeds = input_embeddings.unsqueeze(0)
|
| 100 |
+
outputs = self.model(inputs_embeds=inputs_embeds)
|
| 101 |
+
final_embeddings = outputs['embedding']
|
| 102 |
+
shift_matrix = outputs['shift_matrix']
|
| 103 |
+
|
| 104 |
+
# Group sentences using shift_matrix
|
| 105 |
+
sentences = [f"【Begin-{num}】" + sentence + f"【End-{num}】" for num, sentence in enumerate(sentences)]
|
| 106 |
+
grouped_sentences = self._group_sentences_by_shift_matrix(sentences, shift_matrix)
|
| 107 |
+
result_embeddings = final_embeddings.cpu().numpy()
|
| 108 |
+
|
| 109 |
+
return sentences, result_embeddings, grouped_sentences
|
| 110 |
+
|
| 111 |
+
def _group_sentences_by_shift_matrix(self, sentences: List[str], shift_matrix: torch.Tensor) -> List[List[str]]:
|
| 112 |
+
"""
|
| 113 |
+
Group sentences according to shift_matrix (Optimized version)
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
sentences (List[str]): Original sentence list
|
| 117 |
+
shift_matrix (torch.Tensor): Mask matrix with shape [num_chunks, seq_len]
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
List[List[str]]: List of sentences grouped by shift_matrix
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
grouped_sentences = []
|
| 124 |
+
num_chunks, seq_len = shift_matrix.shape
|
| 125 |
+
|
| 126 |
+
for chunk_idx in range(num_chunks):
|
| 127 |
+
chunk_mask = shift_matrix[chunk_idx] # [seq_len]
|
| 128 |
+
|
| 129 |
+
# Use vectorized operation to get all indices that are 1
|
| 130 |
+
valid_indices = (chunk_mask == 1).nonzero(as_tuple=True)[0].cpu().numpy()
|
| 131 |
+
|
| 132 |
+
# Select only indices within the sentence list range
|
| 133 |
+
valid_indices = valid_indices[valid_indices < len(sentences)]
|
| 134 |
+
|
| 135 |
+
if len(valid_indices) > 0:
|
| 136 |
+
# Get sentences directly by index
|
| 137 |
+
chunk_sentences = [sentences[idx] for idx in valid_indices]
|
| 138 |
+
grouped_sentences.append(chunk_sentences)
|
| 139 |
+
|
| 140 |
+
return grouped_sentences
|
| 141 |
+
|
| 142 |
+
def build_vector_store(self, text: str, show_progress: bool = True):
|
| 143 |
+
"""
|
| 144 |
+
Build vector store based on long text
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
text (str): Long text
|
| 148 |
+
show_progress (bool): Whether to show progress
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
sentences, embeddings, grouped_sentences = self.encode(text, show_progress)
|
| 152 |
+
|
| 153 |
+
# grouped_texts = [" ".join(group) if isinstance(group, list) else str(group) for group in grouped_sentences]
|
| 154 |
+
|
| 155 |
+
grouped_texts = sentences + [" ".join(group) if isinstance(group, list) else str(group) for group in grouped_sentences]
|
| 156 |
+
|
| 157 |
+
self.vector_store = {
|
| 158 |
+
'sentences': sentences, # Keep original sentences for debugging
|
| 159 |
+
'embeddings': embeddings, # embeddings correspond to grouped_sentences
|
| 160 |
+
'grouped_sentences': grouped_sentences, # Original grouping structure
|
| 161 |
+
'grouped_texts': grouped_texts # Text for retrieval
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
if show_progress:
|
| 165 |
+
print(f"Vector store built: {len(sentences)} original sentences, {len(grouped_sentences)} groups, {len(embeddings)} embedding vectors")
|
| 166 |
+
print(f"Vector store verification: embeddings.shape={embeddings.shape}, grouped_texts count={len(grouped_texts)}\n")
|
| 167 |
+
|
| 168 |
+
def query(self, query: str, top_k: int = 5, aggregation_mode: str = 'post', tokenizer=None) -> Union[List[Tuple[str, float]], str]:
|
| 169 |
+
"""
|
| 170 |
+
Query vector store
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
query (str): Query text
|
| 174 |
+
top_k (int): Return top k most similar results
|
| 175 |
+
aggregation_mode (str): Aggregation mode
|
| 176 |
+
- 'none': No aggregation, return top_k results directly [(text, score), ...]
|
| 177 |
+
- 'post': Post-aggregation mode, return aggregated text string
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Union[List[Tuple[str, float]], str]:
|
| 181 |
+
- If aggregation_mode='none', return [(sentence, similarity_score), ...]
|
| 182 |
+
- If aggregation_mode='post', return aggregated string
|
| 183 |
+
"""
|
| 184 |
+
if not hasattr(self, 'vector_store'):
|
| 185 |
+
raise ValueError("Vector store not built, please call build_vector_store method first")
|
| 186 |
+
|
| 187 |
+
# Encode query text
|
| 188 |
+
query_embeddings = self.sentenceizer.encode([query])
|
| 189 |
+
query_embedding = query_embeddings[0]
|
| 190 |
+
|
| 191 |
+
# Calculate cosine similarity
|
| 192 |
+
similarities = np.dot(self.vector_store['embeddings'], query_embedding)
|
| 193 |
+
|
| 194 |
+
# Sort (descending)
|
| 195 |
+
sorted_indices = np.argsort(similarities)[::-1]
|
| 196 |
+
|
| 197 |
+
if aggregation_mode == 'none':
|
| 198 |
+
return self._get_direct_results(sorted_indices, similarities, top_k)
|
| 199 |
+
elif aggregation_mode == 'post':
|
| 200 |
+
return self._post_aggregation(sorted_indices, similarities, top_k, tokenizer=tokenizer)
|
| 201 |
+
else:
|
| 202 |
+
print(f"Warning: Unknown aggregation_mode '{aggregation_mode}', falling back to 'none'")
|
| 203 |
+
return self._get_direct_results(sorted_indices, similarities, top_k)
|
| 204 |
+
|
| 205 |
+
def _get_direct_results(self, sorted_indices: np.ndarray, similarities: np.ndarray, top_k: int) -> List[Tuple[str, float]]:
|
| 206 |
+
|
| 207 |
+
available_count = len(self.vector_store['grouped_texts'])
|
| 208 |
+
actual_top_k = min(top_k, available_count)
|
| 209 |
+
top_indices = sorted_indices[:actual_top_k]
|
| 210 |
+
|
| 211 |
+
results = []
|
| 212 |
+
for idx in top_indices:
|
| 213 |
+
if idx < len(self.vector_store['grouped_texts']):
|
| 214 |
+
grouped_text = self.vector_store['grouped_texts'][idx]
|
| 215 |
+
score = similarities[idx]
|
| 216 |
+
results.append((grouped_text, float(score)))
|
| 217 |
+
|
| 218 |
+
return results
|
| 219 |
+
|
| 220 |
+
def _post_aggregation(self, sorted_indices: np.ndarray, similarities: np.ndarray, top_k: int, tokenizer=None) -> List[Tuple[str, float]]:
|
| 221 |
+
|
| 222 |
+
# Get top_k results first
|
| 223 |
+
direct_results = self._get_direct_results(sorted_indices, similarities, top_k)
|
| 224 |
+
|
| 225 |
+
# Extract text parts for aggregation
|
| 226 |
+
texts = [text for text, score in direct_results]
|
| 227 |
+
|
| 228 |
+
aggregated_texts = self.aggregator.aggregate_segments(texts)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
return aggregated_texts
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def load_vector_store(self, file_path: str):
|
| 235 |
+
"""
|
| 236 |
+
Load vector store from file
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
file_path (str): Vector store file path
|
| 240 |
+
"""
|
| 241 |
+
if not os.path.exists(file_path):
|
| 242 |
+
raise FileNotFoundError(f"Vector store file not found: {file_path}")
|
| 243 |
+
|
| 244 |
+
with open(file_path, 'rb') as f:
|
| 245 |
+
self.vector_store = pickle.load(f)
|
| 246 |
+
|
| 247 |
+
print(f"Vector store loaded from {file_path}")
|
| 248 |
+
print(f"Vector store info: {len(self.vector_store['grouped_texts'])} groups, embedding dimension: {self.vector_store['embeddings'].shape}")
|
| 249 |
+
|
| 250 |
+
def has_vector_store(self) -> bool:
|
| 251 |
+
"""
|
| 252 |
+
Check if vector store is built or loaded
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
bool: Whether a vector store is available
|
| 256 |
+
"""
|
| 257 |
+
return hasattr(self, 'vector_store') and self.vector_store is not None
|
final_loss_curve.png
ADDED
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aca47fe33b4f8d4b507ac46c60817fc9287a1b81d63c0ad06559196d64c9a30d
|
| 3 |
+
size 1247063776
|
modeling_freechunker.py
ADDED
|
@@ -0,0 +1,768 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""FreeChunker model: Modified from PyTorch XLM-RoBERTa model."""
|
| 17 |
+
from .utils import generate_shifted_matrix
|
| 18 |
+
import math
|
| 19 |
+
from typing import Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.utils.checkpoint
|
| 23 |
+
from packaging import version
|
| 24 |
+
from torch import nn
|
| 25 |
+
from transformers.activations import ACT2FN
|
| 26 |
+
from transformers.modeling_outputs import (
|
| 27 |
+
BaseModelOutputWithPoolingAndCrossAttentions
|
| 28 |
+
)
|
| 29 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 30 |
+
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
| 31 |
+
from transformers.utils import (
|
| 32 |
+
add_code_sample_docstrings,
|
| 33 |
+
add_start_docstrings,
|
| 34 |
+
add_start_docstrings_to_model_forward,
|
| 35 |
+
get_torch_version,
|
| 36 |
+
logging
|
| 37 |
+
)
|
| 38 |
+
from .configuration_freechunker import FreeChunkerConfig
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
_CHECKPOINT_FOR_DOC = "FacebookAI/xlm-roberta-base"
|
| 44 |
+
_CONFIG_FOR_DOC = "FreeChunkerConfig"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->FreeChunker
|
| 48 |
+
class FreeChunkerEmbeddings(nn.Module):
|
| 49 |
+
"""
|
| 50 |
+
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
|
| 54 |
+
def __init__(self, config):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 57 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 58 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 59 |
+
|
| 60 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 61 |
+
# any TensorFlow checkpoint file
|
| 62 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 63 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 64 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 65 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 66 |
+
self.register_buffer(
|
| 67 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 68 |
+
)
|
| 69 |
+
self.register_buffer(
|
| 70 |
+
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# End copy
|
| 74 |
+
self.padding_idx = config.pad_token_id
|
| 75 |
+
self.position_embeddings = nn.Embedding(
|
| 76 |
+
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def forward(
|
| 80 |
+
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None
|
| 81 |
+
):
|
| 82 |
+
if position_ids is None:
|
| 83 |
+
if input_ids is not None:
|
| 84 |
+
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
| 85 |
+
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
|
| 86 |
+
else:
|
| 87 |
+
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
| 88 |
+
|
| 89 |
+
if input_ids is not None:
|
| 90 |
+
input_shape = input_ids.size()
|
| 91 |
+
else:
|
| 92 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 93 |
+
|
| 94 |
+
seq_length = input_shape[1]
|
| 95 |
+
|
| 96 |
+
if position_ids is None:
|
| 97 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=self.position_ids.device)
|
| 98 |
+
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
| 99 |
+
|
| 100 |
+
if token_type_ids is None:
|
| 101 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
| 102 |
+
|
| 103 |
+
if inputs_embeds is None:
|
| 104 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 105 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 106 |
+
|
| 107 |
+
embeddings = inputs_embeds + token_type_embeddings
|
| 108 |
+
if self.position_embedding_type == "absolute":
|
| 109 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 110 |
+
embeddings += position_embeddings
|
| 111 |
+
embeddings = self.LayerNorm(embeddings)
|
| 112 |
+
embeddings = self.dropout(embeddings)
|
| 113 |
+
return embeddings
|
| 114 |
+
|
| 115 |
+
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
| 116 |
+
"""
|
| 117 |
+
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
inputs_embeds: torch.Tensor
|
| 121 |
+
|
| 122 |
+
Returns: torch.Tensor
|
| 123 |
+
"""
|
| 124 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 125 |
+
sequence_length = input_shape[1]
|
| 126 |
+
|
| 127 |
+
position_ids = torch.arange(
|
| 128 |
+
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
|
| 129 |
+
)
|
| 130 |
+
return position_ids.unsqueeze(0).expand(input_shape)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->FreeChunker
|
| 134 |
+
class FreeChunkerSelfAttention(nn.Module):
|
| 135 |
+
def __init__(self, config, position_embedding_type=None):
|
| 136 |
+
super().__init__()
|
| 137 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 138 |
+
raise ValueError(
|
| 139 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 140 |
+
f"heads ({config.num_attention_heads})"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self.num_attention_heads = config.num_attention_heads
|
| 144 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 145 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 146 |
+
|
| 147 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 148 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 149 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 150 |
+
|
| 151 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 152 |
+
self.position_embedding_type = position_embedding_type or getattr(
|
| 153 |
+
config, "position_embedding_type", "absolute"
|
| 154 |
+
)
|
| 155 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 156 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 157 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 158 |
+
|
| 159 |
+
self.is_decoder = config.is_decoder
|
| 160 |
+
|
| 161 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 162 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 163 |
+
x = x.view(new_x_shape)
|
| 164 |
+
return x.permute(0, 2, 1, 3)
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
hidden_states: torch.Tensor,
|
| 169 |
+
hidden_states2: torch.Tensor, # Second input stream, required parameter
|
| 170 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 171 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 172 |
+
output_attentions: Optional[bool] = False,
|
| 173 |
+
) -> Tuple[torch.Tensor]:
|
| 174 |
+
# Query comes from hidden_states
|
| 175 |
+
mixed_query_layer = self.query(hidden_states)
|
| 176 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 177 |
+
|
| 178 |
+
# Key and Value come from hidden_states2
|
| 179 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states2))
|
| 180 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states2))
|
| 181 |
+
|
| 182 |
+
# Calculate attention scores
|
| 183 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 184 |
+
|
| 185 |
+
# Modified positional encoding handling
|
| 186 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 187 |
+
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
| 188 |
+
|
| 189 |
+
# hidden_states positions are all the first position (0, 0, 0, ...)
|
| 190 |
+
position_ids_l = torch.zeros(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 191 |
+
# hidden_states2 uses normal incremental position sequence (0, 1, 2, 3, ...)
|
| 192 |
+
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 193 |
+
distance = position_ids_l - position_ids_r
|
| 194 |
+
|
| 195 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 196 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 197 |
+
|
| 198 |
+
if self.position_embedding_type == "relative_key":
|
| 199 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 200 |
+
attention_scores = attention_scores + relative_position_scores
|
| 201 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 202 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 203 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 204 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 205 |
+
|
| 206 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 207 |
+
|
| 208 |
+
if attention_mask is not None:
|
| 209 |
+
attention_scores = attention_scores + attention_mask
|
| 210 |
+
|
| 211 |
+
# Normalize to probabilities
|
| 212 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 213 |
+
attention_probs = self.dropout(attention_probs)
|
| 214 |
+
|
| 215 |
+
# Apply head mask
|
| 216 |
+
if head_mask is not None:
|
| 217 |
+
attention_probs = attention_probs * head_mask
|
| 218 |
+
|
| 219 |
+
# Calculate context
|
| 220 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 221 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 222 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 223 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
| 224 |
+
|
| 225 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 226 |
+
return outputs
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->FreeChunker
|
| 230 |
+
class FreeChunkerSdpaSelfAttention(FreeChunkerSelfAttention):
|
| 231 |
+
def __init__(self, config, position_embedding_type=None):
|
| 232 |
+
super().__init__(config, position_embedding_type=position_embedding_type)
|
| 233 |
+
self.dropout_prob = config.attention_probs_dropout_prob
|
| 234 |
+
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
|
| 235 |
+
|
| 236 |
+
def forward(
|
| 237 |
+
self,
|
| 238 |
+
hidden_states: torch.Tensor,
|
| 239 |
+
hidden_states2: torch.Tensor, # Second input stream, required parameter
|
| 240 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 241 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 242 |
+
output_attentions: Optional[bool] = False,
|
| 243 |
+
) -> Tuple[torch.Tensor]:
|
| 244 |
+
# If relative positional encoding, output attentions, or head mask are present, fallback to parent implementation
|
| 245 |
+
if (self.position_embedding_type != "absolute" or
|
| 246 |
+
output_attentions or
|
| 247 |
+
head_mask is not None):
|
| 248 |
+
return super().forward(
|
| 249 |
+
hidden_states,
|
| 250 |
+
hidden_states2,
|
| 251 |
+
attention_mask,
|
| 252 |
+
head_mask,
|
| 253 |
+
output_attentions,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Use optimized implementation of SDPA
|
| 257 |
+
bsz, tgt_len, _ = hidden_states.size()
|
| 258 |
+
|
| 259 |
+
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
| 260 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states2))
|
| 261 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states2))
|
| 262 |
+
|
| 263 |
+
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
| 264 |
+
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
| 265 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
| 266 |
+
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
|
| 267 |
+
query_layer = query_layer.contiguous()
|
| 268 |
+
key_layer = key_layer.contiguous()
|
| 269 |
+
value_layer = value_layer.contiguous()
|
| 270 |
+
|
| 271 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 272 |
+
query_layer,
|
| 273 |
+
key_layer,
|
| 274 |
+
value_layer,
|
| 275 |
+
attn_mask=attention_mask,
|
| 276 |
+
dropout_p=self.dropout_prob if self.training else 0.0,
|
| 277 |
+
is_causal=False, # For customized tasks, causal mask is not used
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
attn_output = attn_output.transpose(1, 2)
|
| 281 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
|
| 282 |
+
|
| 283 |
+
outputs = (attn_output,)
|
| 284 |
+
return outputs
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->FreeChunker
|
| 288 |
+
class FreeChunkerSelfOutput(nn.Module):
|
| 289 |
+
def __init__(self, config):
|
| 290 |
+
super().__init__()
|
| 291 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 292 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 293 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 294 |
+
|
| 295 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 296 |
+
hidden_states = self.dense(hidden_states)
|
| 297 |
+
hidden_states = self.dropout(hidden_states)
|
| 298 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 299 |
+
return hidden_states
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
XLM_ROBERTA_SELF_ATTENTION_CLASSES = {
|
| 303 |
+
"eager": FreeChunkerSelfAttention,
|
| 304 |
+
"sdpa": FreeChunkerSdpaSelfAttention,
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->FreeChunker
|
| 309 |
+
class FreeChunkerAttention(nn.Module):
|
| 310 |
+
def __init__(self, config, position_embedding_type=None):
|
| 311 |
+
super().__init__()
|
| 312 |
+
self.self = XLM_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
| 313 |
+
config, position_embedding_type=position_embedding_type
|
| 314 |
+
)
|
| 315 |
+
self.output = FreeChunkerSelfOutput(config)
|
| 316 |
+
self.pruned_heads = set()
|
| 317 |
+
|
| 318 |
+
def prune_heads(self, heads):
|
| 319 |
+
if len(heads) == 0:
|
| 320 |
+
return
|
| 321 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 322 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Prune linear layers
|
| 326 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 327 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 328 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 329 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 330 |
+
|
| 331 |
+
# Update hyper params and store pruned heads
|
| 332 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 333 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 334 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 335 |
+
|
| 336 |
+
def forward(
|
| 337 |
+
self,
|
| 338 |
+
hidden_states: torch.Tensor,
|
| 339 |
+
hidden_states2: torch.Tensor, # Second input stream, required parameter
|
| 340 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 341 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 342 |
+
output_attentions: Optional[bool] = False,
|
| 343 |
+
) -> Tuple[torch.Tensor]:
|
| 344 |
+
self_outputs = self.self(
|
| 345 |
+
hidden_states,
|
| 346 |
+
hidden_states2, # Pass second input stream
|
| 347 |
+
attention_mask,
|
| 348 |
+
head_mask,
|
| 349 |
+
output_attentions,
|
| 350 |
+
)
|
| 351 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 352 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 353 |
+
return outputs
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate with Roberta->FreeChunker
|
| 357 |
+
class FreeChunkerIntermediate(nn.Module):
|
| 358 |
+
def __init__(self, config):
|
| 359 |
+
super().__init__()
|
| 360 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 361 |
+
if isinstance(config.hidden_act, str):
|
| 362 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 363 |
+
else:
|
| 364 |
+
self.intermediate_act_fn = config.hidden_act
|
| 365 |
+
|
| 366 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 367 |
+
hidden_states = self.dense(hidden_states)
|
| 368 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 369 |
+
return hidden_states
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaOutput with Roberta->FreeChunker
|
| 373 |
+
class FreeChunkerOutput(nn.Module):
|
| 374 |
+
def __init__(self, config):
|
| 375 |
+
super().__init__()
|
| 376 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 377 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 378 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 379 |
+
|
| 380 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 381 |
+
hidden_states = self.dense(hidden_states)
|
| 382 |
+
hidden_states = self.dropout(hidden_states)
|
| 383 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 384 |
+
return hidden_states
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->FreeChunker
|
| 388 |
+
class FreeChunkerLayer(nn.Module):
|
| 389 |
+
def __init__(self, config):
|
| 390 |
+
super().__init__()
|
| 391 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 392 |
+
self.seq_len_dim = 1
|
| 393 |
+
self.attention = FreeChunkerAttention(config)
|
| 394 |
+
self.is_decoder = config.is_decoder
|
| 395 |
+
self.add_cross_attention = config.add_cross_attention
|
| 396 |
+
if self.add_cross_attention:
|
| 397 |
+
if not self.is_decoder:
|
| 398 |
+
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
| 399 |
+
self.crossattention = FreeChunkerAttention(config, position_embedding_type="absolute")
|
| 400 |
+
self.intermediate = FreeChunkerIntermediate(config)
|
| 401 |
+
self.output = FreeChunkerOutput(config)
|
| 402 |
+
|
| 403 |
+
def forward(
|
| 404 |
+
self,
|
| 405 |
+
hidden_states: torch.Tensor,
|
| 406 |
+
hidden_states2: torch.Tensor, # Second input stream, required parameter
|
| 407 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 408 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 409 |
+
output_attentions: Optional[bool] = False,
|
| 410 |
+
) -> Tuple[torch.Tensor]:
|
| 411 |
+
attention_outputs = self.attention(
|
| 412 |
+
hidden_states,
|
| 413 |
+
hidden_states2, # Pass second input stream
|
| 414 |
+
attention_mask,
|
| 415 |
+
head_mask,
|
| 416 |
+
output_attentions,
|
| 417 |
+
)
|
| 418 |
+
attention_output = attention_outputs[0]
|
| 419 |
+
|
| 420 |
+
outputs = attention_outputs[1:] # add self attentions if we output attention weights
|
| 421 |
+
|
| 422 |
+
layer_output = self.feed_forward_chunk(attention_output)
|
| 423 |
+
outputs = (layer_output,) + outputs
|
| 424 |
+
|
| 425 |
+
return outputs
|
| 426 |
+
|
| 427 |
+
def feed_forward_chunk(self, attention_output):
|
| 428 |
+
intermediate_output = self.intermediate(attention_output)
|
| 429 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 430 |
+
return layer_output
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->FreeChunker
|
| 434 |
+
class FreeChunkerEncoder(nn.Module):
|
| 435 |
+
def __init__(self, config):
|
| 436 |
+
super().__init__()
|
| 437 |
+
self.config = config
|
| 438 |
+
self.layer = nn.ModuleList([FreeChunkerLayer(config) for _ in range(config.num_hidden_layers)])
|
| 439 |
+
self.gradient_checkpointing = False
|
| 440 |
+
|
| 441 |
+
def forward(
|
| 442 |
+
self,
|
| 443 |
+
hidden_states: torch.Tensor,
|
| 444 |
+
hidden_states2: torch.Tensor, # Second input stream, required parameter
|
| 445 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 446 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 447 |
+
) -> torch.Tensor:
|
| 448 |
+
|
| 449 |
+
for i, layer_module in enumerate(self.layer):
|
| 450 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 451 |
+
|
| 452 |
+
if self.gradient_checkpointing and self.training:
|
| 453 |
+
|
| 454 |
+
def create_custom_forward(module):
|
| 455 |
+
def custom_forward(*inputs):
|
| 456 |
+
return module(*inputs)
|
| 457 |
+
|
| 458 |
+
return custom_forward
|
| 459 |
+
|
| 460 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 461 |
+
create_custom_forward(layer_module),
|
| 462 |
+
hidden_states,
|
| 463 |
+
hidden_states2, # Pass second input stream
|
| 464 |
+
attention_mask,
|
| 465 |
+
layer_head_mask,
|
| 466 |
+
)
|
| 467 |
+
else:
|
| 468 |
+
layer_outputs = layer_module(
|
| 469 |
+
hidden_states,
|
| 470 |
+
hidden_states2, # Pass second input stream
|
| 471 |
+
attention_mask,
|
| 472 |
+
layer_head_mask,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
hidden_states = layer_outputs[0]
|
| 476 |
+
|
| 477 |
+
return hidden_states
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaPooler with Roberta->FreeChunker
|
| 481 |
+
class FreeChunkerPooler(nn.Module):
|
| 482 |
+
def __init__(self, config):
|
| 483 |
+
super().__init__()
|
| 484 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 485 |
+
self.activation = nn.Tanh()
|
| 486 |
+
|
| 487 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 488 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 489 |
+
# to the first token.
|
| 490 |
+
first_token_tensor = hidden_states[:, 0]
|
| 491 |
+
pooled_output = self.dense(first_token_tensor)
|
| 492 |
+
pooled_output = self.activation(pooled_output)
|
| 493 |
+
return pooled_output
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel with Roberta->FreeChunker
|
| 497 |
+
class FreeChunkerPreTrainedModel(PreTrainedModel):
|
| 498 |
+
"""
|
| 499 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 500 |
+
models.
|
| 501 |
+
"""
|
| 502 |
+
|
| 503 |
+
config_class = FreeChunkerConfig
|
| 504 |
+
base_model_prefix = "roberta"
|
| 505 |
+
supports_gradient_checkpointing = True
|
| 506 |
+
_no_split_modules = ["FreeChunkerEmbeddings", "FreeChunkerSelfAttention", "FreeChunkerSdpaSelfAttention"]
|
| 507 |
+
_supports_sdpa = True
|
| 508 |
+
|
| 509 |
+
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
| 510 |
+
def _init_weights(self, module):
|
| 511 |
+
"""Initialize the weights"""
|
| 512 |
+
if isinstance(module, nn.Linear):
|
| 513 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 514 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 515 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 516 |
+
if module.bias is not None:
|
| 517 |
+
module.bias.data.zero_()
|
| 518 |
+
elif isinstance(module, nn.Embedding):
|
| 519 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 520 |
+
if module.padding_idx is not None:
|
| 521 |
+
module.weight.data[module.padding_idx].zero_()
|
| 522 |
+
elif isinstance(module, nn.LayerNorm):
|
| 523 |
+
module.bias.data.zero_()
|
| 524 |
+
module.weight.data.fill_(1.0)
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
XLM_ROBERTA_START_DOCSTRING = r"""
|
| 528 |
+
|
| 529 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 530 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 531 |
+
etc.)
|
| 532 |
+
|
| 533 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 534 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 535 |
+
and behavior.
|
| 536 |
+
|
| 537 |
+
Parameters:
|
| 538 |
+
config ([`FreeChunkerConfig`]): Model configuration class with all the parameters of the
|
| 539 |
+
model. Initializing with a config file does not load the weights associated with the model, only the
|
| 540 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 541 |
+
"""
|
| 542 |
+
|
| 543 |
+
XLM_ROBERTA_INPUTS_DOCSTRING = r"""
|
| 544 |
+
Args:
|
| 545 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
| 546 |
+
Indices of input sequence tokens in the vocabulary.
|
| 547 |
+
|
| 548 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 549 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 550 |
+
|
| 551 |
+
[What are input IDs?](../glossary#input-ids)
|
| 552 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
| 553 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 554 |
+
|
| 555 |
+
- 1 for tokens that are **not masked**,
|
| 556 |
+
- 0 for tokens that are **masked**.
|
| 557 |
+
|
| 558 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 559 |
+
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 560 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 561 |
+
1]`:
|
| 562 |
+
|
| 563 |
+
- 0 corresponds to a *sentence A* token,
|
| 564 |
+
- 1 corresponds to a *sentence B* token.
|
| 565 |
+
|
| 566 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 567 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 568 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 569 |
+
config.max_position_embeddings - 1]`.
|
| 570 |
+
|
| 571 |
+
[What are position IDs?](../glossary#position-ids)
|
| 572 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 573 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 574 |
+
|
| 575 |
+
- 1 indicates the head is **not masked**,
|
| 576 |
+
- 0 indicates the head is **masked**.
|
| 577 |
+
|
| 578 |
+
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
| 579 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 580 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 581 |
+
model's internal embedding lookup matrix.
|
| 582 |
+
output_attentions (`bool`, *optional*):
|
| 583 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 584 |
+
tensors for more detail.
|
| 585 |
+
output_hidden_states (`bool`, *optional*):
|
| 586 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 587 |
+
more detail.
|
| 588 |
+
return_dict (`bool`, *optional*):
|
| 589 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 590 |
+
"""
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
@add_start_docstrings(
|
| 594 |
+
"The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
| 595 |
+
XLM_ROBERTA_START_DOCSTRING,
|
| 596 |
+
)
|
| 597 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaModel with Roberta->FreeChunker, ROBERTA->XLM_ROBERTA
|
| 598 |
+
class FreeChunkerModel(FreeChunkerPreTrainedModel):
|
| 599 |
+
"""
|
| 600 |
+
|
| 601 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
| 602 |
+
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
| 603 |
+
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
| 604 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 605 |
+
|
| 606 |
+
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
| 607 |
+
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
| 608 |
+
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
| 609 |
+
"""
|
| 610 |
+
|
| 611 |
+
_no_split_modules = ["FreeChunkerEmbeddings", "FreeChunkerLayer"]
|
| 612 |
+
|
| 613 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 614 |
+
super().__init__(config)
|
| 615 |
+
self.config = config
|
| 616 |
+
self.config.vocab_size = 2
|
| 617 |
+
self.embeddings = FreeChunkerEmbeddings(self.config)
|
| 618 |
+
self.encoder = FreeChunkerEncoder(config)
|
| 619 |
+
|
| 620 |
+
self.pooler = FreeChunkerPooler(config) if add_pooling_layer else None
|
| 621 |
+
|
| 622 |
+
self.attn_implementation = config._attn_implementation
|
| 623 |
+
self.position_embedding_type = config.position_embedding_type
|
| 624 |
+
|
| 625 |
+
# Initialize weights and apply final processing
|
| 626 |
+
self.post_init()
|
| 627 |
+
|
| 628 |
+
def get_input_embeddings(self):
|
| 629 |
+
return self.embeddings.word_embeddings
|
| 630 |
+
|
| 631 |
+
def set_input_embeddings(self, value):
|
| 632 |
+
self.embeddings.word_embeddings = value
|
| 633 |
+
|
| 634 |
+
def _prune_heads(self, heads_to_prune):
|
| 635 |
+
"""
|
| 636 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 637 |
+
class PreTrainedModel
|
| 638 |
+
"""
|
| 639 |
+
for layer, heads in heads_to_prune.items():
|
| 640 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 641 |
+
|
| 642 |
+
@add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 643 |
+
@add_code_sample_docstrings(
|
| 644 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 645 |
+
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
| 646 |
+
config_class=_CONFIG_FOR_DOC,
|
| 647 |
+
)
|
| 648 |
+
def forward(
|
| 649 |
+
self,
|
| 650 |
+
inputs_embeds=None,
|
| 651 |
+
labels=None,
|
| 652 |
+
loss_weights: bool = False,
|
| 653 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 654 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 655 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 656 |
+
encoder_attention_mask: Optional[torch.Tensor] = None
|
| 657 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
| 658 |
+
|
| 659 |
+
# Get input device
|
| 660 |
+
input_device = inputs_embeds.device
|
| 661 |
+
|
| 662 |
+
# Dimension adaptation: if input dimension is less than 1024, pad with 0
|
| 663 |
+
original_hidden_size = inputs_embeds.shape[-1]
|
| 664 |
+
target_hidden_size = self.config.hidden_size # 1024
|
| 665 |
+
|
| 666 |
+
if original_hidden_size < target_hidden_size:
|
| 667 |
+
# Calculate number of dimensions to pad
|
| 668 |
+
padding_size = target_hidden_size - original_hidden_size
|
| 669 |
+
# Pad with 0 on the last dimension
|
| 670 |
+
padding = torch.zeros(inputs_embeds.shape[:-1] + (padding_size,),
|
| 671 |
+
device=input_device, dtype=inputs_embeds.dtype)
|
| 672 |
+
inputs_embeds = torch.cat([inputs_embeds, padding], dim=-1)
|
| 673 |
+
|
| 674 |
+
# Adjust max_power based on sequence length
|
| 675 |
+
sequence_length = inputs_embeds.shape[1]
|
| 676 |
+
|
| 677 |
+
shifted_matrix = generate_shifted_matrix(sequence_length, device=input_device)
|
| 678 |
+
|
| 679 |
+
# Generate attention mask
|
| 680 |
+
encoder_attention_mask = shifted_matrix.transpose(1, 2)
|
| 681 |
+
encoder_attention_mask = torch.where(encoder_attention_mask == 1.0, 0.0, float('-inf'))[:, None, :, :]
|
| 682 |
+
|
| 683 |
+
# Fixed input IDs and position IDs
|
| 684 |
+
input_ids = torch.tensor([[0] * shifted_matrix.shape[2]], device=input_device)
|
| 685 |
+
position_ids = torch.tensor([[0] * shifted_matrix.shape[2]], device=input_device)
|
| 686 |
+
|
| 687 |
+
# Embedding layer processing
|
| 688 |
+
embedding_output = self.embeddings(
|
| 689 |
+
input_ids=input_ids,
|
| 690 |
+
position_ids=position_ids,
|
| 691 |
+
token_type_ids=None,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
# Set second input stream
|
| 695 |
+
encoder_hidden_states = inputs_embeds
|
| 696 |
+
|
| 697 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 698 |
+
|
| 699 |
+
# Encoder processing
|
| 700 |
+
sequence_output = self.encoder(
|
| 701 |
+
embedding_output,
|
| 702 |
+
hidden_states2=encoder_hidden_states, # Second input stream
|
| 703 |
+
attention_mask=encoder_attention_mask, # Use generated mask
|
| 704 |
+
head_mask=head_mask,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
if original_hidden_size < target_hidden_size:
|
| 708 |
+
|
| 709 |
+
sequence_output = sequence_output[..., :original_hidden_size]
|
| 710 |
+
# Also truncate inputs_embeds back to original size to match dimensions of sequence_output
|
| 711 |
+
inputs_embeds = inputs_embeds[..., :original_hidden_size]
|
| 712 |
+
|
| 713 |
+
shift_matrix = shifted_matrix.transpose(1, 2).squeeze(0)
|
| 714 |
+
# Loss calculation
|
| 715 |
+
loss = None
|
| 716 |
+
if labels is not None:
|
| 717 |
+
emb = sequence_output.view(-1, sequence_output.shape[-1])
|
| 718 |
+
lab = labels.view(-1, labels.shape[-1])
|
| 719 |
+
target = torch.ones(emb.size(0), device=emb.device)
|
| 720 |
+
|
| 721 |
+
# If weights are provided, use weighted cosine loss
|
| 722 |
+
if loss_weights:
|
| 723 |
+
# Validate weight dimensions
|
| 724 |
+
loss_weights = shift_matrix.sum(dim=1).to(emb.device)
|
| 725 |
+
|
| 726 |
+
# Calculate unweighted cosine loss
|
| 727 |
+
cos_loss_fn = torch.nn.CosineEmbeddingLoss(reduction='none')
|
| 728 |
+
individual_losses = cos_loss_fn(emb, lab, target)
|
| 729 |
+
|
| 730 |
+
# Apply weights and calculate weighted average
|
| 731 |
+
weighted_losses = individual_losses * loss_weights
|
| 732 |
+
loss = weighted_losses.sum() / loss_weights.sum()
|
| 733 |
+
else:
|
| 734 |
+
# Use standard cosine loss
|
| 735 |
+
cos_loss = torch.nn.CosineEmbeddingLoss()
|
| 736 |
+
loss = cos_loss(emb, lab, target)
|
| 737 |
+
|
| 738 |
+
embedding = torch.cat([inputs_embeds, sequence_output], dim=1)
|
| 739 |
+
embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)
|
| 740 |
+
# embedding = torch.nn.functional.normalize(sequence_output, p=2, dim=-1)
|
| 741 |
+
|
| 742 |
+
return {
|
| 743 |
+
"loss": loss,
|
| 744 |
+
"embedding": embedding.squeeze(0),
|
| 745 |
+
"shift_matrix": shift_matrix
|
| 746 |
+
}
|
| 747 |
+
|
| 748 |
+
# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
|
| 749 |
+
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
| 750 |
+
"""
|
| 751 |
+
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
| 752 |
+
are ignored. This is modified from fairseq's `utils.make_positions`.
|
| 753 |
+
|
| 754 |
+
Args:
|
| 755 |
+
x: torch.Tensor x:
|
| 756 |
+
|
| 757 |
+
Returns: torch.Tensor
|
| 758 |
+
"""
|
| 759 |
+
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
| 760 |
+
mask = input_ids.ne(padding_idx).int()
|
| 761 |
+
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
|
| 762 |
+
return incremental_indices.long() + padding_idx
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
__all__ = [
|
| 766 |
+
"FreeChunkerModel",
|
| 767 |
+
"FreeChunkerPreTrainedModel",
|
| 768 |
+
]
|
sentenizer.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Sentenceizer - Universal sentence splitter + vector encoder
|
| 4 |
+
Length-constrained sentence splitting tool that protects special formats but not quotes/brackets
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import List, Tuple, Union, Optional
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
| 10 |
+
from transformers import AutoTokenizer
|
| 11 |
+
|
| 12 |
+
# --- Integrated TraditionalChunking ---
|
| 13 |
+
|
| 14 |
+
def setup_tokenizer(model_name="xlm-roberta-base"):
|
| 15 |
+
"""Setup tokenizer"""
|
| 16 |
+
try:
|
| 17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f"Warning: Could not load tokenizer for {model_name}: {e}. Falling back to bert-base-uncased")
|
| 20 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 21 |
+
return tokenizer
|
| 22 |
+
|
| 23 |
+
def fixed_size_chunking(text: str, tokenizer=None, chunk_size: int = 256, overlap: int = 0) -> List[str]:
|
| 24 |
+
"""
|
| 25 |
+
Fixed-size chunking based on token count (Strict truncation)
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
text: Text to chunk
|
| 29 |
+
tokenizer: Tokenizer
|
| 30 |
+
chunk_size: Token count per chunk
|
| 31 |
+
overlap: Overlapping token count
|
| 32 |
+
"""
|
| 33 |
+
if tokenizer is None:
|
| 34 |
+
tokenizer = setup_tokenizer()
|
| 35 |
+
|
| 36 |
+
# Encode the entire text, do not add special tokens to keep it clean
|
| 37 |
+
tokens = tokenizer.encode(text, add_special_tokens=False)
|
| 38 |
+
total_tokens = len(tokens)
|
| 39 |
+
|
| 40 |
+
chunks = []
|
| 41 |
+
|
| 42 |
+
# Calculate step size
|
| 43 |
+
step = chunk_size - overlap
|
| 44 |
+
if step <= 0:
|
| 45 |
+
step = 1 # Prevent infinite loop, theoretically overlap should be smaller than chunk_size
|
| 46 |
+
|
| 47 |
+
for i in range(0, total_tokens, step):
|
| 48 |
+
# Truncate tokens for current chunk
|
| 49 |
+
chunk_tokens = tokens[i : i + chunk_size]
|
| 50 |
+
|
| 51 |
+
# Decode back to text
|
| 52 |
+
chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
|
| 53 |
+
|
| 54 |
+
if chunk_text.strip():
|
| 55 |
+
chunks.append(chunk_text.strip())
|
| 56 |
+
|
| 57 |
+
return chunks
|
| 58 |
+
|
| 59 |
+
def traditional_chunking(text, tokenizer=None, chunk_size=256, overlap=0):
|
| 60 |
+
"""
|
| 61 |
+
Fixed-size chunking based on tokens
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
text: Text to chunk
|
| 65 |
+
tokenizer: Tokenizer
|
| 66 |
+
chunk_size: Token count per chunk
|
| 67 |
+
overlap: Overlapping token count
|
| 68 |
+
"""
|
| 69 |
+
return fixed_size_chunking(text, tokenizer, chunk_size, overlap)
|
| 70 |
+
|
| 71 |
+
class TraditionalChunking:
|
| 72 |
+
def __init__(self, model_name_or_path=None, tokenizer=None, chunk_size=256, overlap=0):
|
| 73 |
+
if tokenizer is not None:
|
| 74 |
+
self.tokenizer = tokenizer
|
| 75 |
+
elif model_name_or_path is not None:
|
| 76 |
+
self.tokenizer = setup_tokenizer(model_name_or_path)
|
| 77 |
+
else:
|
| 78 |
+
self.tokenizer = setup_tokenizer()
|
| 79 |
+
self.chunk_size = chunk_size
|
| 80 |
+
self.overlap = overlap
|
| 81 |
+
|
| 82 |
+
def chunk(self, text):
|
| 83 |
+
return traditional_chunking(text, self.tokenizer, self.chunk_size, self.overlap)
|
| 84 |
+
|
| 85 |
+
# --- End TraditionalChunking ---
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class Sentenceizer:
|
| 89 |
+
"""
|
| 90 |
+
Universal sentence splitter and encoder with length constraints, protecting special formats
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(self, model_name: Optional[str] = None):
|
| 94 |
+
"""
|
| 95 |
+
Initialize Sentenceizer
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
model_name (str, optional): SentenceTransformer model name
|
| 99 |
+
If None, no encoding model is loaded
|
| 100 |
+
"""
|
| 101 |
+
# Initialize chunker with model_name if available, otherwise default
|
| 102 |
+
self.chunker = TraditionalChunking(model_name_or_path=model_name if model_name else "xlm-roberta-base", chunk_size=256, overlap=0)
|
| 103 |
+
|
| 104 |
+
self.model = None
|
| 105 |
+
self.model_name = model_name
|
| 106 |
+
if model_name:
|
| 107 |
+
print(f"Loading sentence transformer model: {model_name}")
|
| 108 |
+
self.model = SentenceTransformer(model_name, trust_remote_code=True)
|
| 109 |
+
self.model.eval()
|
| 110 |
+
print(f"Model loaded successfully. Embedding dimension: {self.model.get_sentence_embedding_dimension()}")
|
| 111 |
+
|
| 112 |
+
def split(self, text: str) -> List[str]:
|
| 113 |
+
"""
|
| 114 |
+
Split text into sentence list using NLTK sent_tokenize, then merge short sentences
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
text (str): Input text
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
List[str]: List of sentences
|
| 121 |
+
"""
|
| 122 |
+
if not text.strip():
|
| 123 |
+
return []
|
| 124 |
+
|
| 125 |
+
return self.chunker.chunk(text)
|
| 126 |
+
|
| 127 |
+
def split_with_positions(self, text: str) -> List[Tuple[str, int, int]]:
|
| 128 |
+
"""
|
| 129 |
+
Split text and return sentences with their positions in the original text
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
text (str): Input text
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
List[Tuple[str, int, int]]: List of (sentence, start_position, end_position)
|
| 136 |
+
"""
|
| 137 |
+
sentences = self.split(text)
|
| 138 |
+
sentences_with_pos = []
|
| 139 |
+
|
| 140 |
+
start_pos = 0
|
| 141 |
+
for sentence in sentences:
|
| 142 |
+
# Find sentence position in original text
|
| 143 |
+
pos = text.find(sentence, start_pos)
|
| 144 |
+
if pos != -1:
|
| 145 |
+
sentences_with_pos.append((sentence, pos, pos + len(sentence)))
|
| 146 |
+
start_pos = pos + len(sentence)
|
| 147 |
+
else:
|
| 148 |
+
# If not found (possibly due to merging or splitting), use estimated position
|
| 149 |
+
sentences_with_pos.append((sentence, start_pos, start_pos + len(sentence)))
|
| 150 |
+
start_pos += len(sentence)
|
| 151 |
+
|
| 152 |
+
return sentences_with_pos
|
| 153 |
+
|
| 154 |
+
def encode(self, text: Union[str, List[str]], show_progress: bool = False) -> np.ndarray:
|
| 155 |
+
"""
|
| 156 |
+
Encode text
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
text (Union[str, List[str]]): Input text, can be a single string or list of strings
|
| 160 |
+
If it's a string, sentence splitting will be performed first
|
| 161 |
+
show_progress (bool): Whether to show progress bar
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
np.ndarray: Encoded vector array with shape (n_sentences, embedding_dim)
|
| 165 |
+
|
| 166 |
+
Raises:
|
| 167 |
+
ValueError: If no model is loaded
|
| 168 |
+
"""
|
| 169 |
+
if self.model is None:
|
| 170 |
+
raise ValueError("No model loaded. Please initialize with a model_name.")
|
| 171 |
+
|
| 172 |
+
# If input is string, perform sentence splitting first
|
| 173 |
+
if isinstance(text, str):
|
| 174 |
+
sentences = self.split(text)
|
| 175 |
+
else:
|
| 176 |
+
sentences = text
|
| 177 |
+
|
| 178 |
+
if not sentences:
|
| 179 |
+
return np.array([])
|
| 180 |
+
|
| 181 |
+
# Use sentence transformer for encoding, limit max batch size to 64
|
| 182 |
+
embeddings = self.model.encode(
|
| 183 |
+
sentences,
|
| 184 |
+
show_progress_bar=show_progress,
|
| 185 |
+
convert_to_numpy=True,
|
| 186 |
+
batch_size=4
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return embeddings
|
| 190 |
+
|
| 191 |
+
def split_and_encode(self, text: str, show_progress: bool = True) -> Tuple[List[str], np.ndarray]:
|
| 192 |
+
"""
|
| 193 |
+
Split text and encode
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
text (str): Input text
|
| 197 |
+
show_progress (bool): Whether to show progress bar
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
Tuple[List[str], np.ndarray]: (sentence list, encoded vector array)
|
| 201 |
+
"""
|
| 202 |
+
sentences = self.split(text)
|
| 203 |
+
embeddings = self.encode(sentences, show_progress=show_progress)
|
| 204 |
+
return sentences, embeddings
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def embedding_dimension(self) -> int:
|
| 208 |
+
"""Get embedding dimension"""
|
| 209 |
+
if self.model is None:
|
| 210 |
+
raise ValueError("No model loaded.")
|
| 211 |
+
return self.model.get_sentence_embedding_dimension()
|
| 212 |
+
|
| 213 |
+
def test_sentenceizer():
|
| 214 |
+
"""Test universal sentence splitting functionality and protection mechanisms"""
|
| 215 |
+
|
| 216 |
+
print("=== Testing Universal Sentence Splitting and Protection Mechanisms ===")
|
| 217 |
+
|
| 218 |
+
# Use reasonable length constraints for testing
|
| 219 |
+
sentenceizer = Sentenceizer()
|
| 220 |
+
|
| 221 |
+
test_cases = [
|
| 222 |
+
# Basic sentence splitting test
|
| 223 |
+
"This is the first sentence. This is the second sentence! This is the third sentence?",
|
| 224 |
+
|
| 225 |
+
# Quote sentence splitting test (should be able to split)
|
| 226 |
+
'He said "Hello there. How are you? I hope you are well." Then he left.',
|
| 227 |
+
|
| 228 |
+
# Abbreviation protection test (should not split at abbreviations)
|
| 229 |
+
"Dr. Smith is here. Mr. Jones left at 3 p.m. today. The U.S. economy is growing.",
|
| 230 |
+
|
| 231 |
+
# Number protection test (should not split within numbers)
|
| 232 |
+
"The temperature is 36.5 degrees. The price is $19.99. Version 2.1.3 was released.",
|
| 233 |
+
|
| 234 |
+
# Ellipsis protection test (should not split at ellipsis)
|
| 235 |
+
"This is incomplete... But this continues the thought. Another sentence follows.",
|
| 236 |
+
|
| 237 |
+
# URL protection test (should not split within URLs)
|
| 238 |
+
"Visit https://www.example.com for more info. The website www.test.org has details.",
|
| 239 |
+
|
| 240 |
+
# Email protection test (should not split within emails)
|
| 241 |
+
"Contact me at john.doe@example.com for questions. Send reports to admin@company.org please.",
|
| 242 |
+
|
| 243 |
+
# Date and time protection test
|
| 244 |
+
"The meeting is on 12/25/2023. We start at 3:30 p.m. today. See you then.",
|
| 245 |
+
|
| 246 |
+
# Non-English text test
|
| 247 |
+
"这是第一个句子。这是第二个句子!这是第三个句子?",
|
| 248 |
+
|
| 249 |
+
# Mixed text test
|
| 250 |
+
"This is English. 这是中文。Mix of both languages!",
|
| 251 |
+
|
| 252 |
+
# Complex mixed test
|
| 253 |
+
"访问 https://www.baidu.com 获取信息。联系邮箱是 test@163.com。价格为 ¥99.99 元。",
|
| 254 |
+
|
| 255 |
+
# Long sentence test (should be split)
|
| 256 |
+
"This is a very long sentence that should be split into multiple parts because it exceeds the maximum length limit that we have set for individual sentences in our system, and we need to handle this properly.",
|
| 257 |
+
|
| 258 |
+
# Sentences starting with numbers
|
| 259 |
+
"Today is sunny. 123 people attended the meeting. Everyone was happy.",
|
| 260 |
+
|
| 261 |
+
# Sentences starting with special characters
|
| 262 |
+
"First sentence here. \"Quoted sentence comes next.\" Final sentence ends it.",
|
| 263 |
+
]
|
| 264 |
+
|
| 265 |
+
for i, text in enumerate(test_cases, 1):
|
| 266 |
+
print(f"\n--- Test Case {i} ---")
|
| 267 |
+
print(f"Original: {text}")
|
| 268 |
+
|
| 269 |
+
sentences = sentenceizer.split(text)
|
| 270 |
+
print(f"Split Result ({len(sentences)} sentences):")
|
| 271 |
+
for j, sentence in enumerate(sentences, 1):
|
| 272 |
+
print(f" {j}. ({len(sentence)} chars) {sentence}")
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
if __name__ == "__main__":
|
| 276 |
+
test_sentenceizer()
|
training_losses.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Utility Functions
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
def generate_shifted_matrix(n, device=None):
|
| 11 |
+
|
| 12 |
+
matrix_columns = []
|
| 13 |
+
granularities = [2, 4]
|
| 14 |
+
|
| 15 |
+
for granularity in granularities:
|
| 16 |
+
if granularity > n:
|
| 17 |
+
continue
|
| 18 |
+
|
| 19 |
+
# Calculate step size for this granularity
|
| 20 |
+
step_size = max(1, granularity // 2)
|
| 21 |
+
max_start = n - granularity
|
| 22 |
+
|
| 23 |
+
for start in range(0, max_start + 1, step_size):
|
| 24 |
+
column = torch.zeros(n, dtype=torch.int, device=device)
|
| 25 |
+
column[start:start + granularity] = 1
|
| 26 |
+
matrix_columns.append(column)
|
| 27 |
+
|
| 28 |
+
# If the last position is not covered, add a mask at the end
|
| 29 |
+
if max_start >= 0 and (max_start % step_size) != 0:
|
| 30 |
+
column = torch.zeros(n, dtype=torch.int, device=device)
|
| 31 |
+
column[-granularity:] = 1
|
| 32 |
+
matrix_columns.append(column)
|
| 33 |
+
|
| 34 |
+
if not matrix_columns:
|
| 35 |
+
column = torch.ones(n, dtype=torch.int, device=device)
|
| 36 |
+
matrix_columns.append(column)
|
| 37 |
+
|
| 38 |
+
result = torch.stack(matrix_columns, dim=1).unsqueeze(0).expand(1, -1, -1)
|
| 39 |
+
return result
|
| 40 |
+
|
| 41 |
+
def create_attention_mask(shift_matrix: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
"""
|
| 43 |
+
Create attention mask from shift matrix
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
shift_matrix (torch.Tensor): shift matrix, shape [num_chunks, seq_len]
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
torch.Tensor: attention mask, shape [1, num_chunks, seq_len, seq_len]
|
| 50 |
+
"""
|
| 51 |
+
# Transpose and create attention mask
|
| 52 |
+
attention_mask = shift_matrix.transpose(0, 1) # [seq_len, num_chunks]
|
| 53 |
+
attention_mask = torch.where(attention_mask == 1.0, 0.0, float('-inf'))
|
| 54 |
+
|
| 55 |
+
# Add dimensions to match expected shape of attention
|
| 56 |
+
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, num_chunks]
|
| 57 |
+
|
| 58 |
+
return attention_mask
|
| 59 |
+
|
| 60 |
+
def normalize_embeddings(embeddings: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
L2 normalize embeddings
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
embeddings (torch.Tensor): Embeddings
|
| 66 |
+
eps (float): Small value to prevent division by zero
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
torch.Tensor: Normalized embeddings
|
| 70 |
+
"""
|
| 71 |
+
norm = torch.norm(embeddings, dim=-1, keepdim=True)
|
| 72 |
+
return embeddings / (norm + eps)
|
| 73 |
+
|
| 74 |
+
def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 75 |
+
"""
|
| 76 |
+
Calculate cosine similarity
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
a (torch.Tensor): Vector A
|
| 80 |
+
b (torch.Tensor): Vector B
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
torch.Tensor: Cosine similarity
|
| 84 |
+
"""
|
| 85 |
+
a_norm = normalize_embeddings(a)
|
| 86 |
+
b_norm = normalize_embeddings(b)
|
| 87 |
+
return torch.sum(a_norm * b_norm, dim=-1)
|
| 88 |
+
|
| 89 |
+
def batch_cosine_similarity(embeddings1: torch.Tensor, embeddings2: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
"""
|
| 91 |
+
Calculate batch cosine similarity
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
embeddings1 (torch.Tensor): Embeddings group 1, shape [N, dim]
|
| 95 |
+
embeddings2 (torch.Tensor): Embeddings group 2, shape [M, dim]
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
torch.Tensor: Similarity matrix, shape [N, M]
|
| 99 |
+
"""
|
| 100 |
+
embeddings1_norm = normalize_embeddings(embeddings1)
|
| 101 |
+
embeddings2_norm = normalize_embeddings(embeddings2)
|
| 102 |
+
|
| 103 |
+
return torch.matmul(embeddings1_norm, embeddings2_norm.transpose(0, 1))
|
| 104 |
+
|
| 105 |
+
def split_embeddings_by_shift_matrix(embeddings: torch.Tensor, shift_matrix: torch.Tensor) -> list:
|
| 106 |
+
"""
|
| 107 |
+
Split embeddings based on shift matrix
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
embeddings (torch.Tensor): Embeddings, shape [seq_len, hidden_dim]
|
| 111 |
+
shift_matrix (torch.Tensor): shift matrix, shape [num_chunks, seq_len]
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
list: List of split embeddings
|
| 115 |
+
"""
|
| 116 |
+
split_embeddings = []
|
| 117 |
+
num_chunks, seq_len = shift_matrix.shape
|
| 118 |
+
|
| 119 |
+
for chunk_idx in range(num_chunks):
|
| 120 |
+
mask = shift_matrix[chunk_idx] # [seq_len]
|
| 121 |
+
indices = torch.nonzero(mask, as_tuple=True)[0] # Get indices of non-zero positions
|
| 122 |
+
|
| 123 |
+
if len(indices) > 0:
|
| 124 |
+
chunk_embeddings = embeddings[indices] # [chunk_size, hidden_dim]
|
| 125 |
+
split_embeddings.append(chunk_embeddings)
|
| 126 |
+
|
| 127 |
+
return split_embeddings
|
| 128 |
+
|
| 129 |
+
def pool_embeddings(embeddings: torch.Tensor, method: str = 'mean') -> torch.Tensor:
|
| 130 |
+
"""
|
| 131 |
+
Pool embeddings
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
embeddings (torch.Tensor): Embeddings, shape [seq_len, hidden_dim]
|
| 135 |
+
method (str): Pooling method, optional 'mean', 'max', 'first', 'last'
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
torch.Tensor: Pooled vector, shape [hidden_dim]
|
| 139 |
+
"""
|
| 140 |
+
if method == 'mean':
|
| 141 |
+
return torch.mean(embeddings, dim=0)
|
| 142 |
+
elif method == 'max':
|
| 143 |
+
return torch.max(embeddings, dim=0)[0]
|
| 144 |
+
elif method == 'first':
|
| 145 |
+
return embeddings[0]
|
| 146 |
+
elif method == 'last':
|
| 147 |
+
return embeddings[-1]
|
| 148 |
+
else:
|
| 149 |
+
raise ValueError(f"Unknown pooling method: {method}")
|
| 150 |
+
|
| 151 |
+
def aggregate_chunk_embeddings(split_embeddings: list, method: str = 'mean') -> torch.Tensor:
|
| 152 |
+
"""
|
| 153 |
+
Aggregate chunk embeddings
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
split_embeddings (list): List of split embeddings
|
| 157 |
+
method (str): Aggregation method
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
torch.Tensor: Aggregated embeddings, shape [num_chunks, hidden_dim]
|
| 161 |
+
"""
|
| 162 |
+
if not split_embeddings:
|
| 163 |
+
return torch.tensor([])
|
| 164 |
+
|
| 165 |
+
aggregated = []
|
| 166 |
+
for chunk_embeddings in split_embeddings:
|
| 167 |
+
pooled = pool_embeddings(chunk_embeddings, method)
|
| 168 |
+
aggregated.append(pooled)
|
| 169 |
+
|
| 170 |
+
return torch.stack(aggregated)
|
| 171 |
+
|
| 172 |
+
def safe_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
|
| 173 |
+
"""
|
| 174 |
+
Safely convert tensor to numpy array
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
tensor (torch.Tensor): Input tensor
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
np.ndarray: Numpy array
|
| 181 |
+
"""
|
| 182 |
+
if tensor.requires_grad:
|
| 183 |
+
tensor = tensor.detach()
|
| 184 |
+
|
| 185 |
+
if tensor.is_cuda:
|
| 186 |
+
tensor = tensor.cpu()
|
| 187 |
+
|
| 188 |
+
return tensor.numpy()
|
| 189 |
+
|
| 190 |
+
def ensure_tensor_on_device(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
|
| 191 |
+
"""
|
| 192 |
+
Ensure tensor is on specified device
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
tensor (torch.Tensor): Input tensor
|
| 196 |
+
device (torch.device): Target device
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
torch.Tensor: Tensor on target device
|
| 200 |
+
"""
|
| 201 |
+
if tensor.device != device:
|
| 202 |
+
tensor = tensor.to(device)
|
| 203 |
+
return tensor
|
| 204 |
+
|
| 205 |
+
def get_available_device() -> torch.device:
|
| 206 |
+
"""
|
| 207 |
+
Get available device
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
torch.device: Available device
|
| 211 |
+
"""
|
| 212 |
+
if torch.cuda.is_available():
|
| 213 |
+
return torch.device('cuda')
|
| 214 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 215 |
+
return torch.device('mps')
|
| 216 |
+
else:
|
| 217 |
+
return torch.device('cpu')
|
| 218 |
+
|
| 219 |
+
def print_tensor_info(tensor: torch.Tensor, name: str = "tensor"):
|
| 220 |
+
"""
|
| 221 |
+
Print tensor info
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
tensor (torch.Tensor): Input tensor
|
| 225 |
+
name (str): Tensor name
|
| 226 |
+
"""
|
| 227 |
+
print(f"{name}:")
|
| 228 |
+
print(f" Shape: {tensor.shape}")
|
| 229 |
+
print(f" Data Type: {tensor.dtype}")
|
| 230 |
+
print(f" Device: {tensor.device}")
|
| 231 |
+
print(f" Requires Grad: {tensor.requires_grad}")
|
| 232 |
+
if tensor.numel() > 0:
|
| 233 |
+
print(f" Value Range: [{tensor.min().item():.6f}, {tensor.max().item():.6f}]")
|
| 234 |
+
print(f" Mean: {tensor.mean().item():.6f}")
|
| 235 |
+
print(f" Std Dev: {tensor.std().item():.6f}")
|