AndreasThinks commited on
Commit
63a20c4
·
verified ·
1 Parent(s): a037b2e

Upload train_mistral.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_mistral.py +37 -4
train_mistral.py CHANGED
@@ -1,5 +1,5 @@
1
  # /// script
2
- # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "torch>=2.0.0", "transformers>=4.40.0", "accelerate>=0.20.0"]
3
  # ///
4
 
5
  """Fine-tune Mistral-7B-Instruct-v0.3 on NATO doctrine dataset."""
@@ -7,10 +7,41 @@
7
  from datasets import load_dataset
8
  from peft import LoraConfig
9
  from trl import SFTTrainer, SFTConfig
 
 
10
  import trackio
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Load dataset from HF Hub
13
- print("Loading NATO doctrine dataset...")
14
  dataset = load_dataset("AndreasThinks/nato-doctrine-sft", split="train")
15
  dataset_test = load_dataset("AndreasThinks/nato-doctrine-sft", split="test")
16
 
@@ -74,10 +105,11 @@ training_args = SFTConfig(
74
  seed=42,
75
  )
76
 
77
- # Initialize trainer
78
  print("\n✓ Initializing SFT trainer...")
79
  trainer = SFTTrainer(
80
- model="mistralai/Mistral-7B-Instruct-v0.3",
 
81
  train_dataset=dataset,
82
  eval_dataset=dataset_test,
83
  peft_config=peft_config,
@@ -101,4 +133,5 @@ trainer.push_to_hub()
101
 
102
  print("\n✅ Fine-tuning complete!")
103
  print(f" Model: https://huggingface.co/AndreasThinks/mistral-7b-nato-doctrine")
 
104
  print(f" Trackio: Check your dashboard for metrics")
 
1
  # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "torch>=2.0.0", "transformers>=4.40.0", "accelerate>=0.20.0", "bitsandbytes>=0.41.0"]
3
  # ///
4
 
5
  """Fine-tune Mistral-7B-Instruct-v0.3 on NATO doctrine dataset."""
 
7
  from datasets import load_dataset
8
  from peft import LoraConfig
9
  from trl import SFTTrainer, SFTConfig
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
11
+ import torch
12
  import trackio
13
 
14
+ # Model ID
15
+ model_id = "mistralai/Mistral-7B-Instruct-v0.3"
16
+
17
+ # Load tokenizer
18
+ print("Loading tokenizer...")
19
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
20
+ tokenizer.pad_token = tokenizer.eos_token
21
+ tokenizer.padding_side = "right"
22
+
23
+ # Load model with 4-bit quantization
24
+ print("Loading model with 4-bit quantization...")
25
+ bnb_config = BitsAndBytesConfig(
26
+ load_in_4bit=True,
27
+ bnb_4bit_use_double_quant=True,
28
+ bnb_4bit_quant_type="nf4",
29
+ bnb_4bit_compute_dtype=torch.bfloat16
30
+ )
31
+
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ model_id,
34
+ quantization_config=bnb_config,
35
+ device_map="auto",
36
+ trust_remote_code=True
37
+ )
38
+ model.config.use_cache = False
39
+ model.gradient_checkpointing_enable()
40
+
41
+ print(f"✓ Model loaded: {model_id}")
42
+
43
  # Load dataset from HF Hub
44
+ print("\nLoading NATO doctrine dataset...")
45
  dataset = load_dataset("AndreasThinks/nato-doctrine-sft", split="train")
46
  dataset_test = load_dataset("AndreasThinks/nato-doctrine-sft", split="test")
47
 
 
105
  seed=42,
106
  )
107
 
108
+ # Initialize trainer with loaded model and tokenizer
109
  print("\n✓ Initializing SFT trainer...")
110
  trainer = SFTTrainer(
111
+ model=model,
112
+ tokenizer=tokenizer,
113
  train_dataset=dataset,
114
  eval_dataset=dataset_test,
115
  peft_config=peft_config,
 
133
 
134
  print("\n✅ Fine-tuning complete!")
135
  print(f" Model: https://huggingface.co/AndreasThinks/mistral-7b-nato-doctrine")
136
+ print(f" Base: mistralai/Mistral-7B-Instruct-v0.3")
137
  print(f" Trackio: Check your dashboard for metrics")