Spaces:
Sleeping
Sleeping
Upload run_cloud_training.py with huggingface_hub
Browse files- run_cloud_training.py +57 -193
run_cloud_training.py
CHANGED
|
@@ -401,147 +401,62 @@ def remove_training_marker():
|
|
| 401 |
os.remove("TRAINING_ACTIVE")
|
| 402 |
logger.info("Removed training active marker")
|
| 403 |
|
| 404 |
-
def load_model_safely(model_name, max_seq_length, dtype=None):
|
| 405 |
"""
|
| 406 |
-
Load the model
|
| 407 |
-
by trying different loading strategies.
|
| 408 |
"""
|
| 409 |
-
|
| 410 |
|
| 411 |
-
#
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
|
| 416 |
-
#
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
logger.info("Patching LLaMA attention implementation to avoid xformers")
|
| 425 |
-
|
| 426 |
-
# Store original implementation
|
| 427 |
-
if hasattr(llama_modeling.LlamaAttention, 'forward'):
|
| 428 |
-
llama_modeling._original_forward = llama_modeling.LlamaAttention.forward
|
| 429 |
-
|
| 430 |
-
# Define a new forward method that doesn't use xformers
|
| 431 |
-
def safe_attention_forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False):
|
| 432 |
-
logger.info("Using safe attention implementation (no xformers)")
|
| 433 |
-
|
| 434 |
-
# Force use_flash_attention to False
|
| 435 |
-
self._attn_implementation = "eager"
|
| 436 |
-
if hasattr(self, 'use_flash_attention'):
|
| 437 |
-
self.use_flash_attention = False
|
| 438 |
-
if hasattr(self, 'use_flash_attention_2'):
|
| 439 |
-
self.use_flash_attention_2 = False
|
| 440 |
-
|
| 441 |
-
# Call original implementation with flash attention disabled
|
| 442 |
-
return llama_modeling._original_forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
|
| 443 |
-
|
| 444 |
-
# Replace the forward method
|
| 445 |
-
llama_modeling.LlamaAttention.forward = safe_attention_forward
|
| 446 |
-
logger.info("Successfully patched LLaMA attention implementation")
|
| 447 |
-
except Exception as e:
|
| 448 |
-
logger.warning(f"Failed to patch attention implementation: {e}")
|
| 449 |
-
logger.info("Will try to proceed with standard loading")
|
| 450 |
|
|
|
|
| 451 |
try:
|
| 452 |
-
logger.info(
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
bnb_4bit_quant_type="nf4",
|
| 460 |
-
bnb_4bit_use_double_quant=True
|
| 461 |
)
|
|
|
|
|
|
|
| 462 |
|
| 463 |
-
# First try loading with unsloth but without flash attention
|
| 464 |
-
try:
|
| 465 |
-
logger.info("Loading model with unsloth optimizations")
|
| 466 |
-
# Don't pass any flash attention parameters to unsloth
|
| 467 |
-
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 468 |
-
model_name=model_name,
|
| 469 |
-
max_seq_length=max_seq_length,
|
| 470 |
-
dtype=dtype,
|
| 471 |
-
quantization_config=bnb_config,
|
| 472 |
-
attn_implementation="eager" # Force eager attention
|
| 473 |
-
)
|
| 474 |
-
logger.info("Model loaded successfully with unsloth")
|
| 475 |
-
|
| 476 |
-
# Explicitly disable flash attention in model config
|
| 477 |
-
if hasattr(model, 'config'):
|
| 478 |
-
if hasattr(model.config, 'attn_implementation'):
|
| 479 |
-
model.config.attn_implementation = "eager"
|
| 480 |
-
|
| 481 |
-
return model, tokenizer
|
| 482 |
-
|
| 483 |
-
except Exception as e:
|
| 484 |
-
logger.warning(f"Unsloth loading failed: {e}")
|
| 485 |
-
logger.info("Falling back to standard Hugging Face loading...")
|
| 486 |
-
|
| 487 |
-
# We'll try with HF loading
|
| 488 |
-
attn_params = {
|
| 489 |
-
"attn_implementation": "eager" # Always use eager
|
| 490 |
-
}
|
| 491 |
-
|
| 492 |
-
# Approach 1: Using attn_implementation parameter (newer method)
|
| 493 |
-
try:
|
| 494 |
-
logger.info(f"Trying HF loading with attention parameters: {attn_params}")
|
| 495 |
-
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 496 |
-
|
| 497 |
-
# Disable flash attention in config
|
| 498 |
-
if hasattr(config, 'attn_implementation'):
|
| 499 |
-
config.attn_implementation = "eager"
|
| 500 |
-
|
| 501 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 502 |
-
|
| 503 |
-
# The proper way to set attention implementation in newer transformers
|
| 504 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 505 |
-
model_name,
|
| 506 |
-
config=config,
|
| 507 |
-
device_map="auto",
|
| 508 |
-
torch_dtype=dtype or torch.float16,
|
| 509 |
-
quantization_config=bnb_config,
|
| 510 |
-
trust_remote_code=True,
|
| 511 |
-
**attn_params
|
| 512 |
-
)
|
| 513 |
-
logger.info(f"Model loaded successfully with HF using attention parameters: {attn_params}")
|
| 514 |
-
return model, tokenizer
|
| 515 |
-
|
| 516 |
-
except Exception as e:
|
| 517 |
-
logger.warning(f"HF loading with attn_implementation failed: {e}")
|
| 518 |
-
logger.info("Trying fallback method...")
|
| 519 |
-
|
| 520 |
-
# Approach 2: Complete fallback with minimal parameters
|
| 521 |
-
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 522 |
-
|
| 523 |
-
# Disable flash attention in config
|
| 524 |
-
if hasattr(config, 'attn_implementation'):
|
| 525 |
-
config.attn_implementation = "eager"
|
| 526 |
-
|
| 527 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 528 |
-
|
| 529 |
-
# Most basic loading without any attention parameters
|
| 530 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 531 |
-
model_name,
|
| 532 |
-
config=config,
|
| 533 |
-
device_map="auto",
|
| 534 |
-
torch_dtype=dtype or torch.float16,
|
| 535 |
-
quantization_config=bnb_config,
|
| 536 |
-
trust_remote_code=True,
|
| 537 |
-
attn_implementation="eager"
|
| 538 |
-
)
|
| 539 |
-
logger.info("Model loaded successfully with basic HF loading")
|
| 540 |
-
return model, tokenizer
|
| 541 |
-
|
| 542 |
except Exception as e:
|
| 543 |
-
logger.
|
| 544 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
|
| 546 |
def train(config_path, dataset_name, output_dir):
|
| 547 |
"""Main training function - RESEARCH TRAINING PHASE ONLY"""
|
|
@@ -556,50 +471,6 @@ def train(config_path, dataset_name, output_dir):
|
|
| 556 |
lora_config = config.get("lora_config", {})
|
| 557 |
dataset_config = config.get("dataset_config", {})
|
| 558 |
|
| 559 |
-
# Force disable flash attention and xformers
|
| 560 |
-
os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
|
| 561 |
-
os.environ["XFORMERS_DISABLED"] = "1"
|
| 562 |
-
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
| 563 |
-
|
| 564 |
-
# Monkey patch torch.nn.functional to disable memory_efficient_attention
|
| 565 |
-
try:
|
| 566 |
-
import torch.nn.functional as F
|
| 567 |
-
if hasattr(F, 'scaled_dot_product_attention'):
|
| 568 |
-
logger.info("Monkey patching torch.nn.functional.scaled_dot_product_attention")
|
| 569 |
-
original_sdpa = F.scaled_dot_product_attention
|
| 570 |
-
|
| 571 |
-
def safe_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
| 572 |
-
# Force disable memory efficient attention
|
| 573 |
-
logger.info("Using safe scaled_dot_product_attention (no xformers)")
|
| 574 |
-
return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
| 575 |
-
|
| 576 |
-
F.scaled_dot_product_attention = safe_sdpa
|
| 577 |
-
except Exception as e:
|
| 578 |
-
logger.warning(f"Failed to patch scaled_dot_product_attention: {e}")
|
| 579 |
-
|
| 580 |
-
# Completely remove xformers from sys.modules if it's loaded
|
| 581 |
-
for module_name in list(sys.modules.keys()):
|
| 582 |
-
if 'xformers' in module_name:
|
| 583 |
-
logger.info(f"Removing {module_name} from sys.modules")
|
| 584 |
-
del sys.modules[module_name]
|
| 585 |
-
|
| 586 |
-
# Update flash attention setting to always use eager
|
| 587 |
-
global flash_attention_available
|
| 588 |
-
flash_attention_available = False
|
| 589 |
-
logger.info("Flash Attention has been DISABLED globally")
|
| 590 |
-
|
| 591 |
-
# Update hardware config to ensure eager attention
|
| 592 |
-
hardware_config["attn_implementation"] = "eager"
|
| 593 |
-
|
| 594 |
-
# Verify this is training phase only
|
| 595 |
-
training_phase_only = dataset_config.get("training_phase_only", True)
|
| 596 |
-
if not training_phase_only:
|
| 597 |
-
logger.warning("This script is meant for research training phase only")
|
| 598 |
-
logger.warning("Setting training_phase_only=True")
|
| 599 |
-
|
| 600 |
-
# Verify dataset is pre-tokenized
|
| 601 |
-
logger.info("IMPORTANT: Using pre-tokenized dataset - No tokenization will be performed")
|
| 602 |
-
|
| 603 |
# Set the output directory
|
| 604 |
output_dir = output_dir or training_config.get("output_dir", "fine_tuned_model")
|
| 605 |
os.makedirs(output_dir, exist_ok=True)
|
|
@@ -628,8 +499,8 @@ def train(config_path, dataset_name, output_dir):
|
|
| 628 |
)
|
| 629 |
tokenizer.pad_token = tokenizer.eos_token
|
| 630 |
|
| 631 |
-
# Initialize model
|
| 632 |
-
logger.info("Initializing model
|
| 633 |
max_seq_length = training_config.get("max_seq_length", 2048)
|
| 634 |
|
| 635 |
# Create LoRA config directly
|
|
@@ -642,29 +513,21 @@ def train(config_path, dataset_name, output_dir):
|
|
| 642 |
target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
|
| 643 |
)
|
| 644 |
|
|
|
|
|
|
|
|
|
|
| 645 |
# Initialize model with our safe loading function
|
| 646 |
-
logger.info("Loading pre-quantized model
|
| 647 |
dtype = torch.float16 if hardware_config.get("fp16", True) else None
|
| 648 |
-
|
| 649 |
-
# Force eager attention implementation
|
| 650 |
-
os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
|
| 651 |
-
logger.info("Flash attention has been DISABLED globally via environment variable")
|
| 652 |
-
|
| 653 |
-
# Update hardware config to ensure eager attention
|
| 654 |
-
hardware_config["attn_implementation"] = "eager"
|
| 655 |
-
|
| 656 |
-
model, tokenizer = load_model_safely(model_name, max_seq_length, dtype)
|
| 657 |
|
| 658 |
# Disable generation capabilities for research training
|
| 659 |
logger.info("Disabling generation capabilities - Research training only")
|
| 660 |
model.config.is_decoder = False
|
| 661 |
model.config.task_specific_params = None
|
| 662 |
|
| 663 |
-
#
|
| 664 |
logger.info("Applying LoRA to model")
|
| 665 |
-
|
| 666 |
-
# Skip unsloth's method and go directly to PEFT
|
| 667 |
-
logger.info("Using standard PEFT method to apply LoRA")
|
| 668 |
from peft import get_peft_model
|
| 669 |
model = get_peft_model(model, lora_config_obj)
|
| 670 |
logger.info("Successfully applied LoRA with standard PEFT")
|
|
@@ -692,7 +555,6 @@ def train(config_path, dataset_name, output_dir):
|
|
| 692 |
logger.warning("No reporting backends available - training metrics won't be logged")
|
| 693 |
|
| 694 |
# Set up training arguments with correct parameters
|
| 695 |
-
# Extract only the valid parameters from hardware_config
|
| 696 |
training_args_dict = {
|
| 697 |
"output_dir": output_dir,
|
| 698 |
"num_train_epochs": training_config.get("num_train_epochs", 3),
|
|
@@ -764,6 +626,8 @@ if __name__ == "__main__":
|
|
| 764 |
help="Dataset name or path")
|
| 765 |
parser.add_argument("--output_dir", type=str, default=None,
|
| 766 |
help="Output directory for the fine-tuned model")
|
|
|
|
|
|
|
| 767 |
|
| 768 |
args = parser.parse_args()
|
| 769 |
|
|
|
|
| 401 |
os.remove("TRAINING_ACTIVE")
|
| 402 |
logger.info("Removed training active marker")
|
| 403 |
|
| 404 |
+
def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attention=False):
|
| 405 |
"""
|
| 406 |
+
Load the model with appropriate attention settings based on hardware capability
|
|
|
|
| 407 |
"""
|
| 408 |
+
logger.info(f"Loading model: {model_name}")
|
| 409 |
|
| 410 |
+
# Create BitsAndBytesConfig for 4-bit quantization
|
| 411 |
+
from transformers import BitsAndBytesConfig
|
| 412 |
+
bnb_config = BitsAndBytesConfig(
|
| 413 |
+
load_in_4bit=True,
|
| 414 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 415 |
+
bnb_4bit_quant_type="nf4",
|
| 416 |
+
bnb_4bit_use_double_quant=True
|
| 417 |
+
)
|
| 418 |
|
| 419 |
+
# Determine appropriate attention implementation
|
| 420 |
+
attn_implementation = "sdpa" # Default to PyTorch's scaled dot product attention
|
| 421 |
+
|
| 422 |
+
if use_flash_attention and flash_attention_available:
|
| 423 |
+
logger.info("Using Flash Attention for faster training")
|
| 424 |
+
attn_implementation = "flash_attention_2"
|
| 425 |
+
else:
|
| 426 |
+
logger.info("Using standard attention mechanism (sdpa)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
+
# Try loading with unsloth
|
| 429 |
try:
|
| 430 |
+
logger.info("Loading model with unsloth optimizations")
|
| 431 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 432 |
+
model_name=model_name,
|
| 433 |
+
max_seq_length=max_seq_length,
|
| 434 |
+
dtype=dtype,
|
| 435 |
+
quantization_config=bnb_config,
|
| 436 |
+
attn_implementation=attn_implementation
|
|
|
|
|
|
|
| 437 |
)
|
| 438 |
+
logger.info("Model loaded successfully with unsloth")
|
| 439 |
+
return model, tokenizer
|
| 440 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
except Exception as e:
|
| 442 |
+
logger.warning(f"Unsloth loading failed: {e}")
|
| 443 |
+
logger.info("Falling back to standard Hugging Face loading...")
|
| 444 |
+
|
| 445 |
+
# Fallback to standard HF loading
|
| 446 |
+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 447 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 448 |
+
|
| 449 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 450 |
+
model_name,
|
| 451 |
+
config=config,
|
| 452 |
+
device_map="auto",
|
| 453 |
+
torch_dtype=dtype or torch.float16,
|
| 454 |
+
quantization_config=bnb_config,
|
| 455 |
+
trust_remote_code=True,
|
| 456 |
+
attn_implementation=attn_implementation
|
| 457 |
+
)
|
| 458 |
+
logger.info("Model loaded successfully with standard HF loading")
|
| 459 |
+
return model, tokenizer
|
| 460 |
|
| 461 |
def train(config_path, dataset_name, output_dir):
|
| 462 |
"""Main training function - RESEARCH TRAINING PHASE ONLY"""
|
|
|
|
| 471 |
lora_config = config.get("lora_config", {})
|
| 472 |
dataset_config = config.get("dataset_config", {})
|
| 473 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
# Set the output directory
|
| 475 |
output_dir = output_dir or training_config.get("output_dir", "fine_tuned_model")
|
| 476 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
| 499 |
)
|
| 500 |
tokenizer.pad_token = tokenizer.eos_token
|
| 501 |
|
| 502 |
+
# Initialize model
|
| 503 |
+
logger.info("Initializing model (preserving 4-bit quantization)")
|
| 504 |
max_seq_length = training_config.get("max_seq_length", 2048)
|
| 505 |
|
| 506 |
# Create LoRA config directly
|
|
|
|
| 513 |
target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
|
| 514 |
)
|
| 515 |
|
| 516 |
+
# Determine if we should use flash attention
|
| 517 |
+
use_flash_attention = hardware_config.get("use_flash_attention", False)
|
| 518 |
+
|
| 519 |
# Initialize model with our safe loading function
|
| 520 |
+
logger.info("Loading pre-quantized model")
|
| 521 |
dtype = torch.float16 if hardware_config.get("fp16", True) else None
|
| 522 |
+
model, tokenizer = load_model_safely(model_name, max_seq_length, dtype, use_flash_attention)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
|
| 524 |
# Disable generation capabilities for research training
|
| 525 |
logger.info("Disabling generation capabilities - Research training only")
|
| 526 |
model.config.is_decoder = False
|
| 527 |
model.config.task_specific_params = None
|
| 528 |
|
| 529 |
+
# Apply LoRA to model
|
| 530 |
logger.info("Applying LoRA to model")
|
|
|
|
|
|
|
|
|
|
| 531 |
from peft import get_peft_model
|
| 532 |
model = get_peft_model(model, lora_config_obj)
|
| 533 |
logger.info("Successfully applied LoRA with standard PEFT")
|
|
|
|
| 555 |
logger.warning("No reporting backends available - training metrics won't be logged")
|
| 556 |
|
| 557 |
# Set up training arguments with correct parameters
|
|
|
|
| 558 |
training_args_dict = {
|
| 559 |
"output_dir": output_dir,
|
| 560 |
"num_train_epochs": training_config.get("num_train_epochs", 3),
|
|
|
|
| 626 |
help="Dataset name or path")
|
| 627 |
parser.add_argument("--output_dir", type=str, default=None,
|
| 628 |
help="Output directory for the fine-tuned model")
|
| 629 |
+
parser.add_argument("--use_flash_attention", action="store_true",
|
| 630 |
+
help="Use Flash Attention if available")
|
| 631 |
|
| 632 |
args = parser.parse_args()
|
| 633 |
|