Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Fine-tuning script for SmolLM2-135M model using Unsloth. | |
| This script demonstrates how to: | |
| 1. Install and configure Unsloth | |
| 2. Prepare and format training data | |
| 3. Configure and run the training process | |
| 4. Save and evaluate the model | |
| To run this script: | |
| 1. Install dependencies: pip install -r requirements.txt | |
| 2. Run: python train.py | |
| """ | |
| import logging | |
| import os | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Union | |
| # isort: off | |
| from unsloth import FastLanguageModel, is_bfloat16_supported # noqa: E402 | |
| from unsloth.chat_templates import get_chat_template # noqa: E402 | |
| # isort: on | |
| from datasets import ( | |
| Dataset, | |
| DatasetDict, | |
| IterableDataset, | |
| IterableDatasetDict, | |
| load_dataset, | |
| ) | |
| from transformers import AutoTokenizer, Trainer, TrainingArguments | |
| from trl import SFTTrainer | |
| from trl.data.data_collator import DataCollatorForLanguageModeling | |
| # Configuration | |
| max_seq_length = 2048 # Auto supports RoPE Scaling internally | |
| dtype = ( | |
| None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ | |
| ) | |
| load_in_4bit = True # Use 4bit quantization to reduce memory usage | |
| validation_split = 0.1 # 10% of data for validation | |
| # Setup logging | |
| def setup_logging(): | |
| """Configure logging for the training process.""" | |
| # Create logs directory if it doesn't exist | |
| log_dir = Path("logs") | |
| log_dir.mkdir(exist_ok=True) | |
| # Create a unique log file name with timestamp | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| log_file = log_dir / f"training_{timestamp}.log" | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| handlers=[logging.FileHandler(log_file), logging.StreamHandler()], | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logger.info(f"Logging initialized. Log file: {log_file}") | |
| return logger | |
| logger = setup_logging() | |
| def install_dependencies(): | |
| """Install required dependencies.""" | |
| logger.info("Installing dependencies...") | |
| try: | |
| os.system( | |
| 'pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"' | |
| ) | |
| os.system("pip install --no-deps xformers trl peft accelerate bitsandbytes") | |
| logger.info("Dependencies installed successfully") | |
| except Exception as e: | |
| logger.error(f"Error installing dependencies: {e}") | |
| raise | |
| def load_model() -> tuple[FastLanguageModel, AutoTokenizer]: | |
| """Load and configure the model.""" | |
| logger.info("Loading model and tokenizer...") | |
| try: | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name="unsloth/SmolLM2-135M-Instruct-bnb-4bit", | |
| max_seq_length=max_seq_length, | |
| dtype=dtype, | |
| load_in_4bit=load_in_4bit, | |
| ) | |
| logger.info("Base model loaded successfully") | |
| # Configure LoRA | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=64, | |
| target_modules=[ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ], | |
| lora_alpha=128, | |
| lora_dropout=0.05, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=3407, | |
| use_rslora=True, | |
| loftq_config=None, | |
| ) | |
| logger.info("LoRA configuration applied successfully") | |
| return model, tokenizer | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| raise | |
| def load_and_format_dataset( | |
| tokenizer: AutoTokenizer, | |
| ) -> tuple[ | |
| Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset], AutoTokenizer | |
| ]: | |
| """Load and format the training dataset.""" | |
| logger.info("Loading and formatting dataset...") | |
| try: | |
| # Load the code-act dataset | |
| dataset = load_dataset("xingyaoww/code-act", split="codeact") | |
| logger.info(f"Dataset loaded successfully. Size: {len(dataset)} examples") | |
| # Split into train and validation sets | |
| dataset = dataset.train_test_split(test_size=validation_split, seed=3407) | |
| logger.info( | |
| f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets" | |
| ) | |
| # Configure chat template | |
| tokenizer = get_chat_template( | |
| tokenizer, | |
| chat_template="chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth | |
| mapping={ | |
| "role": "from", | |
| "content": "value", | |
| "user": "human", | |
| "assistant": "gpt", | |
| }, # ShareGPT style | |
| map_eos_token=True, # Maps <|im_end|> to </s> instead | |
| ) | |
| logger.info("Chat template configured successfully") | |
| def formatting_prompts_func(examples): | |
| convos = examples["conversations"] | |
| texts = [ | |
| tokenizer.apply_chat_template( | |
| convo, tokenize=False, add_generation_prompt=False | |
| ) | |
| for convo in convos | |
| ] | |
| return {"text": texts} | |
| # Apply formatting to both train and validation sets | |
| dataset = DatasetDict( | |
| { | |
| "train": dataset["train"].map(formatting_prompts_func, batched=True), | |
| "validation": dataset["test"].map( | |
| formatting_prompts_func, batched=True | |
| ), | |
| } | |
| ) | |
| logger.info("Dataset formatting completed successfully") | |
| return dataset, tokenizer | |
| except Exception as e: | |
| logger.error(f"Error loading/formatting dataset: {e}") | |
| raise | |
| def create_trainer( | |
| model: FastLanguageModel, | |
| tokenizer: AutoTokenizer, | |
| dataset: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset], | |
| ) -> Trainer: | |
| """Create and configure the SFTTrainer.""" | |
| logger.info("Creating trainer...") | |
| try: | |
| trainer = SFTTrainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["validation"], | |
| dataset_num_proc=2, | |
| packing=False, | |
| args=TrainingArguments( | |
| per_device_train_batch_size=2, | |
| per_device_eval_batch_size=2, | |
| gradient_accumulation_steps=16, | |
| warmup_steps=100, | |
| max_steps=120, | |
| learning_rate=5e-5, | |
| fp16=not is_bfloat16_supported(), | |
| bf16=is_bfloat16_supported(), | |
| logging_steps=1, | |
| save_strategy="steps", | |
| save_steps=30, | |
| eval_strategy="steps", | |
| eval_steps=30, | |
| save_total_limit=2, | |
| optim="adamw_8bit", | |
| weight_decay=0.01, | |
| lr_scheduler_type="cosine_with_restarts", | |
| seed=3407, | |
| output_dir="outputs", | |
| gradient_checkpointing=True, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| greater_is_better=False, | |
| ), | |
| data_collator=DataCollatorForLanguageModeling( | |
| tokenizer=tokenizer, | |
| mlm=False, | |
| pad_to_multiple_of=8, | |
| ), | |
| ) | |
| logger.info("Trainer created successfully") | |
| return trainer | |
| except Exception as e: | |
| logger.error(f"Error creating trainer: {e}") | |
| raise | |
| def main(): | |
| """Main training function.""" | |
| try: | |
| logger.info("Starting training process...") | |
| # Install dependencies | |
| install_dependencies() | |
| # Load model and tokenizer | |
| model, tokenizer = load_model() | |
| # Load and prepare dataset | |
| dataset, tokenizer = load_and_format_dataset(tokenizer) | |
| # Create trainer | |
| trainer: Trainer = create_trainer(model, tokenizer, dataset) | |
| # Train | |
| logger.info("Starting training...") | |
| trainer.train() | |
| # Save model | |
| logger.info("Saving final model...") | |
| trainer.save_model("final_model") | |
| # Print final metrics | |
| final_metrics = trainer.state.log_history[-1] | |
| logger.info("\nTraining completed!") | |
| logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}") | |
| logger.info(f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}") | |
| except Exception as e: | |
| logger.error(f"Error in main training process: {e}") | |
| raise | |
| if __name__ == "__main__": | |
| main() | |