VaibhavHD commited on
Commit
a952b69
·
verified ·
1 Parent(s): 5352ede

Update train_lora.py

Browse files
Files changed (1) hide show
  1. train_lora.py +180 -50
train_lora.py CHANGED
@@ -1,50 +1,180 @@
1
- import os, json, torch, wandb
2
- from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
3
- TrainingArguments, DataCollatorForLanguageModeling)
4
- from datasets import load_dataset
5
- from peft import LoraConfig, get_peft_model
6
- from huggingface_hub import HfApi
7
-
8
- HF_TOKEN = os.getenv("HF_TOKEN")
9
- WANDB_API_KEY = os.getenv("WANDB_API_KEY")
10
- wandb.login(key=WANDB_API_KEY)
11
-
12
- model_name = "deepseek-ai/deepseek-coder-1.3b-base"
13
- dataset = load_dataset("westenfelder/NL2SH-ALFA")
14
-
15
- tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
16
- def tok_fn(b): return tok([f"{n} => {bsh}" for n,bsh in zip(b['nl'],b['bash'])],
17
- truncation=True,padding="max_length",max_length=512)
18
- train, test = dataset["train"].map(tok_fn,batched=True), dataset["test"].map(tok_fn,batched=True)
19
-
20
- m = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16,
21
- low_cpu_mem_usage=True, device_map="auto",
22
- trust_remote_code=True)
23
- m.config.use_cache=False
24
- for p in m.parameters(): p.requires_grad=False
25
-
26
- cfg=LoraConfig(r=8,lora_alpha=16,target_modules=["q_proj","v_proj","k_proj","o_proj",
27
- "gate_proj","down_proj","up_proj"],
28
- lora_dropout=0.05,bias="none",task_type="CAUSAL_LM")
29
- m=get_peft_model(m,cfg)
30
- coll=DataCollatorForLanguageModeling(tokenizer=tok,mlm=False)
31
-
32
- args=TrainingArguments(output_dir="./out",num_train_epochs=1,per_device_train_batch_size=1,
33
- gradient_accumulation_steps=8,learning_rate=2e-4,fp16=True,
34
- save_strategy="epoch",logging_steps=25,report_to=["wandb"])
35
- t=Trainer(model=m,args=args,train_dataset=train,eval_dataset=test,data_collator=coll)
36
- wandb.init(project="deepseek-qlora-monthly",name="deepseek-lite-run")
37
- t.train()
38
-
39
- metrics=t.evaluate(); acc=1-metrics.get("eval_loss",1)
40
- with open("out/metrics.json","w") as f: json.dump(metrics,f)
41
- wandb.log({"accuracy":acc})
42
- print(f" Eval accuracy {acc:.4f}")
43
-
44
- ad="out/lora_adapters"; os.makedirs(ad,exist_ok=True)
45
- m.save_pretrained(ad); tok.save_pretrained(ad)
46
- artifact=wandb.Artifact("deepseek-lora-adapters","model"); artifact.add_dir(ad); wandb.log_artifact(artifact)
47
-
48
- api=HfApi(token=HF_TOKEN)
49
- api.upload_folder(folder_path=ad,repo_id="your-username/deepseek-lora-monthly",path_in_repo=".")
50
- print("✅ Uploaded to HF Hub")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ train_lora.py
4
+ - Fine-tune DeepSeek 1.3B with LoRA (QLoRA-ish setup)
5
+ - Save adapters using safe_serialization=True -> adapter_model.safetensors
6
+ - Upload adapter folder to Hugging Face Hub (VaibhavHD/deepseek-lora-monthly)
7
+ - Log metrics/artifact to Weights & Biases
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import wandb
13
+ import torch
14
+ from huggingface_hub import HfApi
15
+ from datasets import load_dataset
16
+ from transformers import (
17
+ AutoTokenizer, AutoModelForCausalLM,
18
+ TrainingArguments, Trainer, DataCollatorForLanguageModeling
19
+ )
20
+ from peft import LoraConfig, get_peft_model
21
+
22
+ # -----------------------------
23
+ # Config (edit if needed)
24
+ # -----------------------------
25
+ HF_REPO = "VaibhavHD/deepseek-lora-monthly" # your HF model repo
26
+ MODEL_NAME = "deepseek-ai/deepseek-coder-1.3b-base"
27
+ OUT_DIR = "out"
28
+ ADAPTER_DIR = os.path.join(OUT_DIR, "lora_adapters")
29
+
30
+ # env secrets expected:
31
+ HF_TOKEN = os.getenv("HF_TOKEN")
32
+ WANDB_API_KEY = os.getenv("WANDB_API_KEY")
33
+
34
+ if WANDB_API_KEY:
35
+ wandb.login(key=WANDB_API_KEY)
36
+ else:
37
+ print("⚠️ WANDB_API_KEY not found in env; continuing without W&B logging.")
38
+
39
+ # -----------------------------
40
+ # Load dataset
41
+ # -----------------------------
42
+ print("Loading dataset...")
43
+ dataset = load_dataset("westenfelder/NL2SH-ALFA")
44
+
45
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
46
+
47
+ def tokenize_fn(batch):
48
+ texts = [f"{nl} => {bash}" for nl, bash in zip(batch["nl"], batch["bash"])]
49
+ return tokenizer(texts, truncation=True, padding="max_length", max_length=512)
50
+
51
+ train = dataset["train"].map(tokenize_fn, batched=True)
52
+ test = dataset["test"].map(tokenize_fn, batched=True)
53
+
54
+ # Optional small-subset for fast runs (uncomment to use)
55
+ # train = train.shuffle(seed=42).select(range(200))
56
+ # test = test.shuffle(seed=42).select(range(20))
57
+
58
+ # -----------------------------
59
+ # Load base model (half precision)
60
+ # -----------------------------
61
+ print("Loading base model (may take a moment)...")
62
+ model = AutoModelForCausalLM.from_pretrained(
63
+ MODEL_NAME,
64
+ torch_dtype=torch.float16,
65
+ low_cpu_mem_usage=True,
66
+ device_map="auto",
67
+ trust_remote_code=True
68
+ )
69
+
70
+ # avoid caching issues
71
+ model.config.use_cache = False
72
+ for p in model.parameters():
73
+ p.requires_grad = False
74
+
75
+ # -----------------------------
76
+ # Attach LoRA
77
+ # -----------------------------
78
+ print("Attaching LoRA adapters...")
79
+ lora_config = LoraConfig(
80
+ r=8,
81
+ lora_alpha=16,
82
+ target_modules=[
83
+ "q_proj", "v_proj", "k_proj", "o_proj",
84
+ "gate_proj", "down_proj", "up_proj"
85
+ ],
86
+ lora_dropout=0.05,
87
+ bias="none",
88
+ task_type="CAUSAL_LM",
89
+ )
90
+ model = get_peft_model(model, lora_config)
91
+
92
+ # -----------------------------
93
+ # Data collator + training args
94
+ # -----------------------------
95
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
96
+
97
+ training_args = TrainingArguments(
98
+ output_dir=OUT_DIR,
99
+ num_train_epochs=1,
100
+ per_device_train_batch_size=1,
101
+ gradient_accumulation_steps=8,
102
+ learning_rate=2e-4,
103
+ fp16=True,
104
+ save_strategy="epoch",
105
+ logging_steps=25,
106
+ report_to=["wandb"] if WANDB_API_KEY else [],
107
+ )
108
+
109
+ trainer = Trainer(
110
+ model=model,
111
+ args=training_args,
112
+ train_dataset=train,
113
+ eval_dataset=test,
114
+ data_collator=data_collator,
115
+ )
116
+
117
+ # -----------------------------
118
+ # Run training
119
+ # -----------------------------
120
+ print("Starting training...")
121
+ if WANDB_API_KEY:
122
+ wandb.init(project="deepseek-qlora-monthly", name="deepseek-lite-run")
123
+
124
+ trainer.train()
125
+
126
+ # -----------------------------
127
+ # Evaluate and save metrics
128
+ # -----------------------------
129
+ print("Evaluating...")
130
+ metrics = trainer.evaluate()
131
+ # compute simple "accuracy-like" metric from loss (replace with real metric if you have one)
132
+ new_acc = 1.0 - metrics.get("eval_loss", 1.0)
133
+ print(f"Eval metrics: {metrics}")
134
+ print(f"Pseudo-accuracy (1 - eval_loss): {new_acc:.6f}")
135
+
136
+ os.makedirs(ADAPTER_DIR, exist_ok=True)
137
+ metrics_path = os.path.join(OUT_DIR, "metrics.json")
138
+ with open(metrics_path, "w") as f:
139
+ json.dump(metrics, f)
140
+
141
+ if WANDB_API_KEY:
142
+ wandb.log({"accuracy": new_acc})
143
+ # log artifact
144
+ artifact = wandb.Artifact(
145
+ name="deepseek-lora-adapters",
146
+ type="model",
147
+ description="LoRA adapters saved with safe_serialization"
148
+ )
149
+
150
+ # -----------------------------
151
+ # Save adapters using safe_serialization
152
+ # -----------------------------
153
+ print("Saving adapters with safe_serialization=True (produces .safetensors)...")
154
+ model.save_pretrained(ADAPTER_DIR, safe_serialization=True)
155
+ tokenizer.save_pretrained(ADAPTER_DIR)
156
+
157
+ # add to wandb artifact directory
158
+ if WANDB_API_KEY:
159
+ artifact.add_dir(ADAPTER_DIR)
160
+ wandb.log_artifact(artifact, aliases=["latest"])
161
+
162
+ print(f"Adapters saved to: {ADAPTER_DIR}")
163
+ print("Files in adapter dir:", os.listdir(ADAPTER_DIR))
164
+
165
+ # -----------------------------
166
+ # Upload to Hugging Face model repo
167
+ # -----------------------------
168
+ if HF_TOKEN:
169
+ print(f"Uploading adapter folder to Hugging Face repo: {HF_REPO}")
170
+ api = HfApi()
171
+ # upload_folder will overwrite same filenames in the repo
172
+ api.upload_folder(
173
+ folder_path=ADAPTER_DIR,
174
+ path_in_repo=".",
175
+ repo_id=HF_REPO,
176
+ token=HF_TOKEN
177
+ )
178
+ print("✅ Upload complete.")
179
+ else:
180
+ print("⚠️ HF_TOKEN not set. Skipping upload to Hugging Face Hub.")