matin-ebrahimkhani commited on
Commit
f5e403b
·
verified ·
1 Parent(s): 4d0138d

Upload the model

Browse files
README.md ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BertForTokenClassificationWithFourO
2
+
3
+ [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/USERNAME/MODEL_NAME)
4
+
5
+ A specialized token classification model built on BERT with a custom classifier for Persian text spacing and formatting tasks.
6
+
7
+ ## Model Description
8
+
9
+ This model is built on a BERT architecture with a custom token classification head called FourOClassifier. It's specifically designed for processing Persian text to correct or add proper spacing characters.
10
+
11
+ ### Task
12
+
13
+ The model performs token classification to detect where spacing characters should be inserted in Persian text. It can operate in two modes:
14
+ - **Spacing Mode**: Uses pure model predictions to insert spaces
15
+ - **Correction Mode**: Combines model predictions with existing spacing in the text
16
+
17
+ ### Model Architecture
18
+
19
+ The model is based on the BERT architecture with a custom classifier head (FourOClassifier) that includes:
20
+ - Dense layer with ReLU activation
21
+ - Dropout for regularization
22
+ - Batch normalization
23
+ - Output projection layer
24
+
25
+ ## Usage
26
+
27
+ ### Installation
28
+
29
+ ```bash
30
+ pip install transformers torch
31
+ ```
32
+
33
+ ### Basic Usage
34
+
35
+ ```python
36
+ from transformers import AutoTokenizer
37
+ from modeling_custom import BertForTokenClassificationWithFourO
38
+ from labeler import Labeler
39
+ import torch
40
+
41
+ # Load model and tokenizer
42
+ model_path = "USERNAME/MODEL_NAME"
43
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
44
+ model = BertForTokenClassificationWithFourO.from_pretrained(model_path)
45
+ model.eval()
46
+
47
+ # Initialize labeler
48
+ labeler = Labeler(tags=(1, 2),
49
+ regexes=(r'[^\S\r\n\v\f]', r'\u200c'),
50
+ chars=(" ", "‌"),
51
+ class_count=2)
52
+
53
+ # Process text
54
+ def process_text(text, mode="space"):
55
+ # Create a pipeline for processing
56
+ from run import ModelPipeline
57
+ pipeline = ModelPipeline(model_path)
58
+ result = pipeline.process_text(text, mode)
59
+ return result
60
+
61
+ # Example
62
+ text = "این متن نمونه فارسی بدون فاصله گذاری مناسب است"
63
+ result = process_text(text, mode="space")
64
+ print(result)
65
+ ```
66
+
67
+ ### Command-line Usage
68
+
69
+ You can also use the provided command-line interface:
70
+
71
+ ```bash
72
+ python run.py --text "متن فارسی شما در اینجا" --mode space
73
+ ```
74
+
75
+ Or process a file:
76
+
77
+ ```bash
78
+ python run.py --file input.txt --output result.txt --mode correct
79
+ ```
80
+
81
+ The repository includes a sample `input.txt` file that you can use to test the model.
82
+
83
+ ## Parameters
84
+
85
+ - `mode`:
86
+ - `space`: Uses model predictions to add spaces
87
+ - `correct`: Combines model predictions with original text spacing (recommended for texts with some correct spacing)
88
+
89
+ ## Evaluation
90
+
91
+ The model achieves excellent performance in both operating modes:
92
+
93
+ ### Spacing Mode Evaluation
94
+
95
+ ```
96
+ ╒═════════╤═════════════╤══════════╤════════════╤════════════╕
97
+ │ Label │ Precision │ Recall │ Accuracy │ F1 Score │
98
+ ╞═════════╪═════════════╪══════════╪════════════╪════════════╡
99
+ │ 0 │ 0.994663 │ 0.997324 │ 0.997324 │ 0.995992 │
100
+ ├─────────┼─────────────┼──────────┼────────────┼────────────┤
101
+ │ 1 │ 0.989546 │ 0.987828 │ 0.987828 │ 0.988686 │
102
+ ├─────────┼─────────────┼──────────┼────────────┼────────────┤
103
+ │ 2 │ 0.913413 │ 0.932125 │ 0.932125 │ 0.922674 │
104
+ ├─────────┼─────────────┼──────────┼────────────┼────────────┤
105
+ │ Average │ 0.965874 │ 0.972426 │ 0.972426 │ 0.969117 │
106
+ ╘═════════╧═════════════╧══════════╧════════════╧════════════╛
107
+ ```
108
+
109
+ ### Correction Mode Evaluation
110
+
111
+ ```
112
+ ╒═════════╤═════════════╤══════════╤════════════╤════════════╕
113
+ │ Label │ Precision │ Recall │ Accuracy │ F1 Score │
114
+ ╞═════════╪═════════════╪══════════╪════════════╪════════════╡
115
+ │ 0 │ 0.995932 │ 0.998386 │ 0.998386 │ 0.997157 │
116
+ ├─────────┼─────────────┼──────────┼────────────┼────���───────┤
117
+ │ 1 │ 0.992917 │ 0.992227 │ 0.992227 │ 0.992572 │
118
+ ├─────────┼─────────────┼──────────┼────────────┼────────────┤
119
+ │ 2 │ 0.944612 │ 0.959428 │ 0.959428 │ 0.951962 │
120
+ ├─────────┼─────────────┼──────────┼────────────┼────────────┤
121
+ │ Average │ 0.97782 │ 0.983347 │ 0.983347 │ 0.980564 │
122
+ ╘═════════╧═════════════╧══════════╧════════════╧════════════╛
123
+ ```
124
+
125
+ Note that the correction mode achieves slightly better results by combining model predictions with existing text spacing.
126
+
127
+ ### Label Meaning
128
+ - Label 0: No spacing needed
129
+ - Label 1: Regular space character needed
130
+ - Label 2: ZWNJ character (‌) needed
131
+
132
+ ## Use Cases
133
+
134
+ This model is particularly useful for:
135
+ - Correcting Persian text with improper spacing
136
+ - Normalizing text from different sources
137
+ - Improving text readability for downstream NLP tasks
138
+ - Preprocessing Persian text for search engines or text analysis
139
+
140
+ ## Training
141
+
142
+ The model was trained on [DATASET_NAME] of Persian text with proper spacing annotations.
143
+
144
+ Training hyperparameters:
145
+ - Learning rate: [VALUE]
146
+ - Batch size: [VALUE]
147
+ - Training steps: [VALUE]
148
+ - [OTHER PARAMETERS]
149
+
150
+ ## Limitations
151
+
152
+ - The model is specifically designed for Persian text
153
+ - Performance may vary on specialized domains or technical texts
154
+ - Very long texts should be processed in chunks for optimal performance
155
+ - Tuned for execution on devices with CUDA
156
+ - [ANY OTHER LIMITATIONS]
157
+
158
+ ## Citation
159
+
160
+ ```
161
+ [CITATION_INFO]
162
+ ```
163
+
164
+ ## License
165
+
166
+ [LICENSE_INFO]
167
+
168
+ ## Contact
169
+
170
+ For questions or feedback, please contact [CONTACT_INFO].
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modeling_custom import BertForTokenClassificationWithFourO, FourOClassifier
__pycache__/labeler.cpython-313.pyc ADDED
Binary file (6.3 kB). View file
 
