Spaces:
Sleeping
Sleeping
Upload run_cloud_training.py with huggingface_hub
Browse files- run_cloud_training.py +54 -13
run_cloud_training.py
CHANGED
|
@@ -407,6 +407,10 @@ def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attentio
|
|
| 407 |
"""
|
| 408 |
logger.info(f"Loading model: {model_name}")
|
| 409 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
# Create BitsAndBytesConfig for 4-bit quantization
|
| 411 |
from transformers import BitsAndBytesConfig
|
| 412 |
bnb_config = BitsAndBytesConfig(
|
|
@@ -416,14 +420,9 @@ def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attentio
|
|
| 416 |
bnb_4bit_use_double_quant=True
|
| 417 |
)
|
| 418 |
|
| 419 |
-
#
|
| 420 |
-
attn_implementation = "
|
| 421 |
-
|
| 422 |
-
if use_flash_attention and flash_attention_available:
|
| 423 |
-
logger.info("Using Flash Attention for faster training")
|
| 424 |
-
attn_implementation = "flash_attention_2"
|
| 425 |
-
else:
|
| 426 |
-
logger.info("Using standard attention mechanism (sdpa)")
|
| 427 |
|
| 428 |
# Try loading with unsloth
|
| 429 |
try:
|
|
@@ -436,6 +435,12 @@ def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attentio
|
|
| 436 |
attn_implementation=attn_implementation
|
| 437 |
)
|
| 438 |
logger.info("Model loaded successfully with unsloth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
return model, tokenizer
|
| 440 |
|
| 441 |
except Exception as e:
|
|
@@ -444,6 +449,10 @@ def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attentio
|
|
| 444 |
|
| 445 |
# Fallback to standard HF loading
|
| 446 |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 448 |
|
| 449 |
model = AutoModelForCausalLM.from_pretrained(
|
|
@@ -464,6 +473,32 @@ def train(config_path, dataset_name, output_dir):
|
|
| 464 |
load_dotenv()
|
| 465 |
config = load_config(config_path)
|
| 466 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
# Extract configs
|
| 468 |
model_config = config.get("model_config", {})
|
| 469 |
training_config = config.get("training_config", {})
|
|
@@ -513,11 +548,11 @@ def train(config_path, dataset_name, output_dir):
|
|
| 513 |
target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
|
| 514 |
)
|
| 515 |
|
| 516 |
-
#
|
| 517 |
-
use_flash_attention =
|
| 518 |
|
| 519 |
# Initialize model with our safe loading function
|
| 520 |
-
logger.info("Loading pre-quantized model")
|
| 521 |
dtype = torch.float16 if hardware_config.get("fp16", True) else None
|
| 522 |
model, tokenizer = load_model_safely(model_name, max_seq_length, dtype, use_flash_attention)
|
| 523 |
|
|
@@ -531,7 +566,10 @@ def train(config_path, dataset_name, output_dir):
|
|
| 531 |
from peft import get_peft_model
|
| 532 |
model = get_peft_model(model, lora_config_obj)
|
| 533 |
logger.info("Successfully applied LoRA with standard PEFT")
|
| 534 |
-
|
|
|
|
|
|
|
|
|
|
| 535 |
# No need to format the dataset - it's already pre-tokenized
|
| 536 |
logger.info("Using dataset with flexible tokenization handling")
|
| 537 |
logger.info("Will use pre-tokenized data if available, or tokenize strings as fallback")
|
|
@@ -627,10 +665,13 @@ if __name__ == "__main__":
|
|
| 627 |
parser.add_argument("--output_dir", type=str, default=None,
|
| 628 |
help="Output directory for the fine-tuned model")
|
| 629 |
parser.add_argument("--use_flash_attention", action="store_true",
|
| 630 |
-
help="Use Flash Attention if available")
|
| 631 |
|
| 632 |
args = parser.parse_args()
|
| 633 |
|
|
|
|
|
|
|
|
|
|
| 634 |
# Run training - Research phase only
|
| 635 |
try:
|
| 636 |
output_path = train(args.config, args.dataset, args.output_dir)
|
|
|
|
| 407 |
"""
|
| 408 |
logger.info(f"Loading model: {model_name}")
|
| 409 |
|
| 410 |
+
# Explicitly disable xformers and flash attention in environment
|
| 411 |
+
os.environ["XFORMERS_DISABLED"] = "1"
|
| 412 |
+
os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
|
| 413 |
+
|
| 414 |
# Create BitsAndBytesConfig for 4-bit quantization
|
| 415 |
from transformers import BitsAndBytesConfig
|
| 416 |
bnb_config = BitsAndBytesConfig(
|
|
|
|
| 420 |
bnb_4bit_use_double_quant=True
|
| 421 |
)
|
| 422 |
|
| 423 |
+
# Force eager implementation to avoid BMGHK format issues
|
| 424 |
+
attn_implementation = "eager" # Use eager implementation to avoid BMGHK format issues
|
| 425 |
+
logger.info(f"Forcing eager attention implementation to avoid BMGHK format issues")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
|
| 427 |
# Try loading with unsloth
|
| 428 |
try:
|
|
|
|
| 435 |
attn_implementation=attn_implementation
|
| 436 |
)
|
| 437 |
logger.info("Model loaded successfully with unsloth")
|
| 438 |
+
|
| 439 |
+
# Explicitly set attention implementation in model config
|
| 440 |
+
if hasattr(model, 'config'):
|
| 441 |
+
model.config.attn_implementation = attn_implementation
|
| 442 |
+
logger.info(f"Explicitly set model config attention implementation to {attn_implementation}")
|
| 443 |
+
|
| 444 |
return model, tokenizer
|
| 445 |
|
| 446 |
except Exception as e:
|
|
|
|
| 449 |
|
| 450 |
# Fallback to standard HF loading
|
| 451 |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 452 |
+
|
| 453 |
+
# Set attention implementation in config
|
| 454 |
+
config.attn_implementation = attn_implementation
|
| 455 |
+
|
| 456 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 457 |
|
| 458 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 473 |
load_dotenv()
|
| 474 |
config = load_config(config_path)
|
| 475 |
|
| 476 |
+
# Explicitly disable xformers and flash attention in environment
|
| 477 |
+
os.environ["XFORMERS_DISABLED"] = "1"
|
| 478 |
+
os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
|
| 479 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
| 480 |
+
|
| 481 |
+
# Try to unload xformers if it's loaded
|
| 482 |
+
if 'xformers' in sys.modules:
|
| 483 |
+
logger.info("Removing xformers from sys.modules")
|
| 484 |
+
del sys.modules['xformers']
|
| 485 |
+
|
| 486 |
+
# Patch torch.nn.functional to avoid memory_efficient_attention
|
| 487 |
+
try:
|
| 488 |
+
import torch.nn.functional as F
|
| 489 |
+
if hasattr(F, 'scaled_dot_product_attention'):
|
| 490 |
+
logger.info("Patching torch.nn.functional.scaled_dot_product_attention")
|
| 491 |
+
original_sdpa = F.scaled_dot_product_attention
|
| 492 |
+
|
| 493 |
+
def safe_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
| 494 |
+
# Force disable memory efficient attention
|
| 495 |
+
logger.info("Using safe scaled_dot_product_attention (no xformers)")
|
| 496 |
+
return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
| 497 |
+
|
| 498 |
+
F.scaled_dot_product_attention = safe_sdpa
|
| 499 |
+
except Exception as e:
|
| 500 |
+
logger.warning(f"Failed to patch scaled_dot_product_attention: {e}")
|
| 501 |
+
|
| 502 |
# Extract configs
|
| 503 |
model_config = config.get("model_config", {})
|
| 504 |
training_config = config.get("training_config", {})
|
|
|
|
| 548 |
target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
|
| 549 |
)
|
| 550 |
|
| 551 |
+
# Force eager attention implementation
|
| 552 |
+
use_flash_attention = False # Override to force eager implementation
|
| 553 |
|
| 554 |
# Initialize model with our safe loading function
|
| 555 |
+
logger.info("Loading pre-quantized model with eager attention")
|
| 556 |
dtype = torch.float16 if hardware_config.get("fp16", True) else None
|
| 557 |
model, tokenizer = load_model_safely(model_name, max_seq_length, dtype, use_flash_attention)
|
| 558 |
|
|
|
|
| 566 |
from peft import get_peft_model
|
| 567 |
model = get_peft_model(model, lora_config_obj)
|
| 568 |
logger.info("Successfully applied LoRA with standard PEFT")
|
| 569 |
+
|
| 570 |
+
# Explicitly set attention implementation in model config again after PEFT
|
| 571 |
+
model.config.attn_implementation = "eager"
|
| 572 |
+
|
| 573 |
# No need to format the dataset - it's already pre-tokenized
|
| 574 |
logger.info("Using dataset with flexible tokenization handling")
|
| 575 |
logger.info("Will use pre-tokenized data if available, or tokenize strings as fallback")
|
|
|
|
| 665 |
parser.add_argument("--output_dir", type=str, default=None,
|
| 666 |
help="Output directory for the fine-tuned model")
|
| 667 |
parser.add_argument("--use_flash_attention", action="store_true",
|
| 668 |
+
help="Use Flash Attention if available (NOT RECOMMENDED)")
|
| 669 |
|
| 670 |
args = parser.parse_args()
|
| 671 |
|
| 672 |
+
# Override flash attention setting to force eager implementation
|
| 673 |
+
args.use_flash_attention = False
|
| 674 |
+
|
| 675 |
# Run training - Research phase only
|
| 676 |
try:
|
| 677 |
output_path = train(args.config, args.dataset, args.output_dir)
|