Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| Seq2SeqTrainingArguments, | |
| Seq2SeqTrainer, | |
| DataCollatorForSeq2Seq | |
| ) | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| from datasets import load_dataset | |
| import os | |
| # Model Configuration | |
| MODEL_NAME = "Salesforce/codet5-base" | |
| MAX_LENGTH = 128 | |
| TRAIN_BATCH_SIZE = 2 | |
| EVAL_BATCH_SIZE = 2 | |
| LEARNING_RATE = 1e-4 | |
| NUM_EPOCHS = 3 | |
| TRAIN_SIZE = 5000 | |
| VAL_SIZE = 500 | |
| CHECKPOINT_DIR = "./codet5-sql-finetuned" | |
| def preprocess(example): | |
| question = example["question"] | |
| table_headers = ", ".join(example["table"]["header"]) | |
| sql_query = example["sql"]["human_readable"] | |
| return { | |
| "input_text": f"### Table columns:\n{table_headers}\n### Question:\n{question}\n### SQL:", | |
| "target_text": sql_query | |
| } | |
| def main(): | |
| # Set up device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Load and preprocess dataset | |
| print("Loading dataset...") | |
| try: | |
| dataset = load_dataset("wikisql") | |
| except Exception as e: | |
| print(f"Error loading dataset: {str(e)}") | |
| print("Trying with trust_remote_code=True...") | |
| dataset = load_dataset("wikisql", trust_remote_code=True) | |
| train_dataset = dataset["train"].select(range(TRAIN_SIZE)) | |
| val_dataset = dataset["validation"].select(range(VAL_SIZE)) | |
| print("Preprocessing datasets...") | |
| processed_train = train_dataset.map(preprocess, remove_columns=train_dataset.column_names) | |
| processed_val = val_dataset.map(preprocess, remove_columns=val_dataset.column_names) | |
| # Load model and tokenizer | |
| print("Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| # Add LoRA adapters | |
| lora_config = LoraConfig( | |
| r=8, | |
| lora_alpha=16, | |
| lora_dropout=0.1, | |
| bias="none", | |
| task_type=TaskType.SEQ_2_SEQ_LM, | |
| target_modules=["q", "v", "k", "o", "wi", "wo"] | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| def tokenize_function(examples): | |
| inputs = tokenizer( | |
| examples["input_text"], | |
| padding="max_length", | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| return_tensors="pt" | |
| ) | |
| targets = tokenizer( | |
| examples["target_text"], | |
| padding="max_length", | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| return_tensors="pt" | |
| ) | |
| inputs["labels"] = targets["input_ids"] | |
| return inputs | |
| print("Tokenizing datasets...") | |
| tokenized_train = processed_train.map( | |
| tokenize_function, | |
| remove_columns=processed_train.column_names, | |
| batched=True | |
| ) | |
| tokenized_val = processed_val.map( | |
| tokenize_function, | |
| remove_columns=processed_val.column_names, | |
| batched=True | |
| ) | |
| # Training arguments - simplified for stability | |
| training_args = Seq2SeqTrainingArguments( | |
| output_dir=CHECKPOINT_DIR, | |
| per_device_train_batch_size=TRAIN_BATCH_SIZE, | |
| per_device_eval_batch_size=EVAL_BATCH_SIZE, | |
| num_train_epochs=NUM_EPOCHS, | |
| learning_rate=LEARNING_RATE, | |
| logging_dir=os.path.join(CHECKPOINT_DIR, "logs"), | |
| logging_steps=10, | |
| save_total_limit=2, | |
| predict_with_generate=True, | |
| no_cuda=True, # Force CPU training | |
| fp16=False, # Disable mixed precision training since we're on CPU | |
| report_to="none" # Disable wandb logging | |
| ) | |
| # Data collator | |
| data_collator = DataCollatorForSeq2Seq( | |
| tokenizer, | |
| model=model, | |
| padding=True | |
| ) | |
| # Initialize trainer | |
| trainer = Seq2SeqTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_train, | |
| eval_dataset=tokenized_val, | |
| data_collator=data_collator, | |
| ) | |
| try: | |
| print("\nStarting training...") | |
| print("You can stop training at any time by pressing Ctrl+C") | |
| print("Training will automatically save checkpoints after each epoch") | |
| # Check for existing checkpoints | |
| last_checkpoint = None | |
| if os.path.exists(CHECKPOINT_DIR): | |
| checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint-')] | |
| if checkpoints: | |
| last_checkpoint = os.path.join(CHECKPOINT_DIR, sorted(checkpoints, key=lambda x: int(x.split('-')[1]))[-1]) | |
| print(f"\nFound checkpoint: {last_checkpoint}") | |
| print("Training will resume from this checkpoint.") | |
| # Start or resume training | |
| trainer.train(resume_from_checkpoint=last_checkpoint) | |
| # Save the final model | |
| trainer.save_model("./final-model") | |
| print("\nTraining completed successfully!") | |
| print(f"Final model saved to: ./final-model") | |
| except KeyboardInterrupt: | |
| print("\nTraining interrupted by user!") | |
| print("Progress is saved in the latest checkpoint.") | |
| print("To resume, just run the script again.") | |
| except Exception as e: | |
| print(f"\nAn error occurred during training: {str(e)}") | |
| if os.path.exists(CHECKPOINT_DIR): | |
| error_checkpoint = os.path.join(CHECKPOINT_DIR, "checkpoint-error") | |
| trainer.save_model(error_checkpoint) | |
| print(f"Saved error checkpoint to: {error_checkpoint}") | |
| if __name__ == "__main__": | |
| main() |