Spaces:
Runtime error
Runtime error
| import os | |
| import pandas as pd | |
| import torch | |
| from datasets import Dataset, DatasetDict | |
| from transformers import ( | |
| MBartForConditionalGeneration, | |
| MBart50TokenizerFast, | |
| Seq2SeqTrainingArguments, | |
| Seq2SeqTrainer, | |
| DataCollatorForSeq2Seq, | |
| ) | |
| # ====================== | |
| # CONFIG | |
| # ====================== | |
| MODEL_NAME = "facebook/mbart-large-50-many-to-many-mmt" | |
| OUTPUT_DIR = "models/mbart-transliteration" | |
| MAX_INPUT_LENGTH = 128 | |
| MAX_TARGET_LENGTH = 128 | |
| BATCH_SIZE = 4 # CPU-safe | |
| EPOCHS = 1 # Increase later | |
| LEARNING_RATE = 5e-5 | |
| SRC_LANG = "en_XX" | |
| TGT_LANG = "hi_IN" # Hindi | |
| # ====================== | |
| # LOAD DATA | |
| # ====================== | |
| def load_data(): | |
| data_files = { | |
| "train": "data/train.csv", | |
| "validation": "data/val.csv", | |
| "test": "data/test.csv", | |
| } | |
| dataset_dict = {} | |
| for split, path in data_files.items(): | |
| df = pd.read_csv(path) | |
| # REQUIRED columns | |
| assert "source" in df.columns | |
| assert "target" in df.columns | |
| dataset_dict[split] = Dataset.from_pandas(df) | |
| return DatasetDict(dataset_dict) | |
| # ====================== | |
| # PREPROCESS (✅ FIXED) | |
| # ====================== | |
| def preprocess_function(examples): | |
| # ✅ MUST set every call (critical for mBART) | |
| tokenizer.src_lang = SRC_LANG | |
| tokenizer.tgt_lang = TGT_LANG | |
| inputs = examples["source"] | |
| targets = examples["target"] | |
| model_inputs = tokenizer( | |
| inputs, | |
| max_length=MAX_INPUT_LENGTH, | |
| truncation=True, | |
| padding="max_length", | |
| ) | |
| labels = tokenizer( | |
| text_target=targets, | |
| max_length=MAX_TARGET_LENGTH, | |
| truncation=True, | |
| padding="max_length", | |
| ) | |
| model_inputs["labels"] = labels["input_ids"] | |
| return model_inputs | |
| # ====================== | |
| # TRAIN | |
| # ====================== | |
| def main(): | |
| print("Loading tokenizer and model...") | |
| global tokenizer | |
| tokenizer = MBart50TokenizerFast.from_pretrained(MODEL_NAME) | |
| model = MBartForConditionalGeneration.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True) | |
| print("Loading datasets...") | |
| raw_datasets = load_data() | |
| print("Tokenizing datasets...") | |
| tokenized_datasets = raw_datasets.map( | |
| preprocess_function, | |
| batched=True, | |
| remove_columns=raw_datasets["train"].column_names, | |
| ) | |
| data_collator = DataCollatorForSeq2Seq( | |
| tokenizer=tokenizer, | |
| model=model, | |
| ) | |
| training_args = Seq2SeqTrainingArguments( | |
| output_dir=OUTPUT_DIR, | |
| eval_strategy="epoch", | |
| learning_rate=LEARNING_RATE, | |
| per_device_train_batch_size=BATCH_SIZE, | |
| per_device_eval_batch_size=BATCH_SIZE, | |
| num_train_epochs=EPOCHS, | |
| weight_decay=0.01, | |
| save_total_limit=1, | |
| save_strategy="epoch", | |
| predict_with_generate=True, | |
| logging_steps=10, | |
| load_best_model_at_end=True, | |
| report_to="none", | |
| fp16=False, # CPU safe | |
| ) | |
| trainer = Seq2SeqTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_datasets["train"], | |
| eval_dataset=tokenized_datasets["validation"], | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| ) | |
| print("Training started...") | |
| trainer.train() | |
| print("Saving model...") | |
| trainer.save_model(OUTPUT_DIR) | |
| tokenizer.save_pretrained(OUTPUT_DIR) | |
| print(f"Training complete. Model saved to `{OUTPUT_DIR}`") | |
| # ====================== | |
| if __name__ == "__main__": | |
| main() | |