{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "2230ec1b", "metadata": {}, "outputs": [], "source": [ "from transformers import MBartForConditionalGeneration, MBart50TokenizerFast # MBART model and tokenizer classes\n", "from tqdm import tqdm # progress bar for loops\n", "import torch # PyTorch for tensors and device handling\n", "import csv # CSV writer for output" ] }, { "cell_type": "code", "execution_count": null, "id": "7552da07", "metadata": {}, "outputs": [], "source": [ "# Load tokenizer and model (local path to your fine-tuned model)\n", "model_path = \"./combined_training/en_tgj_combined_model\" # path to fine-tuned model directory (change if needed)\n", "tokenizer = MBart50TokenizerFast.from_pretrained(model_path) # load tokenizer from model path\n", "model = MBartForConditionalGeneration.from_pretrained(model_path) # load model weights and config\n", "model.eval() # set model to evaluation mode (disables dropout)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # prefer GPU if available\n", "model.to(device) # move model to selected device" ] }, { "cell_type": "code", "execution_count": null, "id": "750545cc", "metadata": {}, "outputs": [], "source": [ "# Parameters for tokenization and generation\n", "src_lang_token = \"en_XX\" # MBART source language token to prepend\n", "tgt_lang_token = \"\" # target language token / forced BOS for generation\n", "batch_size = 16 # number of sentences per batch\n", "max_length = 128 # maximum token length for tokenization and generation" ] }, { "cell_type": "code", "execution_count": null, "id": "03415b52", "metadata": {}, "outputs": [], "source": [ "# Read English sentences from a text file (one sentence per line)\n", "with open(\"./sentences01.txt\", \"r\", encoding=\"utf-8\") as f: # input file path\n", " english_sentences = [line.strip() for line in f if line.strip()] # strip and ignore empty lines" ] }, { "cell_type": "code", "execution_count": null, "id": "c0927f85", "metadata": {}, "outputs": [], "source": [ "# Prepend the MBART source language token to each sentence\n", "prefixed_sentences = [f\"{src_lang_token} {s}\" for s in english_sentences] # required by MBART tokenizer\n", "\n", "# Prepare a list to collect generated translations\n", "translated_sentences = [] # will hold output strings" ] }, { "cell_type": "code", "execution_count": null, "id": "86b13113", "metadata": {}, "outputs": [], "source": [ "# Iterate through sentences in batches and generate translations\n", "for i in tqdm(range(0, len(prefixed_sentences), batch_size), desc=\"Batch Translating\"): # batching loop\n", " batch = prefixed_sentences[i:i+batch_size] # take a slice for this batch\n", "\n", " # Tokenize the batch and move tensors to the model device\n", " inputs = tokenizer(batch, return_tensors=\"pt\", padding=True, truncation=True, max_length=max_length).to(device)\n", "\n", " with torch.no_grad(): # disable gradients for inference to save memory\n", " generated_tokens = model.generate(\n", " **inputs, # pass input_ids, attention_mask, etc.\n", " forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang_token), # ensure generation uses target language token\n", " max_length=max_length, # cap the generated length\n", " num_beams=5, # beam search for higher-quality decoding\n", " early_stopping=True, # stop once beams finish\n", " )\n", "\n", " # Decode token IDs to text and collect results\n", " outputs = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) # convert ids to strings\n", " translated_sentences.extend(outputs) # append batch outputs to final list" ] }, { "cell_type": "code", "execution_count": null, "id": "6d9a12d2", "metadata": {}, "outputs": [], "source": [ "# Write aligned original and translated sentences to a CSV file\n", "with open(\"./output_entgj_combined01.csv\", \"w\", encoding=\"utf-8\", newline=\"\") as f: # output file path\n", " writer = csv.writer(f) # CSV writer object\n", " writer.writerow([\"original\", \"translated\"]) # write header row\n", " for src, tgt in zip(english_sentences, translated_sentences): # iterate aligned pairs\n", " writer.writerow([src, tgt]) # write each pair as a row" ] } ], "metadata": { "kernelspec": { "display_name": "ptorch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 5 }