xlm-roberta-large
Browse files- models/XLM-RoBERTa.ipynb +138 -70
- models/push_to_HF.py +2 -2
- models/xlm_roberta_large.ipynb +2 -2
models/XLM-RoBERTa.ipynb
CHANGED
|
@@ -112,18 +112,31 @@
|
|
| 112 |
},
|
| 113 |
"outputs": [],
|
| 114 |
"source": [
|
| 115 |
-
"
|
| 116 |
-
"import
|
| 117 |
-
"import
|
| 118 |
-
"import
|
| 119 |
-
"import
|
| 120 |
-
"
|
|
|
|
|
|
|
|
|
|
| 121 |
"from transformers import (\n",
|
| 122 |
-
" AutoTokenizer,
|
| 123 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
")\n",
|
| 125 |
-
"
|
| 126 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
]
|
| 128 |
},
|
| 129 |
{
|
|
@@ -160,7 +173,7 @@
|
|
| 160 |
],
|
| 161 |
"source": [
|
| 162 |
"# Log in to Hugging Face Hub\n",
|
| 163 |
-
"login(token=\"
|
| 164 |
]
|
| 165 |
},
|
| 166 |
{
|
|
@@ -171,8 +184,10 @@
|
|
| 171 |
},
|
| 172 |
"outputs": [],
|
| 173 |
"source": [
|
| 174 |
-
"# Disable unwanted
|
| 175 |
"os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
|
|
|
|
|
|
|
| 176 |
"warnings.filterwarnings(\"ignore\")\n"
|
| 177 |
]
|
| 178 |
},
|
|
@@ -337,22 +352,24 @@
|
|
| 337 |
}
|
| 338 |
],
|
| 339 |
"source": [
|
| 340 |
-
"# Load the dataset\n",
|
| 341 |
"dataset = load_dataset(\"LocalDoc/azerbaijani-ner-dataset\")\n",
|
| 342 |
-
"print(dataset)\n",
|
| 343 |
"\n",
|
| 344 |
-
"# Preprocessing function
|
| 345 |
"def preprocess_example(example):\n",
|
| 346 |
" try:\n",
|
|
|
|
| 347 |
" example[\"tokens\"] = ast.literal_eval(example[\"tokens\"])\n",
|
| 348 |
" example[\"ner_tags\"] = list(map(int, ast.literal_eval(example[\"ner_tags\"])))\n",
|
| 349 |
" except (ValueError, SyntaxError) as e:\n",
|
|
|
|
| 350 |
" print(f\"Skipping malformed example: {example['index']} due to error: {e}\")\n",
|
| 351 |
" example[\"tokens\"] = []\n",
|
| 352 |
" example[\"ner_tags\"] = []\n",
|
| 353 |
" return example\n",
|
| 354 |
"\n",
|
| 355 |
-
"# Apply preprocessing\n",
|
| 356 |
"dataset = dataset.map(preprocess_example)\n"
|
| 357 |
]
|
| 358 |
},
|
|
@@ -507,32 +524,39 @@
|
|
| 507 |
}
|
| 508 |
],
|
| 509 |
"source": [
|
| 510 |
-
"# Initialize tokenizer\n",
|
| 511 |
"tokenizer = AutoTokenizer.from_pretrained(\"xlm-roberta-base\")\n",
|
| 512 |
"\n",
|
|
|
|
| 513 |
"def tokenize_and_align_labels(example):\n",
|
|
|
|
| 514 |
" tokenized_inputs = tokenizer(\n",
|
| 515 |
-
" example[\"tokens\"],\n",
|
| 516 |
-
" truncation=True,\n",
|
| 517 |
-
" is_split_into_words=True,\n",
|
| 518 |
-
" padding=\"max_length\",\n",
|
| 519 |
-
" max_length=128,\n",
|
| 520 |
" )\n",
|
| 521 |
-
"
|
| 522 |
-
"
|
| 523 |
-
"
|
|
|
|
|
|
|
|
|
|
| 524 |
" for word_idx in word_ids:\n",
|
| 525 |
" if word_idx is None:\n",
|
| 526 |
-
" labels.append(-100)\n",
|
| 527 |
" elif word_idx != previous_word_idx:\n",
|
|
|
|
| 528 |
" labels.append(example[\"ner_tags\"][word_idx] if word_idx < len(example[\"ner_tags\"]) else -100)\n",
|
| 529 |
" else:\n",
|
| 530 |
-
" labels.append(-100)\n",
|
| 531 |
-
" previous_word_idx = word_idx\n",
|
| 532 |
-
"
|
|
|
|
| 533 |
" return tokenized_inputs\n",
|
| 534 |
"\n",
|
| 535 |
-
"# Apply tokenization and alignment\n",
|
| 536 |
"tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=False)\n"
|
| 537 |
]
|
| 538 |
},
|
|
@@ -575,9 +599,9 @@
|
|
| 575 |
}
|
| 576 |
],
|
| 577 |
"source": [
|
| 578 |
-
"# Create a 90-10 split for training and validation\n",
|
| 579 |
"tokenized_datasets = tokenized_datasets[\"train\"].train_test_split(test_size=0.1)\n",
|
| 580 |
-
"print(tokenized_datasets)
|
| 581 |
]
|
| 582 |
},
|
| 583 |
{
|
|
@@ -588,18 +612,33 @@
|
|
| 588 |
},
|
| 589 |
"outputs": [],
|
| 590 |
"source": [
|
|
|
|
| 591 |
"label_list = [\n",
|
| 592 |
-
" \"O\",
|
| 593 |
-
" \"B-
|
| 594 |
-
" \"B-
|
| 595 |
-
" \"
|
| 596 |
-
" \"
|
| 597 |
-
" \"B-
|
| 598 |
-
" \"
|
| 599 |
-
" \"B-
|
| 600 |
-
" \"B-
|
| 601 |
-
" \"B-
|
| 602 |
-
" \"B-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
"]\n"
|
| 604 |
]
|
| 605 |
},
|
|
@@ -662,13 +701,13 @@
|
|
| 662 |
}
|
| 663 |
],
|
| 664 |
"source": [
|
| 665 |
-
"# Initialize
|
| 666 |
"data_collator = DataCollatorForTokenClassification(tokenizer)\n",
|
| 667 |
"\n",
|
| 668 |
-
"# Load
|
| 669 |
"model = AutoModelForTokenClassification.from_pretrained(\n",
|
| 670 |
-
" \"xlm-roberta-
|
| 671 |
-
" num_labels=len(label_list)\n",
|
| 672 |
")\n"
|
| 673 |
]
|
| 674 |
},
|
|
@@ -680,18 +719,35 @@
|
|
| 680 |
},
|
| 681 |
"outputs": [],
|
| 682 |
"source": [
|
|
|
|
| 683 |
"def compute_metrics(p):\n",
|
| 684 |
-
" predictions, labels = p\n",
|
|
|
|
|
|
|
| 685 |
" predictions = np.argmax(predictions, axis=2)\n",
|
|
|
|
|
|
|
| 686 |
" true_labels = [[label_list[l] for l in label if l != -100] for label in labels]\n",
|
| 687 |
" true_predictions = [\n",
|
| 688 |
" [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
|
| 689 |
" for prediction, label in zip(predictions, labels)\n",
|
| 690 |
" ]\n",
|
|
|
|
|
|
|
| 691 |
" print(classification_report(true_labels, true_predictions))\n",
|
|
|
|
|
|
|
| 692 |
" return {\n",
|
|
|
|
|
|
|
| 693 |
" \"precision\": precision_score(true_labels, true_predictions),\n",
|
|
|
|
|
|
|
|
|
|
| 694 |
" \"recall\": recall_score(true_labels, true_predictions),\n",
|
|
|
|
|
|
|
|
|
|
| 695 |
" \"f1\": f1_score(true_labels, true_predictions),\n",
|
| 696 |
" }\n"
|
| 697 |
]
|
|
@@ -704,21 +760,22 @@
|
|
| 704 |
},
|
| 705 |
"outputs": [],
|
| 706 |
"source": [
|
|
|
|
| 707 |
"training_args = TrainingArguments(\n",
|
| 708 |
-
" output_dir=\"./results\",\n",
|
| 709 |
-
" evaluation_strategy=\"epoch\",
|
| 710 |
-
" save_strategy=\"epoch\",
|
| 711 |
-
" learning_rate=
|
| 712 |
-
" per_device_train_batch_size=
|
| 713 |
-
" per_device_eval_batch_size=
|
| 714 |
-
" num_train_epochs=
|
| 715 |
-
" weight_decay=0.
|
| 716 |
-
" fp16=True,\n",
|
| 717 |
-
" logging_dir='./logs',\n",
|
| 718 |
-
" save_total_limit=2,\n",
|
| 719 |
-
" load_best_model_at_end=True,
|
| 720 |
-
" metric_for_best_model=\"f1\",\n",
|
| 721 |
-
" report_to=\"none\"\n",
|
| 722 |
")\n"
|
| 723 |
]
|
| 724 |
},
|
|
@@ -730,15 +787,16 @@
|
|
| 730 |
},
|
| 731 |
"outputs": [],
|
| 732 |
"source": [
|
|
|
|
| 733 |
"trainer = Trainer(\n",
|
| 734 |
-
" model=model,\n",
|
| 735 |
-
" args=training_args,\n",
|
| 736 |
-
" train_dataset=tokenized_datasets[\"train\"],\n",
|
| 737 |
-
" eval_dataset=tokenized_datasets[\"test\"],\n",
|
| 738 |
-
" tokenizer=tokenizer,\n",
|
| 739 |
-
" data_collator=data_collator,\n",
|
| 740 |
-
" compute_metrics=compute_metrics,\n",
|
| 741 |
-
" callbacks=[EarlyStoppingCallback(early_stopping_patience=
|
| 742 |
")\n"
|
| 743 |
]
|
| 744 |
},
|
|
@@ -1036,8 +1094,13 @@
|
|
| 1036 |
}
|
| 1037 |
],
|
| 1038 |
"source": [
|
|
|
|
| 1039 |
"training_metrics = trainer.train()\n",
|
|
|
|
|
|
|
| 1040 |
"eval_results = trainer.evaluate()\n",
|
|
|
|
|
|
|
| 1041 |
"print(eval_results)\n"
|
| 1042 |
]
|
| 1043 |
},
|
|
@@ -1078,8 +1141,13 @@
|
|
| 1078 |
}
|
| 1079 |
],
|
| 1080 |
"source": [
|
|
|
|
| 1081 |
"save_directory = \"./XLM-RoBERTa\"\n",
|
|
|
|
|
|
|
| 1082 |
"model.save_pretrained(save_directory)\n",
|
|
|
|
|
|
|
| 1083 |
"tokenizer.save_pretrained(save_directory)\n"
|
| 1084 |
]
|
| 1085 |
},
|
|
|
|
| 112 |
},
|
| 113 |
"outputs": [],
|
| 114 |
"source": [
|
| 115 |
+
"# Standard library imports\n",
|
| 116 |
+
"import os # Provides functions for interacting with the operating system\n",
|
| 117 |
+
"import warnings # Used to handle or suppress warnings\n",
|
| 118 |
+
"import numpy as np # Essential for numerical operations and array manipulation\n",
|
| 119 |
+
"import torch # PyTorch library for tensor computations and model handling\n",
|
| 120 |
+
"import ast # Used for safe evaluation of strings to Python objects (e.g., parsing tokens)\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"# Hugging Face and Transformers imports\n",
|
| 123 |
+
"from datasets import load_dataset # Loads datasets for model training and evaluation\n",
|
| 124 |
"from transformers import (\n",
|
| 125 |
+
" AutoTokenizer, # Initializes a tokenizer from a pre-trained model\n",
|
| 126 |
+
" DataCollatorForTokenClassification, # Handles padding and formatting of token classification data\n",
|
| 127 |
+
" TrainingArguments, # Defines training parameters like batch size and learning rate\n",
|
| 128 |
+
" Trainer, # High-level API for managing training and evaluation\n",
|
| 129 |
+
" AutoModelForTokenClassification, # Loads a pre-trained model for token classification tasks\n",
|
| 130 |
+
" get_linear_schedule_with_warmup, # Learning rate scheduler for gradual warm-up and linear decay\n",
|
| 131 |
+
" EarlyStoppingCallback # Callback to stop training if validation performance plateaus\n",
|
| 132 |
")\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"# Hugging Face Hub\n",
|
| 135 |
+
"from huggingface_hub import login # Allows logging in to Hugging Face Hub to upload models\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"# seqeval metrics for NER evaluation\n",
|
| 138 |
+
"from seqeval.metrics import precision_score, recall_score, f1_score, classification_report\n",
|
| 139 |
+
"# Provides precision, recall, F1-score, and classification report for evaluating NER model performance\n"
|
| 140 |
]
|
| 141 |
},
|
| 142 |
{
|
|
|
|
| 173 |
],
|
| 174 |
"source": [
|
| 175 |
"# Log in to Hugging Face Hub\n",
|
| 176 |
+
"login(token=\"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\")\n"
|
| 177 |
]
|
| 178 |
},
|
| 179 |
{
|
|
|
|
| 184 |
},
|
| 185 |
"outputs": [],
|
| 186 |
"source": [
|
| 187 |
+
"# Disable WandB (Weights & Biases) logging to avoid unwanted log outputs during training\n",
|
| 188 |
"os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
|
| 189 |
+
"\n",
|
| 190 |
+
"# Suppress warning messages to keep output clean, especially during training and evaluation\n",
|
| 191 |
"warnings.filterwarnings(\"ignore\")\n"
|
| 192 |
]
|
| 193 |
},
|
|
|
|
| 352 |
}
|
| 353 |
],
|
| 354 |
"source": [
|
| 355 |
+
"# Load the Azerbaijani NER dataset from Hugging Face\n",
|
| 356 |
"dataset = load_dataset(\"LocalDoc/azerbaijani-ner-dataset\")\n",
|
| 357 |
+
"print(dataset) # Display dataset structure (e.g., train/validation splits)\n",
|
| 358 |
"\n",
|
| 359 |
+
"# Preprocessing function to format tokens and NER tags correctly\n",
|
| 360 |
"def preprocess_example(example):\n",
|
| 361 |
" try:\n",
|
| 362 |
+
" # Convert string of tokens to a list and parse NER tags to integers\n",
|
| 363 |
" example[\"tokens\"] = ast.literal_eval(example[\"tokens\"])\n",
|
| 364 |
" example[\"ner_tags\"] = list(map(int, ast.literal_eval(example[\"ner_tags\"])))\n",
|
| 365 |
" except (ValueError, SyntaxError) as e:\n",
|
| 366 |
+
" # Skip and log malformed examples, ensuring error resilience\n",
|
| 367 |
" print(f\"Skipping malformed example: {example['index']} due to error: {e}\")\n",
|
| 368 |
" example[\"tokens\"] = []\n",
|
| 369 |
" example[\"ner_tags\"] = []\n",
|
| 370 |
" return example\n",
|
| 371 |
"\n",
|
| 372 |
+
"# Apply preprocessing to each dataset entry, ensuring consistent formatting\n",
|
| 373 |
"dataset = dataset.map(preprocess_example)\n"
|
| 374 |
]
|
| 375 |
},
|
|
|
|
| 524 |
}
|
| 525 |
],
|
| 526 |
"source": [
|
| 527 |
+
"# Initialize the tokenizer for multilingual NER using XLM-RoBERTa\n",
|
| 528 |
"tokenizer = AutoTokenizer.from_pretrained(\"xlm-roberta-base\")\n",
|
| 529 |
"\n",
|
| 530 |
+
"# Function to tokenize input and align labels with tokenized words\n",
|
| 531 |
"def tokenize_and_align_labels(example):\n",
|
| 532 |
+
" # Tokenize the sentence while preserving word boundaries for correct NER tag alignment\n",
|
| 533 |
" tokenized_inputs = tokenizer(\n",
|
| 534 |
+
" example[\"tokens\"], # List of words (tokens) in the sentence\n",
|
| 535 |
+
" truncation=True, # Truncate sentences longer than max_length\n",
|
| 536 |
+
" is_split_into_words=True, # Specify that input is a list of words\n",
|
| 537 |
+
" padding=\"max_length\", # Pad to maximum sequence length\n",
|
| 538 |
+
" max_length=128, # Set the maximum sequence length to 128 tokens\n",
|
| 539 |
" )\n",
|
| 540 |
+
"\n",
|
| 541 |
+
" labels = [] # List to store aligned NER labels\n",
|
| 542 |
+
" word_ids = tokenized_inputs.word_ids() # Get word IDs for each token\n",
|
| 543 |
+
" previous_word_idx = None # Initialize previous word index for tracking\n",
|
| 544 |
+
"\n",
|
| 545 |
+
" # Loop through word indices to align NER tags with subword tokens\n",
|
| 546 |
" for word_idx in word_ids:\n",
|
| 547 |
" if word_idx is None:\n",
|
| 548 |
+
" labels.append(-100) # Set padding token labels to -100 (ignored in loss)\n",
|
| 549 |
" elif word_idx != previous_word_idx:\n",
|
| 550 |
+
" # Assign the label from example's NER tags if word index matches\n",
|
| 551 |
" labels.append(example[\"ner_tags\"][word_idx] if word_idx < len(example[\"ner_tags\"]) else -100)\n",
|
| 552 |
" else:\n",
|
| 553 |
+
" labels.append(-100) # Label subword tokens with -100 to avoid redundant labels\n",
|
| 554 |
+
" previous_word_idx = word_idx # Update previous word index\n",
|
| 555 |
+
"\n",
|
| 556 |
+
" tokenized_inputs[\"labels\"] = labels # Add labels to tokenized inputs\n",
|
| 557 |
" return tokenized_inputs\n",
|
| 558 |
"\n",
|
| 559 |
+
"# Apply tokenization and label alignment function to the dataset\n",
|
| 560 |
"tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=False)\n"
|
| 561 |
]
|
| 562 |
},
|
|
|
|
| 599 |
}
|
| 600 |
],
|
| 601 |
"source": [
|
| 602 |
+
"# Create a 90-10 split of the dataset for training and validation\n",
|
| 603 |
"tokenized_datasets = tokenized_datasets[\"train\"].train_test_split(test_size=0.1)\n",
|
| 604 |
+
"print(tokenized_datasets) # Output structure of split datasets"
|
| 605 |
]
|
| 606 |
},
|
| 607 |
{
|
|
|
|
| 612 |
},
|
| 613 |
"outputs": [],
|
| 614 |
"source": [
|
| 615 |
+
"# Define a list of entity labels for NER tagging with B- (beginning) and I- (inside) markers\n",
|
| 616 |
"label_list = [\n",
|
| 617 |
+
" \"O\", # Outside of a named entity\n",
|
| 618 |
+
" \"B-PERSON\", \"I-PERSON\", # Person name (e.g., \"John\" in \"John Doe\")\n",
|
| 619 |
+
" \"B-LOCATION\", \"I-LOCATION\", # Geographical location (e.g., \"Paris\")\n",
|
| 620 |
+
" \"B-ORGANISATION\", \"I-ORGANISATION\", # Organization name (e.g., \"UNICEF\")\n",
|
| 621 |
+
" \"B-DATE\", \"I-DATE\", # Date entity (e.g., \"2024-11-05\")\n",
|
| 622 |
+
" \"B-TIME\", \"I-TIME\", # Time (e.g., \"12:00 PM\")\n",
|
| 623 |
+
" \"B-MONEY\", \"I-MONEY\", # Monetary values (e.g., \"$20\")\n",
|
| 624 |
+
" \"B-PERCENTAGE\", \"I-PERCENTAGE\", # Percentage values (e.g., \"20%\")\n",
|
| 625 |
+
" \"B-FACILITY\", \"I-FACILITY\", # Physical facilities (e.g., \"Airport\")\n",
|
| 626 |
+
" \"B-PRODUCT\", \"I-PRODUCT\", # Product names (e.g., \"iPhone\")\n",
|
| 627 |
+
" \"B-EVENT\", \"I-EVENT\", # Named events (e.g., \"Olympics\")\n",
|
| 628 |
+
" \"B-ART\", \"I-ART\", # Works of art (e.g., \"Mona Lisa\")\n",
|
| 629 |
+
" \"B-LAW\", \"I-LAW\", # Laws and legal documents (e.g., \"Article 50\")\n",
|
| 630 |
+
" \"B-LANGUAGE\", \"I-LANGUAGE\", # Languages (e.g., \"Azerbaijani\")\n",
|
| 631 |
+
" \"B-GPE\", \"I-GPE\", # Geopolitical entities (e.g., \"Europe\")\n",
|
| 632 |
+
" \"B-NORP\", \"I-NORP\", # Nationalities, religious groups, political groups\n",
|
| 633 |
+
" \"B-ORDINAL\", \"I-ORDINAL\", # Ordinal indicators (e.g., \"first\", \"second\")\n",
|
| 634 |
+
" \"B-CARDINAL\", \"I-CARDINAL\", # Cardinal numbers (e.g., \"three\")\n",
|
| 635 |
+
" \"B-DISEASE\", \"I-DISEASE\", # Diseases (e.g., \"COVID-19\")\n",
|
| 636 |
+
" \"B-CONTACT\", \"I-CONTACT\", # Contact info (e.g., email or phone number)\n",
|
| 637 |
+
" \"B-ADAGE\", \"I-ADAGE\", # Common sayings or adages\n",
|
| 638 |
+
" \"B-QUANTITY\", \"I-QUANTITY\", # Quantities (e.g., \"5 km\")\n",
|
| 639 |
+
" \"B-MISCELLANEOUS\", \"I-MISCELLANEOUS\", # Miscellaneous entities not fitting other categories\n",
|
| 640 |
+
" \"B-POSITION\", \"I-POSITION\", # Job titles or positions (e.g., \"CEO\")\n",
|
| 641 |
+
" \"B-PROJECT\", \"I-PROJECT\" # Project names (e.g., \"Project Apollo\")\n",
|
| 642 |
"]\n"
|
| 643 |
]
|
| 644 |
},
|
|
|
|
| 701 |
}
|
| 702 |
],
|
| 703 |
"source": [
|
| 704 |
+
"# Initialize a data collator to handle padding and formatting for token classification\n",
|
| 705 |
"data_collator = DataCollatorForTokenClassification(tokenizer)\n",
|
| 706 |
"\n",
|
| 707 |
+
"# Load a pre-trained model for token classification, adapted for NER tasks\n",
|
| 708 |
"model = AutoModelForTokenClassification.from_pretrained(\n",
|
| 709 |
+
" \"xlm-roberta-large\", # Base model (multilingual XLM-RoBERTa) for NER\n",
|
| 710 |
+
" num_labels=len(label_list) # Set the number of output labels to match NER categories\n",
|
| 711 |
")\n"
|
| 712 |
]
|
| 713 |
},
|
|
|
|
| 719 |
},
|
| 720 |
"outputs": [],
|
| 721 |
"source": [
|
| 722 |
+
"# Define a function to compute evaluation metrics for the model's predictions\n",
|
| 723 |
"def compute_metrics(p):\n",
|
| 724 |
+
" predictions, labels = p # Unpack predictions and true labels from the input\n",
|
| 725 |
+
"\n",
|
| 726 |
+
" # Convert logits to predicted label indices by taking the argmax along the last axis\n",
|
| 727 |
" predictions = np.argmax(predictions, axis=2)\n",
|
| 728 |
+
"\n",
|
| 729 |
+
" # Filter out special padding labels (-100) and convert indices to label names\n",
|
| 730 |
" true_labels = [[label_list[l] for l in label if l != -100] for label in labels]\n",
|
| 731 |
" true_predictions = [\n",
|
| 732 |
" [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
|
| 733 |
" for prediction, label in zip(predictions, labels)\n",
|
| 734 |
" ]\n",
|
| 735 |
+
"\n",
|
| 736 |
+
" # Print a detailed classification report for each label category\n",
|
| 737 |
" print(classification_report(true_labels, true_predictions))\n",
|
| 738 |
+
"\n",
|
| 739 |
+
" # Calculate and return key evaluation metrics\n",
|
| 740 |
" return {\n",
|
| 741 |
+
" # Precision measures the accuracy of predicted positive instances\n",
|
| 742 |
+
" # Important in NER to ensure entity predictions are correct and reduce false positives.\n",
|
| 743 |
" \"precision\": precision_score(true_labels, true_predictions),\n",
|
| 744 |
+
"\n",
|
| 745 |
+
" # Recall measures the model's ability to capture all relevant entities\n",
|
| 746 |
+
" # Essential in NER to ensure the model captures all entities, reducing false negatives.\n",
|
| 747 |
" \"recall\": recall_score(true_labels, true_predictions),\n",
|
| 748 |
+
"\n",
|
| 749 |
+
" # F1-score is the harmonic mean of precision and recall, balancing both metrics\n",
|
| 750 |
+
" # Useful in NER for providing an overall performance measure, especially when precision and recall are both important.\n",
|
| 751 |
" \"f1\": f1_score(true_labels, true_predictions),\n",
|
| 752 |
" }\n"
|
| 753 |
]
|
|
|
|
| 760 |
},
|
| 761 |
"outputs": [],
|
| 762 |
"source": [
|
| 763 |
+
"# Set up training arguments for model training, defining essential training configurations\n",
|
| 764 |
"training_args = TrainingArguments(\n",
|
| 765 |
+
" output_dir=\"./results\", # Directory to save model checkpoints and final outputs\n",
|
| 766 |
+
" evaluation_strategy=\"epoch\", # Evaluate model on the validation set at the end of each epoch\n",
|
| 767 |
+
" save_strategy=\"epoch\", # Save model checkpoints at the end of each epoch\n",
|
| 768 |
+
" learning_rate=2e-5, # Set a low learning rate to ensure stable training for fine-tuning\n",
|
| 769 |
+
" per_device_train_batch_size=128, # Number of examples per batch during training, balancing speed and memory\n",
|
| 770 |
+
" per_device_eval_batch_size=128, # Number of examples per batch during evaluation\n",
|
| 771 |
+
" num_train_epochs=12, # Number of full training passes over the dataset\n",
|
| 772 |
+
" weight_decay=0.005, # Regularization term to prevent overfitting by penalizing large weights\n",
|
| 773 |
+
" fp16=True, # Use 16-bit floating point for faster and memory-efficient training\n",
|
| 774 |
+
" logging_dir='./logs', # Directory to store training logs\n",
|
| 775 |
+
" save_total_limit=2, # Keep only the 2 latest model checkpoints to save storage space\n",
|
| 776 |
+
" load_best_model_at_end=True, # Load the best model based on metrics at the end of training\n",
|
| 777 |
+
" metric_for_best_model=\"f1\", # Use F1-score to determine the best model checkpoint\n",
|
| 778 |
+
" report_to=\"none\" # Disable reporting to external services (useful in local runs)\n",
|
| 779 |
")\n"
|
| 780 |
]
|
| 781 |
},
|
|
|
|
| 787 |
},
|
| 788 |
"outputs": [],
|
| 789 |
"source": [
|
| 790 |
+
"# Initialize the Trainer class to manage the training loop with all necessary components\n",
|
| 791 |
"trainer = Trainer(\n",
|
| 792 |
+
" model=model, # The pre-trained model to be fine-tuned\n",
|
| 793 |
+
" args=training_args, # Training configuration parameters defined in TrainingArguments\n",
|
| 794 |
+
" train_dataset=tokenized_datasets[\"train\"], # Tokenized training dataset\n",
|
| 795 |
+
" eval_dataset=tokenized_datasets[\"test\"], # Tokenized validation dataset\n",
|
| 796 |
+
" tokenizer=tokenizer, # Tokenizer used for processing input text\n",
|
| 797 |
+
" data_collator=data_collator, # Data collator for padding and batching during training\n",
|
| 798 |
+
" compute_metrics=compute_metrics, # Function to calculate evaluation metrics like precision, recall, F1\n",
|
| 799 |
+
" callbacks=[EarlyStoppingCallback(early_stopping_patience=5)] # Stop training early if validation metrics don't improve for 2 epochs\n",
|
| 800 |
")\n"
|
| 801 |
]
|
| 802 |
},
|
|
|
|
| 1094 |
}
|
| 1095 |
],
|
| 1096 |
"source": [
|
| 1097 |
+
"# Begin the training process and capture the training metrics\n",
|
| 1098 |
"training_metrics = trainer.train()\n",
|
| 1099 |
+
"\n",
|
| 1100 |
+
"# Evaluate the model on the validation set after training\n",
|
| 1101 |
"eval_results = trainer.evaluate()\n",
|
| 1102 |
+
"\n",
|
| 1103 |
+
"# Print evaluation results, including precision, recall, and F1-score\n",
|
| 1104 |
"print(eval_results)\n"
|
| 1105 |
]
|
| 1106 |
},
|
|
|
|
| 1141 |
}
|
| 1142 |
],
|
| 1143 |
"source": [
|
| 1144 |
+
"# Define the directory where the trained model and tokenizer will be saved\n",
|
| 1145 |
"save_directory = \"./XLM-RoBERTa\"\n",
|
| 1146 |
+
"\n",
|
| 1147 |
+
"# Save the trained model to the specified directory\n",
|
| 1148 |
"model.save_pretrained(save_directory)\n",
|
| 1149 |
+
"\n",
|
| 1150 |
+
"# Save the tokenizer to the same directory for compatibility with the model\n",
|
| 1151 |
"tokenizer.save_pretrained(save_directory)\n"
|
| 1152 |
]
|
| 1153 |
},
|
models/push_to_HF.py
CHANGED
|
@@ -10,8 +10,8 @@ hf_token = os.getenv("HUGGINGFACE_TOKEN")
|
|
| 10 |
login(token=hf_token)
|
| 11 |
|
| 12 |
# Define your repository ID
|
| 13 |
-
repo_id = "IsmatS/
|
| 14 |
|
| 15 |
# Initialize HfApi and upload the model folder
|
| 16 |
api = HfApi()
|
| 17 |
-
api.upload_folder(folder_path="./
|
|
|
|
| 10 |
login(token=hf_token)
|
| 11 |
|
| 12 |
# Define your repository ID
|
| 13 |
+
repo_id = "IsmatS/xlm_roberta_large_az_ner"
|
| 14 |
|
| 15 |
# Initialize HfApi and upload the model folder
|
| 16 |
api = HfApi()
|
| 17 |
+
api.upload_folder(folder_path="./xlm-roberta-large", path_in_repo="", repo_id=repo_id)
|
models/xlm_roberta_large.ipynb
CHANGED
|
@@ -2450,7 +2450,7 @@
|
|
| 2450 |
"colab": {
|
| 2451 |
"base_uri": "https://localhost:8080/"
|
| 2452 |
},
|
| 2453 |
-
"outputId": "
|
| 2454 |
},
|
| 2455 |
"execution_count": null,
|
| 2456 |
"outputs": [
|
|
@@ -2466,7 +2466,7 @@
|
|
| 2466 |
]
|
| 2467 |
},
|
| 2468 |
"metadata": {},
|
| 2469 |
-
"execution_count":
|
| 2470 |
}
|
| 2471 |
]
|
| 2472 |
},
|
|
|
|
| 2450 |
"colab": {
|
| 2451 |
"base_uri": "https://localhost:8080/"
|
| 2452 |
},
|
| 2453 |
+
"outputId": "d8184694-0ab9-44e4-9b4e-859cd2ea6188"
|
| 2454 |
},
|
| 2455 |
"execution_count": null,
|
| 2456 |
"outputs": [
|
|
|
|
| 2466 |
]
|
| 2467 |
},
|
| 2468 |
"metadata": {},
|
| 2469 |
+
"execution_count": 19
|
| 2470 |
}
|
| 2471 |
]
|
| 2472 |
},
|