IsmatS commited on
Commit
6c0d4ed
·
1 Parent(s): 8b7950c
Files changed (1) hide show
  1. model.py +272 -0
model.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["WANDB_DISABLED"] = "true"
3
+
4
+ import pandas as pd
5
+ import numpy as np
6
+ from datasets import Dataset, DatasetDict, Features, Sequence, Value
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ AutoModelForTokenClassification,
10
+ DataCollatorForTokenClassification,
11
+ TrainingArguments,
12
+ Trainer
13
+ )
14
+ from seqeval.metrics import f1_score, precision_score, recall_score
15
+ import torch
16
+ import json
17
+ import ast
18
+ from typing import List, Dict, Tuple
19
+
20
+ class AzerbaijaniNERPipeline:
21
+ def __init__(self, model_name="bert-base-multilingual-cased", output_dir="az-ner-model"):
22
+ self.model_name = model_name
23
+ self.output_dir = output_dir
24
+ if not os.path.exists(self.output_dir):
25
+ os.makedirs(self.output_dir)
26
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+ self.initialize_label_mappings()
28
+
29
+ def initialize_label_mappings(self):
30
+ """Initialize label mappings for the NER tags"""
31
+ self.label2id = {str(i): i for i in range(25)} # 0-24 for all entity types
32
+ self.id2label = {v: k for k, v in self.label2id.items()}
33
+
34
+ def parse_list_string(self, s: str) -> List:
35
+ """Parse a string representation of a list"""
36
+ try:
37
+ if pd.isna(s) or not isinstance(s, str):
38
+ return []
39
+ result = ast.literal_eval(s)
40
+ if not isinstance(result, list):
41
+ return []
42
+ return result
43
+ except:
44
+ return []
45
+
46
+ def clean_and_validate_data(self, df: pd.DataFrame) -> pd.DataFrame:
47
+ """Clean and validate the dataset"""
48
+ print("Cleaning and validating data...")
49
+
50
+ def process_row(row):
51
+ try:
52
+ # Parse tokens and tags
53
+ tokens = self.parse_list_string(row['tokens'])
54
+ ner_tags = self.parse_list_string(row['ner_tags'])
55
+
56
+ # Skip invalid rows
57
+ if not tokens or not ner_tags or len(tokens) != len(ner_tags):
58
+ return None
59
+
60
+ # Ensure all tags are integers and within valid range
61
+ ner_tags = [
62
+ int(tag) if isinstance(tag, (int, str)) and str(tag).isdigit() and int(tag) < 25
63
+ else 0
64
+ for tag in ner_tags
65
+ ]
66
+
67
+ return {
68
+ 'tokens': tokens,
69
+ 'ner_tags': ner_tags,
70
+ }
71
+ except Exception as e:
72
+ return None
73
+
74
+ # Process all rows
75
+ processed_data = []
76
+ skipped_rows = 0
77
+
78
+ for _, row in df.iterrows():
79
+ processed_row = process_row(row)
80
+ if processed_row is not None:
81
+ processed_data.append(processed_row)
82
+ else:
83
+ skipped_rows += 1
84
+
85
+ print(f"Skipped {skipped_rows} invalid rows")
86
+ print(f"Processed {len(processed_data)} valid rows")
87
+
88
+ return pd.DataFrame(processed_data)
89
+
90
+ def create_features(self) -> Features:
91
+ """Create feature descriptions for the dataset"""
92
+ return Features({
93
+ 'tokens': Sequence(Value('string')),
94
+ 'ner_tags': Sequence(Value('int64'))
95
+ })
96
+
97
+ def load_dataset(self, parquet_path: str) -> DatasetDict:
98
+ """Load and prepare the dataset"""
99
+ print(f"Loading dataset from {parquet_path}...")
100
+
101
+ # Load parquet file
102
+ df = pd.read_parquet(parquet_path)
103
+ print(f"Initial dataset size: {len(df)} rows")
104
+
105
+ # Clean and validate data
106
+ processed_df = self.clean_and_validate_data(df)
107
+
108
+ # Create dataset with explicit feature definitions
109
+ dataset = Dataset.from_pandas(
110
+ processed_df,
111
+ features=self.create_features(),
112
+ preserve_index=False
113
+ )
114
+
115
+ # Split dataset
116
+ train_test = dataset.train_test_split(test_size=0.2, seed=42)
117
+ test_valid = train_test['test'].train_test_split(test_size=0.5, seed=42)
118
+
119
+ dataset_dict = DatasetDict({
120
+ 'train': train_test['train'],
121
+ 'validation': test_valid['train'],
122
+ 'test': test_valid['test']
123
+ })
124
+
125
+ # Print split sizes and sample
126
+ print("\nDataset splits:")
127
+ for split, ds in dataset_dict.items():
128
+ print(f"{split} set size: {len(ds)} examples")
129
+
130
+ print("\nSample from training set:")
131
+ sample = dataset_dict['train'][0]
132
+ print(f"Tokens: {sample['tokens']}")
133
+ print(f"Tags: {sample['ner_tags']}")
134
+
135
+ # Calculate and print label distribution
136
+ print("\nLabel distribution in training set:")
137
+ all_labels = []
138
+ for example in dataset_dict['train']:
139
+ all_labels.extend(example['ner_tags'])
140
+ label_counts = pd.Series(all_labels).value_counts().sort_index()
141
+ for label, count in label_counts.items():
142
+ print(f"Label {label}: {count} occurrences")
143
+
144
+ return dataset_dict
145
+
146
+ def tokenize_and_align_labels(self, examples):
147
+ """Tokenize and align labels with tokens"""
148
+ tokenized_inputs = self.tokenizer(
149
+ examples["tokens"],
150
+ truncation=True,
151
+ is_split_into_words=True,
152
+ max_length=512,
153
+ padding="max_length"
154
+ )
155
+
156
+ labels = []
157
+ for i, label in enumerate(examples["ner_tags"]):
158
+ word_ids = tokenized_inputs.word_ids(batch_index=i)
159
+ previous_word_idx = None
160
+ label_ids = []
161
+
162
+ for word_idx in word_ids:
163
+ if word_idx is None:
164
+ label_ids.append(-100)
165
+ elif word_idx != previous_word_idx:
166
+ label_ids.append(int(label[word_idx]))
167
+ else:
168
+ label_ids.append(-100)
169
+ previous_word_idx = word_idx
170
+
171
+ labels.append(label_ids)
172
+
173
+ tokenized_inputs["labels"] = labels
174
+ return tokenized_inputs
175
+
176
+ def compute_metrics(self, eval_preds):
177
+ """Compute evaluation metrics"""
178
+ predictions, labels = eval_preds
179
+ predictions = np.argmax(predictions, axis=2)
180
+
181
+ # Remove ignored index (-100)
182
+ true_predictions = [
183
+ [str(p) for (p, l) in zip(prediction, label) if l != -100]
184
+ for prediction, label in zip(predictions, labels)
185
+ ]
186
+ true_labels = [
187
+ [str(l) for (p, l) in zip(prediction, label) if l != -100]
188
+ for prediction, label in zip(predictions, labels)
189
+ ]
190
+
191
+ return {
192
+ "precision": precision_score(true_labels, true_predictions),
193
+ "recall": recall_score(true_labels, true_predictions),
194
+ "f1": f1_score(true_labels, true_predictions)
195
+ }
196
+
197
+ def train(self, dataset_dict: DatasetDict):
198
+ """Train the NER model"""
199
+ print("Initializing model...")
200
+ model = AutoModelForTokenClassification.from_pretrained(
201
+ self.model_name,
202
+ num_labels=len(self.label2id),
203
+ id2label=self.id2label,
204
+ label2id=self.label2id
205
+ )
206
+
207
+ print("Preparing datasets...")
208
+ tokenized_datasets = dataset_dict.map(
209
+ self.tokenize_and_align_labels,
210
+ batched=True,
211
+ remove_columns=dataset_dict["train"].column_names
212
+ )
213
+
214
+ training_args = TrainingArguments(
215
+ output_dir=self.output_dir,
216
+ evaluation_strategy="steps",
217
+ eval_steps=100,
218
+ learning_rate=2e-5,
219
+ per_device_train_batch_size=16,
220
+ per_device_eval_batch_size=16,
221
+ num_train_epochs=5,
222
+ weight_decay=0.01,
223
+ push_to_hub=False,
224
+ load_best_model_at_end=True,
225
+ metric_for_best_model="f1",
226
+ logging_dir=os.path.join(self.output_dir, 'logs'),
227
+ logging_steps=50,
228
+ report_to="none" # Disable wandb logging
229
+ )
230
+
231
+ print("Initializing trainer...")
232
+ trainer = Trainer(
233
+ model=model,
234
+ args=training_args,
235
+ train_dataset=tokenized_datasets["train"],
236
+ eval_dataset=tokenized_datasets["validation"],
237
+ tokenizer=self.tokenizer,
238
+ data_collator=DataCollatorForTokenClassification(self.tokenizer),
239
+ compute_metrics=self.compute_metrics
240
+ )
241
+
242
+ print("Starting training...")
243
+ trainer.train()
244
+
245
+ print("Saving model...")
246
+ trainer.save_model(self.output_dir)
247
+
248
+ return trainer
249
+
250
+ def main():
251
+ # Initialize pipeline
252
+ pipeline = AzerbaijaniNERPipeline()
253
+
254
+ # Load and process dataset
255
+ dataset_dict = pipeline.load_dataset("train-00000-of-00001.parquet")
256
+
257
+ # Train model
258
+ trainer = pipeline.train(dataset_dict)
259
+
260
+ # Final evaluation
261
+ print("Performing final evaluation...")
262
+ test_results = trainer.evaluate(
263
+ dataset_dict["test"].map(
264
+ pipeline.tokenize_and_align_labels,
265
+ batched=True,
266
+ remove_columns=dataset_dict["test"].column_names
267
+ )
268
+ )
269
+ print("\nFinal Test Results:", json.dumps(test_results, indent=2))
270
+
271
+ if __name__ == "__main__":
272
+ main()