File size: 33,564 Bytes
7275aef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
#!/usr/bin/env python3
"""
Single-GPU LoRA Fine-Tuning Script for Humigence
=================================================

This script provides a robust, single-GPU LoRA fine-tuning solution that works
exactly like the fixed script, but generalized for all models supported by Humigence.

Key Features:
- โœ… Single GPU training (safe default)
- โœ… bf16 precision where supported
- โœ… Proper gradient flow (no loss=None errors)
- โœ… PEFT/LoRA integration with correct target modules
- โœ… Gradient checkpointing enabled
- โœ… Support for LLaMA, Mistral, Phi-2, and other models
- โœ… Comprehensive error handling and validation

Usage:
    # Via Humigence CLI
    humigence train-lora --model meta-llama/Meta-Llama-3-8B-Instruct --dataset wikitext-2-raw-v1 --output-dir ./out_lora

    # Direct execution
    python3 cli/train_lora_single.py --model meta-llama/Meta-Llama-3-8B-Instruct --dataset wikitext-2-raw-v1 --output-dir ./out_lora

    # With accelerate (recommended)
    accelerate launch --num_processes=1 cli/train_lora_single.py --model meta-llama/Meta-Llama-3-8B-Instruct --dataset wikitext-2-raw-v1 --output-dir ./out_lora

Tested Models:
- โœ… meta-llama/Meta-Llama-3-8B-Instruct
- โœ… mistralai/Mistral-7B-Instruct-v0.1
- โœ… microsoft/Phi-2
- โœ… TinyLlama/TinyLlama-1.1B-Chat-v1.0
- โœ… Qwen/Qwen1.5-0.5B

Validation:
After training, validate your adapters with:
    python3 -c "
    from transformers import AutoTokenizer, AutoModelForCausalLM
    from peft import PeftModel
    tokenizer = AutoTokenizer.from_pretrained('./out_lora')
    model = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct')
    model = PeftModel.from_pretrained(model, './out_lora')
    print('โœ… Adapters loaded successfully!')
    "
"""

import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Dict, Any, Optional, List
import json
import time

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
console = Console()


class LoRATrainer(Trainer):
    """
    Custom trainer that ensures proper gradient flow for LoRA models.
    This is the key fix that prevents loss=None errors.
    """
    
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """
        Compute loss ensuring gradients flow properly.
        This is the critical fix that ensures loss.requires_grad = True
        """
        # Get model outputs
        outputs = model(**inputs)
        
        # Check if model returned loss
        if hasattr(outputs, 'loss') and outputs.loss is not None:
            loss = outputs.loss
        else:
            # Manual loss computation if model didn't return loss
            logits = outputs.logits
            labels = inputs.get("labels")
            
            if labels is not None:
                # Shift logits and labels for causal LM
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
                
                # Compute cross-entropy loss
                loss_fct = torch.nn.CrossEntropyLoss()
                loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            else:
                # Fallback: create a dummy loss that requires gradients
                loss = torch.tensor(0.0, requires_grad=True, device=next(model.parameters()).device)
        
        # Ensure loss requires gradients - this is the critical fix
        if not loss.requires_grad:
            logger.warning("Loss does not require gradients! This will cause training to fail.")
            # Force gradient computation by creating a new tensor
            loss = loss.detach().requires_grad_(True)
        
        return (loss, outputs) if return_outputs else loss
    
    def evaluation_loop(self, dataloader, description, prediction_loss_only=None, ignore_keys=None, metric_key_prefix="eval"):
        """
        Override evaluation loop to ensure proper gradient flow during evaluation.
        """
        # Set model to eval mode but keep gradients enabled for LoRA
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
        
        # Ensure model is in eval mode but gradients are still enabled
        model.eval()
        
        # Call parent evaluation loop
        return super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)


