Billy Lin commited on
Commit ·
97a5393
1
Parent(s): de88dca
text-emotion-classification
Browse files- .gitignore +46 -0
- LICENSE +201 -0
- README.md +59 -3
- README_zh.md +57 -0
- emotion-classification-train.csv +0 -0
- environment.yml +185 -0
- main.py +96 -0
- sentiment_roberta/config.json +56 -0
- sentiment_roberta/model.safetensors +3 -0
- sentiment_roberta/tokenizer.json +0 -0
- sentiment_roberta/tokenizer_config.json +14 -0
- text-emotion-classification.py +66 -0
- text-emotion.yaml +8 -0
- train-data-preload.py +66 -0
.gitignore
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
.venv/
|
| 6 |
+
venv/
|
| 7 |
+
ENV/
|
| 8 |
+
.env
|
| 9 |
+
.env.*
|
| 10 |
+
|
| 11 |
+
# Conda
|
| 12 |
+
.conda/
|
| 13 |
+
|
| 14 |
+
# Jupyter
|
| 15 |
+
.ipynb_checkpoints/
|
| 16 |
+
|
| 17 |
+
# OS / IDE
|
| 18 |
+
.DS_Store
|
| 19 |
+
Thumbs.db
|
| 20 |
+
.vscode/
|
| 21 |
+
.idea/
|
| 22 |
+
.cursor/
|
| 23 |
+
|
| 24 |
+
# Logs
|
| 25 |
+
*.log
|
| 26 |
+
|
| 27 |
+
# Build / packaging
|
| 28 |
+
build/
|
| 29 |
+
dist/
|
| 30 |
+
*.spec
|
| 31 |
+
|
| 32 |
+
# Hugging Face / Transformers caches
|
| 33 |
+
.cache/
|
| 34 |
+
|
| 35 |
+
# Model training artifacts (NOT needed for inference)
|
| 36 |
+
sentiment_roberta/checkpoint-*/
|
| 37 |
+
sentiment_roberta/**/optimizer.pt
|
| 38 |
+
sentiment_roberta/**/scheduler.pt
|
| 39 |
+
sentiment_roberta/**/scaler.pt
|
| 40 |
+
sentiment_roberta/**/rng_state.pth
|
| 41 |
+
sentiment_roberta/**/trainer_state.json
|
| 42 |
+
sentiment_roberta/**/training_args.bin
|
| 43 |
+
sentiment_roberta/training_args.bin
|
| 44 |
+
|
| 45 |
+
# Optional: keep only inference model files in sentiment_roberta/
|
| 46 |
+
# (Do NOT ignore model.safetensors/config/tokenizer* by default)
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,3 +1,59 @@
|
|
| 1 |
-
--
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# text-emotion-classification
|
| 2 |
+
|
| 3 |
+
A text emotion recognition application that can be quickly deployed and used locally. You can perform interactive inference simply by running `main.py`.
|
| 4 |
+
|
| 5 |
+
[中文版](./README_zh.md)
|
| 6 |
+
|
| 7 |
+
## Features
|
| 8 |
+
|
| 9 |
+
- **Local Inference**: Loads the `sentiment_roberta` model directory within the repository for text emotion classification.
|
| 10 |
+
- **Label Mapping**: Reads `id -> Chinese Emotion Name` mapping from `text-emotion.yaml`.
|
| 11 |
+
- **Interactive CLI**: Enter text in the command line to output the emotion category and confidence level.
|
| 12 |
+
|
| 13 |
+
## Directory Structure (Key Files)
|
| 14 |
+
|
| 15 |
+
- `main.py`: Entry script (run directly).
|
| 16 |
+
- `sentiment_roberta/`: Exported Transformers model directory (contains `config.json`, `model.safetensors`, tokenizer, etc.).
|
| 17 |
+
- `text-emotion.yaml`: Label mapping file.
|
| 18 |
+
- `release-note.md`: Release notes (used by GitHub Actions as the release body).
|
| 19 |
+
|
| 20 |
+
## Environment Requirements
|
| 21 |
+
|
| 22 |
+
- **Python 3.10** (Recommended, matches the author's environment; 3.9+ is theoretically compatible but not fully verified).
|
| 23 |
+
- **Dependency Management**: Conda environment (recommended) or venv.
|
| 24 |
+
- **PyTorch**:
|
| 25 |
+
- **CPU Inference**: Install the CPU version of `torch`.
|
| 26 |
+
- **GPU Inference**: Requires an NVIDIA GPU + corresponding CUDA version (the author's environment uses `torch==2.10.0+cu128` / `torchvision==0.25.0+cu128` built with CUDA 12.8).
|
| 27 |
+
|
| 28 |
+
The author's conda environment export file is provided: `environment.yml`.
|
| 29 |
+
|
| 30 |
+
## Installation
|
| 31 |
+
|
| 32 |
+
### Using Conda Environment File (Recommended)
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
conda env create -f environment.yml
|
| 36 |
+
conda activate text-emotion-classification
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
## Usage
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
python main.py
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
Follow the prompts to enter text:
|
| 46 |
+
- **Enter any text**: Outputs emotion prediction and confidence.
|
| 47 |
+
- **Empty input (Enter)**: Exits the program.
|
| 48 |
+
|
| 49 |
+
## FAQ
|
| 50 |
+
|
| 51 |
+
- **Cannot find model directory `sentiment_roberta`**
|
| 52 |
+
- Ensure `sentiment_roberta/` exists in the root directory and contains files like `config.json` and `model.safetensors`.
|
| 53 |
+
- **Inference Device**
|
| 54 |
+
- The program automatically selects `cuda` if available; otherwise, it defaults to `cpu`.
|
| 55 |
+
|
| 56 |
+
## License
|
| 57 |
+
|
| 58 |
+
See [Apache 2.0 License](./LICENSE).
|
| 59 |
+
|
README_zh.md
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# text-emotion-classification
|
| 2 |
+
|
| 3 |
+
一个可本地快捷部署使用的文本情绪识别模型应用。项目直接运行 `main.py` 即可进行交互式推理。
|
| 4 |
+
|
| 5 |
+
## 功能
|
| 6 |
+
|
| 7 |
+
- **本地推理**:加载仓库内的 `sentiment_roberta` 模型目录进行文本情绪分类。
|
| 8 |
+
- **标签映射**:从 `text-emotion.yaml` 读取 `id -> 情绪中文名` 映射。
|
| 9 |
+
- **交互式 CLI**:命令行输入文本,输出情绪类别与置信度。
|
| 10 |
+
|
| 11 |
+
## 目录结构(关键文件)
|
| 12 |
+
|
| 13 |
+
- `main.py`:入口脚本(直接运行即可)
|
| 14 |
+
- `sentiment_roberta/`:已导出的 Transformers 模型目录(包含 `config.json`、`model.safetensors`、tokenizer 等)
|
| 15 |
+
- `text-emotion.yaml`:标签映射文件
|
| 16 |
+
- `release-note.md`:Release 说明(由 GitHub Action 用作 release body)
|
| 17 |
+
|
| 18 |
+
## 运行环境配置要求
|
| 19 |
+
|
| 20 |
+
- Python 3.10(推荐,与作者环境一致;3.9+ 理论可用但未完整验证)
|
| 21 |
+
- 依赖管理方式:Conda 环境(推荐)或 venv
|
| 22 |
+
- PyTorch:
|
| 23 |
+
- CPU 推理:安装 CPU 版 `torch`
|
| 24 |
+
- GPU 推理:需要 NVIDIA GPU + 对应版本 CUDA(本仓库作者环境为 `torch==2.10.0+cu128` / `torchvision==0.25.0+cu128`,即 CUDA 12.8 构建)
|
| 25 |
+
|
| 26 |
+
本仓库提供了作者的 conda 环境导出文件:`environment.yml`。
|
| 27 |
+
|
| 28 |
+
## 安装
|
| 29 |
+
|
| 30 |
+
### 使用 conda 环境文件(推荐,复现作者环境)
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
conda env create -f environment.yml
|
| 34 |
+
conda activate text-emotion-classification
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
## 运行
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
python main.py
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
运行后按提示输入文本:
|
| 44 |
+
|
| 45 |
+
- 输入任意文本并回车:输出情绪预测与置信度
|
| 46 |
+
- 直接回车:退出
|
| 47 |
+
|
| 48 |
+
## 常见问题
|
| 49 |
+
|
| 50 |
+
- **找不到模型目录 `sentiment_roberta`**
|
| 51 |
+
- 请确认仓库根目录下存在 `sentiment_roberta/`,且其中包含 `config.json`、`model.safetensors` 等文件。
|
| 52 |
+
- **模型推理设备**
|
| 53 |
+
- 程序会自动选择 `cuda`(如可用)否则使用 `cpu`。
|
| 54 |
+
|
| 55 |
+
## License
|
| 56 |
+
|
| 57 |
+
见 [Apache 2.0 License](./LICENSE)。
|
emotion-classification-train.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
environment.yml
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: text-emotion-classification
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- asttokens=3.0.1
|
| 7 |
+
- blas=1.0
|
| 8 |
+
- boto3=1.42.34
|
| 9 |
+
- botocore=1.42.34
|
| 10 |
+
- bottleneck=1.4.2
|
| 11 |
+
- brotlicffi=1.2.0.0
|
| 12 |
+
- bzip2=1.0.8
|
| 13 |
+
- ca-certificates=2026.1.4
|
| 14 |
+
- cairo=1.18.4
|
| 15 |
+
- certifi=2026.01.04
|
| 16 |
+
- cffi=2.0.0
|
| 17 |
+
- charset-normalizer=3.4.4
|
| 18 |
+
- colorama=0.4.6
|
| 19 |
+
- comm=0.2.3
|
| 20 |
+
- contourpy=1.3.1
|
| 21 |
+
- cycler=0.12.1
|
| 22 |
+
- debugpy=1.8.20
|
| 23 |
+
- decorator=5.2.1
|
| 24 |
+
- exceptiongroup=1.3.1
|
| 25 |
+
- executing=2.2.1
|
| 26 |
+
- expat=2.7.4
|
| 27 |
+
- fontconfig=2.15.0
|
| 28 |
+
- fonttools=4.61.0
|
| 29 |
+
- freetype=2.14.1
|
| 30 |
+
- graphite2=1.3.14
|
| 31 |
+
- harfbuzz=12.3.0
|
| 32 |
+
- icc_rt=2022.1.0
|
| 33 |
+
- icu=73.1
|
| 34 |
+
- idna=3.11
|
| 35 |
+
- intel-openmp=2025.0.0
|
| 36 |
+
- ipykernel=7.2.0
|
| 37 |
+
- ipython=8.37.0
|
| 38 |
+
- jedi=0.19.2
|
| 39 |
+
- jmespath=1.1.0
|
| 40 |
+
- joblib=1.5.3
|
| 41 |
+
- jpeg=9f
|
| 42 |
+
- jupyter_client=8.8.0
|
| 43 |
+
- jupyter_core=5.9.1
|
| 44 |
+
- kiwisolver=1.4.9
|
| 45 |
+
- krb5=1.21.3
|
| 46 |
+
- lcms2=2.16
|
| 47 |
+
- lerc=3.0
|
| 48 |
+
- libdeflate=1.17
|
| 49 |
+
- libexpat=2.7.4
|
| 50 |
+
- libffi=3.4.4
|
| 51 |
+
- libglib=2.86.3
|
| 52 |
+
- libhwloc=2.12.1
|
| 53 |
+
- libiconv=1.16
|
| 54 |
+
- libkrb5=1.22.1
|
| 55 |
+
- libpng=1.6.54
|
| 56 |
+
- libpq=17.6
|
| 57 |
+
- libsodium=1.0.20
|
| 58 |
+
- libtiff=4.5.1
|
| 59 |
+
- libwebp-base=1.6.0
|
| 60 |
+
- libxml2=2.13.9
|
| 61 |
+
- libzlib=1.3.1
|
| 62 |
+
- lz4-c=1.9.4
|
| 63 |
+
- matplotlib=3.10.8
|
| 64 |
+
- matplotlib-base=3.10.8
|
| 65 |
+
- matplotlib-inline=0.2.1
|
| 66 |
+
- mkl=2025.0.0
|
| 67 |
+
- mkl-service=2.5.2
|
| 68 |
+
- mkl_fft=2.1.1
|
| 69 |
+
- mkl_random=1.3.0
|
| 70 |
+
- mysql-common=9.3.0
|
| 71 |
+
- mysql-libs=9.3.0
|
| 72 |
+
- nest-asyncio=1.6.0
|
| 73 |
+
- numexpr=2.14.1
|
| 74 |
+
- numpy-base=2.2.5
|
| 75 |
+
- openjpeg=2.5.2
|
| 76 |
+
- openssl=3.6.1
|
| 77 |
+
- packaging=25.0
|
| 78 |
+
- pandas=2.3.3
|
| 79 |
+
- parso=0.8.6
|
| 80 |
+
- pcre2=10.46
|
| 81 |
+
- pickleshare=0.7.5
|
| 82 |
+
- pip=26.0.1
|
| 83 |
+
- pixman=0.46.4
|
| 84 |
+
- platformdirs=4.5.1
|
| 85 |
+
- prompt-toolkit=3.0.52
|
| 86 |
+
- psutil=7.2.2
|
| 87 |
+
- pure_eval=0.2.3
|
| 88 |
+
- pycparser=2.23
|
| 89 |
+
- pygments=2.19.2
|
| 90 |
+
- pyparsing=3.2.5
|
| 91 |
+
- pyqt=6.9.1
|
| 92 |
+
- pyqt6-sip=13.10.2
|
| 93 |
+
- pysocks=1.7.1
|
| 94 |
+
- python=3.10.19
|
| 95 |
+
- python-dateutil=2.9.0post0
|
| 96 |
+
- python-tzdata=2025.3
|
| 97 |
+
- python_abi=3.10
|
| 98 |
+
- pytz=2025.2
|
| 99 |
+
- pywin32=311
|
| 100 |
+
- pyzmq=27.1.0
|
| 101 |
+
- qtbase=6.9.2
|
| 102 |
+
- qtdeclarative=6.9.2
|
| 103 |
+
- qtsvg=6.9.2
|
| 104 |
+
- qttools=6.9.2
|
| 105 |
+
- qtwebchannel=6.9.2
|
| 106 |
+
- qtwebsockets=6.9.2
|
| 107 |
+
- regex=2025.11.3
|
| 108 |
+
- requests=2.32.5
|
| 109 |
+
- s3transfer=0.16.0
|
| 110 |
+
- sacremoses=0.1.1
|
| 111 |
+
- scikit-learn=1.7.1
|
| 112 |
+
- scipy=1.15.3
|
| 113 |
+
- setuptools=80.10.2
|
| 114 |
+
- sip=6.12.0
|
| 115 |
+
- six=1.17.0
|
| 116 |
+
- sqlite=3.51.1
|
| 117 |
+
- stack_data=0.6.3
|
| 118 |
+
- tbb=2022.3.0
|
| 119 |
+
- tbb-devel=2022.3.0
|
| 120 |
+
- threadpoolctl=3.5.0
|
| 121 |
+
- tk=8.6.15
|
| 122 |
+
- tomli=2.4.0
|
| 123 |
+
- tornado=6.5.4
|
| 124 |
+
- tqdm=4.67.3
|
| 125 |
+
- traitlets=5.14.3
|
| 126 |
+
- typing_extensions=4.15.0
|
| 127 |
+
- tzdata=2025c
|
| 128 |
+
- ucrt=10.0.22621.0
|
| 129 |
+
- urllib3=2.6.3
|
| 130 |
+
- vc=14.3
|
| 131 |
+
- vc14_runtime=14.44.35208
|
| 132 |
+
- vs2015_runtime=14.44.35208
|
| 133 |
+
- wcwidth=0.6.0
|
| 134 |
+
- wheel=0.46.3
|
| 135 |
+
- win_inet_pton=1.1.0
|
| 136 |
+
- xz=5.6.4
|
| 137 |
+
- zeromq=4.3.5
|
| 138 |
+
- zlib=1.3.1
|
| 139 |
+
- zstd=1.5.7
|
| 140 |
+
- pip:
|
| 141 |
+
- accelerate==1.12.0
|
| 142 |
+
- aiohappyeyeballs==2.6.1
|
| 143 |
+
- aiohttp==3.13.3
|
| 144 |
+
- aiosignal==1.4.0
|
| 145 |
+
- annotated-doc==0.0.4
|
| 146 |
+
- anyio==4.12.1
|
| 147 |
+
- async-timeout==5.0.1
|
| 148 |
+
- attrs==25.4.0
|
| 149 |
+
- click==8.3.1
|
| 150 |
+
- datasets==4.5.0
|
| 151 |
+
- dill==0.4.0
|
| 152 |
+
- filelock==3.20.0
|
| 153 |
+
- frozenlist==1.8.0
|
| 154 |
+
- fsspec==2025.10.0
|
| 155 |
+
- h11==0.16.0
|
| 156 |
+
- hf-xet==1.2.0
|
| 157 |
+
- httpcore==1.0.9
|
| 158 |
+
- httpx==0.28.1
|
| 159 |
+
- huggingface-hub==1.4.1
|
| 160 |
+
- jinja2==3.1.6
|
| 161 |
+
- markdown-it-py==4.0.0
|
| 162 |
+
- markupsafe==2.1.5
|
| 163 |
+
- mdurl==0.1.2
|
| 164 |
+
- mpmath==1.3.0
|
| 165 |
+
- multidict==6.7.1
|
| 166 |
+
- multiprocess==0.70.18
|
| 167 |
+
- networkx==3.4.2
|
| 168 |
+
- numpy==2.2.6
|
| 169 |
+
- pillow==12.0.0
|
| 170 |
+
- propcache==0.4.1
|
| 171 |
+
- pyarrow==23.0.0
|
| 172 |
+
- pyyaml==6.0.3
|
| 173 |
+
- rich==14.3.2
|
| 174 |
+
- safetensors==0.7.0
|
| 175 |
+
- shellingham==1.5.4
|
| 176 |
+
- sympy==1.14.0
|
| 177 |
+
- tokenizers==0.22.2
|
| 178 |
+
- torch==2.10.0+cu128
|
| 179 |
+
- torchvision==0.25.0+cu128
|
| 180 |
+
- transformers==5.1.0
|
| 181 |
+
- typer==0.23.1
|
| 182 |
+
- typer-slim==0.23.1
|
| 183 |
+
- xxhash==3.6.0
|
| 184 |
+
- yarl==1.22.0
|
| 185 |
+
|
main.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import yaml
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_label_map(yaml_path: str):
|
| 10 |
+
with open(yaml_path, "r", encoding="utf-8") as f:
|
| 11 |
+
data = yaml.safe_load(f)
|
| 12 |
+
|
| 13 |
+
label_map = {}
|
| 14 |
+
if isinstance(data, list):
|
| 15 |
+
# 支持两种写法:
|
| 16 |
+
# - 0: 伤心
|
| 17 |
+
# - {0: 伤心}
|
| 18 |
+
for item in data:
|
| 19 |
+
if isinstance(item, dict):
|
| 20 |
+
for k, v in item.items():
|
| 21 |
+
label_map[int(k)] = str(v)
|
| 22 |
+
elif isinstance(item, str) and ":" in item:
|
| 23 |
+
k, v = item.split(":", 1)
|
| 24 |
+
label_map[int(k.strip())] = v.strip()
|
| 25 |
+
elif isinstance(data, dict):
|
| 26 |
+
for k, v in data.items():
|
| 27 |
+
label_map[int(k)] = str(v)
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError(f"无法解析标签映射:{yaml_path}")
|
| 30 |
+
|
| 31 |
+
if not label_map:
|
| 32 |
+
raise ValueError(f"标签映射为空:{yaml_path}")
|
| 33 |
+
|
| 34 |
+
return label_map
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def predict(text: str, tokenizer, model, device: torch.device):
|
| 38 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
|
| 39 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 40 |
+
|
| 41 |
+
model.eval()
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
outputs = model(**inputs)
|
| 44 |
+
logits = outputs.logits
|
| 45 |
+
probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()[0]
|
| 46 |
+
|
| 47 |
+
pred_id = int(np.argmax(probs))
|
| 48 |
+
confidence = float(probs[pred_id])
|
| 49 |
+
return pred_id, confidence, probs
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def main():
|
| 53 |
+
base_dir = os.path.dirname(os.path.abspath(__file__))
|
| 54 |
+
model_dir = os.path.join(base_dir, "sentiment_roberta")
|
| 55 |
+
yaml_path = os.path.join(base_dir, "text-emotion.yaml")
|
| 56 |
+
|
| 57 |
+
if not os.path.isdir(model_dir):
|
| 58 |
+
print(f"找不到模型目录:{model_dir}")
|
| 59 |
+
print("请先训练并确保训练脚本 output_dir=./sentiment_roberta(相对 data_preload 目录)。")
|
| 60 |
+
sys.exit(1)
|
| 61 |
+
|
| 62 |
+
if not os.path.isfile(yaml_path):
|
| 63 |
+
print(f"找不到标签映射文件:{yaml_path}")
|
| 64 |
+
sys.exit(1)
|
| 65 |
+
|
| 66 |
+
label_map = load_label_map(yaml_path)
|
| 67 |
+
|
| 68 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 69 |
+
print(f"推理设备:{device}")
|
| 70 |
+
|
| 71 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
| 72 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
| 73 |
+
model.to(device)
|
| 74 |
+
|
| 75 |
+
print("请输入一段文本(直接回车退出):")
|
| 76 |
+
while True:
|
| 77 |
+
try:
|
| 78 |
+
text = input("> ").strip()
|
| 79 |
+
except (EOFError, KeyboardInterrupt):
|
| 80 |
+
print("\n退出")
|
| 81 |
+
break
|
| 82 |
+
|
| 83 |
+
if not text:
|
| 84 |
+
print("退出")
|
| 85 |
+
break
|
| 86 |
+
|
| 87 |
+
pred_id, conf, _ = predict(text, tokenizer, model, device)
|
| 88 |
+
emotion_cn = label_map.get(pred_id, f"未知标签({pred_id})")
|
| 89 |
+
|
| 90 |
+
print(f"情绪预测:{emotion_cn}")
|
| 91 |
+
print(f"置信度:{conf:.4f}")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
main()
|
| 96 |
+
|
sentiment_roberta/config.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_cross_attention": false,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"BertForSequenceClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"classifier_dropout": null,
|
| 9 |
+
"directionality": "bidi",
|
| 10 |
+
"dtype": "float32",
|
| 11 |
+
"eos_token_id": 2,
|
| 12 |
+
"hidden_act": "gelu",
|
| 13 |
+
"hidden_dropout_prob": 0.1,
|
| 14 |
+
"hidden_size": 768,
|
| 15 |
+
"id2label": {
|
| 16 |
+
"0": "LABEL_0",
|
| 17 |
+
"1": "LABEL_1",
|
| 18 |
+
"2": "LABEL_2",
|
| 19 |
+
"3": "LABEL_3",
|
| 20 |
+
"4": "LABEL_4",
|
| 21 |
+
"5": "LABEL_5",
|
| 22 |
+
"6": "LABEL_6",
|
| 23 |
+
"7": "LABEL_7"
|
| 24 |
+
},
|
| 25 |
+
"initializer_range": 0.02,
|
| 26 |
+
"intermediate_size": 3072,
|
| 27 |
+
"is_decoder": false,
|
| 28 |
+
"label2id": {
|
| 29 |
+
"LABEL_0": 0,
|
| 30 |
+
"LABEL_1": 1,
|
| 31 |
+
"LABEL_2": 2,
|
| 32 |
+
"LABEL_3": 3,
|
| 33 |
+
"LABEL_4": 4,
|
| 34 |
+
"LABEL_5": 5,
|
| 35 |
+
"LABEL_6": 6,
|
| 36 |
+
"LABEL_7": 7
|
| 37 |
+
},
|
| 38 |
+
"layer_norm_eps": 1e-12,
|
| 39 |
+
"max_position_embeddings": 512,
|
| 40 |
+
"model_type": "bert",
|
| 41 |
+
"num_attention_heads": 12,
|
| 42 |
+
"num_hidden_layers": 12,
|
| 43 |
+
"output_past": true,
|
| 44 |
+
"pad_token_id": 0,
|
| 45 |
+
"pooler_fc_size": 768,
|
| 46 |
+
"pooler_num_attention_heads": 12,
|
| 47 |
+
"pooler_num_fc_layers": 3,
|
| 48 |
+
"pooler_size_per_head": 128,
|
| 49 |
+
"pooler_type": "first_token_transform",
|
| 50 |
+
"problem_type": "single_label_classification",
|
| 51 |
+
"tie_word_embeddings": true,
|
| 52 |
+
"transformers_version": "5.1.0",
|
| 53 |
+
"type_vocab_size": 2,
|
| 54 |
+
"use_cache": false,
|
| 55 |
+
"vocab_size": 21128
|
| 56 |
+
}
|
sentiment_roberta/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9a64010a9f27db8eab2ef283add822ec36106abc778399f2bd8dfa5c1d2f189e
|
| 3 |
+
size 409118672
|
sentiment_roberta/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sentiment_roberta/tokenizer_config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"cls_token": "[CLS]",
|
| 4 |
+
"do_lower_case": false,
|
| 5 |
+
"is_local": false,
|
| 6 |
+
"mask_token": "[MASK]",
|
| 7 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 8 |
+
"pad_token": "[PAD]",
|
| 9 |
+
"sep_token": "[SEP]",
|
| 10 |
+
"strip_accents": null,
|
| 11 |
+
"tokenize_chinese_chars": true,
|
| 12 |
+
"tokenizer_class": "BertTokenizer",
|
| 13 |
+
"unk_token": "[UNK]"
|
| 14 |
+
}
|
text-emotion-classification.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
# 1️⃣ 载入中文 RoBERTa 分词器和模型
|
| 6 |
+
model_name = "hfl/chinese-roberta-wwm-ext"
|
| 7 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 8 |
+
|
| 9 |
+
# 指定标签数量,比如 8 类情绪
|
| 10 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
|
| 11 |
+
|
| 12 |
+
# 2️⃣ 加载自己情绪数据集
|
| 13 |
+
# 需要 CSV 至少包含两列:text(文本)、label(整数标签)
|
| 14 |
+
dataset = load_dataset(
|
| 15 |
+
"csv",
|
| 16 |
+
data_files={
|
| 17 |
+
"train": "emotion-classification-train.csv",
|
| 18 |
+
"test": "emotion-classification-train.csv",
|
| 19 |
+
},
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def preprocess(examples):
|
| 23 |
+
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
|
| 24 |
+
|
| 25 |
+
dataset = dataset.map(preprocess, batched=True)
|
| 26 |
+
|
| 27 |
+
# Transformers Trainer 期望标签列名为 label
|
| 28 |
+
if "labels" in dataset["train"].column_names and "label" not in dataset["train"].column_names:
|
| 29 |
+
dataset = dataset.rename_column("labels", "label")
|
| 30 |
+
|
| 31 |
+
def compute_metrics(eval_pred):
|
| 32 |
+
logits, labels = eval_pred
|
| 33 |
+
preds = np.argmax(logits, axis=-1)
|
| 34 |
+
acc = (preds == labels).mean().item() if hasattr((preds == labels).mean(), "item") else float((preds == labels).mean())
|
| 35 |
+
return {"accuracy": acc}
|
| 36 |
+
|
| 37 |
+
# 3️⃣ 配置训练参数(保存最优模型)
|
| 38 |
+
training_args = TrainingArguments(
|
| 39 |
+
output_dir="./sentiment_roberta",
|
| 40 |
+
eval_strategy="epoch",
|
| 41 |
+
save_strategy="epoch",
|
| 42 |
+
learning_rate=2e-5,
|
| 43 |
+
per_device_train_batch_size=8,
|
| 44 |
+
per_device_eval_batch_size=8,
|
| 45 |
+
num_train_epochs=3,
|
| 46 |
+
load_best_model_at_end=True,
|
| 47 |
+
metric_for_best_model="accuracy",
|
| 48 |
+
greater_is_better=True,
|
| 49 |
+
save_total_limit=2,
|
| 50 |
+
fp16=True,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
trainer = Trainer(
|
| 54 |
+
model=model,
|
| 55 |
+
args=training_args,
|
| 56 |
+
train_dataset=dataset["train"],
|
| 57 |
+
eval_dataset=dataset["test"],
|
| 58 |
+
compute_metrics=compute_metrics,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# 4️⃣ 开始训练
|
| 62 |
+
trainer.train()
|
| 63 |
+
|
| 64 |
+
# 5️⃣ 显式保存最优模型与分词器到 output_dir
|
| 65 |
+
trainer.save_model(training_args.output_dir)
|
| 66 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
text-emotion.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- 0: 伤心
|
| 2 |
+
- 1: 生气
|
| 3 |
+
- 2: 关心
|
| 4 |
+
- 3: 惊讶
|
| 5 |
+
- 4: 开心
|
| 6 |
+
- 5: 平静
|
| 7 |
+
- 6: 厌恶
|
| 8 |
+
- 7: 恐惧
|
train-data-preload.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
|
| 3 |
+
df1 = pd.read_csv(
|
| 4 |
+
"hf://datasets/zzhdbw/Simplified_Chinese_Multi-Emotion_Dialogue_Dataset/Simplified_Chinese_Multi-Emotion_Dialogue_Dataset.csv"
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
df2 = pd.read_csv(
|
| 8 |
+
"hf://datasets/jakeazcona/short-text-multi-labeled-emotion-classification/FINALDATA.csv"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
# df1: 中文情绪 -> 统一标签
|
| 12 |
+
# 伤心:0,生气:1,关心:2,惊讶:3,开心:4,平静:5,厌恶:6
|
| 13 |
+
# (如 df1 还包含其它情绪,将被置为 NA 并在后续丢弃)
|
| 14 |
+
DF1_LABEL_MAP = {
|
| 15 |
+
"伤心": 0,
|
| 16 |
+
"生气": 1,
|
| 17 |
+
"关心": 2,
|
| 18 |
+
"惊讶": 3,
|
| 19 |
+
"开心": 4,
|
| 20 |
+
"平静": 5,
|
| 21 |
+
"厌恶": 6,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
# df2: emotion 数字 -> 统一标签
|
| 25 |
+
# 0:1,1:7,2:2,3:3,4:4,5:5
|
| 26 |
+
DF2_EMOTION_MAP = {
|
| 27 |
+
0: 1,
|
| 28 |
+
1: 7,
|
| 29 |
+
2: 2,
|
| 30 |
+
3: 3,
|
| 31 |
+
4: 4,
|
| 32 |
+
5: 5,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
# 统一列名(将 df1 的 text 和 df2 的 sample 统一到新表的 text 字段)
|
| 36 |
+
if "label" not in df1.columns or "text" not in df1.columns:
|
| 37 |
+
raise KeyError(f"df1 缺少必要列: label 或 text,现有列: {list(df1.columns)}")
|
| 38 |
+
if "emotion" not in df2.columns or "sample" not in df2.columns:
|
| 39 |
+
raise KeyError(f"df2 缺少必要列: emotion 或 sample,现有列: {list(df2.columns)}")
|
| 40 |
+
|
| 41 |
+
df1_std = df1.copy()
|
| 42 |
+
df1_std["label"] = df1_std["label"].map(DF1_LABEL_MAP)
|
| 43 |
+
# df1 保持 text 字段名不变
|
| 44 |
+
|
| 45 |
+
df2_std = df2.copy()
|
| 46 |
+
# 确保 emotion 可被当作 int 映射
|
| 47 |
+
df2_std["emotion"] = pd.to_numeric(df2_std["emotion"], errors="coerce")
|
| 48 |
+
df2_std["label"] = df2_std["emotion"].map(DF2_EMOTION_MAP)
|
| 49 |
+
# 将 df2 的 sample 重命名为 text
|
| 50 |
+
df2_std = df2_std.rename(columns={"sample": "text"})
|
| 51 |
+
|
| 52 |
+
# 只保留 text 和 label 两列进行合并
|
| 53 |
+
final_cols = ["text", "label"]
|
| 54 |
+
|
| 55 |
+
merged = pd.concat([
|
| 56 |
+
df1_std[final_cols],
|
| 57 |
+
df2_std[final_cols],
|
| 58 |
+
], ignore_index=True)
|
| 59 |
+
|
| 60 |
+
# 丢弃无法映射的样本
|
| 61 |
+
merged = merged.dropna(subset=["label"]).copy()
|
| 62 |
+
merged["label"] = merged["label"].astype(int)
|
| 63 |
+
|
| 64 |
+
# 输出
|
| 65 |
+
merged.to_csv("emotion-classification-train.csv", index=False, encoding="utf-8-sig")
|
| 66 |
+
print(f"merged saved: emotion-classification-train.csv, rows={len(merged)}, cols={len(merged.columns)}")
|