File size: 3,814 Bytes
b8ab4a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
import jsonlines
import os
import torch
from model import Transformer, ModelArgs
from tokenizer import Tokenizer

class MathDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, data_paths, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = []
        
        # Load and combine data from all files
        for path in data_paths:
            with jsonlines.open(path) as reader:
                self.data.extend(list(reader))
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        example = self.data[idx]
        
        # Format the input text
        if "proof_steps" in example:
            # For ProofNet-style data
            text = f"Problem: {example['problem']}\nSolution: {example['solution']}\nProof Steps:\n"
            for step in example["proof_steps"]:
                text += f"- {step['text']}\n"
        else:
            # For GSM8K-style data
            text = f"Question: {example['question']}\nAnswer: {example['answer']}"
        
        # Tokenize
        inputs = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        # Remove batch dimension
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        
        return {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "labels": inputs["input_ids"]  # For causal LM training
        }

def main():
    # Initialize your custom model
    model_args = ModelArgs(
        dim=512,
        n_layers=8,
        n_heads=8,
        vocab_size=50000,  # Adjust based on your tokenizer
        max_seq_len=1024
    )
    model = Transformer(model_args)
    
    # Initialize your custom tokenizer
    tokenizer = Tokenizer()
    
    # Configure tokenizer
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Set up training data paths
    data_dir = os.path.join(os.path.dirname(__file__), "processed_data")
    data_paths = [
        os.path.join(data_dir, "gsm8k_processed.jsonl"),
        os.path.join(data_dir, "proofnet_processed.jsonl")
    ]
    
    # Create dataset
    dataset = MathDataset(
        tokenizer=tokenizer,
        data_paths=data_paths,
        max_length=1024  # Increased max_length for longer proofs
    )
    
    # Define training arguments
    training_args = TrainingArguments(
        output_dir="./math_expert_output",
        overwrite_output_dir=True,
        num_train_epochs=3,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        save_steps=1000,
        save_total_limit=2,
        logging_dir="./math_expert_logs",
        logging_steps=100,
        evaluation_strategy="steps",
        eval_steps=1000,
        load_best_model_at_end=True,
        learning_rate=5e-5,
        warmup_steps=500,
        weight_decay=0.01,
        fp16=True if torch.cuda.is_available() else False
    )
    
    # Create trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        tokenizer=tokenizer,
    )
    
    # Start training
    print("Starting training with your custom model...")
    trainer.train()
    
    # Save the model
    output_dir = "./math_expert_model"
    os.makedirs(output_dir, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
    model_args.save(os.path.join(output_dir, "config.json"))
    print(f"Model saved to {output_dir}")

if __name__ == "__main__":
    main()