def get_model_target_modules(model_name: str) -> List[str]:
    """
    Get the correct LoRA target modules for different model architectures.
    
    Args:
        model_name: Name or path of the model
        
    Returns:
        List of target module names for LoRA
    """
    model_name_lower = model_name.lower()
    
    # LLaMA family (including Llama-3, CodeLlama, etc.)
    if any(x in model_name_lower for x in ["llama", "codellama", "vicuna", "alpaca"]):
        return ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    
    # Mistral family
    elif any(x in model_name_lower for x in ["mistral", "mixtral"]):
        return ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    
    # Phi family
    elif any(x in model_name_lower for x in ["phi", "microsoft"]):
        return ["q_proj", "k_proj", "v_proj", "dense"]
    
    # GPT family
    elif any(x in model_name_lower for x in ["gpt", "openai"]):
        return ["c_attn", "c_proj"]
    
    # Qwen family
    elif any(x in model_name_lower for x in ["qwen"]):
        return ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    
    # TinyLlama
    elif "tinyllama" in model_name_lower:
        return ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    
    # Default fallback for unknown models
    else:
        logger.warning(f"Unknown model architecture for {model_name}, using default target modules")
        return ["q_proj", "k_proj", "v_proj", "o_proj"]


def prepare_dataset(tokenizer, dataset_name: str = "wikitext", dataset_config: str = "wikitext-2-raw-v1", block_size: int = 512):
    """
    Prepare the dataset with proper tokenization and labeling.
    This mirrors the working dataset preparation from the fixed script.
    """
    logger.info(f"Loading dataset: {dataset_name}/{dataset_config}")
    
    # Load dataset - handle both Hugging Face datasets and local files
    if dataset_name == "jsonl":
        # Load local JSONL file
        from datasets import Dataset
        import json
        
        logger.info(f"Loading local JSONL file: {dataset_config}")
        data = []
        with open(dataset_config, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    data.append(json.loads(line))
        
        # Convert to Hugging Face dataset format
        train_dataset = Dataset.from_list(data)
        # Create validation split (use first 10% for validation)
        val_size = max(1, len(train_dataset) // 10)
        val_dataset = train_dataset.select(range(val_size))
        train_dataset = train_dataset.select(range(val_size, len(train_dataset)))
        
        # Wrap in the expected format
        dataset = {"train": train_dataset, "validation": val_dataset}
    else:
        # Load Hugging Face dataset
        dataset = load_dataset(dataset_name, dataset_config)
    
    def tokenize_function(examples):
        """Tokenize the dataset."""
        # Handle different column structures
        if "text" in examples:
            text_column = "text"
        elif "instruction" in examples and "output" in examples:
            # For instruction-following datasets, combine instruction and output
            # Create combined text from instruction and output
            if "input" in examples and examples["input"]:
                combined_text = [f"Instruction: {inst}\nInput: {inp}\nOutput: {out}" 
                               for inst, inp, out in zip(examples["instruction"], examples["input"], examples["output"])]
            else:
                combined_text = [f"Instruction: {inst}\nOutput: {out}" 
                               for inst, out in zip(examples["instruction"], examples["output"])]
            examples["text"] = combined_text
            text_column = "text"
        else:
            # Try to find any text-like column
            text_columns = [col for col in examples.keys() if col in ["content", "body", "message", "prompt"]]
            if text_columns:
                text_column = text_columns[0]
            else:
                # Use the first column as text
                text_column = list(examples.keys())[0]
        
        # Ensure we have a list of strings, not nested lists
        texts = examples[text_column]
        if isinstance(texts[0], list):
            # Flatten the list of lists
            texts = [item for sublist in texts for item in sublist]
        
        return tokenizer(
            texts,
            truncation=True,
            padding=False,  # Don't pad here, let data collator handle it
            max_length=block_size,
            return_tensors=None,
        )
    
    # Tokenize dataset
    if dataset_name == "jsonl":
        # For local datasets, tokenize each split separately
        tokenized_train = dataset["train"].map(
            tokenize_function,
            batched=True,
            remove_columns=dataset["train"].column_names,
            desc="Tokenizing train dataset",
        )
        tokenized_validation = dataset["validation"].map(
            tokenize_function,
            batched=True,
            remove_columns=dataset["validation"].column_names,
            desc="Tokenizing validation dataset",
        )
        tokenized_dataset = {"train": tokenized_train, "validation": tokenized_validation}
    else:
        # For Hugging Face datasets, use the standard approach
        tokenized_dataset = dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=dataset["train"].column_names,
            desc="Tokenizing dataset",
        )
    
    # Remove text column after tokenization to avoid data collator issues
    if dataset_name == "jsonl":
        # Remove text column from both splits
        tokenized_dataset["train"] = tokenized_dataset["train"].remove_columns(["text"])
        tokenized_dataset["validation"] = tokenized_dataset["validation"].remove_columns(["text"])
    else:
        # For Hugging Face datasets, remove text column if it exists
        if "text" in tokenized_dataset["train"].column_names:
            tokenized_dataset = tokenized_dataset.remove_columns(["text"])
    
    def group_texts(examples):
        """Group texts into fixed-length blocks."""
        # For local datasets, we need to handle the text differently
        if dataset_name == "jsonl":
            # Each example should already be tokenized
            # We need to ensure all sequences are the same length (block_size)
            result = {}
            
            # Process each sequence to ensure consistent length
            for k in examples.keys():
                if k in ["input_ids", "attention_mask"]:
                    # Pad or truncate to block_size
                    processed_sequences = []
                    for seq in examples[k]:
                        if len(seq) > block_size:
                            # Truncate if too long
                            processed_sequences.append(seq[:block_size])
                        elif len(seq) < block_size:
                            # Pad if too short
                            if k == "input_ids":
                                # Pad with pad_token_id
                                pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
                                padded_seq = seq + [pad_token_id] * (block_size - len(seq))
                            else:  # attention_mask
                                # Pad with 0s
                                padded_seq = seq + [0] * (block_size - len(seq))
                            processed_sequences.append(padded_seq)
                        else:
                            # Already correct length
                            processed_sequences.append(seq)
                    result[k] = processed_sequences
                else:
                    # Keep other columns as is
                    result[k] = examples[k]
            
            # Create labels (same as input_ids for causal LM)
            result["labels"] = result["input_ids"].copy()
            return result
        else:
            # For Hugging Face datasets, use the original logic
            # Concatenate all texts - handle both lists and strings
            concatenated_examples = {}
            for k in examples.keys():
                if isinstance(examples[k][0], list):
                    # If it's already a list of lists, concatenate
                    concatenated_examples[k] = sum(examples[k], [])
                else:
                    # If it's a list of strings, just use as is
                    concatenated_examples[k] = examples[k]
            
            # Create blocks
            total_length = len(concatenated_examples[list(examples.keys())[0]])
            total_length = (total_length // block_size) * block_size
            
            # Split by chunks of max_len
            result = {
                k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
                for k, t in concatenated_examples.items()
            }
            
            # Create labels (same as input_ids for causal LM)
            result["labels"] = result["input_ids"].copy()
            
            return result
    
    # Group texts into blocks
    if dataset_name == "jsonl":
        # For local datasets, group each split separately
        lm_train = tokenized_dataset["train"].map(
            group_texts,
            batched=True,
            desc="Grouping train texts into blocks",
        )
        lm_validation = tokenized_dataset["validation"].map(
            group_texts,
            batched=True,
            desc="Grouping validation texts into blocks",
        )
        lm_dataset = {"train": lm_train, "validation": lm_validation}
    else:
        # For Hugging Face datasets, use the standard approach
        lm_dataset = tokenized_dataset.map(
            group_texts,
            batched=True,
            desc="Grouping texts into blocks",
        )
    
    return lm_dataset


def display_training_summary(metrics: dict, model_name: str, dataset_name: str, dataset_config: str, output_dir: str):
    """
    Display a beautiful, comprehensive training summary.
    """
    from rich.table import Table
    from rich.panel import Panel
    from rich.text import Text
    from rich.columns import Columns
    from datetime import datetime
    
    # Create main summary table
    summary_table = Table(title="๐ŸŽ‰ LoRA Training Complete!", show_header=True, header_style="bold magenta")
    summary_table.add_column("Metric", style="cyan", no_wrap=True)
    summary_table.add_column("Value", style="green")
    
    # Add key metrics
    summary_table.add_row("Model", model_name)
    summary_table.add_row("Dataset", f"{dataset_name}/{dataset_config}")
    summary_table.add_row("Output Directory", output_dir)
    summary_table.add_row("", "")  # Empty row for spacing
    
    # Training metrics
    train_loss = metrics.get("train_loss", "N/A")
    eval_loss = metrics.get("eval_loss", "N/A")
    total_steps = metrics.get("total_steps", "N/A")
    epochs = metrics.get("epoch", "N/A")
    
    summary_table.add_row("Final Train Loss", f"{train_loss:.4f}" if isinstance(train_loss, (int, float)) else str(train_loss))
    summary_table.add_row("Final Eval Loss", f"{eval_loss:.4f}" if isinstance(eval_loss, (int, float)) else str(eval_loss))
    summary_table.add_row("Total Steps", str(total_steps))
    summary_table.add_row("Epochs", f"{epochs:.2f}" if isinstance(epochs, (int, float)) else str(epochs))
    summary_table.add_row("", "")  # Empty row for spacing
    
    # Performance metrics
    runtime = metrics.get("train_runtime", "N/A")
    samples_per_sec = metrics.get("train_samples_per_second", "N/A")
    steps_per_sec = metrics.get("train_steps_per_second", "N/A")
    
    if isinstance(runtime, (int, float)):
        hours = int(runtime // 3600)
        minutes = int((runtime % 3600) // 60)
        seconds = int(runtime % 60)
        runtime_str = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
    else:
        runtime_str = str(runtime)
    
    summary_table.add_row("Training Time", runtime_str)
    summary_table.add_row("Samples/sec", f"{samples_per_sec:.2f}" if isinstance(samples_per_sec, (int, float)) else str(samples_per_sec))
    summary_table.add_row("Steps/sec", f"{steps_per_sec:.3f}" if isinstance(steps_per_sec, (int, float)) else str(steps_per_sec))
    
    # Create performance panel
    performance_text = Text()
    performance_text.append("๐Ÿš€ Performance Summary\n", style="bold blue")
    performance_text.append(f"โ€ข Training completed in {runtime_str}\n", style="white")
    performance_text.append(f"โ€ข Processed {samples_per_sec:.1f} samples/second\n", style="white")
    performance_text.append(f"โ€ข Achieved {steps_per_sec:.3f} steps/second\n", style="white")
    performance_text.append(f"โ€ข Final train loss: {train_loss:.4f}\n", style="white")
    
    # Add evaluation metrics if available
    if isinstance(eval_loss, (int, float)) and eval_loss != "N/A":
        eval_runtime = metrics.get("eval_runtime", "N/A")
        eval_samples_per_sec = metrics.get("eval_samples_per_second", "N/A")
        eval_steps_per_sec = metrics.get("eval_steps_per_second", "N/A")
        
        performance_text.append(f"โ€ข Final eval loss: {eval_loss:.4f}\n", style="white")
        if isinstance(eval_runtime, (int, float)):
            performance_text.append(f"โ€ข Eval time: {eval_runtime:.2f}s\n", style="white")
        if isinstance(eval_samples_per_sec, (int, float)):
            performance_text.append(f"โ€ข Eval speed: {eval_samples_per_sec:.1f} samples/sec\n", style="white")
    
    performance_panel = Panel(performance_text, title="๐Ÿ“Š Performance", border_style="blue")
    
    # Create next steps panel
    next_steps_text = Text()
    next_steps_text.append("๐ŸŽฏ Next Steps\n", style="bold green")
    next_steps_text.append("โ€ข Your LoRA adapters are saved in the output directory\n", style="white")
    next_steps_text.append("โ€ข Use the model for inference or further fine-tuning\n", style="white")
    next_steps_text.append("โ€ข Check the training_summary.json for detailed metrics\n", style="white")
    next_steps_text.append("โ€ข Consider running evaluation on a test set\n", style="white")
    
    next_steps_panel = Panel(next_steps_text, title="๐Ÿ”ฎ Next Steps", border_style="green")
    
    # Display everything
    console.print("\n")
    console.print(summary_table)
    console.print("\n")
    
    # Create columns for panels
    columns = Columns([performance_panel, next_steps_panel], equal=True, expand=True)
    console.print(columns)
    
    # Final success message
    console.print("\n[bold green]๐ŸŽ‰ LoRA Training Successfully Completed! ๐ŸŽ‰[/bold green]")
    console.print(f"[blue]๐Ÿ“ Model saved to: [bold]{output_dir}[/bold][/blue]")
    console.print(f"[blue]๐Ÿ“Š Training metrics: [bold]{metrics}[/bold][/blue]")
    console.print("\n[bold cyan]Thank you for using Humigence! ๐Ÿš€[/bold cyan]\n")


def validate_model_and_dataset(model_name: str, dataset_name: str) -> bool:
    """
    Validate that the model and dataset are accessible.
    
    Args:
        model_name: Name or path of the model
        dataset_name: Name of the dataset
        
    Returns:
        True if validation passes, False otherwise
    """
    try:
        console.print(f"[blue]๐Ÿ” Validating model: {model_name}[/blue]")
        
        # Test tokenizer loading
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Test dataset loading
        console.print(f"[blue]๐Ÿ” Validating dataset: {dataset_name}[/blue]")
        
        if dataset_name == "jsonl":
            # For local JSONL files, just check if file exists
            import os
            if not os.path.exists(dataset_name):
                console.print(f"[red]โŒ Local dataset file not found: {dataset_name}[/red]")
                return False
        else:
            # For Hugging Face datasets
            dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
        
        console.print("[green]โœ… Model and dataset validation passed![/green]")
        return True
        
    except Exception as e:
        console.print(f"[red]โŒ Validation failed: {e}[/red]")
        return False


def train_lora_single_gpu(
    model_name: str,
    dataset_name: str = "wikitext",
    dataset_config: str = "wikitext-2-raw-v1",
    output_dir: str = "./out_lora",
    max_steps: int = 1000,
    batch_size: int = 4,
    grad_accum: int = 4,
    learning_rate: float = 2e-4,
    block_size: int = 512,
    lora_r: int = 16,
    lora_alpha: int = 32,
    lora_dropout: float = 0.05,
    warmup_steps: int = 100,
    logging_steps: int = 10,
    save_steps: int = 200,
    eval_steps: int = 200,
    save_total_limit: int = 2,
    **kwargs
) -> Dict[str, Any]:
    """
    Main training function for single-GPU LoRA fine-tuning.
    
    Args:
        model_name: Name or path of the model to fine-tune
        dataset_name: Name of the dataset (e.g., "wikitext")
        dataset_config: Dataset configuration (e.g., "wikitext-2-raw-v1")
        output_dir: Directory to save the trained model
        max_steps: Maximum number of training steps
        batch_size: Per-device batch size
        grad_accum: Gradient accumulation steps
        learning_rate: Learning rate
        block_size: Block size for text grouping
        lora_r: LoRA rank
        lora_alpha: LoRA alpha
        lora_dropout: LoRA dropout
        warmup_steps: Number of warmup steps
        logging_steps: Logging frequency
        save_steps: Save frequency
        eval_steps: Evaluation frequency
        save_total_limit: Maximum number of checkpoints to keep
        
    Returns:
        Dictionary with training results
    """
    
    # Validate inputs
    if not validate_model_and_dataset(model_name, dataset_config if dataset_name == "jsonl" else dataset_name):
        return {"status": "error", "error": "Model or dataset validation failed"}
    
    try:
        console.print(f"[bold green]๐Ÿš€ Starting LoRA fine-tuning[/bold green]")
        console.print(f"[blue]Model: {model_name}[/blue]")
        console.print(f"[blue]Dataset: {dataset_name}/{dataset_config}[/blue]")
        console.print(f"[blue]Output: {output_dir}[/blue]")
        
        # Load tokenizer
        console.print("[blue]๐Ÿ“ Loading tokenizer...[/blue]")
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Load model
        console.print("[blue]๐Ÿค– Loading model...[/blue]")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
        )
        
        # Get model-specific target modules
        target_modules = get_model_target_modules(model_name)
        console.print(f"[blue]๐ŸŽฏ Using target modules: {target_modules}[/blue]")
        
        # Configure LoRA
        lora_config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            target_modules=target_modules,
            bias="none",
            task_type="CAUSAL_LM",
        )
        
        # Apply LoRA
        console.print("[blue]๐Ÿ”ง Applying LoRA configuration...[/blue]")
        model = get_peft_model(model, lora_config)
        
        # CRITICAL: Enable input gradients for PEFT models
        # This is essential for gradient checkpointing to work with LoRA
        model.enable_input_require_grads()
        
        # Print trainable parameters
        model.print_trainable_parameters()
        
        # Prepare dataset
        console.print("[blue]๐Ÿ“Š Preparing dataset...[/blue]")
        dataset = prepare_dataset(tokenizer, dataset_name, dataset_config, block_size)
        
        # Data collator - this is crucial for proper batching
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm=False,  # We're doing causal LM, not masked LM
            pad_to_multiple_of=8,  # For efficiency
        )
        
        # Training arguments
        training_args = TrainingArguments(
            output_dir=output_dir,
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=grad_accum,
            max_steps=max_steps,
            learning_rate=learning_rate,
            warmup_steps=warmup_steps,
            logging_steps=logging_steps,
            save_steps=save_steps,
            eval_steps=eval_steps,
            eval_strategy="steps",
            bf16=True,
            gradient_checkpointing=True,
            # Use non-reentrant checkpointing to avoid gradient issues
            gradient_checkpointing_kwargs={"use_reentrant": False},
            save_total_limit=save_total_limit,
            report_to="none",
            # Memory optimizations
            dataloader_drop_last=True,
            remove_unused_columns=False,
        )
        
        # Initialize trainer with custom trainer class
        trainer = LoRATrainer(
            model=model,
            args=training_args,
            train_dataset=dataset["train"],
            eval_dataset=dataset["validation"],
            data_collator=data_collator,
            tokenizer=tokenizer,
        )
        
        # Test the setup before training
        console.print("[blue]๐Ÿงช Testing model setup...[/blue]")
        test_batch = next(iter(trainer.get_train_dataloader()))
        console.print(f"[blue]Test batch keys: {list(test_batch.keys())}[/blue]")
        console.print(f"[blue]Test batch shapes: {[(k, v.shape) for k, v in test_batch.items()]}[/blue]")
        
        # Test forward pass
        model.eval()
        with torch.no_grad():
            test_outputs = model(**test_batch)
            if hasattr(test_outputs, 'loss') and test_outputs.loss is not None:
                console.print(f"[green]โœ… Test loss: {test_outputs.loss.item()}[/green]")
            else:
                console.print("[yellow]โš ๏ธ No loss in test outputs![/yellow]")
        
        # Start training
        console.print("[bold green]๐Ÿƒ Starting training...[/bold green]")
        start_time = time.time()
        
        with Progress(
            SpinnerColumn(),
            TextColumn("[progress.description]{task.description}"),
            BarColumn(),
            TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
            TimeElapsedColumn(),
            console=console,
        ) as progress:
            task = progress.add_task("Training...", total=max_steps)
            
            # Start training
            training_result = trainer.train()
            
            # Update progress
            progress.update(task, completed=max_steps)
        
        # Save the final model
        console.print("[blue]๐Ÿ’พ Saving final model...[/blue]")
        trainer.save_model()
        tokenizer.save_pretrained(output_dir)
        
        # Get final evaluation metrics
        console.print("[blue]๐Ÿ“Š Running final evaluation...[/blue]")
        eval_metrics = trainer.evaluate()
        
        # Calculate final metrics
        end_time = time.time()
        training_time = end_time - start_time
        
        final_metrics = {
            "train_loss": training_result.training_loss,
            "train_runtime": training_result.metrics.get("train_runtime", training_time),
            "train_samples_per_second": training_result.metrics.get("train_samples_per_second", 0),
            "train_steps_per_second": training_result.metrics.get("train_steps_per_second", 0),
            "total_steps": training_result.global_step,
            "epochs": training_result.metrics.get("epoch", 0),
            # Add evaluation metrics from final evaluation
            "eval_loss": eval_metrics.get("eval_loss", "N/A"),
            "eval_runtime": eval_metrics.get("eval_runtime", "N/A"),
            "eval_samples_per_second": eval_metrics.get("eval_samples_per_second", "N/A"),
            "eval_steps_per_second": eval_metrics.get("eval_steps_per_second", "N/A"),
            "model_name": model_name,
            "dataset": f"{dataset_name}/{dataset_config}",
            "output_dir": output_dir,
        }
        
        # Save training summary
        summary_path = Path(output_dir) / "training_summary.json"
        with open(summary_path, "w") as f:
            json.dump(final_metrics, f, indent=2)
        
        # Display beautiful training summary
        display_training_summary(final_metrics, model_name, dataset_name, dataset_config, output_dir)
        
        return {
            "status": "success",
            "metrics": final_metrics,
            "output_dir": output_dir,
            "model_path": output_dir,
        }
        
    except Exception as e:
        console.print(f"[bold red]โŒ Training failed: {str(e)}[/bold red]")
        logger.exception("Training failed with exception:")
        return {
            "status": "error",
            "error": str(e),
            "output_dir": output_dir,
        }


def main():
    """Main entry point for the script."""
    parser = argparse.ArgumentParser(
        description="Single-GPU LoRA Fine-Tuning for Humigence",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__
    )
    
    # Required arguments
    parser.add_argument("--model", type=str, required=True, help="Model name or path")
    parser.add_argument("--output-dir", type=str, required=True, help="Output directory")
    
    # Dataset arguments
    parser.add_argument("--dataset", type=str, default="wikitext", help="Dataset name")
    parser.add_argument("--dataset-config", type=str, default="wikitext-2-raw-v1", help="Dataset configuration")
    
    # Training arguments
    parser.add_argument("--max-steps", type=int, default=1000, help="Maximum training steps")
    parser.add_argument("--batch-size", type=int, default=4, help="Per-device batch size")
    parser.add_argument("--grad-accum", type=int, default=4, help="Gradient accumulation steps")
    parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate")
    parser.add_argument("--block-size", type=int, default=512, help="Block size for text grouping")
    
    # LoRA arguments
    parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
    parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
    parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
    
    # Other arguments
    parser.add_argument("--warmup-steps", type=int, default=100, help="Number of warmup steps")
    parser.add_argument("--logging-steps", type=int, default=10, help="Logging frequency")
    parser.add_argument("--save-steps", type=int, default=200, help="Save frequency")
    parser.add_argument("--eval-steps", type=int, default=200, help="Evaluation frequency")
    parser.add_argument("--save-total-limit", type=int, default=2, help="Maximum number of checkpoints to keep")
    
    args = parser.parse_args()
    
    # Create output directory
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    
    # Run training
    result = train_lora_single_gpu(
        model_name=args.model,
        dataset_name=args.dataset,
        dataset_config=args.dataset_config,
        output_dir=args.output_dir,
        max_steps=args.max_steps,
        batch_size=args.batch_size,
        grad_accum=args.grad_accum,
        learning_rate=args.learning_rate,
        block_size=args.block_size,
        lora_r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        warmup_steps=args.warmup_steps,
        logging_steps=args.logging_steps,
        save_steps=args.save_steps,
        eval_steps=args.eval_steps,
        save_total_limit=args.save_total_limit,
    )
    
    # Exit with appropriate code
    if result["status"] == "success":
        console.print("[bold green]๐ŸŽ‰ Training completed successfully![/bold green]")
        sys.exit(0)
    else:
        console.print(f"[bold red]๐Ÿ’ฅ Training failed: {result.get('error', 'Unknown error')}[/bold red]")
        sys.exit(1)


if __name__ == "__main__":
    main()