khulnasoft commited on
Commit
bf2f259
·
verified ·
1 Parent(s): ac5d95d

Update aifixcode_trainer.py

Browse files
Files changed (1) hide show
  1. aifixcode_trainer.py +161 -70
aifixcode_trainer.py CHANGED
@@ -1,86 +1,177 @@
1
- ### aifixcode_trainer.py
2
-
3
  """
4
- This script sets up a simple HuggingFace-based training + inference pipeline
5
- for bug-fixing AI using a CodeT5 model and supports continual training.
6
- You can upload this script to HuggingFace Space or Hub repo.
 
 
 
 
 
 
 
 
 
7
  """
8
 
 
 
 
9
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
10
  from datasets import load_dataset, DatasetDict
11
- import torch
12
- import os
13
 
14
- # ========== CONFIG ==========
15
- MODEL_NAME = "Salesforce/codet5p-220m"
16
- MODEL_OUT_DIR = "./aifixcode-model"
17
- TRAIN_DATASET_PATH = "./data/train.json"
18
- VAL_DATASET_PATH = "./data/val.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # ========== LOAD MODEL + TOKENIZER ==========
21
- print("Loading model and tokenizer...")
22
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
24
-
25
- # ========== LOAD DATASET ==========
26
- print("Loading dataset...")
27
- def load_json_dataset(train_path, val_path):
28
  dataset = DatasetDict({
29
- "train": load_dataset("json", data_files=train_path)["train"],
30
- "validation": load_dataset("json", data_files=val_path)["train"]
31
  })
32
  return dataset
33
 
34
- dataset = load_json_dataset(TRAIN_DATASET_PATH, VAL_DATASET_PATH)
35
-
36
- # ========== PREPROCESS ==========
37
- print("Tokenizing dataset...")
38
- def preprocess(example):
39
- input_code = example["input"]
40
- target_code = example["output"]
41
- model_inputs = tokenizer(input_code, truncation=True, padding="max_length", max_length=512)
42
- labels = tokenizer(target_code, truncation=True, padding="max_length", max_length=512)
43
- model_inputs["labels"] = labels["input_ids"]
 
44
  return model_inputs
45
 
46
- encoded_dataset = dataset.map(preprocess, batched=True)
47
-
48
- # ========== TRAINING SETUP ==========
49
- print("Setting up trainer...")
50
- training_args = TrainingArguments(
51
- output_dir=MODEL_OUT_DIR,
52
- evaluation_strategy="epoch",
53
- save_strategy="epoch",
54
- learning_rate=5e-5,
55
- per_device_train_batch_size=4,
56
- per_device_eval_batch_size=4,
57
- num_train_epochs=3,
58
- weight_decay=0.01,
59
- logging_dir="./logs",
60
- logging_strategy="epoch",
61
- push_to_hub=True,
62
- hub_model_id="khulnasoft/aifixcode-model",
63
- hub_strategy="every_save"
64
- )
65
-
66
- data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
67
-
68
- trainer = Trainer(
69
- model=model,
70
- args=training_args,
71
- train_dataset=encoded_dataset["train"],
72
- eval_dataset=encoded_dataset["validation"],
73
- tokenizer=tokenizer,
74
- data_collator=data_collator
75
- )
76
 
77
- # ========== TRAIN ==========
78
- print("Starting training...")
79
- trainer.train()
 
 
 
 
 
80
 