__pycache__/modeling_custom.cpython-313.pyc ADDED
Binary file (4.81 kB). View file
 
__pycache__/run.cpython-313.pyc ADDED
Binary file (8.29 kB). View file
 
config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "bert-base-multilingual-uncased",
3
+ "architectures": [
4
+ "BertForTokenClassificationWithFourO"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "directionality": "bidi",
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "id2label": {
13
+ "0": "LABEL_0",
14
+ "1": "LABEL_1",
15
+ "2": "LABEL_2"
16
+ },
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 3072,
19
+ "label2id": {
20
+ "LABEL_0": 0,
21
+ "LABEL_1": 1,
22
+ "LABEL_2": 2
23
+ },
24
+ "layer_norm_eps": 1e-12,
25
+ "max_position_embeddings": 512,
26
+ "model_type": "bert",
27
+ "num_attention_heads": 12,
28
+ "num_hidden_layers": 12,
29
+ "pad_token_id": 0,
30
+ "pooler_fc_size": 768,
31
+ "pooler_num_attention_heads": 12,
32
+ "pooler_num_fc_layers": 3,
33
+ "pooler_size_per_head": 128,
34
+ "pooler_type": "first_token_transform",
35
+ "position_embedding_type": "absolute",
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.44.0",
38
+ "type_vocab_size": 2,
39
+ "use_cache": true,
40
+ "vocab_size": 105879
41
+ }
input.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ میدانید که ادبیات فارسی گنجینه‌ای از آثار پر مغز و زیباست. خواندن متون فاخر فارسی نه تنها موجب افزایش دانش‌ بلکه باعث غنی‌تر شدن روح انسان می‌شود.بسیاری از اشعار ونثرهای کهن، مانند آثارحافظ‌سعدی،و فردوسی،باید با دقت خوانده شوند تا معنای کامل آنهارا دریافت کنیم.درک معنای ژرف چنینآثاری بدون داشتن دانش کافی از زبان و فرهنگ فارسی‌ امکانپذیر نیست.بنابراین، آموزش درست به کودکان و نوجوانان نقشحیاتی در زنده نگه داشتن این میراث فرهنگی دارد.فضای مجازی امروز نقش عمده‌ای در انتشار آثار فارسی دارد،اما بایدبه شیوه‌ای صحیح ازآن بهره ببریم تا از تحریف یا اشتباه در انتقال متون جلوگیری شود.
labeler.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ import sys
4
+
5
+ # Set the print option for numpy arrays to display the whole array without truncation
6
+ np.set_printoptions(threshold=sys.maxsize)
7
+
8
+
9
+ class Labeler:
10
+ # Define corpus types as class constants
11
+ WHOLE_RAW = 'whole_raw'
12
+ SENTS_RAW = 'sents_raw'
13
+
14
+ def __init__(self, tags=(1, 2),
15
+ regexes=(r'[^\S\r\n\v\f]', r'\u200c'),
16
+ chars=(" ", "‌"),
17
+ class_count=2):
18
+ self._tags = tags
19
+ self._regexes = regexes
20
+ self._class_chars = chars
21
+ self.class_count = class_count
22
+
23
+ self.data = None
24
+ self.labels = []
25
+ self.corpus_type = None
26
+
27
+ def _sent_labeler(self, sent: str):
28
+ """Label a single sentence and return characters and labels.
29
+
30
+ Args:
31
+ sent: The sentence to be labeled
32
+
33
+ Returns:
34
+ A tuple of (characters, labels)
35
+ """
36
+ # Initialize an empty list to store the labels
37
+ labels = [0] * len(sent)
38
+ # Convert the input sentence into a list of characters for the output
39
+ characters = list(sent)
40
+ # Initialize an empty list to store the indices of characters to be deleted
41
+ deletable = []
42
+
43
+ # Loop through the classes
44
+ for i in range(self.class_count):
45
+ # Find all the matches of the regular expression for the current class in the sentence
46
+ for match in re.finditer(self._regexes[i], sent):
47
+ # Get the index of the match
48
+ idx = match.start()
49
+ # Assign the corresponding tag to the label of the character before the match
50
+ labels[idx - 1] = self._tags[i]
51
+ # Add the index of the match to the list of deletable
52
+ deletable.append(idx)
53
+
54
+ # Sort the deletable in descending order to avoid index shifting
55
+ deletable = sorted(deletable, reverse=True)
56
+
57
+ # Remove the deletable characters and their labels
58
+ for idx in deletable:
59
+ characters.pop(idx)
60
+ labels.pop(idx)
61
+
62
+ return characters, labels
63
+
64
+ def _text_labeler(self):
65
+ """Label the whole text and return characters and labels."""
66
+ # Initialize labels with all zeros
67
+ labels = [0] * len(self.data)
68
+ # Convert characters to a list
69
+ characters = list(self.data)
70
+ # Track indices to delete
71
+ deletable = []
72
+
73
+ # Loop through the classes
74
+ for i in range(self.class_count):
75
+ # Find all matches for the current class
76
+ for match in re.finditer(self._regexes[i], self.data):
77
+ idx = match.start()
78
+ # Label the character before the match
79
+ labels[idx - 1] = self._tags[i]
80
+ # Mark this character for deletion
81
+ deletable.append(idx)
82
+
83
+ # Sort deletable indices in descending order
84
+ deletable.sort(reverse=True)
85
+
86
+ # Delete characters and labels at the specified indices
87
+ for idx in deletable:
88
+ del characters[idx]
89
+ del labels[idx]
90
+
91
+ return characters, labels
92
+
93
+ def _labeler(self):
94
+ """Label the data and return characters and labels."""
95
+ # Initialize empty lists for results
96
+ result_chars = []
97
+ result_labels = []
98
+
99
+ # Process based on corpus type
100
+ if self.corpus_type == self.SENTS_RAW:
101
+ for sent in self.data:
102
+ # Label each sentence individually
103
+ characters, labels = self._sent_labeler(sent)
104
+ result_chars.append(characters)
105
+ result_labels.append(labels)
106
+ elif self.corpus_type == self.WHOLE_RAW:
107
+ # Label the entire text at once
108
+ result_chars, result_labels = self._text_labeler()
109
+ # Wrap the results in lists to maintain consistent return structure
110
+ result_chars = [result_chars]
111
+ result_labels = [result_labels]
112
+
113
+ return result_chars, result_labels
114
+
115
+ def label_text(self, textinput, corpus_type):
116
+ """Label text and return characters and labels.
117
+
118
+ Args:
119
+ textinput: Either a string or a list of strings to label
120
+ corpus_type: Either Labeler.WHOLE_RAW or Labeler.SENTS_RAW
121
+
122
+ Returns:
123
+ A tuple of (characters, labels)
124
+ """
125
+
126
+ # Validate input types
127
+ if corpus_type == self.WHOLE_RAW and isinstance(textinput, str):
128
+ self.data = textinput
129
+ self.corpus_type = corpus_type
130
+ elif corpus_type == self.SENTS_RAW and isinstance(textinput, list):
131
+ self.data = textinput
132
+ self.corpus_type = corpus_type
133
+ else:
134
+ raise ValueError(f"Invalid input: expected {corpus_type} with compatible data type")
135
+
136
+ return self._labeler()
137
+
138
+ def _text_generator(self, chars, labels):
139
+ """Generate text with labels inserted.
140
+
141
+ Args:
142
+ chars: A list of characters
143
+ labels: A list of labels for those characters
144
+
145
+ Returns:
146
+ A string with class characters inserted according to the labels
147
+ """
148
+ result = []
149
+ for char, label in zip(chars, labels):
150
+ # Always add the character
151
+ result.append(char)
152
+
153
+ # Add class character if needed
154
+ if label != 0:
155
+ for i in range(self.class_count):
156
+ if label == self._tags[i]:
157
+ result.append(self._class_chars[i])
158
+ break
159
+
160
+ return ''.join(result)
161
+
162
+ def text_generator(self, chars, labels, corpus_type):
163
+ """Generate text with labels inserted.
164
+
165
+ Args:
166
+ chars: Either a list of characters or a list of lists of characters
167
+ labels: Either a list of labels or a list of lists of labels
168
+ corpus_type: Either Labeler.WHOLE_RAW or Labeler.SENTS_RAW
169
+
170
+ Returns:
171
+ Either a string or a list of strings with class characters inserted
172
+ """
173
+ if corpus_type == self.SENTS_RAW:
174
+ # Process each sentence separately
175
+ return [self._text_generator(sent_chars, sent_labels)
176
+ for sent_chars, sent_labels in zip(chars, labels)]
177
+ elif corpus_type == self.WHOLE_RAW:
178
+ # Process the whole text at once
179
+ return self._text_generator(chars, labels)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7af3dfbb0383158d574242df30addc291152cf5dce6948c96118460d87362666
3
+ size 669471252
modeling_custom.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertForTokenClassification
4
+
5
+
6
+ class FourOClassifier(nn.Module):
7
+ def __init__(self, clf_hidden_size, num_labels):
8
+ super(FourOClassifier, self).__init__()
9
+ self.dense = nn.Linear(clf_hidden_size, clf_hidden_size)
10
+ self.activation = nn.ReLU()
11
+ self.dropout = nn.Dropout(p=0.1)
12
+ self.batch_norm = nn.BatchNorm1d(clf_hidden_size)
13
+ self.output_layer = nn.Linear(clf_hidden_size, num_labels)
14
+
15
+ def forward(self, clf_input):
16
+ x = self.dense(clf_input)
17
+ x = self.activation(x)
18
+ x = self.dropout(x)
19
+ x = self.batch_norm(x.permute(0, 2, 1)).permute(0, 2, 1) # BatchNorm1d expects (N, C, L)
20
+ x = self.output_layer(x)
21
+ return x
22
+
23
+
24
+ class BertForTokenClassificationWithFourO(BertForTokenClassification):
25
+ def __init__(self, config):
26
+ super().__init__(config)
27
+ self.num_labels = config.num_labels
28
+ self.classifier = FourOClassifier(config.hidden_size, config.num_labels)
29
+ self.init_weights()
30
+
31
+ @classmethod
32
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
33
+ model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
34
+ model.check_classifier_initialization()
35
+ return model
36
+
37
+ def check_classifier_initialization(self):
38
+ # Check if classifier weights seem to be randomly initialized
39
+ def is_randomly_initialized(tensor):
40
+ return torch.abs(tensor.mean()) < 1e-3 < tensor.std() < 1e-1
41
+
42
+ classifier_weights = [
43
+ self.classifier.dense.weight,
44
+ self.classifier.dense.bias,
45
+ self.classifier.output_layer.weight,
46
+ self.classifier.output_layer.bias
47
+ ]
48
+
49
+ def freeze_bert(self):
50
+ """Freezes the BERT layers to prevent their parameters from being updated during training."""
51
+ for param in self.bert.parameters():
52
+ param.requires_grad = False
53
+ print("BERT layers frozen.")
54
+
55
+ def unfreeze_bert(self):
56
+ """Unfreezes the BERT layers to allow their parameters to be updated during training."""
57
+ for param in self.bert.parameters():
58
+ param.requires_grad = True
59
+ print("BERT layers unfrozen.")
result.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ می‌دانید که ادبیات فارسی گنجینه‌ای از آثار پرمغز و زیباست. خواندن متون فاخر فارسی نه‌تنها موجب افزایش دانش بلکه باعث غنی‌تر شدن روح انسان می‌شود. بسیاری از اشعار و نثرهای کهن، مانند آثار حافظ سعدی، و فردوسی، باید با دقت خوانده شوند تا معنای کامل آنها را دریافت کنیم. درک معنای ژرف چنین آثاری بدون داشتن دانش کافی از زبان و فرهنگ فارسی امکان‌پذیر نیست. بنابراین، آموزش درست به کودکان و نوجوانان نقش حیاتی در زنده نگه داشتن این میراث فرهنگی دارد. فضای مجازی امروز نقش عمده‌ای در انتشار آثار فارسی دارد، اما باید به شیوه‌ای صحیح از آن بهره ببریم تا از تحریف یا اشتباه در انتقال متون جلوگیری شود.
run.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import argparse
5
+ from transformers import AutoTokenizer
6
+ from modeling_custom import BertForTokenClassificationWithFourO
7
+ from labeler import Labeler
8
+
9
+
10
+ class ModelPipeline:
11
+ """
12
+ Pipeline for text processing using the BertForTokenClassificationWithFourO model.
13
+ Handles preprocessing, inference, and postprocessing.
14
+ """
15
+
16
+ def __init__(self, model_path=None):
17
+ """
18
+ Initialize the pipeline with model and tokenizer.
19
+
20
+ Args:
21
+ model_path: Path to the model directory. Defaults to current directory.
22
+ """
23
+ # Use current directory if no path specified
24
+ if model_path is None:
25
+ model_path = os.path.dirname(os.path.abspath(__file__))
26
+
27
+ # Load tokenizer and model
28
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ print(f"Using device: {self.device}")
30
+
31
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
32
+ self.model = BertForTokenClassificationWithFourO.from_pretrained(model_path)
33
+ self.model.to(self.device)
34
+ self.model.eval()
35
+
36
+ # Load config for any custom settings
37
+ config_path = os.path.join(model_path, "config.json")
38
+ with open(config_path, 'r') as f:
39
+ self.config = json.load(f)
40
+
41
+ # Initialize labeler for postprocessing
42
+ self.labeler = Labeler(tags=(1, 2),
43
+ regexes=(r'[^\S\r\n\v\f]', r'\u200c'),
44
+ chars=(" ", "‌"),
45
+ class_count=2)
46
+
47
+ def _tokenize(self, chars, chunk_size=512):
48
+ input_ids = []
49
+ attention_masks = []
50
+
51
+ ids_ = [101] + [self.tokenizer.encode(char)[1] for char in chars] + [102]
52
+ for i in range(0, len(ids_), chunk_size):
53
+ chunked_ids = ids_[i:i + chunk_size]
54
+ attention_mask = [1] * len(chunked_ids)
55
+
56
+ if len(chunked_ids) != chunk_size:
57
+ attention_mask += [0] * (chunk_size - len(chunked_ids)) # padding the attention mask accordingly
58
+ chunked_ids += [0] * (chunk_size - len(chunked_ids)) # padding the last chunk to chunk size
59
+
60
+ input_ids.append(chunked_ids)
61
+ attention_masks.append(attention_mask)
62
+
63
+ return input_ids, attention_masks
64
+
65
+ def preprocess(self, text):
66
+ chars, labels = self.labeler.label_text(text, corpus_type=Labeler.WHOLE_RAW)
67
+ input_ids, attention_mask = self._tokenize(chars[0], chunk_size=512)
68
+ input_ids = torch.tensor(input_ids)
69
+ attention_mask = torch.tensor(attention_mask)
70
+ labels = [0] + labels[0] + [0]
71
+ return {"input_ids": input_ids, "attention_mask": attention_mask}, chars, labels
72
+
73
+ def predict(self, encoded_inputs):
74
+ """
75
+ Run model inference on preprocessed inputs.
76
+
77
+ Args:
78
+ encoded_inputs: Tokenized inputs from preprocess method
79
+
80
+ Returns:
81
+ Model predictions
82
+ """
83
+ # Move input tensors to the same device as the model
84
+ input_ids = encoded_inputs["input_ids"].to(self.device)
85
+ attention_mask = encoded_inputs["attention_mask"].to(self.device)
86
+
87
+ with torch.no_grad():
88
+ outputs = self.model(
89
+ input_ids=input_ids,
90
+ attention_mask=attention_mask
91
+ )
92
+
93
+ # Get predicted class for each token
94
+ predictions = torch.argmax(outputs.logits, dim=-1)
95
+ return predictions, outputs.logits
96
+
97
+ def postprocess_spacing(self, predictions, chars):
98
+ # Move predictions back to CPU if needed for postprocessing
99
+ predictions_cpu = predictions.cpu()
100
+ predictions_flat = [label for sample in predictions_cpu.tolist() for label in sample]
101
+ return self.labeler.text_generator([' '] + chars[0] + [' '],
102
+ predictions_flat[:len(chars[0]) + 2],
103
+ corpus_type=Labeler.WHOLE_RAW).strip()
104
+
105
+ def postprocess_correcting(self, logits, chars, labels, alpha):
106
+ # Process the labels to match the sequence length
107
+
108
+ sequence_length = logits.size(1)
109
+ if len(labels) < sequence_length:
110
+ labels.extend([0] * (sequence_length - len(labels))) # Padding with 0
111
+
112
+ else:
113
+ labels = labels[:sequence_length] # Truncate if longer
114
+
115
+ # Convert labels to one-hot encoding
116
+ num_classes = logits.size(-1)
117
+ user_labels_tensor = torch.tensor([labels], device=self.device)
118
+ user_labels_one_hot = torch.nn.functional.one_hot(user_labels_tensor, num_classes=num_classes).float()
119
+
120
+ # Expand dimensions to match logits shape if needed
121
+ if logits.dim() > user_labels_one_hot.dim():
122
+ user_labels_one_hot = user_labels_one_hot.unsqueeze(0)
123
+
124
+ # Combine logits and user labels
125
+ combined_logits = logits * (1 - alpha) + user_labels_one_hot * alpha
126
+
127
+ # Apply softmax to get probabilities
128
+ combined_probs = torch.nn.functional.softmax(combined_logits, dim=-1)
129
+
130
+ # Get the final predictions
131
+ final_predictions = torch.argmax(combined_probs, dim=-1)
132
+
133
+ # Move to CPU for postprocessing
134
+ final_predictions_cpu = final_predictions.cpu()
135
+
136
+ # Flatten predictions
137
+ predictions_flat = [label for sample in final_predictions_cpu.tolist() for label in sample]
138
+
139
+ # Generate text with combined predictions
140
+ return self.labeler.text_generator([' '] + chars[0] + [' '],
141
+ predictions_flat[:len(chars[0]) + 2],
142
+ corpus_type=Labeler.WHOLE_RAW).strip()
143
+
144
+ def process_text(self, text, mode):
145
+ """
146
+ Process text through the entire pipeline.
147
+
148
+ Args:
149
+ text: Input text as string or list of strings
150
+ mode: Processing mode ('space' or 'correct')
151
+
152
+ Returns:
153
+ Processed text with labels applied
154
+ """
155
+ # Run the full pipeline
156
+ encoded_inputs, chars, labels = self.preprocess(text)
157
+
158
+ predictions, logits = self.predict(encoded_inputs)
159
+
160
+ if mode == 'space':
161
+ result = self.postprocess_spacing(predictions, chars)
162
+ elif mode == 'correct':
163
+ result = self.postprocess_correcting(logits, chars, labels, 0.5)
164
+ else:
165
+ raise ValueError(f"Unrecognized mode: {mode}")
166
+ return result
167
+
168
+
169
+ def main():
170
+ """Command line interface for the model pipeline."""
171
+ parser = argparse.ArgumentParser(description="Process text using token classification model")
172
+ parser.add_argument("--text", type=str, help="Text to process")
173
+ parser.add_argument("--file", type=str, help="Path to file containing text to process")
174
+ parser.add_argument("--output", type=str, help="Path to output file")
175
+ parser.add_argument("--mode", type=str, choices=['space', 'correct'], default='space',
176
+ help="Processing mode: 'space' uses model outputs only, 'correct' combines model results with "
177
+ "original text spacing")
178
+ parser.add_argument("--model_path", type=str, default=None, help="Path to model directory")
179
+ args = parser.parse_args()
180
+
181
+ # Initialize pipeline
182
+ pipeline = ModelPipeline(args.model_path)
183
+
184
+ # Get input text
185
+ if args.text:
186
+ input_text = args.text
187
+ elif args.file:
188
+ with open(args.file, 'r', encoding='utf-8') as f:
189
+ input_text = f.read()
190
+ else:
191
+ print("Please provide either --text or --file")
192
+ return
193
+
194
+ # Process the text
195
+ result = pipeline.process_text(input_text, args.mode)
196
+
197
+ # Output results
198
+ if args.output:
199
+ with open(args.output, 'w', encoding='utf-8') as f:
200
+ if isinstance(result, list):
201
+ f.write('\n'.join(result))
202
+ else:
203
+ f.write(result)
204
+ else:
205
+ print("\nProcessed Text:")
206
+ if isinstance(result, list):
207
+ for item in result:
208
+ print(item)
209
+ else:
210
+ print(result)
211
+
212
+
213
+ if __name__ == "__main__":
214
+ main()
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "mask_token": "[MASK]",
48
+ "model_max_length": 512,
49
+ "pad_token": "[PAD]",
50
+ "sep_token": "[SEP]",
51
+ "strip_accents": null,
52
+ "tokenize_chinese_chars": true,
53
+ "tokenizer_class": "BertTokenizer",
54
+ "unk_token": "[UNK]"
55
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c82f1de4729fbd371473ac1e44782c7a63067513492c68ebacd386520759be1d
3
+ size 5176
vocab.txt ADDED
The diff for this file is too large to render. See raw diff