# src/train.py import os import argparse from datasets import Dataset from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, ) def train_model(): """ Fine-tunes a pre-trained NLLB model on a parallel dataset. """ parser = argparse.ArgumentParser(description="Fine-tune a translation model.") parser.add_argument("--model_checkpoint", type=str, default="facebook/nllb-200-distilled-600M") parser.add_argument("--source_lang", type=str, required=True, help="Source language code (e.g., 'ne')") parser.add_argument("--target_lang", type=str, default="en") parser.add_argument("--source_lang_tokenizer", type=str, required=True, help="Source language code for tokenizer (e.g., 'nep_Npan')") parser.add_argument("--train_file_source", type=str, required=True, help="Path to the source language training file") parser.add_argument("--train_file_target", type=str, required=True, help="Path to the target language training file") parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the fine-tuned model") parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--batch_size", type=int, default=8) args = parser.parse_args() # --- 1. Configuration --- MODEL_CHECKPOINT = args.model_checkpoint SOURCE_LANG = args.source_lang TARGET_LANG = args.target_lang MODEL_OUTPUT_DIR = args.output_dir # --- 2. Load Tokenizer and Model --- print("Loading tokenizer and model...") tokenizer = AutoTokenizer.from_pretrained( MODEL_CHECKPOINT, src_lang=args.source_lang_tokenizer, tgt_lang="eng_Latn" ) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT) # --- 3. Load and Preprocess Data (Memory-Efficiently) --- print("Loading and preprocessing data...") def generate_examples(): with open(args.train_file_source, "r", encoding="utf-8") as f_src, \ open(args.train_file_target, "r", encoding="utf-8") as f_tgt: for src_line, tgt_line in zip(f_src, f_tgt): yield {"translation": {SOURCE_LANG: src_line.strip(), TARGET_LANG: tgt_line.strip()}} dataset = Dataset.from_generator(generate_examples) split_datasets = dataset.train_test_split(train_size=0.95, seed=42) split_datasets["validation"] = split_datasets.pop("test") def preprocess_function(examples): inputs = [ex[SOURCE_LANG] for ex in examples["translation"]] targets = [ex[TARGET_LANG] for ex in examples["translation"]] model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True) return model_inputs tokenized_datasets = split_datasets.map( preprocess_function, batched=True, remove_columns=split_datasets["train"].column_names, ) # --- 4. Set Up Training Arguments --- print("Setting up training arguments...") training_args = Seq2SeqTrainingArguments( output_dir=MODEL_OUTPUT_DIR, eval_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, weight_decay=0.01, save_total_limit=3, num_train_epochs=args.epochs, predict_with_generate=True, fp16=False, # Set to True if you have a compatible GPU ) # --- 5. Create the Trainer --- data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["validation"], tokenizer=tokenizer, data_collator=data_collator, ) # --- 6. Start Training --- print("\n--- Starting model fine-tuning ---") trainer.train() print("--- Training complete ---") # --- 7. Save the Final Model --- print(f"Saving final model to {MODEL_OUTPUT_DIR}") trainer.save_model() print("Model saved successfully!") if __name__ == "__main__": train_model()