| | import os |
| | import sys |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments |
| |
|
| | class GPTAssistant: |
| | def __init__(self, model_name="/Users/migueldeguzman/Desktop/gpt2xl_algos/falcon-1b/v1/"): |
| | try: |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| | self.model = AutoModelForCausalLM.from_pretrained(model_name) |
| | except Exception as e: |
| | print(f"Error initializing the model or tokenizer: {e}") |
| | sys.exit(1) |
| |
|
| | def fine_tune(self, answer_file_path, model_output_dir, epochs=1.0): |
| | |
| | try: |
| | train_dataset = TextDataset( |
| | tokenizer=self.tokenizer, |
| | file_path=answer_file_path, |
| | block_size=128 |
| | ) |
| | except Exception as e: |
| | print(f"Error loading training dataset: {e}") |
| | sys.exit(1) |
| |
|
| | |
| | data_collator = DataCollatorForLanguageModeling( |
| | tokenizer=self.tokenizer, |
| | mlm=False |
| | ) |
| |
|
| | total_steps = len(train_dataset) * epochs |
| | warmup_steps = 0.1 * total_steps |
| |
|
| | |
| | training_args = TrainingArguments( |
| | output_dir=model_output_dir, |
| | overwrite_output_dir=True, |
| | num_train_epochs=epochs, |
| | per_device_train_batch_size=4, |
| | save_steps=10_000, |
| | save_total_limit=2, |
| | weight_decay=0.005, |
| | gradient_accumulation_steps=8, |
| | learning_rate=3e-6, |
| | lr_scheduler_type='cosine', |
| | warmup_steps=warmup_steps |
| | ) |
| |
|
| | |
| | trainer = Trainer( |
| | model=self.model, |
| | args=training_args, |
| | data_collator=data_collator, |
| | train_dataset=train_dataset |
| | ) |
| |
|
| | |
| | trainer.train() |
| | self.model.save_pretrained(model_output_dir) |
| | self.tokenizer.save_pretrained(model_output_dir) |
| |
|
| | def main(): |
| | |
| | text_file_path = "/Users/migueldeguzman/Desktop/gpt2xl_algos/falcon-1b/v2/shadow_integration.text" |
| | model_output_dir = "/Users/migueldeguzman/Desktop/gpt2xl_algos/falcon-1b/v2/" |
| | |
| | |
| | assistant = GPTAssistant() |
| | assistant.fine_tune(text_file_path, model_output_dir) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|