File size: 3,820 Bytes
5e0532d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os

# --- IMPORT PRIORITY: Unsloth must be imported before transformers/torch if used ---
try:
    import torch
    if torch.cuda.is_available():
        # Force unsloth import first on GPU
        from unsloth import FastLanguageModel
        HAS_UNSLOTH = True
        print("ORA Trainer: Unsloth imported successfully.")
    else:
        HAS_UNSLOTH = False
except ImportError:
    HAS_UNSLOTH = False

# Now safe to import others
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
)
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig

# Settings
MODEL_NAME = "unsloth/Llama-3.2-1B-Instruct" # Base model
MAX_SEQ_LENGTH = 2048
OUTPUT_DIR = "important/finetuning/models/ora_adapter"

def train_ora(max_steps=5):
    has_cuda = torch.cuda.is_available()
    print(f"ORA Trainer: CUDA Detected = {has_cuda}")

    if has_cuda and HAS_UNSLOTH:
        # --- MODE: GPU (Unsloth/QLoRA) ---
        print("ORA Trainer: Using GPU + Unsloth (Standard for Google Colab)")
        
        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
            max_seq_length = MAX_SEQ_LENGTH,
            load_in_4bit = True,
        )

        model = FastLanguageModel.get_peft_model(
            model,
            r = 16,
            target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
            lora_alpha = 32,
            lora_dropout = 0,
            bias = "none",
            use_gradient_checkpointing = "unsloth",
            random_state = 3407,
        )
    else:
        # --- MODE: CPU (Standard PEFT) ---
        print("ORA Trainer: Using CPU + Standard PEFT (Local Hardware Mode)")
        from peft import LoraConfig, get_peft_model, TaskType
        
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        tokenizer.pad_token = tokenizer.eos_token
        
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.float32,
            device_map="cpu",
            low_cpu_mem_usage=True
        )

        lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type=TaskType.CAUSAL_LM,
        )
        model = get_peft_model(model, lora_config)

    # Common Dataset Loading
    data_path = "important/curated_data/final_ora_dataset.jsonl"
    if not os.path.exists(data_path):
        print(f"Error: Dataset {data_path} not found. Run consolidation first!")
        return

    dataset = load_dataset("json", data_files=data_path, split="train")

    # SFTConfig (replaces TrainingArguments + extra SFT args)
    training_args = SFTConfig(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=2 if has_cuda else 1,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        max_steps=max_steps,
        logging_steps=1,
        save_strategy="no",
        use_cpu=not has_cuda,
        report_to="none",
        max_length=MAX_SEQ_LENGTH,
        dataset_text_field="text",
        dataset_num_proc=2, # Limit processes to avoid pickling errors
    )

    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        processing_class=tokenizer,
        args=training_args,
    )

    print(f"ORA Trainer: Starting training ({max_steps} steps)...")
    trainer.train()

    print(f"ORA Trainer: Saving adapter to {OUTPUT_DIR}...")
    model.save_pretrained(OUTPUT_DIR)
    tokenizer.save_pretrained(OUTPUT_DIR)
    print("ORA Trainer: Training complete.")

if __name__ == "__main__":
    train_ora(max_steps=100) # Increased default steps for user utility, they can kill it if needed