Upload 5 files
Browse files- scripts/KIFS_filtering_script.ipynb +470 -0
- scripts/batch_translation.ipynb +137 -0
- scripts/corpus_stats.ipynb +232 -0
- scripts/finetuning_script.ipynb +442 -0
- scripts/intrinsic_evaluation.ipynb +141 -0
scripts/KIFS_filtering_script.ipynb
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "311e31e2",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"# Import pandas for DataFrame manipulation\n",
|
| 11 |
+
"import pandas as pd\n",
|
| 12 |
+
"# Import numpy for numerical operations\n",
|
| 13 |
+
"import numpy as np\n",
|
| 14 |
+
"# Import torch for tensor operations and device handling\n",
|
| 15 |
+
"import torch\n",
|
| 16 |
+
"# Import MBART model and tokenizer from Hugging Face Transformers\n",
|
| 17 |
+
"from transformers import MBartForConditionalGeneration, MBart50TokenizerFast\n",
|
| 18 |
+
"# Import cosine similarity for comparing embeddings\n",
|
| 19 |
+
"from sklearn.metrics.pairwise import cosine_similarity\n",
|
| 20 |
+
"# Import tqdm to show progress bars for loops\n",
|
| 21 |
+
"from tqdm import tqdm\n",
|
| 22 |
+
"# Import regex utilities for tokenization and cleaning\n",
|
| 23 |
+
"import re"
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"execution_count": null,
|
| 29 |
+
"id": "3363ac62",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"outputs": [],
|
| 32 |
+
"source": [
|
| 33 |
+
"# --- Configuration ---\n",
|
| 34 |
+
"MODEL_NAME = \"your/model/name\"\n",
|
| 35 |
+
"SRC_LANG_CODE = \"src_lang_code\"\n",
|
| 36 |
+
"TGT_LANG_CODE = \"tgt_lang_code\"\n",
|
| 37 |
+
"CORPUS_FILE = \"your/corpus/here.csv\"\n",
|
| 38 |
+
"DICT_FILE = \"your/bilingual/dictionary/here.csv\""
|
| 39 |
+
]
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"cell_type": "code",
|
| 43 |
+
"execution_count": null,
|
| 44 |
+
"id": "d5d67ae6",
|
| 45 |
+
"metadata": {},
|
| 46 |
+
"outputs": [],
|
| 47 |
+
"source": [
|
| 48 |
+
"# Hyperparameters for the Knowledge Score (KS_i)\n",
|
| 49 |
+
"# You would tune these based on empirical performance\n",
|
| 50 |
+
"ALPHA = 0.1\n",
|
| 51 |
+
"BETA = 0.3\n",
|
| 52 |
+
"GAMMA = 0.6\n",
|
| 53 |
+
"PERCENTILE_THRESHOLD = 70 # Filter threshold: keep pairs above this percentile"
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"cell_type": "code",
|
| 58 |
+
"execution_count": null,
|
| 59 |
+
"id": "f5fb7924",
|
| 60 |
+
"metadata": {},
|
| 61 |
+
"outputs": [],
|
| 62 |
+
"source": [
|
| 63 |
+
"def preprocess_text(text):\n",
|
| 64 |
+
" \"\"\"\n",
|
| 65 |
+
" Safely preprocesses text by handling NaN, non-string values,\n",
|
| 66 |
+
" and performing normalization steps.\n",
|
| 67 |
+
" \"\"\"\n",
|
| 68 |
+
" if not isinstance(text, str):\n",
|
| 69 |
+
" return \"\"\n",
|
| 70 |
+
" text = text.strip().lower()\n",
|
| 71 |
+
" text = re.sub(r\"\\s+\", \" \", text) # Collapse multiple spaces\n",
|
| 72 |
+
" text = re.sub(r\"[^a-zA-Z0-9\\s']\", \"\", text) # Remove unwanted symbols (keep alphanumerics and apostrophes)\n",
|
| 73 |
+
" return text\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"def load_data(corpus_file, dict_file):\n",
|
| 77 |
+
" \"\"\"Loads, cleans, and prepares the parallel corpus and bilingual dictionary.\"\"\"\n",
|
| 78 |
+
"\n",
|
| 79 |
+
" # --- Load the CSVs safely ---\n",
|
| 80 |
+
" try:\n",
|
| 81 |
+
" raw_corpus = pd.read_csv(corpus_file)\n",
|
| 82 |
+
" word_dictionary = pd.read_csv(dict_file)\n",
|
| 83 |
+
" except Exception as e:\n",
|
| 84 |
+
" raise ValueError(f\"Error loading files: {e}\")\n",
|
| 85 |
+
"\n",
|
| 86 |
+
" # --- Ensure expected columns exist ---\n",
|
| 87 |
+
" required_corpus_cols = {'English', 'Tagin'}\n",
|
| 88 |
+
" required_dict_cols = {'English', 'Tagin'}\n",
|
| 89 |
+
"\n",
|
| 90 |
+
" if not required_corpus_cols.issubset(raw_corpus.columns):\n",
|
| 91 |
+
" raise ValueError(f\"Corpus file must contain columns: {required_corpus_cols}\")\n",
|
| 92 |
+
" if not required_dict_cols.issubset(word_dictionary.columns):\n",
|
| 93 |
+
" raise ValueError(f\"Dictionary file must contain columns: {required_dict_cols}\")\n",
|
| 94 |
+
"\n",
|
| 95 |
+
" # --- Drop rows with all NaN values ---\n",
|
| 96 |
+
" raw_corpus = raw_corpus.dropna(how='all')\n",
|
| 97 |
+
"\n",
|
| 98 |
+
" # --- Fill NaN cells with empty strings ---\n",
|
| 99 |
+
" raw_corpus = raw_corpus.fillna(\"\")\n",
|
| 100 |
+
"\n",
|
| 101 |
+
" # --- Apply text preprocessing ---\n",
|
| 102 |
+
" raw_corpus[\"English\"] = raw_corpus[\"English\"].apply(preprocess_text)\n",
|
| 103 |
+
" raw_corpus[\"Tagin\"] = raw_corpus[\"Tagin\"].apply(preprocess_text)\n",
|
| 104 |
+
"\n",
|
| 105 |
+
" # --- Clean dictionary entries ---\n",
|
| 106 |
+
" word_dictionary[\"English\"] = word_dictionary[\"English\"].apply(preprocess_text)\n",
|
| 107 |
+
" word_dictionary[\"Tagin\"] = word_dictionary[\"Tagin\"].apply(preprocess_text)\n",
|
| 108 |
+
"\n",
|
| 109 |
+
" # --- Convert dictionary to mapping ---\n",
|
| 110 |
+
" word_dictionary = word_dictionary.set_index('English')['Tagin'].to_dict()\n",
|
| 111 |
+
"\n",
|
| 112 |
+
" # --- Remove empty rows after cleaning ---\n",
|
| 113 |
+
" raw_corpus = raw_corpus[\n",
|
| 114 |
+
" (raw_corpus[\"English\"].str.strip() != \"\") &\n",
|
| 115 |
+
" (raw_corpus[\"Tagin\"].str.strip() != \"\")\n",
|
| 116 |
+
" ].reset_index(drop=True)\n",
|
| 117 |
+
"\n",
|
| 118 |
+
" print(f\"Loaded {len(raw_corpus)} sentence pairs and {len(word_dictionary)} dictionary entries.\")\n",
|
| 119 |
+
" return raw_corpus, word_dictionary"
|
| 120 |
+
]
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"cell_type": "code",
|
| 124 |
+
"execution_count": null,
|
| 125 |
+
"id": "772824d1",
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"outputs": [],
|
| 128 |
+
"source": [
|
| 129 |
+
"load_data(CORPUS_FILE,DICT_FILE)"
|
| 130 |
+
]
|
| 131 |
+
},
|
| 132 |
+
{
|
| 133 |
+
"cell_type": "code",
|
| 134 |
+
"execution_count": null,
|
| 135 |
+
"id": "7322656f",
|
| 136 |
+
"metadata": {},
|
| 137 |
+
"outputs": [],
|
| 138 |
+
"source": [
|
| 139 |
+
"# Function for Step 2: Perplexity (PPL)\n",
|
| 140 |
+
"@torch.no_grad()\n",
|
| 141 |
+
"def calculate_perplexity(sentence, model, tokenizer, device):\n",
|
| 142 |
+
" \"\"\"Computes perplexity of a sentence using the given LM.\"\"\"\n",
|
| 143 |
+
" try:\n",
|
| 144 |
+
" # Tokenize and format for mBART-50 (e.g., [lang_code] X [eos])\n",
|
| 145 |
+
" # We'll treat this as a generation task from the source language to itself\n",
|
| 146 |
+
" # to get log probabilities for the language modeling loss.\n",
|
| 147 |
+
" input_ids = tokenizer(\n",
|
| 148 |
+
" sentence,\n",
|
| 149 |
+
" return_tensors=\"pt\",\n",
|
| 150 |
+
" max_length=512,\n",
|
| 151 |
+
" truncation=True\n",
|
| 152 |
+
" ).input_ids.to(device)\n",
|
| 153 |
+
" \n",
|
| 154 |
+
" # Set the source language\n",
|
| 155 |
+
" tokenizer.src_lang = SRC_LANG_CODE\n",
|
| 156 |
+
" \n",
|
| 157 |
+
" # The labels for perplexity are the input tokens themselves, shifted.\n",
|
| 158 |
+
" # This is essentially a language modeling task.\n",
|
| 159 |
+
" labels = input_ids.clone()\n",
|
| 160 |
+
" \n",
|
| 161 |
+
" # Use -100 to ignore the loss for special tokens (like the language code token)\n",
|
| 162 |
+
" labels[:, 0] = -100\n",
|
| 163 |
+
"\n",
|
| 164 |
+
" outputs = model(input_ids=input_ids, labels=labels)\n",
|
| 165 |
+
" neg_log_likelihood = outputs.loss\n",
|
| 166 |
+
" \n",
|
| 167 |
+
" # Perplexity is exp(average negative log-likelihood)\n",
|
| 168 |
+
" # The 'outputs.loss' from the Transformers library is already the average NLL per token.\n",
|
| 169 |
+
" ppl = torch.exp(neg_log_likelihood).item()\n",
|
| 170 |
+
" return ppl\n",
|
| 171 |
+
" except Exception as e:\n",
|
| 172 |
+
" print(f\"Error calculating PPL for: '{sentence}'. Error: {e}\")\n",
|
| 173 |
+
" return float('inf') # Return a very high PPL for errors/bad sentences\n",
|
| 174 |
+
"\n"
|
| 175 |
+
]
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"cell_type": "code",
|
| 179 |
+
"execution_count": null,
|
| 180 |
+
"id": "231f19a8",
|
| 181 |
+
"metadata": {},
|
| 182 |
+
"outputs": [],
|
| 183 |
+
"source": [
|
| 184 |
+
"def normalize_inverse_ppl(ppl_scores, epsilon=1e-6):\n",
|
| 185 |
+
" \"\"\"\n",
|
| 186 |
+
" Safely normalizes inverse perplexity (1/PPL_i) to [0, 1].\n",
|
| 187 |
+
" \n",
|
| 188 |
+
" Handles edge cases where PPL scores are constant, contain inf/nan, or are invalid.\n",
|
| 189 |
+
" \"\"\"\n",
|
| 190 |
+
" ppl_scores = np.array(ppl_scores, dtype=np.float64)\n",
|
| 191 |
+
"\n",
|
| 192 |
+
" # Replace infinities or NaNs with large finite numbers for stability\n",
|
| 193 |
+
" ppl_scores = np.nan_to_num(ppl_scores, nan=np.inf, posinf=np.inf, neginf=np.inf)\n",
|
| 194 |
+
"\n",
|
| 195 |
+
" # Compute inverse PPL (fluency measure)\n",
|
| 196 |
+
" inv_ppl = 1.0 / (ppl_scores + epsilon)\n",
|
| 197 |
+
"\n",
|
| 198 |
+
" # Remove any remaining NaNs/Infs from inverse scores\n",
|
| 199 |
+
" inv_ppl = np.nan_to_num(inv_ppl, nan=0.0, posinf=0.0, neginf=0.0)\n",
|
| 200 |
+
"\n",
|
| 201 |
+
" inv_min = np.min(inv_ppl)\n",
|
| 202 |
+
" inv_max = np.max(inv_ppl)\n",
|
| 203 |
+
"\n",
|
| 204 |
+
" # Handle zero-range case: all scores are the same\n",
|
| 205 |
+
" if np.isclose(inv_max, inv_min) or np.isnan(inv_max - inv_min):\n",
|
| 206 |
+
" return np.zeros_like(inv_ppl)\n",
|
| 207 |
+
"\n",
|
| 208 |
+
" # Normal min–max scaling\n",
|
| 209 |
+
" inv_ppl_norm = (inv_ppl - inv_min) / (inv_max - inv_min)\n",
|
| 210 |
+
" inv_ppl_norm = np.clip(inv_ppl_norm, 0.0, 1.0)\n",
|
| 211 |
+
"\n",
|
| 212 |
+
" return inv_ppl_norm\n",
|
| 213 |
+
"\n"
|
| 214 |
+
]
|
| 215 |
+
},
|
| 216 |
+
{
|
| 217 |
+
"cell_type": "code",
|
| 218 |
+
"execution_count": null,
|
| 219 |
+
"id": "ad791177",
|
| 220 |
+
"metadata": {},
|
| 221 |
+
"outputs": [],
|
| 222 |
+
"source": [
|
| 223 |
+
"# Function for Step 3: Semantic Similarity (Sim)\n",
|
| 224 |
+
"\n",
|
| 225 |
+
"def calculate_semantic_similarity(s_i, t_i, model, tokenizer, device):\n",
|
| 226 |
+
" \"\"\"\n",
|
| 227 |
+
" Computes Cosine Similarity between source and target sentence embeddings \n",
|
| 228 |
+
" and normalizes the result to the range [0, 1].\n",
|
| 229 |
+
" \"\"\"\n",
|
| 230 |
+
" try:\n",
|
| 231 |
+
" def get_embedding(sentence, lang_code):\n",
|
| 232 |
+
" tokenizer.src_lang = lang_code\n",
|
| 233 |
+
" inputs = tokenizer(\n",
|
| 234 |
+
" sentence,\n",
|
| 235 |
+
" return_tensors=\"pt\",\n",
|
| 236 |
+
" max_length=512,\n",
|
| 237 |
+
" truncation=True,\n",
|
| 238 |
+
" padding=True\n",
|
| 239 |
+
" ).to(device)\n",
|
| 240 |
+
" \n",
|
| 241 |
+
" with torch.no_grad():\n",
|
| 242 |
+
" encoder_output = model.model.encoder(**inputs).last_hidden_state\n",
|
| 243 |
+
" \n",
|
| 244 |
+
" mean_embedding = encoder_output[:, 1:-1, :].mean(dim=1).squeeze() \n",
|
| 245 |
+
" \n",
|
| 246 |
+
" return mean_embedding.cpu().detach().numpy().reshape(1, -1)\n",
|
| 247 |
+
"\n",
|
| 248 |
+
" emb_s = get_embedding(s_i, SRC_LANG_CODE) \n",
|
| 249 |
+
" emb_t = get_embedding(t_i, TGT_LANG_CODE)\n",
|
| 250 |
+
"\n",
|
| 251 |
+
" sim_raw = cosine_similarity(emb_s, emb_t)[0][0]\n",
|
| 252 |
+
" \n",
|
| 253 |
+
" sim_normalized = (sim_raw + 1) / 2\n",
|
| 254 |
+
" sim_normalized = max(0.0, min(1.0, sim_normalized))\n",
|
| 255 |
+
" \n",
|
| 256 |
+
" return sim_normalized\n",
|
| 257 |
+
" \n",
|
| 258 |
+
" except Exception as e:\n",
|
| 259 |
+
" # print(f\"Error calculating Sim for: '{s_i}' and '{t_i}'. Error: {e}\")\n",
|
| 260 |
+
" return 0.0"
|
| 261 |
+
]
|
| 262 |
+
},
|
| 263 |
+
{
|
| 264 |
+
"cell_type": "code",
|
| 265 |
+
"execution_count": null,
|
| 266 |
+
"id": "87eebd8b",
|
| 267 |
+
"metadata": {},
|
| 268 |
+
"outputs": [],
|
| 269 |
+
"source": [
|
| 270 |
+
"# Function for Step 4: Lexical Match (Lex) # header describing the block\n",
|
| 271 |
+
"# blank line preserved for readability\n",
|
| 272 |
+
"# Define a function that computes a lexical match score based on a bilingual dictionary\n",
|
| 273 |
+
"def calculate_lexical_match(s_i, t_i, word_dictionary):\n",
|
| 274 |
+
" # Docstring start: describe purpose and formula for lex score\n",
|
| 275 |
+
" \"\"\"\n",
|
| 276 |
+
" Computes a dictionary-based lexical match score prioritizing phrase matches.\n",
|
| 277 |
+
" Score = (Count of source words covered by successfully translated phrases) / (Total words in source sentence)\n",
|
| 278 |
+
" \"\"\" # docstring end\n",
|
| 279 |
+
" # Helper: normalize text to ease phrase matching (lowercase, token boundaries)\n",
|
| 280 |
+
" def normalize_text(text):\n",
|
| 281 |
+
" # Simple tokenization: lowercase, remove non-word characters, and join back for easy phrase matching\n",
|
| 282 |
+
" return \" \" + \" \".join(re.findall(r'\\b\\w+\\b', text.lower())) + \" \" # pad with spaces for boundary-safe matching\n",
|
| 283 |
+
" # Normalize the source sentence for phrase lookups\n",
|
| 284 |
+
" s_normalized = normalize_text(s_i)\n",
|
| 285 |
+
" # Token set of the target sentence for quick membership tests\n",
|
| 286 |
+
" t_tokens = set(re.findall(r'\\b\\w+\\b', t_i.lower()))\n",
|
| 287 |
+
" # blank line preserved for readability\n",
|
| 288 |
+
" # Sort dictionary keys by length (descending) to prioritize phrase matches over single words\n",
|
| 289 |
+
" tagin_phrases = sorted(word_dictionary.keys(), key=len, reverse=True)\n",
|
| 290 |
+
" # blank line preserved for readability\n",
|
| 291 |
+
" # Extract source word tokens and compute total count\n",
|
| 292 |
+
" source_words = re.findall(r'\\b\\w+\\b', s_i.lower())\n",
|
| 293 |
+
" total_source_words = len(source_words)\n",
|
| 294 |
+
" # Initialize covered word counter\n",
|
| 295 |
+
" covered_word_count = 0\n",
|
| 296 |
+
" # If the source sentence is empty, return 0.0 immediately\n",
|
| 297 |
+
" if total_source_words == 0:\n",
|
| 298 |
+
" return 0.0\n",
|
| 299 |
+
" # Track indices of covered words if needed (not used further but kept for clarity)\n",
|
| 300 |
+
" covered_indices = set()\n",
|
| 301 |
+
" # Iterate over dictionary phrases (longest-first) to find matches in the source\n",
|
| 302 |
+
" for phrase in tagin_phrases:\n",
|
| 303 |
+
" # Skip empty dictionary entries\n",
|
| 304 |
+
" if not phrase:\n",
|
| 305 |
+
" continue\n",
|
| 306 |
+
" # Normalize the phrase for safe matching\n",
|
| 307 |
+
" norm_phrase = normalize_text(phrase)\n",
|
| 308 |
+
" # If the normalized phrase exists in the normalized source text, proceed\n",
|
| 309 |
+
" if norm_phrase in s_normalized:\n",
|
| 310 |
+
" # Get expected translation from the dictionary (lowercased)\n",
|
| 311 |
+
" expected_translation = word_dictionary[phrase].lower()\n",
|
| 312 |
+
" # Tokenize the expected translation into words\n",
|
| 313 |
+
" translation_words = re.findall(r'\\b\\w+\\b', expected_translation)\n",
|
| 314 |
+
" # Check whether all translated words appear in the target sentence tokens\n",
|
| 315 |
+
" is_translation_present = all(word in t_tokens for word in translation_words)\n",
|
| 316 |
+
" # If the translation words are present in the target, count the phrase as covered\n",
|
| 317 |
+
" if is_translation_present:\n",
|
| 318 |
+
" # Search for possibly multiple occurrences of the phrase in the source\n",
|
| 319 |
+
" start = 0\n",
|
| 320 |
+
" while True:\n",
|
| 321 |
+
" # Find next occurrence starting from 'start' index\n",
|
| 322 |
+
" start_index = s_normalized.find(norm_phrase, start)\n",
|
| 323 |
+
" # If no more occurrences, break the loop\n",
|
| 324 |
+
" if start_index == -1:\n",
|
| 325 |
+
" break\n",
|
| 326 |
+
" # Count how many words are in the matched phrase\n",
|
| 327 |
+
" phrase_word_count = len(re.findall(r'\\b\\w+\\b', phrase))\n",
|
| 328 |
+
" # Add the phrase's word count to the covered total\n",
|
| 329 |
+
" covered_word_count += phrase_word_count\n",
|
| 330 |
+
" # Advance the search start position past the current match\n",
|
| 331 |
+
" start = start_index + len(norm_phrase)\n",
|
| 332 |
+
" # end while loop for occurrences\n",
|
| 333 |
+
" # After checking all phrases, compute lex score as covered words / total source words (capped at 1.0)\n",
|
| 334 |
+
" lex_score = min(1.0, covered_word_count / total_source_words)\n",
|
| 335 |
+
" # Return lexical match score between 0 and 1\n",
|
| 336 |
+
" return lex_score"
|
| 337 |
+
]
|
| 338 |
+
},
|
| 339 |
+
{
|
| 340 |
+
"cell_type": "code",
|
| 341 |
+
"execution_count": null,
|
| 342 |
+
"id": "46f310e4",
|
| 343 |
+
"metadata": {},
|
| 344 |
+
"outputs": [],
|
| 345 |
+
"source": [
|
| 346 |
+
"# Main Algorithm Implementation\n",
|
| 347 |
+
"def knowledge_based_filtering(raw_corpus, word_dictionary, alpha, beta, gamma, percentile_threshold):\n",
|
| 348 |
+
" # 1. Load Model and Tokenizer (Step 1 of the algorithm's loop)\n",
|
| 349 |
+
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 350 |
+
" print(f\"Loading mBART-tgj-base model to {device}...\")\n",
|
| 351 |
+
" model = MBartForConditionalGeneration.from_pretrained(MODEL_NAME).to(device)\n",
|
| 352 |
+
" tokenizer = MBart50TokenizerFast.from_pretrained(MODEL_NAME)\n",
|
| 353 |
+
" \n",
|
| 354 |
+
" # Ensure source/target language codes are in the tokenizer vocabulary\n",
|
| 355 |
+
" if SRC_LANG_CODE not in tokenizer.vocab or TGT_LANG_CODE not in tokenizer.vocab:\n",
|
| 356 |
+
" print(f\"Warning: Language codes {SRC_LANG_CODE} or {TGT_LANG_CODE} not found in base mBART-tgj-base vocab.\")\n",
|
| 357 |
+
" print(\"Using placeholder language codes. Results may not be accurate.\")\n",
|
| 358 |
+
"\n",
|
| 359 |
+
" results = []\n",
|
| 360 |
+
"\n",
|
| 361 |
+
" # 2. Iterate through the corpus (Lines 1-6)\n",
|
| 362 |
+
" print(\"Processing corpus to calculate scores...\")\n",
|
| 363 |
+
" for index, row in tqdm(raw_corpus.iterrows(), total=len(raw_corpus), desc=\"Calculating KS\"):\n",
|
| 364 |
+
" s_i = row['English']\n",
|
| 365 |
+
" t_i = row['Tagin']\n",
|
| 366 |
+
"\n",
|
| 367 |
+
" # Line 2: Compute PPL_i (Lower is better)\n",
|
| 368 |
+
" pp= calculate_perplexity(s_i, model, tokenizer, device)\n",
|
| 369 |
+
" PPL_i= normalize_inverse_ppl(pp, epsilon=1e-6)\n",
|
| 370 |
+
" # PPL_i = normalize_inverse_ppl(row[\"Perplexity\"])\n",
|
| 371 |
+
" \n",
|
| 372 |
+
" # Line 3: Compute Sim_i (Higher is better)\n",
|
| 373 |
+
" Sim_i = calculate_semantic_similarity(s_i, t_i, model, tokenizer, device)\n",
|
| 374 |
+
" \n",
|
| 375 |
+
" # Line 4: Check Lex_i (Higher is better)\n",
|
| 376 |
+
" Lex_i = calculate_lexical_match(s_i, t_i, word_dictionary)\n",
|
| 377 |
+
" \n",
|
| 378 |
+
" # Line 5: Derive Knowledge Score (KS_i)\n",
|
| 379 |
+
" # Note: We use 1/PPL_i because PPL_i is an inverse quality metric (lower PPL is higher quality)\n",
|
| 380 |
+
" # while Sim and Lex are direct quality metrics (higher is better).\n",
|
| 381 |
+
" # We add a small epsilon to avoid division by zero, though a PPL of 0 is practically impossible.\n",
|
| 382 |
+
" # PPL_i_inv = 1.0 / (PPL_i + 1e-6)\n",
|
| 383 |
+
" # -----IMPORTANT------\n",
|
| 384 |
+
" \n",
|
| 385 |
+
" KS_i = alpha * PPL_i + beta * Sim_i + gamma * Lex_i\n",
|
| 386 |
+
" \n",
|
| 387 |
+
" results.append({\n",
|
| 388 |
+
" 'src_lang': s_i,\n",
|
| 389 |
+
" 'tgt_lang': t_i,\n",
|
| 390 |
+
" 'PPL_i': PPL_i,\n",
|
| 391 |
+
" 'Sim_i': Sim_i,\n",
|
| 392 |
+
" 'Lex_i': Lex_i,\n",
|
| 393 |
+
" 'PPL_i': PPL_i,\n",
|
| 394 |
+
" 'KS_i': KS_i\n",
|
| 395 |
+
" })\n",
|
| 396 |
+
"\n",
|
| 397 |
+
" # Convert results to DataFrame for filtering\n",
|
| 398 |
+
" scored_corpus = pd.DataFrame(results)\n",
|
| 399 |
+
"\n",
|
| 400 |
+
" # 3. Determine Threshold and Filter (Lines 7-9)\n",
|
| 401 |
+
" # Line 7: Find the 80th percentile of Knowledge Scores\n",
|
| 402 |
+
" tau_K = np.percentile(scored_corpus['KS_i'], percentile_threshold)\n",
|
| 403 |
+
" print(f\"\\n50th Percentile Knowledge Score (τ_K): {tau_K:.4f}\")\n",
|
| 404 |
+
" \n",
|
| 405 |
+
" # Line 8: Filter the corpus\n",
|
| 406 |
+
" D_filtered = scored_corpus[scored_corpus['KS_i'] >= tau_K].copy()\n",
|
| 407 |
+
" \n",
|
| 408 |
+
" # Final cleanup of columns and return\n",
|
| 409 |
+
" D_filtered = D_filtered[['src_lang', 'tgt_lang', 'KS_i']]\n",
|
| 410 |
+
" print(f\"Raw corpus size: {len(raw_corpus)}\")\n",
|
| 411 |
+
" print(f\"Filtered corpus size (KS_i >= τ_K): {len(D_filtered)}\")\n",
|
| 412 |
+
" \n",
|
| 413 |
+
" return D_filtered"
|
| 414 |
+
]
|
| 415 |
+
},
|
| 416 |
+
{
|
| 417 |
+
"cell_type": "code",
|
| 418 |
+
"execution_count": null,
|
| 419 |
+
"id": "2b2ce69d",
|
| 420 |
+
"metadata": {},
|
| 421 |
+
"outputs": [],
|
| 422 |
+
"source": [
|
| 423 |
+
"# --- Execution --- # script entry and high-level execution steps\n",
|
| 424 |
+
"# blank line preserved for readability\n",
|
| 425 |
+
"# Guard to ensure code only runs when executed as a script, not on import\n",
|
| 426 |
+
"if __name__ == '__main__':\n",
|
| 427 |
+
" # 1. Load data # load and preprocess corpus and dictionary files\n",
|
| 428 |
+
" raw_corpus, word_dictionary = load_data(CORPUS_FILE, DICT_FILE)\n",
|
| 429 |
+
" # blank line preserved for readability\n",
|
| 430 |
+
" # 2. Run the filtering algorithm # compute KS_i and filter by percentile\n",
|
| 431 |
+
" filtered_corpus = knowledge_based_filtering(\n",
|
| 432 |
+
" raw_corpus, # pass the preprocessed corpus DataFrame\n",
|
| 433 |
+
" word_dictionary, # pass the dictionary mapping\n",
|
| 434 |
+
" ALPHA, BETA, GAMMA, # weighting hyperparameters for KS_i\n",
|
| 435 |
+
" PERCENTILE_THRESHOLD # percentile cutoff for filtering\n",
|
| 436 |
+
" )\n",
|
| 437 |
+
" # blank line preserved for readability\n",
|
| 438 |
+
" # Save the filtered corpus to CSV for downstream use\n",
|
| 439 |
+
" filtered_corpus.to_csv(\"tgj_corpus_filtered_70th.csv\", index=False)\n",
|
| 440 |
+
" # blank line preserved for readability\n",
|
| 441 |
+
" # Notify user of completion and where results were saved\n",
|
| 442 |
+
" print(\"\\nFiltering complete. Results saved to tgj_corpus_filtered_70th.csv\")\n",
|
| 443 |
+
" # Show a short preview of the filtered corpus\n",
|
| 444 |
+
" print(\"\\nFiltered Corpus Head:\")\n",
|
| 445 |
+
" print(filtered_corpus) # print DataFrame to stdout"
|
| 446 |
+
]
|
| 447 |
+
}
|
| 448 |
+
],
|
| 449 |
+
"metadata": {
|
| 450 |
+
"kernelspec": {
|
| 451 |
+
"display_name": "ptorch",
|
| 452 |
+
"language": "python",
|
| 453 |
+
"name": "python3"
|
| 454 |
+
},
|
| 455 |
+
"language_info": {
|
| 456 |
+
"codemirror_mode": {
|
| 457 |
+
"name": "ipython",
|
| 458 |
+
"version": 3
|
| 459 |
+
},
|
| 460 |
+
"file_extension": ".py",
|
| 461 |
+
"mimetype": "text/x-python",
|
| 462 |
+
"name": "python",
|
| 463 |
+
"nbconvert_exporter": "python",
|
| 464 |
+
"pygments_lexer": "ipython3",
|
| 465 |
+
"version": "3.12.11"
|
| 466 |
+
}
|
| 467 |
+
},
|
| 468 |
+
"nbformat": 4,
|
| 469 |
+
"nbformat_minor": 5
|
| 470 |
+
}
|
scripts/batch_translation.ipynb
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "2230ec1b",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"from transformers import MBartForConditionalGeneration, MBart50TokenizerFast # MBART model and tokenizer classes\n",
|
| 11 |
+
"from tqdm import tqdm # progress bar for loops\n",
|
| 12 |
+
"import torch # PyTorch for tensors and device handling\n",
|
| 13 |
+
"import csv # CSV writer for output"
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "code",
|
| 18 |
+
"execution_count": null,
|
| 19 |
+
"id": "7552da07",
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"source": [
|
| 23 |
+
"# Load tokenizer and model (local path to your fine-tuned model)\n",
|
| 24 |
+
"model_path = \"./combined_training/en_tgj_combined_model\" # path to fine-tuned model directory (change if needed)\n",
|
| 25 |
+
"tokenizer = MBart50TokenizerFast.from_pretrained(model_path) # load tokenizer from model path\n",
|
| 26 |
+
"model = MBartForConditionalGeneration.from_pretrained(model_path) # load model weights and config\n",
|
| 27 |
+
"model.eval() # set model to evaluation mode (disables dropout)\n",
|
| 28 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # prefer GPU if available\n",
|
| 29 |
+
"model.to(device) # move model to selected device"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": null,
|
| 35 |
+
"id": "750545cc",
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"outputs": [],
|
| 38 |
+
"source": [
|
| 39 |
+
"# Parameters for tokenization and generation\n",
|
| 40 |
+
"src_lang_token = \"en_XX\" # MBART source language token to prepend\n",
|
| 41 |
+
"tgt_lang_token = \"<tgn_IN>\" # target language token / forced BOS for generation\n",
|
| 42 |
+
"batch_size = 16 # number of sentences per batch\n",
|
| 43 |
+
"max_length = 128 # maximum token length for tokenization and generation"
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "code",
|
| 48 |
+
"execution_count": null,
|
| 49 |
+
"id": "03415b52",
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"outputs": [],
|
| 52 |
+
"source": [
|
| 53 |
+
"# Read English sentences from a text file (one sentence per line)\n",
|
| 54 |
+
"with open(\"./sentences01.txt\", \"r\", encoding=\"utf-8\") as f: # input file path\n",
|
| 55 |
+
" english_sentences = [line.strip() for line in f if line.strip()] # strip and ignore empty lines"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": null,
|
| 61 |
+
"id": "c0927f85",
|
| 62 |
+
"metadata": {},
|
| 63 |
+
"outputs": [],
|
| 64 |
+
"source": [
|
| 65 |
+
"# Prepend the MBART source language token to each sentence\n",
|
| 66 |
+
"prefixed_sentences = [f\"{src_lang_token} {s}\" for s in english_sentences] # required by MBART tokenizer\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"# Prepare a list to collect generated translations\n",
|
| 69 |
+
"translated_sentences = [] # will hold output strings"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "code",
|
| 74 |
+
"execution_count": null,
|
| 75 |
+
"id": "86b13113",
|
| 76 |
+
"metadata": {},
|
| 77 |
+
"outputs": [],
|
| 78 |
+
"source": [
|
| 79 |
+
"# Iterate through sentences in batches and generate translations\n",
|
| 80 |
+
"for i in tqdm(range(0, len(prefixed_sentences), batch_size), desc=\"Batch Translating\"): # batching loop\n",
|
| 81 |
+
" batch = prefixed_sentences[i:i+batch_size] # take a slice for this batch\n",
|
| 82 |
+
"\n",
|
| 83 |
+
" # Tokenize the batch and move tensors to the model device\n",
|
| 84 |
+
" inputs = tokenizer(batch, return_tensors=\"pt\", padding=True, truncation=True, max_length=max_length).to(device)\n",
|
| 85 |
+
"\n",
|
| 86 |
+
" with torch.no_grad(): # disable gradients for inference to save memory\n",
|
| 87 |
+
" generated_tokens = model.generate(\n",
|
| 88 |
+
" **inputs, # pass input_ids, attention_mask, etc.\n",
|
| 89 |
+
" forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang_token), # ensure generation uses target language token\n",
|
| 90 |
+
" max_length=max_length, # cap the generated length\n",
|
| 91 |
+
" num_beams=5, # beam search for higher-quality decoding\n",
|
| 92 |
+
" early_stopping=True, # stop once beams finish\n",
|
| 93 |
+
" )\n",
|
| 94 |
+
"\n",
|
| 95 |
+
" # Decode token IDs to text and collect results\n",
|
| 96 |
+
" outputs = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) # convert ids to strings\n",
|
| 97 |
+
" translated_sentences.extend(outputs) # append batch outputs to final list"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"execution_count": null,
|
| 103 |
+
"id": "6d9a12d2",
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"outputs": [],
|
| 106 |
+
"source": [
|
| 107 |
+
"# Write aligned original and translated sentences to a CSV file\n",
|
| 108 |
+
"with open(\"./output_entgj_combined01.csv\", \"w\", encoding=\"utf-8\", newline=\"\") as f: # output file path\n",
|
| 109 |
+
" writer = csv.writer(f) # CSV writer object\n",
|
| 110 |
+
" writer.writerow([\"original\", \"translated\"]) # write header row\n",
|
| 111 |
+
" for src, tgt in zip(english_sentences, translated_sentences): # iterate aligned pairs\n",
|
| 112 |
+
" writer.writerow([src, tgt]) # write each pair as a row"
|
| 113 |
+
]
|
| 114 |
+
}
|
| 115 |
+
],
|
| 116 |
+
"metadata": {
|
| 117 |
+
"kernelspec": {
|
| 118 |
+
"display_name": "ptorch",
|
| 119 |
+
"language": "python",
|
| 120 |
+
"name": "python3"
|
| 121 |
+
},
|
| 122 |
+
"language_info": {
|
| 123 |
+
"codemirror_mode": {
|
| 124 |
+
"name": "ipython",
|
| 125 |
+
"version": 3
|
| 126 |
+
},
|
| 127 |
+
"file_extension": ".py",
|
| 128 |
+
"mimetype": "text/x-python",
|
| 129 |
+
"name": "python",
|
| 130 |
+
"nbconvert_exporter": "python",
|
| 131 |
+
"pygments_lexer": "ipython3",
|
| 132 |
+
"version": "3.12.11"
|
| 133 |
+
}
|
| 134 |
+
},
|
| 135 |
+
"nbformat": 4,
|
| 136 |
+
"nbformat_minor": 5
|
| 137 |
+
}
|
scripts/corpus_stats.ipynb
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "54834b8c",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import pandas as pd # import pandas and alias as pd for DataFrame operations\n",
|
| 11 |
+
"# blank line kept for readability\n",
|
| 12 |
+
"# Load your CSV # comment indicating next line loads the CSV into a DataFrame\n",
|
| 13 |
+
"df = pd.read_csv(\"filtered_corpus_here.csv\") # read the CSV file into variable df\n",
|
| 14 |
+
"# blank line kept for readability\n",
|
| 15 |
+
"# Normalize whitespace so \"hello world\" and \"hello world\" match # explain normalization intent\n",
|
| 16 |
+
"df['src_norm'] = df['src_lang'].str.strip().str.replace(r'\\s+', ' ', regex=True) # strip edges and collapse multiple spaces in source column\n",
|
| 17 |
+
"df['tgt_norm'] = df['tgt_lang'].str.strip().str.replace(r'\\s+', ' ', regex=True) # same normalization for target column\n",
|
| 18 |
+
"# blank line kept for readability\n",
|
| 19 |
+
"# Drop duplicates based on combined src+tgt # remove identical source-target pairs after normalization\n",
|
| 20 |
+
"df_unique = df.drop_duplicates(subset=['src_norm', 'tgt_norm'], keep='first') # keep first occurrence of duplicate pairs\n",
|
| 21 |
+
"# blank line kept for readability\n",
|
| 22 |
+
"# Remove helper columns # drop intermediate normalization columns before saving\n",
|
| 23 |
+
"df_unique = df_unique.drop(columns=['src_norm', 'tgt_norm']) # remove the temporary normalized columns\n",
|
| 24 |
+
"# blank line kept for readability\n",
|
| 25 |
+
"# Save result # write the deduplicated DataFrame to a new CSV file\n",
|
| 26 |
+
"df_unique.to_csv(\"filtered_corpus_here_removedDuplicates.csv\", index=False) # save without row index\n",
|
| 27 |
+
"# blank line kept for readability\n",
|
| 28 |
+
"print(\"Done. Original rows:\", len(df)) # print the original number of rows\n",
|
| 29 |
+
"print(\"New rows:\", len(df_unique)) # print the number of rows after deduplication\n",
|
| 30 |
+
"print(\"Removed:\", len(df) - len(df_unique)) # print how many rows were removed\n",
|
| 31 |
+
"# end of cell"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": null,
|
| 37 |
+
"id": "e73f3d91",
|
| 38 |
+
"metadata": {},
|
| 39 |
+
"outputs": [],
|
| 40 |
+
"source": [
|
| 41 |
+
"import pandas as pd # pandas for DataFrame operations\n",
|
| 42 |
+
"import numpy as np # numpy for numeric utilities (unused but commonly imported)\n",
|
| 43 |
+
"import string # string constants, used to build punctuation translator\n",
|
| 44 |
+
"import os # os module for file/path operations (imported for potential use)\n",
|
| 45 |
+
"# blank line for readability\n",
|
| 46 |
+
"# --- Configuration --- # configuration section start\n",
|
| 47 |
+
"# NOTE: This script assumes 'sample_corpus.csv' exists to generate the filtered file. # informational note\n",
|
| 48 |
+
"FILEPATH_RAW = 'filtered_corpus_here_removedDuplicates.csv' # path to the deduplicated raw corpus\n",
|
| 49 |
+
"FILEPATH_FILTERED = 'filtered_corpus_top_20.csv' # path to write/read the top filtered corpus\n",
|
| 50 |
+
"SOURCE_LANG_COL = 'src_lang' # column name for source language text\n",
|
| 51 |
+
"TARGET_LANG_COL = 'tgt_lang' # column name for target language text\n",
|
| 52 |
+
"SCORE_COL = 'KS_i' # column name that stores the Knowledge Score\n",
|
| 53 |
+
"# Filtering constants moved from calculate_threshold.py # note about origin of constant\n",
|
| 54 |
+
"TARGET_PERCENTILE = 80 # percentile threshold to filter top N% by score\n",
|
| 55 |
+
"OUTPUT_FILENAME = FILEPATH_FILTERED # Ensure output filename is the same # output target file variable\n",
|
| 56 |
+
"# --- End Configuration --- # configuration section end\n",
|
| 57 |
+
"# blank line for readability\n",
|
| 58 |
+
"def calculate_knowledge_threshold(filepath, percentile): # function to compute tau_K and filtered DF\n",
|
| 59 |
+
" \"\"\" # docstring start\n",
|
| 60 |
+
" Reads a CSV file, calculates the specified percentile of the Knowledge Score # description line 1\n",
|
| 61 |
+
" column, and returns the filtered corpus and the threshold (tau_K). # description line 2\n",
|
| 62 |
+
" \"\"\" # docstring end\n",
|
| 63 |
+
" try: # attempt to load and process the file\n",
|
| 64 |
+
" # 1. Load the data # step 1 comment\n",
|
| 65 |
+
" df = pd.read_csv(filepath) # load CSV into DataFrame df\n",
|
| 66 |
+
" # blank line for readability\n",
|
| 67 |
+
" # 1.5. Robust column check and numeric conversion # validate columns and convert types\n",
|
| 68 |
+
" if SCORE_COL not in df.columns: # check that score column exists\n",
|
| 69 |
+
" print(f\"Error: The CSV file must contain a column named '{SCORE_COL}'. Found columns: {list(df.columns)}\") # informative error\n",
|
| 70 |
+
" return pd.DataFrame(), None # return empty DF and None threshold on error\n",
|
| 71 |
+
" # blank line for readability\n",
|
| 72 |
+
" df_initial_size = len(df) # store initial row count for warnings\n",
|
| 73 |
+
" # Convert the score column to numeric, coercing errors # ensure numeric dtype for quantile\n",
|
| 74 |
+
" df[SCORE_COL] = pd.to_numeric(df[SCORE_COL], errors='coerce') # coerce invalids to NaN\n",
|
| 75 |
+
" df.dropna(subset=[SCORE_COL], inplace=True) # drop rows where score could not be parsed\n",
|
| 76 |
+
" # blank line for readability\n",
|
| 77 |
+
" if len(df) < df_initial_size: # warn if rows were dropped\n",
|
| 78 |
+
" print(f\"Warning: Dropped {df_initial_size - len(df)} rows with non-numeric scores.\") # warn about dropped rows\n",
|
| 79 |
+
" # blank line for readability\n",
|
| 80 |
+
" # 2. Calculate the threshold (tau_K) # compute percentile threshold\n",
|
| 81 |
+
" tau_K = df[SCORE_COL].quantile(percentile / 100, interpolation='linear') # compute numeric threshold\n",
|
| 82 |
+
" # blank line for readability\n",
|
| 83 |
+
" # 3. Apply the threshold to construct the filtered corpus (D_filtered) # filter rows >= tau_K\n",
|
| 84 |
+
" D_filtered = df[df[SCORE_COL] >= tau_K].copy() # select high-score rows and copy to new DF\n",
|
| 85 |
+
" # blank line for readability\n",
|
| 86 |
+
" return D_filtered, tau_K # return filtered DF and threshold\n",
|
| 87 |
+
" except FileNotFoundError: # handle missing file error\n",
|
| 88 |
+
" print(f\"Error: The file '{filepath}' was not found. Please ensure it exists.\") # print helpful message\n",
|
| 89 |
+
" return pd.DataFrame(), None # return empty DF and None threshold when file missing\n",
|
| 90 |
+
" except Exception as e: # catch-all for other errors\n",
|
| 91 |
+
" print(f\"An unexpected error occurred during filtering: {e}\") # print exception details\n",
|
| 92 |
+
" return pd.DataFrame(), None # return safe defaults on error\n",
|
| 93 |
+
"# blank line for readability\n",
|
| 94 |
+
"def tokenize_and_clean(text_series): # function to tokenize text series and clean tokens\n",
|
| 95 |
+
" \"\"\" # docstring start\n",
|
| 96 |
+
" Tokenizes text by splitting on whitespace, then cleans tokens by removing # description line 1\n",
|
| 97 |
+
" punctuation and converting to lowercase for accurate token counting and vocabulary size. # description line 2\n",
|
| 98 |
+
" \"\"\" # docstring end\n",
|
| 99 |
+
" all_tokens = [] # accumulator for all tokens across sentences\n",
|
| 100 |
+
" # blank line for readability\n",
|
| 101 |
+
" # Simple preprocessing: remove punctuation and lowercase # describe translator creation\n",
|
| 102 |
+
" translator = str.maketrans(string.punctuation, ' ' * len(string.punctuation)) # map punctuation to spaces\n",
|
| 103 |
+
" # blank line for readability\n",
|
| 104 |
+
" for text in text_series.astype(str): # iterate rows as strings\n",
|
| 105 |
+
" clean_text = text.translate(translator).lower() # remove punctuation and lowercase the text\n",
|
| 106 |
+
" tokens = clean_text.split() # split on whitespace into tokens\n",
|
| 107 |
+
" all_tokens.extend(tokens) # add tokens to accumulator\n",
|
| 108 |
+
" # loop continues for next sentence\n",
|
| 109 |
+
" return all_tokens # return the flattened token list\n",
|
| 110 |
+
"# blank line for readability\n",
|
| 111 |
+
"def calculate_corpus_metrics(filepath): # function to compute corpus-level metrics\n",
|
| 112 |
+
" \"\"\"Calculates all required corpus statistics for a given file.\"\"\" # single-line docstring\n",
|
| 113 |
+
" try: # try to read the file into a DataFrame\n",
|
| 114 |
+
" df = pd.read_csv(filepath) # load corpus file\n",
|
| 115 |
+
" except FileNotFoundError: # handle missing file\n",
|
| 116 |
+
" print(f\"Error: The file '{filepath}' was not found. Please ensure it exists.\") # user-friendly message\n",
|
| 117 |
+
" return None # return None to indicate failure\n",
|
| 118 |
+
" except pd.errors.EmptyDataError: # handle empty file error\n",
|
| 119 |
+
" print(f\"Error: The file '{filepath}' is empty.\") # inform user file has no data\n",
|
| 120 |
+
" return None # return None to indicate failure\n",
|
| 121 |
+
" # blank line for readability\n",
|
| 122 |
+
" # 1. Metric: Sentence Pairs # compute total number of sentence pairs\n",
|
| 123 |
+
" sentence_pairs = len(df) # number of rows equals sentence pairs\n",
|
| 124 |
+
" # blank line for readability\n",
|
| 125 |
+
" # 2. Metric: Tokens, Avg. Sentence Length, Vocabulary Size # prepare metrics container\n",
|
| 126 |
+
" metrics = {} # dict to hold source/target metrics\n",
|
| 127 |
+
" # blank line for readability\n",
|
| 128 |
+
" # Ensure column existence before proceeding # validate presence of expected columns\n",
|
| 129 |
+
" for col in [SOURCE_LANG_COL, TARGET_LANG_COL]: # iterate expected column names\n",
|
| 130 |
+
" if col not in df.columns: # raise if missing\n",
|
| 131 |
+
" raise KeyError(f\"Column '{col}' not found in the corpus file.\") # explicit error to surface missing columns\n",
|
| 132 |
+
" # blank line for readability\n",
|
| 133 |
+
" for col, tag in [(SOURCE_LANG_COL, 'English'), (TARGET_LANG_COL, 'Target')]: # compute metrics for both sides\n",
|
| 134 |
+
" # Tokenization and Cleaning # comment for next steps\n",
|
| 135 |
+
" tokens = tokenize_and_clean(df[col]) # get token list for this column\n",
|
| 136 |
+
" total_tokens = len(tokens) # total token count across all sentences\n",
|
| 137 |
+
" # blank line for readability\n",
|
| 138 |
+
" # Vocabulary Size # compute unique token count\n",
|
| 139 |
+
" vocab_size = len(set(tokens)) # size of unique token set\n",
|
| 140 |
+
" # blank line for readability\n",
|
| 141 |
+
" # Avg. Sentence Length (Tokens / Sentence Pairs) # compute average tokens per sentence\n",
|
| 142 |
+
" avg_sentence_length = total_tokens / sentence_pairs if sentence_pairs > 0 else 0 # guard division by zero\n",
|
| 143 |
+
" # blank line for readability\n",
|
| 144 |
+
" metrics[tag] = { # store computed metrics under tag key\n",
|
| 145 |
+
" 'tokens': total_tokens, # total tokens count\n",
|
| 146 |
+
" 'avg_len': avg_sentence_length, # average sentence length in tokens\n",
|
| 147 |
+
" 'vocab_size': vocab_size # vocabulary size after preprocessing\n",
|
| 148 |
+
" }\n",
|
| 149 |
+
" return sentence_pairs, metrics # return computed metrics\n",
|
| 150 |
+
"# blank line for readability\n",
|
| 151 |
+
"def format_and_print_results(sentence_pairs, metrics): # pretty-print the metrics in table form\n",
|
| 152 |
+
" \"\"\"Formats the calculated metrics into the requested table structure.\"\"\" # docstring\n",
|
| 153 |
+
" # blank line for readability\n",
|
| 154 |
+
" src_data = metrics['English'] # metrics for source/English side\n",
|
| 155 |
+
" tgt_data = metrics['Target'] # metrics for target side\n",
|
| 156 |
+
" # blank line for readability\n",
|
| 157 |
+
" print(\"\\n\" + \"=\"*80) # print top border\n",
|
| 158 |
+
" print(f\" FILTERED CORPUS METRICS ({FILEPATH_FILTERED})\") # title with filename\n",
|
| 159 |
+
" print(\"=\"*80) # print border again\n",
|
| 160 |
+
" # blank line for readability\n",
|
| 161 |
+
" # Table Header # print header row labels\n",
|
| 162 |
+
" print(f\"| {'Metric':<30} | {'Source (English)':>20} | {'Target (Target)':>20} | {'Notes':<10} |\") # header formatting\n",
|
| 163 |
+
" print(\"-\" * 80) # separator line\n",
|
| 164 |
+
" # blank line for readability\n",
|
| 165 |
+
" # Row: Sentence Pairs # print sentence pair counts\n",
|
| 166 |
+
" print(f\"| {'Sentence Pairs':<30} | {sentence_pairs:>20,} | {sentence_pairs:>20,} | {'--':<10} |\") # counts for both columns\n",
|
| 167 |
+
" # blank line for readability\n",
|
| 168 |
+
" # Row: Tokens (Formatted for M/K display based on size) # prepare token display strings\n",
|
| 169 |
+
" src_tokens_display = f\"{src_data['tokens']:,}\" # formatted source token count\n",
|
| 170 |
+
" tgt_tokens_display = f\"{tgt_data['tokens']:,}\" # formatted target token count\n",
|
| 171 |
+
" # blank line for readability\n",
|
| 172 |
+
" print(f\"| {'Tokens':<30} | {src_tokens_display:>20} | {tgt_tokens_display:>20} | {'Actual Count':<10} |\") # print tokens row\n",
|
| 173 |
+
" # blank line for readability\n",
|
| 174 |
+
" # Row: Avg. Sentence Length # print average sentence lengths\n",
|
| 175 |
+
" print(f\"| {'Avg. Sentence Length':<30} | {src_data['avg_len']:>20.2f} | {tgt_data['avg_len']:>20.2f} | {'Tokens/Pair':<10} |\") # two-decimal precision\n",
|
| 176 |
+
" # blank line for readability\n",
|
| 177 |
+
" # Row: Vocabulary Size # print vocabulary sizes\n",
|
| 178 |
+
" print(f\"| {'Vocabulary Size':<30} | {src_data['vocab_size']:>20,} | {tgt_data['vocab_size']:>20,} | {'After Preprocessing':<10} |\") # vocab counts\n",
|
| 179 |
+
" # blank line for readability\n",
|
| 180 |
+
" # Row: OOV Rate (Placeholder since test set is needed) # OOV requires a test set\n",
|
| 181 |
+
" print(f\"| {'OOV Rate':<30} | {'--':>20} | {'--':>20} | {'Requires Test Set':<10} |\") # placeholder output\n",
|
| 182 |
+
" # blank line for readability\n",
|
| 183 |
+
" print(\"-\" * 80) # bottom separator\n",
|
| 184 |
+
"# blank line for readability\n",
|
| 185 |
+
"# blank line for readability\n",
|
| 186 |
+
"if __name__ == \"__main__\": # script entrypoint guard\n",
|
| 187 |
+
" # blank line for readability\n",
|
| 188 |
+
" print(\"NOTE: Running filtering logic to ensure 'filtered_corpus_top_20.csv' is up-to-date for metrics.\") # informational message\n",
|
| 189 |
+
" # blank line for readability\n",
|
| 190 |
+
" # 1. Run the filtering logic (now self-contained) # compute filtered DF and threshold\n",
|
| 191 |
+
" filtered_df, threshold = calculate_knowledge_threshold(FILEPATH_RAW, TARGET_PERCENTILE) # call to compute top percentile\n",
|
| 192 |
+
" # blank line for readability\n",
|
| 193 |
+
" if filtered_df.empty: # check if filtering produced results\n",
|
| 194 |
+
" print(\"Could not generate filtered corpus. Cannot proceed with metrics calculation.\") # failure message\n",
|
| 195 |
+
" else: # if we have filtered data, continue\n",
|
| 196 |
+
" # 2. Save the filtered corpus (important for the metrics script to read it) # save step\n",
|
| 197 |
+
" filtered_df.to_csv(OUTPUT_FILENAME, index=False) # write filtered DF to output file\n",
|
| 198 |
+
" print(f\"Filtered corpus saved to '{OUTPUT_FILENAME}'.\") # confirmation message\n",
|
| 199 |
+
" # blank line for readability\n",
|
| 200 |
+
" # 3. Now run the actual metrics calculation on the saved filtered file # compute metrics\n",
|
| 201 |
+
" sentence_pairs, metrics = calculate_corpus_metrics(FILEPATH_FILTERED) # call metrics function\n",
|
| 202 |
+
" # blank line for readability\n",
|
| 203 |
+
" if metrics: # if metrics computed successfully\n",
|
| 204 |
+
" format_and_print_results(sentence_pairs, metrics) # print the formatted table\n",
|
| 205 |
+
" else: # metrics call failed\n",
|
| 206 |
+
" print(\"Metrics calculation failed.\") # failure message\n",
|
| 207 |
+
"# end of cell"
|
| 208 |
+
]
|
| 209 |
+
}
|
| 210 |
+
],
|
| 211 |
+
"metadata": {
|
| 212 |
+
"kernelspec": {
|
| 213 |
+
"display_name": "ptorch",
|
| 214 |
+
"language": "python",
|
| 215 |
+
"name": "python3"
|
| 216 |
+
},
|
| 217 |
+
"language_info": {
|
| 218 |
+
"codemirror_mode": {
|
| 219 |
+
"name": "ipython",
|
| 220 |
+
"version": 3
|
| 221 |
+
},
|
| 222 |
+
"file_extension": ".py",
|
| 223 |
+
"mimetype": "text/x-python",
|
| 224 |
+
"name": "python",
|
| 225 |
+
"nbconvert_exporter": "python",
|
| 226 |
+
"pygments_lexer": "ipython3",
|
| 227 |
+
"version": "3.12.11"
|
| 228 |
+
}
|
| 229 |
+
},
|
| 230 |
+
"nbformat": 4,
|
| 231 |
+
"nbformat_minor": 5
|
| 232 |
+
}
|
scripts/finetuning_script.ipynb
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "b58b67f3",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"# Import model and training classes from Hugging Face Transformers\n",
|
| 11 |
+
"from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, Seq2SeqTrainer, Seq2SeqTrainingArguments\n",
|
| 12 |
+
"# Import dataset utilities from the datasets library\n",
|
| 13 |
+
"from datasets import load_dataset, Dataset\n",
|
| 14 |
+
"# Import pandas for CSV/DF handling\n",
|
| 15 |
+
"import pandas as pd\n",
|
| 16 |
+
"# Import torch for device checks and tensors\n",
|
| 17 |
+
"import torch\n",
|
| 18 |
+
"# Import evaluate to load evaluation metrics\n",
|
| 19 |
+
"import evaluate\n",
|
| 20 |
+
"# Import numpy for numeric array manipulation\n",
|
| 21 |
+
"import numpy as np"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"execution_count": null,
|
| 27 |
+
"id": "349f59d3",
|
| 28 |
+
"metadata": {},
|
| 29 |
+
"outputs": [],
|
| 30 |
+
"source": [
|
| 31 |
+
"# Path to the main corpus CSV file (expects two columns: source and target)\n",
|
| 32 |
+
"data_path = \"./your/main/corpus.csv\" # Two columns: 'en' and 't'\n",
|
| 33 |
+
"# Read the CSV into a pandas DataFrame\n",
|
| 34 |
+
"df = pd.read_csv(data_path)"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": null,
|
| 40 |
+
"id": "f9232b43",
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"outputs": [],
|
| 43 |
+
"source": [
|
| 44 |
+
"# Display the first 3 rows of the DataFrame to inspect loaded data\n",
|
| 45 |
+
"df.head(3)"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": null,
|
| 51 |
+
"id": "88f11bdd",
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"outputs": [],
|
| 54 |
+
"source": [
|
| 55 |
+
"# Ensure the dataframe has correct column names and no list-type values\n",
|
| 56 |
+
"def ensure_text_columns(df):\n",
|
| 57 |
+
" # If the DataFrame uses 'Src_lang'/'Tgt_lang', rename them to 'src'/'tgt'\n",
|
| 58 |
+
" if 'Src_lang' in df.columns and 'Tgt_lang' in df.columns:\n",
|
| 59 |
+
" df = df.rename(columns={\"Src_lang\": \"src\", \"Tgt_lang\": \"tgt\"})\n",
|
| 60 |
+
" # blank line preserved for readability\n",
|
| 61 |
+
" # Ensure all values are strings to avoid list/object types during tokenization\n",
|
| 62 |
+
" df['src'] = df['src'].astype(str)\n",
|
| 63 |
+
" df['tgt'] = df['tgt'].astype(str)\n",
|
| 64 |
+
" # blank line preserved for readability\n",
|
| 65 |
+
" return df # return the normalized DataFrame\n",
|
| 66 |
+
"# Apply the helper to the loaded DataFrame\n",
|
| 67 |
+
"df = ensure_text_columns(df)"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"cell_type": "code",
|
| 72 |
+
"execution_count": null,
|
| 73 |
+
"id": "9c3fdfa7",
|
| 74 |
+
"metadata": {},
|
| 75 |
+
"outputs": [],
|
| 76 |
+
"source": [
|
| 77 |
+
"# Re-inspect the DataFrame after normalization\n",
|
| 78 |
+
"df.head(3)"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "code",
|
| 83 |
+
"execution_count": null,
|
| 84 |
+
"id": "3aaec3a8",
|
| 85 |
+
"metadata": {},
|
| 86 |
+
"outputs": [],
|
| 87 |
+
"source": [
|
| 88 |
+
"# Add language prefix tokens that will be prepended to source/target sentences\n",
|
| 89 |
+
"prefix_src = \"src_lang_code\" # placeholder source language token\n",
|
| 90 |
+
"prefix_tgt = \"tgt_lang_code\" # placeholder target language token"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "code",
|
| 95 |
+
"execution_count": null,
|
| 96 |
+
"id": "06bbfc98",
|
| 97 |
+
"metadata": {},
|
| 98 |
+
"outputs": [],
|
| 99 |
+
"source": [
|
| 100 |
+
"# Preprocessing function that adds language prefix tokens to each example\n",
|
| 101 |
+
"def preprocess(example):\n",
|
| 102 |
+
" # change the prefix_src and prefix_tgt to change the translation direction\n",
|
| 103 |
+
" return {\n",
|
| 104 |
+
" \"translation\": {\n",
|
| 105 |
+
" \"src\": f\"{prefix_src} {example['src']}\", # prepend source prefix\n",
|
| 106 |
+
" \"tgt\": f\"{prefix_tgt} {example['tgt']}\" # prepend target prefix\n",
|
| 107 |
+
" }\n",
|
| 108 |
+
" }"
|
| 109 |
+
]
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"cell_type": "code",
|
| 113 |
+
"execution_count": null,
|
| 114 |
+
"id": "ad52beb7",
|
| 115 |
+
"metadata": {},
|
| 116 |
+
"outputs": [],
|
| 117 |
+
"source": [
|
| 118 |
+
"# Rename columns (no-op here but kept for clarity) and apply preprocessing to create a Dataset\n",
|
| 119 |
+
"df = df.rename(columns={\"src\": \"src\", \"tgt\": \"tgt\"}) # explicit rename placeholder\n",
|
| 120 |
+
"# Convert pandas DataFrame to a Hugging Face Dataset\n",
|
| 121 |
+
"dataset = Dataset.from_pandas(df)\n",
|
| 122 |
+
"# Apply preprocessing function to each example in the dataset\n",
|
| 123 |
+
"dataset = dataset.map(preprocess)"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"cell_type": "code",
|
| 128 |
+
"execution_count": null,
|
| 129 |
+
"id": "fd30423f",
|
| 130 |
+
"metadata": {},
|
| 131 |
+
"outputs": [],
|
| 132 |
+
"source": [
|
| 133 |
+
"# Split the Dataset into training and validation sets\n",
|
| 134 |
+
"split_dataset = dataset.train_test_split(test_size=0.1, seed=42) # 10% for validation\n",
|
| 135 |
+
"# Extract train and validation Dataset objects\n",
|
| 136 |
+
"train_data = split_dataset[\"train\"]\n",
|
| 137 |
+
"val_data = split_dataset[\"test\"]"
|
| 138 |
+
]
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"cell_type": "code",
|
| 142 |
+
"execution_count": null,
|
| 143 |
+
"id": "951b1f86",
|
| 144 |
+
"metadata": {},
|
| 145 |
+
"outputs": [],
|
| 146 |
+
"source": [
|
| 147 |
+
"# Save processed train/validation splits to CSV files for later use\n",
|
| 148 |
+
"train_data.to_csv(\"train_set.csv\", index=False) # write training set\n",
|
| 149 |
+
"val_data.to_csv(\"val_set.csv\", index=False) # write validation set\n",
|
| 150 |
+
"print(\"Train and validation data saved successfully:\") # confirmation\n",
|
| 151 |
+
"print(f\"Train size: {len(train_data)}\") # show train count\n",
|
| 152 |
+
"print(f\"Validation size: {len(val_data)}\") # show validation count"
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"cell_type": "code",
|
| 157 |
+
"execution_count": null,
|
| 158 |
+
"id": "3cbc12e7",
|
| 159 |
+
"metadata": {},
|
| 160 |
+
"outputs": [],
|
| 161 |
+
"source": [
|
| 162 |
+
"# Load the MBART-50 tokenizer (fast implementation)\n",
|
| 163 |
+
"tokenizer = MBart50TokenizerFast.from_pretrained(\"facebook/mbart-large-50-many-to-many-mmt\")\n",
|
| 164 |
+
"# Add any custom special tokens required (e.g., our target language code)\n",
|
| 165 |
+
"tokenizer.add_special_tokens({'additional_special_tokens': [\"tgt_lang_code\"]})\n",
|
| 166 |
+
"# Register the new lang token in the tokenizer's lang_code mapping\n",
|
| 167 |
+
"tokenizer.lang_code_to_id[\"tgt_lang_code\"] = len(tokenizer.lang_code_to_id)\n",
|
| 168 |
+
"# Rebuild reverse mapping from id to lang code (useful later)\n",
|
| 169 |
+
"tokenizer.id_to_lang_code = {v: k for k, v in tokenizer.lang_code_to_id.items()}"
|
| 170 |
+
]
|
| 171 |
+
},
|
| 172 |
+
{
|
| 173 |
+
"cell_type": "code",
|
| 174 |
+
"execution_count": null,
|
| 175 |
+
"id": "fc507095",
|
| 176 |
+
"metadata": {},
|
| 177 |
+
"outputs": [],
|
| 178 |
+
"source": [
|
| 179 |
+
"# Load the pretrained MBART model for conditional generation\n",
|
| 180 |
+
"model = MBartForConditionalGeneration.from_pretrained(\"facebook/mbart-large-50-many-to-many-mmt\")\n",
|
| 181 |
+
"# Resize model token embeddings to account for any newly added tokens in the tokenizer\n",
|
| 182 |
+
"model.resize_token_embeddings(len(tokenizer))"
|
| 183 |
+
]
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"cell_type": "code",
|
| 187 |
+
"execution_count": null,
|
| 188 |
+
"id": "1ebd78be",
|
| 189 |
+
"metadata": {},
|
| 190 |
+
"outputs": [],
|
| 191 |
+
"source": [
|
| 192 |
+
"# --- Step 4: Tokenize data --- # tokenization settings and helper\n",
|
| 193 |
+
"# Maximum tokenization length for inputs/targets\n",
|
| 194 |
+
"max_length = 128\n",
|
| 195 |
+
"# blank line for readability\n",
|
| 196 |
+
"# Tokenization function applied to dataset examples\n",
|
| 197 |
+
"def tokenize_function(examples):\n",
|
| 198 |
+
" # Tokenize source with padding/truncation to max_length\n",
|
| 199 |
+
" inputs = tokenizer(examples[\"translation\"][\"src\"], padding=\"max_length\", truncation=True, max_length=max_length)\n",
|
| 200 |
+
" # Tokenize target similarly\n",
|
| 201 |
+
" targets = tokenizer(examples[\"translation\"][\"tgt\"], padding=\"max_length\", truncation=True, max_length=max_length)\n",
|
| 202 |
+
" # Use tokenized target input_ids as labels for seq2seq training\n",
|
| 203 |
+
" inputs[\"labels\"] = targets[\"input_ids\"]\n",
|
| 204 |
+
" return inputs"
|
| 205 |
+
]
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"cell_type": "code",
|
| 209 |
+
"execution_count": null,
|
| 210 |
+
"id": "bb922c07",
|
| 211 |
+
"metadata": {},
|
| 212 |
+
"outputs": [],
|
| 213 |
+
"source": [
|
| 214 |
+
"# Tokenize the dataset using the helper function defined above\n",
|
| 215 |
+
"train_dataset = train_data.map(tokenize_function)\n",
|
| 216 |
+
"val_dataset = val_data.map(tokenize_function)"
|
| 217 |
+
]
|
| 218 |
+
},
|
| 219 |
+
{
|
| 220 |
+
"cell_type": "code",
|
| 221 |
+
"execution_count": null,
|
| 222 |
+
"id": "213966dc",
|
| 223 |
+
"metadata": {},
|
| 224 |
+
"outputs": [],
|
| 225 |
+
"source": [
|
| 226 |
+
"# Import evaluation utilities (repeated import is safe inside notebook but already imported above)\n",
|
| 227 |
+
"import evaluate\n",
|
| 228 |
+
"# numpy imported earlier; this duplicate import is harmless\n",
|
| 229 |
+
"import numpy as np\n",
|
| 230 |
+
"# blank line for readability\n",
|
| 231 |
+
"# Load metric implementations once to reuse inside compute_metrics\n",
|
| 232 |
+
"bleu_metric = evaluate.load(\"bleu\")\n",
|
| 233 |
+
"meteor_metric = evaluate.load(\"meteor\")\n",
|
| 234 |
+
"ter_metric = evaluate.load(\"ter\")\n",
|
| 235 |
+
"chrf_metric = evaluate.load(\"chrf\")\n",
|
| 236 |
+
"# blank line for readability\n",
|
| 237 |
+
"# Function used by Trainer to compute evaluation metrics from model outputs\n",
|
| 238 |
+
"def compute_metrics(eval_preds):\n",
|
| 239 |
+
" # eval_preds is a tuple (predictions, labels)\n",
|
| 240 |
+
" preds, labels = eval_preds\n",
|
| 241 |
+
" # blank line for readability\n",
|
| 242 |
+
" # Decode predictions from token ids to strings\n",
|
| 243 |
+
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
| 244 |
+
" # blank line for readability\n",
|
| 245 |
+
" # Replace masked label tokens (-100) with pad token id so they decode properly\n",
|
| 246 |
+
" labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
|
| 247 |
+
" # blank line for readability\n",
|
| 248 |
+
" # Decode label ids to strings\n",
|
| 249 |
+
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
| 250 |
+
" # blank line for readability\n",
|
| 251 |
+
" # Clean whitespace from decoded strings\n",
|
| 252 |
+
" decoded_preds = [p.strip() for p in decoded_preds]\n",
|
| 253 |
+
" decoded_labels = [[l.strip()] for l in decoded_labels] # convert to list-of-lists for metrics\n",
|
| 254 |
+
" # blank line for readability\n",
|
| 255 |
+
" # Compute each metric using the decoded predictions and references\n",
|
| 256 |
+
" bleu = bleu_metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
|
| 257 |
+
" meteor = meteor_metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
|
| 258 |
+
" ter = ter_metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
|
| 259 |
+
" chrf = chrf_metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
|
| 260 |
+
" # blank line for readability\n",
|
| 261 |
+
" # BLEU implementations may return different keys; try common ones\n",
|
| 262 |
+
" bleu_score = bleu.get(\"score\", bleu.get(\"bleu\"))\n",
|
| 263 |
+
" # blank line for readability\n",
|
| 264 |
+
" return {\n",
|
| 265 |
+
" \"ChrF\": chrf[\"score\"], # MAIN METRIC\n",
|
| 266 |
+
" \"BLEU\": bleu_score,\n",
|
| 267 |
+
" \"METEOR\": meteor[\"meteor\"],\n",
|
| 268 |
+
" \"TER\": ter[\"score\"]\n",
|
| 269 |
+
" }\n",
|
| 270 |
+
"# end of cell"
|
| 271 |
+
]
|
| 272 |
+
},
|
| 273 |
+
{
|
| 274 |
+
"cell_type": "code",
|
| 275 |
+
"execution_count": null,
|
| 276 |
+
"id": "824252e3",
|
| 277 |
+
"metadata": {},
|
| 278 |
+
"outputs": [],
|
| 279 |
+
"source": [
|
| 280 |
+
"# Configure Seq2Seq training arguments for the Hugging Face Trainer\n",
|
| 281 |
+
"training_args = Seq2SeqTrainingArguments(\n",
|
| 282 |
+
" output_dir=\"./your/model/checkpoints\", # directory for checkpoints and outputs\n",
|
| 283 |
+
" per_device_train_batch_size=8, # batch size per device for training\n",
|
| 284 |
+
" per_device_eval_batch_size=8, # batch size per device for evaluation\n",
|
| 285 |
+
" gradient_accumulation_steps=4, # effective batch size = 8*4 = 32\n",
|
| 286 |
+
" learning_rate=3e-5, # initial learning rate\n",
|
| 287 |
+
" weight_decay=0.01, # weight decay for optimizer\n",
|
| 288 |
+
" num_train_epochs=3, # number of training epochs\n",
|
| 289 |
+
" warmup_steps=1000, # number of warmup steps for scheduler\n",
|
| 290 |
+
" lr_scheduler_type=\"cosine\", # learning rate scheduler type\n",
|
| 291 |
+
" fp16=torch.cuda.is_available(), # enable fp16 if CUDA is available\n",
|
| 292 |
+
" evaluation_strategy=\"steps\", # evaluate every X steps\n",
|
| 293 |
+
" eval_steps=2000, # evaluation interval in steps\n",
|
| 294 |
+
" save_strategy=\"steps\", # save checkpoints every X steps\n",
|
| 295 |
+
" save_steps=2000, # checkpoint saving interval\n",
|
| 296 |
+
" load_best_model_at_end=True, # keep the best model according to metric\n",
|
| 297 |
+
" metric_for_best_model=\"ChrF\", # metric used to select best model\n",
|
| 298 |
+
" greater_is_better=True, # higher metric value is better\n",
|
| 299 |
+
" save_total_limit=5, # limit number of saved checkpoints\n",
|
| 300 |
+
" predict_with_generate=True, # use generate() for predictions during eval\n",
|
| 301 |
+
" generation_max_length=128, # max length when generating predictions\n",
|
| 302 |
+
" generation_num_beams=4, # number of beams for generation\n",
|
| 303 |
+
" logging_dir=\"./logs\", # tensorboard/logging dir\n",
|
| 304 |
+
" logging_steps=200, # logging interval\n",
|
| 305 |
+
" seed=42, # random seed for reproducibility\n",
|
| 306 |
+
" report_to=\"none\", # disable reporting to external services\n",
|
| 307 |
+
")\n",
|
| 308 |
+
"# end of training_args cell"
|
| 309 |
+
]
|
| 310 |
+
},
|
| 311 |
+
{
|
| 312 |
+
"cell_type": "code",
|
| 313 |
+
"execution_count": null,
|
| 314 |
+
"id": "dc6cb1bd",
|
| 315 |
+
"metadata": {},
|
| 316 |
+
"outputs": [],
|
| 317 |
+
"source": [
|
| 318 |
+
"# Import a data collator that pads to longest sequence in the batch for seq2seq models\n",
|
| 319 |
+
"from transformers import DataCollatorForSeq2Seq\n",
|
| 320 |
+
"# blank line for readability\n",
|
| 321 |
+
"# Create the data collator which will dynamically pad batch examples\n",
|
| 322 |
+
"data_collator = DataCollatorForSeq2Seq(\n",
|
| 323 |
+
" tokenizer,\n",
|
| 324 |
+
" model=model,\n",
|
| 325 |
+
" padding=\"longest\", # pad to the longest sequence in the batch\n",
|
| 326 |
+
")\n",
|
| 327 |
+
"# end of data_collator cell"
|
| 328 |
+
]
|
| 329 |
+
},
|
| 330 |
+
{
|
| 331 |
+
"cell_type": "code",
|
| 332 |
+
"execution_count": null,
|
| 333 |
+
"id": "9508fc22",
|
| 334 |
+
"metadata": {},
|
| 335 |
+
"outputs": [],
|
| 336 |
+
"source": [
|
| 337 |
+
"# Create the Seq2SeqTrainer wrapper which handles training/evaluation loops\n",
|
| 338 |
+
"trainer = Seq2SeqTrainer(\n",
|
| 339 |
+
" model=model, # the model to train\n",
|
| 340 |
+
" args=training_args, # training configuration\n",
|
| 341 |
+
" train_dataset=train_dataset, # training data\n",
|
| 342 |
+
" eval_dataset=val_dataset, # evaluation data\n",
|
| 343 |
+
" processing_class=tokenizer, # tokenizer/processor used for the model\n",
|
| 344 |
+
" data_collator=data_collator, # handles padding in batches\n",
|
| 345 |
+
" compute_metrics=compute_metrics, # metrics callback for evaluation\n",
|
| 346 |
+
")\n",
|
| 347 |
+
"# end of trainer creation cell"
|
| 348 |
+
]
|
| 349 |
+
},
|
| 350 |
+
{
|
| 351 |
+
"cell_type": "code",
|
| 352 |
+
"execution_count": null,
|
| 353 |
+
"id": "d5aea8ad",
|
| 354 |
+
"metadata": {},
|
| 355 |
+
"outputs": [],
|
| 356 |
+
"source": [
|
| 357 |
+
"# Start training. This runs the main training loop according to training_args\n",
|
| 358 |
+
"trainer.train()"
|
| 359 |
+
]
|
| 360 |
+
},
|
| 361 |
+
{
|
| 362 |
+
"cell_type": "code",
|
| 363 |
+
"execution_count": null,
|
| 364 |
+
"id": "462c0da0",
|
| 365 |
+
"metadata": {},
|
| 366 |
+
"outputs": [],
|
| 367 |
+
"source": [
|
| 368 |
+
"# Evaluate the trained model on the validation set and print returned metrics\n",
|
| 369 |
+
"eval_results = trainer.evaluate()\n",
|
| 370 |
+
"print(eval_results)"
|
| 371 |
+
]
|
| 372 |
+
},
|
| 373 |
+
{
|
| 374 |
+
"cell_type": "code",
|
| 375 |
+
"execution_count": null,
|
| 376 |
+
"id": "59cfe87b",
|
| 377 |
+
"metadata": {},
|
| 378 |
+
"outputs": [],
|
| 379 |
+
"source": [
|
| 380 |
+
"# Show eval_results variable (already printed above) in a notebook cell to display its value\n",
|
| 381 |
+
"eval_results"
|
| 382 |
+
]
|
| 383 |
+
},
|
| 384 |
+
{
|
| 385 |
+
"cell_type": "code",
|
| 386 |
+
"execution_count": null,
|
| 387 |
+
"id": "d629dc4b",
|
| 388 |
+
"metadata": {},
|
| 389 |
+
"outputs": [],
|
| 390 |
+
"source": [
|
| 391 |
+
"# Save the fine-tuned model weights and tokenizer to a directory\n",
|
| 392 |
+
"model.save_pretrained(\"./your/model/name\") # saves model config and weights\n",
|
| 393 |
+
"tokenizer.save_pretrained(\"./your/model/name\") # saves tokenizer files"
|
| 394 |
+
]
|
| 395 |
+
},
|
| 396 |
+
{
|
| 397 |
+
"cell_type": "code",
|
| 398 |
+
"execution_count": null,
|
| 399 |
+
"id": "0641551e",
|
| 400 |
+
"metadata": {},
|
| 401 |
+
"outputs": [],
|
| 402 |
+
"source": [
|
| 403 |
+
"# Load pipeline utilities for quick inference\n",
|
| 404 |
+
"import torch\n",
|
| 405 |
+
"from transformers import pipeline\n",
|
| 406 |
+
"# blank line for readability\n",
|
| 407 |
+
"# Create a translation pipeline pointing at the saved model directory\n",
|
| 408 |
+
"pipeline = pipeline(\n",
|
| 409 |
+
" task=\"translation\", # pipeline task\n",
|
| 410 |
+
" model=\"./your/model/name\", # path to saved model\n",
|
| 411 |
+
" device=0, # device id (0 for first GPU); set to -1 for CPU\n",
|
| 412 |
+
" torch_dtype=torch.float16, # use float16 if model and device support it\n",
|
| 413 |
+
" src_lang=\"src_lang_code\", # source language code token\n",
|
| 414 |
+
" tgt_lang=\"tgt_lang_code\", # target language code token\n",
|
| 415 |
+
")\n",
|
| 416 |
+
"# Run the pipeline on a sample sentence and print the translation\n",
|
| 417 |
+
"print(pipeline(\"I like singing\"))"
|
| 418 |
+
]
|
| 419 |
+
}
|
| 420 |
+
],
|
| 421 |
+
"metadata": {
|
| 422 |
+
"kernelspec": {
|
| 423 |
+
"display_name": "ptorch",
|
| 424 |
+
"language": "python",
|
| 425 |
+
"name": "python3"
|
| 426 |
+
},
|
| 427 |
+
"language_info": {
|
| 428 |
+
"codemirror_mode": {
|
| 429 |
+
"name": "ipython",
|
| 430 |
+
"version": 3
|
| 431 |
+
},
|
| 432 |
+
"file_extension": ".py",
|
| 433 |
+
"mimetype": "text/x-python",
|
| 434 |
+
"name": "python",
|
| 435 |
+
"nbconvert_exporter": "python",
|
| 436 |
+
"pygments_lexer": "ipython3",
|
| 437 |
+
"version": "3.12.11"
|
| 438 |
+
}
|
| 439 |
+
},
|
| 440 |
+
"nbformat": 4,
|
| 441 |
+
"nbformat_minor": 5
|
| 442 |
+
}
|
scripts/intrinsic_evaluation.ipynb
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "cbf05abe",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import pandas as pd\n",
|
| 11 |
+
"import numpy as np\n",
|
| 12 |
+
"import string\n",
|
| 13 |
+
"import re\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"# --- Configuration ---\n",
|
| 16 |
+
"FILEPATH_RAW = 'Path/to/the/original/corpus.csv'\n",
|
| 17 |
+
"FILEPATH_FILTERED = 'Path/to/the/pre-filtered/corpus.csv' # Path to the pre-filtered corpus\n",
|
| 18 |
+
"SCORE_COL = 'KS_i'\n",
|
| 19 |
+
"# --- End Configuration ---\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"def tokenize_and_clean(text_series):\n",
|
| 22 |
+
" \"\"\"Tokenizes text, converts to lowercase, and removes punctuation.\"\"\"\n",
|
| 23 |
+
" # Combine all text into a single string\n",
|
| 24 |
+
" full_text = \" \".join(text_series.astype(str))\n",
|
| 25 |
+
" # Remove punctuation\n",
|
| 26 |
+
" full_text = full_text.lower().translate(str.maketrans('', '', string.punctuation))\n",
|
| 27 |
+
" # Simple split by whitespace\n",
|
| 28 |
+
" tokens = full_text.split()\n",
|
| 29 |
+
" return tokens\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"def calculate_ttr(tokens):\n",
|
| 32 |
+
" \"\"\"Calculates Type-Token Ratio (Lexical Diversity).\"\"\"\n",
|
| 33 |
+
" if not tokens:\n",
|
| 34 |
+
" return 0.0\n",
|
| 35 |
+
" types = set(tokens)\n",
|
| 36 |
+
" return len(types) / len(tokens)\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"def intrinsic_evaluation(filepath_raw, filepath_filtered, score_col):\n",
|
| 39 |
+
" \"\"\"\n",
|
| 40 |
+
" Performs intrinsic evaluation metrics comparing the pre-filtered corpus \n",
|
| 41 |
+
" against the raw corpus.\n",
|
| 42 |
+
" \"\"\"\n",
|
| 43 |
+
" try:\n",
|
| 44 |
+
" # 1. Load and prepare the data\n",
|
| 45 |
+
" df_raw = pd.read_csv(filepath_raw)\n",
|
| 46 |
+
" df_filtered = pd.read_csv(filepath_filtered)\n",
|
| 47 |
+
"\n",
|
| 48 |
+
" # 1.5. Robust cleaning and preparation for both dataframes\n",
|
| 49 |
+
" for df, name in [(df_raw, \"Raw\"), (df_filtered, \"Filtered\")]:\n",
|
| 50 |
+
" if score_col not in df.columns:\n",
|
| 51 |
+
" print(f\"Error: The {name} CSV file must contain a column named '{score_col}'.\")\n",
|
| 52 |
+
" return\n",
|
| 53 |
+
" # Convert score column to numeric and drop rows where score or text columns are missing\n",
|
| 54 |
+
" df[score_col] = pd.to_numeric(df[score_col], errors='coerce')\n",
|
| 55 |
+
" df.dropna(subset=[score_col, 'src_lang', 'tgt_lang'], inplace=True)\n",
|
| 56 |
+
" \n",
|
| 57 |
+
" if len(df_raw) == 0 or len(df_filtered) == 0:\n",
|
| 58 |
+
" print(\"Error: One or both corpora are empty after cleaning.\")\n",
|
| 59 |
+
" return\n",
|
| 60 |
+
"\n",
|
| 61 |
+
" # 2. Calculate Intrinsic Metrics\n",
|
| 62 |
+
" \n",
|
| 63 |
+
" # --- A. Knowledge Score Averages ---\n",
|
| 64 |
+
" raw_avg_ks = df_raw[score_col].mean()\n",
|
| 65 |
+
" filtered_avg_ks = df_filtered[score_col].mean()\n",
|
| 66 |
+
" \n",
|
| 67 |
+
" # --- B. Lexical Diversity (Type-Token Ratio) ---\n",
|
| 68 |
+
" \n",
|
| 69 |
+
" # Raw Corpus TTR\n",
|
| 70 |
+
" raw_src_tokens = tokenize_and_clean(df_raw['src_lang'])\n",
|
| 71 |
+
" raw_tgt_tokens = tokenize_and_clean(df_raw['tgt_lang'])\n",
|
| 72 |
+
" raw_src_ttr = calculate_ttr(raw_src_tokens)\n",
|
| 73 |
+
" raw_tgt_ttr = calculate_ttr(raw_tgt_tokens)\n",
|
| 74 |
+
" \n",
|
| 75 |
+
" # Filtered Corpus TTR\n",
|
| 76 |
+
" filtered_src_tokens = tokenize_and_clean(df_filtered['src_lang'])\n",
|
| 77 |
+
" filtered_tgt_tokens = tokenize_and_clean(df_filtered['tgt_lang'])\n",
|
| 78 |
+
" filtered_src_ttr = calculate_ttr(filtered_src_tokens)\n",
|
| 79 |
+
" filtered_tgt_ttr = calculate_ttr(filtered_tgt_tokens)\n",
|
| 80 |
+
" \n",
|
| 81 |
+
" # 3. Print Results\n",
|
| 82 |
+
" print(\"\\n\" + \"=\"*60)\n",
|
| 83 |
+
" print(\"INTRINSIC CORPUS QUALITY EVALUATION\")\n",
|
| 84 |
+
" print(\"=\"*60)\n",
|
| 85 |
+
" retention_rate = len(df_filtered)/len(df_raw)*100\n",
|
| 86 |
+
" print(f\"Corpus Sizes: Raw={len(df_raw)} | Filtered={len(df_filtered)} (Retention: {retention_rate:.2f}%)\")\n",
|
| 87 |
+
" \n",
|
| 88 |
+
" # --- Average KS_i Comparison ---\n",
|
| 89 |
+
" print(\"\\n--- 1. Average Knowledge Score (KS_i) ---\")\n",
|
| 90 |
+
" print(f\"| {'Metric':<25} | {'Raw Corpus':>15} | {'Filtered Corpus':>15} |\")\n",
|
| 91 |
+
" print(f\"| {'Average KS_i':<25} | {raw_avg_ks:>15.4f} | {filtered_avg_ks:>15.4f} |\")\n",
|
| 92 |
+
" \n",
|
| 93 |
+
" ks_increase_percent = ((filtered_avg_ks - raw_avg_ks) / raw_avg_ks) * 100\n",
|
| 94 |
+
" print(\"\\n**Conclusion:** The Average KS_i increased by {0:.2f}% after filtering. The increase in mean score confirms that the KS_i metric successfully concentrated high-quality data.\".format(ks_increase_percent))\n",
|
| 95 |
+
" \n",
|
| 96 |
+
" # --- TTR Comparison ---\n",
|
| 97 |
+
" print(\"\\n--- 2. Lexical Diversity (Type-Token Ratio) ---\")\n",
|
| 98 |
+
" print(f\"| {'Metric':<25} | {'Raw Corpus':>15} | {'Filtered Corpus':>15} |\")\n",
|
| 99 |
+
" print(f\"| {'Source (src_lang) TTR':<25} | {raw_src_ttr:>15.4f} | {filtered_src_ttr:>15.4f} |\")\n",
|
| 100 |
+
" print(f\"| {'Target (tgt_lang) TTR':<25} | {raw_tgt_ttr:>15.4f} | {filtered_tgt_ttr:>15.4f} |\")\n",
|
| 101 |
+
" \n",
|
| 102 |
+
" # --- Diversity Conclusion ---\n",
|
| 103 |
+
" src_ttr_change = (filtered_src_ttr - raw_src_ttr) / raw_src_ttr * 100\n",
|
| 104 |
+
" tgt_ttr_change = (filtered_tgt_ttr - raw_tgt_ttr) / raw_tgt_ttr * 100\n",
|
| 105 |
+
"\n",
|
| 106 |
+
" print(f\"\\n**Conclusion:** Diversity (TTR) changed by {src_ttr_change:.2f}% (Source) and {tgt_ttr_change:.2f}% (Target).\")\n",
|
| 107 |
+
" print(\"A positive or minimal negative change in TTR suggests the filter successfully isolated quality data without sacrificing vital vocabulary coverage.\")\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"\n",
|
| 110 |
+
" except FileNotFoundError as e:\n",
|
| 111 |
+
" print(f\"Error: A required file was not found: {e}. Ensure both '{filepath_raw}' and '{filepath_filtered}' exist.\")\n",
|
| 112 |
+
" except Exception as e:\n",
|
| 113 |
+
" print(f\"An unexpected error occurred: {e}\")\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"if __name__ == \"__main__\":\n",
|
| 116 |
+
" intrinsic_evaluation(FILEPATH_RAW, FILEPATH_FILTERED, SCORE_COL)"
|
| 117 |
+
]
|
| 118 |
+
}
|
| 119 |
+
],
|
| 120 |
+
"metadata": {
|
| 121 |
+
"kernelspec": {
|
| 122 |
+
"display_name": "ptorch",
|
| 123 |
+
"language": "python",
|
| 124 |
+
"name": "python3"
|
| 125 |
+
},
|
| 126 |
+
"language_info": {
|
| 127 |
+
"codemirror_mode": {
|
| 128 |
+
"name": "ipython",
|
| 129 |
+
"version": 3
|
| 130 |
+
},
|
| 131 |
+
"file_extension": ".py",
|
| 132 |
+
"mimetype": "text/x-python",
|
| 133 |
+
"name": "python",
|
| 134 |
+
"nbconvert_exporter": "python",
|
| 135 |
+
"pygments_lexer": "ipython3",
|
| 136 |
+
"version": "3.12.11"
|
| 137 |
+
}
|
| 138 |
+
},
|
| 139 |
+
"nbformat": 4,
|
| 140 |
+
"nbformat_minor": 5
|
| 141 |
+
}
|