File size: 7,301 Bytes
22e0e62
bf2f259
 
 
 
 
 
 
 
 
 
 
 
22e0e62
 
bf2f259
 
 
22e0e62
 
bf2f259
 
22e0e62
bf2f259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22e0e62
bf2f259
 
 
 
 
 
 
22e0e62
bf2f259
 
22e0e62
 
 
bf2f259
 
 
 
 
 
 
 
 
 
 
22e0e62
 
bf2f259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22e0e62
bf2f259
 
 
 
 
 
 
 
22e0e62
bf2f259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22e0e62
bf2f259
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""
This script sets up a HuggingFace-based training and inference pipeline
for bug-fixing AI using a CodeT5 model. It is designed to be more
robust and flexible than the original.

Key improvements:
- Uses argparse for configuration, making it easy to change settings
  via the command line.
- Adds checks to ensure data files exist.
- Implements a compute_metrics function for better model evaluation.
- Optimizes data preprocessing with dynamic padding.
- Saves the best-performing model based on evaluation metrics.
- Checks for GPU availability.
"""

import os
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import load_dataset, DatasetDict
from typing import Dict
from evaluate import load

# ========== ARGUMENT PARSING ==========
def parse_args():
    """Parses command-line arguments for the training script."""
    parser = argparse.ArgumentParser(description="Fine-tune a Seq2Seq model for code repair.")
    parser.add_argument("--model_name", type=str, default="Salesforce/codet5p-220m",
                        help="Pre-trained model name from HuggingFace.")
    parser.add_argument("--output_dir", type=str, default="./aifixcode-model",
                        help="Directory to save the trained model.")
    parser.add_argument("--train_path", type=str, default="./data/train.json",
                        help="Path to the training data JSON file.")
    parser.add_argument("--val_path", type=str, default="./data/val.json",
                        help="Path to the validation data JSON file.")
    parser.add_argument("--epochs", type=int, default=3,
                        help="Number of training epochs.")
    parser.add_argument("--learning_rate", type=float, default=5e-5,
                        help="Learning rate for the optimizer.")
    parser.add_argument("--per_device_train_batch_size", type=int, default=4,
                        help="Batch size per device for training.")
    parser.add_argument("--per_device_eval_batch_size", type=int, default=4,
                        help="Batch size per device for evaluation.")
    parser.add_argument("--push_to_hub", action="store_true",
                        help="Whether to push the model to the Hugging Face Hub.")
    parser.add_argument("--hub_model_id", type=str, default="khulnasoft/aifixcode-model",
                        help="Hugging Face Hub model ID to push to.")
    return parser.parse_args()

# ========== DATA LOADING ==========
def load_json_dataset(train_path: str, val_path: str) -> DatasetDict:
    """Loads and returns a dataset dictionary from JSON files."""
    if not os.path.exists(train_path) or not os.path.exists(val_path):
        raise FileNotFoundError(f"One or both data files not found: {train_path}, {val_path}")
    
    print("Loading dataset...")
    dataset = DatasetDict({
        "train": load_dataset("json", data_files=train_path, split="train"),
        "validation": load_dataset("json", data_files=val_path, split="train")
    })
    return dataset

# ========== DATA PREPROCESSING ==========
def preprocess_function(examples: Dict[str, list], tokenizer) -> Dict[str, list]:
    """Tokenizes a batch of input and target code.
    
    This function uses dynamic padding by default, which is more
    memory-efficient than padding all sequences to a fixed max length.
    """
    inputs = [ex for ex in examples["input"]]
    targets = [ex for ex in examples["output"]]
    
    model_inputs = tokenizer(inputs, text_target=targets, max_length=512, truncation=True)
    return model_inputs

# ========== METRIC CALCULATION ==========
def compute_metrics(eval_pred):
    """Computes BLEU and Rouge metrics for model evaluation."""
    bleu_metric = load("bleu")
    rouge_metric = load("rouge")
    
    predictions, labels = eval_pred
    
    # Replace -100 in labels as we can't decode them
    labels = [[item if item != -100 else tokenizer.pad_token_id for item in row] for row in labels]
    
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Compute BLEU score
    bleu_result = bleu_metric.compute(predictions=decoded_preds, references=decoded_labels)
    
    # Compute ROUGE score
    rouge_result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)
    
    return {
        "bleu": bleu_result["bleu"],
        "rouge1": rouge_result["rouge1"],
        "rouge2": rouge_result["rouge2"],
        "rougeL": rouge_result["rougeL"],
    }

# ========== MAIN EXECUTION BLOCK ==========
def main():
    """Main function to set up and run the training pipeline."""
    args = parse_args()
    
    # Check for GPU availability
    if not torch.cuda.is_available():
        print("Warning: A GPU is not available. Training will be very slow on CPU.")

    # Load model and tokenizer
    print(f"Loading model '{args.model_name}' and tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
    
    # Load and preprocess dataset
    try:
        dataset = load_json_dataset(args.train_path, args.val_path)
    except FileNotFoundError as e:
        print(e)
        return
        
    print("Tokenizing dataset...")
    tokenized_dataset = dataset.map(
        lambda examples: preprocess_function(examples, tokenizer),
        batched=True,
        remove_columns=dataset["train"].column_names
    )
    
    # Training arguments setup
    print("Setting up trainer...")
    training_args = TrainingArguments(
        output_dir=os.path.join(args.output_dir, "checkpoints"),
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        num_train_epochs=args.epochs,
        weight_decay=0.01,
        logging_dir=os.path.join(args.output_dir, "logs"),
        logging_strategy="epoch",
        push_to_hub=args.push_to_hub,
        hub_model_id=args.hub_model_id if args.push_to_hub else None,
        hub_strategy="every_save",
        load_best_model_at_end=True,  # Saves the best model
        metric_for_best_model="rougeL", # Specify the metric to use for saving the best model
        greater_is_better=True,
        report_to="tensorboard"
    )
    
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
    
    # Initialize and train the trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics
    )
    
    print("Starting training...")
    trainer.train()
    
    # Save final model
    print("Saving final model...")
    final_model_dir = os.path.join(args.output_dir, "final")
    trainer.save_model(final_model_dir)
    tokenizer.save_pretrained(final_model_dir)
    print("Training complete and model saved!")

if __name__ == "__main__":
    main()