miyuki2026 commited on
Commit
113e3f4
·
1 Parent(s): df0647f
examples/tutorials/lora_unsloth/step_2_train_model.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import platform
7
+
8
+ if platform.system() in ("Windows", "Darwin"):
9
+ from project_settings import project_path
10
+ else:
11
+ project_path = os.path.abspath("../../../")
12
+ project_path = Path(project_path)
13
+
14
+ # from unsloth import FastLanguageModel
15
+ # from trl import SFTTrainer, SFTConfig
16
+ from datasets import load_dataset
17
+ # import torch
18
+
19
+
20
+ def get_args():
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument(
23
+ "--model_name",
24
+ default="unsloth/Qwen3-8B-unsloth-bnb-4bit",
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--dataset_path",
29
+ default="miyuki2026/tutorials",
30
+ type=str
31
+ ),
32
+ parser.add_argument("--dataset_name", default=None, type=str),
33
+ parser.add_argument("--dataset_split", default=None, type=str),
34
+ parser.add_argument(
35
+ "--dataset_cache_dir",
36
+ default=(project_path / "hub_datasets").as_posix(),
37
+ type=str
38
+ ),
39
+ parser.add_argument("--dataset_streaming", default=None, type=str),
40
+
41
+ parser.add_argument(
42
+ "--num_workers",
43
+ default=None if platform.system() == "Windows" else os.cpu_count() // 2,
44
+ type=str
45
+ ),
46
+ args = parser.parse_args()
47
+ return args
48
+
49
+
50
+ def convert_to_qwen_format(example):
51
+ """
52
+
53
+ :param example: {"conversation_id": 612, "category": "", "conversation": [{"human": "", "assistant": ""}], "dataset": ""}
54
+ :return:
55
+ """
56
+ conversations = []
57
+ for conversation in example["conversation"]:
58
+ for turn in conversation:
59
+ conversations.append([
60
+ {"role": "user", "content": turn["human"].strip()},
61
+ {"role": "assistant", "content": turn["assistant"].strip()},
62
+ ])
63
+ result = {"conversations": conversations}
64
+ print(result)
65
+ exit(0)
66
+ return result
67
+
68
+
69
+ def main():
70
+ args = get_args()
71
+
72
+ model, tokenizer = FastLanguageModel.from_pretrained(
73
+ model_name=args.model_name,
74
+ max_seq_length=2048,
75
+ device_map="auto",
76
+ dtype=None,
77
+ load_in_4bit=True,
78
+ load_in_8bit=False,
79
+ full_finetuning=False
80
+ )
81
+
82
+ # model = FastLanguageModel.get_peft_model(
83
+ # model,
84
+ # r=32, # Choose any number > 0! Suggested 8, 16, 32, 64, 128
85
+ # target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
86
+ # "gate_proj", "up_proj", "down_proj", ],
87
+ # lora_alpha=32, # Best to choose alpha = rank or rank*2
88
+ # lora_dropout=0, # Supports any, but = 0 is optimized
89
+ # bias="none", # Supports any, but = "none" is optimized
90
+ # # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
91
+ # use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
92
+ # random_state=3407,
93
+ # use_rslora=False, # rank stabilized LoRA
94
+ # loftq_config=None, # LoftQ
95
+ # )
96
+ # print(model)
97
+
98
+ def format_func(example):
99
+ formatted_texts = []
100
+ for conv in example['conversations']:
101
+ formatted_texts.append(
102
+ tokenizer.apply_chat_template(
103
+ conv,
104
+ tokenize=False, # 训练时部分词,true返回的是张量
105
+ add_generation_prompt=False, # 训练期间要关闭,如果是推理则设为True
106
+ )
107
+ )
108
+
109
+ return {"text": formatted_texts}
110
+
111
+ dataset_dict = load_dataset(
112
+ path=args.dataset_path,
113
+ name=args.dataset_name,
114
+ data_dir="keywords",
115
+ # data_dir="psychology",
116
+ split=args.dataset_split,
117
+ cache_dir=args.dataset_cache_dir,
118
+ # num_proc=args.num_workers if not args.dataset_streaming else None,
119
+ streaming=args.dataset_streaming,
120
+ )
121
+ print(dataset_dict)
122
+ train_dataset = dataset_dict["train"]
123
+
124
+ train_dataset = train_dataset.map(
125
+ convert_to_qwen_format,
126
+ batched=True,
127
+ remove_columns=train_dataset.column_names
128
+ )
129
+ print(train_dataset)
130
+
131
+ train_dataset = train_dataset.map(
132
+ format_func,
133
+ batched=True,
134
+ remove_columns=train_dataset.column_names
135
+ )
136
+ print(train_dataset)
137
+
138
+ return
139
+
140
+
141
+ if __name__ == "__main__":
142
+ main()