Project-1 / src /train.py
Abhishek11k's picture
Upload 31 files
e1d9ec2 verified
import os
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from transformers import (
MBartForConditionalGeneration,
MBart50TokenizerFast,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
DataCollatorForSeq2Seq,
)
# ======================
# CONFIG
# ======================
MODEL_NAME = "facebook/mbart-large-50-many-to-many-mmt"
OUTPUT_DIR = "models/mbart-transliteration"
MAX_INPUT_LENGTH = 128
MAX_TARGET_LENGTH = 128
BATCH_SIZE = 4 # CPU-safe
EPOCHS = 1 # Increase later
LEARNING_RATE = 5e-5
SRC_LANG = "en_XX"
TGT_LANG = "hi_IN" # Hindi
# ======================
# LOAD DATA
# ======================
def load_data():
data_files = {
"train": "data/train.csv",
"validation": "data/val.csv",
"test": "data/test.csv",
}
dataset_dict = {}
for split, path in data_files.items():
df = pd.read_csv(path)
# REQUIRED columns
assert "source" in df.columns
assert "target" in df.columns
dataset_dict[split] = Dataset.from_pandas(df)
return DatasetDict(dataset_dict)
# ======================
# PREPROCESS (✅ FIXED)
# ======================
def preprocess_function(examples):
# ✅ MUST set every call (critical for mBART)
tokenizer.src_lang = SRC_LANG
tokenizer.tgt_lang = TGT_LANG
inputs = examples["source"]
targets = examples["target"]
model_inputs = tokenizer(
inputs,
max_length=MAX_INPUT_LENGTH,
truncation=True,
padding="max_length",
)
labels = tokenizer(
text_target=targets,
max_length=MAX_TARGET_LENGTH,
truncation=True,
padding="max_length",
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
# ======================
# TRAIN
# ======================
def main():
print("Loading tokenizer and model...")
global tokenizer
tokenizer = MBart50TokenizerFast.from_pretrained(MODEL_NAME)
model = MBartForConditionalGeneration.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True)
print("Loading datasets...")
raw_datasets = load_data()
print("Tokenizing datasets...")
tokenized_datasets = raw_datasets.map(
preprocess_function,
batched=True,
remove_columns=raw_datasets["train"].column_names,
)
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model,
)
training_args = Seq2SeqTrainingArguments(
output_dir=OUTPUT_DIR,
eval_strategy="epoch",
learning_rate=LEARNING_RATE,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
num_train_epochs=EPOCHS,
weight_decay=0.01,
save_total_limit=1,
save_strategy="epoch",
predict_with_generate=True,
logging_steps=10,
load_best_model_at_end=True,
report_to="none",
fp16=False, # CPU safe
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
tokenizer=tokenizer,
data_collator=data_collator,
)
print("Training started...")
trainer.train()
print("Saving model...")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"Training complete. Model saved to `{OUTPUT_DIR}`")
# ======================
if __name__ == "__main__":
main()