hari7261 commited on
Commit
b1639a9
·
verified ·
1 Parent(s): 4cbf4c9

Create training.py

Browse files
Files changed (1) hide show
  1. training.py +98 -0
training.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForCausalLM,
5
+ TrainingArguments,
6
+ Trainer,
7
+ DataCollatorForLanguageModeling
8
+ )
9
+ from peft import LoraConfig, get_peft_model, merge_lora_weights
10
+ from huggingface_hub import login
11
+ import os
12
+
13
+ # ====== HF Login (optional if pushing to Hub) ======
14
+ hf_token = os.getenv("HF_TOKEN") # set this as env var or hardcode
15
+ if hf_token:
16
+ login(token=hf_token)
17
+
18
+ # ====== 1. Config ======
19
+ model_name = "mistralai/Mistral-7B-v0.1"
20
+ dataset_path = "tech_domains.jsonl" # local file or HF dataset
21
+ output_dir = "./TechChat"
22
+ max_seq_length = 512
23
+
24
+ # ====== 2. Load Dataset ======
25
+ dataset = load_dataset("json", data_files=dataset_path)
26
+
27
+ # ====== 3. Tokenizer ======
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+ tokenizer.pad_token = tokenizer.eos_token
30
+
31
+ def tokenize(example):
32
+ text = example["instruction"] + "\n" + example["output"]
33
+ tokens = tokenizer(
34
+ text,
35
+ truncation=True,
36
+ padding="max_length",
37
+ max_length=max_seq_length
38
+ )
39
+ tokens["labels"] = tokens["input_ids"].copy()
40
+ return tokens
41
+
42
+ dataset = dataset.map(tokenize, batched=True, remove_columns=dataset["train"].column_names)
43
+
44
+ # ====== 4. Load Base Model ======
45
+ model = AutoModelForCausalLM.from_pretrained(model_name)
46
+
47
+ # ====== 5. Apply LoRA ======
48
+ lora_config = LoraConfig(
49
+ r=8,
50
+ lora_alpha=16,
51
+ target_modules=["q_proj", "v_proj"],
52
+ lora_dropout=0.1,
53
+ bias="none",
54
+ task_type="CAUSAL_LM"
55
+ )
56
+ model = get_peft_model(model, lora_config)
57
+
58
+ # ====== 6. Data Collator ======
59
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
60
+
61
+ # ====== 7. Training Args ======
62
+ args = TrainingArguments(
63
+ output_dir="./lora_tmp",
64
+ per_device_train_batch_size=2,
65
+ gradient_accumulation_steps=4,
66
+ warmup_steps=50,
67
+ max_steps=1000,
68
+ learning_rate=2e-4,
69
+ fp16=True,
70
+ logging_steps=10,
71
+ save_strategy="no" # We'll save after merging
72
+ )
73
+
74
+ # ====== 8. Trainer ======
75
+ trainer = Trainer(
76
+ model=model,
77
+ args=args,
78
+ train_dataset=dataset["train"],
79
+ data_collator=data_collator
80
+ )
81
+
82
+ # ====== 9. Train ======
83
+ trainer.train()
84
+
85
+ # ====== 10. Merge LoRA into full model ======
86
+ print("Merging LoRA weights into the base model...")
87
+ model = merge_lora_weights(model)
88
+
89
+ # ====== 11. Save Full Model ======
90
+ model.save_pretrained(output_dir)
91
+ tokenizer.save_pretrained(output_dir)
92
+
93
+ print(f"✅ Full model saved at {output_dir}")
94
+
95
+ # ====== 12. (Optional) Push to Hugging Face Hub ======
96
+ # Uncomment to push
97
+ # model.push_to_hub("hari7261/TechChat", use_temp_dir=False)
98
+ # tokenizer.push_to_hub("hari7261/TechChat", use_temp_dir=False)