import os import pandas as pd import torch from datasets import Dataset, DatasetDict from transformers import ( MBartForConditionalGeneration, MBart50TokenizerFast, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq, ) # ====================== # CONFIG # ====================== MODEL_NAME = "facebook/mbart-large-50-many-to-many-mmt" OUTPUT_DIR = "models/mbart-transliteration" MAX_INPUT_LENGTH = 128 MAX_TARGET_LENGTH = 128 BATCH_SIZE = 4 # CPU-safe EPOCHS = 1 # Increase later LEARNING_RATE = 5e-5 SRC_LANG = "en_XX" TGT_LANG = "hi_IN" # Hindi # ====================== # LOAD DATA # ====================== def load_data(): data_files = { "train": "data/train.csv", "validation": "data/val.csv", "test": "data/test.csv", } dataset_dict = {} for split, path in data_files.items(): df = pd.read_csv(path) # REQUIRED columns assert "source" in df.columns assert "target" in df.columns dataset_dict[split] = Dataset.from_pandas(df) return DatasetDict(dataset_dict) # ====================== # PREPROCESS (✅ FIXED) # ====================== def preprocess_function(examples): # ✅ MUST set every call (critical for mBART) tokenizer.src_lang = SRC_LANG tokenizer.tgt_lang = TGT_LANG inputs = examples["source"] targets = examples["target"] model_inputs = tokenizer( inputs, max_length=MAX_INPUT_LENGTH, truncation=True, padding="max_length", ) labels = tokenizer( text_target=targets, max_length=MAX_TARGET_LENGTH, truncation=True, padding="max_length", ) model_inputs["labels"] = labels["input_ids"] return model_inputs # ====================== # TRAIN # ====================== def main(): print("Loading tokenizer and model...") global tokenizer tokenizer = MBart50TokenizerFast.from_pretrained(MODEL_NAME) model = MBartForConditionalGeneration.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True) print("Loading datasets...") raw_datasets = load_data() print("Tokenizing datasets...") tokenized_datasets = raw_datasets.map( preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names, ) data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, model=model, ) training_args = Seq2SeqTrainingArguments( output_dir=OUTPUT_DIR, eval_strategy="epoch", learning_rate=LEARNING_RATE, per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE, num_train_epochs=EPOCHS, weight_decay=0.01, save_total_limit=1, save_strategy="epoch", predict_with_generate=True, logging_steps=10, load_best_model_at_end=True, report_to="none", fp16=False, # CPU safe ) trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["validation"], tokenizer=tokenizer, data_collator=data_collator, ) print("Training started...") trainer.train() print("Saving model...") trainer.save_model(OUTPUT_DIR) tokenizer.save_pretrained(OUTPUT_DIR) print(f"Training complete. Model saved to `{OUTPUT_DIR}`") # ====================== if __name__ == "__main__": main()