ceperaltab commited on
Commit
6c0db32
·
verified ·
1 Parent(s): e9f43ee

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. requirements_train.txt +8 -0
  2. run.py +18 -0
  3. run.sh +3 -0
  4. train.py +106 -0
requirements_train.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ datasets
4
+ peft
5
+ bitsandbytes
6
+ trl
7
+ accelerate
8
+ scipy
run.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ import os
4
+
5
+ def install_dependencies():
6
+ print("Installing dependencies...")
7
+ # Use -v for verbose output so user sees progress
8
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-v", "-r", "requirements_train.txt"])
9
+
10
+ def main():
11
+ install_dependencies()
12
+ print("Dependencies installed. Starting training...")
13
+ # Import train only after dependencies are installed
14
+ import train
15
+ train.main()
16
+
17
+ if __name__ == "__main__":
18
+ main()
run.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+ pip install -r requirements_train.txt
3
+ python train.py
train.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from datasets import load_dataset
4
+ from transformers import (
5
+ AutoModelForCausalLM,
6
+ AutoTokenizer,
7
+ BitsAndBytesConfig,
8
+ TrainingArguments,
9
+ )
10
+ from peft import LoraConfig
11
+ from trl import SFTTrainer
12
+
13
+ # --- CONFIGURATION ---
14
+ # Base model: Using a quantized Llama 3 or Mistral is recommended for consumer GPUs.
15
+ # Ensure you have access to the model on Hugging Face (might need login).
16
+ MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
17
+ DATASET_NAME = "ceperaltab/elixir-golden-dataset"
18
+ OUTPUT_DIR = "elixir-model-adapter"
19
+
20
+ def main():
21
+ print(f"Loading dataset from {DATASET_NAME}...")
22
+ # 1. Load Dataset
23
+ try:
24
+ # Load directly from HF Hub
25
+ dataset = load_dataset(DATASET_NAME, split="train")
26
+ except Exception as e:
27
+ print(f"Error loading dataset: {e}")
28
+ return
29
+
30
+ # 2. Quantization Config (4-bit for memory efficiency)
31
+ bnb_config = BitsAndBytesConfig(
32
+ load_in_4bit=True,
33
+ bnb_4bit_quant_type="nf4",
34
+ bnb_4bit_compute_dtype=torch.float16,
35
+ )
36
+
37
+ print(f"Loading base model: {MODEL_NAME}...")
38
+ # 3. Load Model
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ MODEL_NAME,
41
+ quantization_config=bnb_config,
42
+ device_map="auto",
43
+ trust_remote_code=True
44
+ )
45
+
46
+ # 4. Load Tokenizer
47
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+ tokenizer.padding_side = "right" # Critical for fp16 training
50
+
51
+ # 5. LoRA Config (Parameter Efficient Fine-Tuning)
52
+ peft_config = LoraConfig(
53
+ lora_alpha=16,
54
+ lora_dropout=0.1,
55
+ r=64,
56
+ bias="none",
57
+ task_type="CAUSAL_LM",
58
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
59
+ )
60
+
61
+ # 6. Formatting Function for Chat Dataset
62
+ # Converts {"messages": [...]} into the model's expected prompt format
63
+ def formatting_prompts_func(examples):
64
+ output_texts = []
65
+ for messages in examples['messages']:
66
+ # Apply chat template (e.g., <|begin_of_text|><|start_header_id|>user...)
67
+ # We don't tokenize yet, SFTTrainer handles it
68
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
69
+ output_texts.append(text)
70
+ return output_texts
71
+
72
+ print("Starting SFTTrainer setup...")
73
+ # 7. Trainer
74
+ trainer = SFTTrainer(
75
+ model=model,
76
+ train_dataset=dataset,
77
+ peft_config=peft_config,
78
+ formatting_func=formatting_prompts_func,
79
+ max_seq_length=2048,
80
+ tokenizer=tokenizer,
81
+ args=TrainingArguments(
82
+ output_dir=OUTPUT_DIR,
83
+ per_device_train_batch_size=2,
84
+ gradient_accumulation_steps=4, # Simulate larger batch size
85
+ learning_rate=2e-4,
86
+ logging_steps=10,
87
+ num_train_epochs=1,
88
+ optim="paged_adamw_32bit",
89
+ fp16=True,
90
+ group_by_length=True,
91
+ save_strategy="epoch",
92
+ report_to="none", # Change to "wandb" if desired
93
+ push_to_hub=True,
94
+ hub_model_id=f"ceperaltab/{OUTPUT_DIR}", # Pushes to your namespace
95
+ ),
96
+ )
97
+
98
+ print("Starting training...")
99
+ trainer.train()
100
+
101
+ print(f"Saving model to {OUTPUT_DIR}...")
102
+ trainer.save_model(OUTPUT_DIR)
103
+ print("Done!")
104
+
105
+ if __name__ == "__main__":
106
+ main()