| import os |
| import sys |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments |
|
|
| class GPTAssistant: |
| def __init__(self, model_name="/Users/migueldeguzman/Desktop/gpt2xl_algos/phi-1.5/v7/"): |
| try: |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| self.model = AutoModelForCausalLM.from_pretrained(model_name) |
| except Exception as e: |
| print(f"Error initializing the model or tokenizer: {e}") |
| sys.exit(1) |
|
|
| def fine_tune(self, answer_file_path, model_output_dir, epochs=1.0): |
| |
| try: |
| train_dataset = TextDataset( |
| tokenizer=self.tokenizer, |
| file_path=answer_file_path, |
| block_size=128 |
| ) |
| except Exception as e: |
| print(f"Error loading training dataset: {e}") |
| sys.exit(1) |
|
|
| |
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=self.tokenizer, |
| mlm=False |
| ) |
|
|
| total_steps = len(train_dataset) * epochs |
| warmup_steps = 0.1 * total_steps |
|
|
| |
| training_args = TrainingArguments( |
| output_dir=model_output_dir, |
| overwrite_output_dir=True, |
| num_train_epochs=epochs, |
| per_device_train_batch_size=4, |
| save_steps=10_000, |
| save_total_limit=2, |
| weight_decay=0.005, |
| gradient_accumulation_steps=8, |
| learning_rate=15e-6, |
| lr_scheduler_type='cosine', |
| warmup_steps=warmup_steps |
| ) |
|
|
| |
| trainer = Trainer( |
| model=self.model, |
| args=training_args, |
| data_collator=data_collator, |
| train_dataset=train_dataset |
| ) |
|
|
| |
| trainer.train() |
| self.model.save_pretrained(model_output_dir) |
| self.tokenizer.save_pretrained(model_output_dir) |
|
|
| def main(): |
| |
| text_file_path = "/Users/migueldeguzman/Desktop/gpt2xl_algos/phi-1.5/v8/q&a_test_v1-3.text" |
| model_output_dir = "/Users/migueldeguzman/Desktop/gpt2xl_algos/phi-1.5/v8/" |
| |
| |
| assistant = GPTAssistant() |
| assistant.fine_tune(text_file_path, model_output_dir) |
|
|
| if __name__ == "__main__": |
| main() |
|
|