81
- # ========== SAVE FINAL MODEL ==========
82
- print("Saving model...")
83
- trainer.save_model(MODEL_OUT_DIR)
84
- tokenizer.save_pretrained(MODEL_OUT_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- print("Training complete and model saved!")
 
 
 
 
1
  """
2
+ This script sets up a HuggingFace-based training and inference pipeline
3
+ for bug-fixing AI using a CodeT5 model. It is designed to be more
4
+ robust and flexible than the original.
5
+
6
+ Key improvements:
7
+ - Uses argparse for configuration, making it easy to change settings
8
+ via the command line.
9
+ - Adds checks to ensure data files exist.
10
+ - Implements a compute_metrics function for better model evaluation.
11
+ - Optimizes data preprocessing with dynamic padding.
12
+ - Saves the best-performing model based on evaluation metrics.
13
+ - Checks for GPU availability.
14
  """
15
 
16
+ import os
17
+ import argparse
18
+ import torch
19
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
20
  from datasets import load_dataset, DatasetDict
21
+ from typing import Dict
22
+ from evaluate import load
23
 
24
+ # ========== ARGUMENT PARSING ==========
25
+ def parse_args():
26
+ """Parses command-line arguments for the training script."""
27
+ parser = argparse.ArgumentParser(description="Fine-tune a Seq2Seq model for code repair.")
28
+ parser.add_argument("--model_name", type=str, default="Salesforce/codet5p-220m",
29
+ help="Pre-trained model name from HuggingFace.")
30
+ parser.add_argument("--output_dir", type=str, default="./aifixcode-model",
31
+ help="Directory to save the trained model.")
32
+ parser.add_argument("--train_path", type=str, default="./data/train.json",
33
+ help="Path to the training data JSON file.")
34
+ parser.add_argument("--val_path", type=str, default="./data/val.json",
35
+ help="Path to the validation data JSON file.")
36
+ parser.add_argument("--epochs", type=int, default=3,
37
+ help="Number of training epochs.")
38
+ parser.add_argument("--learning_rate", type=float, default=5e-5,
39
+ help="Learning rate for the optimizer.")
40
+ parser.add_argument("--per_device_train_batch_size", type=int, default=4,
41
+ help="Batch size per device for training.")
42
+ parser.add_argument("--per_device_eval_batch_size", type=int, default=4,
43
+ help="Batch size per device for evaluation.")
44
+ parser.add_argument("--push_to_hub", action="store_true",
45
+ help="Whether to push the model to the Hugging Face Hub.")
46
+ parser.add_argument("--hub_model_id", type=str, default="khulnasoft/aifixcode-model",
47
+ help="Hugging Face Hub model ID to push to.")
48
+ return parser.parse_args()
49
 
50
+ # ========== DATA LOADING ==========
51
+ def load_json_dataset(train_path: str, val_path: str) -> DatasetDict:
52
+ """Loads and returns a dataset dictionary from JSON files."""
53
+ if not os.path.exists(train_path) or not os.path.exists(val_path):
54
+ raise FileNotFoundError(f"One or both data files not found: {train_path}, {val_path}")
55
+
56
+ print("Loading dataset...")
 
57
  dataset = DatasetDict({
58
+ "train": load_dataset("json", data_files=train_path, split="train"),
59
+ "validation": load_dataset("json", data_files=val_path, split="train")
60
  })
61
  return dataset
62
 
63
+ # ========== DATA PREPROCESSING ==========
64
+ def preprocess_function(examples: Dict[str, list], tokenizer) -> Dict[str, list]:
65
+ """Tokenizes a batch of input and target code.
66
+
67
+ This function uses dynamic padding by default, which is more
68
+ memory-efficient than padding all sequences to a fixed max length.
69
+ """
70
+ inputs = [ex for ex in examples["input"]]
71
+ targets = [ex for ex in examples["output"]]
72
+
73
+ model_inputs = tokenizer(inputs, text_target=targets, max_length=512, truncation=True)
74
  return model_inputs
75
 
76
+ # ========== METRIC CALCULATION ==========
77
+ def compute_metrics(eval_pred):
78
+ """Computes BLEU and Rouge metrics for model evaluation."""
79
+ bleu_metric = load("bleu")
80
+ rouge_metric = load("rouge")
81
+
82
+ predictions, labels = eval_pred
83
+
84
+ # Replace -100 in labels as we can't decode them
85
+ labels = [[item if item != -100 else tokenizer.pad_token_id for item in row] for row in labels]
86
+
87
+ decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
88
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
89
+
90
+ # Compute BLEU score
91
+ bleu_result = bleu_metric.compute(predictions=decoded_preds, references=decoded_labels)
92
+
93
+ # Compute ROUGE score
94
+ rouge_result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)
95
+
96
+ return {
97
+ "bleu": bleu_result["bleu"],
98
+ "rouge1": rouge_result["rouge1"],
99
+ "rouge2": rouge_result["rouge2"],
100
+ "rougeL": rouge_result["rougeL"],
101
+ }
 
 
 
 
102
 
103
+ # ========== MAIN EXECUTION BLOCK ==========
104
+ def main():
105
+ """Main function to set up and run the training pipeline."""
106
+ args = parse_args()
107
+
108
+ # Check for GPU availability
109
+ if not torch.cuda.is_available():
110
+ print("Warning: A GPU is not available. Training will be very slow on CPU.")
111
 
112
+ # Load model and tokenizer
113
+ print(f"Loading model '{args.model_name}' and tokenizer...")
114
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
115
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
116
+
117
+ # Load and preprocess dataset
118
+ try:
119
+ dataset = load_json_dataset(args.train_path, args.val_path)
120
+ except FileNotFoundError as e:
121
+ print(e)
122
+ return
123
+
124
+ print("Tokenizing dataset...")
125
+ tokenized_dataset = dataset.map(
126
+ lambda examples: preprocess_function(examples, tokenizer),
127
+ batched=True,
128
+ remove_columns=dataset["train"].column_names
129
+ )
130
+
131
+ # Training arguments setup
132
+ print("Setting up trainer...")
133
+ training_args = TrainingArguments(
134
+ output_dir=os.path.join(args.output_dir, "checkpoints"),
135
+ evaluation_strategy="epoch",
136
+ save_strategy="epoch",
137
+ learning_rate=args.learning_rate,
138
+ per_device_train_batch_size=args.per_device_train_batch_size,
139
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
140
+ num_train_epochs=args.epochs,
141
+ weight_decay=0.01,
142
+ logging_dir=os.path.join(args.output_dir, "logs"),
143
+ logging_strategy="epoch",
144
+ push_to_hub=args.push_to_hub,
145
+ hub_model_id=args.hub_model_id if args.push_to_hub else None,
146
+ hub_strategy="every_save",
147
+ load_best_model_at_end=True, # Saves the best model
148
+ metric_for_best_model="rougeL", # Specify the metric to use for saving the best model
149
+ greater_is_better=True,
150
+ report_to="tensorboard"
151
+ )
152
+
153
+ data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
154
+
155
+ # Initialize and train the trainer
156
+ trainer = Trainer(
157
+ model=model,
158
+ args=training_args,
159
+ train_dataset=tokenized_dataset["train"],
160
+ eval_dataset=tokenized_dataset["validation"],
161
+ tokenizer=tokenizer,
162
+ data_collator=data_collator,
163
+ compute_metrics=compute_metrics
164
+ )
165
+
166
+ print("Starting training...")
167
+ trainer.train()
168
+
169
+ # Save final model
170
+ print("Saving final model...")
171
+ final_model_dir = os.path.join(args.output_dir, "final")
172
+ trainer.save_model(final_model_dir)
173
+ tokenizer.save_pretrained(final_model_dir)
174
+ print("Training complete and model saved!")
175
 
176
+ if __name__ == "__main__":
177
+ main()