Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- run_transformers_training.py +329 -62
run_transformers_training.py
CHANGED
|
@@ -10,6 +10,7 @@ import logging
|
|
| 10 |
from datetime import datetime
|
| 11 |
import time
|
| 12 |
import warnings
|
|
|
|
| 13 |
from importlib.util import find_spec
|
| 14 |
import multiprocessing
|
| 15 |
import torch
|
|
@@ -31,64 +32,36 @@ if CUDA_AVAILABLE:
|
|
| 31 |
# Method already set, which is fine
|
| 32 |
print("Multiprocessing start method already set")
|
| 33 |
|
| 34 |
-
#
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
level=logging.INFO,
|
| 40 |
-
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 41 |
-
handlers=[logging.StreamHandler(sys.stdout)]
|
| 42 |
-
)
|
| 43 |
-
logger = logging.getLogger(__name__)
|
| 44 |
-
|
| 45 |
-
# Set other loggers to WARNING to reduce noise and ensure our logs are visible
|
| 46 |
-
logging.getLogger("transformers").setLevel(logging.WARNING)
|
| 47 |
-
logging.getLogger("datasets").setLevel(logging.WARNING)
|
| 48 |
-
logging.getLogger("accelerate").setLevel(logging.WARNING)
|
| 49 |
-
logging.getLogger("torch").setLevel(logging.WARNING)
|
| 50 |
-
logging.getLogger("bitsandbytes").setLevel(logging.WARNING)
|
| 51 |
-
|
| 52 |
-
# Import Unsloth first, before other ML imports
|
| 53 |
-
try:
|
| 54 |
-
from unsloth import FastLanguageModel
|
| 55 |
-
from unsloth.chat_templates import get_chat_template
|
| 56 |
-
unsloth_available = True
|
| 57 |
-
logger.info("Unsloth successfully imported")
|
| 58 |
-
except ImportError:
|
| 59 |
-
unsloth_available = False
|
| 60 |
-
logger.warning("Unsloth not available. Please install with: pip install unsloth")
|
| 61 |
|
| 62 |
-
#
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
import transformers
|
| 65 |
-
from transformers import
|
| 66 |
-
|
| 67 |
-
AutoTokenizer,
|
| 68 |
-
TrainingArguments,
|
| 69 |
-
Trainer,
|
| 70 |
-
TrainerCallback,
|
| 71 |
-
set_seed,
|
| 72 |
-
BitsAndBytesConfig
|
| 73 |
-
)
|
| 74 |
-
logger.info(f"Transformers version: {transformers.__version__}")
|
| 75 |
-
except ImportError:
|
| 76 |
-
logger.error("Transformers not available. This is a critical dependency.")
|
| 77 |
|
| 78 |
-
# Check availability of libraries
|
| 79 |
peft_available = find_spec("peft") is not None
|
| 80 |
if peft_available:
|
| 81 |
import peft
|
| 82 |
-
logger.info(f"PEFT version: {peft.__version__}")
|
| 83 |
-
else:
|
| 84 |
-
logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.")
|
| 85 |
|
| 86 |
-
#
|
| 87 |
-
|
|
|
|
| 88 |
from datasets import load_dataset
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
# Define a clean logging function for HF Space compatibility
|
| 94 |
def log_info(message):
|
|
@@ -243,6 +216,17 @@ def load_model_and_tokenizer(config):
|
|
| 243 |
chat_template = get_config_value(tokenizer_config, "chat_template", None)
|
| 244 |
padding_side = get_config_value(tokenizer_config, "padding_side", "right")
|
| 245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
log_info(f"Loading model: {model_name} (revision: {model_revision})")
|
| 247 |
log_info(f"Max sequence length: {max_seq_length}")
|
| 248 |
|
|
@@ -257,7 +241,7 @@ def load_model_and_tokenizer(config):
|
|
| 257 |
dtype=get_config_value(config, "torch_dtype", "bfloat16"),
|
| 258 |
revision=model_revision,
|
| 259 |
trust_remote_code=trust_remote_code,
|
| 260 |
-
use_flash_attention_2=
|
| 261 |
)
|
| 262 |
|
| 263 |
# Configure tokenizer settings
|
|
@@ -294,11 +278,23 @@ def load_model_and_tokenizer(config):
|
|
| 294 |
max_seq_length=max_seq_length,
|
| 295 |
modules_to_save=None
|
| 296 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
else:
|
| 298 |
# Standard HuggingFace loading
|
| 299 |
log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)")
|
| 300 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
# Load tokenizer first
|
| 303 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 304 |
model_name,
|
|
@@ -327,7 +323,8 @@ def load_model_and_tokenizer(config):
|
|
| 327 |
trust_remote_code=trust_remote_code,
|
| 328 |
revision=model_revision,
|
| 329 |
torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16,
|
| 330 |
-
device_map="auto" if CUDA_AVAILABLE else None
|
|
|
|
| 331 |
)
|
| 332 |
|
| 333 |
# Apply PEFT/LoRA if enabled but using standard loading
|
|
@@ -760,6 +757,63 @@ class LoggingCallback(TrainerCallback):
|
|
| 760 |
"""Called at the beginning of a step"""
|
| 761 |
pass
|
| 762 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
def check_dependencies():
|
| 764 |
"""
|
| 765 |
Check for required and optional dependencies, ensuring proper versions and import order.
|
|
@@ -785,6 +839,7 @@ def check_dependencies():
|
|
| 785 |
missing_packages = []
|
| 786 |
package_versions = {}
|
| 787 |
order_issues = []
|
|
|
|
| 788 |
|
| 789 |
# Check required packages
|
| 790 |
log_info("Checking required dependencies...")
|
|
@@ -822,6 +877,7 @@ def check_dependencies():
|
|
| 822 |
log_info(f"✅ {package} - {feature} available")
|
| 823 |
except ImportError:
|
| 824 |
log_info(f"⚠️ {package} - {feature} not available")
|
|
|
|
| 825 |
|
| 826 |
# Check import order for optimal performance
|
| 827 |
if "transformers" in package_versions and "unsloth" in package_versions:
|
|
@@ -835,11 +891,19 @@ def check_dependencies():
|
|
| 835 |
order_issue = "⚠️ For optimal performance, import unsloth before transformers"
|
| 836 |
order_issues.append(order_issue)
|
| 837 |
log_info(order_issue)
|
|
|
|
| 838 |
else:
|
| 839 |
log_info("✅ Import order: unsloth before transformers (optimal)")
|
| 840 |
except (ValueError, IndexError) as e:
|
| 841 |
log_info(f"⚠️ Could not verify import order: {str(e)}")
|
| 842 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 843 |
# Report missing required packages
|
| 844 |
if missing_packages:
|
| 845 |
log_info("\n❌ Critical dependencies missing:")
|
|
@@ -990,10 +1054,22 @@ def setup_environment(args):
|
|
| 990 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True"
|
| 991 |
log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
|
| 992 |
|
| 993 |
-
# Check dependencies
|
| 994 |
if not check_dependencies():
|
| 995 |
raise RuntimeError("Critical dependencies missing")
|
| 996 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 997 |
return transformers_config, seed
|
| 998 |
|
| 999 |
def setup_model_and_tokenizer(config):
|
|
@@ -1001,21 +1077,206 @@ def setup_model_and_tokenizer(config):
|
|
| 1001 |
Load and configure the model and tokenizer.
|
| 1002 |
|
| 1003 |
Args:
|
| 1004 |
-
config: Complete configuration dictionary
|
| 1005 |
|
| 1006 |
Returns:
|
| 1007 |
tuple: (model, tokenizer) - The loaded model and tokenizer
|
| 1008 |
"""
|
| 1009 |
-
|
| 1010 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1011 |
|
| 1012 |
-
if model is
|
| 1013 |
-
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1017 |
|
| 1018 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1019 |
|
| 1020 |
def setup_dataset_and_collator(config, tokenizer):
|
| 1021 |
"""
|
|
@@ -1229,6 +1490,12 @@ def main():
|
|
| 1229 |
logger.info("Starting training process")
|
| 1230 |
|
| 1231 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1232 |
# Parse command line arguments
|
| 1233 |
args = parse_args()
|
| 1234 |
|
|
|
|
| 10 |
from datetime import datetime
|
| 11 |
import time
|
| 12 |
import warnings
|
| 13 |
+
import traceback
|
| 14 |
from importlib.util import find_spec
|
| 15 |
import multiprocessing
|
| 16 |
import torch
|
|
|
|
| 32 |
# Method already set, which is fine
|
| 33 |
print("Multiprocessing start method already set")
|
| 34 |
|
| 35 |
+
# Import order is important: unsloth should be imported before transformers
|
| 36 |
+
# Check for libraries without importing them
|
| 37 |
+
unsloth_available = find_spec("unsloth") is not None
|
| 38 |
+
if unsloth_available:
|
| 39 |
+
import unsloth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
# Import torch first, then transformers if available
|
| 42 |
+
import torch
|
| 43 |
+
transformers_available = find_spec("transformers") is not None
|
| 44 |
+
if transformers_available:
|
| 45 |
import transformers
|
| 46 |
+
from transformers import AutoTokenizer, TrainingArguments, Trainer, set_seed
|
| 47 |
+
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
|
|
|
| 49 |
peft_available = find_spec("peft") is not None
|
| 50 |
if peft_available:
|
| 51 |
import peft
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
+
# Only import HF datasets if available
|
| 54 |
+
datasets_available = find_spec("datasets") is not None
|
| 55 |
+
if datasets_available:
|
| 56 |
from datasets import load_dataset
|
| 57 |
+
|
| 58 |
+
# Set up the logger
|
| 59 |
+
logger = logging.getLogger(__name__)
|
| 60 |
+
log_handler = logging.StreamHandler()
|
| 61 |
+
log_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| 62 |
+
log_handler.setFormatter(log_format)
|
| 63 |
+
logger.addHandler(log_handler)
|
| 64 |
+
logger.setLevel(logging.INFO)
|
| 65 |
|
| 66 |
# Define a clean logging function for HF Space compatibility
|
| 67 |
def log_info(message):
|
|
|
|
| 216 |
chat_template = get_config_value(tokenizer_config, "chat_template", None)
|
| 217 |
padding_side = get_config_value(tokenizer_config, "padding_side", "right")
|
| 218 |
|
| 219 |
+
# Check for flash attention
|
| 220 |
+
use_flash_attention = get_config_value(config, "use_flash_attention", False)
|
| 221 |
+
flash_attention_available = False
|
| 222 |
+
try:
|
| 223 |
+
import flash_attn
|
| 224 |
+
flash_attention_available = True
|
| 225 |
+
log_info(f"Flash Attention detected (version: {flash_attn.__version__})")
|
| 226 |
+
except ImportError:
|
| 227 |
+
if use_flash_attention:
|
| 228 |
+
log_info("Flash Attention requested but not available")
|
| 229 |
+
|
| 230 |
log_info(f"Loading model: {model_name} (revision: {model_revision})")
|
| 231 |
log_info(f"Max sequence length: {max_seq_length}")
|
| 232 |
|
|
|
|
| 241 |
dtype=get_config_value(config, "torch_dtype", "bfloat16"),
|
| 242 |
revision=model_revision,
|
| 243 |
trust_remote_code=trust_remote_code,
|
| 244 |
+
use_flash_attention_2=use_flash_attention and flash_attention_available
|
| 245 |
)
|
| 246 |
|
| 247 |
# Configure tokenizer settings
|
|
|
|
| 278 |
max_seq_length=max_seq_length,
|
| 279 |
modules_to_save=None
|
| 280 |
)
|
| 281 |
+
|
| 282 |
+
if use_flash_attention and flash_attention_available:
|
| 283 |
+
log_info("🚀 Using Flash Attention for faster training")
|
| 284 |
+
elif use_flash_attention and not flash_attention_available:
|
| 285 |
+
log_info("⚠️ Flash Attention requested but not available - using standard attention")
|
| 286 |
+
|
| 287 |
else:
|
| 288 |
# Standard HuggingFace loading
|
| 289 |
log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)")
|
| 290 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 291 |
|
| 292 |
+
# Check if flash attention should be enabled in config
|
| 293 |
+
use_attn_implementation = None
|
| 294 |
+
if use_flash_attention and flash_attention_available:
|
| 295 |
+
use_attn_implementation = "flash_attention_2"
|
| 296 |
+
log_info("🚀 Using Flash Attention for faster training")
|
| 297 |
+
|
| 298 |
# Load tokenizer first
|
| 299 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 300 |
model_name,
|
|
|
|
| 323 |
trust_remote_code=trust_remote_code,
|
| 324 |
revision=model_revision,
|
| 325 |
torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16,
|
| 326 |
+
device_map="auto" if CUDA_AVAILABLE else None,
|
| 327 |
+
attn_implementation=use_attn_implementation
|
| 328 |
)
|
| 329 |
|
| 330 |
# Apply PEFT/LoRA if enabled but using standard loading
|
|
|
|
| 757 |
"""Called at the beginning of a step"""
|
| 758 |
pass
|
| 759 |
|
| 760 |
+
def install_flash_attention():
|
| 761 |
+
"""
|
| 762 |
+
Attempt to install Flash Attention for improved performance.
|
| 763 |
+
Returns True if installation was successful, False otherwise.
|
| 764 |
+
"""
|
| 765 |
+
log_info("Attempting to install Flash Attention...")
|
| 766 |
+
|
| 767 |
+
# Check for CUDA before attempting installation
|
| 768 |
+
if not CUDA_AVAILABLE:
|
| 769 |
+
log_info("❌ Cannot install Flash Attention: CUDA not available")
|
| 770 |
+
return False
|
| 771 |
+
|
| 772 |
+
try:
|
| 773 |
+
# Check CUDA version to determine correct installation command
|
| 774 |
+
cuda_version = torch.version.cuda
|
| 775 |
+
if cuda_version is None:
|
| 776 |
+
log_info("❌ Cannot determine CUDA version for Flash Attention installation")
|
| 777 |
+
return False
|
| 778 |
+
|
| 779 |
+
import subprocess
|
| 780 |
+
|
| 781 |
+
# Use --no-build-isolation for better compatibility
|
| 782 |
+
install_cmd = [
|
| 783 |
+
sys.executable,
|
| 784 |
+
"-m",
|
| 785 |
+
"pip",
|
| 786 |
+
"install",
|
| 787 |
+
"flash-attn",
|
| 788 |
+
"--no-build-isolation"
|
| 789 |
+
]
|
| 790 |
+
|
| 791 |
+
log_info(f"Running: {' '.join(install_cmd)}")
|
| 792 |
+
result = subprocess.run(
|
| 793 |
+
install_cmd,
|
| 794 |
+
capture_output=True,
|
| 795 |
+
text=True,
|
| 796 |
+
check=False
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
if result.returncode == 0:
|
| 800 |
+
log_info("✅ Flash Attention installed successfully!")
|
| 801 |
+
# Attempt to import to verify installation
|
| 802 |
+
try:
|
| 803 |
+
import flash_attn
|
| 804 |
+
log_info(f"✅ Flash Attention version {flash_attn.__version__} is now available")
|
| 805 |
+
return True
|
| 806 |
+
except ImportError:
|
| 807 |
+
log_info("⚠️ Flash Attention installed but import failed")
|
| 808 |
+
return False
|
| 809 |
+
else:
|
| 810 |
+
log_info(f"❌ Flash Attention installation failed with error: {result.stderr}")
|
| 811 |
+
return False
|
| 812 |
+
|
| 813 |
+
except Exception as e:
|
| 814 |
+
log_info(f"❌ Error installing Flash Attention: {str(e)}")
|
| 815 |
+
return False
|
| 816 |
+
|
| 817 |
def check_dependencies():
|
| 818 |
"""
|
| 819 |
Check for required and optional dependencies, ensuring proper versions and import order.
|
|
|
|
| 839 |
missing_packages = []
|
| 840 |
package_versions = {}
|
| 841 |
order_issues = []
|
| 842 |
+
missing_optional = []
|
| 843 |
|
| 844 |
# Check required packages
|
| 845 |
log_info("Checking required dependencies...")
|
|
|
|
| 877 |
log_info(f"✅ {package} - {feature} available")
|
| 878 |
except ImportError:
|
| 879 |
log_info(f"⚠️ {package} - {feature} not available")
|
| 880 |
+
missing_optional.append(package)
|
| 881 |
|
| 882 |
# Check import order for optimal performance
|
| 883 |
if "transformers" in package_versions and "unsloth" in package_versions:
|
|
|
|
| 891 |
order_issue = "⚠️ For optimal performance, import unsloth before transformers"
|
| 892 |
order_issues.append(order_issue)
|
| 893 |
log_info(order_issue)
|
| 894 |
+
log_info("This might cause performance issues but won't prevent training")
|
| 895 |
else:
|
| 896 |
log_info("✅ Import order: unsloth before transformers (optimal)")
|
| 897 |
except (ValueError, IndexError) as e:
|
| 898 |
log_info(f"⚠️ Could not verify import order: {str(e)}")
|
| 899 |
|
| 900 |
+
# Try to install missing optional packages
|
| 901 |
+
if "flash_attn" in missing_optional and CUDA_AVAILABLE:
|
| 902 |
+
log_info("\nFlash Attention is missing but would improve performance.")
|
| 903 |
+
install_result = install_flash_attention()
|
| 904 |
+
if install_result:
|
| 905 |
+
missing_optional.remove("flash_attn")
|
| 906 |
+
|
| 907 |
# Report missing required packages
|
| 908 |
if missing_packages:
|
| 909 |
log_info("\n❌ Critical dependencies missing:")
|
|
|
|
| 1054 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True"
|
| 1055 |
log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
|
| 1056 |
|
| 1057 |
+
# Check dependencies and install optional ones if needed
|
| 1058 |
if not check_dependencies():
|
| 1059 |
raise RuntimeError("Critical dependencies missing")
|
| 1060 |
|
| 1061 |
+
# Check if flash attention was successfully installed
|
| 1062 |
+
flash_attention_available = False
|
| 1063 |
+
try:
|
| 1064 |
+
import flash_attn
|
| 1065 |
+
flash_attention_available = True
|
| 1066 |
+
log_info(f"Flash Attention will be used (version: {flash_attn.__version__})")
|
| 1067 |
+
# Update config to use flash attention
|
| 1068 |
+
if "use_flash_attention" not in transformers_config:
|
| 1069 |
+
transformers_config["use_flash_attention"] = True
|
| 1070 |
+
except ImportError:
|
| 1071 |
+
log_info("Flash Attention not available, will use standard attention mechanism")
|
| 1072 |
+
|
| 1073 |
return transformers_config, seed
|
| 1074 |
|
| 1075 |
def setup_model_and_tokenizer(config):
|
|
|
|
| 1077 |
Load and configure the model and tokenizer.
|
| 1078 |
|
| 1079 |
Args:
|
| 1080 |
+
config (dict): Complete configuration dictionary
|
| 1081 |
|
| 1082 |
Returns:
|
| 1083 |
tuple: (model, tokenizer) - The loaded model and tokenizer
|
| 1084 |
"""
|
| 1085 |
+
# Extract model configuration
|
| 1086 |
+
model_config = get_config_value(config, "model", {})
|
| 1087 |
+
model_name = get_config_value(model_config, "name", "unsloth/phi-4-unsloth-bnb-4bit")
|
| 1088 |
+
use_fast_tokenizer = get_config_value(model_config, "use_fast_tokenizer", True)
|
| 1089 |
+
trust_remote_code = get_config_value(model_config, "trust_remote_code", True)
|
| 1090 |
+
model_revision = get_config_value(config, "model_revision", "main")
|
| 1091 |
|
| 1092 |
+
# Detect if model is already pre-quantized (includes '4bit', 'bnb', or 'int4' in name)
|
| 1093 |
+
is_prequantized = any(q in model_name.lower() for q in ['4bit', 'bnb', 'int4', 'quant'])
|
| 1094 |
+
if is_prequantized:
|
| 1095 |
+
log_info("⚠️ Detected pre-quantized model. No additional quantization will be applied.")
|
| 1096 |
+
|
| 1097 |
+
# Unsloth configuration
|
| 1098 |
+
unsloth_config = get_config_value(config, "unsloth", {})
|
| 1099 |
+
unsloth_enabled = get_config_value(unsloth_config, "enabled", True)
|
| 1100 |
+
|
| 1101 |
+
# Tokenizer configuration
|
| 1102 |
+
tokenizer_config = get_config_value(config, "tokenizer", {})
|
| 1103 |
+
max_seq_length = min(
|
| 1104 |
+
get_config_value(tokenizer_config, "max_seq_length", 2048),
|
| 1105 |
+
4096 # Maximum supported by most models
|
| 1106 |
+
)
|
| 1107 |
+
add_eos_token = get_config_value(tokenizer_config, "add_eos_token", True)
|
| 1108 |
+
chat_template = get_config_value(tokenizer_config, "chat_template", None)
|
| 1109 |
+
padding_side = get_config_value(tokenizer_config, "padding_side", "right")
|
| 1110 |
+
|
| 1111 |
+
# Check for flash attention
|
| 1112 |
+
use_flash_attention = get_config_value(config, "use_flash_attention", False)
|
| 1113 |
+
flash_attention_available = False
|
| 1114 |
+
try:
|
| 1115 |
+
import flash_attn
|
| 1116 |
+
flash_attention_available = True
|
| 1117 |
+
log_info(f"Flash Attention detected (version: {flash_attn.__version__})")
|
| 1118 |
+
except ImportError:
|
| 1119 |
+
if use_flash_attention:
|
| 1120 |
+
log_info("Flash Attention requested but not available")
|
| 1121 |
+
|
| 1122 |
+
log_info(f"Loading model: {model_name} (revision: {model_revision})")
|
| 1123 |
+
log_info(f"Max sequence length: {max_seq_length}")
|
| 1124 |
|
| 1125 |
+
try:
|
| 1126 |
+
if unsloth_enabled and unsloth_available:
|
| 1127 |
+
log_info("Using Unsloth for LoRA fine-tuning")
|
| 1128 |
+
if is_prequantized:
|
| 1129 |
+
log_info("Using pre-quantized model - no additional quantization will be applied")
|
| 1130 |
+
else:
|
| 1131 |
+
log_info("Using 4-bit quantization for efficient training")
|
| 1132 |
+
|
| 1133 |
+
# Load using Unsloth
|
| 1134 |
+
from unsloth import FastLanguageModel
|
| 1135 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 1136 |
+
model_name=model_name,
|
| 1137 |
+
max_seq_length=max_seq_length,
|
| 1138 |
+
dtype=get_config_value(config, "torch_dtype", "bfloat16"),
|
| 1139 |
+
revision=model_revision,
|
| 1140 |
+
trust_remote_code=trust_remote_code,
|
| 1141 |
+
use_flash_attention_2=use_flash_attention and flash_attention_available
|
| 1142 |
+
)
|
| 1143 |
+
|
| 1144 |
+
# Configure tokenizer settings
|
| 1145 |
+
tokenizer.padding_side = padding_side
|
| 1146 |
+
if add_eos_token and tokenizer.eos_token is None:
|
| 1147 |
+
log_info("Setting EOS token")
|
| 1148 |
+
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
| 1149 |
+
|
| 1150 |
+
# Set chat template if specified
|
| 1151 |
+
if chat_template:
|
| 1152 |
+
log_info(f"Setting chat template: {chat_template}")
|
| 1153 |
+
if hasattr(tokenizer, "chat_template"):
|
| 1154 |
+
tokenizer.chat_template = chat_template
|
| 1155 |
+
else:
|
| 1156 |
+
log_info("Tokenizer does not support chat templates, using default formatting")
|
| 1157 |
+
|
| 1158 |
+
# Apply LoRA
|
| 1159 |
+
lora_r = get_config_value(unsloth_config, "r", 16)
|
| 1160 |
+
lora_alpha = get_config_value(unsloth_config, "alpha", 32)
|
| 1161 |
+
lora_dropout = get_config_value(unsloth_config, "dropout", 0)
|
| 1162 |
+
target_modules = get_config_value(unsloth_config, "target_modules",
|
| 1163 |
+
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
|
| 1164 |
+
|
| 1165 |
+
log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}")
|
| 1166 |
+
model = FastLanguageModel.get_peft_model(
|
| 1167 |
+
model,
|
| 1168 |
+
r=lora_r,
|
| 1169 |
+
target_modules=target_modules,
|
| 1170 |
+
lora_alpha=lora_alpha,
|
| 1171 |
+
lora_dropout=lora_dropout,
|
| 1172 |
+
bias="none",
|
| 1173 |
+
use_gradient_checkpointing=get_config_value(config, "training.gradient_checkpointing", True),
|
| 1174 |
+
random_state=0,
|
| 1175 |
+
max_seq_length=max_seq_length,
|
| 1176 |
+
modules_to_save=None
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
+
if use_flash_attention and flash_attention_available:
|
| 1180 |
+
log_info("🚀 Using Flash Attention for faster training")
|
| 1181 |
+
elif use_flash_attention and not flash_attention_available:
|
| 1182 |
+
log_info("⚠️ Flash Attention requested but not available - using standard attention")
|
| 1183 |
+
|
| 1184 |
+
else:
|
| 1185 |
+
# Standard HuggingFace loading
|
| 1186 |
+
log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)")
|
| 1187 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1188 |
+
|
| 1189 |
+
# Check if flash attention should be enabled in config
|
| 1190 |
+
use_attn_implementation = None
|
| 1191 |
+
if use_flash_attention and flash_attention_available:
|
| 1192 |
+
use_attn_implementation = "flash_attention_2"
|
| 1193 |
+
log_info("🚀 Using Flash Attention for faster training")
|
| 1194 |
+
|
| 1195 |
+
# Load tokenizer first
|
| 1196 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 1197 |
+
model_name,
|
| 1198 |
+
trust_remote_code=trust_remote_code,
|
| 1199 |
+
use_fast=use_fast_tokenizer,
|
| 1200 |
+
revision=model_revision,
|
| 1201 |
+
padding_side=padding_side
|
| 1202 |
+
)
|
| 1203 |
+
|
| 1204 |
+
# Configure tokenizer settings
|
| 1205 |
+
if add_eos_token and tokenizer.eos_token is None:
|
| 1206 |
+
log_info("Setting EOS token")
|
| 1207 |
+
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
| 1208 |
+
|
| 1209 |
+
# Set chat template if specified
|
| 1210 |
+
if chat_template:
|
| 1211 |
+
log_info(f"Setting chat template: {chat_template}")
|
| 1212 |
+
if hasattr(tokenizer, "chat_template"):
|
| 1213 |
+
tokenizer.chat_template = chat_template
|
| 1214 |
+
else:
|
| 1215 |
+
log_info("Tokenizer does not support chat templates, using default formatting")
|
| 1216 |
+
|
| 1217 |
+
# Only apply quantization config if model is not already pre-quantized
|
| 1218 |
+
quantization_config = None
|
| 1219 |
+
if not is_prequantized and CUDA_AVAILABLE:
|
| 1220 |
+
try:
|
| 1221 |
+
from transformers import BitsAndBytesConfig
|
| 1222 |
+
log_info("Using 4-bit quantization (BitsAndBytes) for efficient training")
|
| 1223 |
+
quantization_config = BitsAndBytesConfig(
|
| 1224 |
+
load_in_4bit=True,
|
| 1225 |
+
bnb_4bit_quant_type="nf4",
|
| 1226 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 1227 |
+
bnb_4bit_use_double_quant=True
|
| 1228 |
+
)
|
| 1229 |
+
except ImportError:
|
| 1230 |
+
log_info("BitsAndBytes not available - quantization disabled")
|
| 1231 |
+
|
| 1232 |
+
# Now load model with updated tokenizer
|
| 1233 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 1234 |
+
model_name,
|
| 1235 |
+
trust_remote_code=trust_remote_code,
|
| 1236 |
+
revision=model_revision,
|
| 1237 |
+
torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16,
|
| 1238 |
+
device_map="auto" if CUDA_AVAILABLE else None,
|
| 1239 |
+
attn_implementation=use_attn_implementation,
|
| 1240 |
+
quantization_config=quantization_config
|
| 1241 |
+
)
|
| 1242 |
+
|
| 1243 |
+
# Apply PEFT/LoRA if enabled but using standard loading
|
| 1244 |
+
if peft_available and get_config_value(unsloth_config, "enabled", True):
|
| 1245 |
+
log_info("Applying standard PEFT/LoRA configuration")
|
| 1246 |
+
from peft import LoraConfig, get_peft_model
|
| 1247 |
+
|
| 1248 |
+
lora_r = get_config_value(unsloth_config, "r", 16)
|
| 1249 |
+
lora_alpha = get_config_value(unsloth_config, "alpha", 32)
|
| 1250 |
+
lora_dropout = get_config_value(unsloth_config, "dropout", 0)
|
| 1251 |
+
target_modules = get_config_value(unsloth_config, "target_modules",
|
| 1252 |
+
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
|
| 1253 |
+
|
| 1254 |
+
log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}")
|
| 1255 |
+
lora_config = LoraConfig(
|
| 1256 |
+
r=lora_r,
|
| 1257 |
+
lora_alpha=lora_alpha,
|
| 1258 |
+
target_modules=target_modules,
|
| 1259 |
+
lora_dropout=lora_dropout,
|
| 1260 |
+
bias="none",
|
| 1261 |
+
task_type="CAUSAL_LM"
|
| 1262 |
+
)
|
| 1263 |
+
model = get_peft_model(model, lora_config)
|
| 1264 |
+
|
| 1265 |
+
# Print model summary
|
| 1266 |
+
log_info(f"Model loaded successfully: {model.__class__.__name__}")
|
| 1267 |
+
if hasattr(model, "print_trainable_parameters"):
|
| 1268 |
+
model.print_trainable_parameters()
|
| 1269 |
+
else:
|
| 1270 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 1271 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 1272 |
+
log_info(f"Model has {total_params:,} parameters, {trainable_params:,} trainable ({trainable_params/total_params:.2%})")
|
| 1273 |
+
|
| 1274 |
+
return model, tokenizer
|
| 1275 |
+
|
| 1276 |
+
except Exception as e:
|
| 1277 |
+
log_info(f"Error loading model: {str(e)}")
|
| 1278 |
+
traceback.print_exc()
|
| 1279 |
+
return None, None
|
| 1280 |
|
| 1281 |
def setup_dataset_and_collator(config, tokenizer):
|
| 1282 |
"""
|
|
|
|
| 1490 |
logger.info("Starting training process")
|
| 1491 |
|
| 1492 |
try:
|
| 1493 |
+
# Check for potential import order issue and warn early
|
| 1494 |
+
if "transformers" in sys.modules and "unsloth" in sys.modules:
|
| 1495 |
+
if list(sys.modules.keys()).index("transformers") < list(sys.modules.keys()).index("unsloth"):
|
| 1496 |
+
log_info("⚠️ Warning: transformers was imported before unsloth. This may affect performance.")
|
| 1497 |
+
log_info(" For optimal performance in future runs, import unsloth first.")
|
| 1498 |
+
|
| 1499 |
# Parse command line arguments
|
| 1500 |
args = parse_args()
|
| 1501 |
|