Upload 30 files
Browse files- DS_LoRA/README.md +202 -0
- DS_LoRA/adapter_config.json +34 -0
- DS_LoRA/adapter_model.safetensors +3 -0
- DS_RL_model/README.md +202 -0
- DS_RL_model/adapter_config.json +38 -0
- DS_RL_model/adapter_model.safetensors +3 -0
- Qwen_CoT_LoRA/README.md +202 -0
- Qwen_CoT_LoRA/adapter_config.json +35 -0
- Qwen_CoT_LoRA/adapter_model.safetensors +3 -0
- Qwen_LoRA/README.md +202 -0
- Qwen_LoRA/adapter_config.json +34 -0
- Qwen_LoRA/adapter_model.safetensors +3 -0
- code/GRPO.ipynb +460 -0
- code/LORA.py +89 -0
- code/LORA_with_CoT.py +119 -0
- code/UI.py +181 -0
- code/_MyModel.py +24 -0
- code/__main__.py +15 -0
- code/__pycache__/UI.cpython-311.pyc +0 -0
- code/__pycache__/_MyModel.cpython-311.pyc +0 -0
- code/__pycache__/deepseek_vaule.cpython-311.pyc +0 -0
- code/__pycache__/reward.cpython-311.pyc +0 -0
- code/__pycache__/train_nessary.cpython-311.pyc +0 -0
- code/data_process.py +62 -0
- code/deepseek_vaule.py +187 -0
- code/getCOT.py +149 -0
- code/reward.py +147 -0
- code/test.ipynb +144 -0
- code/threads_data_extract.py +194 -0
- requirements.txt +16 -0
DS_LoRA/README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
| 3 |
+
library_name: peft
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
### Framework versions
|
| 201 |
+
|
| 202 |
+
- PEFT 0.15.0
|
DS_LoRA/adapter_config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 32,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0.1,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": null,
|
| 22 |
+
"peft_type": "LORA",
|
| 23 |
+
"r": 8,
|
| 24 |
+
"rank_pattern": {},
|
| 25 |
+
"revision": null,
|
| 26 |
+
"target_modules": [
|
| 27 |
+
"q_proj",
|
| 28 |
+
"v_proj"
|
| 29 |
+
],
|
| 30 |
+
"task_type": "CAUSAL_LM",
|
| 31 |
+
"trainable_token_indices": null,
|
| 32 |
+
"use_dora": false,
|
| 33 |
+
"use_rslora": false
|
| 34 |
+
}
|
DS_LoRA/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1b03770158458f7f2108bad02721e94cba1a8935ebe9e81038129be01a5f03bc
|
| 3 |
+
size 4372840
|
DS_RL_model/README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
| 3 |
+
library_name: peft
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
### Framework versions
|
| 201 |
+
|
| 202 |
+
- PEFT 0.15.1
|
DS_RL_model/adapter_config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": {
|
| 4 |
+
"base_model_class": "PeftModelForCausalLM",
|
| 5 |
+
"parent_library": "peft.peft_model"
|
| 6 |
+
},
|
| 7 |
+
"base_model_name_or_path": null,
|
| 8 |
+
"bias": "none",
|
| 9 |
+
"corda_config": null,
|
| 10 |
+
"eva_config": null,
|
| 11 |
+
"exclude_modules": null,
|
| 12 |
+
"fan_in_fan_out": false,
|
| 13 |
+
"inference_mode": true,
|
| 14 |
+
"init_lora_weights": true,
|
| 15 |
+
"layer_replication": null,
|
| 16 |
+
"layers_pattern": null,
|
| 17 |
+
"layers_to_transform": null,
|
| 18 |
+
"loftq_config": {},
|
| 19 |
+
"lora_alpha": 16,
|
| 20 |
+
"lora_bias": false,
|
| 21 |
+
"lora_dropout": 0.0,
|
| 22 |
+
"megatron_config": null,
|
| 23 |
+
"megatron_core": "megatron.core",
|
| 24 |
+
"modules_to_save": null,
|
| 25 |
+
"peft_type": "LORA",
|
| 26 |
+
"r": 8,
|
| 27 |
+
"rank_pattern": {},
|
| 28 |
+
"revision": null,
|
| 29 |
+
"target_modules": [
|
| 30 |
+
"k_proj",
|
| 31 |
+
"q_proj",
|
| 32 |
+
"v_proj"
|
| 33 |
+
],
|
| 34 |
+
"task_type": null,
|
| 35 |
+
"trainable_token_indices": null,
|
| 36 |
+
"use_dora": false,
|
| 37 |
+
"use_rslora": false
|
| 38 |
+
}
|
DS_RL_model/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:92cda569f4ae4f21ecbd0790b623428f349855221041023c8f578c0b188d6ac2
|
| 3 |
+
size 5988672
|
Qwen_CoT_LoRA/README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: Qwen/Qwen2.5-0.5B-Instruct
|
| 3 |
+
library_name: peft
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
### Framework versions
|
| 201 |
+
|
| 202 |
+
- PEFT 0.15.1
|
Qwen_CoT_LoRA/adapter_config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": null,
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 32,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0.1,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": null,
|
| 22 |
+
"peft_type": "LORA",
|
| 23 |
+
"r": 8,
|
| 24 |
+
"rank_pattern": {},
|
| 25 |
+
"revision": null,
|
| 26 |
+
"target_modules": [
|
| 27 |
+
"k_proj",
|
| 28 |
+
"v_proj",
|
| 29 |
+
"q_proj"
|
| 30 |
+
],
|
| 31 |
+
"task_type": "CAUSAL_LM",
|
| 32 |
+
"trainable_token_indices": null,
|
| 33 |
+
"use_dora": false,
|
| 34 |
+
"use_rslora": false
|
| 35 |
+
}
|
Qwen_CoT_LoRA/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dbb7d5e3fea8f9085b091a8e9b37028f06cd23297d9d02a1e7fbf747310e4c86
|
| 3 |
+
size 2970304
|
Qwen_LoRA/README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: /workspace/local_model
|
| 3 |
+
library_name: peft
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
### Framework versions
|
| 201 |
+
|
| 202 |
+
- PEFT 0.15.0
|
Qwen_LoRA/adapter_config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "/workspace/local_model",
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 32,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0.1,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": null,
|
| 22 |
+
"peft_type": "LORA",
|
| 23 |
+
"r": 8,
|
| 24 |
+
"rank_pattern": {},
|
| 25 |
+
"revision": null,
|
| 26 |
+
"target_modules": [
|
| 27 |
+
"v_proj",
|
| 28 |
+
"q_proj"
|
| 29 |
+
],
|
| 30 |
+
"task_type": "CAUSAL_LM",
|
| 31 |
+
"trainable_token_indices": null,
|
| 32 |
+
"use_dora": false,
|
| 33 |
+
"use_rslora": false
|
| 34 |
+
}
|
Qwen_LoRA/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:45708acd3a64f3f203e8eb44855bfc122a62008d16776b5c73672ec6e102eab2
|
| 3 |
+
size 2175168
|
code/GRPO.ipynb
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "e325d4b51fe4fad2",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# 加载模型以及数据"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": 1,
|
| 14 |
+
"id": "initial_id",
|
| 15 |
+
"metadata": {
|
| 16 |
+
"ExecuteTime": {
|
| 17 |
+
"end_time": "2025-03-31T15:57:38.940490Z",
|
| 18 |
+
"start_time": "2025-03-31T15:57:29.198500Z"
|
| 19 |
+
},
|
| 20 |
+
"collapsed": true
|
| 21 |
+
},
|
| 22 |
+
"outputs": [
|
| 23 |
+
{
|
| 24 |
+
"name": "stderr",
|
| 25 |
+
"output_type": "stream",
|
| 26 |
+
"text": [
|
| 27 |
+
"E:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 28 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 29 |
+
]
|
| 30 |
+
}
|
| 31 |
+
],
|
| 32 |
+
"source": [
|
| 33 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
| 34 |
+
"from trl import GRPOTrainer, GRPOConfig # 假设 trl 库中有 GRPOTrainer 模块\n",
|
| 35 |
+
"from peft import PeftModel\n",
|
| 36 |
+
"from reward import compute_rewards\n",
|
| 37 |
+
"def load_data(input_path):\n",
|
| 38 |
+
" data = []\n",
|
| 39 |
+
" with open(input_path, 'r', encoding='utf-8') as f:\n",
|
| 40 |
+
" for line in f:\n",
|
| 41 |
+
" line = line.strip()\n",
|
| 42 |
+
" if not line:\n",
|
| 43 |
+
" continue\n",
|
| 44 |
+
" parts = line.split('<think>', 1)\n",
|
| 45 |
+
" if len(parts) != 2:\n",
|
| 46 |
+
" print(f\"警告: 格式错误的行,已跳过: {line}\")\n",
|
| 47 |
+
" continue\n",
|
| 48 |
+
" keywords_part, lyrics = parts[0], parts[1]\n",
|
| 49 |
+
" keywords = [kw.strip() for kw in keywords_part.split(',')]\n",
|
| 50 |
+
"\n",
|
| 51 |
+
" # 关键修改:使用关键词作为 prompt,歌词作为 completion\n",
|
| 52 |
+
" data.append({\n",
|
| 53 |
+
" 'prompt': \"根据以下关键词生成一首歌词,歌词中包含多个句子,确保句子通顺,诗意,格式正确.让我们一步一步的思考(思考过程包含在<think>和</think>之间):\" + \",\".join(keywords), # 关键词拼接成字符串,作为模型输入\n",
|
| 54 |
+
" 'completion': \"<think>\" + lyrics, # 歌词(去掉多余空格),作为模型输出\n",
|
| 55 |
+
" 'keywords': keywords, # 关键词拼接成字符串,作为模型输入\n",
|
| 56 |
+
" })\n",
|
| 57 |
+
" \n",
|
| 58 |
+
" print(f\"成功加载 {len(data)} 条数据\")\n",
|
| 59 |
+
" return data\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"\n"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "code",
|
| 66 |
+
"execution_count": 2,
|
| 67 |
+
"id": "dc98a63e850836a",
|
| 68 |
+
"metadata": {
|
| 69 |
+
"ExecuteTime": {
|
| 70 |
+
"end_time": "2025-03-31T15:57:38.963191Z",
|
| 71 |
+
"start_time": "2025-03-31T15:57:38.951488Z"
|
| 72 |
+
}
|
| 73 |
+
},
|
| 74 |
+
"outputs": [
|
| 75 |
+
{
|
| 76 |
+
"name": "stdout",
|
| 77 |
+
"output_type": "stream",
|
| 78 |
+
"text": [
|
| 79 |
+
"成功加载 1000 条数据\n",
|
| 80 |
+
"第一条数据: {'prompt': '根据以下关键词生成一首歌词,歌词中包含多个句子,确保句子通顺,诗意,格式正确.让我们一步一步的思考(思考过程包含在<think>和</think>之间):午夜,寒冬,心动', 'completion': '<think>嗯,用户让我根据“午夜,寒冬,心动”这三个关键词写一首诗。首先,我需要理解每个词带来的意象和情感。午夜通常给人一种寂静、神秘的感觉,可能带有孤独或反思的情绪。寒冬则联想到寒冷、冰雪,可能还有萧瑟的景象,但也可以有温暖的反差,比如炉火或者温暖的回忆。心动是情感的核心,可能涉及爱情、突然的情感波动,或者内心的温暖。接下来,我要考虑如何将这三个元素融合在一起。可能需要营造一个寒冷的冬夜环境,然后在这样的背景下引入心动的瞬间,形成对比。比如,用寒冷的景象来衬托内心的温暖或突然的情感变化。然后,结构方面,可能需要分几段,每段围绕一个关键词展开,但又要相互联系。比如,前两段描绘午夜和寒冬的景象,第三段引入心动的元素。或者交替出现,交织在一起。语言风格方面,用户可能希望有诗意和意象化的表达,避免直白。使用比喻、拟人等修辞手法,比如将心跳比作钟摆,或者用冰晶、霜花来表现寒冷。同时,注意押韵和节奏,使诗歌流畅。还需要考虑用户可能的深层需求。他们可能不仅仅想要一首描述这三个词的诗,而是希望表达在寒冷孤独的环境中突然感受到的温暖或爱情,强调对比和情感冲击。因此,诗里需要有转折点,从冷到暖,从静到动。可能会遇到的挑战是如何自然地将三个关键词串联起来,不显得生硬。可能需要找到一个共同的场景或隐喻,比如午夜寒冬中的相遇,或者内心的变化。同时,避免陈词滥调,用新颖的比喻,比如将霜花比作时针,或者心跳与钟摆的共振。最后,检查诗的整体意境是否统一,情感是否连贯,是否有足够的意象让读��产生共鸣。可能需要调整用词和结构,使每个部分都服务于主题,增强整体效果。</think>:《冰裂时刻》/石英血管里霜花正逆向生长/成为时针,成为所有表盘/反向旋转的证词。玻璃在窗框深处/爆出细纹,像候鸟迁徙时/骨骼深处迸发的裂响//壁炉吞下最后一块松木的年轮/灰烬在铁架上凝结成第六种晶体/这时钟摆突然停驻——/某种振频正穿过雪原/穿过二十年冻土下沉默的矿脉//冰层深处传来远古鲸群的应和/我们站在地磁偏转的切线上/听见彼此胸腔里/石英开始顺时针碎裂/春天正从裂缝中/涌出第一滴蓝', 'keywords': ['午夜', '寒冬', '心动']}\n"
|
| 81 |
+
]
|
| 82 |
+
}
|
| 83 |
+
],
|
| 84 |
+
"source": [
|
| 85 |
+
"dataset = load_data('../data/CoTdata.txt')\n",
|
| 86 |
+
"if dataset:\n",
|
| 87 |
+
" print(\"第一条数据:\", dataset[0])\n",
|
| 88 |
+
"else:\n",
|
| 89 |
+
" print(\"未加载到有效数据。\")"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"cell_type": "code",
|
| 94 |
+
"execution_count": 3,
|
| 95 |
+
"id": "17175326d7901595",
|
| 96 |
+
"metadata": {
|
| 97 |
+
"ExecuteTime": {
|
| 98 |
+
"end_time": "2025-03-31T15:57:45.180036Z",
|
| 99 |
+
"start_time": "2025-03-31T15:57:41.329675Z"
|
| 100 |
+
}
|
| 101 |
+
},
|
| 102 |
+
"outputs": [
|
| 103 |
+
{
|
| 104 |
+
"name": "stderr",
|
| 105 |
+
"output_type": "stream",
|
| 106 |
+
"text": [
|
| 107 |
+
"Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.\n"
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"name": "stdout",
|
| 112 |
+
"output_type": "stream",
|
| 113 |
+
"text": [
|
| 114 |
+
"trainable params: 372,736 || all params: 1,777,460,736 || trainable%: 0.0210\n",
|
| 115 |
+
"None\n"
|
| 116 |
+
]
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"name": "stderr",
|
| 120 |
+
"output_type": "stream",
|
| 121 |
+
"text": [
|
| 122 |
+
"E:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\peft\\mapping_func.py:73: UserWarning: You are trying to modify a model with PEFT for a second time. If you want to reload the model with a different config, make sure to call `.unload()` before.\n",
|
| 123 |
+
" warnings.warn(\n",
|
| 124 |
+
"E:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\peft\\tuners\\tuners_utils.py:167: UserWarning: Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!\n",
|
| 125 |
+
" warnings.warn(\n"
|
| 126 |
+
]
|
| 127 |
+
}
|
| 128 |
+
],
|
| 129 |
+
"source": [
|
| 130 |
+
"base_model = AutoModelForCausalLM.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\").to(\"cuda\")\n",
|
| 131 |
+
"# 2. 加载 LoRA 适配器\n",
|
| 132 |
+
"model = PeftModel.from_pretrained(base_model, \"../3_26_LoRA\").to(\"cuda\") # 你的 LoRA 路径\n",
|
| 133 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\")\n",
|
| 134 |
+
"tokenizer.pad_token = tokenizer.eos_token\n",
|
| 135 |
+
"# Load LoRA lora_config = LoraConfig( task_type=\"CAUSAL_LM\", r=16, lora_alpha=32, target_modules=\"all-linear\", ) model = get_peft_model(model, lora_config) print(model.print_trainable_parameters()) \n",
|
| 136 |
+
"from peft import LoraConfig, get_peft_model\n",
|
| 137 |
+
"target_modules = [\"q_proj\", \"k_proj\", \"v_proj\"] \n",
|
| 138 |
+
"lora_config = LoraConfig(\n",
|
| 139 |
+
" r=2, # 秩(可尝试8~32)\n",
|
| 140 |
+
" lora_alpha=32, # 缩放系数(通常设为2*r)\n",
|
| 141 |
+
" target_modules=target_modules, \n",
|
| 142 |
+
" bias=\"none\", # 不训练偏置项\n",
|
| 143 |
+
")\n",
|
| 144 |
+
"model = get_peft_model(model, lora_config) \n",
|
| 145 |
+
"print(model.print_trainable_parameters()) "
|
| 146 |
+
]
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"cell_type": "markdown",
|
| 150 |
+
"id": "d2a601c596644fc",
|
| 151 |
+
"metadata": {},
|
| 152 |
+
"source": [
|
| 153 |
+
"# 配置训练参数"
|
| 154 |
+
]
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"cell_type": "code",
|
| 158 |
+
"execution_count": 4,
|
| 159 |
+
"id": "8ce28487acb606a2",
|
| 160 |
+
"metadata": {
|
| 161 |
+
"ExecuteTime": {
|
| 162 |
+
"end_time": "2025-03-31T15:57:48.223131Z",
|
| 163 |
+
"start_time": "2025-03-31T15:57:48.149696Z"
|
| 164 |
+
}
|
| 165 |
+
},
|
| 166 |
+
"outputs": [],
|
| 167 |
+
"source": [
|
| 168 |
+
"# 配置 GRPO 的超参数(这里的参数可以根据需要进行调整)\n",
|
| 169 |
+
"config = GRPOConfig(\n",
|
| 170 |
+
" gradient_accumulation_steps = 50, # 多少步更新一次参考模型\n",
|
| 171 |
+
" per_device_train_batch_size=2, # 每个批次的样本数\n",
|
| 172 |
+
" epsilon=0.2, # GRPO 中的 clip 范围\n",
|
| 173 |
+
" beta=0.05, # KL 惩罚系数\n",
|
| 174 |
+
" num_train_epochs=1, # 总训练步数(总周期)\n",
|
| 175 |
+
" num_generations=2, # 分组采样的大小\n",
|
| 176 |
+
" learning_rate=1e-5, # 优化器的学习率\n",
|
| 177 |
+
" bf16=True, \n",
|
| 178 |
+
" adam_beta1=0.9,\n",
|
| 179 |
+
" adam_beta2=0.98,\n",
|
| 180 |
+
" optim=\"adamw_8bit\", # 优化器\n",
|
| 181 |
+
" max_grad_norm=0.1, # 梯度裁剪的最大值\n",
|
| 182 |
+
" save_steps=1000, # 多少步保存一次模型\n",
|
| 183 |
+
" save_total_limit=2, # 最多保存几个模型 \n",
|
| 184 |
+
" logging_steps=5, # 多少步打印一次训练信息\n",
|
| 185 |
+
" output_dir=\"GRPO\", # 模型保存路径\n",
|
| 186 |
+
" weight_decay=0.01, # 权重衰减\n",
|
| 187 |
+
" warmup_ratio=0.03, # 预热比例\n",
|
| 188 |
+
" max_prompt_length=256,\n",
|
| 189 |
+
" max_completion_length=1024, # 最大输出长度\n",
|
| 190 |
+
" report_to='tensorboard', # or `tensorboard`\n",
|
| 191 |
+
")\n",
|
| 192 |
+
"# Training arguments training_args = GRPOConfig( \n",
|
| 193 |
+
"# output_dir=\"GRPO\", \n",
|
| 194 |
+
"# learning_rate=2e-5, \n",
|
| 195 |
+
"# per_device_train_batch_size=8, \n",
|
| 196 |
+
"# gradient_accumulation_steps=2, \n",
|
| 197 |
+
"# max_prompt_length=512, \n",
|
| 198 |
+
"# max_completion_length=96, \n",
|
| 199 |
+
"# num_generations=8, \n",
|
| 200 |
+
"# optim=\"adamw_8bit\", \n",
|
| 201 |
+
"# num_train_epochs=1, \n",
|
| 202 |
+
"# bf16=True, \n",
|
| 203 |
+
"# report_to=[\"wandb\"], \n",
|
| 204 |
+
"# remove_unused_columns=False, \n",
|
| 205 |
+
"# logging_steps=1, \n",
|
| 206 |
+
"# ) "
|
| 207 |
+
]
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
"cell_type": "markdown",
|
| 211 |
+
"id": "793b3094cd98fed6",
|
| 212 |
+
"metadata": {},
|
| 213 |
+
"source": [
|
| 214 |
+
"# 训练模型"
|
| 215 |
+
]
|
| 216 |
+
},
|
| 217 |
+
{
|
| 218 |
+
"cell_type": "code",
|
| 219 |
+
"execution_count": 5,
|
| 220 |
+
"id": "19094188d22e45c2",
|
| 221 |
+
"metadata": {
|
| 222 |
+
"ExecuteTime": {
|
| 223 |
+
"end_time": "2025-03-31T17:08:58.041584Z",
|
| 224 |
+
"start_time": "2025-03-31T15:57:50.316805Z"
|
| 225 |
+
}
|
| 226 |
+
},
|
| 227 |
+
"outputs": [
|
| 228 |
+
{
|
| 229 |
+
"name": "stderr",
|
| 230 |
+
"output_type": "stream",
|
| 231 |
+
"text": [
|
| 232 |
+
"No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n"
|
| 233 |
+
]
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"data": {
|
| 237 |
+
"text/html": [
|
| 238 |
+
"\n",
|
| 239 |
+
" <div>\n",
|
| 240 |
+
" \n",
|
| 241 |
+
" <progress value='2' max='20' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
| 242 |
+
" [ 2/20 : < :, Epoch 0.05/1]\n",
|
| 243 |
+
" </div>\n",
|
| 244 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
| 245 |
+
" <thead>\n",
|
| 246 |
+
" <tr style=\"text-align: left;\">\n",
|
| 247 |
+
" <th>Step</th>\n",
|
| 248 |
+
" <th>Training Loss</th>\n",
|
| 249 |
+
" </tr>\n",
|
| 250 |
+
" </thead>\n",
|
| 251 |
+
" <tbody>\n",
|
| 252 |
+
" </tbody>\n",
|
| 253 |
+
"</table><p>"
|
| 254 |
+
],
|
| 255 |
+
"text/plain": [
|
| 256 |
+
"<IPython.core.display.HTML object>"
|
| 257 |
+
]
|
| 258 |
+
},
|
| 259 |
+
"metadata": {},
|
| 260 |
+
"output_type": "display_data"
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"ename": "KeyboardInterrupt",
|
| 264 |
+
"evalue": "",
|
| 265 |
+
"output_type": "error",
|
| 266 |
+
"traceback": [
|
| 267 |
+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
| 268 |
+
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
|
| 269 |
+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 16\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Trainer trainer = GRPOTrainer( \u001b[39;00m\n\u001b[32m 2\u001b[39m \u001b[38;5;66;03m# model=model, \u001b[39;00m\n\u001b[32m 3\u001b[39m \u001b[38;5;66;03m# reward_funcs=[reward_len], \u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 8\u001b[39m \u001b[38;5;66;03m# wandb.init(project=\"GRPO\") \u001b[39;00m\n\u001b[32m 9\u001b[39m \u001b[38;5;66;03m# trainer.train()\u001b[39;00m\n\u001b[32m 10\u001b[39m trainer = GRPOTrainer(\n\u001b[32m 11\u001b[39m model=model,\n\u001b[32m 12\u001b[39m reward_funcs=[compute_rewards],\n\u001b[32m 13\u001b[39m args=config,\n\u001b[32m 14\u001b[39m train_dataset=dataset\n\u001b[32m 15\u001b[39m )\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 270 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\trainer.py:2245\u001b[39m, in \u001b[36mTrainer.train\u001b[39m\u001b[34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[39m\n\u001b[32m 2243\u001b[39m hf_hub_utils.enable_progress_bars()\n\u001b[32m 2244\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2245\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2246\u001b[39m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m=\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2247\u001b[39m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m=\u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2248\u001b[39m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2249\u001b[39m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m=\u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2250\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
|
| 271 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\trainer.py:2556\u001b[39m, in \u001b[36mTrainer._inner_training_loop\u001b[39m\u001b[34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[39m\n\u001b[32m 2549\u001b[39m context = (\n\u001b[32m 2550\u001b[39m functools.partial(\u001b[38;5;28mself\u001b[39m.accelerator.no_sync, model=model)\n\u001b[32m 2551\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m i != \u001b[38;5;28mlen\u001b[39m(batch_samples) - \u001b[32m1\u001b[39m\n\u001b[32m 2552\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m.accelerator.distributed_type != DistributedType.DEEPSPEED\n\u001b[32m 2553\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m contextlib.nullcontext\n\u001b[32m 2554\u001b[39m )\n\u001b[32m 2555\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[32m-> \u001b[39m\u001b[32m2556\u001b[39m tr_loss_step = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_items_in_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2558\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[32m 2559\u001b[39m args.logging_nan_inf_filter\n\u001b[32m 2560\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[32m 2561\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m (torch.isnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch.isinf(tr_loss_step))\n\u001b[32m 2562\u001b[39m ):\n\u001b[32m 2563\u001b[39m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[32m 2564\u001b[39m tr_loss = tr_loss + tr_loss / (\u001b[32m1\u001b[39m + \u001b[38;5;28mself\u001b[39m.state.global_step - \u001b[38;5;28mself\u001b[39m._globalstep_last_logged)\n",
|
| 272 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\trainer.py:3712\u001b[39m, in \u001b[36mTrainer.training_step\u001b[39m\u001b[34m(self, model, inputs, num_items_in_batch)\u001b[39m\n\u001b[32m 3709\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m.optimizer, \u001b[33m\"\u001b[39m\u001b[33mtrain\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mcallable\u001b[39m(\u001b[38;5;28mself\u001b[39m.optimizer.train):\n\u001b[32m 3710\u001b[39m \u001b[38;5;28mself\u001b[39m.optimizer.train()\n\u001b[32m-> \u001b[39m\u001b[32m3712\u001b[39m inputs = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_prepare_inputs\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3713\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m is_sagemaker_mp_enabled():\n\u001b[32m 3714\u001b[39m loss_mb = smp_forward_backward(model, inputs, \u001b[38;5;28mself\u001b[39m.args.gradient_accumulation_steps)\n",
|
| 273 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\trl\\extras\\profiling.py:87\u001b[39m, in \u001b[36mprofiling_decorator.<locals>.wrapper\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 84\u001b[39m \u001b[38;5;129m@functools\u001b[39m.wraps(func)\n\u001b[32m 85\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mwrapper\u001b[39m(\u001b[38;5;28mself\u001b[39m, *args, **kwargs):\n\u001b[32m 86\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m profiling_context(\u001b[38;5;28mself\u001b[39m, func.\u001b[34m__name__\u001b[39m):\n\u001b[32m---> \u001b[39m\u001b[32m87\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 274 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\trl\\trainer\\grpo_trainer.py:647\u001b[39m, in \u001b[36mGRPOTrainer._prepare_inputs\u001b[39m\u001b[34m(self, inputs)\u001b[39m\n\u001b[32m 645\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m mode == \u001b[33m\"\u001b[39m\u001b[33mtrain\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 646\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.state.global_step % \u001b[38;5;28mself\u001b[39m.num_iterations == \u001b[32m0\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m647\u001b[39m inputs = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_generate_and_score_completions\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 648\u001b[39m \u001b[38;5;28mself\u001b[39m._buffered_inputs[\u001b[38;5;28mself\u001b[39m._step % \u001b[38;5;28mself\u001b[39m.args.gradient_accumulation_steps] = inputs\n\u001b[32m 649\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n",
|
| 275 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\trl\\trainer\\grpo_trainer.py:719\u001b[39m, in \u001b[36mGRPOTrainer._generate_and_score_completions\u001b[39m\u001b[34m(self, inputs)\u001b[39m\n\u001b[32m 714\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 715\u001b[39m \u001b[38;5;66;03m# Regular generation path\u001b[39;00m\n\u001b[32m 716\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m unwrap_model_for_generation(\n\u001b[32m 717\u001b[39m \u001b[38;5;28mself\u001b[39m.model_wrapped, \u001b[38;5;28mself\u001b[39m.accelerator, gather_deepspeed3_params=\u001b[38;5;28mself\u001b[39m.args.ds3_gather_for_generation\n\u001b[32m 718\u001b[39m ) \u001b[38;5;28;01mas\u001b[39;00m unwrapped_model:\n\u001b[32m--> \u001b[39m\u001b[32m719\u001b[39m prompt_completion_ids = \u001b[43munwrapped_model\u001b[49m\u001b[43m.\u001b[49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 720\u001b[39m \u001b[43m \u001b[49m\u001b[43mprompt_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprompt_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mgeneration_config\u001b[49m\n\u001b[32m 721\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 723\u001b[39m \u001b[38;5;66;03m# Compute prompt length and extract completion ids\u001b[39;00m\n\u001b[32m 724\u001b[39m prompt_length = prompt_ids.size(\u001b[32m1\u001b[39m)\n",
|
| 276 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\peft\\peft_model.py:823\u001b[39m, in \u001b[36mPeftModel.generate\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 821\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m._enable_peft_forward_hooks(*args, **kwargs):\n\u001b[32m 822\u001b[39m kwargs = {k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m kwargs.items() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.special_peft_forward_args}\n\u001b[32m--> \u001b[39m\u001b[32m823\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mget_base_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 277 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\peft\\peft_model.py:1874\u001b[39m, in \u001b[36mPeftModelForCausalLM.generate\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1872\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m._enable_peft_forward_hooks(*args, **kwargs):\n\u001b[32m 1873\u001b[39m kwargs = {k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m kwargs.items() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.special_peft_forward_args}\n\u001b[32m-> \u001b[39m\u001b[32m1874\u001b[39m outputs = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbase_model\u001b[49m\u001b[43m.\u001b[49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1875\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 1876\u001b[39m outputs = \u001b[38;5;28mself\u001b[39m.base_model.generate(**kwargs)\n",
|
| 278 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\utils\\_contextlib.py:116\u001b[39m, in \u001b[36mcontext_decorator.<locals>.decorate_context\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 113\u001b[39m \u001b[38;5;129m@functools\u001b[39m.wraps(func)\n\u001b[32m 114\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdecorate_context\u001b[39m(*args, **kwargs):\n\u001b[32m 115\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[32m--> \u001b[39m\u001b[32m116\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 279 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\generation\\utils.py:2326\u001b[39m, in \u001b[36mGenerationMixin.generate\u001b[39m\u001b[34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, **kwargs)\u001b[39m\n\u001b[32m 2318\u001b[39m input_ids, model_kwargs = \u001b[38;5;28mself\u001b[39m._expand_inputs_for_generation(\n\u001b[32m 2319\u001b[39m input_ids=input_ids,\n\u001b[32m 2320\u001b[39m expand_size=generation_config.num_return_sequences,\n\u001b[32m 2321\u001b[39m is_encoder_decoder=\u001b[38;5;28mself\u001b[39m.config.is_encoder_decoder,\n\u001b[32m 2322\u001b[39m **model_kwargs,\n\u001b[32m 2323\u001b[39m )\n\u001b[32m 2325\u001b[39m \u001b[38;5;66;03m# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m2326\u001b[39m result = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_sample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2327\u001b[39m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2328\u001b[39m \u001b[43m \u001b[49m\u001b[43mlogits_processor\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprepared_logits_processor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2329\u001b[39m \u001b[43m \u001b[49m\u001b[43mstopping_criteria\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprepared_stopping_criteria\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2330\u001b[39m \u001b[43m \u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2331\u001b[39m \u001b[43m \u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[43m=\u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2332\u001b[39m \u001b[43m \u001b[49m\u001b[43mstreamer\u001b[49m\u001b[43m=\u001b[49m\u001b[43mstreamer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2333\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2334\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2336\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m generation_mode \u001b[38;5;129;01min\u001b[39;00m (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):\n\u001b[32m 2337\u001b[39m \u001b[38;5;66;03m# 11. interleave input_ids with `num_beams` additional sequences per batch\u001b[39;00m\n\u001b[32m 2338\u001b[39m input_ids, model_kwargs = \u001b[38;5;28mself\u001b[39m._expand_inputs_for_generation(\n\u001b[32m 2339\u001b[39m input_ids=input_ids,\n\u001b[32m 2340\u001b[39m expand_size=generation_config.num_beams,\n\u001b[32m 2341\u001b[39m is_encoder_decoder=\u001b[38;5;28mself\u001b[39m.config.is_encoder_decoder,\n\u001b[32m 2342\u001b[39m **model_kwargs,\n\u001b[32m 2343\u001b[39m )\n",
|
| 280 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\generation\\utils.py:3289\u001b[39m, in \u001b[36mGenerationMixin._sample\u001b[39m\u001b[34m(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)\u001b[39m\n\u001b[32m 3287\u001b[39m is_prefill = \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[32m 3288\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m3289\u001b[39m outputs = \u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mmodel_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 3291\u001b[39m \u001b[38;5;66;03m# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping\u001b[39;00m\n\u001b[32m 3292\u001b[39m model_kwargs = \u001b[38;5;28mself\u001b[39m._update_model_kwargs_for_generation(\n\u001b[32m 3293\u001b[39m outputs,\n\u001b[32m 3294\u001b[39m model_kwargs,\n\u001b[32m 3295\u001b[39m is_encoder_decoder=\u001b[38;5;28mself\u001b[39m.config.is_encoder_decoder,\n\u001b[32m 3296\u001b[39m )\n",
|
| 281 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 282 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
|
| 283 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\utils\\deprecation.py:172\u001b[39m, in \u001b[36mdeprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 168\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m minimum_action \u001b[38;5;129;01min\u001b[39;00m (Action.NOTIFY, Action.NOTIFY_ALWAYS) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n\u001b[32m 169\u001b[39m \u001b[38;5;66;03m# DeprecationWarning is ignored by default, so we use FutureWarning instead\u001b[39;00m\n\u001b[32m 170\u001b[39m warnings.warn(message, \u001b[38;5;167;01mFutureWarning\u001b[39;00m, stacklevel=\u001b[32m2\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m172\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 284 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\models\\qwen2\\modeling_qwen2.py:855\u001b[39m, in \u001b[36mQwen2ForCausalLM.forward\u001b[39m\u001b[34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, **kwargs)\u001b[39m\n\u001b[32m 852\u001b[39m return_dict = return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m.config.use_return_dict\n\u001b[32m 854\u001b[39m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m855\u001b[39m outputs = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 856\u001b[39m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m=\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 857\u001b[39m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 858\u001b[39m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[43m=\u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 859\u001b[39m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[43m=\u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 860\u001b[39m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[43m=\u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 861\u001b[39m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m=\u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 862\u001b[39m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m=\u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 863\u001b[39m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m=\u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 864\u001b[39m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[43m=\u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 865\u001b[39m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 866\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 867\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 869\u001b[39m hidden_states = outputs[\u001b[32m0\u001b[39m]\n\u001b[32m 870\u001b[39m \u001b[38;5;66;03m# Only compute necessary logits, and do not upcast them to float if we are not computing the loss\u001b[39;00m\n",
|
| 285 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 286 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
|
| 287 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\models\\qwen2\\modeling_qwen2.py:579\u001b[39m, in \u001b[36mQwen2Model.forward\u001b[39m\u001b[34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)\u001b[39m\n\u001b[32m 567\u001b[39m layer_outputs = \u001b[38;5;28mself\u001b[39m._gradient_checkpointing_func(\n\u001b[32m 568\u001b[39m decoder_layer.\u001b[34m__call__\u001b[39m,\n\u001b[32m 569\u001b[39m hidden_states,\n\u001b[32m (...)\u001b[39m\u001b[32m 576\u001b[39m position_embeddings,\n\u001b[32m 577\u001b[39m )\n\u001b[32m 578\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m579\u001b[39m layer_outputs = \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 580\u001b[39m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 581\u001b[39m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 582\u001b[39m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[43m=\u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 583\u001b[39m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[43m=\u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 584\u001b[39m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m=\u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 585\u001b[39m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m=\u001b[49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 586\u001b[39m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 587\u001b[39m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m=\u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 588\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mflash_attn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 589\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 591\u001b[39m hidden_states = layer_outputs[\u001b[32m0\u001b[39m]\n\u001b[32m 593\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m output_attentions:\n",
|
| 288 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 289 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
|
| 290 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\models\\qwen2\\modeling_qwen2.py:276\u001b[39m, in \u001b[36mQwen2DecoderLayer.forward\u001b[39m\u001b[34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)\u001b[39m\n\u001b[32m 274\u001b[39m residual = hidden_states\n\u001b[32m 275\u001b[39m hidden_states = \u001b[38;5;28mself\u001b[39m.post_attention_layernorm(hidden_states)\n\u001b[32m--> \u001b[39m\u001b[32m276\u001b[39m hidden_states = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 277\u001b[39m hidden_states = residual + hidden_states\n\u001b[32m 279\u001b[39m outputs = (hidden_states,)\n",
|
| 291 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 292 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
|
| 293 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\transformers\\models\\qwen2\\modeling_qwen2.py:57\u001b[39m, in \u001b[36mQwen2MLP.forward\u001b[39m\u001b[34m(self, x)\u001b[39m\n\u001b[32m 56\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[32m---> \u001b[39m\u001b[32m57\u001b[39m down_proj = \u001b[38;5;28mself\u001b[39m.down_proj(\u001b[38;5;28mself\u001b[39m.act_fn(\u001b[38;5;28mself\u001b[39m.gate_proj(x)) * \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mup_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 58\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m down_proj\n",
|
| 294 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 295 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
|
| 296 |
+
"\u001b[36mFile \u001b[39m\u001b[32mE:\\共享\\GoodMusicV2.0\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\linear.py:125\u001b[39m, in \u001b[36mLinear.forward\u001b[39m\u001b[34m(self, input)\u001b[39m\n\u001b[32m 124\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) -> Tensor:\n\u001b[32m--> \u001b[39m\u001b[32m125\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[43m.\u001b[49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 297 |
+
"\u001b[31mKeyboardInterrupt\u001b[39m: "
|
| 298 |
+
]
|
| 299 |
+
}
|
| 300 |
+
],
|
| 301 |
+
"source": [
|
| 302 |
+
"# Trainer trainer = GRPOTrainer( \n",
|
| 303 |
+
"# model=model, \n",
|
| 304 |
+
"# reward_funcs=[reward_len], \n",
|
| 305 |
+
"# args=training_args, \n",
|
| 306 |
+
"# train_dataset=dataset[\"train\"], \n",
|
| 307 |
+
"# ) \n",
|
| 308 |
+
"# Train model \n",
|
| 309 |
+
"# wandb.init(project=\"GRPO\") \n",
|
| 310 |
+
"# trainer.train()\n",
|
| 311 |
+
"trainer = GRPOTrainer(\n",
|
| 312 |
+
" model=model,\n",
|
| 313 |
+
" reward_funcs=[compute_rewards],\n",
|
| 314 |
+
" args=config,\n",
|
| 315 |
+
" train_dataset=dataset\n",
|
| 316 |
+
")\n",
|
| 317 |
+
"trainer.train()"
|
| 318 |
+
]
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"cell_type": "markdown",
|
| 322 |
+
"id": "f621c33533e55b00",
|
| 323 |
+
"metadata": {},
|
| 324 |
+
"source": [
|
| 325 |
+
"# 评估"
|
| 326 |
+
]
|
| 327 |
+
},
|
| 328 |
+
{
|
| 329 |
+
"cell_type": "code",
|
| 330 |
+
"execution_count": null,
|
| 331 |
+
"id": "d17d0e3eb9069545",
|
| 332 |
+
"metadata": {},
|
| 333 |
+
"outputs": [],
|
| 334 |
+
"source": [
|
| 335 |
+
"import matplotlib.pyplot as plt\n",
|
| 336 |
+
"from datetime import datetime\n",
|
| 337 |
+
"\n",
|
| 338 |
+
"def plot_training_metrics(losses, kls, avg_rewards, output_dir=\".\"):\n",
|
| 339 |
+
" \"\"\"\n",
|
| 340 |
+
" 绘制并保存训练指标图表\n",
|
| 341 |
+
" \n",
|
| 342 |
+
" 参数:\n",
|
| 343 |
+
" losses: 训练损失列表\n",
|
| 344 |
+
" kls: KL散度列表\n",
|
| 345 |
+
" avg_rewards: 平均奖励列表\n",
|
| 346 |
+
" output_dir: 输出目录路径\n",
|
| 347 |
+
" \"\"\"\n",
|
| 348 |
+
" # 生成带时间戳的唯一文件名\n",
|
| 349 |
+
" timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
|
| 350 |
+
" output_path = f\"{output_dir}/training_curves_{timestamp}.png\"\n",
|
| 351 |
+
" \n",
|
| 352 |
+
" # 创建画布\n",
|
| 353 |
+
" plt.figure(figsize=(15, 5), dpi=300)\n",
|
| 354 |
+
" \n",
|
| 355 |
+
" # 1. Loss 曲线\n",
|
| 356 |
+
" plt.subplot(1, 3, 1)\n",
|
| 357 |
+
" plt.plot(losses, label=\"Loss\", linewidth=1.5, color='blue')\n",
|
| 358 |
+
" plt.title(\"Training Loss\", fontsize=10)\n",
|
| 359 |
+
" plt.xlabel(\"Step\", fontsize=9)\n",
|
| 360 |
+
" plt.ylabel(\"Loss\", fontsize=9)\n",
|
| 361 |
+
" plt.grid(True, alpha=0.3)\n",
|
| 362 |
+
" \n",
|
| 363 |
+
" # 2. KL 散度曲线\n",
|
| 364 |
+
" plt.subplot(1, 3, 2)\n",
|
| 365 |
+
" plt.plot(kls, label=\"KL Divergence\", linewidth=1.5, color='orange')\n",
|
| 366 |
+
" plt.title(\"KL Divergence\", fontsize=10)\n",
|
| 367 |
+
" plt.xlabel(\"Step\", fontsize=9)\n",
|
| 368 |
+
" plt.ylabel(\"KL Divergence\", fontsize=9)\n",
|
| 369 |
+
" plt.grid(True, alpha=0.3)\n",
|
| 370 |
+
" \n",
|
| 371 |
+
" # 3. 平均奖励曲线\n",
|
| 372 |
+
" plt.subplot(1, 3, 3)\n",
|
| 373 |
+
" plt.plot(avg_rewards, label=\"Avg Reward\", linewidth=1.5, color='green')\n",
|
| 374 |
+
" plt.title(\"Average Reward\", fontsize=10)\n",
|
| 375 |
+
" plt.xlabel(\"Step\", fontsize=9)\n",
|
| 376 |
+
" plt.ylabel(\"Reward\", fontsize=9)\n",
|
| 377 |
+
" plt.grid(True, alpha=0.3)\n",
|
| 378 |
+
" \n",
|
| 379 |
+
" # 调整布局并保存\n",
|
| 380 |
+
" plt.tight_layout()\n",
|
| 381 |
+
" plt.savefig(\n",
|
| 382 |
+
" output_path,\n",
|
| 383 |
+
" bbox_inches='tight',\n",
|
| 384 |
+
" facecolor='white',\n",
|
| 385 |
+
" dpi=300\n",
|
| 386 |
+
" )\n",
|
| 387 |
+
" plt.close()\n",
|
| 388 |
+
" \n",
|
| 389 |
+
" print(f\"训练指标图表已保存至: {output_path}\")\n",
|
| 390 |
+
"\n",
|
| 391 |
+
"# 使用示例 (假设你已经有了这些数据)\n",
|
| 392 |
+
"# losses = [...] # 你的损失数据\n",
|
| 393 |
+
"# kls = [...] # 你的KL散度数据\n",
|
| 394 |
+
"# avg_rewards = [...] # 你的平均奖励数据\n",
|
| 395 |
+
"# plot_training_metrics(losses, kls, avg_rewards)\n",
|
| 396 |
+
"\n"
|
| 397 |
+
]
|
| 398 |
+
},
|
| 399 |
+
{
|
| 400 |
+
"cell_type": "code",
|
| 401 |
+
"execution_count": null,
|
| 402 |
+
"id": "b6a739a2f9d0a343",
|
| 403 |
+
"metadata": {},
|
| 404 |
+
"outputs": [],
|
| 405 |
+
"source": [
|
| 406 |
+
"class MetricsCallback(TrainerCallback):\n",
|
| 407 |
+
" def __init__(self):\n",
|
| 408 |
+
" super().__init__()\n",
|
| 409 |
+
" self.metrics = {\n",
|
| 410 |
+
" 'loss': [], \n",
|
| 411 |
+
" 'kl_divergence': [], \n",
|
| 412 |
+
" 'avg_reward': []\n",
|
| 413 |
+
" }\n",
|
| 414 |
+
" \n",
|
| 415 |
+
" def on_log(self, args, state, control, logs=None, **kwargs):\n",
|
| 416 |
+
" if logs is not None:\n",
|
| 417 |
+
" if 'loss' in logs:\n",
|
| 418 |
+
" self.metrics['loss'].append(logs['loss'])\n",
|
| 419 |
+
" if 'kl_divergence' in logs:\n",
|
| 420 |
+
" self.metrics['kl_divergence'].append(logs['kl_divergence'])\n",
|
| 421 |
+
" if 'rewards' in logs: # 假设返回的是列表,取其平均值\n",
|
| 422 |
+
" avg_reward = sum(logs['rewards'])/len(logs['rewards'])\n",
|
| 423 |
+
" self.metrics['avg_reward'].append(avg_reward)\n",
|
| 424 |
+
" \n",
|
| 425 |
+
" "
|
| 426 |
+
]
|
| 427 |
+
},
|
| 428 |
+
{
|
| 429 |
+
"cell_type": "code",
|
| 430 |
+
"execution_count": null,
|
| 431 |
+
"id": "14cee34aa3bb165",
|
| 432 |
+
"metadata": {},
|
| 433 |
+
"outputs": [],
|
| 434 |
+
"source": [
|
| 435 |
+
"plot_training_metrics(metrics_callback.metrics['loss'],metrics_callback.metrics['kl_divergence'],metrics_callback.metrics['avg_reward'])"
|
| 436 |
+
]
|
| 437 |
+
}
|
| 438 |
+
],
|
| 439 |
+
"metadata": {
|
| 440 |
+
"kernelspec": {
|
| 441 |
+
"display_name": ".venv",
|
| 442 |
+
"language": "python",
|
| 443 |
+
"name": "python3"
|
| 444 |
+
},
|
| 445 |
+
"language_info": {
|
| 446 |
+
"codemirror_mode": {
|
| 447 |
+
"name": "ipython",
|
| 448 |
+
"version": 2
|
| 449 |
+
},
|
| 450 |
+
"file_extension": ".py",
|
| 451 |
+
"mimetype": "text/x-python",
|
| 452 |
+
"name": "python",
|
| 453 |
+
"nbconvert_exporter": "python",
|
| 454 |
+
"pygments_lexer": "ipython2",
|
| 455 |
+
"version": "3.11.3"
|
| 456 |
+
}
|
| 457 |
+
},
|
| 458 |
+
"nbformat": 4,
|
| 459 |
+
"nbformat_minor": 5
|
| 460 |
+
}
|
code/LORA.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import Dataset
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
|
| 3 |
+
from peft import LoraConfig, get_peft_model, PeftModel
|
| 4 |
+
|
| 5 |
+
raw_data_path = ""#替换为对应的数据集路径
|
| 6 |
+
with open(raw_data_path, "r", encoding="utf-8") as f:
|
| 7 |
+
raw_lines = f.readlines()
|
| 8 |
+
|
| 9 |
+
def process_line(line):
|
| 10 |
+
segments = line.strip().split("/")
|
| 11 |
+
return "/".join(segments[:-1]) if len(segments) > 1 else line.strip()
|
| 12 |
+
|
| 13 |
+
processed_samples = [process_line(line) for line in raw_lines if line.strip()]
|
| 14 |
+
dataset = Dataset.from_dict({"text": processed_samples})
|
| 15 |
+
|
| 16 |
+
model_name = ""#替换为对应的模型路径
|
| 17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 18 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 19 |
+
|
| 20 |
+
lora_config = LoraConfig(
|
| 21 |
+
r=8, # 低秩矩阵的秩,通常取 8、16 或 32
|
| 22 |
+
lora_alpha=32, # 缩放因子,控制 LoRA 的影响
|
| 23 |
+
target_modules=["q_proj", "v_proj"], # 应用 LoRA 的模块,通常是注意力层的投影
|
| 24 |
+
lora_dropout=0.1, # Dropout 概率,防止过拟合
|
| 25 |
+
bias="none", # 是否训练偏置,通常设为 "none"
|
| 26 |
+
task_type="CAUSAL_LM" # 任务类型,对于因果语言模型使用 "CAUSAL_LM"
|
| 27 |
+
)
|
| 28 |
+
model = get_peft_model(model, lora_config)
|
| 29 |
+
|
| 30 |
+
def tokenize_function(examples):
|
| 31 |
+
# 预定义固定的提示词
|
| 32 |
+
prompt = "根据以下关键词生成一首歌词,歌词中包含多个句子,句子与句子之间使用/隔开,让我们一步一步的思考(思考过程包含在<think>和</think>之间):"
|
| 33 |
+
|
| 34 |
+
# 在原文本前面加上提示词
|
| 35 |
+
modified_texts = [prompt + text for text in examples["text"]]
|
| 36 |
+
|
| 37 |
+
# 进行分词
|
| 38 |
+
tokenized = tokenizer(modified_texts, truncation=True, padding="max_length", max_length=256)
|
| 39 |
+
|
| 40 |
+
# 复制 input_ids 作为 labels
|
| 41 |
+
tokenized["labels"] = tokenized["input_ids"].copy()
|
| 42 |
+
|
| 43 |
+
return tokenized
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
tokenized_dataset = dataset.map(tokenize_function, batched=True)
|
| 49 |
+
|
| 50 |
+
training_args = TrainingArguments(
|
| 51 |
+
output_dir="./lora",
|
| 52 |
+
num_train_epochs=8,
|
| 53 |
+
per_device_train_batch_size=10,
|
| 54 |
+
learning_rate=2e-5,
|
| 55 |
+
weight_decay=0.01,
|
| 56 |
+
logging_steps=10000,
|
| 57 |
+
save_steps=15000,
|
| 58 |
+
fp16=True,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
trainer = Trainer(
|
| 62 |
+
model=model,
|
| 63 |
+
args=training_args,
|
| 64 |
+
train_dataset=tokenized_dataset,
|
| 65 |
+
tokenizer=tokenizer,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
trainer.train()
|
| 70 |
+
|
| 71 |
+
# 推理示例
|
| 72 |
+
generation_config = {
|
| 73 |
+
"max_new_tokens": 1024,
|
| 74 |
+
"temperature": 1.0,
|
| 75 |
+
"top_p": 0.9,
|
| 76 |
+
"top_k": 40,
|
| 77 |
+
"repetition_penalty": 1.2,
|
| 78 |
+
"do_sample": True,
|
| 79 |
+
"encoder_no_repeat_ngram_size": 4,
|
| 80 |
+
}
|
| 81 |
+
if True:
|
| 82 |
+
prompt = "根据以下关键词生成一首歌词,歌词中包含多个句子,句子与句子之间使用/隔开,让我们一步一步的思考(思考过程包含在<think>和</think>之间):温柔,轮廓,洒脱:"
|
| 83 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
|
| 84 |
+
outputs = model.generate(input_ids, **generation_config)
|
| 85 |
+
decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
| 86 |
+
|
| 87 |
+
print(decoded)
|
| 88 |
+
|
| 89 |
+
model.save_pretrained("")#替换为对应的保存路径
|
code/LORA_with_CoT.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from datasets import Dataset
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
|
| 4 |
+
from peft import LoraConfig, get_peft_model, PeftModel
|
| 5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 6 |
+
|
| 7 |
+
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
|
| 8 |
+
|
| 9 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 10 |
+
model_name,
|
| 11 |
+
torch_dtype="auto",
|
| 12 |
+
device_map="auto"
|
| 13 |
+
)
|
| 14 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 15 |
+
|
| 16 |
+
# 加载数据
|
| 17 |
+
raw_data_path = r"data/CoTdata.txt"
|
| 18 |
+
with open(raw_data_path, "r", encoding="utf-8") as f:
|
| 19 |
+
raw_lines = f.readlines()
|
| 20 |
+
|
| 21 |
+
# 处理每一行数据,解析出关键词、思维链和诗歌内容
|
| 22 |
+
def process_line(line):
|
| 23 |
+
# 使用 [::] 同时匹配中文和英文冒号
|
| 24 |
+
pattern = r"^(.*?)<think>(.*?)</think>[::](.*)$"
|
| 25 |
+
match = re.match(pattern, line.strip())
|
| 26 |
+
if match:
|
| 27 |
+
keywords = match.group(1).strip()
|
| 28 |
+
cot = match.group(2).strip()
|
| 29 |
+
poem = match.group(3).strip()
|
| 30 |
+
# 构造训练实例:输入部分给出提示和关键词,输出部分包含完整思维链及答案
|
| 31 |
+
training_text = (
|
| 32 |
+
f"【输入】:根据以下关键词生成一首歌词,歌词中包含多个句子,确保句子通顺、诗意、格式正确。"
|
| 33 |
+
f"让我们一步一步的思考(思考过程包含在<think>和</think>之间):{keywords}\n\n"
|
| 34 |
+
f"【输出】:<think>{cot}</think>\n{poem}"
|
| 35 |
+
)
|
| 36 |
+
return training_text
|
| 37 |
+
else:
|
| 38 |
+
# 如果格式不符,输出提示并返回 None
|
| 39 |
+
print("跳过格式错误的行:", line.strip())
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
# 解析所有数据行
|
| 43 |
+
processed_samples = []
|
| 44 |
+
for line in raw_lines:
|
| 45 |
+
result = process_line(line)
|
| 46 |
+
if result:
|
| 47 |
+
processed_samples.append(result)
|
| 48 |
+
|
| 49 |
+
# 构建 Hugging Face 数据集
|
| 50 |
+
dataset = Dataset.from_dict({"text": processed_samples})
|
| 51 |
+
|
| 52 |
+
# 加载基础模型和 LoRA 模型
|
| 53 |
+
model = PeftModel.from_pretrained(base_model, r"D:\GoodMusicV3.0\3_24_LoRA").to("cuda") # 替换为你的 LoRA 路径
|
| 54 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 55 |
+
|
| 56 |
+
lora_config = LoraConfig(
|
| 57 |
+
r=8, # 低秩矩阵的秩,常取 8、16 或 32
|
| 58 |
+
lora_alpha=32, # 缩放因子,控制 LoRA 影响
|
| 59 |
+
target_modules=["q_proj", "k_proj", "v_proj"], # 应用 LoRA 的模块,通常是注意力层的投影
|
| 60 |
+
lora_dropout=0.1, # Dropout 概率,防止过拟合
|
| 61 |
+
bias="none", # 通常设为 "none"
|
| 62 |
+
task_type="CAUSAL_LM"
|
| 63 |
+
)
|
| 64 |
+
model = get_peft_model(model, lora_config)
|
| 65 |
+
model.cuda()
|
| 66 |
+
|
| 67 |
+
# 分词函数:对文本进行分词,并构造 labels
|
| 68 |
+
def tokenize_function(examples):
|
| 69 |
+
# 此处的文本已经包含了输入和输出的完整内容
|
| 70 |
+
tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=256)
|
| 71 |
+
tokenized["labels"] = tokenized["input_ids"].copy()
|
| 72 |
+
return tokenized
|
| 73 |
+
|
| 74 |
+
# 对数据集进行映射处理
|
| 75 |
+
tokenized_dataset = dataset.map(tokenize_function, batched=True)
|
| 76 |
+
|
| 77 |
+
# 设置训练参数
|
| 78 |
+
training_args = TrainingArguments(
|
| 79 |
+
output_dir="./lora",
|
| 80 |
+
num_train_epochs=1000,
|
| 81 |
+
per_device_train_batch_size=16,
|
| 82 |
+
learning_rate=2e-5,
|
| 83 |
+
weight_decay=0.01,
|
| 84 |
+
logging_steps=10000,
|
| 85 |
+
save_steps=15000,
|
| 86 |
+
fp16=True,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# 构造 Trainer
|
| 90 |
+
trainer = Trainer(
|
| 91 |
+
model=model,
|
| 92 |
+
args=training_args,
|
| 93 |
+
train_dataset=tokenized_dataset,
|
| 94 |
+
tokenizer=tokenizer,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# 开始训练
|
| 98 |
+
trainer.train()
|
| 99 |
+
|
| 100 |
+
# 推理示例
|
| 101 |
+
generation_config = {
|
| 102 |
+
"max_new_tokens": 1024,
|
| 103 |
+
"temperature": 1.0,
|
| 104 |
+
"top_p": 0.9,
|
| 105 |
+
"top_k": 40,
|
| 106 |
+
"repetition_penalty": 1.2,
|
| 107 |
+
"do_sample": True,
|
| 108 |
+
"encoder_no_repeat_ngram_size": 4,
|
| 109 |
+
}
|
| 110 |
+
if True:
|
| 111 |
+
prompt = "根据以下关键词生成一首歌词,歌词中包含多个句子,句子与句子之间使用/隔开,让我们一步一步的思考(思考过程包含在<think>和</think>之间):温柔,轮廓,洒脱:"
|
| 112 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
|
| 113 |
+
outputs = model.generate(input_ids, **generation_config)
|
| 114 |
+
decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
| 115 |
+
|
| 116 |
+
print(decoded)
|
| 117 |
+
|
| 118 |
+
# 保存模型
|
| 119 |
+
model.save_pretrained("4_2_LoRA_3")
|
code/UI.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton, QTextEdit, QLineEdit, QListWidget, QLabel, QHBoxLayout, QListWidgetItem
|
| 3 |
+
import _MyModel
|
| 4 |
+
from PyQt5.QtGui import QColor, QPalette
|
| 5 |
+
from PyQt5.QtCore import Qt
|
| 6 |
+
class ChatSession:
|
| 7 |
+
"""储存单个对话的内容"""
|
| 8 |
+
def __init__(self, topic="新对话"):
|
| 9 |
+
self.topic = topic
|
| 10 |
+
self.messages = [] # 存储聊天记录
|
| 11 |
+
|
| 12 |
+
def add_message(self, sender, text):
|
| 13 |
+
"""添加消息(sender: 'user' 或 'ai')"""
|
| 14 |
+
self.messages.append((sender, text))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ChatGPTUI(QWidget):
|
| 18 |
+
|
| 19 |
+
def __init__(self, MyModel):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.model = MyModel
|
| 22 |
+
self.first_list_item = QListWidget()
|
| 23 |
+
self.setWindowTitle("ChatGPT 聊天界面")
|
| 24 |
+
self.setGeometry(200, 200, 800, 600)
|
| 25 |
+
self.setStyleSheet("background-color: #DCB272; color: white;") # 设置深色背景
|
| 26 |
+
#self.setWindowFlags(Qt.FramelessWindowHint) # 设置无边框
|
| 27 |
+
# 创建主布局
|
| 28 |
+
main_layout = QHBoxLayout(self)
|
| 29 |
+
# 左侧:
|
| 30 |
+
left_layout = QVBoxLayout()
|
| 31 |
+
# 添加“新建对话”按钮
|
| 32 |
+
self.new_chat_button = QPushButton("新建对话")
|
| 33 |
+
self.new_chat_button.setStyleSheet("background-color: #0FA958; color: white; padding: 8px; border-radius: 5px;")
|
| 34 |
+
self.new_chat_button.clicked.connect(self.create_new_chat)
|
| 35 |
+
left_layout.addWidget(self.new_chat_button)
|
| 36 |
+
# 左侧:对话历史列表
|
| 37 |
+
self.history_list = QListWidget()
|
| 38 |
+
self.history_list.setStyleSheet("background-color: #E4DECE; color: black; border: none;")
|
| 39 |
+
self.history_list.itemClicked.connect(self.load_selected_chat) # 绑定选择对话事件
|
| 40 |
+
left_layout.addWidget(self.history_list)
|
| 41 |
+
|
| 42 |
+
# 右侧:聊天区域
|
| 43 |
+
right_layout = QVBoxLayout()
|
| 44 |
+
|
| 45 |
+
# 对话主题输入框
|
| 46 |
+
self.topic_input = QLineEdit()
|
| 47 |
+
self.topic_input.setPlaceholderText("请输入对话主题...")
|
| 48 |
+
self.topic_input.setStyleSheet("background-color: #E4DECE; color: black; padding: 5px; border-radius: 5px;")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# 聊天显示区域
|
| 53 |
+
self.chat_display = QTextEdit()
|
| 54 |
+
self.chat_display.setReadOnly(True)
|
| 55 |
+
self.chat_display.setStyleSheet("background-color: #E4DECE; color: black; border: none; padding: 10px;")
|
| 56 |
+
right_layout.addWidget(self.chat_display, 7)
|
| 57 |
+
|
| 58 |
+
# 输入区域(水平布局)
|
| 59 |
+
input_layout = QHBoxLayout()
|
| 60 |
+
|
| 61 |
+
# 用户输入框
|
| 62 |
+
self.input_field = QLineEdit()
|
| 63 |
+
self.input_field.setPlaceholderText("输入消息...")
|
| 64 |
+
self.input_field.setStyleSheet("background-color: #E4DECE; color: black; padding: 5px; border-radius: 5px;")
|
| 65 |
+
input_layout.addWidget(self.input_field, 8)
|
| 66 |
+
self.input_field.returnPressed.connect(self.send_message)
|
| 67 |
+
|
| 68 |
+
# 发送按钮
|
| 69 |
+
self.send_button = QPushButton("发送")
|
| 70 |
+
self.send_button.setStyleSheet("background-color: #DA8D6D; color: white; padding: 8px; border-radius: 5px;")
|
| 71 |
+
self.send_button.clicked.connect(self.send_message)
|
| 72 |
+
input_layout.addWidget(self.send_button, 2)
|
| 73 |
+
|
| 74 |
+
right_layout.addLayout(input_layout)
|
| 75 |
+
|
| 76 |
+
# 将右侧布局添加到主布局
|
| 77 |
+
main_layout.addLayout(left_layout, 2)
|
| 78 |
+
main_layout.addLayout(right_layout, 8)
|
| 79 |
+
|
| 80 |
+
self.setLayout(main_layout)
|
| 81 |
+
|
| 82 |
+
# 初始对话存储
|
| 83 |
+
self.chat_sessions = [] # 存储多个会话
|
| 84 |
+
self.current_session = None
|
| 85 |
+
self.create_new_chat() # 启动时创建默认对话
|
| 86 |
+
|
| 87 |
+
def create_new_chat(self):
|
| 88 |
+
"""新建对话并添加到历史列表"""
|
| 89 |
+
topic = self.topic_input.text().strip()
|
| 90 |
+
if not topic:
|
| 91 |
+
topic = "新对话"
|
| 92 |
+
|
| 93 |
+
new_session = ChatSession(topic)
|
| 94 |
+
self.chat_sessions.append(new_session)
|
| 95 |
+
self.current_session = new_session
|
| 96 |
+
|
| 97 |
+
# 更新左侧历史对话列表
|
| 98 |
+
self.add_chat_item(topic)
|
| 99 |
+
self.history_list.setCurrentRow(self.history_list.count() - 1) # 选中新建的对话
|
| 100 |
+
self.chat_display.clear()
|
| 101 |
+
|
| 102 |
+
def load_selected_chat(self):
|
| 103 |
+
"""切换到用户选择的历史对话"""
|
| 104 |
+
selected_index = self.history_list.currentRow()
|
| 105 |
+
if selected_index >= 0:
|
| 106 |
+
self.current_session = self.chat_sessions[selected_index]
|
| 107 |
+
self.display_chat_history()
|
| 108 |
+
|
| 109 |
+
def display_chat_history(self):
|
| 110 |
+
"""显示当前会话的聊天记录"""
|
| 111 |
+
self.chat_display.clear()
|
| 112 |
+
for sender, text in self.current_session.messages:
|
| 113 |
+
if sender == 'user':
|
| 114 |
+
self.chat_display.append(f"<b><span style='color: #9b7438; font-family: 微软雅黑; font-size: 28px'>主题 : </span><span style='color: #1B2131; font-family: 微软雅黑; font-size: 28px'> {text}</span></b>")
|
| 115 |
+
else:
|
| 116 |
+
self.chat_display.append(f"<b>{'用户' if sender == 'user' else 'ChatGPT'}:</b> {text}")
|
| 117 |
+
|
| 118 |
+
def send_message(self):
|
| 119 |
+
"""发送用户输入的消息"""
|
| 120 |
+
user_text = self.input_field.text().strip()
|
| 121 |
+
if user_text and self.current_session:
|
| 122 |
+
self.current_session.add_message("user", user_text)
|
| 123 |
+
self.chat_display.append(f"<b><span style='color: #9b7438; font-family: 微软雅黑; font-size: 28px'>主题 : </span><span style='color: #1B2131; font-family: 微软雅黑; font-size: 28px'> {user_text}</span></b>")
|
| 124 |
+
self.input_field.clear()
|
| 125 |
+
|
| 126 |
+
# 触发 AI 回复(暂时用占位内容)
|
| 127 |
+
ai_reply = self.get_ai_response(user_text)
|
| 128 |
+
self.receive_message(ai_reply)
|
| 129 |
+
|
| 130 |
+
def receive_message(self, text):
|
| 131 |
+
"""显示 AI 回复"""
|
| 132 |
+
if self.current_session:
|
| 133 |
+
self.current_session.add_message("ai", text)
|
| 134 |
+
self.chat_display.append(f"<b>ChatGPT:</b> {text}")
|
| 135 |
+
|
| 136 |
+
def get_ai_response(self, user_input):
|
| 137 |
+
"""可在此接入 AI 模型,如 OpenAI API 或本地大模型"""
|
| 138 |
+
output = self.model.predict(user_input)
|
| 139 |
+
return f"<span style='font-size: 20px;'>{output}</span>"
|
| 140 |
+
|
| 141 |
+
def add_chat_item(self, text):
|
| 142 |
+
""" 添加带删除按钮的聊天记录项 """
|
| 143 |
+
item_widget = QWidget()
|
| 144 |
+
item_layout = QHBoxLayout(item_widget)
|
| 145 |
+
item_layout.setContentsMargins(5, 2, 5, 2)
|
| 146 |
+
|
| 147 |
+
label = QLabel(text)
|
| 148 |
+
delete_button = QPushButton("×")
|
| 149 |
+
delete_button.setFixedSize(20, 20)
|
| 150 |
+
delete_button.setStyleSheet("background-color: #cc6666; color: white; border-radius: 10px;")
|
| 151 |
+
|
| 152 |
+
item_layout.addWidget(label)
|
| 153 |
+
item_layout.addWidget(delete_button)
|
| 154 |
+
item_layout.addStretch()
|
| 155 |
+
|
| 156 |
+
list_item = QListWidgetItem(self.history_list)
|
| 157 |
+
list_item.setSizeHint(item_widget.sizeHint())
|
| 158 |
+
|
| 159 |
+
self.history_list.addItem(list_item)
|
| 160 |
+
self.history_list.setItemWidget(list_item, item_widget)
|
| 161 |
+
|
| 162 |
+
# 绑定删除事件
|
| 163 |
+
delete_button.clicked.connect(lambda: self.remove_chat_item(list_item))
|
| 164 |
+
self.first_list_item = list_item
|
| 165 |
+
|
| 166 |
+
def remove_chat_item(self, item):
|
| 167 |
+
""" 删除聊天记录项 """
|
| 168 |
+
row = self.history_list.row(item)
|
| 169 |
+
del self.chat_sessions[row]
|
| 170 |
+
self.history_list.takeItem(row)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# 运行 PyQt5 应用
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
|
| 176 |
+
app = QApplication(sys.argv)
|
| 177 |
+
window = ChatGPTUI()
|
| 178 |
+
window.show()
|
| 179 |
+
window.remove_chat_item(window.first_list_item)
|
| 180 |
+
window.create_new_chat()
|
| 181 |
+
sys.exit(app.exec_())
|
code/_MyModel.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from peft import LoraConfig, get_peft_model, PeftModel
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 4 |
+
|
| 5 |
+
class MyModel():
|
| 6 |
+
def __init__(self):
|
| 7 |
+
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
| 8 |
+
lora_path = "DS_RL_model"
|
| 9 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 10 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 11 |
+
self.model = PeftModel.from_pretrained(model, lora_path)
|
| 12 |
+
self.generation_config = {
|
| 13 |
+
"max_new_tokens": 2048,
|
| 14 |
+
"temperature": 0.9,
|
| 15 |
+
"top_p": 1.0,
|
| 16 |
+
"repetition_penalty": 1.2,
|
| 17 |
+
}
|
| 18 |
+
def predict(self, text):
|
| 19 |
+
prompt = "根据以下关键词生成一首歌词,歌词中包含多个句子,句子与句子之间使用/隔开,让我们一步一步的思考(思考过程包含在<think>和</think>之间):" + text
|
| 20 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device)
|
| 21 |
+
outputs = self.model.generate(input_ids, **self.generation_config)
|
| 22 |
+
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
|
| 23 |
+
return decoded
|
| 24 |
+
#诗,样子,天地:
|
code/__main__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import _MyModel
|
| 3 |
+
from UI import QApplication, ChatGPTUI
|
| 4 |
+
import os
|
| 5 |
+
os.environ["QT_QPA_PLATFORM_PLUGIN_PATH"] = "D:\不会编程\Machine_Learning\class_project\project\.venv\Lib\site-packages\PyQt5\Qt5\plugins"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
if __name__ == '__main__':
|
| 9 |
+
myModel = _MyModel.MyModel()
|
| 10 |
+
app = QApplication(sys.argv)
|
| 11 |
+
window = ChatGPTUI(myModel)
|
| 12 |
+
window.show()
|
| 13 |
+
window.remove_chat_item(window.first_list_item)
|
| 14 |
+
window.create_new_chat()
|
| 15 |
+
sys.exit(app.exec_())
|
code/__pycache__/UI.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
code/__pycache__/_MyModel.cpython-311.pyc
ADDED
|
Binary file (2.11 kB). View file
|
|
|
code/__pycache__/deepseek_vaule.cpython-311.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
code/__pycache__/reward.cpython-311.pyc
ADDED
|
Binary file (5.73 kB). View file
|
|
|
code/__pycache__/train_nessary.cpython-311.pyc
ADDED
|
Binary file (8.72 kB). View file
|
|
|
code/data_process.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
def contains_chinese(text):
|
| 4 |
+
"""
|
| 5 |
+
Unicode 范围 \u4e00-\u9fff 包含常见的汉字
|
| 6 |
+
"""
|
| 7 |
+
return re.search(r'[\u4e00-\u9fff]', text) is not None
|
| 8 |
+
|
| 9 |
+
def process_lyrics(text):
|
| 10 |
+
"""
|
| 11 |
+
处理歌词文本:
|
| 12 |
+
1. 按 '/' 分割
|
| 13 |
+
2. 去除空白及空行
|
| 14 |
+
3. 过滤掉不包含中文(视为英文)的歌词
|
| 15 |
+
4. 去除重复歌词(保持原始顺序)
|
| 16 |
+
"""
|
| 17 |
+
# 使用 '/' 分割字符串得到歌词列表
|
| 18 |
+
lyrics = text.split('/')
|
| 19 |
+
processed = []
|
| 20 |
+
seen = set()
|
| 21 |
+
|
| 22 |
+
for line in lyrics:
|
| 23 |
+
# 去除两端空白
|
| 24 |
+
line = line.strip()
|
| 25 |
+
# 如果为空则跳过
|
| 26 |
+
if not line:
|
| 27 |
+
continue
|
| 28 |
+
# 如果这句歌词不包含中文,则视为英文歌词,跳过
|
| 29 |
+
if not contains_chinese(line):
|
| 30 |
+
continue
|
| 31 |
+
if len(line) < 3:
|
| 32 |
+
continue
|
| 33 |
+
# 去重:如果该句未出现过,则添加到结果中
|
| 34 |
+
if line not in seen:
|
| 35 |
+
seen.add(line)
|
| 36 |
+
processed.append(line)
|
| 37 |
+
|
| 38 |
+
return processed
|
| 39 |
+
|
| 40 |
+
def main():
|
| 41 |
+
input_filename = 'data\lyrics.txt'
|
| 42 |
+
output_filename = 'data\processed_data.txt'
|
| 43 |
+
|
| 44 |
+
# 读取原始数据文件,建议使用 utf-8 编码
|
| 45 |
+
with open(input_filename, 'r', encoding='utf-8') as f:
|
| 46 |
+
content = f.read()
|
| 47 |
+
|
| 48 |
+
# 处理歌词数据
|
| 49 |
+
processed = process_lyrics(content)
|
| 50 |
+
|
| 51 |
+
# 处理后的数据以 '/' 重新拼接,也可以改成每行一个
|
| 52 |
+
output_content = '/'.join(processed)
|
| 53 |
+
|
| 54 |
+
# 将处理后的数据写入输出文件
|
| 55 |
+
with open(output_filename, 'w', encoding='utf-8') as f:
|
| 56 |
+
f.write(output_content)
|
| 57 |
+
|
| 58 |
+
print(f'处理完成,结果保存在 {output_filename}')
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == '__main__':
|
| 62 |
+
main()
|
code/deepseek_vaule.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import openai
|
| 3 |
+
from openai import APIError
|
| 4 |
+
from typing import Dict, List, Union
|
| 5 |
+
|
| 6 |
+
# 自定义异常
|
| 7 |
+
class InsufficientBalanceError(Exception):
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
class EvaluationError(Exception):
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
# 系统提示词 - 更详细的评分标准
|
| 14 |
+
SYS_PROMPT = """你是一个专业的文本质量评估专家。请根据以下标准对文本进行评分(满分10分):
|
| 15 |
+
1. 创意性(权重25%): 内容的原创性和新颖性
|
| 16 |
+
2. 文采(权重25%): 语言表达的优美程度和修辞手法
|
| 17 |
+
3. 格式(权重25%): 结构清晰度、可读性和符合要求的格式
|
| 18 |
+
4. 长度(权重25%): 内容长度是否适中(50-300字为佳)
|
| 19 |
+
5. 总分(根据四个维度进行加权计算)
|
| 20 |
+
|
| 21 |
+
评分要求:
|
| 22 |
+
- 使用表格形式输出,得到得分表格.
|
| 23 |
+
- 每项评分保留1位小数
|
| 24 |
+
- 最后简要对目标文本的评价,而不是让你自己再写一个,切记
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def evaluate_text_quality(
|
| 28 |
+
text: str,
|
| 29 |
+
api_key: str = None,
|
| 30 |
+
model: str = "deepseek-chat",
|
| 31 |
+
temperature: float = 0.3,
|
| 32 |
+
max_tokens: int = 300
|
| 33 |
+
) -> Dict[str, Union[float, str]]:
|
| 34 |
+
# 获取API密钥
|
| 35 |
+
api_key = api_key or os.getenv("you_api_key")
|
| 36 |
+
if not api_key:
|
| 37 |
+
raise ValueError("DeepSeek API密钥未提供")
|
| 38 |
+
|
| 39 |
+
# 创建客户端
|
| 40 |
+
client = openai.OpenAI(
|
| 41 |
+
api_key=api_key,
|
| 42 |
+
base_url="https://api.deepseek.com/v1"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
# 调用API
|
| 47 |
+
response = client.chat.completions.create(
|
| 48 |
+
model=model,
|
| 49 |
+
messages=[
|
| 50 |
+
{"role": "system", "content": SYS_PROMPT},
|
| 51 |
+
{"role": "user", "content": text}
|
| 52 |
+
],
|
| 53 |
+
temperature=temperature,
|
| 54 |
+
max_tokens=max_tokens,
|
| 55 |
+
stream=False
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# 解析结果
|
| 59 |
+
output = response.choices[0].message.content.strip()
|
| 60 |
+
|
| 61 |
+
# 从API响应中提取评分
|
| 62 |
+
return parse_evaluation_result(output)
|
| 63 |
+
|
| 64 |
+
except APIError as e:
|
| 65 |
+
if e.status_code == 402: # 假设402为余额不足状态码
|
| 66 |
+
raise InsufficientBalanceError("API余额不足,请充值") from e
|
| 67 |
+
else:
|
| 68 |
+
raise EvaluationError(f"API错误[{e.status_code}]: {e.message}") from e
|
| 69 |
+
except Exception as e:
|
| 70 |
+
raise EvaluationError(f"评估失败: {str(e)}") from e
|
| 71 |
+
|
| 72 |
+
def parse_evaluation_result(output: str) -> Dict[str, Union[float, str]]:
|
| 73 |
+
"""
|
| 74 |
+
改进后的评估结果解析函数,能更好处理中文评分表格
|
| 75 |
+
"""
|
| 76 |
+
result = {
|
| 77 |
+
"scores": {
|
| 78 |
+
"creativity": 0.0,
|
| 79 |
+
"language": 0.0,
|
| 80 |
+
"format": 0.0,
|
| 81 |
+
"length": 0.0,
|
| 82 |
+
"total": 0.0
|
| 83 |
+
},
|
| 84 |
+
"evaluation": output # 默认保留全部输出
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
# 改进的表格解析逻辑
|
| 88 |
+
lines = [line.strip() for line in output.split('\n') if line.strip()]
|
| 89 |
+
|
| 90 |
+
for line in lines:
|
| 91 |
+
# 处理创意性评分
|
| 92 |
+
if "创意性" in line:
|
| 93 |
+
result["scores"]["creativity"] = extract_score_from_line(line)
|
| 94 |
+
# 处理文采评分
|
| 95 |
+
elif any(key in line for key in ["文采", "语言表达"]):
|
| 96 |
+
result["scores"]["language"] = extract_score_from_line(line)
|
| 97 |
+
# 处理格式评分
|
| 98 |
+
elif "格式" in line:
|
| 99 |
+
result["scores"]["format"] = extract_score_from_line(line)
|
| 100 |
+
# 处理长度评分
|
| 101 |
+
elif "长度" in line:
|
| 102 |
+
result["scores"]["length"] = extract_score_from_line(line)
|
| 103 |
+
# 处理总分
|
| 104 |
+
elif any(key in line for key in ["总分", "总计", "平均"]):
|
| 105 |
+
result["scores"]["total"] = extract_score_from_line(line)
|
| 106 |
+
|
| 107 |
+
# 提取评价部分(从"评价:"之后的内容)
|
| 108 |
+
evaluation_lines = []
|
| 109 |
+
found_evaluation = False
|
| 110 |
+
for line in lines:
|
| 111 |
+
if any(prefix in line for prefix in ["评价:", "评语:", "总结:"]):
|
| 112 |
+
found_evaluation = True
|
| 113 |
+
line = line.split(":", 1)[-1].strip()
|
| 114 |
+
if found_evaluation and line:
|
| 115 |
+
evaluation_lines.append(line)
|
| 116 |
+
|
| 117 |
+
if evaluation_lines:
|
| 118 |
+
result["evaluation"] = "\n".join(evaluation_lines)
|
| 119 |
+
|
| 120 |
+
return result
|
| 121 |
+
|
| 122 |
+
def extract_score_from_line(line: str) -> float:
|
| 123 |
+
"""
|
| 124 |
+
改进的分数提取函数,能处理多种表格格式
|
| 125 |
+
"""
|
| 126 |
+
try:
|
| 127 |
+
# 处理 | 创意性 | 8.5 | 这种格式
|
| 128 |
+
if "|" in line:
|
| 129 |
+
parts = [p.strip() for p in line.split("|") if p.strip()]
|
| 130 |
+
for part in parts:
|
| 131 |
+
if part.replace('.', '').isdigit():
|
| 132 |
+
return float(part)
|
| 133 |
+
|
| 134 |
+
# 处理 "创意性: 8.5" 这种格式
|
| 135 |
+
if ":" in line or ":" in line:
|
| 136 |
+
parts = line.split(":", 1) if ":" in line else line.split(":", 1)
|
| 137 |
+
num_part = parts[-1].strip()
|
| 138 |
+
for s in num_part.split():
|
| 139 |
+
s = s.replace('/', '').replace('分', '')
|
| 140 |
+
if s.replace('.', '').isdigit():
|
| 141 |
+
return float(s)
|
| 142 |
+
|
| 143 |
+
# 直接搜索数字
|
| 144 |
+
for word in line.split():
|
| 145 |
+
word = word.replace('分', '').replace('/', '')
|
| 146 |
+
if word.replace('.', '').isdigit():
|
| 147 |
+
return float(word)
|
| 148 |
+
|
| 149 |
+
except (ValueError, IndexError):
|
| 150 |
+
pass
|
| 151 |
+
|
| 152 |
+
return 0.0
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def print_evaluation_result(
|
| 156 |
+
evaluation: Dict[str, Union[float, str]],
|
| 157 |
+
show_details: bool = True,
|
| 158 |
+
score_only: bool = False
|
| 159 |
+
) -> None:
|
| 160 |
+
"""
|
| 161 |
+
打印评估结果
|
| 162 |
+
|
| 163 |
+
参数:
|
| 164 |
+
evaluation: evaluate_text_quality返回的评估结果字典
|
| 165 |
+
show_details: 是否显示详细评价
|
| 166 |
+
score_only: 是否仅显示分数(优先级高于show_details)
|
| 167 |
+
"""
|
| 168 |
+
if not evaluation:
|
| 169 |
+
print("无有效评估结果")
|
| 170 |
+
return
|
| 171 |
+
|
| 172 |
+
scores = evaluation.get("scores", {})
|
| 173 |
+
evaluation_text = evaluation.get("evaluation", "")
|
| 174 |
+
|
| 175 |
+
# 打印分数摘要
|
| 176 |
+
print("\n=== 文本质量评估 ===")
|
| 177 |
+
print(f"[创意性] {scores.get('creativity', 0.0):.1f}/10")
|
| 178 |
+
print(f"[文采] {scores.get('language', 0.0):.1f}/10")
|
| 179 |
+
print(f"[格式] {scores.get('format', 0.0):.1f}/10")
|
| 180 |
+
print(f"[长度] {scores.get('length', 0.0):.1f}/10")
|
| 181 |
+
print("-" * 25)
|
| 182 |
+
print(f"[总分] {scores.get('total', 0.0):.1f}/10")
|
| 183 |
+
|
| 184 |
+
# 根据参数决定是否显示详细评价
|
| 185 |
+
if not score_only and show_details and evaluation_text:
|
| 186 |
+
print("\n=== 详细评价 ===")
|
| 187 |
+
print(evaluation_text)
|
code/getCOT.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import openai
|
| 3 |
+
import threading
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
+
from openai import APIError
|
| 6 |
+
|
| 7 |
+
API_KEY = os.getenv("DEEPSEEK_API_KEY", "your_api_key")
|
| 8 |
+
|
| 9 |
+
class ThreadSafeWriter:
|
| 10 |
+
"""线程安全写入器"""
|
| 11 |
+
def __init__(self, output_path: str):
|
| 12 |
+
self.file = open(output_path, 'a+', encoding='utf-8')
|
| 13 |
+
self.lock = threading.Lock()
|
| 14 |
+
self.counter = 0
|
| 15 |
+
|
| 16 |
+
def write_line(self, content: str):
|
| 17 |
+
with self.lock:
|
| 18 |
+
self.file.write(content + '\n')
|
| 19 |
+
self.file.flush()
|
| 20 |
+
self.counter += 1
|
| 21 |
+
|
| 22 |
+
def get_progress(self):
|
| 23 |
+
with self.lock:
|
| 24 |
+
return self.counter
|
| 25 |
+
|
| 26 |
+
def close(self):
|
| 27 |
+
self.file.close()
|
| 28 |
+
|
| 29 |
+
class DeepSeekBatchProcessor:
|
| 30 |
+
def __init__(self, max_workers: int = 100):
|
| 31 |
+
self.client = openai.OpenAI(
|
| 32 |
+
api_key=API_KEY,
|
| 33 |
+
base_url="https://api.deepseek.com/v1"
|
| 34 |
+
)
|
| 35 |
+
self.max_workers = max_workers
|
| 36 |
+
self.error_flag = threading.Event()
|
| 37 |
+
self.rate_limiter = threading.Semaphore(20)
|
| 38 |
+
|
| 39 |
+
def process_batch(self, batch, writer: ThreadSafeWriter):
|
| 40 |
+
"""批量处理,每个任务单独线程"""
|
| 41 |
+
futures = []
|
| 42 |
+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
| 43 |
+
for line_num, line in batch:
|
| 44 |
+
if self.error_flag.is_set():
|
| 45 |
+
break
|
| 46 |
+
futures.append(
|
| 47 |
+
executor.submit(
|
| 48 |
+
self._process_single_line,
|
| 49 |
+
line_num,
|
| 50 |
+
line,
|
| 51 |
+
writer
|
| 52 |
+
)
|
| 53 |
+
)
|
| 54 |
+
for future in futures:
|
| 55 |
+
future.result()
|
| 56 |
+
|
| 57 |
+
def _process_single_line(self, line_num: int, line: str, writer: ThreadSafeWriter):
|
| 58 |
+
if self.error_flag.is_set():
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
# 支持英文冒号(:)和中文全角冒号(:)
|
| 62 |
+
separator = None
|
| 63 |
+
if ':' in line:
|
| 64 |
+
separator = ':'
|
| 65 |
+
elif ':' in line:
|
| 66 |
+
separator = ':'
|
| 67 |
+
|
| 68 |
+
if not separator:
|
| 69 |
+
print(f"\n行 {line_num} 格式错误")
|
| 70 |
+
writer.write_line(f"格式错误:{line}")
|
| 71 |
+
return
|
| 72 |
+
|
| 73 |
+
keywords_part, original_text = line.split(separator, 1)
|
| 74 |
+
# 这里只提取关键词部分(例如“风,雾,寂寞”)
|
| 75 |
+
keywords = [kw.strip() for kw in keywords_part.split(",") if kw.strip()]
|
| 76 |
+
if not keywords:
|
| 77 |
+
keywords = ["无关键词"]
|
| 78 |
+
|
| 79 |
+
# 构造提示:根据关键词生成诗歌
|
| 80 |
+
prompt = "请根据以下关键词写一首诗:" + ",".join(keywords)
|
| 81 |
+
messages = [{"role": "user", "content": prompt}]
|
| 82 |
+
|
| 83 |
+
retries = 0
|
| 84 |
+
while retries < 3 and not self.error_flag.is_set():
|
| 85 |
+
try:
|
| 86 |
+
with self.rate_limiter:
|
| 87 |
+
response = self.client.chat.completions.create(
|
| 88 |
+
model="deepseek-reasoner",
|
| 89 |
+
messages=messages,
|
| 90 |
+
temperature=0.1
|
| 91 |
+
)
|
| 92 |
+
# 提取返回中的思考过程和诗歌原文
|
| 93 |
+
reasoning_content = response.choices[0].message.reasoning_content.replace('\n', '').replace('\r', '')
|
| 94 |
+
poem_original = response.choices[0].message.content.replace('\n', '/').replace('\r', '')
|
| 95 |
+
# 拼接最终结果:关键词<think>思考过程</think>:诗歌原文
|
| 96 |
+
final_line = f"{','.join(keywords)}<think>{reasoning_content}</think>:{poem_original}"
|
| 97 |
+
writer.write_line(final_line)
|
| 98 |
+
progress = writer.get_progress()
|
| 99 |
+
print(f"\r已处理 {progress} 条", end='')
|
| 100 |
+
break
|
| 101 |
+
|
| 102 |
+
except APIError as e:
|
| 103 |
+
if e.status_code == 402:
|
| 104 |
+
print(f"\n行 {line_num} 处理失败:API余额不足")
|
| 105 |
+
self.error_flag.set()
|
| 106 |
+
return
|
| 107 |
+
elif e.status_code == 429:
|
| 108 |
+
print(f"\n行 {line_num} 速率受限,重试中...")
|
| 109 |
+
retries += 1
|
| 110 |
+
if retries >= 3:
|
| 111 |
+
print(f"\n行 {line_num} 重试次数耗尽")
|
| 112 |
+
else:
|
| 113 |
+
print(f"\n行 {line_num} API错误[{e.status_code}]:{e.message}")
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f"\n行 {line_num} 处理异常:{str(e)}")
|
| 118 |
+
retries += 1
|
| 119 |
+
if retries >= 3:
|
| 120 |
+
print(f"\n行 {line_num} 重试次数耗尽")
|
| 121 |
+
|
| 122 |
+
if retries >= 3 and not self.error_flag.is_set():
|
| 123 |
+
writer.write_line(f"处理失败:{line}")
|
| 124 |
+
|
| 125 |
+
def process_first_1000_lines(input_path: str, output_path: str, max_workers: int = 100):
|
| 126 |
+
"""仅读取前1000行数据,并使用多线程处理"""
|
| 127 |
+
processor = DeepSeekBatchProcessor(max_workers)
|
| 128 |
+
writer = ThreadSafeWriter(output_path)
|
| 129 |
+
batch = []
|
| 130 |
+
try:
|
| 131 |
+
with open(input_path, 'r', encoding='utf-8') as f:
|
| 132 |
+
for line_num, line in enumerate(f, 1):
|
| 133 |
+
if not line.strip():
|
| 134 |
+
continue
|
| 135 |
+
batch.append( (line_num, line.strip()) )
|
| 136 |
+
if line_num >= 1000:
|
| 137 |
+
break
|
| 138 |
+
|
| 139 |
+
total = len(batch)
|
| 140 |
+
print(f"总数据量:{total} 条")
|
| 141 |
+
processor.process_batch(batch, writer)
|
| 142 |
+
print("\n处理完成!")
|
| 143 |
+
finally:
|
| 144 |
+
writer.close()
|
| 145 |
+
|
| 146 |
+
if __name__ == '__main__':
|
| 147 |
+
input_file = "data/DSdata.txt"
|
| 148 |
+
output_file = "data/CoTdata.txt"
|
| 149 |
+
process_first_1000_lines(input_file, output_file, max_workers=100)
|
code/reward.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import List, Dict, Union, Optional
|
| 4 |
+
from sentence_transformers import SentenceTransformer, util
|
| 5 |
+
from multiprocessing import Pool, cpu_count
|
| 6 |
+
|
| 7 |
+
# 全局初始化 SentenceTransformer 模型,并移动到 GPU
|
| 8 |
+
embedder = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2').to("cuda")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def compute_rewards(
|
| 12 |
+
completions: List[str],
|
| 13 |
+
min_len: Union[int, List[int]] = 100,
|
| 14 |
+
max_len: Union[int, List[int]] = 300,
|
| 15 |
+
weights: Union[tuple, List[tuple]] = (0.25, 0.25, 0.25, 0.25),
|
| 16 |
+
return_components: bool = False,
|
| 17 |
+
**kwargs
|
| 18 |
+
) -> Union[List[float], Dict[str, List[float]]]:
|
| 19 |
+
"""并行优化的奖励计算函数"""
|
| 20 |
+
keywords = kwargs["keywords"]
|
| 21 |
+
n_samples = len(completions)
|
| 22 |
+
|
| 23 |
+
min_len = _to_list(min_len, n_samples)
|
| 24 |
+
max_len = _to_list(max_len, n_samples)
|
| 25 |
+
weights = _to_list(weights, n_samples)
|
| 26 |
+
|
| 27 |
+
# 并行计算各子奖励
|
| 28 |
+
with Pool(cpu_count()) as pool:
|
| 29 |
+
length_rewards = pool.starmap(_length_reward, zip(completions, min_len, max_len))
|
| 30 |
+
format_rewards = pool.map(_format_reward, completions)
|
| 31 |
+
keyword_rewards = _batch_keyword_reward(completions, keywords) # 这个用 GPU 计算
|
| 32 |
+
language_rewards = pool.map(_language_reward, completions)
|
| 33 |
+
|
| 34 |
+
# 加权求和总奖励
|
| 35 |
+
total_rewards = [
|
| 36 |
+
w[0] * lr + w[1] * fr + w[2] * kr + w[3] * lang_r
|
| 37 |
+
for w, lr, fr, kr, lang_r in zip(weights, length_rewards, format_rewards, keyword_rewards, language_rewards)
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
if return_components:
|
| 41 |
+
return {
|
| 42 |
+
"rewards": total_rewards,
|
| 43 |
+
"length_rewards": length_rewards,
|
| 44 |
+
"format_rewards": format_rewards,
|
| 45 |
+
"keyword_rewards": keyword_rewards,
|
| 46 |
+
"language_rewards": language_rewards,
|
| 47 |
+
}
|
| 48 |
+
return total_rewards
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# -------------- 并行子函数 --------------
|
| 52 |
+
def _to_list(val: Union[any, List[any]], n: int) -> List[any]:
|
| 53 |
+
"""转换为样本级列表"""
|
| 54 |
+
return val if isinstance(val, list) else [val] * n
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _length_reward(text: str, min_len: int, max_len: int) -> float:
|
| 58 |
+
"""单样本长度奖励"""
|
| 59 |
+
original = text.split("</think>:", 1)[1].strip() if "</think>:" in text else text.strip()
|
| 60 |
+
length = len(original)
|
| 61 |
+
|
| 62 |
+
if length < min_len:
|
| 63 |
+
return length / min_len + 1 # 1~2线性增长
|
| 64 |
+
elif length > max_len:
|
| 65 |
+
return max_len / length + 1 # 2~1线性衰减
|
| 66 |
+
return 2.0
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _format_reward(text: str) -> float:
|
| 70 |
+
"""单样本格式奖励"""
|
| 71 |
+
if "<think>" not in text or "</think>:" not in text:
|
| 72 |
+
return -2.0
|
| 73 |
+
think_content = text.split("<think>")[1].split("</think>")[0].strip()
|
| 74 |
+
return 2.0 if think_content else -2.0
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _batch_keyword_reward(texts: List[str], keywords_list: List[List[str]]) -> List[float]:
|
| 78 |
+
"""批量关键词匹配(优化:使用 GPU 并行计算)"""
|
| 79 |
+
originals = [text.split("</think>:", 1)[1].strip() if "</think>:" in text else text.strip() for text in texts]
|
| 80 |
+
valid_indices = [i for i, orig in enumerate(originals) if orig and keywords_list[i]]
|
| 81 |
+
|
| 82 |
+
if not valid_indices:
|
| 83 |
+
return [0.8 if not kw else -2.0 for kw in keywords_list] # 无关键词时默认0.8
|
| 84 |
+
|
| 85 |
+
valid_originals = [originals[i] for i in valid_indices]
|
| 86 |
+
valid_keywords = [keywords_list[i] for i in valid_indices]
|
| 87 |
+
|
| 88 |
+
# 让计算在 GPU 上执行
|
| 89 |
+
original_embs = embedder.encode(valid_originals, convert_to_tensor=True)
|
| 90 |
+
keyword_embs = [embedder.encode(kw, convert_to_tensor=True) for kw in valid_keywords]
|
| 91 |
+
|
| 92 |
+
similarities = [
|
| 93 |
+
util.pytorch_cos_sim(orig_emb, kw_emb).mean().item()
|
| 94 |
+
for orig_emb, kw_emb in zip(original_embs, keyword_embs)
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
# 分配奖励
|
| 98 |
+
rewards = []
|
| 99 |
+
sim_idx = 0
|
| 100 |
+
for i, kw in enumerate(keywords_list):
|
| 101 |
+
if i in valid_indices:
|
| 102 |
+
sim = similarities[sim_idx]
|
| 103 |
+
rewards.append(2.0 if sim >= 0.6 else (1.2 if sim >= 0.4 else 0.8))
|
| 104 |
+
sim_idx += 1
|
| 105 |
+
else:
|
| 106 |
+
rewards.append(0.8 if not kw else -2.0)
|
| 107 |
+
return rewards
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _language_reward(text: str) -> float:
|
| 111 |
+
"""单样本语言奖励"""
|
| 112 |
+
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
|
| 113 |
+
ratio = chinese_chars / max(1, len(text))
|
| 114 |
+
|
| 115 |
+
if ratio >= 0.9:
|
| 116 |
+
return 2.0
|
| 117 |
+
elif ratio >= 0.7:
|
| 118 |
+
return 1.4
|
| 119 |
+
return 0.7
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# ------------ 运行示例 ------------
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
samples = [
|
| 125 |
+
"科技<think>技术创新是关键</think>:人工智能在医疗领域的应用正在改变诊断方式。",
|
| 126 |
+
"无效样本<think></think>:无意义内容",
|
| 127 |
+
"经济<think>宏观经济分析</think>:全球供应链重构对发展中国家影响深远。"
|
| 128 |
+
]
|
| 129 |
+
keywords = [
|
| 130 |
+
["科技", "人工智能"],
|
| 131 |
+
[], # 空关键词
|
| 132 |
+
["经济", "供应链"]
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
# 并行计算
|
| 136 |
+
rewards = compute_rewards(
|
| 137 |
+
completions=samples,
|
| 138 |
+
keywords=keywords,
|
| 139 |
+
min_len=[50, 10, 80],
|
| 140 |
+
return_components=True
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
print("总奖励:", rewards["rewards"])
|
| 144 |
+
print("长度奖励:", rewards["length_rewards"])
|
| 145 |
+
print("格式奖励:", rewards["format_rewards"])
|
| 146 |
+
print("关键词奖励:", rewards["keyword_rewards"])
|
| 147 |
+
print("语言奖励:", rewards["language_rewards"])
|
code/test.ipynb
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "initial_id",
|
| 7 |
+
"metadata": {
|
| 8 |
+
"ExecuteTime": {
|
| 9 |
+
"end_time": "2025-04-02T06:42:56.681032Z",
|
| 10 |
+
"start_time": "2025-04-02T06:42:19.346090Z"
|
| 11 |
+
},
|
| 12 |
+
"collapsed": true
|
| 13 |
+
},
|
| 14 |
+
"outputs": [],
|
| 15 |
+
"source": [
|
| 16 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
| 17 |
+
"from peft import PeftModel\n",
|
| 18 |
+
"# 1. 加载基础模型和LoRA适配器\n",
|
| 19 |
+
"base_model = AutoModelForCausalLM.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\")#您也可以使用GPU推理\n",
|
| 20 |
+
"model = PeftModel.from_pretrained(base_model, \"../DS_RL_model\") # .to(\"cuda\")使用GPU加速推理\n",
|
| 21 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\")\n",
|
| 22 |
+
"tokenizer.pad_token = tokenizer.eos_token\n",
|
| 23 |
+
"\n"
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"execution_count": 5,
|
| 29 |
+
"id": "c805805aeaabd6a8",
|
| 30 |
+
"metadata": {
|
| 31 |
+
"ExecuteTime": {
|
| 32 |
+
"end_time": "2025-04-02T07:41:36.715675Z",
|
| 33 |
+
"start_time": "2025-04-02T07:41:27.640736Z"
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
"outputs": [
|
| 37 |
+
{
|
| 38 |
+
"name": "stdout",
|
| 39 |
+
"output_type": "stream",
|
| 40 |
+
"text": [
|
| 41 |
+
"模型输出: 根据以下关键词生成一首歌词,歌词中包含多个句子,确保句子通顺,诗意,格式正确.让我们一步一步的思考(思考过程包含在<think>和</think>之间):月光,欢乐,伊甸园,月光下的欢乐,小猪们、小羊们,月光下的欢乐。月光下的欢乐,小猪们、小羊们,月光下的欢乐,小猪们、小羊们,月光下的欢乐,月光下的欢乐,小猪们、小羊们,月光下的欢乐。月光,小猪们、小羊们,月光下的欢乐,月光下的欢乐,月光下的欢乐,月光下的欢乐,月光下的欢乐,月光下的欢乐,月光下的欢乐,月光下的欢乐,月光下的欢乐。\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"嗯,我现在需要帮用户生成一首关于月光、欢乐、伊甸园的歌词。用户给了一个比较长的查询,里面有很多重复的句子,可能想要更简洁或者更流畅的歌词。我得先理解用户的需求,可能他们是在做一个儿童文学作品,或者是在学习如何创作歌词。\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"首先,关键词有月光、欢乐、伊甸园、小猪、小羊。所以歌词里要包含这些元素。用户给出的回复里有很多重复,可能是因为想通过多个句子来强调主题,让读者更容易理解和记忆。\n",
|
| 46 |
+
"\n",
|
| 47 |
+
"我需要确保歌词结构合理,有起承转合,句子通顺。可能用户希望歌词有一定的押韵和节奏感,这样读起来更顺口。同时,格式要正确,可能需要遵循中文诗歌的格式,比如分句、押韵等。\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"另外,用户提供的回复是多次重复的句子,可能是因为想强调月光下的欢乐,让读者感受到那种温馨和欢乐。我需要在生成歌词时,把这些元素自然地融入进去,而不是单纯地重复。\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"我还得考虑歌词的情感基调,是欢快的还是带有感慨的。用户没有特别说明,但关键词中提到“欢乐”和“月光”,感觉偏向于积极向上的情感。\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"可能需要避免过于复杂的结构,保持歌词简洁明了,同时有足够的意象来传达主题。比如,用“月光下的欢乐”这样的词句,可以增强画面感,让读者有身临其境的感觉。\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"另外,用户提到了小猪和小羊,可能是在描绘一个小动物们的场景,或者是在描述一个充满欢乐的小世界。可能需要把这些元素融合在歌词中,让读者感受到那种温暖和快乐。\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"我还需要注意押韵,虽然中文诗歌不一定严格押韵,但要有一定的节奏感。选择合适的结尾词来增强主题的表达。\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"总的来说,我需要把月光、欢乐、小猪、小羊、伊甸园这几个元素有机地融入歌词中,确保结构合理,情感流畅,同时保持格式正确。可能需要多试几遍,调整用词和句式,直到满意为止。\n",
|
| 60 |
+
"</think>\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"## 《月光下的欢乐》\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"月光如水般温柔,\n",
|
| 65 |
+
"在掌心流淌着幸福的泪。\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"小猪们、小羊们,\n",
|
| 68 |
+
"在伊甸园里跳跃舞。\n",
|
| 69 |
+
"月光下欢声笑语,\n",
|
| 70 |
+
"欢声笑语映照着我们的脸。\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"月光下欢声笑语,\n",
|
| 73 |
+
"月光下欢声笑语,\n",
|
| 74 |
+
"月光下欢声笑语,\n",
|
| 75 |
+
"月光下欢声笑语。\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"月光下欢声笑语,\n",
|
| 78 |
+
"月光下欢声笑语,\n",
|
| 79 |
+
"月光下欢声笑语,\n",
|
| 80 |
+
"月光下欢声笑语。\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"月光下欢声笑语,\n",
|
| 83 |
+
"月光下欢声笑语,\n",
|
| 84 |
+
"月光下欢声笑语,\n",
|
| 85 |
+
"月光下欢声笑语。\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"月光下欢声笑语,\n",
|
| 88 |
+
"月光下欢声笑语,\n",
|
| 89 |
+
"月光下欢声笑语,\n",
|
| 90 |
+
"月光下欢声笑语。\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"月光下欢声笑语,\n",
|
| 93 |
+
"月光下欢声笑语,\n",
|
| 94 |
+
"月光下欢声笑语,\n",
|
| 95 |
+
"月光下欢声笑语。\n"
|
| 96 |
+
]
|
| 97 |
+
}
|
| 98 |
+
],
|
| 99 |
+
"source": [
|
| 100 |
+
"# 2. 准备提示词\n",
|
| 101 |
+
"prompt = \"根据以下关键词生成一首歌词,歌词中包含多个句子,确保句子通顺,诗意,格式正确.让我们一步一步的思考(思考过程包含在<think>和</think>之间):月光,欢乐,伊甸园\" \n",
|
| 102 |
+
"\n",
|
| 103 |
+
"# 3. 编码并生成回复\n",
|
| 104 |
+
"inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"# 4. 生成参数设置\n",
|
| 107 |
+
"outputs = model.generate(\n",
|
| 108 |
+
" input_ids=inputs[\"input_ids\"],\n",
|
| 109 |
+
" attention_mask=inputs[\"attention_mask\"],\n",
|
| 110 |
+
" max_new_tokens=2048, # 生成的最大token数\n",
|
| 111 |
+
" do_sample=True, # 启用随机采样\n",
|
| 112 |
+
" temperature=0.9, # 控制随机性 (0.1-1.0)\n",
|
| 113 |
+
" top_p=0.9, # nucleus sampling参数\n",
|
| 114 |
+
" pad_token_id=tokenizer.eos_token_id\n",
|
| 115 |
+
")\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"# 5. 解码并打印结果\n",
|
| 118 |
+
"generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
|
| 119 |
+
"print(\"模型输出:\", generated_text)"
|
| 120 |
+
]
|
| 121 |
+
}
|
| 122 |
+
],
|
| 123 |
+
"metadata": {
|
| 124 |
+
"kernelspec": {
|
| 125 |
+
"display_name": ".venv",
|
| 126 |
+
"language": "python",
|
| 127 |
+
"name": "python3"
|
| 128 |
+
},
|
| 129 |
+
"language_info": {
|
| 130 |
+
"codemirror_mode": {
|
| 131 |
+
"name": "ipython",
|
| 132 |
+
"version": 3
|
| 133 |
+
},
|
| 134 |
+
"file_extension": ".py",
|
| 135 |
+
"mimetype": "text/x-python",
|
| 136 |
+
"name": "python",
|
| 137 |
+
"nbconvert_exporter": "python",
|
| 138 |
+
"pygments_lexer": "ipython3",
|
| 139 |
+
"version": "3.11.3"
|
| 140 |
+
}
|
| 141 |
+
},
|
| 142 |
+
"nbformat": 4,
|
| 143 |
+
"nbformat_minor": 5
|
| 144 |
+
}
|
code/threads_data_extract.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import openai
|
| 3 |
+
import threading
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
+
from openai import APIError
|
| 6 |
+
from typing import List, Tuple
|
| 7 |
+
|
| 8 |
+
class FormatValidator:
|
| 9 |
+
"""数据格式验证器"""
|
| 10 |
+
@staticmethod
|
| 11 |
+
def validate_line(keywords: List[str], original: str) -> str:
|
| 12 |
+
"""
|
| 13 |
+
格式:关键词1,关键词2,关键词3:原文
|
| 14 |
+
"""
|
| 15 |
+
# 清洗关键词中的非法符号
|
| 16 |
+
cleaned_keywords = [
|
| 17 |
+
kw.strip().replace(':', '').replace('\n', '')[:10] # 限制关键词长度
|
| 18 |
+
for kw in keywords if kw.strip()
|
| 19 |
+
][:3] # 最多取前3个关键词
|
| 20 |
+
|
| 21 |
+
# 处理空关键词情况
|
| 22 |
+
if not cleaned_keywords:
|
| 23 |
+
keywords_str = "无关键词"
|
| 24 |
+
else:
|
| 25 |
+
keywords_str = ",".join(cleaned_keywords)
|
| 26 |
+
|
| 27 |
+
# 移除原文中的换行符
|
| 28 |
+
cleaned_original = original.strip().replace('\n', ' ')
|
| 29 |
+
return f"{keywords_str}:{cleaned_original}"
|
| 30 |
+
|
| 31 |
+
class ThreadSafeWriter:
|
| 32 |
+
"""增强型线程安全写入器"""
|
| 33 |
+
def __init__(self, output_path: str):
|
| 34 |
+
self.file = open(output_path, 'a+', encoding='utf-8')
|
| 35 |
+
self.lock = threading.Lock()
|
| 36 |
+
self.counter = 0 # 写入计数器
|
| 37 |
+
|
| 38 |
+
def write_line(self, content: str):
|
| 39 |
+
with self.lock:
|
| 40 |
+
self.file.write(content + '\n')
|
| 41 |
+
self.file.flush()
|
| 42 |
+
self.counter += 1
|
| 43 |
+
|
| 44 |
+
def get_progress(self):
|
| 45 |
+
with self.lock:
|
| 46 |
+
return self.counter
|
| 47 |
+
|
| 48 |
+
def close(self):
|
| 49 |
+
self.file.close()
|
| 50 |
+
|
| 51 |
+
class DeepSeekBatchProcessor:
|
| 52 |
+
def __init__(self, max_workers: int = 100):
|
| 53 |
+
self.client = openai.OpenAI(
|
| 54 |
+
api_key=os.getenv("DEEPSEEK_API_KEY", "sk-4da7e956235447e3b7bec1b51f5a3db7"),
|
| 55 |
+
base_url="https://api.deepseek.com"
|
| 56 |
+
)
|
| 57 |
+
self.max_workers = max_workers
|
| 58 |
+
self.error_flag = threading.Event()
|
| 59 |
+
self.rate_limiter = threading.Semaphore(20) # API速率限制
|
| 60 |
+
|
| 61 |
+
def process_batch(self, batch: List[Tuple[int, str]], writer: ThreadSafeWriter):
|
| 62 |
+
"""批量处理并保持顺序"""
|
| 63 |
+
futures = []
|
| 64 |
+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
| 65 |
+
for line_num, original in batch:
|
| 66 |
+
if self.error_flag.is_set():
|
| 67 |
+
break
|
| 68 |
+
futures.append(
|
| 69 |
+
executor.submit(
|
| 70 |
+
self._process_single_line,
|
| 71 |
+
line_num,
|
| 72 |
+
original,
|
| 73 |
+
writer
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# 等待当前批次完成
|
| 78 |
+
for future in futures:
|
| 79 |
+
future.result()
|
| 80 |
+
|
| 81 |
+
def _process_single_line(self, line_num: int, original: str, writer: ThreadSafeWriter):
|
| 82 |
+
if self.error_flag.is_set():
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
retries = 0
|
| 86 |
+
while retries < 3 and not self.error_flag.is_set():
|
| 87 |
+
try:
|
| 88 |
+
with self.rate_limiter:
|
| 89 |
+
response = self.client.chat.completions.create(
|
| 90 |
+
model="deepseek-reasoner",
|
| 91 |
+
messages=[
|
| 92 |
+
{"role": "system", "content": self._get_prompt()},
|
| 93 |
+
{"role": "user", "content": original}
|
| 94 |
+
],
|
| 95 |
+
temperature=0.1,
|
| 96 |
+
max_tokens=30
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# 解析响应
|
| 100 |
+
keywords = self._parse_response(response)
|
| 101 |
+
formatted_line = FormatValidator.validate_line(keywords, original)
|
| 102 |
+
writer.write_line(formatted_line)
|
| 103 |
+
|
| 104 |
+
# 更新进度
|
| 105 |
+
progress = writer.get_progress()
|
| 106 |
+
print(f"\r已处理 {progress} 条", end='')
|
| 107 |
+
|
| 108 |
+
break # 成功时退出重试循环
|
| 109 |
+
|
| 110 |
+
except APIError as e:
|
| 111 |
+
if e.status_code == 402: # 余额不足
|
| 112 |
+
print(f"\n行 {line_num} 处理失败:API余额不足")
|
| 113 |
+
self.error_flag.set()
|
| 114 |
+
return
|
| 115 |
+
elif e.status_code == 429: # 速率限制
|
| 116 |
+
print(f"\n行 {line_num} 速率受限,重试中...")
|
| 117 |
+
retries += 1
|
| 118 |
+
if retries >= 3:
|
| 119 |
+
print(f"行 {line_num} 重试次数耗尽")
|
| 120 |
+
else:
|
| 121 |
+
print(f"\n行 {line_num} API错误[{e.status_code}]:{e.message}")
|
| 122 |
+
return # 其他API错误不重试
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"\n行 {line_num} 处理异常:{str(e)}")
|
| 126 |
+
retries += 1
|
| 127 |
+
if retries >= 3:
|
| 128 |
+
print(f"行 {line_num} 重试次数耗尽")
|
| 129 |
+
|
| 130 |
+
# 重试失败处理
|
| 131 |
+
if retries >= 3 and not self.error_flag.is_set():
|
| 132 |
+
writer.write_line(f"处理失败:{original}") # 记录失败行
|
| 133 |
+
|
| 134 |
+
@staticmethod
|
| 135 |
+
def _get_prompt() -> str:
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def _parse_response(response) -> List[str]:
|
| 140 |
+
content = response.choices[0].message.content.strip()
|
| 141 |
+
return [kw.strip("。、") for kw in content.replace(',', ',').split(',') if kw]
|
| 142 |
+
|
| 143 |
+
def process_large_file(
|
| 144 |
+
input_path: str,
|
| 145 |
+
output_path: str,
|
| 146 |
+
batch_size: int = 500,
|
| 147 |
+
max_workers: int = 100
|
| 148 |
+
):
|
| 149 |
+
"""大文件处理入口"""
|
| 150 |
+
# 初始化组件
|
| 151 |
+
processor = DeepSeekBatchProcessor(max_workers)
|
| 152 |
+
writer = ThreadSafeWriter(output_path)
|
| 153 |
+
|
| 154 |
+
try:
|
| 155 |
+
# 读取并批处理数据
|
| 156 |
+
with open(input_path, 'r', encoding='utf-8') as f:
|
| 157 |
+
# 生成带行号的批次 [(行号, 内容), ...]
|
| 158 |
+
batches = []
|
| 159 |
+
current_batch = []
|
| 160 |
+
for line_num, line in enumerate(f, 1):
|
| 161 |
+
if line.strip():
|
| 162 |
+
current_batch.append( (line_num, line.strip()) )
|
| 163 |
+
if len(current_batch) >= batch_size:
|
| 164 |
+
batches.append(current_batch)
|
| 165 |
+
current_batch = []
|
| 166 |
+
if current_batch:
|
| 167 |
+
batches.append(current_batch)
|
| 168 |
+
|
| 169 |
+
# 按批次处理(保持批次顺序)
|
| 170 |
+
total = sum(len(b) for b in batches)
|
| 171 |
+
print(f"总数据量:{total}条")
|
| 172 |
+
|
| 173 |
+
for batch in batches:
|
| 174 |
+
if processor.error_flag.is_set():
|
| 175 |
+
break
|
| 176 |
+
processor.process_batch(batch, writer)
|
| 177 |
+
|
| 178 |
+
print("\n处理完成!")
|
| 179 |
+
|
| 180 |
+
finally:
|
| 181 |
+
writer.close()
|
| 182 |
+
|
| 183 |
+
if __name__ == '__main__':
|
| 184 |
+
# 文件路径配置
|
| 185 |
+
input_file = "data\DSdata.txt"
|
| 186 |
+
output_file = "data\CoTdata.txt"
|
| 187 |
+
|
| 188 |
+
# 启动处理流程
|
| 189 |
+
process_large_file(
|
| 190 |
+
input_path=input_file,
|
| 191 |
+
output_path=output_file,
|
| 192 |
+
batch_size=500,
|
| 193 |
+
max_workers=100
|
| 194 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
PyQt5>=5.15
|
| 2 |
+
transformers>=4.30
|
| 3 |
+
peft>=0.15
|
| 4 |
+
torch>=2.0
|
| 5 |
+
numpy
|
| 6 |
+
matplotlib
|
| 7 |
+
jupyter
|
| 8 |
+
trl
|
| 9 |
+
datasets
|
| 10 |
+
accelerate
|
| 11 |
+
safetensors
|
| 12 |
+
scipy
|
| 13 |
+
tqdm
|
| 14 |
+
tensorboard
|
| 15 |
+
sentence-transformers
|
| 16 |
+
openai
|