Spaces:
Paused
Paused
Upload folder using huggingface_hub
Browse files- requirements.txt +1 -1
- run_transformers_training.py +44 -158
- update_space.py +12 -4
requirements.txt
CHANGED
|
@@ -4,7 +4,6 @@ bitsandbytes>=0.41.0
|
|
| 4 |
datasets>=2.15.0
|
| 5 |
einops>=0.7.0
|
| 6 |
filelock>=3.13.1
|
| 7 |
-
flash-attn==2.5.2
|
| 8 |
gradio>=5.17.0
|
| 9 |
huggingface-hub>=0.19.0
|
| 10 |
matplotlib>=3.7.0
|
|
@@ -23,3 +22,4 @@ tqdm>=4.65.0
|
|
| 23 |
transformers>=4.36.0
|
| 24 |
typing-extensions>=4.8.0
|
| 25 |
unsloth>=2024.3
|
|
|
|
|
|
| 4 |
datasets>=2.15.0
|
| 5 |
einops>=0.7.0
|
| 6 |
filelock>=3.13.1
|
|
|
|
| 7 |
gradio>=5.17.0
|
| 8 |
huggingface-hub>=0.19.0
|
| 9 |
matplotlib>=3.7.0
|
|
|
|
| 22 |
transformers>=4.36.0
|
| 23 |
typing-extensions>=4.8.0
|
| 24 |
unsloth>=2024.3
|
| 25 |
+
flash-attn==2.5.2
|
run_transformers_training.py
CHANGED
|
@@ -158,38 +158,13 @@ def load_model_and_tokenizer(config):
|
|
| 158 |
|
| 159 |
logger.info("Using Unsloth optimizations with pre-quantized model")
|
| 160 |
|
| 161 |
-
# Check for flash attention
|
| 162 |
-
use_flash_attention = config.get("use_flash_attention", True)
|
| 163 |
-
if use_flash_attention and not find_spec("flash_attn"):
|
| 164 |
-
logger.warning("flash-attn not found. Will continue without flash attention.")
|
| 165 |
-
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
|
| 166 |
-
use_flash_attention = False
|
| 167 |
-
|
| 168 |
# First detect if we have a GPU
|
| 169 |
if torch.cuda.is_available():
|
| 170 |
gpu_count = torch.cuda.device_count()
|
| 171 |
-
logger.info(f"
|
| 172 |
-
|
| 173 |
-
# Log GPU info
|
| 174 |
-
for i in range(gpu_count):
|
| 175 |
-
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
| 176 |
-
logger.info(f"Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")
|
| 177 |
-
|
| 178 |
-
# Create an optimized device map for better balance
|
| 179 |
-
if gpu_count > 1:
|
| 180 |
-
logger.info(f"Creating balanced device map for {gpu_count} GPUs")
|
| 181 |
-
# Use auto mapping but with memory tracking
|
| 182 |
-
device_map = "auto"
|
| 183 |
-
# Set max memory for better balancing
|
| 184 |
-
max_memory = {i: f"{int(torch.cuda.get_device_properties(i).total_memory * 0.85 / 1024**3)}GiB" for i in range(gpu_count)}
|
| 185 |
-
logger.info(f"Max memory settings: {max_memory}")
|
| 186 |
-
else:
|
| 187 |
-
device_map = "auto"
|
| 188 |
-
max_memory = None
|
| 189 |
else:
|
| 190 |
-
logger.warning("No CUDA
|
| 191 |
-
|
| 192 |
-
max_memory = None
|
| 193 |
|
| 194 |
# Set default dtype for better numerics
|
| 195 |
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
|
@@ -205,6 +180,13 @@ def load_model_and_tokenizer(config):
|
|
| 205 |
dtype = None
|
| 206 |
logger.info("Using default precision (CPU)")
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
# Load model with proper error handling for out-of-memory
|
| 209 |
try:
|
| 210 |
# Improved memory settings for multi-GPU setup
|
|
@@ -300,6 +282,16 @@ def load_dataset_with_mapping(dataset_config):
|
|
| 300 |
else:
|
| 301 |
logger.info(f"Dataset has all required fields: {required_fields}")
|
| 302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
# Log a few samples for verification
|
| 304 |
if len(dataset) > 0:
|
| 305 |
sample_indices = range(min(5, len(dataset)))
|
|
@@ -524,54 +516,15 @@ class LoggingCallback(TrainerCallback):
|
|
| 524 |
log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
|
| 525 |
log_info(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
|
| 526 |
|
| 527 |
-
#
|
| 528 |
-
|
| 529 |
-
self.verify_sequence = dataset_config.get("validation", {}).get("verify_sequence_integrity", False)
|
| 530 |
-
if self.verify_sequence:
|
| 531 |
-
log_info("Sequence integrity verification enabled during training")
|
| 532 |
-
|
| 533 |
-
# Save actual samples for later verification
|
| 534 |
-
if trainer and hasattr(trainer, 'train_dataset') and trainer.train_dataset is not None:
|
| 535 |
-
# Get some reference samples from the beginning of the dataset defensively
|
| 536 |
-
self.sample_indices = []
|
| 537 |
-
self.sequence_samples = []
|
| 538 |
-
|
| 539 |
-
max_samples = min(5, len(trainer.train_dataset))
|
| 540 |
-
for i in range(max_samples):
|
| 541 |
-
try:
|
| 542 |
-
if i < len(trainer.train_dataset):
|
| 543 |
-
self.sample_indices.append(i)
|
| 544 |
-
self.sequence_samples.append(trainer.train_dataset[i])
|
| 545 |
-
except Exception as e:
|
| 546 |
-
log_info(f"Warning: Error capturing reference sample at index {i}: {e}")
|
| 547 |
-
|
| 548 |
-
if self.sequence_samples:
|
| 549 |
-
log_info(f"Captured {len(self.sequence_samples)} reference samples for sequence integrity verification")
|
| 550 |
-
|
| 551 |
-
# Log sample prompt numbers for debugging
|
| 552 |
-
sample_prompt_numbers = []
|
| 553 |
-
for s in self.sequence_samples:
|
| 554 |
-
if isinstance(s, dict) and 'prompt_number' in s and s['prompt_number'] is not None:
|
| 555 |
-
sample_prompt_numbers.append(s.get('prompt_number'))
|
| 556 |
-
|
| 557 |
-
if sample_prompt_numbers:
|
| 558 |
-
log_info(f"Reference sample prompt numbers: {sample_prompt_numbers}")
|
| 559 |
-
if sample_prompt_numbers == list(range(1, len(sample_prompt_numbers) + 1)):
|
| 560 |
-
log_info("Prompt numbers are sequential (1-indexed) - sequence integrity confirmed")
|
| 561 |
-
else:
|
| 562 |
-
log_info("Prompt numbers are not in expected sequence - will verify during training")
|
| 563 |
-
else:
|
| 564 |
-
log_info("Warning: No reference samples were captured")
|
| 565 |
-
else:
|
| 566 |
-
log_info("Warning: Could not capture reference samples - verification will be limited")
|
| 567 |
-
except Exception as e:
|
| 568 |
-
log_info(f"Warning: Could not set up sequence integrity verification: {e}")
|
| 569 |
-
self.verify_sequence = False
|
| 570 |
|
| 571 |
log_info("=== Training is starting ===")
|
| 572 |
|
| 573 |
# Log important training parameters for visibility
|
| 574 |
total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
|
|
|
|
|
|
|
| 575 |
log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
|
| 576 |
log_info(f"Learning rate: {args.learning_rate}")
|
| 577 |
log_info(f"Epochs: {args.num_train_epochs}")
|
|
@@ -585,90 +538,12 @@ class LoggingCallback(TrainerCallback):
|
|
| 585 |
memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
|
| 586 |
|
| 587 |
log_info(f"Initial memory usage - {', '.join(memory_info)}")
|
| 588 |
-
|
| 589 |
def on_step_end(self, args, state, control, **kwargs):
|
| 590 |
# Log every 50 steps or every 5 minutes, whichever comes first
|
| 591 |
current_time = time.time()
|
| 592 |
|
| 593 |
-
#
|
| 594 |
-
if self.verify_sequence is True and state.global_step % 100 == 0 and self.sequence_samples:
|
| 595 |
-
try:
|
| 596 |
-
# Get a batch of data without disturbing the training
|
| 597 |
-
train_dataloader = trainer.get_train_dataloader()
|
| 598 |
-
if train_dataloader is None:
|
| 599 |
-
log_info("Warning: Could not get train dataloader for verification")
|
| 600 |
-
else:
|
| 601 |
-
batch_iterator = iter(train_dataloader)
|
| 602 |
-
if batch_iterator is None:
|
| 603 |
-
log_info("Warning: Could not get batch iterator for verification")
|
| 604 |
-
else:
|
| 605 |
-
try:
|
| 606 |
-
batch = next(batch_iterator)
|
| 607 |
-
if batch is None:
|
| 608 |
-
log_info("Warning: Could not get batch for verification")
|
| 609 |
-
elif 'input_ids' in batch and 'labels' in batch:
|
| 610 |
-
log_info("Verifying data sequence integrity...")
|
| 611 |
-
|
| 612 |
-
# Check if we can access some of our reference samples
|
| 613 |
-
if not hasattr(trainer, 'train_dataset') or trainer.train_dataset is None:
|
| 614 |
-
log_info("Warning: Train dataset is not available")
|
| 615 |
-
else:
|
| 616 |
-
# Get current samples defensively
|
| 617 |
-
current_samples = []
|
| 618 |
-
current_indices = list(range(min(3, len(trainer.train_dataset))))
|
| 619 |
-
|
| 620 |
-
for idx in current_indices:
|
| 621 |
-
try:
|
| 622 |
-
if idx < len(trainer.train_dataset):
|
| 623 |
-
current_samples.append(trainer.train_dataset[idx])
|
| 624 |
-
except Exception as e:
|
| 625 |
-
log_info(f"Warning: Error accessing dataset at index {idx}: {e}")
|
| 626 |
-
|
| 627 |
-
# Only proceed if we have samples to compare
|
| 628 |
-
if current_samples and self.sequence_samples:
|
| 629 |
-
# Compare current samples with our reference samples from training start
|
| 630 |
-
is_sequence_maintained = True
|
| 631 |
-
|
| 632 |
-
for i, (orig_idx, orig_sample) in enumerate(zip(self.sample_indices, self.sequence_samples)):
|
| 633 |
-
# Check if sample index is valid
|
| 634 |
-
if i < len(current_samples):
|
| 635 |
-
current_sample = current_samples[i]
|
| 636 |
-
|
| 637 |
-
# Compare prompt numbers if available - this is our primary check now
|
| 638 |
-
if ('prompt_number' in orig_sample and
|
| 639 |
-
'prompt_number' in current_sample and
|
| 640 |
-
orig_sample['prompt_number'] is not None and
|
| 641 |
-
current_sample['prompt_number'] is not None):
|
| 642 |
-
|
| 643 |
-
if orig_sample['prompt_number'] != current_sample['prompt_number']:
|
| 644 |
-
log_info(f"WARNING: Sequence integrity compromised! Sample {i} prompt number changed from {orig_sample['prompt_number']} to {current_sample['prompt_number']}")
|
| 645 |
-
is_sequence_maintained = False
|
| 646 |
-
else:
|
| 647 |
-
# This is now our primary verification
|
| 648 |
-
log_info(f"Prompt number match confirmed for sample {i}: {orig_sample['prompt_number']}")
|
| 649 |
-
|
| 650 |
-
# Also compare article_id as a backup check
|
| 651 |
-
elif ('article_id' in orig_sample and
|
| 652 |
-
'article_id' in current_sample and
|
| 653 |
-
orig_sample['article_id'] is not None and
|
| 654 |
-
current_sample['article_id'] is not None):
|
| 655 |
-
|
| 656 |
-
if orig_sample['article_id'] != current_sample['article_id']:
|
| 657 |
-
log_info(f"WARNING: Sequence integrity compromised! Sample {i} article_id changed from {orig_sample['article_id']} to {current_sample['article_id']}")
|
| 658 |
-
is_sequence_maintained = False
|
| 659 |
-
|
| 660 |
-
if is_sequence_maintained:
|
| 661 |
-
log_info("Data sequence integrity check: OK - prompt numbers preserved")
|
| 662 |
-
else:
|
| 663 |
-
log_info("CRITICAL WARNING: Data sequence integrity check FAILED!")
|
| 664 |
-
else:
|
| 665 |
-
log_info("Warning: Not enough samples available for sequence verification")
|
| 666 |
-
except StopIteration:
|
| 667 |
-
log_info("Warning: No batches available in the dataloader")
|
| 668 |
-
except Exception as e:
|
| 669 |
-
log_info(f"Warning: Error iterating through dataloader: {e}")
|
| 670 |
-
except Exception as e:
|
| 671 |
-
log_info(f"Warning: Couldn't verify sequence integrity: {e}")
|
| 672 |
|
| 673 |
# Log progress at regular intervals
|
| 674 |
if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
|
|
@@ -708,13 +583,6 @@ def check_dependencies():
|
|
| 708 |
if not peft_available:
|
| 709 |
missing_packages.append("peft>=0.9.0")
|
| 710 |
|
| 711 |
-
# Optional packages - don't add to missing list, just log
|
| 712 |
-
if find_spec("flash_attn"):
|
| 713 |
-
logger.info("flash-attn found. Flash attention will be used for faster training.")
|
| 714 |
-
else:
|
| 715 |
-
logger.warning("flash-attn not found. Training will work but may be slower.")
|
| 716 |
-
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
|
| 717 |
-
|
| 718 |
# If critical packages are missing, exit with instructions
|
| 719 |
if missing_packages:
|
| 720 |
logger.error("Critical dependencies missing:")
|
|
@@ -723,6 +591,13 @@ def check_dependencies():
|
|
| 723 |
logger.error("Please ensure the space has these packages in requirements.txt")
|
| 724 |
return False
|
| 725 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
return True
|
| 727 |
|
| 728 |
def main():
|
|
@@ -934,6 +809,17 @@ def main():
|
|
| 934 |
|
| 935 |
# Log our approach clearly
|
| 936 |
log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 937 |
log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
|
| 938 |
|
| 939 |
# Calculate batch size based on device availability
|
|
|
|
| 158 |
|
| 159 |
logger.info("Using Unsloth optimizations with pre-quantized model")
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
# First detect if we have a GPU
|
| 162 |
if torch.cuda.is_available():
|
| 163 |
gpu_count = torch.cuda.device_count()
|
| 164 |
+
logger.info(f"Found {gpu_count} CUDA devices")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
else:
|
| 166 |
+
logger.warning("No CUDA devices detected. Training will be slow on CPU!")
|
| 167 |
+
gpu_count = 0
|
|
|
|
| 168 |
|
| 169 |
# Set default dtype for better numerics
|
| 170 |
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
|
|
|
| 180 |
dtype = None
|
| 181 |
logger.info("Using default precision (CPU)")
|
| 182 |
|
| 183 |
+
# Check for flash attention as the last dependency check
|
| 184 |
+
use_flash_attention = config.get("use_flash_attention", True)
|
| 185 |
+
if use_flash_attention and not find_spec("flash_attn"):
|
| 186 |
+
logger.warning("flash-attn not found. Will continue without flash attention.")
|
| 187 |
+
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
|
| 188 |
+
use_flash_attention = False
|
| 189 |
+
|
| 190 |
# Load model with proper error handling for out-of-memory
|
| 191 |
try:
|
| 192 |
# Improved memory settings for multi-GPU setup
|
|
|
|
| 282 |
else:
|
| 283 |
logger.info(f"Dataset has all required fields: {required_fields}")
|
| 284 |
|
| 285 |
+
# Verify that column order matches our expectation
|
| 286 |
+
expected_order = ["prompt_number", "article_id", "conversations"]
|
| 287 |
+
actual_order = dataset.column_names
|
| 288 |
+
|
| 289 |
+
if actual_order == expected_order:
|
| 290 |
+
logger.info("Dataset column order matches expected order (prompt_number, article_id, conversations)")
|
| 291 |
+
else:
|
| 292 |
+
logger.warning(f"Dataset column order ({', '.join(actual_order)}) differs from expected order ({', '.join(expected_order)})")
|
| 293 |
+
logger.warning("This should not affect processing but is noted for debugging purposes")
|
| 294 |
+
|
| 295 |
# Log a few samples for verification
|
| 296 |
if len(dataset) > 0:
|
| 297 |
sample_indices = range(min(5, len(dataset)))
|
|
|
|
| 516 |
log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
|
| 517 |
log_info(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
|
| 518 |
|
| 519 |
+
# Disable sequence verification
|
| 520 |
+
self.verify_sequence = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
|
| 522 |
log_info("=== Training is starting ===")
|
| 523 |
|
| 524 |
# Log important training parameters for visibility
|
| 525 |
total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
|
| 526 |
+
total_steps = int(len(dataset) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs)
|
| 527 |
+
log_info(f"Training plan: {len(dataset)} examples over {args.num_train_epochs} epochs ≈ {total_steps} steps")
|
| 528 |
log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
|
| 529 |
log_info(f"Learning rate: {args.learning_rate}")
|
| 530 |
log_info(f"Epochs: {args.num_train_epochs}")
|
|
|
|
| 538 |
memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
|
| 539 |
|
| 540 |
log_info(f"Initial memory usage - {', '.join(memory_info)}")
|
| 541 |
+
|
| 542 |
def on_step_end(self, args, state, control, **kwargs):
|
| 543 |
# Log every 50 steps or every 5 minutes, whichever comes first
|
| 544 |
current_time = time.time()
|
| 545 |
|
| 546 |
+
# Sequence verification removed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
|
| 548 |
# Log progress at regular intervals
|
| 549 |
if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
|
|
|
|
| 583 |
if not peft_available:
|
| 584 |
missing_packages.append("peft>=0.9.0")
|
| 585 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
# If critical packages are missing, exit with instructions
|
| 587 |
if missing_packages:
|
| 588 |
logger.error("Critical dependencies missing:")
|
|
|
|
| 591 |
logger.error("Please ensure the space has these packages in requirements.txt")
|
| 592 |
return False
|
| 593 |
|
| 594 |
+
# Optional packages - moved to the end
|
| 595 |
+
if find_spec("flash_attn"):
|
| 596 |
+
logger.info("flash-attn found. Flash attention will be used for faster training.")
|
| 597 |
+
else:
|
| 598 |
+
logger.warning("flash-attn not found. Training will work but may be slower.")
|
| 599 |
+
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
|
| 600 |
+
|
| 601 |
return True
|
| 602 |
|
| 603 |
def main():
|
|
|
|
| 809 |
|
| 810 |
# Log our approach clearly
|
| 811 |
log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
|
| 812 |
+
|
| 813 |
+
# Verify column order
|
| 814 |
+
expected_order = ["prompt_number", "article_id", "conversations"]
|
| 815 |
+
if hasattr(dataset, 'column_names'):
|
| 816 |
+
actual_order = dataset.column_names
|
| 817 |
+
if actual_order == expected_order:
|
| 818 |
+
log_info(f"Confirmed dataset columns are in expected order: {', '.join(expected_order)}")
|
| 819 |
+
else:
|
| 820 |
+
log_info(f"Note: Dataset columns ({', '.join(actual_order)}) are not in expected order ({', '.join(expected_order)})")
|
| 821 |
+
log_info("This is handled correctly by field-based access, but noting for clarity")
|
| 822 |
+
|
| 823 |
log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
|
| 824 |
|
| 825 |
# Calculate batch size based on device availability
|
update_space.py
CHANGED
|
@@ -121,17 +121,25 @@ def update_requirements():
|
|
| 121 |
# Add new requirements
|
| 122 |
updated_requirements = existing_requirements.union(required_packages)
|
| 123 |
|
| 124 |
-
# Write updated requirements with torch first
|
| 125 |
with open(req_path, 'w') as f:
|
| 126 |
# Ensure torch is first
|
| 127 |
torch_req = next((req for req in updated_requirements if req.startswith("torch")), "torch>=2.0.0")
|
| 128 |
f.write(f"{torch_req}\n")
|
| 129 |
|
| 130 |
-
#
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
f.write(f"{req}\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
logger.info("Updated requirements.txt with
|
| 135 |
|
| 136 |
def create_space(username, space_name):
|
| 137 |
"""Create or get a Hugging Face Space."""
|
|
|
|
| 121 |
# Add new requirements
|
| 122 |
updated_requirements = existing_requirements.union(required_packages)
|
| 123 |
|
| 124 |
+
# Write updated requirements with torch first and flash-attn last
|
| 125 |
with open(req_path, 'w') as f:
|
| 126 |
# Ensure torch is first
|
| 127 |
torch_req = next((req for req in updated_requirements if req.startswith("torch")), "torch>=2.0.0")
|
| 128 |
f.write(f"{torch_req}\n")
|
| 129 |
|
| 130 |
+
# Extract flash-attn to add it last
|
| 131 |
+
flash_attn_req = next((req for req in updated_requirements if req.startswith("flash-attn")), None)
|
| 132 |
+
|
| 133 |
+
# Write all other requirements (excluding torch and flash-attn)
|
| 134 |
+
for req in sorted(r for r in updated_requirements
|
| 135 |
+
if not r.startswith("torch") and not r.startswith("flash-attn")):
|
| 136 |
f.write(f"{req}\n")
|
| 137 |
+
|
| 138 |
+
# Add flash-attn as the very last package
|
| 139 |
+
if flash_attn_req:
|
| 140 |
+
f.write(f"{flash_attn_req}\n")
|
| 141 |
|
| 142 |
+
logger.info("Updated requirements.txt with torch listed first and flash-attn listed last")
|
| 143 |
|
| 144 |
def create_space(username, space_name):
|
| 145 |
"""Create or get a Hugging Face Space."""
|