File size: 7,301 Bytes
22e0e62 bf2f259 22e0e62 bf2f259 22e0e62 bf2f259 22e0e62 bf2f259 22e0e62 bf2f259 22e0e62 bf2f259 22e0e62 bf2f259 22e0e62 bf2f259 22e0e62 bf2f259 22e0e62 bf2f259 22e0e62 bf2f259 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
"""
This script sets up a HuggingFace-based training and inference pipeline
for bug-fixing AI using a CodeT5 model. It is designed to be more
robust and flexible than the original.
Key improvements:
- Uses argparse for configuration, making it easy to change settings
via the command line.
- Adds checks to ensure data files exist.
- Implements a compute_metrics function for better model evaluation.
- Optimizes data preprocessing with dynamic padding.
- Saves the best-performing model based on evaluation metrics.
- Checks for GPU availability.
"""
import os
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import load_dataset, DatasetDict
from typing import Dict
from evaluate import load
# ========== ARGUMENT PARSING ==========
def parse_args():
"""Parses command-line arguments for the training script."""
parser = argparse.ArgumentParser(description="Fine-tune a Seq2Seq model for code repair.")
parser.add_argument("--model_name", type=str, default="Salesforce/codet5p-220m",
help="Pre-trained model name from HuggingFace.")
parser.add_argument("--output_dir", type=str, default="./aifixcode-model",
help="Directory to save the trained model.")
parser.add_argument("--train_path", type=str, default="./data/train.json",
help="Path to the training data JSON file.")
parser.add_argument("--val_path", type=str, default="./data/val.json",
help="Path to the validation data JSON file.")
parser.add_argument("--epochs", type=int, default=3,
help="Number of training epochs.")
parser.add_argument("--learning_rate", type=float, default=5e-5,
help="Learning rate for the optimizer.")
parser.add_argument("--per_device_train_batch_size", type=int, default=4,
help="Batch size per device for training.")
parser.add_argument("--per_device_eval_batch_size", type=int, default=4,
help="Batch size per device for evaluation.")
parser.add_argument("--push_to_hub", action="store_true",
help="Whether to push the model to the Hugging Face Hub.")
parser.add_argument("--hub_model_id", type=str, default="khulnasoft/aifixcode-model",
help="Hugging Face Hub model ID to push to.")
return parser.parse_args()
# ========== DATA LOADING ==========
def load_json_dataset(train_path: str, val_path: str) -> DatasetDict:
"""Loads and returns a dataset dictionary from JSON files."""
if not os.path.exists(train_path) or not os.path.exists(val_path):
raise FileNotFoundError(f"One or both data files not found: {train_path}, {val_path}")
print("Loading dataset...")
dataset = DatasetDict({
"train": load_dataset("json", data_files=train_path, split="train"),
"validation": load_dataset("json", data_files=val_path, split="train")
})
return dataset
# ========== DATA PREPROCESSING ==========
def preprocess_function(examples: Dict[str, list], tokenizer) -> Dict[str, list]:
"""Tokenizes a batch of input and target code.
This function uses dynamic padding by default, which is more
memory-efficient than padding all sequences to a fixed max length.
"""
inputs = [ex for ex in examples["input"]]
targets = [ex for ex in examples["output"]]
model_inputs = tokenizer(inputs, text_target=targets, max_length=512, truncation=True)
return model_inputs
# ========== METRIC CALCULATION ==========
def compute_metrics(eval_pred):
"""Computes BLEU and Rouge metrics for model evaluation."""
bleu_metric = load("bleu")
rouge_metric = load("rouge")
predictions, labels = eval_pred
# Replace -100 in labels as we can't decode them
labels = [[item if item != -100 else tokenizer.pad_token_id for item in row] for row in labels]
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Compute BLEU score
bleu_result = bleu_metric.compute(predictions=decoded_preds, references=decoded_labels)
# Compute ROUGE score
rouge_result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)
return {
"bleu": bleu_result["bleu"],
"rouge1": rouge_result["rouge1"],
"rouge2": rouge_result["rouge2"],
"rougeL": rouge_result["rougeL"],
}
# ========== MAIN EXECUTION BLOCK ==========
def main():
"""Main function to set up and run the training pipeline."""
args = parse_args()
# Check for GPU availability
if not torch.cuda.is_available():
print("Warning: A GPU is not available. Training will be very slow on CPU.")
# Load model and tokenizer
print(f"Loading model '{args.model_name}' and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
# Load and preprocess dataset
try:
dataset = load_json_dataset(args.train_path, args.val_path)
except FileNotFoundError as e:
print(e)
return
print("Tokenizing dataset...")
tokenized_dataset = dataset.map(
lambda examples: preprocess_function(examples, tokenizer),
batched=True,
remove_columns=dataset["train"].column_names
)
# Training arguments setup
print("Setting up trainer...")
training_args = TrainingArguments(
output_dir=os.path.join(args.output_dir, "checkpoints"),
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=args.learning_rate,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
num_train_epochs=args.epochs,
weight_decay=0.01,
logging_dir=os.path.join(args.output_dir, "logs"),
logging_strategy="epoch",
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id if args.push_to_hub else None,
hub_strategy="every_save",
load_best_model_at_end=True, # Saves the best model
metric_for_best_model="rougeL", # Specify the metric to use for saving the best model
greater_is_better=True,
report_to="tensorboard"
)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
# Initialize and train the trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics
)
print("Starting training...")
trainer.train()
# Save final model
print("Saving final model...")
final_model_dir = os.path.join(args.output_dir, "final")
trainer.save_model(final_model_dir)
tokenizer.save_pretrained(final_model_dir)
print("Training complete and model saved!")
if __name__ == "__main__":
main()
|