File size: 5,152 Bytes
ad0be11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "2230ec1b",
"metadata": {},
"outputs": [],
"source": [
"from transformers import MBartForConditionalGeneration, MBart50TokenizerFast # MBART model and tokenizer classes\n",
"from tqdm import tqdm # progress bar for loops\n",
"import torch # PyTorch for tensors and device handling\n",
"import csv # CSV writer for output"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7552da07",
"metadata": {},
"outputs": [],
"source": [
"# Load tokenizer and model (local path to your fine-tuned model)\n",
"model_path = \"./combined_training/en_tgj_combined_model\" # path to fine-tuned model directory (change if needed)\n",
"tokenizer = MBart50TokenizerFast.from_pretrained(model_path) # load tokenizer from model path\n",
"model = MBartForConditionalGeneration.from_pretrained(model_path) # load model weights and config\n",
"model.eval() # set model to evaluation mode (disables dropout)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # prefer GPU if available\n",
"model.to(device) # move model to selected device"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "750545cc",
"metadata": {},
"outputs": [],
"source": [
"# Parameters for tokenization and generation\n",
"src_lang_token = \"en_XX\" # MBART source language token to prepend\n",
"tgt_lang_token = \"<tgn_IN>\" # target language token / forced BOS for generation\n",
"batch_size = 16 # number of sentences per batch\n",
"max_length = 128 # maximum token length for tokenization and generation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "03415b52",
"metadata": {},
"outputs": [],
"source": [
"# Read English sentences from a text file (one sentence per line)\n",
"with open(\"./sentences01.txt\", \"r\", encoding=\"utf-8\") as f: # input file path\n",
" english_sentences = [line.strip() for line in f if line.strip()] # strip and ignore empty lines"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0927f85",
"metadata": {},
"outputs": [],
"source": [
"# Prepend the MBART source language token to each sentence\n",
"prefixed_sentences = [f\"{src_lang_token} {s}\" for s in english_sentences] # required by MBART tokenizer\n",
"\n",
"# Prepare a list to collect generated translations\n",
"translated_sentences = [] # will hold output strings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "86b13113",
"metadata": {},
"outputs": [],
"source": [
"# Iterate through sentences in batches and generate translations\n",
"for i in tqdm(range(0, len(prefixed_sentences), batch_size), desc=\"Batch Translating\"): # batching loop\n",
" batch = prefixed_sentences[i:i+batch_size] # take a slice for this batch\n",
"\n",
" # Tokenize the batch and move tensors to the model device\n",
" inputs = tokenizer(batch, return_tensors=\"pt\", padding=True, truncation=True, max_length=max_length).to(device)\n",
"\n",
" with torch.no_grad(): # disable gradients for inference to save memory\n",
" generated_tokens = model.generate(\n",
" **inputs, # pass input_ids, attention_mask, etc.\n",
" forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang_token), # ensure generation uses target language token\n",
" max_length=max_length, # cap the generated length\n",
" num_beams=5, # beam search for higher-quality decoding\n",
" early_stopping=True, # stop once beams finish\n",
" )\n",
"\n",
" # Decode token IDs to text and collect results\n",
" outputs = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) # convert ids to strings\n",
" translated_sentences.extend(outputs) # append batch outputs to final list"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d9a12d2",
"metadata": {},
"outputs": [],
"source": [
"# Write aligned original and translated sentences to a CSV file\n",
"with open(\"./output_entgj_combined01.csv\", \"w\", encoding=\"utf-8\", newline=\"\") as f: # output file path\n",
" writer = csv.writer(f) # CSV writer object\n",
" writer.writerow([\"original\", \"translated\"]) # write header row\n",
" for src, tgt in zip(english_sentences, translated_sentences): # iterate aligned pairs\n",
" writer.writerow([src, tgt]) # write each pair as a row"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ptorch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|