Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| import torch | |
| from datasets import load_dataset | |
| from peft import LoraConfig, get_peft_model | |
| from transformers import ( | |
| AutoModelForSeq2SeqLM, | |
| AutoTokenizer, | |
| DataCollatorForSeq2Seq, | |
| Seq2SeqTrainer, | |
| Seq2SeqTrainingArguments, | |
| ) | |
| from prompting import clean_gold_sql, get_schema_text, build_prompt | |
| # ===================================================== | |
| # SETTINGS | |
| # ===================================================== | |
| BASE_MODEL = os.environ.get("BASE_MODEL", "facebook/bart-base") | |
| PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| OUT_DIR = os.path.join(PROJECT_ROOT, "checkpoints", "sft_best_bart_2") | |
| TRAIN_SPLIT = "train[:7000]" | |
| EPOCHS = 12 | |
| LR = 3e-4 | |
| PER_DEVICE_BATCH = 16 | |
| GRAD_ACCUM = 4 | |
| MAX_INPUT = 512 | |
| MAX_OUTPUT = 128 | |
| # ===================================================== | |
| # DEVICE | |
| # ===================================================== | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| device = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")) | |
| print("Using device:", device) | |
| # ===================================================== | |
| # TOKENIZER | |
| # ===================================================== | |
| print("Loading tokenizer/model:", BASE_MODEL) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # ===================================================== | |
| # PREPROCESS FUNCTION | |
| # ===================================================== | |
| def preprocess_function(example): | |
| question = example["question"] | |
| db_id = example["db_id"] | |
| gold_sql = clean_gold_sql(example["query"]) | |
| # ---- Build Prompt ---- | |
| schema_text = get_schema_text(db_id) | |
| prompt = build_prompt(question, db_id, schema_text=schema_text, training_sql=None) | |
| model_inputs = tokenizer( | |
| prompt, | |
| max_length=MAX_INPUT, | |
| truncation=True, | |
| padding="max_length", | |
| ) | |
| # ---- Target SQL ---- | |
| labels = tokenizer( | |
| gold_sql, | |
| max_length=MAX_OUTPUT, | |
| truncation=True, | |
| padding="max_length", | |
| )["input_ids"] | |
| # IMPORTANT: ignore padding in loss | |
| labels = [ | |
| (tok if tok != tokenizer.pad_token_id else -400) | |
| for tok in labels | |
| ] | |
| model_inputs["labels"] = labels | |
| return model_inputs | |
| # ===================================================== | |
| # DATASET | |
| # ===================================================== | |
| print("Loading Spider subset:", TRAIN_SPLIT) | |
| dataset = load_dataset("spider", split=TRAIN_SPLIT) | |
| dataset = dataset.train_test_split(test_size=0.1, seed=42) | |
| train_ds = dataset["train"] | |
| eval_ds = dataset["test"] | |
| print("Tokenizing dataset (single process, stable)...") | |
| train_tok = train_ds.map( | |
| preprocess_function, | |
| batched=False, | |
| num_proc=1, | |
| remove_columns=train_ds.column_names, | |
| load_from_cache_file=False, | |
| ) | |
| eval_tok = eval_ds.map( | |
| preprocess_function, | |
| batched=False, | |
| num_proc=1, | |
| remove_columns=eval_ds.column_names, | |
| load_from_cache_file=False, | |
| ) | |
| print("Train dataset size:", len(train_tok)) | |
| print("Eval dataset size:", len(eval_tok)) | |
| # ===================================================== | |
| # MODEL + LoRA | |
| # ===================================================== | |
| base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL) | |
| base_model.config.use_cache = False | |
| # π UPGRADE 1: Expanded LoRA brainpower | |
| lora_config = LoraConfig( | |
| r=16, # Increased rank for more learning capacity | |
| lora_alpha=32, # Alpha is typically 2x the rank | |
| lora_dropout=0.1, | |
| bias="none", | |
| task_type="SEQ_2_SEQ_LM", | |
| # Target all attention and dense layers in BART | |
| target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], | |
| ) | |
| model = get_peft_model(base_model, lora_config) | |
| model.to(device) | |
| # ===================================================== | |
| # TRAINER | |
| # ===================================================== | |
| data_collator = DataCollatorForSeq2Seq( | |
| tokenizer=tokenizer, | |
| model=model, | |
| padding=True, | |
| ) | |
| args = Seq2SeqTrainingArguments( | |
| output_dir=os.path.join(PROJECT_ROOT, "checkpoints", "sft_bart_runs"), | |
| num_train_epochs=EPOCHS, | |
| learning_rate=LR, | |
| per_device_train_batch_size=PER_DEVICE_BATCH, | |
| per_device_eval_batch_size=PER_DEVICE_BATCH, | |
| gradient_accumulation_steps=GRAD_ACCUM, | |
| dataloader_num_workers=0, | |
| dataloader_pin_memory=False, | |
| # π UPGRADE 2 & 3: Better optimization & generalization | |
| warmup_ratio=0.05, # Slowly ramp up learning rate | |
| weight_decay=0.01, # Penalize over-reliance on single tokens | |
| label_smoothing_factor=0.1, # Prevent overconfidence in SQL token matching | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| save_total_limit=1, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| greater_is_better=False, | |
| logging_steps=50, | |
| report_to=[], | |
| fp16=False, | |
| bf16=False, | |
| predict_with_generate=True, | |
| ) | |
| trainer = Seq2SeqTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_tok, | |
| eval_dataset=eval_tok, | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| ) | |
| # ===================================================== | |
| # TRAIN | |
| # ===================================================== | |
| trainer.train() | |
| # ===================================================== | |
| # SAVE BEST MODEL | |
| # ===================================================== | |
| print("Saving best BART LoRA adapter to:", OUT_DIR) | |
| os.makedirs(OUT_DIR, exist_ok=True) | |
| trainer.model.save_pretrained(OUT_DIR) | |
| tokenizer.save_pretrained(OUT_DIR) | |
| print("DONE β SFT BART finished") |