| import tensorflow as tf |
| import os |
| import pandas as pd |
| from tqdm import tqdm |
| from datasets import Dataset |
| from transformers import TFMT5ForConditionalGeneration, MT5Tokenizer, DataCollatorForSeq2Seq |
| from tensorflow.keras.optimizers import Adam |
|
|
| |
| train_flag = False |
|
|
| |
| checkpoint_dir = 'model_checkpoints' |
| MAX_LENGTH = 2000 |
| |
| tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small") |
|
|
| |
| class CustomModelCheckpoint(tf.keras.callbacks.Callback): |
| def __init__(self, checkpoint_dir, save_freq=1): |
| super(CustomModelCheckpoint, self).__init__() |
| self.checkpoint_dir = checkpoint_dir |
| self.save_freq = save_freq |
|
|
| def on_epoch_end(self, epoch, logs=None): |
| if (epoch + 1) % self.save_freq == 0: |
| path = os.path.join(self.checkpoint_dir, f"checkpoint-{epoch + 1}") |
| if not os.path.exists(path): |
| os.makedirs(path) |
| self.model.save_pretrained(path) |
|
|
| |
| latest_checkpoint = None |
| if os.path.exists(checkpoint_dir): |
| checkpoints = [os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir)] |
| checkpoints = [d for d in checkpoints if os.path.isdir(d)] |
| if checkpoints: |
| latest_checkpoint = max(checkpoints, key=os.path.getmtime) |
|
|
| if latest_checkpoint: |
| print("Resuming...") |
| model = TFMT5ForConditionalGeneration.from_pretrained(latest_checkpoint) |
| else: |
| model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small") |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| df = pd.read_pickle("srt_scapper/Quran/fin_data.nas") |
| df.columns = ["Text","Expected"] |
| |
| dataset = Dataset.from_pandas(df) |
| dataset = dataset.shuffle(seed=42) |
|
|
| |
| def preprocess_function(examples): |
| |
| padding = "max_length" |
| max_length = 2000 |
| inputs = examples["Text"] |
| targets = examples["Expected"] |
| model_inputs = tokenizer(inputs, max_length=max_length, padding=padding, truncation=True) |
| labels = tokenizer(targets, max_length=max_length, padding=padding, truncation=True) |
| model_inputs["labels"] = labels["input_ids"] |
| return model_inputs |
|
|
| train_dataset = dataset.map(preprocess_function, batched=True, desc="Running tokenizer") |
|
|
| data_collator = DataCollatorForSeq2Seq( |
| tokenizer, |
| model=model, |
| label_pad_token_id=tokenizer.pad_token_id, |
| pad_to_multiple_of=64, |
| return_tensors="tf" |
| ) |
|
|
| |
| tf_train_dataset = model.prepare_tf_dataset( |
| train_dataset, |
| collate_fn=data_collator, |
| batch_size=4, |
| shuffle=True |
| ) |
|
|
| |
| model.compile(optimizer=Adam(3e-5)) |
|
|
| |
| early_stopping = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3) |
|
|
| |
| if not os.path.exists(checkpoint_dir): |
| os.makedirs(checkpoint_dir) |
|
|
| |
| model_checkpoint = CustomModelCheckpoint( |
| checkpoint_dir, |
| save_freq=1 |
| ) |
|
|
| |
| model.fit( |
| tf_train_dataset, |
| epochs=10, |
| callbacks=[early_stopping, model_checkpoint] |
| ) |
|
|
| |
| model.save_pretrained(os.path.join(checkpoint_dir, 'final_model')) |