AISkywalker commited on
Commit
38d8dc2
·
verified ·
1 Parent(s): fb31fe9

Upload 30 files

Browse files
DS_LoRA/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.15.0
DS_LoRA/adapter_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 32,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.1,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "r": 8,
24
+ "rank_pattern": {},
25
+ "revision": null,
26
+ "target_modules": [
27
+ "q_proj",
28
+ "v_proj"
29
+ ],
30
+ "task_type": "CAUSAL_LM",
31
+ "trainable_token_indices": null,
32
+ "use_dora": false,
33
+ "use_rslora": false
34
+ }
DS_LoRA/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b03770158458f7f2108bad02721e94cba1a8935ebe9e81038129be01a5f03bc
3
+ size 4372840
DS_RL_model/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.15.1
DS_RL_model/adapter_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": {
4
+ "base_model_class": "PeftModelForCausalLM",
5
+ "parent_library": "peft.peft_model"
6
+ },
7
+ "base_model_name_or_path": null,
8
+ "bias": "none",
9
+ "corda_config": null,
10
+ "eva_config": null,
11
+ "exclude_modules": null,
12
+ "fan_in_fan_out": false,
13
+ "inference_mode": true,
14
+ "init_lora_weights": true,
15
+ "layer_replication": null,
16
+ "layers_pattern": null,
17
+ "layers_to_transform": null,
18
+ "loftq_config": {},
19
+ "lora_alpha": 16,
20
+ "lora_bias": false,
21
+ "lora_dropout": 0.0,
22
+ "megatron_config": null,
23
+ "megatron_core": "megatron.core",
24
+ "modules_to_save": null,
25
+ "peft_type": "LORA",
26
+ "r": 8,
27
+ "rank_pattern": {},
28
+ "revision": null,
29
+ "target_modules": [
30
+ "k_proj",
31
+ "q_proj",
32
+ "v_proj"
33
+ ],
34
+ "task_type": null,
35
+ "trainable_token_indices": null,
36
+ "use_dora": false,
37
+ "use_rslora": false
38
+ }
DS_RL_model/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92cda569f4ae4f21ecbd0790b623428f349855221041023c8f578c0b188d6ac2
3
+ size 5988672
Qwen_CoT_LoRA/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: Qwen/Qwen2.5-0.5B-Instruct
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.15.1
Qwen_CoT_LoRA/adapter_config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": null,
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 32,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.1,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "r": 8,
24
+ "rank_pattern": {},
25
+ "revision": null,
26
+ "target_modules": [
27
+ "k_proj",
28
+ "v_proj",
29
+ "q_proj"
30
+ ],
31
+ "task_type": "CAUSAL_LM",
32
+ "trainable_token_indices": null,
33
+ "use_dora": false,
34
+ "use_rslora": false
35
+ }
Qwen_CoT_LoRA/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbb7d5e3fea8f9085b091a8e9b37028f06cd23297d9d02a1e7fbf747310e4c86
3
+ size 2970304
Qwen_LoRA/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: /workspace/local_model
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.15.0
Qwen_LoRA/adapter_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "/workspace/local_model",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 32,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.1,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "r": 8,
24
+ "rank_pattern": {},
25
+ "revision": null,
26
+ "target_modules": [
27
+ "v_proj",
28
+ "q_proj"
29
+ ],
30
+ "task_type": "CAUSAL_LM",
31
+ "trainable_token_indices": null,
32
+ "use_dora": false,
33
+ "use_rslora": false
34
+ }
Qwen_LoRA/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45708acd3a64f3f203e8eb44855bfc122a62008d16776b5c73672ec6e102eab2
3
+ size 2175168
code/GRPO.ipynb ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "e325d4b51fe4fad2",
6
+ "metadata": {},
7
+ "source": [
8
+ "# 加载模型以及数据"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "initial_id",
15
+ "metadata": {
16
+ "ExecuteTime": {
17
+ "end_time": "2025-03-31T15:57:38.940490Z",
18
+ "start_time": "2025-03-31T15:57:29.198500Z"
19
+ },
20
+ "collapsed": true
21
+ },
22
+ "outputs": [
23
+ {
24
+ "name": "stderr",
25
+ "output_type": "stream",
26
+ "text": [
27
+ "E:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
28
+ " from .autonotebook import tqdm as notebook_tqdm\n"
29
+ ]
30
+ }
31
+ ],
32
+ "source": [
33
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
34
+ "from trl import GRPOTrainer, GRPOConfig # 假设 trl 库中有 GRPOTrainer 模块\n",
35
+ "from peft import PeftModel\n",
36
+ "from reward import compute_rewards\n",
37
+ "def load_data(input_path):\n",
38
+ " data = []\n",
39
+ " with open(input_path, 'r', encoding='utf-8') as f:\n",
40
+ " for line in f:\n",
41
+ " line = line.strip()\n",
42
+ " if not line:\n",
43
+ " continue\n",
44
+ " parts = line.split('<think>', 1)\n",
45
+ " if len(parts) != 2:\n",
46
+ " print(f\"警告: 格式错误的行,已跳过: {line}\")\n",
47
+ " continue\n",
48
+ " keywords_part, lyrics = parts[0], parts[1]\n",
49
+ " keywords = [kw.strip() for kw in keywords_part.split(',')]\n",
50
+ "\n",
51
+ " # 关键修改:使用关键词作为 prompt,歌词作为 completion\n",
52
+ " data.append({\n",
53
+ " 'prompt': \"根据以下关键词生成一首歌词,歌词中包含多个句子,确保句子通顺,诗意,格式正确.让我们一步一步的思考(思考过程包含在<think>和</think>之间):\" + \",\".join(keywords), # 关键词拼接成字符串,作为模型输入\n",
54
+ " 'completion': \"<think>\" + lyrics, # 歌词(去掉多余空格),作为模型输出\n",
55
+ " 'keywords': keywords, # 关键词拼接成字符串,作为模型输入\n",
56
+ " })\n",
57
+ " \n",
58
+ " print(f\"成功加载 {len(data)} 条数据\")\n",
59
+ " return data\n",
60
+ "\n",
61
+ "\n"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": 2,
67
+ "id": "dc98a63e850836a",
68
+ "metadata": {
69
+ "ExecuteTime": {
70
+ "end_time": "2025-03-31T15:57:38.963191Z",
71
+ "start_time": "2025-03-31T15:57:38.951488Z"
72
+ }
73
+ },
74
+ "outputs": [
75
+ {
76
+ "name": "stdout",
77
+ "output_type": "stream",
78
+ "text": [
79
+ "成功加载 1000 条数据\n",
80
+ "第一条数据: {'prompt': '根据以下关键词生成一首歌词,歌词中包含多个句子,确保句子通顺,诗意,格式正确.让我们一步一步的思考(思考过程包含在<think>和</think>之间):午夜,寒冬,心动', 'completion': '<think>嗯,用户让我根据“午夜,寒冬,心动”这三个关键词写一首诗。首先,我需要理解每个词带来的意象和情感。午夜通常给人一种寂静、神秘的感觉,可能带有孤独或反思的情绪。寒冬则联想到寒冷、冰雪,可能还有萧瑟的景象,但也可以有温暖的反差,比如炉火或者温暖的回忆。心动是情感的核心,可能涉及爱情、突然的情感波动,或者内心的温暖。接下来,我要考虑如何将这三个元素融合在一起。可能需要营造一个寒冷的冬夜环境,然后在这样的背景下引入心动的瞬间,形成对比。比如,用寒冷的景象来衬托内心的温暖或突然的情感变化。然后,结构方面,可能需要分几段,每段围绕一个关键词展开,但又要相互联系。比如,前两段描绘午夜和寒冬的景象,第三段引入心动的元素。或者交替出现,交织在一起。语言风格方面,用户可能希望有诗意和意象化的表达,避免直白。使用比喻、拟人等修辞手法,比如将心跳比作钟摆,或者用冰晶、霜花来表现寒冷。同时,注意押韵和节奏,使诗歌流畅。还需要考虑用户可能的深层需求。他们可能不仅仅想要一首描述这三个词的诗,而是希望表达在寒冷孤独的环境中突然感受到的温暖或爱情,强调对比和情感冲击。因此,诗里需要有转折点,从冷到暖,从静到动。可能会遇到的挑战是如何自然地将三个关键词串联起来,不显得生硬。可能需要找到一个共同的场景或隐喻,比如午夜寒冬中的相遇,或者内心的变化。同时,避免陈词滥调,用新颖的比喻,比如将霜花比作时针,或者心跳与钟摆的共振。最后,检查诗的整体意境是否统一,情感是否连贯,是否有足够的意象让读��产生共鸣。可能需要调整用词和结构,使每个部分都服务于主题,增强整体效果。</think>:《冰裂时刻》/石英血管里霜花正逆向生长/成为时针,成为所有表盘/反向旋转的证词。玻璃在窗框深处/爆出细纹,像候鸟迁徙时/骨骼深处迸发的裂响//壁炉吞下最后一块松木的年轮/灰烬在铁架上凝结成第六种晶体/这时钟摆突然停驻——/某种振频正穿过雪原/穿过二十年冻土下沉默的矿脉//冰层深处传来远古鲸群的应和/我们站在地磁偏转的切线上/听见彼此胸腔里/石英开始顺时针碎裂/春天正从裂缝中/涌出第一滴蓝', 'keywords': ['午夜', '寒冬', '心动']}\n"
81
+ ]
82
+ }
83
+ ],
84
+ "source": [
85
+ "dataset = load_data('../data/CoTdata.txt')\n",
86
+ "if dataset:\n",
87
+ " print(\"第一条数据:\", dataset[0])\n",
88
+ "else:\n",
89
+ " print(\"未加载到有效数据。\")"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 3,
95
+ "id": "17175326d7901595",
96
+ "metadata": {
97
+ "ExecuteTime": {
98
+ "end_time": "2025-03-31T15:57:45.180036Z",
99
+ "start_time": "2025-03-31T15:57:41.329675Z"
100
+ }
101
+ },
102
+ "outputs": [
103
+ {
104
+ "name": "stderr",
105
+ "output_type": "stream",
106
+ "text": [
107
+ "Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.\n"
108
+ ]
109
+ },
110
+ {
111
+ "name": "stdout",
112
+ "output_type": "stream",
113
+ "text": [
114
+ "trainable params: 372,736 || all params: 1,777,460,736 || trainable%: 0.0210\n",
115
+ "None\n"
116
+ ]
117
+ },
118
+ {
119
+ "name": "stderr",
120
+ "output_type": "stream",
121
+ "text": [
122
+ "E:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\peft\\mapping_func.py:73: UserWarning: You are trying to modify a model with PEFT for a second time. If you want to reload the model with a different config, make sure to call `.unload()` before.\n",
123
+ " warnings.warn(\n",
124
+ "E:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\peft\\tuners\\tuners_utils.py:167: UserWarning: Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!\n",
125
+ " warnings.warn(\n"
126
+ ]
127
+ }
128
+ ],
129
+ "source": [
130
+ "base_model = AutoModelForCausalLM.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\").to(\"cuda\")\n",
131
+ "# 2. 加载 LoRA 适配器\n",
132
+ "model = PeftModel.from_pretrained(base_model, \"../3_26_LoRA\").to(\"cuda\") # 你的 LoRA 路径\n",
133
+ "tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\")\n",
134
+ "tokenizer.pad_token = tokenizer.eos_token\n",
135
+ "# Load LoRA lora_config = LoraConfig( task_type=\"CAUSAL_LM\", r=16, lora_alpha=32, target_modules=\"all-linear\", ) model = get_peft_model(model, lora_config) print(model.print_trainable_parameters()) \n",
136
+ "from peft import LoraConfig, get_peft_model\n",
137
+ "target_modules = [\"q_proj\", \"k_proj\", \"v_proj\"] \n",
138
+ "lora_config = LoraConfig(\n",
139
+ " r=2, # 秩(可尝试8~32)\n",
140
+ " lora_alpha=32, # 缩放系数(通常设为2*r)\n",
141
+ " target_modules=target_modules, \n",
142
+ " bias=\"none\", # 不训练偏置项\n",
143
+ ")\n",
144
+ "model = get_peft_model(model, lora_config) \n",
145
+ "print(model.print_trainable_parameters()) "
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "id": "d2a601c596644fc",
151
+ "metadata": {},
152
+ "source": [
153
+ "# 配置训练参数"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": 4,
159
+ "id": "8ce28487acb606a2",
160
+ "metadata": {
161
+ "ExecuteTime": {
162
+ "end_time": "2025-03-31T15:57:48.223131Z",
163
+ "start_time": "2025-03-31T15:57:48.149696Z"
164
+ }
165
+ },
166
+ "outputs": [],
167
+ "source": [
168
+ "# 配置 GRPO 的超参数(这里的参数可以根据需要进行调整)\n",
169
+ "config = GRPOConfig(\n",
170
+ " gradient_accumulation_steps = 50, # 多少步更新一次参考模型\n",
171
+ " per_device_train_batch_size=2, # 每个批次的样本数\n",
172
+ " epsilon=0.2, # GRPO 中的 clip 范围\n",
173
+ " beta=0.05, # KL 惩罚系数\n",
174
+ " num_train_epochs=1, # 总训练步数(总周期)\n",
175
+ " num_generations=2, # 分组采样的大小\n",
176
+ " learning_rate=1e-5, # 优化器的学习率\n",
177
+ " bf16=True, \n",
178
+ " adam_beta1=0.9,\n",
179
+ " adam_beta2=0.98,\n",
180
+ " optim=\"adamw_8bit\", # 优化器\n",
181
+ " max_grad_norm=0.1, # 梯度裁剪的最大值\n",
182
+ " save_steps=1000, # 多少步保存一次模型\n",
183
+ " save_total_limit=2, # 最多保存几个模型 \n",
184
+ " logging_steps=5, # 多少步打印一次训练信息\n",
185
+ " output_dir=\"GRPO\", # 模型保存路径\n",
186
+ " weight_decay=0.01, # 权重衰减\n",
187
+ " warmup_ratio=0.03, # 预热比例\n",
188
+ " max_prompt_length=256,\n",
189
+ " max_completion_length=1024, # 最大输出长度\n",
190
+ " report_to='tensorboard', # or `tensorboard`\n",
191
+ ")\n",
192
+ "# Training arguments training_args = GRPOConfig( \n",
193
+ "# output_dir=\"GRPO\", \n",
194
+ "# learning_rate=2e-5, \n",
195
+ "# per_device_train_batch_size=8, \n",
196
+ "# gradient_accumulation_steps=2, \n",
197
+ "# max_prompt_length=512, \n",
198
+ "# max_completion_length=96, \n",
199
+ "# num_generations=8, \n",
200
+ "# optim=\"adamw_8bit\", \n",
201
+ "# num_train_epochs=1, \n",
202
+ "# bf16=True, \n",
203
+ "# report_to=[\"wandb\"], \n",
204
+ "# remove_unused_columns=False, \n",
205
+ "# logging_steps=1, \n",
206
+ "# ) "
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "markdown",
211
+ "id": "793b3094cd98fed6",
212
+ "metadata": {},
213
+ "source": [
214
+ "# 训练模型"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": 5,
220
+ "id": "19094188d22e45c2",
221
+ "metadata": {
222
+ "ExecuteTime": {
223
+ "end_time": "2025-03-31T17:08:58.041584Z",
224
+ "start_time": "2025-03-31T15:57:50.316805Z"
225
+ }
226
+ },
227
+ "outputs": [
228
+ {
229
+ "name": "stderr",
230
+ "output_type": "stream",
231
+ "text": [
232
+ "No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n"
233
+ ]
234
+ },
235
+ {
236
+ "data": {
237
+ "text/html": [
238
+ "\n",
239
+ " <div>\n",
240
+ " \n",
241
+ " <progress value='2' max='20' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
242
+ " [ 2/20 : < :, Epoch 0.05/1]\n",
243
+ " </div>\n",
244
+ " <table border=\"1\" class=\"dataframe\">\n",
245
+ " <thead>\n",
246
+ " <tr style=\"text-align: left;\">\n",
247
+ " <th>Step</th>\n",
248
+ " <th>Training Loss</th>\n",
249
+ " </tr>\n",
250
+ " </thead>\n",
251
+ " <tbody>\n",
252
+ " </tbody>\n",
253
+ "</table><p>"
254
+ ],
255
+ "text/plain": [
256
+ "<IPython.core.display.HTML object>"
257
+ ]
258
+ },
259
+ "metadata": {},
260
+ "output_type": "display_data"
261
+ },
262
+ {
263
+ "ename": "KeyboardInterrupt",
264
+ "evalue": "",
265
+ "output_type": "error",
266
+ "traceback": [
267
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
268
+ "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
269
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 16\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Trainer trainer = GRPOTrainer( \u001b[39;00m\n\u001b[32m 2\u001b[39m \u001b[38;5;66;03m# model=model, \u001b[39;00m\n\u001b[32m 3\u001b[39m \u001b[38;5;66;03m# reward_funcs=[reward_len], \u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 8\u001b[39m \u001b[38;5;66;03m# wandb.init(project=\"GRPO\") \u001b[39;00m\n\u001b[32m 9\u001b[39m \u001b[38;5;66;03m# trainer.train()\u001b[39;00m\n\u001b[32m 10\u001b[39m trainer = GRPOTrainer(\n\u001b[32m 11\u001b[39m model=model,\n\u001b[32m 12\u001b[39m reward_funcs=[compute_rewards],\n\u001b[32m 13\u001b[39m args=config,\n\u001b[32m 14\u001b[39m train_dataset=dataset\n\u001b[32m 15\u001b[39m )\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
270
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\trainer.py:2245\u001b[39m, in \u001b[36mTrainer.train\u001b[39m\u001b[34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[39m\n\u001b[32m 2243\u001b[39m hf_hub_utils.enable_progress_bars()\n\u001b[32m 2244\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2245\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2246\u001b[39m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m=\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2247\u001b[39m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m=\u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2248\u001b[39m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2249\u001b[39m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m=\u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2250\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
271
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\trainer.py:2556\u001b[39m, in \u001b[36mTrainer._inner_training_loop\u001b[39m\u001b[34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[39m\n\u001b[32m 2549\u001b[39m context = (\n\u001b[32m 2550\u001b[39m functools.partial(\u001b[38;5;28mself\u001b[39m.accelerator.no_sync, model=model)\n\u001b[32m 2551\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m i != \u001b[38;5;28mlen\u001b[39m(batch_samples) - \u001b[32m1\u001b[39m\n\u001b[32m 2552\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m.accelerator.distributed_type != DistributedType.DEEPSPEED\n\u001b[32m 2553\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m contextlib.nullcontext\n\u001b[32m 2554\u001b[39m )\n\u001b[32m 2555\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[32m-> \u001b[39m\u001b[32m2556\u001b[39m tr_loss_step = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_items_in_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2558\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[32m 2559\u001b[39m args.logging_nan_inf_filter\n\u001b[32m 2560\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[32m 2561\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m (torch.isnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch.isinf(tr_loss_step))\n\u001b[32m 2562\u001b[39m ):\n\u001b[32m 2563\u001b[39m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[32m 2564\u001b[39m tr_loss = tr_loss + tr_loss / (\u001b[32m1\u001b[39m + \u001b[38;5;28mself\u001b[39m.state.global_step - \u001b[38;5;28mself\u001b[39m._globalstep_last_logged)\n",
272
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\trainer.py:3712\u001b[39m, in \u001b[36mTrainer.training_step\u001b[39m\u001b[34m(self, model, inputs, num_items_in_batch)\u001b[39m\n\u001b[32m 3709\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m.optimizer, \u001b[33m\"\u001b[39m\u001b[33mtrain\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mcallable\u001b[39m(\u001b[38;5;28mself\u001b[39m.optimizer.train):\n\u001b[32m 3710\u001b[39m \u001b[38;5;28mself\u001b[39m.optimizer.train()\n\u001b[32m-> \u001b[39m\u001b[32m3712\u001b[39m inputs = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_prepare_inputs\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3713\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m is_sagemaker_mp_enabled():\n\u001b[32m 3714\u001b[39m loss_mb = smp_forward_backward(model, inputs, \u001b[38;5;28mself\u001b[39m.args.gradient_accumulation_steps)\n",
273
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\trl\\extras\\profiling.py:87\u001b[39m, in \u001b[36mprofiling_decorator.<locals>.wrapper\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 84\u001b[39m \u001b[38;5;129m@functools\u001b[39m.wraps(func)\n\u001b[32m 85\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mwrapper\u001b[39m(\u001b[38;5;28mself\u001b[39m, *args, **kwargs):\n\u001b[32m 86\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m profiling_context(\u001b[38;5;28mself\u001b[39m, func.\u001b[34m__name__\u001b[39m):\n\u001b[32m---> \u001b[39m\u001b[32m87\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
274
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\trl\\trainer\\grpo_trainer.py:647\u001b[39m, in \u001b[36mGRPOTrainer._prepare_inputs\u001b[39m\u001b[34m(self, inputs)\u001b[39m\n\u001b[32m 645\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m mode == \u001b[33m\"\u001b[39m\u001b[33mtrain\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 646\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.state.global_step % \u001b[38;5;28mself\u001b[39m.num_iterations == \u001b[32m0\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m647\u001b[39m inputs = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_generate_and_score_completions\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 648\u001b[39m \u001b[38;5;28mself\u001b[39m._buffered_inputs[\u001b[38;5;28mself\u001b[39m._step % \u001b[38;5;28mself\u001b[39m.args.gradient_accumulation_steps] = inputs\n\u001b[32m 649\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n",
275
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\trl\\trainer\\grpo_trainer.py:719\u001b[39m, in \u001b[36mGRPOTrainer._generate_and_score_completions\u001b[39m\u001b[34m(self, inputs)\u001b[39m\n\u001b[32m 714\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 715\u001b[39m \u001b[38;5;66;03m# Regular generation path\u001b[39;00m\n\u001b[32m 716\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m unwrap_model_for_generation(\n\u001b[32m 717\u001b[39m \u001b[38;5;28mself\u001b[39m.model_wrapped, \u001b[38;5;28mself\u001b[39m.accelerator, gather_deepspeed3_params=\u001b[38;5;28mself\u001b[39m.args.ds3_gather_for_generation\n\u001b[32m 718\u001b[39m ) \u001b[38;5;28;01mas\u001b[39;00m unwrapped_model:\n\u001b[32m--> \u001b[39m\u001b[32m719\u001b[39m prompt_completion_ids = \u001b[43munwrapped_model\u001b[49m\u001b[43m.\u001b[49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 720\u001b[39m \u001b[43m \u001b[49m\u001b[43mprompt_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprompt_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mgeneration_config\u001b[49m\n\u001b[32m 721\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 723\u001b[39m \u001b[38;5;66;03m# Compute prompt length and extract completion ids\u001b[39;00m\n\u001b[32m 724\u001b[39m prompt_length = prompt_ids.size(\u001b[32m1\u001b[39m)\n",
276
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\peft\\peft_model.py:823\u001b[39m, in \u001b[36mPeftModel.generate\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 821\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m._enable_peft_forward_hooks(*args, **kwargs):\n\u001b[32m 822\u001b[39m kwargs = {k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m kwargs.items() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.special_peft_forward_args}\n\u001b[32m--> \u001b[39m\u001b[32m823\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mget_base_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
277
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\peft\\peft_model.py:1874\u001b[39m, in \u001b[36mPeftModelForCausalLM.generate\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1872\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m._enable_peft_forward_hooks(*args, **kwargs):\n\u001b[32m 1873\u001b[39m kwargs = {k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m kwargs.items() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.special_peft_forward_args}\n\u001b[32m-> \u001b[39m\u001b[32m1874\u001b[39m outputs = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbase_model\u001b[49m\u001b[43m.\u001b[49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1875\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 1876\u001b[39m outputs = \u001b[38;5;28mself\u001b[39m.base_model.generate(**kwargs)\n",
278
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\utils\\_contextlib.py:116\u001b[39m, in \u001b[36mcontext_decorator.<locals>.decorate_context\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 113\u001b[39m \u001b[38;5;129m@functools\u001b[39m.wraps(func)\n\u001b[32m 114\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdecorate_context\u001b[39m(*args, **kwargs):\n\u001b[32m 115\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[32m--> \u001b[39m\u001b[32m116\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
279
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\generation\\utils.py:2326\u001b[39m, in \u001b[36mGenerationMixin.generate\u001b[39m\u001b[34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, **kwargs)\u001b[39m\n\u001b[32m 2318\u001b[39m input_ids, model_kwargs = \u001b[38;5;28mself\u001b[39m._expand_inputs_for_generation(\n\u001b[32m 2319\u001b[39m input_ids=input_ids,\n\u001b[32m 2320\u001b[39m expand_size=generation_config.num_return_sequences,\n\u001b[32m 2321\u001b[39m is_encoder_decoder=\u001b[38;5;28mself\u001b[39m.config.is_encoder_decoder,\n\u001b[32m 2322\u001b[39m **model_kwargs,\n\u001b[32m 2323\u001b[39m )\n\u001b[32m 2325\u001b[39m \u001b[38;5;66;03m# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m2326\u001b[39m result = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_sample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2327\u001b[39m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2328\u001b[39m \u001b[43m \u001b[49m\u001b[43mlogits_processor\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprepared_logits_processor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2329\u001b[39m \u001b[43m \u001b[49m\u001b[43mstopping_criteria\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprepared_stopping_criteria\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2330\u001b[39m \u001b[43m \u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2331\u001b[39m \u001b[43m \u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[43m=\u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2332\u001b[39m \u001b[43m \u001b[49m\u001b[43mstreamer\u001b[49m\u001b[43m=\u001b[49m\u001b[43mstreamer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2333\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2334\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2336\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m generation_mode \u001b[38;5;129;01min\u001b[39;00m (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):\n\u001b[32m 2337\u001b[39m \u001b[38;5;66;03m# 11. interleave input_ids with `num_beams` additional sequences per batch\u001b[39;00m\n\u001b[32m 2338\u001b[39m input_ids, model_kwargs = \u001b[38;5;28mself\u001b[39m._expand_inputs_for_generation(\n\u001b[32m 2339\u001b[39m input_ids=input_ids,\n\u001b[32m 2340\u001b[39m expand_size=generation_config.num_beams,\n\u001b[32m 2341\u001b[39m is_encoder_decoder=\u001b[38;5;28mself\u001b[39m.config.is_encoder_decoder,\n\u001b[32m 2342\u001b[39m **model_kwargs,\n\u001b[32m 2343\u001b[39m )\n",
280
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\generation\\utils.py:3289\u001b[39m, in \u001b[36mGenerationMixin._sample\u001b[39m\u001b[34m(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)\u001b[39m\n\u001b[32m 3287\u001b[39m is_prefill = \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[32m 3288\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m3289\u001b[39m outputs = \u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mmodel_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 3291\u001b[39m \u001b[38;5;66;03m# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping\u001b[39;00m\n\u001b[32m 3292\u001b[39m model_kwargs = \u001b[38;5;28mself\u001b[39m._update_model_kwargs_for_generation(\n\u001b[32m 3293\u001b[39m outputs,\n\u001b[32m 3294\u001b[39m model_kwargs,\n\u001b[32m 3295\u001b[39m is_encoder_decoder=\u001b[38;5;28mself\u001b[39m.config.is_encoder_decoder,\n\u001b[32m 3296\u001b[39m )\n",
281
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
282
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
283
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\utils\\deprecation.py:172\u001b[39m, in \u001b[36mdeprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 168\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m minimum_action \u001b[38;5;129;01min\u001b[39;00m (Action.NOTIFY, Action.NOTIFY_ALWAYS) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n\u001b[32m 169\u001b[39m \u001b[38;5;66;03m# DeprecationWarning is ignored by default, so we use FutureWarning instead\u001b[39;00m\n\u001b[32m 170\u001b[39m warnings.warn(message, \u001b[38;5;167;01mFutureWarning\u001b[39;00m, stacklevel=\u001b[32m2\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m172\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
284
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\models\\qwen2\\modeling_qwen2.py:855\u001b[39m, in \u001b[36mQwen2ForCausalLM.forward\u001b[39m\u001b[34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, **kwargs)\u001b[39m\n\u001b[32m 852\u001b[39m return_dict = return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m.config.use_return_dict\n\u001b[32m 854\u001b[39m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m855\u001b[39m outputs = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 856\u001b[39m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m=\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 857\u001b[39m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 858\u001b[39m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[43m=\u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 859\u001b[39m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[43m=\u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 860\u001b[39m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[43m=\u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 861\u001b[39m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m=\u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 862\u001b[39m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m=\u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 863\u001b[39m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m=\u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 864\u001b[39m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[43m=\u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 865\u001b[39m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 866\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 867\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 869\u001b[39m hidden_states = outputs[\u001b[32m0\u001b[39m]\n\u001b[32m 870\u001b[39m \u001b[38;5;66;03m# Only compute necessary logits, and do not upcast them to float if we are not computing the loss\u001b[39;00m\n",
285
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
286
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
287
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\models\\qwen2\\modeling_qwen2.py:579\u001b[39m, in \u001b[36mQwen2Model.forward\u001b[39m\u001b[34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)\u001b[39m\n\u001b[32m 567\u001b[39m layer_outputs = \u001b[38;5;28mself\u001b[39m._gradient_checkpointing_func(\n\u001b[32m 568\u001b[39m decoder_layer.\u001b[34m__call__\u001b[39m,\n\u001b[32m 569\u001b[39m hidden_states,\n\u001b[32m (...)\u001b[39m\u001b[32m 576\u001b[39m position_embeddings,\n\u001b[32m 577\u001b[39m )\n\u001b[32m 578\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m579\u001b[39m layer_outputs = \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 580\u001b[39m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 581\u001b[39m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 582\u001b[39m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[43m=\u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 583\u001b[39m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[43m=\u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 584\u001b[39m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m=\u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 585\u001b[39m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m=\u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 586\u001b[39m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 587\u001b[39m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m=\u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 588\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mflash_attn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 589\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 591\u001b[39m hidden_states = layer_outputs[\u001b[32m0\u001b[39m]\n\u001b[32m 593\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m output_attentions:\n",
288
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
289
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
290
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\models\\qwen2\\modeling_qwen2.py:276\u001b[39m, in \u001b[36mQwen2DecoderLayer.forward\u001b[39m\u001b[34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)\u001b[39m\n\u001b[32m 274\u001b[39m residual = hidden_states\n\u001b[32m 275\u001b[39m hidden_states = \u001b[38;5;28mself\u001b[39m.post_attention_layernorm(hidden_states)\n\u001b[32m--> \u001b[39m\u001b[32m276\u001b[39m hidden_states = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 277\u001b[39m hidden_states = residual + hidden_states\n\u001b[32m 279\u001b[39m outputs = (hidden_states,)\n",
291
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
292
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
293
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\models\\qwen2\\modeling_qwen2.py:57\u001b[39m, in \u001b[36mQwen2MLP.forward\u001b[39m\u001b[34m(self, x)\u001b[39m\n\u001b[32m 56\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[32m---> \u001b[39m\u001b[32m57\u001b[39m down_proj = \u001b[38;5;28mself\u001b[39m.down_proj(\u001b[38;5;28mself\u001b[39m.act_fn(\u001b[38;5;28mself\u001b[39m.gate_proj(x)) * \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mup_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 58\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m down_proj\n",
294
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
295
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
296
+ "\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\linear.py:125\u001b[39m, in \u001b[36mLinear.forward\u001b[39m\u001b[34m(self, input)\u001b[39m\n\u001b[32m 124\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) -> Tensor:\n\u001b[32m--> \u001b[39m\u001b[32m125\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[43m.\u001b[49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
297
+ "\u001b[31mKeyboardInterrupt\u001b[39m: "
298
+ ]
299
+ }
300
+ ],
301
+ "source": [
302
+ "# Trainer trainer = GRPOTrainer( \n",
303
+ "# model=model, \n",
304
+ "# reward_funcs=[reward_len], \n",
305
+ "# args=training_args, \n",
306
+ "# train_dataset=dataset[\"train\"], \n",
307
+ "# ) \n",
308
+ "# Train model \n",
309
+ "# wandb.init(project=\"GRPO\") \n",
310
+ "# trainer.train()\n",
311
+ "trainer = GRPOTrainer(\n",
312
+ " model=model,\n",
313
+ " reward_funcs=[compute_rewards],\n",
314
+ " args=config,\n",
315
+ " train_dataset=dataset\n",
316
+ ")\n",
317
+ "trainer.train()"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "markdown",
322
+ "id": "f621c33533e55b00",
323
+ "metadata": {},
324
+ "source": [
325
+ "# 评估"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "id": "d17d0e3eb9069545",
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "import matplotlib.pyplot as plt\n",
336
+ "from datetime import datetime\n",
337
+ "\n",
338
+ "def plot_training_metrics(losses, kls, avg_rewards, output_dir=\".\"):\n",
339
+ " \"\"\"\n",
340
+ " 绘制并保存训练指标图表\n",
341
+ " \n",
342
+ " 参数:\n",
343
+ " losses: 训练损失列表\n",
344
+ " kls: KL散度列表\n",
345
+ " avg_rewards: 平均奖励列表\n",
346
+ " output_dir: 输出目录路径\n",
347
+ " \"\"\"\n",
348
+ " # 生成带时间戳的唯一文件名\n",
349
+ " timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
350
+ " output_path = f\"{output_dir}/training_curves_{timestamp}.png\"\n",
351
+ " \n",
352
+ " # 创建画布\n",
353
+ " plt.figure(figsize=(15, 5), dpi=300)\n",
354
+ " \n",
355
+ " # 1. Loss 曲线\n",
356
+ " plt.subplot(1, 3, 1)\n",
357
+ " plt.plot(losses, label=\"Loss\", linewidth=1.5, color='blue')\n",
358
+ " plt.title(\"Training Loss\", fontsize=10)\n",
359
+ " plt.xlabel(\"Step\", fontsize=9)\n",
360
+ " plt.ylabel(\"Loss\", fontsize=9)\n",
361
+ " plt.grid(True, alpha=0.3)\n",
362
+ " \n",
363
+ " # 2. KL 散度曲线\n",
364
+ " plt.subplot(1, 3, 2)\n",
365
+ " plt.plot(kls, label=\"KL Divergence\", linewidth=1.5, color='orange')\n",
366
+ " plt.title(\"KL Divergence\", fontsize=10)\n",
367
+ " plt.xlabel(\"Step\", fontsize=9)\n",
368
+ " plt.ylabel(\"KL Divergence\", fontsize=9)\n",
369
+ " plt.grid(True, alpha=0.3)\n",
370
+ " \n",
371
+ " # 3. 平均奖励曲线\n",
372
+ " plt.subplot(1, 3, 3)\n",
373
+ " plt.plot(avg_rewards, label=\"Avg Reward\", linewidth=1.5, color='green')\n",
374
+ " plt.title(\"Average Reward\", fontsize=10)\n",
375
+ " plt.xlabel(\"Step\", fontsize=9)\n",
376
+ " plt.ylabel(\"Reward\", fontsize=9)\n",
377
+ " plt.grid(True, alpha=0.3)\n",
378
+ " \n",
379
+ " # 调整布局并保存\n",
380
+ " plt.tight_layout()\n",
381
+ " plt.savefig(\n",
382
+ " output_path,\n",
383
+ " bbox_inches='tight',\n",
384
+ " facecolor='white',\n",
385
+ " dpi=300\n",
386
+ " )\n",
387
+ " plt.close()\n",
388
+ " \n",
389
+ " print(f\"训练指标图表已保存至: {output_path}\")\n",
390
+ "\n",
391
+ "# 使用示例 (假设你已经有了这些数据)\n",
392
+ "# losses = [...] # 你的损失数据\n",
393
+ "# kls = [...] # 你的KL散度数据\n",
394
+ "# avg_rewards = [...] # 你的平均奖励数据\n",
395
+ "# plot_training_metrics(losses, kls, avg_rewards)\n",
396
+ "\n"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "code",
401
+ "execution_count": null,
402
+ "id": "b6a739a2f9d0a343",
403
+ "metadata": {},
404
+ "outputs": [],
405
+ "source": [
406
+ "class MetricsCallback(TrainerCallback):\n",
407
+ " def __init__(self):\n",
408
+ " super().__init__()\n",
409
+ " self.metrics = {\n",
410
+ " 'loss': [], \n",
411
+ " 'kl_divergence': [], \n",
412
+ " 'avg_reward': []\n",
413
+ " }\n",
414
+ " \n",
415
+ " def on_log(self, args, state, control, logs=None, **kwargs):\n",
416
+ " if logs is not None:\n",
417
+ " if 'loss' in logs:\n",
418
+ " self.metrics['loss'].append(logs['loss'])\n",
419
+ " if 'kl_divergence' in logs:\n",
420
+ " self.metrics['kl_divergence'].append(logs['kl_divergence'])\n",
421
+ " if 'rewards' in logs: # 假设返回的是列表,取其平均值\n",
422
+ " avg_reward = sum(logs['rewards'])/len(logs['rewards'])\n",
423
+ " self.metrics['avg_reward'].append(avg_reward)\n",
424
+ " \n",
425
+ " "
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": null,
431
+ "id": "14cee34aa3bb165",
432
+ "metadata": {},
433
+ "outputs": [],
434
+ "source": [
435
+ "plot_training_metrics(metrics_callback.metrics['loss'],metrics_callback.metrics['kl_divergence'],metrics_callback.metrics['avg_reward'])"
436
+ ]
437
+ }
438
+ ],
439
+ "metadata": {
440
+ "kernelspec": {
441
+ "display_name": ".venv",
442
+ "language": "python",
443
+ "name": "python3"
444
+ },
445
+ "language_info": {
446
+ "codemirror_mode": {
447
+ "name": "ipython",
448
+ "version": 2
449
+ },
450
+ "file_extension": ".py",
451
+ "mimetype": "text/x-python",
452
+ "name": "python",
453
+ "nbconvert_exporter": "python",
454
+ "pygments_lexer": "ipython2",
455
+ "version": "3.11.3"
456
+ }
457
+ },
458
+ "nbformat": 4,
459
+ "nbformat_minor": 5
460
+ }
code/LORA.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
3
+ from peft import LoraConfig, get_peft_model, PeftModel
4
+
5
+ raw_data_path = ""#替换为对应的数据集路径
6
+ with open(raw_data_path, "r", encoding="utf-8") as f:
7
+ raw_lines = f.readlines()
8
+
9
+ def process_line(line):
10
+ segments = line.strip().split("/")
11
+ return "/".join(segments[:-1]) if len(segments) > 1 else line.strip()
12
+
13
+ processed_samples = [process_line(line) for line in raw_lines if line.strip()]
14
+ dataset = Dataset.from_dict({"text": processed_samples})
15
+
16
+ model_name = ""#替换为对应的模型路径
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = AutoModelForCausalLM.from_pretrained(model_name)
19
+
20
+ lora_config = LoraConfig(
21
+ r=8, # 低秩矩阵的秩,通常取 8、16 或 32
22
+ lora_alpha=32, # 缩放因子,控制 LoRA 的影响
23
+ target_modules=["q_proj", "v_proj"], # 应用 LoRA 的模块,通常是注意力层的投影
24
+ lora_dropout=0.1, # Dropout 概率,防止过拟合
25
+ bias="none", # 是否训练偏置,通常设为 "none"
26
+ task_type="CAUSAL_LM" # 任务类型,对于因果语言模型使用 "CAUSAL_LM"
27
+ )
28
+ model = get_peft_model(model, lora_config)
29
+
30
+ def tokenize_function(examples):
31
+ # 预定义固定的提示词
32
+ prompt = "根据以下关键词生成一首歌词,歌词中包含多个句子,句子与句子之间使用/隔开,让我们一步一步的思考(思考过程包含在<think>和</think>之间):"
33
+
34
+ # 在原文本前面加上提示词
35
+ modified_texts = [prompt + text for text in examples["text"]]
36
+
37
+ # 进行分词
38
+ tokenized = tokenizer(modified_texts, truncation=True, padding="max_length", max_length=256)
39
+
40
+ # 复制 input_ids 作为 labels
41
+ tokenized["labels"] = tokenized["input_ids"].copy()
42
+
43
+ return tokenized
44
+
45
+
46
+
47
+
48
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
49
+
50
+ training_args = TrainingArguments(
51
+ output_dir="./lora",
52
+ num_train_epochs=8,
53
+ per_device_train_batch_size=10,
54
+ learning_rate=2e-5,
55
+ weight_decay=0.01,
56
+ logging_steps=10000,
57
+ save_steps=15000,
58
+ fp16=True,
59
+ )
60
+
61
+ trainer = Trainer(
62
+ model=model,
63
+ args=training_args,
64
+ train_dataset=tokenized_dataset,
65
+ tokenizer=tokenizer,
66
+ )
67
+
68
+
69
+ trainer.train()
70
+
71
+ # 推理示例
72
+ generation_config = {
73
+ "max_new_tokens": 1024,
74
+ "temperature": 1.0,
75
+ "top_p": 0.9,
76
+ "top_k": 40,
77
+ "repetition_penalty": 1.2,
78
+ "do_sample": True,
79
+ "encoder_no_repeat_ngram_size": 4,
80
+ }
81
+ if True:
82
+ prompt = "根据以下关键词生成一首歌词,歌词中包含多个句子,句子与句子之间使用/隔开,让我们一步一步的思考(思考过程包含在<think>和</think>之间):温柔,轮廓,洒脱:"
83
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
84
+ outputs = model.generate(input_ids, **generation_config)
85
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
86
+
87
+ print(decoded)
88
+
89
+ model.save_pretrained("")#替换为对应的保存路径
code/LORA_with_CoT.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from datasets import Dataset
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
4
+ from peft import LoraConfig, get_peft_model, PeftModel
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ model_name = "Qwen/Qwen2.5-0.5B-Instruct"
8
+
9
+ base_model = AutoModelForCausalLM.from_pretrained(
10
+ model_name,
11
+ torch_dtype="auto",
12
+ device_map="auto"
13
+ )
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+
16
+ # 加载数据
17
+ raw_data_path = r"data/CoTdata.txt"
18
+ with open(raw_data_path, "r", encoding="utf-8") as f:
19
+ raw_lines = f.readlines()
20
+
21
+ # 处理每一行数据,解析出关键词、思维链和诗歌内容
22
+ def process_line(line):
23
+ # 使用 [::] 同时匹配中文和英文冒号
24
+ pattern = r"^(.*?)<think>(.*?)</think>[::](.*)$"
25
+ match = re.match(pattern, line.strip())
26
+ if match:
27
+ keywords = match.group(1).strip()
28
+ cot = match.group(2).strip()
29
+ poem = match.group(3).strip()
30
+ # 构造训练实例:输入部分给出提示和关键词,输出部分包含完整思维链及答案
31
+ training_text = (
32
+ f"【输入】:根据以下关键词生成一首歌词,歌词中包含多个句子,确保句子通顺、诗意、格式正确。"
33
+ f"让我们一步一步的思考(思考过程包含在<think>和</think>之间):{keywords}\n\n"
34
+ f"【输出】:<think>{cot}</think>\n{poem}"
35
+ )
36
+ return training_text
37
+ else:
38
+ # 如果格式不符,输出提示并返回 None
39
+ print("跳过格式错误的行:", line.strip())
40
+ return None
41
+
42
+ # 解析所有数据行
43
+ processed_samples = []
44
+ for line in raw_lines:
45
+ result = process_line(line)
46
+ if result:
47
+ processed_samples.append(result)
48
+
49
+ # 构建 Hugging Face 数据集
50
+ dataset = Dataset.from_dict({"text": processed_samples})
51
+
52
+ # 加载基础模型和 LoRA 模型
53
+ model = PeftModel.from_pretrained(base_model, r"D:\GoodMusicV3.0\3_24_LoRA").to("cuda") # 替换为你的 LoRA 路径
54
+ tokenizer.pad_token = tokenizer.eos_token
55
+
56
+ lora_config = LoraConfig(
57
+ r=8, # 低秩矩阵的秩,常取 8、16 或 32
58
+ lora_alpha=32, # 缩放因子,控制 LoRA 影响
59
+ target_modules=["q_proj", "k_proj", "v_proj"], # 应用 LoRA 的模块,通常是注意力层的投影
60
+ lora_dropout=0.1, # Dropout 概率,防止过拟合
61
+ bias="none", # 通常设为 "none"
62
+ task_type="CAUSAL_LM"
63
+ )
64
+ model = get_peft_model(model, lora_config)
65
+ model.cuda()
66
+
67
+ # 分词函数:对文本进行分词,并构造 labels
68
+ def tokenize_function(examples):
69
+ # 此处的文本已经包含了输入和输出的完整内容
70
+ tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=256)
71
+ tokenized["labels"] = tokenized["input_ids"].copy()
72
+ return tokenized
73
+
74
+ # 对数据集进行映射处理
75
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
76
+
77
+ # 设置训练参数
78
+ training_args = TrainingArguments(
79
+ output_dir="./lora",
80
+ num_train_epochs=1000,
81
+ per_device_train_batch_size=16,
82
+ learning_rate=2e-5,
83
+ weight_decay=0.01,
84
+ logging_steps=10000,
85
+ save_steps=15000,
86
+ fp16=True,
87
+ )
88
+
89
+ # 构造 Trainer
90
+ trainer = Trainer(
91
+ model=model,
92
+ args=training_args,
93
+ train_dataset=tokenized_dataset,
94
+ tokenizer=tokenizer,
95
+ )
96
+
97
+ # 开始训练
98
+ trainer.train()
99
+
100
+ # 推理示例
101
+ generation_config = {
102
+ "max_new_tokens": 1024,
103
+ "temperature": 1.0,
104
+ "top_p": 0.9,
105
+ "top_k": 40,
106
+ "repetition_penalty": 1.2,
107
+ "do_sample": True,
108
+ "encoder_no_repeat_ngram_size": 4,
109
+ }
110
+ if True:
111
+ prompt = "根据以下关键词生成一首歌词,歌词中包含多个句子,句子与句子之间使用/隔开,让我们一步一步的思考(思考过程包含在<think>和</think>之间):温柔,轮廓,洒脱:"
112
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
113
+ outputs = model.generate(input_ids, **generation_config)
114
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
115
+
116
+ print(decoded)
117
+
118
+ # 保存模型
119
+ model.save_pretrained("4_2_LoRA_3")
code/UI.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton, QTextEdit, QLineEdit, QListWidget, QLabel, QHBoxLayout, QListWidgetItem
3
+ import _MyModel
4
+ from PyQt5.QtGui import QColor, QPalette
5
+ from PyQt5.QtCore import Qt
6
+ class ChatSession:
7
+ """储存单个对话的内容"""
8
+ def __init__(self, topic="新对话"):
9
+ self.topic = topic
10
+ self.messages = [] # 存储聊天记录
11
+
12
+ def add_message(self, sender, text):
13
+ """添加消息(sender: 'user' 或 'ai')"""
14
+ self.messages.append((sender, text))
15
+
16
+
17
+ class ChatGPTUI(QWidget):
18
+
19
+ def __init__(self, MyModel):
20
+ super().__init__()
21
+ self.model = MyModel
22
+ self.first_list_item = QListWidget()
23
+ self.setWindowTitle("ChatGPT 聊天界面")
24
+ self.setGeometry(200, 200, 800, 600)
25
+ self.setStyleSheet("background-color: #DCB272; color: white;") # 设置深色背景
26
+ #self.setWindowFlags(Qt.FramelessWindowHint) # 设置无边框
27
+ # 创建主布局
28
+ main_layout = QHBoxLayout(self)
29
+ # 左侧:
30
+ left_layout = QVBoxLayout()
31
+ # 添加“新建对话”按钮
32
+ self.new_chat_button = QPushButton("新建对话")
33
+ self.new_chat_button.setStyleSheet("background-color: #0FA958; color: white; padding: 8px; border-radius: 5px;")
34
+ self.new_chat_button.clicked.connect(self.create_new_chat)
35
+ left_layout.addWidget(self.new_chat_button)
36
+ # 左侧:对话历史列表
37
+ self.history_list = QListWidget()
38
+ self.history_list.setStyleSheet("background-color: #E4DECE; color: black; border: none;")
39
+ self.history_list.itemClicked.connect(self.load_selected_chat) # 绑定选择对话事件
40
+ left_layout.addWidget(self.history_list)
41
+
42
+ # 右侧:聊天区域
43
+ right_layout = QVBoxLayout()
44
+
45
+ # 对话主题输入框
46
+ self.topic_input = QLineEdit()
47
+ self.topic_input.setPlaceholderText("请输入对话主题...")
48
+ self.topic_input.setStyleSheet("background-color: #E4DECE; color: black; padding: 5px; border-radius: 5px;")
49
+
50
+
51
+
52
+ # 聊天显示区域
53
+ self.chat_display = QTextEdit()
54
+ self.chat_display.setReadOnly(True)
55
+ self.chat_display.setStyleSheet("background-color: #E4DECE; color: black; border: none; padding: 10px;")
56
+ right_layout.addWidget(self.chat_display, 7)
57
+
58
+ # 输入区域(水平布局)
59
+ input_layout = QHBoxLayout()
60
+
61
+ # 用户输入框
62
+ self.input_field = QLineEdit()
63
+ self.input_field.setPlaceholderText("输入消息...")
64
+ self.input_field.setStyleSheet("background-color: #E4DECE; color: black; padding: 5px; border-radius: 5px;")
65
+ input_layout.addWidget(self.input_field, 8)
66
+ self.input_field.returnPressed.connect(self.send_message)
67
+
68
+ # 发送按钮
69
+ self.send_button = QPushButton("发送")
70
+ self.send_button.setStyleSheet("background-color: #DA8D6D; color: white; padding: 8px; border-radius: 5px;")
71
+ self.send_button.clicked.connect(self.send_message)
72
+ input_layout.addWidget(self.send_button, 2)
73
+
74
+ right_layout.addLayout(input_layout)
75
+
76
+ # 将右侧布局添加到主布局
77
+ main_layout.addLayout(left_layout, 2)
78
+ main_layout.addLayout(right_layout, 8)
79
+
80
+ self.setLayout(main_layout)
81
+
82
+ # 初始对话存储
83
+ self.chat_sessions = [] # 存储多个会话
84
+ self.current_session = None
85
+ self.create_new_chat() # 启动时创建默认对话
86
+
87
+ def create_new_chat(self):
88
+ """新建对话并添加到历史列表"""
89
+ topic = self.topic_input.text().strip()
90
+ if not topic:
91
+ topic = "新对话"
92
+
93
+ new_session = ChatSession(topic)
94
+ self.chat_sessions.append(new_session)
95
+ self.current_session = new_session
96
+
97
+ # 更新左侧历史对话列表
98
+ self.add_chat_item(topic)
99
+ self.history_list.setCurrentRow(self.history_list.count() - 1) # 选中新建的对话
100
+ self.chat_display.clear()
101
+
102
+ def load_selected_chat(self):
103
+ """切换到用户选择的历史对话"""
104
+ selected_index = self.history_list.currentRow()
105
+ if selected_index >= 0:
106
+ self.current_session = self.chat_sessions[selected_index]
107
+ self.display_chat_history()
108
+
109
+ def display_chat_history(self):
110
+ """显示当前会话的聊天记录"""
111
+ self.chat_display.clear()
112
+ for sender, text in self.current_session.messages:
113
+ if sender == 'user':
114
+ self.chat_display.append(f"<b><span style='color: #9b7438; font-family: 微软雅黑; font-size: 28px'>主题 : </span><span style='color: #1B2131; font-family: 微软雅黑; font-size: 28px'> {text}</span></b>")
115
+ else:
116
+ self.chat_display.append(f"<b>{'用户' if sender == 'user' else 'ChatGPT'}:</b> {text}")
117
+
118
+ def send_message(self):
119
+ """发送用户输入的消息"""
120
+ user_text = self.input_field.text().strip()
121
+ if user_text and self.current_session:
122
+ self.current_session.add_message("user", user_text)
123
+ self.chat_display.append(f"<b><span style='color: #9b7438; font-family: 微软雅黑; font-size: 28px'>主题 : </span><span style='color: #1B2131; font-family: 微软雅黑; font-size: 28px'> {user_text}</span></b>")
124
+ self.input_field.clear()
125
+
126
+ # 触发 AI 回复(暂时用占位内容)
127
+ ai_reply = self.get_ai_response(user_text)
128
+ self.receive_message(ai_reply)
129
+
130
+ def receive_message(self, text):
131
+ """显示 AI 回复"""
132
+ if self.current_session:
133
+ self.current_session.add_message("ai", text)
134
+ self.chat_display.append(f"<b>ChatGPT:</b> {text}")
135
+
136
+ def get_ai_response(self, user_input):
137
+ """可在此接入 AI 模型,如 OpenAI API 或本地大模型"""
138
+ output = self.model.predict(user_input)
139
+ return f"<span style='font-size: 20px;'>{output}</span>"
140
+
141
+ def add_chat_item(self, text):
142
+ """ 添加带删除按钮的聊天记录项 """
143
+ item_widget = QWidget()
144
+ item_layout = QHBoxLayout(item_widget)
145
+ item_layout.setContentsMargins(5, 2, 5, 2)
146
+
147
+ label = QLabel(text)
148
+ delete_button = QPushButton("×")
149
+ delete_button.setFixedSize(20, 20)
150
+ delete_button.setStyleSheet("background-color: #cc6666; color: white; border-radius: 10px;")
151
+
152
+ item_layout.addWidget(label)
153
+ item_layout.addWidget(delete_button)
154
+ item_layout.addStretch()
155
+
156
+ list_item = QListWidgetItem(self.history_list)
157
+ list_item.setSizeHint(item_widget.sizeHint())
158
+
159
+ self.history_list.addItem(list_item)
160
+ self.history_list.setItemWidget(list_item, item_widget)
161
+
162
+ # 绑定删除事件
163
+ delete_button.clicked.connect(lambda: self.remove_chat_item(list_item))
164
+ self.first_list_item = list_item
165
+
166
+ def remove_chat_item(self, item):
167
+ """ 删除聊天记录项 """
168
+ row = self.history_list.row(item)
169
+ del self.chat_sessions[row]
170
+ self.history_list.takeItem(row)
171
+
172
+
173
+ # 运行 PyQt5 应用
174
+ if __name__ == "__main__":
175
+
176
+ app = QApplication(sys.argv)
177
+ window = ChatGPTUI()
178
+ window.show()
179
+ window.remove_chat_item(window.first_list_item)
180
+ window.create_new_chat()
181
+ sys.exit(app.exec_())
code/_MyModel.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from peft import LoraConfig, get_peft_model, PeftModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ class MyModel():
6
+ def __init__(self):
7
+ model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
8
+ lora_path = "DS_RL_model"
9
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForCausalLM.from_pretrained(model_name)
11
+ self.model = PeftModel.from_pretrained(model, lora_path)
12
+ self.generation_config = {
13
+ "max_new_tokens": 2048,
14
+ "temperature": 0.9,
15
+ "top_p": 1.0,
16
+ "repetition_penalty": 1.2,
17
+ }
18
+ def predict(self, text):
19
+ prompt = "根据以下关键词生成一首歌词,歌词中包含多个句子,句子与句子之间使用/隔开,让我们一步一步的思考(思考过程包含在<think>和</think>之间):" + text
20
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device)
21
+ outputs = self.model.generate(input_ids, **self.generation_config)
22
+ decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
23
+ return decoded
24
+ #诗,样子,天地:
code/__main__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import _MyModel
3
+ from UI import QApplication, ChatGPTUI
4
+ import os
5
+ os.environ["QT_QPA_PLATFORM_PLUGIN_PATH"] = "D:\不会编程\Machine_Learning\class_project\project\.venv\Lib\site-packages\PyQt5\Qt5\plugins"
6
+
7
+
8
+ if __name__ == '__main__':
9
+ myModel = _MyModel.MyModel()
10
+ app = QApplication(sys.argv)
11
+ window = ChatGPTUI(myModel)
12
+ window.show()
13
+ window.remove_chat_item(window.first_list_item)
14
+ window.create_new_chat()
15
+ sys.exit(app.exec_())
code/__pycache__/UI.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
code/__pycache__/_MyModel.cpython-311.pyc ADDED
Binary file (2.11 kB). View file
 
code/__pycache__/deepseek_vaule.cpython-311.pyc ADDED
Binary file (10 kB). View file
 
code/__pycache__/reward.cpython-311.pyc ADDED
Binary file (5.73 kB). View file
 
code/__pycache__/train_nessary.cpython-311.pyc ADDED
Binary file (8.72 kB). View file
 
code/data_process.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def contains_chinese(text):
4
+ """
5
+ Unicode 范围 \u4e00-\u9fff 包含常见的汉字
6
+ """
7
+ return re.search(r'[\u4e00-\u9fff]', text) is not None
8
+
9
+ def process_lyrics(text):
10
+ """
11
+ 处理歌词文本:
12
+ 1. 按 '/' 分割
13
+ 2. 去除空白及空行
14
+ 3. 过滤掉不包含中文(视为英文)的歌词
15
+ 4. 去除重复歌词(保持原始顺序)
16
+ """
17
+ # 使用 '/' 分割字符串得到歌词列表
18
+ lyrics = text.split('/')
19
+ processed = []
20
+ seen = set()
21
+
22
+ for line in lyrics:
23
+ # 去除两端空白
24
+ line = line.strip()
25
+ # 如果为空则跳过
26
+ if not line:
27
+ continue
28
+ # 如果这句歌词不包含中文,则视为英文歌词,跳过
29
+ if not contains_chinese(line):
30
+ continue
31
+ if len(line) < 3:
32
+ continue
33
+ # 去重:如果该句未出现过,则添加到结果中
34
+ if line not in seen:
35
+ seen.add(line)
36
+ processed.append(line)
37
+
38
+ return processed
39
+
40
+ def main():
41
+ input_filename = 'data\lyrics.txt'
42
+ output_filename = 'data\processed_data.txt'
43
+
44
+ # 读取原始数据文件,建议使用 utf-8 编码
45
+ with open(input_filename, 'r', encoding='utf-8') as f:
46
+ content = f.read()
47
+
48
+ # 处理歌词数据
49
+ processed = process_lyrics(content)
50
+
51
+ # 处理后的数据以 '/' 重新拼接,也可以改成每行一个
52
+ output_content = '/'.join(processed)
53
+
54
+ # 将处理后的数据写入输出文件
55
+ with open(output_filename, 'w', encoding='utf-8') as f:
56
+ f.write(output_content)
57
+
58
+ print(f'处理完成,结果保存在 {output_filename}')
59
+
60
+
61
+ if __name__ == '__main__':
62
+ main()
code/deepseek_vaule.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import openai
3
+ from openai import APIError
4
+ from typing import Dict, List, Union
5
+
6
+ # 自定义异常
7
+ class InsufficientBalanceError(Exception):
8
+ pass
9
+
10
+ class EvaluationError(Exception):
11
+ pass
12
+
13
+ # 系统提示词 - 更详细的评分标准
14
+ SYS_PROMPT = """你是一个专业的文本质量评估专家。请根据以下标准对文本进行评分(满分10分):
15
+ 1. 创意性(权重25%): 内容的原创性和新颖性
16
+ 2. 文采(权重25%): 语言表达的优美程度和修辞手法
17
+ 3. 格式(权重25%): 结构清晰度、可读性和符合要求的格式
18
+ 4. 长度(权重25%): 内容长度是否适中(50-300字为佳)
19
+ 5. 总分(根据四个维度进行加权计算)
20
+
21
+ 评分要求:
22
+ - 使用表格形式输出,得到得分表格.
23
+ - 每项评分保留1位小数
24
+ - 最后简要对目标文本的评价,而不是让你自己再写一个,切记
25
+ """
26
+
27
+ def evaluate_text_quality(
28
+ text: str,
29
+ api_key: str = None,
30
+ model: str = "deepseek-chat",
31
+ temperature: float = 0.3,
32
+ max_tokens: int = 300
33
+ ) -> Dict[str, Union[float, str]]:
34
+ # 获取API密钥
35
+ api_key = api_key or os.getenv("you_api_key")
36
+ if not api_key:
37
+ raise ValueError("DeepSeek API密钥未提供")
38
+
39
+ # 创建客户端
40
+ client = openai.OpenAI(
41
+ api_key=api_key,
42
+ base_url="https://api.deepseek.com/v1"
43
+ )
44
+
45
+ try:
46
+ # 调用API
47
+ response = client.chat.completions.create(
48
+ model=model,
49
+ messages=[
50
+ {"role": "system", "content": SYS_PROMPT},
51
+ {"role": "user", "content": text}
52
+ ],
53
+ temperature=temperature,
54
+ max_tokens=max_tokens,
55
+ stream=False
56
+ )
57
+
58
+ # 解析结果
59
+ output = response.choices[0].message.content.strip()
60
+
61
+ # 从API响应中提取评分
62
+ return parse_evaluation_result(output)
63
+
64
+ except APIError as e:
65
+ if e.status_code == 402: # 假设402为余额不足状态码
66
+ raise InsufficientBalanceError("API余额不足,请充值") from e
67
+ else:
68
+ raise EvaluationError(f"API错误[{e.status_code}]: {e.message}") from e
69
+ except Exception as e:
70
+ raise EvaluationError(f"评估失败: {str(e)}") from e
71
+
72
+ def parse_evaluation_result(output: str) -> Dict[str, Union[float, str]]:
73
+ """
74
+ 改进后的评估结果解析函数,能更好处理中文评分表格
75
+ """
76
+ result = {
77
+ "scores": {
78
+ "creativity": 0.0,
79
+ "language": 0.0,
80
+ "format": 0.0,
81
+ "length": 0.0,
82
+ "total": 0.0
83
+ },
84
+ "evaluation": output # 默认保留全部输出
85
+ }
86
+
87
+ # 改进的表格解析逻辑
88
+ lines = [line.strip() for line in output.split('\n') if line.strip()]
89
+
90
+ for line in lines:
91
+ # 处理创意性评分
92
+ if "创意性" in line:
93
+ result["scores"]["creativity"] = extract_score_from_line(line)
94
+ # 处理文采评分
95
+ elif any(key in line for key in ["文采", "语言表达"]):
96
+ result["scores"]["language"] = extract_score_from_line(line)
97
+ # 处理格式评分
98
+ elif "格式" in line:
99
+ result["scores"]["format"] = extract_score_from_line(line)
100
+ # 处理长度评分
101
+ elif "长度" in line:
102
+ result["scores"]["length"] = extract_score_from_line(line)
103
+ # 处理总分
104
+ elif any(key in line for key in ["总分", "总计", "平均"]):
105
+ result["scores"]["total"] = extract_score_from_line(line)
106
+
107
+ # 提取评价部分(从"评价:"之后的内容)
108
+ evaluation_lines = []
109
+ found_evaluation = False
110
+ for line in lines:
111
+ if any(prefix in line for prefix in ["评价:", "评语:", "总结:"]):
112
+ found_evaluation = True
113
+ line = line.split(":", 1)[-1].strip()
114
+ if found_evaluation and line:
115
+ evaluation_lines.append(line)
116
+
117
+ if evaluation_lines:
118
+ result["evaluation"] = "\n".join(evaluation_lines)
119
+
120
+ return result
121
+
122
+ def extract_score_from_line(line: str) -> float:
123
+ """
124
+ 改进的分数提取函数,能处理多种表格格式
125
+ """
126
+ try:
127
+ # 处理 | 创意性 | 8.5 | 这种格式
128
+ if "|" in line:
129
+ parts = [p.strip() for p in line.split("|") if p.strip()]
130
+ for part in parts:
131
+ if part.replace('.', '').isdigit():
132
+ return float(part)
133
+
134
+ # 处理 "创意性: 8.5" 这种格式
135
+ if ":" in line or ":" in line:
136
+ parts = line.split(":", 1) if ":" in line else line.split(":", 1)
137
+ num_part = parts[-1].strip()
138
+ for s in num_part.split():
139
+ s = s.replace('/', '').replace('分', '')
140
+ if s.replace('.', '').isdigit():
141
+ return float(s)
142
+
143
+ # 直接搜索数字
144
+ for word in line.split():
145
+ word = word.replace('分', '').replace('/', '')
146
+ if word.replace('.', '').isdigit():
147
+ return float(word)
148
+
149
+ except (ValueError, IndexError):
150
+ pass
151
+
152
+ return 0.0
153
+
154
+
155
+ def print_evaluation_result(
156
+ evaluation: Dict[str, Union[float, str]],
157
+ show_details: bool = True,
158
+ score_only: bool = False
159
+ ) -> None:
160
+ """
161
+ 打印评估结果
162
+
163
+ 参数:
164
+ evaluation: evaluate_text_quality返回的评估结果字典
165
+ show_details: 是否显示详细评价
166
+ score_only: 是否仅显示分数(优先级高于show_details)
167
+ """
168
+ if not evaluation:
169
+ print("无有效评估结果")
170
+ return
171
+
172
+ scores = evaluation.get("scores", {})
173
+ evaluation_text = evaluation.get("evaluation", "")
174
+
175
+ # 打印分数摘要
176
+ print("\n=== 文本质量评估 ===")
177
+ print(f"[创意性] {scores.get('creativity', 0.0):.1f}/10")
178
+ print(f"[文采] {scores.get('language', 0.0):.1f}/10")
179
+ print(f"[格式] {scores.get('format', 0.0):.1f}/10")
180
+ print(f"[长度] {scores.get('length', 0.0):.1f}/10")
181
+ print("-" * 25)
182
+ print(f"[总分] {scores.get('total', 0.0):.1f}/10")
183
+
184
+ # 根据参数决定是否显示详细评价
185
+ if not score_only and show_details and evaluation_text:
186
+ print("\n=== 详细评价 ===")
187
+ print(evaluation_text)
code/getCOT.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import openai
3
+ import threading
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from openai import APIError
6
+
7
+ API_KEY = os.getenv("DEEPSEEK_API_KEY", "your_api_key")
8
+
9
+ class ThreadSafeWriter:
10
+ """线程安全写入器"""
11
+ def __init__(self, output_path: str):
12
+ self.file = open(output_path, 'a+', encoding='utf-8')
13
+ self.lock = threading.Lock()
14
+ self.counter = 0
15
+
16
+ def write_line(self, content: str):
17
+ with self.lock:
18
+ self.file.write(content + '\n')
19
+ self.file.flush()
20
+ self.counter += 1
21
+
22
+ def get_progress(self):
23
+ with self.lock:
24
+ return self.counter
25
+
26
+ def close(self):
27
+ self.file.close()
28
+
29
+ class DeepSeekBatchProcessor:
30
+ def __init__(self, max_workers: int = 100):
31
+ self.client = openai.OpenAI(
32
+ api_key=API_KEY,
33
+ base_url="https://api.deepseek.com/v1"
34
+ )
35
+ self.max_workers = max_workers
36
+ self.error_flag = threading.Event()
37
+ self.rate_limiter = threading.Semaphore(20)
38
+
39
+ def process_batch(self, batch, writer: ThreadSafeWriter):
40
+ """批量处理,每个任务单独线程"""
41
+ futures = []
42
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
43
+ for line_num, line in batch:
44
+ if self.error_flag.is_set():
45
+ break
46
+ futures.append(
47
+ executor.submit(
48
+ self._process_single_line,
49
+ line_num,
50
+ line,
51
+ writer
52
+ )
53
+ )
54
+ for future in futures:
55
+ future.result()
56
+
57
+ def _process_single_line(self, line_num: int, line: str, writer: ThreadSafeWriter):
58
+ if self.error_flag.is_set():
59
+ return
60
+
61
+ # 支持英文冒号(:)和中文全角冒号(:)
62
+ separator = None
63
+ if ':' in line:
64
+ separator = ':'
65
+ elif ':' in line:
66
+ separator = ':'
67
+
68
+ if not separator:
69
+ print(f"\n行 {line_num} 格式错误")
70
+ writer.write_line(f"格式错误:{line}")
71
+ return
72
+
73
+ keywords_part, original_text = line.split(separator, 1)
74
+ # 这里只提取关键词部分(例如“风,雾,寂寞”)
75
+ keywords = [kw.strip() for kw in keywords_part.split(",") if kw.strip()]
76
+ if not keywords:
77
+ keywords = ["无关键词"]
78
+
79
+ # 构造提示:根据关键词生成诗歌
80
+ prompt = "请根据以下关键词写一首诗:" + ",".join(keywords)
81
+ messages = [{"role": "user", "content": prompt}]
82
+
83
+ retries = 0
84
+ while retries < 3 and not self.error_flag.is_set():
85
+ try:
86
+ with self.rate_limiter:
87
+ response = self.client.chat.completions.create(
88
+ model="deepseek-reasoner",
89
+ messages=messages,
90
+ temperature=0.1
91
+ )
92
+ # 提取返回中的思考过程和诗歌原文
93
+ reasoning_content = response.choices[0].message.reasoning_content.replace('\n', '').replace('\r', '')
94
+ poem_original = response.choices[0].message.content.replace('\n', '/').replace('\r', '')
95
+ # 拼接最终结果:关键词<think>思考过程</think>:诗歌原文
96
+ final_line = f"{','.join(keywords)}<think>{reasoning_content}</think>:{poem_original}"
97
+ writer.write_line(final_line)
98
+ progress = writer.get_progress()
99
+ print(f"\r已处理 {progress} 条", end='')
100
+ break
101
+
102
+ except APIError as e:
103
+ if e.status_code == 402:
104
+ print(f"\n行 {line_num} 处理失败:API余额不足")
105
+ self.error_flag.set()
106
+ return
107
+ elif e.status_code == 429:
108
+ print(f"\n行 {line_num} 速率受限,重试中...")
109
+ retries += 1
110
+ if retries >= 3:
111
+ print(f"\n行 {line_num} 重试次数耗尽")
112
+ else:
113
+ print(f"\n行 {line_num} API错误[{e.status_code}]:{e.message}")
114
+ return
115
+
116
+ except Exception as e:
117
+ print(f"\n行 {line_num} 处理异常:{str(e)}")
118
+ retries += 1
119
+ if retries >= 3:
120
+ print(f"\n行 {line_num} 重试次数耗尽")
121
+
122
+ if retries >= 3 and not self.error_flag.is_set():
123
+ writer.write_line(f"处理失败:{line}")
124
+
125
+ def process_first_1000_lines(input_path: str, output_path: str, max_workers: int = 100):
126
+ """仅读取前1000行数据,并使用多线程处理"""
127
+ processor = DeepSeekBatchProcessor(max_workers)
128
+ writer = ThreadSafeWriter(output_path)
129
+ batch = []
130
+ try:
131
+ with open(input_path, 'r', encoding='utf-8') as f:
132
+ for line_num, line in enumerate(f, 1):
133
+ if not line.strip():
134
+ continue
135
+ batch.append( (line_num, line.strip()) )
136
+ if line_num >= 1000:
137
+ break
138
+
139
+ total = len(batch)
140
+ print(f"总数据量:{total} 条")
141
+ processor.process_batch(batch, writer)
142
+ print("\n处理完成!")
143
+ finally:
144
+ writer.close()
145
+
146
+ if __name__ == '__main__':
147
+ input_file = "data/DSdata.txt"
148
+ output_file = "data/CoTdata.txt"
149
+ process_first_1000_lines(input_file, output_file, max_workers=100)
code/reward.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ from typing import List, Dict, Union, Optional
4
+ from sentence_transformers import SentenceTransformer, util
5
+ from multiprocessing import Pool, cpu_count
6
+
7
+ # 全局初始化 SentenceTransformer 模型,并移动到 GPU
8
+ embedder = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2').to("cuda")
9
+
10
+
11
+ def compute_rewards(
12
+ completions: List[str],
13
+ min_len: Union[int, List[int]] = 100,
14
+ max_len: Union[int, List[int]] = 300,
15
+ weights: Union[tuple, List[tuple]] = (0.25, 0.25, 0.25, 0.25),
16
+ return_components: bool = False,
17
+ **kwargs
18
+ ) -> Union[List[float], Dict[str, List[float]]]:
19
+ """并行优化的奖励计算函数"""
20
+ keywords = kwargs["keywords"]
21
+ n_samples = len(completions)
22
+
23
+ min_len = _to_list(min_len, n_samples)
24
+ max_len = _to_list(max_len, n_samples)
25
+ weights = _to_list(weights, n_samples)
26
+
27
+ # 并行计算各子奖励
28
+ with Pool(cpu_count()) as pool:
29
+ length_rewards = pool.starmap(_length_reward, zip(completions, min_len, max_len))
30
+ format_rewards = pool.map(_format_reward, completions)
31
+ keyword_rewards = _batch_keyword_reward(completions, keywords) # 这个用 GPU 计算
32
+ language_rewards = pool.map(_language_reward, completions)
33
+
34
+ # 加权求和总奖励
35
+ total_rewards = [
36
+ w[0] * lr + w[1] * fr + w[2] * kr + w[3] * lang_r
37
+ for w, lr, fr, kr, lang_r in zip(weights, length_rewards, format_rewards, keyword_rewards, language_rewards)
38
+ ]
39
+
40
+ if return_components:
41
+ return {
42
+ "rewards": total_rewards,
43
+ "length_rewards": length_rewards,
44
+ "format_rewards": format_rewards,
45
+ "keyword_rewards": keyword_rewards,
46
+ "language_rewards": language_rewards,
47
+ }
48
+ return total_rewards
49
+
50
+
51
+ # -------------- 并行子函数 --------------
52
+ def _to_list(val: Union[any, List[any]], n: int) -> List[any]:
53
+ """转换为样本级列表"""
54
+ return val if isinstance(val, list) else [val] * n
55
+
56
+
57
+ def _length_reward(text: str, min_len: int, max_len: int) -> float:
58
+ """单样本长度奖励"""
59
+ original = text.split("</think>:", 1)[1].strip() if "</think>:" in text else text.strip()
60
+ length = len(original)
61
+
62
+ if length < min_len:
63
+ return length / min_len + 1 # 1~2线性增长
64
+ elif length > max_len:
65
+ return max_len / length + 1 # 2~1线性衰减
66
+ return 2.0
67
+
68
+
69
+ def _format_reward(text: str) -> float:
70
+ """单样本格式奖励"""
71
+ if "<think>" not in text or "</think>:" not in text:
72
+ return -2.0
73
+ think_content = text.split("<think>")[1].split("</think>")[0].strip()
74
+ return 2.0 if think_content else -2.0
75
+
76
+
77
+ def _batch_keyword_reward(texts: List[str], keywords_list: List[List[str]]) -> List[float]:
78
+ """批量关键词匹配(优化:使用 GPU 并行计算)"""
79
+ originals = [text.split("</think>:", 1)[1].strip() if "</think>:" in text else text.strip() for text in texts]
80
+ valid_indices = [i for i, orig in enumerate(originals) if orig and keywords_list[i]]
81
+
82
+ if not valid_indices:
83
+ return [0.8 if not kw else -2.0 for kw in keywords_list] # 无关键词时默认0.8
84
+
85
+ valid_originals = [originals[i] for i in valid_indices]
86
+ valid_keywords = [keywords_list[i] for i in valid_indices]
87
+
88
+ # 让计算在 GPU 上执行
89
+ original_embs = embedder.encode(valid_originals, convert_to_tensor=True)
90
+ keyword_embs = [embedder.encode(kw, convert_to_tensor=True) for kw in valid_keywords]
91
+
92
+ similarities = [
93
+ util.pytorch_cos_sim(orig_emb, kw_emb).mean().item()
94
+ for orig_emb, kw_emb in zip(original_embs, keyword_embs)
95
+ ]
96
+
97
+ # 分配奖励
98
+ rewards = []
99
+ sim_idx = 0
100
+ for i, kw in enumerate(keywords_list):
101
+ if i in valid_indices:
102
+ sim = similarities[sim_idx]
103
+ rewards.append(2.0 if sim >= 0.6 else (1.2 if sim >= 0.4 else 0.8))
104
+ sim_idx += 1
105
+ else:
106
+ rewards.append(0.8 if not kw else -2.0)
107
+ return rewards
108
+
109
+
110
+ def _language_reward(text: str) -> float:
111
+ """单样本语言奖励"""
112
+ chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
113
+ ratio = chinese_chars / max(1, len(text))
114
+
115
+ if ratio >= 0.9:
116
+ return 2.0
117
+ elif ratio >= 0.7:
118
+ return 1.4
119
+ return 0.7
120
+
121
+
122
+ # ------------ 运行示例 ------------
123
+ if __name__ == "__main__":
124
+ samples = [
125
+ "科技<think>技术创新是关键</think>:人工智能在医疗领域的应用正在改变诊断方式。",
126
+ "无效样本<think></think>:无意义内容",
127
+ "经济<think>宏观经济分析</think>:全球供应链重构对发展中国家影响深远。"
128
+ ]
129
+ keywords = [
130
+ ["科技", "人工智能"],
131
+ [], # 空关键词
132
+ ["经济", "供应链"]
133
+ ]
134
+
135
+ # 并行计算
136
+ rewards = compute_rewards(
137
+ completions=samples,
138
+ keywords=keywords,
139
+ min_len=[50, 10, 80],
140
+ return_components=True
141
+ )
142
+
143
+ print("总奖励:", rewards["rewards"])
144
+ print("长度奖励:", rewards["length_rewards"])
145
+ print("格式奖励:", rewards["format_rewards"])
146
+ print("关键词奖励:", rewards["keyword_rewards"])
147
+ print("语言奖励:", rewards["language_rewards"])
code/test.ipynb ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "initial_id",
7
+ "metadata": {
8
+ "ExecuteTime": {
9
+ "end_time": "2025-04-02T06:42:56.681032Z",
10
+ "start_time": "2025-04-02T06:42:19.346090Z"
11
+ },
12
+ "collapsed": true
13
+ },
14
+ "outputs": [],
15
+ "source": [
16
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
17
+ "from peft import PeftModel\n",
18
+ "# 1. 加载基础模型和LoRA适配器\n",
19
+ "base_model = AutoModelForCausalLM.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\")#您也可以使用GPU推理\n",
20
+ "model = PeftModel.from_pretrained(base_model, \"../DS_RL_model\") # .to(\"cuda\")使用GPU加速推理\n",
21
+ "tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\")\n",
22
+ "tokenizer.pad_token = tokenizer.eos_token\n",
23
+ "\n"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 5,
29
+ "id": "c805805aeaabd6a8",
30
+ "metadata": {
31
+ "ExecuteTime": {
32
+ "end_time": "2025-04-02T07:41:36.715675Z",
33
+ "start_time": "2025-04-02T07:41:27.640736Z"
34
+ }
35
+ },
36
+ "outputs": [
37
+ {
38
+ "name": "stdout",
39
+ "output_type": "stream",
40
+ "text": [
41
+ "模型输出: 根据以下关键词生成一首歌词,歌词中包含多个句子,确保句子通顺,诗意,格式正确.让我们一步一步的思考(思考过程包含在<think>和</think>之间):月光,欢乐,伊甸园,月光下的欢乐,小猪们、小羊们,月光下的欢乐。月光下的欢乐,小猪们、小羊们,月光下的欢乐,小猪们、小羊们,月光下的欢乐,月光下的欢乐,小猪们、小羊们,月光下的欢乐。月光,小猪们、小羊们,月光下的欢乐,月光下的欢乐,月光下的欢乐,月光下的欢乐,月光下的欢乐,月光下的欢乐,月光下的欢乐,月光下的欢乐,月光下的欢乐。\n",
42
+ "\n",
43
+ "嗯,我现在需要帮用户生成一首关于月光、欢乐、伊甸园的歌词。用户给了一个比较长的查询,里面有很多重复的句子,可能想要更简洁或者更流畅的歌词。我得先理解用户的需求,可能他们是在做一个儿童文学作品,或者是在学习如何创作歌词。\n",
44
+ "\n",
45
+ "首先,关键词有月光、欢乐、伊甸园、小猪、小羊。所以歌词里要包含这些元素。用户给出的回复里有很多重复,可能是因为想通过多个句子来强调主题,让读者更容易理解和记忆。\n",
46
+ "\n",
47
+ "我需要确保歌词结构合理,有起承转合,句子通顺。可能用户希望歌词有一定的押韵和节奏感,这样读起来更顺口。同时,格式要正确,可能需要遵循中文诗歌的格式,比如分句、押韵等。\n",
48
+ "\n",
49
+ "另外,用户提供的回复是多次重复的句子,可能是因为想强调月光下的欢乐,让读者感受到那种温馨和欢乐。我需要在生成歌词时,把这些元素自然地融入进去,而不是单纯地重复。\n",
50
+ "\n",
51
+ "我还得考虑歌词的情感基调,是欢快的还是带有感慨的。用户没有特别说明,但关键词中提到“欢乐”和“月光”,感觉偏向于积极向上的情感。\n",
52
+ "\n",
53
+ "可能需要避免过于复杂的结构,保持歌词简洁明了,同时有足够的意象来传达主题。比如,用“月光下的欢乐”这样的词句,可以增强画面感,让读者有身临其境的感觉。\n",
54
+ "\n",
55
+ "另外,用户提到了小猪和小羊,可能是在描绘一个小动物们的场景,或者是在描述一个充满欢乐的小世界。可能需要把这些元素融合在歌词中,让读者感受到那种温暖和快乐。\n",
56
+ "\n",
57
+ "我还需要注意押韵,虽然中文诗歌不一定严格押韵,但要有一定的节奏感。选择合适的结尾词来增强主题的表达。\n",
58
+ "\n",
59
+ "总的来说,我需要把月光、欢乐、小猪、小羊、伊甸园这几个元素有机地融入歌词中,确保结构合理,情感流畅,同时保持格式正确。可能需要多试几遍,调整用词和句式,直到满意为止。\n",
60
+ "</think>\n",
61
+ "\n",
62
+ "## 《月光下的欢乐》\n",
63
+ "\n",
64
+ "月光如水般温柔,\n",
65
+ "在掌心流淌着幸福的泪。\n",
66
+ "\n",
67
+ "小猪们、小羊们,\n",
68
+ "在伊甸园里跳跃舞。\n",
69
+ "月光下欢声笑语,\n",
70
+ "欢声笑语映照着我们的脸。\n",
71
+ "\n",
72
+ "月光下欢声笑语,\n",
73
+ "月光下欢声笑语,\n",
74
+ "月光下欢声笑语,\n",
75
+ "月光下欢声笑语。\n",
76
+ "\n",
77
+ "月光下欢声笑语,\n",
78
+ "月光下欢声笑语,\n",
79
+ "月光下欢声笑语,\n",
80
+ "月光下欢声笑语。\n",
81
+ "\n",
82
+ "月光下欢声笑语,\n",
83
+ "月光下欢声笑语,\n",
84
+ "月光下欢声笑语,\n",
85
+ "月光下欢声笑语。\n",
86
+ "\n",
87
+ "月光下欢声笑语,\n",
88
+ "月光下欢声笑语,\n",
89
+ "月光下欢声笑语,\n",
90
+ "月光下欢声笑语。\n",
91
+ "\n",
92
+ "月光下欢声笑语,\n",
93
+ "月光下欢声笑语,\n",
94
+ "月光下欢声笑语,\n",
95
+ "月光下欢声笑语。\n"
96
+ ]
97
+ }
98
+ ],
99
+ "source": [
100
+ "# 2. 准备提示词\n",
101
+ "prompt = \"根据以下关键词生成一首歌词,歌词中包含多个句子,确保句子通顺,诗意,格式正确.让我们一步一步的思考(思考过程包含在<think>和</think>之间):月光,欢乐,伊甸园\" \n",
102
+ "\n",
103
+ "# 3. 编码并生成回复\n",
104
+ "inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
105
+ "\n",
106
+ "# 4. 生成参数设置\n",
107
+ "outputs = model.generate(\n",
108
+ " input_ids=inputs[\"input_ids\"],\n",
109
+ " attention_mask=inputs[\"attention_mask\"],\n",
110
+ " max_new_tokens=2048, # 生成的最大token数\n",
111
+ " do_sample=True, # 启用随机采样\n",
112
+ " temperature=0.9, # 控制随机性 (0.1-1.0)\n",
113
+ " top_p=0.9, # nucleus sampling参数\n",
114
+ " pad_token_id=tokenizer.eos_token_id\n",
115
+ ")\n",
116
+ "\n",
117
+ "# 5. 解码并打印结果\n",
118
+ "generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
119
+ "print(\"模型输出:\", generated_text)"
120
+ ]
121
+ }
122
+ ],
123
+ "metadata": {
124
+ "kernelspec": {
125
+ "display_name": ".venv",
126
+ "language": "python",
127
+ "name": "python3"
128
+ },
129
+ "language_info": {
130
+ "codemirror_mode": {
131
+ "name": "ipython",
132
+ "version": 3
133
+ },
134
+ "file_extension": ".py",
135
+ "mimetype": "text/x-python",
136
+ "name": "python",
137
+ "nbconvert_exporter": "python",
138
+ "pygments_lexer": "ipython3",
139
+ "version": "3.11.3"
140
+ }
141
+ },
142
+ "nbformat": 4,
143
+ "nbformat_minor": 5
144
+ }
code/threads_data_extract.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import openai
3
+ import threading
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from openai import APIError
6
+ from typing import List, Tuple
7
+
8
+ class FormatValidator:
9
+ """数据格式验证器"""
10
+ @staticmethod
11
+ def validate_line(keywords: List[str], original: str) -> str:
12
+ """
13
+ 格式:关键词1,关键词2,关键词3:原文
14
+ """
15
+ # 清洗关键词中的非法符号
16
+ cleaned_keywords = [
17
+ kw.strip().replace(':', '').replace('\n', '')[:10] # 限制关键词长度
18
+ for kw in keywords if kw.strip()
19
+ ][:3] # 最多取前3个关键词
20
+
21
+ # 处理空关键词情况
22
+ if not cleaned_keywords:
23
+ keywords_str = "无关键词"
24
+ else:
25
+ keywords_str = ",".join(cleaned_keywords)
26
+
27
+ # 移除原文中的换行符
28
+ cleaned_original = original.strip().replace('\n', ' ')
29
+ return f"{keywords_str}:{cleaned_original}"
30
+
31
+ class ThreadSafeWriter:
32
+ """增强型线程安全写入器"""
33
+ def __init__(self, output_path: str):
34
+ self.file = open(output_path, 'a+', encoding='utf-8')
35
+ self.lock = threading.Lock()
36
+ self.counter = 0 # 写入计数器
37
+
38
+ def write_line(self, content: str):
39
+ with self.lock:
40
+ self.file.write(content + '\n')
41
+ self.file.flush()
42
+ self.counter += 1
43
+
44
+ def get_progress(self):
45
+ with self.lock:
46
+ return self.counter
47
+
48
+ def close(self):
49
+ self.file.close()
50
+
51
+ class DeepSeekBatchProcessor:
52
+ def __init__(self, max_workers: int = 100):
53
+ self.client = openai.OpenAI(
54
+ api_key=os.getenv("DEEPSEEK_API_KEY", "sk-4da7e956235447e3b7bec1b51f5a3db7"),
55
+ base_url="https://api.deepseek.com"
56
+ )
57
+ self.max_workers = max_workers
58
+ self.error_flag = threading.Event()
59
+ self.rate_limiter = threading.Semaphore(20) # API速率限制
60
+
61
+ def process_batch(self, batch: List[Tuple[int, str]], writer: ThreadSafeWriter):
62
+ """批量处理并保持顺序"""
63
+ futures = []
64
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
65
+ for line_num, original in batch:
66
+ if self.error_flag.is_set():
67
+ break
68
+ futures.append(
69
+ executor.submit(
70
+ self._process_single_line,
71
+ line_num,
72
+ original,
73
+ writer
74
+ )
75
+ )
76
+
77
+ # 等待当前批次完成
78
+ for future in futures:
79
+ future.result()
80
+
81
+ def _process_single_line(self, line_num: int, original: str, writer: ThreadSafeWriter):
82
+ if self.error_flag.is_set():
83
+ return
84
+
85
+ retries = 0
86
+ while retries < 3 and not self.error_flag.is_set():
87
+ try:
88
+ with self.rate_limiter:
89
+ response = self.client.chat.completions.create(
90
+ model="deepseek-reasoner",
91
+ messages=[
92
+ {"role": "system", "content": self._get_prompt()},
93
+ {"role": "user", "content": original}
94
+ ],
95
+ temperature=0.1,
96
+ max_tokens=30
97
+ )
98
+
99
+ # 解析响应
100
+ keywords = self._parse_response(response)
101
+ formatted_line = FormatValidator.validate_line(keywords, original)
102
+ writer.write_line(formatted_line)
103
+
104
+ # 更新进度
105
+ progress = writer.get_progress()
106
+ print(f"\r已处理 {progress} 条", end='')
107
+
108
+ break # 成功时退出重试循环
109
+
110
+ except APIError as e:
111
+ if e.status_code == 402: # 余额不足
112
+ print(f"\n行 {line_num} 处理失败:API余额不足")
113
+ self.error_flag.set()
114
+ return
115
+ elif e.status_code == 429: # 速率限制
116
+ print(f"\n行 {line_num} 速率受限,重试中...")
117
+ retries += 1
118
+ if retries >= 3:
119
+ print(f"行 {line_num} 重试次数耗尽")
120
+ else:
121
+ print(f"\n行 {line_num} API错误[{e.status_code}]:{e.message}")
122
+ return # 其他API错误不重试
123
+
124
+ except Exception as e:
125
+ print(f"\n行 {line_num} 处理异常:{str(e)}")
126
+ retries += 1
127
+ if retries >= 3:
128
+ print(f"行 {line_num} 重试次数耗尽")
129
+
130
+ # 重试失败处理
131
+ if retries >= 3 and not self.error_flag.is_set():
132
+ writer.write_line(f"处理失败:{original}") # 记录失败行
133
+
134
+ @staticmethod
135
+ def _get_prompt() -> str:
136
+ return
137
+
138
+ @staticmethod
139
+ def _parse_response(response) -> List[str]:
140
+ content = response.choices[0].message.content.strip()
141
+ return [kw.strip("。、") for kw in content.replace(',', ',').split(',') if kw]
142
+
143
+ def process_large_file(
144
+ input_path: str,
145
+ output_path: str,
146
+ batch_size: int = 500,
147
+ max_workers: int = 100
148
+ ):
149
+ """大文件处理入口"""
150
+ # 初始化组件
151
+ processor = DeepSeekBatchProcessor(max_workers)
152
+ writer = ThreadSafeWriter(output_path)
153
+
154
+ try:
155
+ # 读取并批处理数据
156
+ with open(input_path, 'r', encoding='utf-8') as f:
157
+ # 生成带行号的批次 [(行号, 内容), ...]
158
+ batches = []
159
+ current_batch = []
160
+ for line_num, line in enumerate(f, 1):
161
+ if line.strip():
162
+ current_batch.append( (line_num, line.strip()) )
163
+ if len(current_batch) >= batch_size:
164
+ batches.append(current_batch)
165
+ current_batch = []
166
+ if current_batch:
167
+ batches.append(current_batch)
168
+
169
+ # 按批次处理(保持批次顺序)
170
+ total = sum(len(b) for b in batches)
171
+ print(f"总数据量:{total}条")
172
+
173
+ for batch in batches:
174
+ if processor.error_flag.is_set():
175
+ break
176
+ processor.process_batch(batch, writer)
177
+
178
+ print("\n处理完成!")
179
+
180
+ finally:
181
+ writer.close()
182
+
183
+ if __name__ == '__main__':
184
+ # 文件路径配置
185
+ input_file = "data\DSdata.txt"
186
+ output_file = "data\CoTdata.txt"
187
+
188
+ # 启动处理流程
189
+ process_large_file(
190
+ input_path=input_file,
191
+ output_path=output_file,
192
+ batch_size=500,
193
+ max_workers=100
194
+ )
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PyQt5>=5.15
2
+ transformers>=4.30
3
+ peft>=0.15
4
+ torch>=2.0
5
+ numpy
6
+ matplotlib
7
+ jupyter
8
+ trl
9
+ datasets
10
+ accelerate
11
+ safetensors
12
+ scipy
13
+ tqdm
14
+ tensorboard
15
+ sentence-transformers
16
+ openai