EnglishToHindiTranslationLongContext / Retraining_pipeline.py
nashit93's picture
First commit with model
e3937dd
import tensorflow as tf
import os
import pandas as pd
from tqdm import tqdm
from datasets import Dataset
from transformers import TFMT5ForConditionalGeneration, MT5Tokenizer, DataCollatorForSeq2Seq
from tensorflow.keras.optimizers import Adam
# Flag to control the training
train_flag = False # Set this to False if you don't want to train
# Check for an existing model checkpoint
checkpoint_dir = 'model_checkpoints'
MAX_LENGTH = 2000
# Set up tokenizer and model
tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small")
# Custom callback for saving the model
class CustomModelCheckpoint(tf.keras.callbacks.Callback):
def __init__(self, checkpoint_dir, save_freq=1):
super(CustomModelCheckpoint, self).__init__()
self.checkpoint_dir = checkpoint_dir
self.save_freq = save_freq
def on_epoch_end(self, epoch, logs=None):
if (epoch + 1) % self.save_freq == 0:
path = os.path.join(self.checkpoint_dir, f"checkpoint-{epoch + 1}")
if not os.path.exists(path):
os.makedirs(path)
self.model.save_pretrained(path)
# Load or initialize model
latest_checkpoint = None
if os.path.exists(checkpoint_dir):
checkpoints = [os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir)]
checkpoints = [d for d in checkpoints if os.path.isdir(d)]
if checkpoints:
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
if latest_checkpoint:
print("Resuming...")
model = TFMT5ForConditionalGeneration.from_pretrained(latest_checkpoint)
else:
model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
# # Load your data
# file_path1 = 'train.en'
# file_path2 = 'train.hi'
# lines1, lines2 = [], []
# # Read data from files
# with open(file_path1, 'r') as file1, open(file_path2, 'r') as file2:
# for line1, line2 in tqdm(zip(file1, file2), desc="Reading Data"):
# try:
# line1_clean = line1.strip()
# line2_clean = line2.strip()
# lines1.append(line1_clean)
# lines2.append(line2_clean)
# except Exception as e:
# continue
# # Create DataFrame
# df = pd.DataFrame({
# 'Text': lines1,
# 'Expected': lines2
# })
# df = df.reset_index(drop=True)
# df = df.sample(frac=0.1)
df = pd.read_pickle("srt_scapper/Quran/fin_data.nas")
df.columns = ["Text","Expected"]
# Convert DataFrame to Hugging Face dataset format
dataset = Dataset.from_pandas(df)
dataset = dataset.shuffle(seed=42)
# Tokenization and data preparation with increased max_length
def preprocess_function(examples):
# Increased max_length to 2000 to handle longer sentences
padding = "max_length"
max_length = 2000 # Increase max_length here
inputs = examples["Text"]
targets = examples["Expected"]
model_inputs = tokenizer(inputs, max_length=max_length, padding=padding, truncation=True)
labels = tokenizer(targets, max_length=max_length, padding=padding, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
train_dataset = dataset.map(preprocess_function, batched=True, desc="Running tokenizer")
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=tokenizer.pad_token_id,
pad_to_multiple_of=64,
return_tensors="tf"
)
# Prepare dataset for training
tf_train_dataset = model.prepare_tf_dataset(
train_dataset,
collate_fn=data_collator,
batch_size=4, # Consider reducing batch size if memory issues occur
shuffle=True
)
# Compile the model
model.compile(optimizer=Adam(3e-5))
# Callbacks for training
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
# Ensure the checkpoint directory exists
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# Custom model checkpoint
model_checkpoint = CustomModelCheckpoint(
checkpoint_dir,
save_freq=1 # Save after every epoch
)
# Fit the model
model.fit(
tf_train_dataset,
epochs=10,
callbacks=[early_stopping, model_checkpoint]
)
# Saving the final model (optional)
model.save_pretrained(os.path.join(checkpoint_dir, 'final_model'))