File size: 20,600 Bytes
ad0be11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "311e31e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import pandas for DataFrame manipulation\n",
    "import pandas as pd\n",
    "# Import numpy for numerical operations\n",
    "import numpy as np\n",
    "# Import torch for tensor operations and device handling\n",
    "import torch\n",
    "# Import MBART model and tokenizer from Hugging Face Transformers\n",
    "from transformers import MBartForConditionalGeneration, MBart50TokenizerFast\n",
    "# Import cosine similarity for comparing embeddings\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "# Import tqdm to show progress bars for loops\n",
    "from tqdm import tqdm\n",
    "# Import regex utilities for tokenization and cleaning\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3363ac62",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Configuration ---\n",
    "MODEL_NAME = \"your/model/name\"\n",
    "SRC_LANG_CODE = \"src_lang_code\"\n",
    "TGT_LANG_CODE = \"tgt_lang_code\"\n",
    "CORPUS_FILE = \"your/corpus/here.csv\"\n",
    "DICT_FILE = \"your/bilingual/dictionary/here.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5d67ae6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters for the Knowledge Score (KS_i)\n",
    "# You would tune these based on empirical performance\n",
    "ALPHA = 0.1\n",
    "BETA = 0.3\n",
    "GAMMA = 0.6\n",
    "PERCENTILE_THRESHOLD = 70 # Filter threshold: keep pairs above this percentile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5fb7924",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_text(text):\n",
    "    \"\"\"\n",
    "    Safely preprocesses text by handling NaN, non-string values,\n",
    "    and performing normalization steps.\n",
    "    \"\"\"\n",
    "    if not isinstance(text, str):\n",
    "        return \"\"\n",
    "    text = text.strip().lower()\n",
    "    text = re.sub(r\"\\s+\", \" \", text)           # Collapse multiple spaces\n",
    "    text = re.sub(r\"[^a-zA-Z0-9\\s']\", \"\", text)  # Remove unwanted symbols (keep alphanumerics and apostrophes)\n",
    "    return text\n",
    "\n",
    "\n",
    "def load_data(corpus_file, dict_file):\n",
    "    \"\"\"Loads, cleans, and prepares the parallel corpus and bilingual dictionary.\"\"\"\n",
    "\n",
    "    # --- Load the CSVs safely ---\n",
    "    try:\n",
    "        raw_corpus = pd.read_csv(corpus_file)\n",
    "        word_dictionary = pd.read_csv(dict_file)\n",
    "    except Exception as e:\n",
    "        raise ValueError(f\"Error loading files: {e}\")\n",
    "\n",
    "    # --- Ensure expected columns exist ---\n",
    "    required_corpus_cols = {'English', 'Tagin'}\n",
    "    required_dict_cols = {'English', 'Tagin'}\n",
    "\n",
    "    if not required_corpus_cols.issubset(raw_corpus.columns):\n",
    "        raise ValueError(f\"Corpus file must contain columns: {required_corpus_cols}\")\n",
    "    if not required_dict_cols.issubset(word_dictionary.columns):\n",
    "        raise ValueError(f\"Dictionary file must contain columns: {required_dict_cols}\")\n",
    "\n",
    "    # --- Drop rows with all NaN values ---\n",
    "    raw_corpus = raw_corpus.dropna(how='all')\n",
    "\n",
    "    # --- Fill NaN cells with empty strings ---\n",
    "    raw_corpus = raw_corpus.fillna(\"\")\n",
    "\n",
    "    # --- Apply text preprocessing ---\n",
    "    raw_corpus[\"English\"] = raw_corpus[\"English\"].apply(preprocess_text)\n",
    "    raw_corpus[\"Tagin\"] = raw_corpus[\"Tagin\"].apply(preprocess_text)\n",
    "\n",
    "    # --- Clean dictionary entries ---\n",
    "    word_dictionary[\"English\"] = word_dictionary[\"English\"].apply(preprocess_text)\n",
    "    word_dictionary[\"Tagin\"] = word_dictionary[\"Tagin\"].apply(preprocess_text)\n",
    "\n",
    "    # --- Convert dictionary to mapping ---\n",
    "    word_dictionary = word_dictionary.set_index('English')['Tagin'].to_dict()\n",
    "\n",
    "    # --- Remove empty rows after cleaning ---\n",
    "    raw_corpus = raw_corpus[\n",
    "        (raw_corpus[\"English\"].str.strip() != \"\") &\n",
    "        (raw_corpus[\"Tagin\"].str.strip() != \"\")\n",
    "    ].reset_index(drop=True)\n",
    "\n",
    "    print(f\"Loaded {len(raw_corpus)} sentence pairs and {len(word_dictionary)} dictionary entries.\")\n",
    "    return raw_corpus, word_dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "772824d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "load_data(CORPUS_FILE,DICT_FILE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7322656f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function for Step 2: Perplexity (PPL)\n",
    "@torch.no_grad()\n",
    "def calculate_perplexity(sentence, model, tokenizer, device):\n",
    "    \"\"\"Computes perplexity of a sentence using the given LM.\"\"\"\n",
    "    try:\n",
    "        # Tokenize and format for mBART-50 (e.g., [lang_code] X [eos])\n",
    "        # We'll treat this as a generation task from the source language to itself\n",
    "        # to get log probabilities for the language modeling loss.\n",
    "        input_ids = tokenizer(\n",
    "            sentence,\n",
    "            return_tensors=\"pt\",\n",
    "            max_length=512,\n",
    "            truncation=True\n",
    "        ).input_ids.to(device)\n",
    "        \n",
    "        # Set the source language\n",
    "        tokenizer.src_lang = SRC_LANG_CODE\n",
    "        \n",
    "        # The labels for perplexity are the input tokens themselves, shifted.\n",
    "        # This is essentially a language modeling task.\n",
    "        labels = input_ids.clone()\n",
    "        \n",
    "        # Use -100 to ignore the loss for special tokens (like the language code token)\n",
    "        labels[:, 0] = -100\n",
    "\n",
    "        outputs = model(input_ids=input_ids, labels=labels)\n",
    "        neg_log_likelihood = outputs.loss\n",
    "        \n",
    "        # Perplexity is exp(average negative log-likelihood)\n",
    "        # The 'outputs.loss' from the Transformers library is already the average NLL per token.\n",
    "        ppl = torch.exp(neg_log_likelihood).item()\n",
    "        return ppl\n",
    "    except Exception as e:\n",
    "        print(f\"Error calculating PPL for: '{sentence}'. Error: {e}\")\n",
    "        return float('inf') # Return a very high PPL for errors/bad sentences\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "231f19a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize_inverse_ppl(ppl_scores, epsilon=1e-6):\n",
    "    \"\"\"\n",
    "    Safely normalizes inverse perplexity (1/PPL_i) to [0, 1].\n",
    "    \n",
    "    Handles edge cases where PPL scores are constant, contain inf/nan, or are invalid.\n",
    "    \"\"\"\n",
    "    ppl_scores = np.array(ppl_scores, dtype=np.float64)\n",
    "\n",
    "    # Replace infinities or NaNs with large finite numbers for stability\n",
    "    ppl_scores = np.nan_to_num(ppl_scores, nan=np.inf, posinf=np.inf, neginf=np.inf)\n",
    "\n",
    "    # Compute inverse PPL (fluency measure)\n",
    "    inv_ppl = 1.0 / (ppl_scores + epsilon)\n",
    "\n",
    "    # Remove any remaining NaNs/Infs from inverse scores\n",
    "    inv_ppl = np.nan_to_num(inv_ppl, nan=0.0, posinf=0.0, neginf=0.0)\n",
    "\n",
    "    inv_min = np.min(inv_ppl)\n",
    "    inv_max = np.max(inv_ppl)\n",
    "\n",
    "    # Handle zero-range case: all scores are the same\n",
    "    if np.isclose(inv_max, inv_min) or np.isnan(inv_max - inv_min):\n",
    "        return np.zeros_like(inv_ppl)\n",
    "\n",
    "    # Normal min–max scaling\n",
    "    inv_ppl_norm = (inv_ppl - inv_min) / (inv_max - inv_min)\n",
    "    inv_ppl_norm = np.clip(inv_ppl_norm, 0.0, 1.0)\n",
    "\n",
    "    return inv_ppl_norm\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad791177",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function for Step 3: Semantic Similarity (Sim)\n",
    "\n",
    "def calculate_semantic_similarity(s_i, t_i, model, tokenizer, device):\n",
    "    \"\"\"\n",
    "    Computes Cosine Similarity between source and target sentence embeddings \n",
    "    and normalizes the result to the range [0, 1].\n",
    "    \"\"\"\n",
    "    try:\n",
    "        def get_embedding(sentence, lang_code):\n",
    "            tokenizer.src_lang = lang_code\n",
    "            inputs = tokenizer(\n",
    "                sentence,\n",
    "                return_tensors=\"pt\",\n",
    "                max_length=512,\n",
    "                truncation=True,\n",
    "                padding=True\n",
    "            ).to(device)\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                encoder_output = model.model.encoder(**inputs).last_hidden_state\n",
    "            \n",
    "            mean_embedding = encoder_output[:, 1:-1, :].mean(dim=1).squeeze() \n",
    "            \n",
    "            return mean_embedding.cpu().detach().numpy().reshape(1, -1)\n",
    "\n",
    "        emb_s = get_embedding(s_i, SRC_LANG_CODE) \n",
    "        emb_t = get_embedding(t_i, TGT_LANG_CODE)\n",
    "\n",
    "        sim_raw = cosine_similarity(emb_s, emb_t)[0][0]\n",
    "        \n",
    "        sim_normalized = (sim_raw + 1) / 2\n",
    "        sim_normalized = max(0.0, min(1.0, sim_normalized))\n",
    "        \n",
    "        return sim_normalized\n",
    "        \n",
    "    except Exception as e:\n",
    "        # print(f\"Error calculating Sim for: '{s_i}' and '{t_i}'. Error: {e}\")\n",
    "        return 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87eebd8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function for Step 4: Lexical Match (Lex)  # header describing the block\n",
    "# blank line preserved for readability\n",
    "# Define a function that computes a lexical match score based on a bilingual dictionary\n",
    "def calculate_lexical_match(s_i, t_i, word_dictionary):\n",
    "    # Docstring start: describe purpose and formula for lex score\n",
    "    \"\"\"\n",
    "    Computes a dictionary-based lexical match score prioritizing phrase matches.\n",
    "    Score = (Count of source words covered by successfully translated phrases) / (Total words in source sentence)\n",
    "    \"\"\"  # docstring end\n",
    "    # Helper: normalize text to ease phrase matching (lowercase, token boundaries)\n",
    "    def normalize_text(text):\n",
    "        # Simple tokenization: lowercase, remove non-word characters, and join back for easy phrase matching\n",
    "        return \" \" + \" \".join(re.findall(r'\\b\\w+\\b', text.lower())) + \" \"  # pad with spaces for boundary-safe matching\n",
    "    # Normalize the source sentence for phrase lookups\n",
    "    s_normalized = normalize_text(s_i)\n",
    "    # Token set of the target sentence for quick membership tests\n",
    "    t_tokens = set(re.findall(r'\\b\\w+\\b', t_i.lower()))\n",
    "    # blank line preserved for readability\n",
    "    # Sort dictionary keys by length (descending) to prioritize phrase matches over single words\n",
    "    tagin_phrases = sorted(word_dictionary.keys(), key=len, reverse=True)\n",
    "    # blank line preserved for readability\n",
    "    # Extract source word tokens and compute total count\n",
    "    source_words = re.findall(r'\\b\\w+\\b', s_i.lower())\n",
    "    total_source_words = len(source_words)\n",
    "    # Initialize covered word counter\n",
    "    covered_word_count = 0\n",
    "    # If the source sentence is empty, return 0.0 immediately\n",
    "    if total_source_words == 0:\n",
    "        return 0.0\n",
    "    # Track indices of covered words if needed (not used further but kept for clarity)\n",
    "    covered_indices = set()\n",
    "    # Iterate over dictionary phrases (longest-first) to find matches in the source\n",
    "    for phrase in tagin_phrases:\n",
    "        # Skip empty dictionary entries\n",
    "        if not phrase:\n",
    "            continue\n",
    "        # Normalize the phrase for safe matching\n",
    "        norm_phrase = normalize_text(phrase)\n",
    "        # If the normalized phrase exists in the normalized source text, proceed\n",
    "        if norm_phrase in s_normalized:\n",
    "            # Get expected translation from the dictionary (lowercased)\n",
    "            expected_translation = word_dictionary[phrase].lower()\n",
    "            # Tokenize the expected translation into words\n",
    "            translation_words = re.findall(r'\\b\\w+\\b', expected_translation)\n",
    "            # Check whether all translated words appear in the target sentence tokens\n",
    "            is_translation_present = all(word in t_tokens for word in translation_words)\n",
    "            # If the translation words are present in the target, count the phrase as covered\n",
    "            if is_translation_present:\n",
    "                # Search for possibly multiple occurrences of the phrase in the source\n",
    "                start = 0\n",
    "                while True:\n",
    "                    # Find next occurrence starting from 'start' index\n",
    "                    start_index = s_normalized.find(norm_phrase, start)\n",
    "                    # If no more occurrences, break the loop\n",
    "                    if start_index == -1:\n",
    "                        break\n",
    "                    # Count how many words are in the matched phrase\n",
    "                    phrase_word_count = len(re.findall(r'\\b\\w+\\b', phrase))\n",
    "                    # Add the phrase's word count to the covered total\n",
    "                    covered_word_count += phrase_word_count\n",
    "                    # Advance the search start position past the current match\n",
    "                    start = start_index + len(norm_phrase)\n",
    "                # end while loop for occurrences\n",
    "    # After checking all phrases, compute lex score as covered words / total source words (capped at 1.0)\n",
    "    lex_score = min(1.0, covered_word_count / total_source_words)\n",
    "    # Return lexical match score between 0 and 1\n",
    "    return lex_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46f310e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Main Algorithm Implementation\n",
    "def knowledge_based_filtering(raw_corpus, word_dictionary, alpha, beta, gamma, percentile_threshold):\n",
    "    # 1. Load Model and Tokenizer (Step 1 of the algorithm's loop)\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    print(f\"Loading mBART-tgj-base model to {device}...\")\n",
    "    model = MBartForConditionalGeneration.from_pretrained(MODEL_NAME).to(device)\n",
    "    tokenizer = MBart50TokenizerFast.from_pretrained(MODEL_NAME)\n",
    "    \n",
    "    # Ensure source/target language codes are in the tokenizer vocabulary\n",
    "    if SRC_LANG_CODE not in tokenizer.vocab or TGT_LANG_CODE not in tokenizer.vocab:\n",
    "        print(f\"Warning: Language codes {SRC_LANG_CODE} or {TGT_LANG_CODE} not found in base mBART-tgj-base vocab.\")\n",
    "        print(\"Using placeholder language codes. Results may not be accurate.\")\n",
    "\n",
    "    results = []\n",
    "\n",
    "    # 2. Iterate through the corpus (Lines 1-6)\n",
    "    print(\"Processing corpus to calculate scores...\")\n",
    "    for index, row in tqdm(raw_corpus.iterrows(), total=len(raw_corpus), desc=\"Calculating KS\"):\n",
    "        s_i = row['English']\n",
    "        t_i = row['Tagin']\n",
    "\n",
    "        # Line 2: Compute PPL_i (Lower is better)\n",
    "        pp= calculate_perplexity(s_i, model, tokenizer, device)\n",
    "        PPL_i=  normalize_inverse_ppl(pp, epsilon=1e-6)\n",
    "        # PPL_i = normalize_inverse_ppl(row[\"Perplexity\"])\n",
    "        \n",
    "        # Line 3: Compute Sim_i (Higher is better)\n",
    "        Sim_i = calculate_semantic_similarity(s_i, t_i, model, tokenizer, device)\n",
    "        \n",
    "        # Line 4: Check Lex_i (Higher is better)\n",
    "        Lex_i = calculate_lexical_match(s_i, t_i, word_dictionary)\n",
    "        \n",
    "        # Line 5: Derive Knowledge Score (KS_i)\n",
    "        # Note: We use 1/PPL_i because PPL_i is an inverse quality metric (lower PPL is higher quality)\n",
    "        # while Sim and Lex are direct quality metrics (higher is better).\n",
    "        # We add a small epsilon to avoid division by zero, though a PPL of 0 is practically impossible.\n",
    "        # PPL_i_inv = 1.0 / (PPL_i + 1e-6)\n",
    "        # -----IMPORTANT------\n",
    "        \n",
    "        KS_i = alpha * PPL_i + beta * Sim_i + gamma * Lex_i\n",
    "        \n",
    "        results.append({\n",
    "            'src_lang': s_i,\n",
    "            'tgt_lang': t_i,\n",
    "            'PPL_i': PPL_i,\n",
    "            'Sim_i': Sim_i,\n",
    "            'Lex_i': Lex_i,\n",
    "            'PPL_i': PPL_i,\n",
    "            'KS_i': KS_i\n",
    "        })\n",
    "\n",
    "    # Convert results to DataFrame for filtering\n",
    "    scored_corpus = pd.DataFrame(results)\n",
    "\n",
    "    # 3. Determine Threshold and Filter (Lines 7-9)\n",
    "    # Line 7: Find the 80th percentile of Knowledge Scores\n",
    "    tau_K = np.percentile(scored_corpus['KS_i'], percentile_threshold)\n",
    "    print(f\"\\n50th Percentile Knowledge Score (τ_K): {tau_K:.4f}\")\n",
    "    \n",
    "    # Line 8: Filter the corpus\n",
    "    D_filtered = scored_corpus[scored_corpus['KS_i'] >= tau_K].copy()\n",
    "    \n",
    "    # Final cleanup of columns and return\n",
    "    D_filtered = D_filtered[['src_lang', 'tgt_lang', 'KS_i']]\n",
    "    print(f\"Raw corpus size: {len(raw_corpus)}\")\n",
    "    print(f\"Filtered corpus size (KS_i >= τ_K): {len(D_filtered)}\")\n",
    "    \n",
    "    return D_filtered"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b2ce69d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Execution ---  # script entry and high-level execution steps\n",
    "# blank line preserved for readability\n",
    "# Guard to ensure code only runs when executed as a script, not on import\n",
    "if __name__ == '__main__':\n",
    "    # 1. Load data  # load and preprocess corpus and dictionary files\n",
    "    raw_corpus, word_dictionary = load_data(CORPUS_FILE, DICT_FILE)\n",
    "    # blank line preserved for readability\n",
    "    # 2. Run the filtering algorithm  # compute KS_i and filter by percentile\n",
    "    filtered_corpus = knowledge_based_filtering(\n",
    "        raw_corpus,  # pass the preprocessed corpus DataFrame\n",
    "        word_dictionary,  # pass the dictionary mapping\n",
    "        ALPHA, BETA, GAMMA,  # weighting hyperparameters for KS_i\n",
    "        PERCENTILE_THRESHOLD  # percentile cutoff for filtering\n",
    "    )\n",
    "    # blank line preserved for readability\n",
    "    # Save the filtered corpus to CSV for downstream use\n",
    "    filtered_corpus.to_csv(\"tgj_corpus_filtered_70th.csv\", index=False)\n",
    "    # blank line preserved for readability\n",
    "    # Notify user of completion and where results were saved\n",
    "    print(\"\\nFiltering complete. Results saved to tgj_corpus_filtered_70th.csv\")\n",
    "    # Show a short preview of the filtered corpus\n",
    "    print(\"\\nFiltered Corpus Head:\")\n",
    "    print(filtered_corpus)  # print DataFrame to stdout"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ptorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}