hari7261 commited on
Commit
ccdeea5
·
verified ·
1 Parent(s): fac7607

Delete training.py

Browse files
Files changed (1) hide show
  1. training.py +0 -98
training.py DELETED
@@ -1,98 +0,0 @@
1
- from datasets import load_dataset
2
- from transformers import (
3
- AutoTokenizer,
4
- AutoModelForCausalLM,
5
- TrainingArguments,
6
- Trainer,
7
- DataCollatorForLanguageModeling
8
- )
9
- from peft import LoraConfig, get_peft_model, merge_lora_weights
10
- from huggingface_hub import login
11
- import os
12
-
13
- # ====== HF Login (optional if pushing to Hub) ======
14
- hf_token = os.getenv("HF_TOKEN") # set this as env var or hardcode
15
- if hf_token:
16
- login(token=hf_token)
17
-
18
- # ====== 1. Config ======
19
- model_name = "mistralai/Mistral-7B-v0.1"
20
- dataset_path = "tech_domains.jsonl" # local file or HF dataset
21
- output_dir = "./TechChat"
22
- max_seq_length = 512
23
-
24
- # ====== 2. Load Dataset ======
25
- dataset = load_dataset("json", data_files=dataset_path)
26
-
27
- # ====== 3. Tokenizer ======
28
- tokenizer = AutoTokenizer.from_pretrained(model_name)
29
- tokenizer.pad_token = tokenizer.eos_token
30
-
31
- def tokenize(example):
32
- text = example["instruction"] + "\n" + example["output"]
33
- tokens = tokenizer(
34
- text,
35
- truncation=True,
36
- padding="max_length",
37
- max_length=max_seq_length
38
- )
39
- tokens["labels"] = tokens["input_ids"].copy()
40
- return tokens
41
-
42
- dataset = dataset.map(tokenize, batched=True, remove_columns=dataset["train"].column_names)
43
-
44
- # ====== 4. Load Base Model ======
45
- model = AutoModelForCausalLM.from_pretrained(model_name)
46
-
47
- # ====== 5. Apply LoRA ======
48
- lora_config = LoraConfig(
49
- r=8,
50
- lora_alpha=16,
51
- target_modules=["q_proj", "v_proj"],
52
- lora_dropout=0.1,
53
- bias="none",
54
- task_type="CAUSAL_LM"
55
- )
56
- model = get_peft_model(model, lora_config)
57
-
58
- # ====== 6. Data Collator ======
59
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
60
-
61
- # ====== 7. Training Args ======
62
- args = TrainingArguments(
63
- output_dir="./lora_tmp",
64
- per_device_train_batch_size=2,
65
- gradient_accumulation_steps=4,
66
- warmup_steps=50,
67
- max_steps=1000,
68
- learning_rate=2e-4,
69
- fp16=True,
70
- logging_steps=10,
71
- save_strategy="no" # We'll save after merging
72
- )
73
-
74
- # ====== 8. Trainer ======
75
- trainer = Trainer(
76
- model=model,
77
- args=args,
78
- train_dataset=dataset["train"],
79
- data_collator=data_collator
80
- )
81
-
82
- # ====== 9. Train ======
83
- trainer.train()
84
-
85
- # ====== 10. Merge LoRA into full model ======
86
- print("Merging LoRA weights into the base model...")
87
- model = merge_lora_weights(model)
88
-
89
- # ====== 11. Save Full Model ======
90
- model.save_pretrained(output_dir)
91
- tokenizer.save_pretrained(output_dir)
92
-
93
- print(f"✅ Full model saved at {output_dir}")
94
-
95
- # ====== 12. (Optional) Push to Hugging Face Hub ======
96
- # Uncomment to push
97
- # model.push_to_hub("hari7261/TechChat", use_temp_dir=False)
98
- # tokenizer.push_to_hub("hari7261/TechChat", use_temp_dir=False)