Spaces:
Running
Running
| from __future__ import annotations | |
| import os | |
| import torch | |
| from datasets import load_dataset | |
| from peft import LoraConfig, get_peft_model | |
| from transformers import ( | |
| AutoModelForSeq2SeqLM, | |
| AutoTokenizer, | |
| DataCollatorForSeq2Seq, | |
| Seq2SeqTrainer, | |
| Seq2SeqTrainingArguments, | |
| ) | |
| from prompting import clean_gold_sql, get_schema_text, build_prompt | |
| # ===================================================== | |
| # SETTINGS | |
| # ===================================================== | |
| BASE_MODEL = "Salesforce/codet5-base" | |
| PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| OUT_DIR = os.path.join(PROJECT_ROOT, "checkpoints", "sft_adapter_codet5") | |
| TRAIN_SPLIT = "train[:7000]" | |
| EPOCHS = 10 | |
| LR = 2e-4 | |
| PER_DEVICE_BATCH = 2 # codet5 bigger -> reduce | |
| GRAD_ACCUM = 4 | |
| MAX_INPUT = 512 | |
| MAX_OUTPUT = 160 | |
| # ===================================================== | |
| # DEVICE | |
| # ===================================================== | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| print("Using device:", device) | |
| # ===================================================== | |
| # TOKENIZER | |
| # ===================================================== | |
| print("Loading tokenizer/model:", BASE_MODEL) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # ===================================================== | |
| # PREPROCESS FUNCTION | |
| # ===================================================== | |
| def preprocess_function(example): | |
| question = example["question"] | |
| db_id = example["db_id"] | |
| gold_sql = clean_gold_sql(example["query"]) | |
| schema_text = get_schema_text(db_id) | |
| prompt = build_prompt(question, db_id, schema_text=schema_text, training_sql=None) | |
| model_inputs = tokenizer( | |
| prompt, | |
| max_length=MAX_INPUT, | |
| truncation=True, | |
| padding="max_length", | |
| ) | |
| labels = tokenizer( | |
| gold_sql, | |
| max_length=MAX_OUTPUT, | |
| truncation=True, | |
| padding="max_length", | |
| )["input_ids"] | |
| labels = [(tok if tok != tokenizer.pad_token_id else -100) for tok in labels] | |
| model_inputs["labels"] = labels | |
| return model_inputs | |
| # ===================================================== | |
| # DATASET | |
| # ===================================================== | |
| print("Loading Spider subset:", TRAIN_SPLIT) | |
| dataset = load_dataset("spider", split=TRAIN_SPLIT) | |
| dataset = dataset.train_test_split(test_size=0.1, seed=42) | |
| train_ds = dataset["train"] | |
| eval_ds = dataset["test"] | |
| print("Tokenizing dataset...") | |
| train_tok = train_ds.map( | |
| preprocess_function, | |
| batched=False, | |
| num_proc=1, | |
| remove_columns=train_ds.column_names, | |
| load_from_cache_file=False, | |
| ) | |
| eval_tok = eval_ds.map( | |
| preprocess_function, | |
| batched=False, | |
| num_proc=1, | |
| remove_columns=eval_ds.column_names, | |
| load_from_cache_file=False, | |
| ) | |
| print("Train dataset size:", len(train_tok)) | |
| print("Eval dataset size:", len(eval_tok)) | |
| # ===================================================== | |
| # MODEL + LoRA (CODET5 FIXED) | |
| # ===================================================== | |
| base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL) | |
| base_model.config.use_cache = False | |
| base_model.gradient_checkpointing_enable() | |
| # 🔥 DIFFERENT FROM T5 | |
| lora_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="SEQ_2_SEQ_LM", | |
| target_modules=["q", "v"], # IMPORTANT FOR CODET5 | |
| ) | |
| model = get_peft_model(base_model, lora_config) | |
| model.to(device) | |
| model.print_trainable_parameters() | |
| # ===================================================== | |
| # TRAINER | |
| # ===================================================== | |
| data_collator = DataCollatorForSeq2Seq( | |
| tokenizer=tokenizer, | |
| model=model, | |
| padding=True, | |
| ) | |
| args = Seq2SeqTrainingArguments( | |
| output_dir=os.path.join(PROJECT_ROOT, "checkpoints", "sft_runs_codet5"), | |
| num_train_epochs=EPOCHS, | |
| learning_rate=LR, | |
| per_device_train_batch_size=PER_DEVICE_BATCH, | |
| per_device_eval_batch_size=PER_DEVICE_BATCH, | |
| gradient_accumulation_steps=GRAD_ACCUM, | |
| dataloader_num_workers=0, | |
| dataloader_pin_memory=False, | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| save_total_limit=1, | |
| logging_steps=50, | |
| report_to=[], | |
| fp16=False, | |
| bf16=False, | |
| predict_with_generate=True, | |
| ) | |
| trainer = Seq2SeqTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_tok, | |
| eval_dataset=eval_tok, | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| ) | |
| # ===================================================== | |
| # TRAIN | |
| # ===================================================== | |
| trainer.train() | |
| # ===================================================== | |
| # SAVE | |
| # ===================================================== | |
| # ===================================================== | |
| # SAVE (SAFE PEFT SAVE) | |
| # ===================================================== | |
| print("Saving LoRA adapter to:", OUT_DIR) | |
| os.makedirs(OUT_DIR, exist_ok=True) | |
| # unwrap trainer model (important!) | |
| peft_model = trainer.model | |
| # ensure on cpu before saving (mac mps bug fix) | |
| peft_model = peft_model.to("cpu") | |
| # save adapter only | |
| peft_model.save_pretrained(OUT_DIR) | |
| tokenizer.save_pretrained(OUT_DIR) | |
| print("DONE ✔ CodeT5 SFT finished") |