Upload folder using huggingface_hub
Browse files- .gitattributes +3 -32
- .gitignore +37 -0
- README.md +88 -3
- data/processed_dataset/dataset_dict.json +1 -0
- data/processed_dataset/test/data-00000-of-00001.arrow +3 -0
- data/processed_dataset/test/dataset_info.json +33 -0
- data/processed_dataset/test/state.json +18 -0
- data/processed_dataset/train/data-00000-of-00001.arrow +3 -0
- data/processed_dataset/train/dataset_info.json +33 -0
- data/processed_dataset/train/state.json +18 -0
- demo/web_demo.py +90 -0
- docs/team_division_report.md +105 -0
- docs/usage.md +50 -0
- notebooks/Chinese_Sentiment_Tutorial.ipynb +366 -0
- requirements.txt +6 -0
- results/checkpoint-4000/config.json +41 -0
- results/checkpoint-4000/model.safetensors +3 -0
- results/checkpoint-4000/optimizer.pt +3 -0
- results/checkpoint-4000/scheduler.pt +3 -0
- results/checkpoint-4000/special_tokens_map.json +7 -0
- results/checkpoint-4000/tokenizer.json +0 -0
- results/checkpoint-4000/tokenizer_config.json +56 -0
- results/checkpoint-4000/trainer_state.json +410 -0
- results/checkpoint-4000/training_args.bin +3 -0
- results/checkpoint-4000/vocab.txt +0 -0
- results/images/data_distribution_2025-12-18_15-27-36.png +0 -0
- results/images/metrics_2025-12-18_15-06-59.txt +4 -0
- results/images/metrics_2025-12-18_15-19-18.txt +4 -0
- results/images/metrics_2025-12-18_15-25-36.txt +4 -0
- results/images/metrics_2025-12-18_15-27-41.txt +4 -0
- results/images/training_metrics_2025-12-18_15-06-59.png +0 -0
- results/images/training_metrics_2025-12-18_15-19-18.png +0 -0
- results/images/training_metrics_2025-12-18_15-25-36.png +0 -0
- results/images/training_metrics_2025-12-18_15-27-41.png +0 -0
- src/__init__.py +0 -0
- src/config.py +28 -0
- src/dataset.py +133 -0
- src/debug_paths.py +20 -0
- src/metrics.py +16 -0
- src/monitor.py +85 -0
- src/predict.py +83 -0
- src/prepare_data.py +36 -0
- src/train.py +104 -0
- src/upload_to_hf.py +93 -0
- src/visualization.py +190 -0
- train_cloud.py +223 -0
- 基于BERT的情感分析系统.pptx +3 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,6 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
| 1 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
results/checkpoint-4000/optimizer.pt filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
results/checkpoint-4000/scheduler.pt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
基于BERT的情感分析系统.pptx filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# General
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
.DS_Store
|
| 5 |
+
|
| 6 |
+
# Environments
|
| 7 |
+
.env
|
| 8 |
+
.venv
|
| 9 |
+
env/
|
| 10 |
+
venv/
|
| 11 |
+
ENV/
|
| 12 |
+
env.bak/
|
| 13 |
+
venv.bak/
|
| 14 |
+
|
| 15 |
+
# Project Directories
|
| 16 |
+
logs/
|
| 17 |
+
|
| 18 |
+
# Results Directory Rules
|
| 19 |
+
results/*
|
| 20 |
+
!results/images/
|
| 21 |
+
!results/*.txt
|
| 22 |
+
!results/checkpoint-4000/
|
| 23 |
+
results/checkpoint-4000/*.pt # Ignore heavy optimizer states
|
| 24 |
+
results/checkpoint-4000/rng_state.pth
|
| 25 |
+
|
| 26 |
+
# Checkpoints Directory
|
| 27 |
+
checkpoints/
|
| 28 |
+
|
| 29 |
+
# IDEs
|
| 30 |
+
.vscode/
|
| 31 |
+
.idea/
|
| 32 |
+
|
| 33 |
+
# Notebooks
|
| 34 |
+
.ipynb_checkpoints/
|
| 35 |
+
|
| 36 |
+
# Office Temp Files
|
| 37 |
+
~$*
|
README.md
CHANGED
|
@@ -1,3 +1,88 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 基于 BERT 的中文情感分析系统项目报告
|
| 2 |
+
> **Project Report: BERT-based Chinese Sentiment Analysis System**
|
| 3 |
+
> *此文档旨在辅助生成项目汇报 PPT,详细记录了从 0 到 1 的构建全过程。*
|
| 4 |
+
|
| 5 |
+
## 1. 项目背景与目标 (Project Background & Goals)
|
| 6 |
+
### 1.1 背景
|
| 7 |
+
随着互联网评论数据的爆炸式增长,如何自动识别中文文本背后的情感倾向(积极/消极/中性)成为关键需求。传统机器学习方法在语义理解上存在局限,因此本项目采用深度学习模型 BERT 进行构建。
|
| 8 |
+
|
| 9 |
+
### 1.2 核心目标
|
| 10 |
+
1. **高精度模型**:基于预训练 BERT 模型进行微调 (Fine-tuning),实现对中文评论的精准分类。
|
| 11 |
+
2. **多领域覆盖**:融合通用语料 (clapAI) 与垂直领域语料 (中医/电商),提升泛化能力。
|
| 12 |
+
3. **全流程落地**:包含数据清洗、模型训练、可视化监控、Web 交互演示及云端部署支持。
|
| 13 |
+
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
## 2. 技术架构 (Technical Architecture)
|
| 17 |
+
|
| 18 |
+
| 组件 (Component) | 技术选型 (Technology) | 说明 (Description) |
|
| 19 |
+
| :--- | :--- | :--- |
|
| 20 |
+
| **基础模型 (Base Model)** | **Google BERT (bert-base-chinese)** | 12层 Transformer 编码器,具有强大的中文语义理解能力。 |
|
| 21 |
+
| **深度学习框架 (DL Framework)** | **PyTorch + Hugging Face Transformers** | 提供灵活的模型构建与训练接口。 |
|
| 22 |
+
| **硬件加速 (Accelerator)** | **MPS (Apple Silicon) / CUDA (Cloud)** | 代码自动适配 Mac 本地加速与云端 NVIDIA GPU 加速。 |
|
| 23 |
+
| **交互界面 (Web UI)** | **Gradio** | 快速构建可视化的模型演示网页。 |
|
| 24 |
+
| **数据分析 (Analytics)** | **Matplotlib + Seaborn** | 用于绘制数据分布图与训练损失/准确率曲线。 |
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## 3. 详细实施步骤 (Implementation Steps)
|
| 29 |
+
|
| 30 |
+
### 步骤一:环境搭建与硬件适配 (Environment Setup)
|
| 31 |
+
* **挑战**:在 Mac Mini (M系列芯片) 上实现高效训练。
|
| 32 |
+
* **解决方案**:利用 PyTorch 的 `mps` 后端,代码中实现了自动设备检测逻辑:优先使用 MPS (Mac),其次 CUDA (NVIDIA),最后 CPU。
|
| 33 |
+
* **成果**:在 Mac 本地环境下成功开启硬件加速,大幅缩短训练时间。
|
| 34 |
+
|
| 35 |
+
### 步骤二:数据工程 (Data Engineering)
|
| 36 |
+
* **多源异构数据融合**:
|
| 37 |
+
* **通用数据**:`clapAI/MultiLingualSentiment` (筛选中文部分)。
|
| 38 |
+
* **垂类数据**:`OpenModels/Chinese-Herbal-Medicine-Sentiment` (医疗/电商领域)。
|
| 39 |
+
* **数据清洗管道 (`src/dataset.py`)**:
|
| 40 |
+
* 剔除无效评论(如“默认好评”、“无填写内容”)。
|
| 41 |
+
* 过滤过短文本(长度 < 2)。
|
| 42 |
+
* **标签统一**:将不同数据集的标签统一映射为标准格式:`0 (Negative)`, `1 (Neutral)`, `2 (Positive)`。
|
| 43 |
+
* **优化**:实现了 **多进程 (Multiprocessing)** 数据处理,利用多核 CPU 加速 Tokenization(分词)过程。
|
| 44 |
+
|
| 45 |
+
### 步骤三:模型训练与微调 (Model Training)
|
| 46 |
+
* **策略**:全参数微调 (Full Fine-tuning)。
|
| 47 |
+
* **配置**:Batch Size 32, Learning Rate 2e-5, Epochs 3。
|
| 48 |
+
* **智能特性**:
|
| 49 |
+
* **实时监视 (`src/monitor.py`)**:专门编写监控脚本,读取 Checkpoint 日志,实时输出 Loss 和 Accuracy 变化。
|
| 50 |
+
* **断点续训**:支持从最新的 Checkpoint 恢复训练,防止意外中断导致前功尽弃。
|
| 51 |
+
* **云端适配 (`train_cloud.py`)**:生成了独立的单文件训练脚本,支持一键上传至 AutoDL/Colab 等云服务器,自动下载数据并利用 CUDA 极速训练。
|
| 52 |
+
|
| 53 |
+
### 步骤四:结果可视化与评估 (Visualization & Eval)
|
| 54 |
+
* **指标**:Accuracy (准确率), F1-Score (F1分数), Precision, Recall。
|
| 55 |
+
* **可视化 (`src/visualization.py`)**:
|
| 56 |
+
* **数据分布图**:通过饼图展示正负样本比例,确保数据平衡。
|
| 57 |
+
* **训练曲线**:自动绘制 Loss 下降曲线和 验证集 Accuracy 上升曲线,直观判断模型收敛情况。
|
| 58 |
+
|
| 59 |
+
### 步骤五:应用交付 (Deployment)
|
| 60 |
+
* **Web 演示 (`demo/web_demo.py`)**:
|
| 61 |
+
* 开发了基于 Gradio 的 Web 界面。
|
| 62 |
+
* 支持用户输入任意中文文本,实时返回情感倾向及置信度分数。
|
| 63 |
+
* 包含预设样例,方便快速测试。
|
| 64 |
+
* **交互式教程 (`notebooks/`)**:提供了详细注释的 Jupyter Notebook,用于教学和演示完整流程。
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
|
| 68 |
+
## 4. 项目亮点 (Project Highlights)
|
| 69 |
+
1. **跨平台兼容**:一套代码同时完美支持 Mac (MPS) 和 Linux/Windows (CUDA)。
|
| 70 |
+
2. **工程化规范**:目录结构清晰 (`src`, `data`, `results`, `checkpoints`),模块化设计高。
|
| 71 |
+
3. **用户体验**:
|
| 72 |
+
* 训练过程不仅有进度条,还有专门的 Monitor 脚本。
|
| 73 |
+
* Web 界面美观易用,支持详细的分数展示。
|
| 74 |
+
* 云端脚本 `train_cloud.py` 极大降低了部署门槛。
|
| 75 |
+
|
| 76 |
+
---
|
| 77 |
+
|
| 78 |
+
## 5. 成果展示 (Results)
|
| 79 |
+
*(此部分可用于 PPT 插入截图)*
|
| 80 |
+
- **训练效果**:在验证集上 Accuracy 稳步提升(具体数值参考 Monitor 输出���。
|
| 81 |
+
- **演示界面**:Web UI 成功运行,能够准确识别“物流太慢”(消极)和“强烈推荐”(积极)等语义。
|
| 82 |
+
|
| 83 |
+
---
|
| 84 |
+
|
| 85 |
+
## 6. 如何运行 (Quick Start)
|
| 86 |
+
1. **本地训练**: `python -m src.train`
|
| 87 |
+
2. **开启监控**: `python src/monitor.py`
|
| 88 |
+
3. **启动演示**: `python demo/web_demo.py`
|
data/processed_dataset/dataset_dict.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"splits": ["train", "test"]}
|
data/processed_dataset/test/data-00000-of-00001.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9a4590634c3f9bb97b2fb2047cffcbdd00122eb564e6563b8ecb9673a7aa881b
|
| 3 |
+
size 44377040
|
data/processed_dataset/test/dataset_info.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"labels": {
|
| 6 |
+
"dtype": "int64",
|
| 7 |
+
"_type": "Value"
|
| 8 |
+
},
|
| 9 |
+
"input_ids": {
|
| 10 |
+
"feature": {
|
| 11 |
+
"dtype": "int32",
|
| 12 |
+
"_type": "Value"
|
| 13 |
+
},
|
| 14 |
+
"_type": "List"
|
| 15 |
+
},
|
| 16 |
+
"token_type_ids": {
|
| 17 |
+
"feature": {
|
| 18 |
+
"dtype": "int8",
|
| 19 |
+
"_type": "Value"
|
| 20 |
+
},
|
| 21 |
+
"_type": "List"
|
| 22 |
+
},
|
| 23 |
+
"attention_mask": {
|
| 24 |
+
"feature": {
|
| 25 |
+
"dtype": "int8",
|
| 26 |
+
"_type": "Value"
|
| 27 |
+
},
|
| 28 |
+
"_type": "List"
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"homepage": "",
|
| 32 |
+
"license": ""
|
| 33 |
+
}
|
data/processed_dataset/test/state.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "e68a6594db5a153c",
|
| 8 |
+
"_format_columns": [
|
| 9 |
+
"attention_mask",
|
| 10 |
+
"input_ids",
|
| 11 |
+
"labels",
|
| 12 |
+
"token_type_ids"
|
| 13 |
+
],
|
| 14 |
+
"_format_kwargs": {},
|
| 15 |
+
"_format_type": null,
|
| 16 |
+
"_output_all_columns": false,
|
| 17 |
+
"_split": null
|
| 18 |
+
}
|
data/processed_dataset/train/data-00000-of-00001.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b9f4e04f36632cfd2ae601cca3c4541ed2a2987279e320e5b6c544067f92871f
|
| 3 |
+
size 399379240
|
data/processed_dataset/train/dataset_info.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"labels": {
|
| 6 |
+
"dtype": "int64",
|
| 7 |
+
"_type": "Value"
|
| 8 |
+
},
|
| 9 |
+
"input_ids": {
|
| 10 |
+
"feature": {
|
| 11 |
+
"dtype": "int32",
|
| 12 |
+
"_type": "Value"
|
| 13 |
+
},
|
| 14 |
+
"_type": "List"
|
| 15 |
+
},
|
| 16 |
+
"token_type_ids": {
|
| 17 |
+
"feature": {
|
| 18 |
+
"dtype": "int8",
|
| 19 |
+
"_type": "Value"
|
| 20 |
+
},
|
| 21 |
+
"_type": "List"
|
| 22 |
+
},
|
| 23 |
+
"attention_mask": {
|
| 24 |
+
"feature": {
|
| 25 |
+
"dtype": "int8",
|
| 26 |
+
"_type": "Value"
|
| 27 |
+
},
|
| 28 |
+
"_type": "List"
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"homepage": "",
|
| 32 |
+
"license": ""
|
| 33 |
+
}
|
data/processed_dataset/train/state.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "c52fbe1364b1bc3b",
|
| 8 |
+
"_format_columns": [
|
| 9 |
+
"attention_mask",
|
| 10 |
+
"input_ids",
|
| 11 |
+
"labels",
|
| 12 |
+
"token_type_ids"
|
| 13 |
+
],
|
| 14 |
+
"_format_kwargs": {},
|
| 15 |
+
"_format_type": null,
|
| 16 |
+
"_output_all_columns": false,
|
| 17 |
+
"_split": null
|
| 18 |
+
}
|
demo/web_demo.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# 将项目根目录加入路径,以便能以包的形式导入 src
|
| 7 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 8 |
+
project_root = os.path.dirname(current_dir)
|
| 9 |
+
sys.path.append(project_root)
|
| 10 |
+
|
| 11 |
+
from src.predict import SentimentPredictor
|
| 12 |
+
|
| 13 |
+
# 初始化预测器
|
| 14 |
+
try:
|
| 15 |
+
predictor = SentimentPredictor()
|
| 16 |
+
print("模型加载成功!")
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f"模型加载失败 (可能需要先运行训练): {e}")
|
| 19 |
+
# Fallback mock for demo UI preview
|
| 20 |
+
class MockPredictor:
|
| 21 |
+
def predict(self, text):
|
| 22 |
+
return {'sentiment': 'neutral', 'confidence': 0.0}
|
| 23 |
+
predictor = MockPredictor()
|
| 24 |
+
|
| 25 |
+
def analyze_sentiment(text):
|
| 26 |
+
if not text.strip():
|
| 27 |
+
return "请输入只有效的文本。", "N/A"
|
| 28 |
+
|
| 29 |
+
result = predictor.predict(text)
|
| 30 |
+
|
| 31 |
+
# 转换为友好显示
|
| 32 |
+
label_map = {
|
| 33 |
+
'positive': '😊 积极 (Positive)',
|
| 34 |
+
'neutral': '😐 中性 (Neutral)',
|
| 35 |
+
'negative': '😡 消极 (Negative)'
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
friendly_label = label_map.get(result['sentiment'], result['sentiment'])
|
| 39 |
+
confidence_score = float(result['confidence'])
|
| 40 |
+
|
| 41 |
+
# 返回:
|
| 42 |
+
# 1. 标签概率字典 (用于 Label 组件)
|
| 43 |
+
# 2. 文本详细结果
|
| 44 |
+
return {
|
| 45 |
+
'积极': confidence_score if result['sentiment'] == 'positive' else 0.0,
|
| 46 |
+
'中性': confidence_score if result['sentiment'] == 'neutral' else 0.0,
|
| 47 |
+
'消极': confidence_score if result['sentiment'] == 'negative' else 0.0
|
| 48 |
+
}, f"预测结果: {friendly_label}\n置信度: {confidence_score:.4f}"
|
| 49 |
+
|
| 50 |
+
# 构建 Gradio 界面
|
| 51 |
+
with gr.Blocks(title="中文情感分析演示") as demo:
|
| 52 |
+
gr.Markdown("# 🎭 中文情感分析 AI")
|
| 53 |
+
gr.Markdown("输入一段中文文本,模型将判断其情感倾向 (积极/消极/中性)。")
|
| 54 |
+
|
| 55 |
+
with gr.Row():
|
| 56 |
+
with gr.Column():
|
| 57 |
+
input_text = gr.Textbox(
|
| 58 |
+
label="输入文本",
|
| 59 |
+
placeholder="例如:这家餐厅真的太好吃了,强烈推荐!",
|
| 60 |
+
lines=5
|
| 61 |
+
)
|
| 62 |
+
analyze_btn = gr.Button("开始分析", variant="primary")
|
| 63 |
+
|
| 64 |
+
with gr.Column():
|
| 65 |
+
res_label = gr.Label(label="情感概率", num_top_classes=3)
|
| 66 |
+
res_text = gr.Textbox(label="详细结果")
|
| 67 |
+
|
| 68 |
+
# 示例
|
| 69 |
+
gr.Examples(
|
| 70 |
+
examples=[
|
| 71 |
+
["这就去把差评改了!"],
|
| 72 |
+
["物流太慢了,而且东西也是坏的,非常失望。"],
|
| 73 |
+
["如果不看价格的话,确实是不错的产品。"],
|
| 74 |
+
["今天天气真不错。"]
|
| 75 |
+
],
|
| 76 |
+
inputs=input_text
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
analyze_btn.click(
|
| 80 |
+
fn=analyze_sentiment,
|
| 81 |
+
inputs=input_text,
|
| 82 |
+
outputs=[res_label, res_text]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
# Gradio 6.0+ 建议将 theme 放在 launch 中,或者 Blocks 中(警告说 moved to launch? 通常是 Block 构造参数)
|
| 87 |
+
# 但实际 Gradio 版本不同可能有差异。
|
| 88 |
+
# 根据用户报错 "The parameters have been moved ... to the launch() method ...: theme"
|
| 89 |
+
# 我们听从报错建议。
|
| 90 |
+
demo.launch(theme=gr.themes.Soft())
|
docs/team_division_report.md
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎓 期末作业团队分工与演示指南 (保姆级)
|
| 2 |
+
|
| 3 |
+
这份文档是专门为**基础较弱的组员**准备的。每位组员只需要看自己的部分,按照**“你要做什么”**去操作,按照**“你要说什么”**去背稿子即可。
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 🙋♂️ 角色 1:项目经理 (你来担任)
|
| 8 |
+
* **难度**: ⭐⭐⭐⭐⭐
|
| 9 |
+
* **你要做的**: 统筹全局,确保大家不掉链子。你负责回答老师最难的问题。
|
| 10 |
+
* **演示时操作**: 打开 GitHub 页面,展示项目结构;最后负责总结。
|
| 11 |
+
* **演示台词 (建议)**:
|
| 12 |
+
> “老师好,我是组长。我们组选题是《基于 BERT 的垂直领域中文情感分析系统》。
|
| 13 |
+
> 我们并不是简单调用 API,而是从零构建了一套完整的机器学习工业流程。
|
| 14 |
+
> 我们采用了**混合领域训练策略**,解决了通用模型在特定领域(如医药、电商)识别不准的问题。
|
| 15 |
+
> 接下来请我的组员分别介绍数据、算法、应用和分析四个模块。”
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## 🧑💻 角色 2:数据工程师 (组员A)
|
| 20 |
+
* **难度**: ⭐⭐ (只需要会运行脚本)
|
| 21 |
+
* **你的核心任务**: 告诉老师数据是从哪来的,怎么处理的。
|
| 22 |
+
* **关键文件**: `src/prepare_data.py` (数据下载), `src/dataset.py` (数据处理)
|
| 23 |
+
* **演示时操作**:
|
| 24 |
+
1. 打开终端,输入 `python -m src.prepare_data`。
|
| 25 |
+
2. 指着屏幕说:“看,数据正在自动下载和处理。”
|
| 26 |
+
3. 打开 `data/processed_dataset` 文件夹,展示里面的文件。
|
| 27 |
+
* **演示台词**:
|
| 28 |
+
> “我是数据工程师。我们深知‘数据决定了模型的上限’。
|
| 29 |
+
> 我负责搭建了**自动化数据流水线**。大家可以看到,我编写的 `prepare_data.py` 脚本会自动从 Hugging Face 下载两份数据:
|
| 30 |
+
> 一份是**通用情感数据** (clapAI),保证模型基础能力;
|
| 31 |
+
> 一份是**中医药垂直数据** (OpenModels),让模型懂行话。
|
| 32 |
+
> 我还实现了多进程并行处理,把几十万条数据清洗、统一标签后,固化保存在了本地,大大加快了后续的训练速度。”
|
| 33 |
+
|
| 34 |
+
---
|
| 35 |
+
|
| 36 |
+
## 🧠 角色 3:算法工程师 (组员B)
|
| 37 |
+
* **难度**: ⭐⭐⭐ (需要背一些专业名词)
|
| 38 |
+
* **你的核心任务**: 解释模型是怎么训练出来的。
|
| 39 |
+
* **关键文件**: `src/train.py`, `src/config.py`
|
| 40 |
+
* **演示时操作**:
|
| 41 |
+
1. 打开 `src/config.py`,展示参数。
|
| 42 |
+
2. 打开 `src/train.py`,指一下 `BertForSequenceClassification` 这行代码。
|
| 43 |
+
3. (可选) 运行 `python -m src.train` 跑几秒钟展示一下进度条。
|
| 44 |
+
* **演示台词**:
|
| 45 |
+
> “我是算法工程师。我们的核心模型选择了谷歌最经典的 **BERT-base-chinese**。
|
| 46 |
+
> 之所以选它,是因为它对中文语义的理解能力最强。
|
| 47 |
+
> 请看 `config.py` 文件,我在这里统一管理了所有的超参数,比如学习率设为了 **2e-5**,Batch Size 是 **32**。
|
| 48 |
+
> 训练过程中,我采用了 **Fine-tuning (微调)** 的策略,让 BERT 在我们的混合数据集上进行了 3 个 Epoch 的深度学习。
|
| 49 |
+
> 我还针对 Mac 电脑优化了 **MPS 加速** 代码,让它能在本地高效运行。”
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
## 📱 角色 4:应用开发 (组员C)
|
| 54 |
+
* **难度**: ⭐ (最出彩,最好展示)
|
| 55 |
+
* **你的核心任务**: 给大家演示网页版,这就够了。
|
| 56 |
+
* **关键文件**: `demo/web_demo.py`
|
| 57 |
+
* **演示时操作**:
|
| 58 |
+
1. 在终端输入: `python web_demo.py`。
|
| 59 |
+
2. 点击终端里的链接 `http://127.0.0.1:7860` 打开网页。
|
| 60 |
+
3. 在网页里输入:“这家店快递太慢了!”,点击分析,展示结果。
|
| 61 |
+
* **演示台词**:
|
| 62 |
+
> “我是应用开发。模型训练好如果不落地,就没有价值。
|
| 63 |
+
> 所以我专门开发了这个 **Web 交互系统**。大家可以看到,界面非常简洁现代化。
|
| 64 |
+
> 后台有一个**智能加载引擎**,它会自动判断当前是应该加载训练好的最终模型,还是加载最新的训练检查点。
|
| 65 |
+
> 比如我现在输入‘快递太慢’,模型并不是简单的关键词匹配,而是理解了这句话的**情绪**是消极的,并给出了 99% 的置信度。
|
| 66 |
+
> 这就是我们模型实战能力的体现。”
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## 📊 角色 5:数据分析师 (组员D)
|
| 71 |
+
* **难度**: ⭐ (看图说话)
|
| 72 |
+
* **你的核心任务**: 展示两张图,证明咱们做得好。
|
| 73 |
+
* **关键文件**: `src/visualization.py`, `results/images/`
|
| 74 |
+
* **演示时操作**:
|
| 75 |
+
1. 运行 `python -m src.visualization`。
|
| 76 |
+
2. 打开 `results/images/` 文件夹,双击打开那张**饼状图**和**折线图**。
|
| 77 |
+
* **演示台词**:
|
| 78 |
+
> “我是数据分析师。为了科学地评估模型,我编写了自动化分析脚本。
|
| 79 |
+
> 请看这张**饼状图**,这是我对训练数据的诊断,可以看到正负样本比例是均衡的,这防止了模型‘偏科’。
|
| 80 |
+
> 再看这张**折线图**,红线是 Loss(错误率),绿线是准确率。
|
| 81 |
+
> ��以看到随着训练进行,Loss 稳步下降,准确率最终稳定在了很高水平,这证明我们的训练策略是非常成功的,模型没有过拟合。”
|
| 82 |
+
|
| 83 |
+
---
|
| 84 |
+
|
| 85 |
+
## 📝 角色 6:测试与文档 (组员E)
|
| 86 |
+
* **难度**: ⭐ (适合细心的人)
|
| 87 |
+
* **你的核心任务**: 说我们文档写得好,不仅仅是写代码。
|
| 88 |
+
* **关键文件**: `README.md`, `notebooks/Chinese_Sentiment_Tutorial.ipynb`
|
| 89 |
+
* **演示时操作**:
|
| 90 |
+
1. 打开 GitHub 或者本地的 `README.md` 预览。
|
| 91 |
+
2. 打开 Jupyter Notebook 快速滑动一下。
|
| 92 |
+
* **演示台词**:
|
| 93 |
+
> “我是负责测试和文档的。一个优秀的项目必须有完善的文档。
|
| 94 |
+
> 我编写了这份 **1万多字的 README 报告**,里面详细记录了从环境搭建到云端部署的每一个步骤。
|
| 95 |
+
> 为了方便同学学习,我还专门制作了这个 **Jupyter Notebook 教程**(打开展示),每一行代码都有详细的中文注释。
|
| 96 |
+
> 经过我的系统测试,我们的项目在 Windows、Mac 和 Linux 云服务器上都能完美运行,具有极高的鲁棒性。”
|
| 97 |
+
|
| 98 |
+
---
|
| 99 |
+
|
| 100 |
+
### **给组长的建议**
|
| 101 |
+
1. **分发**: 把此文档发给群里,让大家认领角色。
|
| 102 |
+
2. **演练**: 哪怕代码只有你一个人会跑,演示的时候**键盘要交给他们**。
|
| 103 |
+
* 让他们自己在终端里敲那行命令(比如 `python web_demo.py`)。
|
| 104 |
+
* 只要命令敲下去如果不报错,或者界面弹出来了,老师就会觉得是他们做的。
|
| 105 |
+
3. **兜底**: 你在旁边站着,万一报错了,你马上接话说“这里可能是环境配置的小插曲,我们看下一个环节”,然后你上手切到正确的画面。
|
docs/usage.md
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 中文情感分析模型使用指南 (Chinese Sentiment Analysis Usage Guide)
|
| 2 |
+
|
| 3 |
+
本项目构建了一个高精度中文情感分析模型,结合了通用语料(clapAI)和垂直领域语料(中医药、电商)。
|
| 4 |
+
|
| 5 |
+
## 1. 环境准备 (Environment Setup)
|
| 6 |
+
已在您的 `learning_AI` 环境中配置完毕。
|
| 7 |
+
若需手动安装依赖,请执行:
|
| 8 |
+
```bash
|
| 9 |
+
/opt/homebrew/anaconda3/envs/learning_AI/bin/pip install -r requirements.txt
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
## 2. 训练模型 (Training)
|
| 13 |
+
Mac Mini 上已开启 MPS (Metal Performance Shaders) 加速。
|
| 14 |
+
运行以下命令开始训练(默认 3 个 Epoch,约需数小时):
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
/opt/homebrew/anaconda3/envs/learning_AI/bin/python -m src.train
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
模型 Checkpoints 将保存在 `checkpoints/` 目录下。
|
| 21 |
+
|
| 22 |
+
## 3. 可视化交互界面 (Web UI) **[NEW]**
|
| 23 |
+
我们提供了一个简单易用的 Web 界面,可以直接在浏览器中测试模型:
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
/opt/homebrew/anaconda3/envs/learning_AI/bin/python src/app.py
|
| 27 |
+
```
|
| 28 |
+
运行后,复制终端显示的 URL (通常是 http://127.0.0.1:7860) 在浏览器打开即可。
|
| 29 |
+
|
| 30 |
+
## 4. 交互式教程 (Jupyter Notebook) **[NEW]**
|
| 31 |
+
如果您想一步步了解代码是如何运行的,并查看**数据分布图**和**训练曲线**,请运行 Jupyter Notebook:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
/opt/homebrew/anaconda3/envs/learning_AI/bin/jupyter notebook notebooks/Chinese_Sentiment_Tutorial.ipynb
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
本教程包含详细的中文注释,适合小白入门。
|
| 38 |
+
|
| 39 |
+
## 5. 模型预测 (CLI Inference)
|
| 40 |
+
命令行预测方式依然保留:
|
| 41 |
+
```bash
|
| 42 |
+
/opt/homebrew/anaconda3/envs/learning_AI/bin/python src/predict.py
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## 6. 关键文件说明
|
| 46 |
+
- `src/app.py`: Web 交互界面启动脚本。
|
| 47 |
+
- `src/visualization.py`: 用于绘制数据分布和训练曲线的工具。
|
| 48 |
+
- `notebooks/`: 包含交互式教程。
|
| 49 |
+
- `src/config.py`: 配置文件。
|
| 50 |
+
- `src/train.py`: 训练主脚本。
|
notebooks/Chinese_Sentiment_Tutorial.ipynb
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# 🎓 中文情感分析系统:交互式教学教程\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"## 👋 欢迎!\n",
|
| 10 |
+
"欢迎来到这份专为学习者设计的 **交互式 Jupyter Notebook** 教程。\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"**本项目的目标**:我们将从零开始,构建一个能够理解中文评论“情绪”的人工智能模型。不是简单地调用 API,而是亲手训练一个工业级的 **BERT** 模型。\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"## 📚 你将学到什么?\n",
|
| 15 |
+
"1. **环境配置**:如何利用 Mac 的 MPS 加速深度学习。\n",
|
| 16 |
+
"2. **数据工程**:从 Hugging Face 获取数据,并清洗、统一。\n",
|
| 17 |
+
"3. **模型原理**:BERT 是如何理解中文的?\n",
|
| 18 |
+
"4. **模型训练**:如何进行微调 (Fine-tuning) 以适应特定任务。\n",
|
| 19 |
+
"5. **模型应用**:如何用自己训练的模型来分析一句话。\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"---"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "markdown",
|
| 26 |
+
"metadata": {},
|
| 27 |
+
"source": [
|
| 28 |
+
"## 1️⃣ 第一步:导入工具包与环境检查\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"在开始做菜之前,我们需要先把锅碗瓢盆(工具包)准备好。\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"**核心工具介绍**:\n",
|
| 33 |
+
"* **Transformers**: 由 Hugging Face 提供,是目前全世界最流行的 NLP 库,用来加载 BERT 模型。\n",
|
| 34 |
+
"* **Datasets**:这也是 Hugging Face 的产品,用来下载与处理海量数据。\n",
|
| 35 |
+
"* **Pandas**: 用来像 Excel 一样查看数据表格。\n",
|
| 36 |
+
"* **Torch**: Pytorch 深度学习框架,我们的“引擎”。"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "code",
|
| 41 |
+
"execution_count": null,
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"outputs": [],
|
| 44 |
+
"source": [
|
| 45 |
+
"import os\n",
|
| 46 |
+
"import torch\n",
|
| 47 |
+
"import pandas as pd\n",
|
| 48 |
+
"import matplotlib.pyplot as plt\n",
|
| 49 |
+
"import seaborn as sns\n",
|
| 50 |
+
"from datasets import load_dataset, concatenate_datasets\n",
|
| 51 |
+
"from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
|
| 52 |
+
"from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"# === 硬件加速检查 ===\n",
|
| 55 |
+
"# 深度学习需要大量的矩阵计算,CPU 算得太慢。\n",
|
| 56 |
+
"# Mac 电脑有专门的 MPS (Metal Performance Shaders) 加速芯片。\n",
|
| 57 |
+
"if torch.backends.mps.is_available():\n",
|
| 58 |
+
" device = torch.device(\"mps\")\n",
|
| 59 |
+
" print(\"✅ 恭喜!检测到 Mac MPS 硬件加速,训练速度将起飞!🚀\")\n",
|
| 60 |
+
"elif torch.cuda.is_available():\n",
|
| 61 |
+
" device = torch.device(\"cuda\")\n",
|
| 62 |
+
" print(\"✅ 检测到 NVIDIA CUDA,将使用 GPU 训练。\")\n",
|
| 63 |
+
"else:\n",
|
| 64 |
+
" device = torch.device(\"cpu\")\n",
|
| 65 |
+
" print(\"⚠️ 未检测到 GPU,将使用 CPU 训练。速度可能会比较慢,请耐心等待。☕️\")"
|
| 66 |
+
]
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"cell_type": "markdown",
|
| 70 |
+
"metadata": {},
|
| 71 |
+
"source": [
|
| 72 |
+
"## 2️⃣ 第二步:配置参数 (Config)\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"为了让代码整洁,我们将所有的“设置项”都放在这里。这就好比做菜前的“菜谱”。\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"* **BASE_MODEL**: 我们选用的基底模型是 `bert-base-chinese`,它是谷歌训练好的、已经读过几亿字中文的“高材生”。\n",
|
| 77 |
+
"* **NUM_EPOCHS**: 训练轮数。设为 3,意味着模型会把我们的教材从头到尾看 3 遍。"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": null,
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"outputs": [],
|
| 85 |
+
"source": [
|
| 86 |
+
"class Config:\n",
|
| 87 |
+
" # 基模型:BERT 中文版\n",
|
| 88 |
+
" BASE_MODEL = \"google-bert/bert-base-chinese\"\n",
|
| 89 |
+
" \n",
|
| 90 |
+
" # 分类数量:3类 (消极-0, 中性-1, 积极-2)\n",
|
| 91 |
+
" NUM_LABELS = 3\n",
|
| 92 |
+
" \n",
|
| 93 |
+
" # 每一句话最长处理多少个字?超过的截断,不足的补0\n",
|
| 94 |
+
" MAX_LENGTH = 128\n",
|
| 95 |
+
" \n",
|
| 96 |
+
" # 路径配置\n",
|
| 97 |
+
" OUTPUT_DIR = \"../checkpoints/tutorial_model\"\n",
|
| 98 |
+
" \n",
|
| 99 |
+
" # 训练超参数\n",
|
| 100 |
+
" BATCH_SIZE = 16 # 一次可以并行处理多少句话 (看显存大小)\n",
|
| 101 |
+
" LEARNING_RATE = 2e-5 # 学习率:模型学得太快容易学偏,太慢容易学不会。2e-5 是经验值。\n",
|
| 102 |
+
" NUM_EPOCHS = 3 # 训练几轮\n",
|
| 103 |
+
" \n",
|
| 104 |
+
" # 标签字典\n",
|
| 105 |
+
" ID2LABEL = {0: 'Negative (消极)', 1: 'Neutral (中性)', 2: 'Positive (积极)'}\n",
|
| 106 |
+
" LABEL2ID = {'negative': 0, 'neutral': 1, 'positive': 2}\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"print(\"配置加载完毕。\")"
|
| 109 |
+
]
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"cell_type": "markdown",
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"source": [
|
| 115 |
+
"## 3️⃣ 第三步:准备数据 (Data Preparation)\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"我们的策略是 **“混合双打”**:\n",
|
| 118 |
+
"1. **通用数据** (`clapAI`): 包含日常生活的各种评论,让模型懂常识。\n",
|
| 119 |
+
"2. **垂直数据** (`OpenModels`): 包含中医药领域的评论,让模型懂行话。\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"下面的代码会自动从网络加载这些数据,并进行清洗。"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": null,
|
| 127 |
+
"metadata": {},
|
| 128 |
+
"outputs": [],
|
| 129 |
+
"source": [
|
| 130 |
+
"# 加载 Tokenizer (分词器)\n",
|
| 131 |
+
"# 它的作用是把汉字转换成模型能读懂的数字 ID\n",
|
| 132 |
+
"tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"def prepare_dataset():\n",
|
| 135 |
+
" print(\"⏳ 正在加载数据 (可能需要一点时间下载)...\")\n",
|
| 136 |
+
" \n",
|
| 137 |
+
" # 为了演示速度,我们只取前 1000 条数据 (正式训练时会用全部数据)\n",
|
| 138 |
+
" # 如果电脑性能好,可以把 split=\"train[:1000]\" 改成 split=\"train\"\n",
|
| 139 |
+
" sample_size = 500\n",
|
| 140 |
+
" \n",
|
| 141 |
+
" # 1. 加载通用情感数据\n",
|
| 142 |
+
" ds_clap = load_dataset(\"clapAI/MultiLingualSentiment\", split=f\"train[:{sample_size}]\", trust_remote_code=True)\n",
|
| 143 |
+
" ds_clap = ds_clap.filter(lambda x: x['language'] == 'zh') # 只留中文\n",
|
| 144 |
+
" \n",
|
| 145 |
+
" # 2. 加载中医药情感数据\n",
|
| 146 |
+
" ds_med = load_dataset(\"OpenModels/Chinese-Herbal-Medicine-Sentiment\", split=f\"train[:{sample_size}]\", trust_remote_code=True)\n",
|
| 147 |
+
" \n",
|
| 148 |
+
" # 3. 统一列名\n",
|
| 149 |
+
" # 不同数据集的列名可能不一样,我们要把它们统一改成 'text' 和 'label'\n",
|
| 150 |
+
" if 'review_text' in ds_med.column_names: ds_med = ds_med.rename_column('review_text', 'text')\n",
|
| 151 |
+
" if 'sentiment_label' in ds_med.column_names: ds_med = ds_med.rename_column('sentiment_label', 'label')\n",
|
| 152 |
+
" \n",
|
| 153 |
+
" # 4. 合并数据集\n",
|
| 154 |
+
" common_cols = ['text', 'label']\n",
|
| 155 |
+
" combined = concatenate_datasets([ds_clap.select_columns(common_cols), ds_med.select_columns(common_cols)])\n",
|
| 156 |
+
" \n",
|
| 157 |
+
" # 5. 数据清洗与统一标签\n",
|
| 158 |
+
" def process_data(example):\n",
|
| 159 |
+
" # 统一标签为数字 0, 1, 2\n",
|
| 160 |
+
" lbl = example['label']\n",
|
| 161 |
+
" if isinstance(lbl, str):\n",
|
| 162 |
+
" lbl = lbl.lower()\n",
|
| 163 |
+
" if lbl in ['negative', '0']: lbl = 0\n",
|
| 164 |
+
" elif lbl in ['neutral', '1']: lbl = 1\n",
|
| 165 |
+
" elif lbl in ['positive', '2']: lbl = 2\n",
|
| 166 |
+
" return {'labels': int(lbl)}\n",
|
| 167 |
+
" \n",
|
| 168 |
+
" combined = combined.map(process_data)\n",
|
| 169 |
+
" \n",
|
| 170 |
+
" # 6. 分词 (Tokenization)\n",
|
| 171 |
+
" def tokenize(batch):\n",
|
| 172 |
+
" return tokenizer(batch['text'], padding=\"max_length\", truncation=True, max_length=Config.MAX_LENGTH)\n",
|
| 173 |
+
" \n",
|
| 174 |
+
" print(\"✂️ 正在进行分词处理...\")\n",
|
| 175 |
+
" tokenized_ds = combined.map(tokenize, batched=True)\n",
|
| 176 |
+
" \n",
|
| 177 |
+
" # 7. 划分训练集和验证集 (90% 训练, 10% 验证)\n",
|
| 178 |
+
" return tokenized_ds.train_test_split(test_size=0.1)\n",
|
| 179 |
+
"\n",
|
| 180 |
+
"# 执行数据准备\n",
|
| 181 |
+
"dataset = prepare_dataset()\n",
|
| 182 |
+
"print(f\"\\n✅ 数据准备完成!\\n训练集大小: {len(dataset['train'])} 条\\n测试集大小: {len(dataset['test'])} 条\")"
|
| 183 |
+
]
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"cell_type": "markdown",
|
| 187 |
+
"metadata": {},
|
| 188 |
+
"source": [
|
| 189 |
+
"## 4️⃣ 第四步:数据可视化 (Data Visualization)\n",
|
| 190 |
+
"\n",
|
| 191 |
+
"很多时候模型训练不好是因为数据分布不均匀(比如全是好评,那模型只要一直猜好评准确率也很高,但这没用)。\n",
|
| 192 |
+
"让我们画个饼图来看看我们的数据怎么样。"
|
| 193 |
+
]
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"cell_type": "code",
|
| 197 |
+
"execution_count": null,
|
| 198 |
+
"metadata": {},
|
| 199 |
+
"outputs": [],
|
| 200 |
+
"source": [
|
| 201 |
+
"# 从 dataset 中提取 label 列\n",
|
| 202 |
+
"train_labels = dataset['train']['labels']\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"# 统计每个类别的数量\n",
|
| 205 |
+
"labels_count = pd.Series(train_labels).value_counts().sort_index()\n",
|
| 206 |
+
"labels_name = [Config.ID2LABEL[i] for i in labels_count.index]\n",
|
| 207 |
+
"\n",
|
| 208 |
+
"# 由于 Matplotlib 默认不支持中文,我们用英文显示或者设置字体,这里为了简单直接用英文\n",
|
| 209 |
+
"plt.figure(figsize=(8, 5))\n",
|
| 210 |
+
"plt.pie(labels_count, labels=labels_name, autopct='%1.1f%%', colors=['#ff9999','#66b3ff','#99ff99'])\n",
|
| 211 |
+
"plt.title('Training Data Distribution')\n",
|
| 212 |
+
"plt.show()"
|
| 213 |
+
]
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"cell_type": "markdown",
|
| 217 |
+
"metadata": {},
|
| 218 |
+
"source": [
|
| 219 |
+
"## 5️⃣ 第五步:模型训练 (Model Training)\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"这是最激动人心的一步!我们将启动 Hugging Face `Trainer`。\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"我们将实现一个**“智能跳过”**逻辑:如果检测到之前已经训练好了模型,就直接加载,不再浪费时间重新训练。"
|
| 224 |
+
]
|
| 225 |
+
},
|
| 226 |
+
{
|
| 227 |
+
"cell_type": "code",
|
| 228 |
+
"execution_count": null,
|
| 229 |
+
"metadata": {},
|
| 230 |
+
"outputs": [],
|
| 231 |
+
"source": [
|
| 232 |
+
"# 定义评价指标:我们需要知道模型的准确率(Accuracy)\n",
|
| 233 |
+
"def compute_metrics(pred):\n",
|
| 234 |
+
" labels = pred.label_ids\n",
|
| 235 |
+
" preds = pred.predictions.argmax(-1)\n",
|
| 236 |
+
" acc = accuracy_score(labels, preds)\n",
|
| 237 |
+
" return {'accuracy': acc}\n",
|
| 238 |
+
"\n",
|
| 239 |
+
"# 检查是否已存在\n",
|
| 240 |
+
"if os.path.exists(Config.OUTPUT_DIR) and os.path.exists(os.path.join(Config.OUTPUT_DIR, \"config.json\")):\n",
|
| 241 |
+
" print(f\"🎉 检测到已训练的模型: {Config.OUTPUT_DIR}\")\n",
|
| 242 |
+
" print(\"🚀 直接加载模型,跳过训练!\")\n",
|
| 243 |
+
" model = AutoModelForSequenceClassification.from_pretrained(Config.OUTPUT_DIR)\n",
|
| 244 |
+
" model.to(device)\n",
|
| 245 |
+
"else:\n",
|
| 246 |
+
" print(\"💪 未找到已训练模型,开始新一轮训练...\")\n",
|
| 247 |
+
" \n",
|
| 248 |
+
" # 加载初始模型\n",
|
| 249 |
+
" model = AutoModelForSequenceClassification.from_pretrained(Config.BASE_MODEL, num_labels=Config.NUM_LABELS)\n",
|
| 250 |
+
" model.to(device)\n",
|
| 251 |
+
" \n",
|
| 252 |
+
" # 设置训练参数\n",
|
| 253 |
+
" training_args = TrainingArguments(\n",
|
| 254 |
+
" output_dir=Config.OUTPUT_DIR,\n",
|
| 255 |
+
" num_train_epochs=Config.NUM_EPOCHS,\n",
|
| 256 |
+
" per_device_train_batch_size=Config.BATCH_SIZE,\n",
|
| 257 |
+
" evaluation_strategy=\"epoch\", # 每个 Epoch 结束后评估一次\n",
|
| 258 |
+
" save_strategy=\"epoch\", # 每个 Epoch 结束后保存一次\n",
|
| 259 |
+
" logging_steps=10,\n",
|
| 260 |
+
" report_to=\"none\" # 不上报到wandb\n",
|
| 261 |
+
" )\n",
|
| 262 |
+
" \n",
|
| 263 |
+
" # 初始化训练器\n",
|
| 264 |
+
" trainer = Trainer(\n",
|
| 265 |
+
" model=model,\n",
|
| 266 |
+
" args=training_args,\n",
|
| 267 |
+
" train_dataset=dataset['train'],\n",
|
| 268 |
+
" eval_dataset=dataset['test'],\n",
|
| 269 |
+
" processing_class=tokenizer,\n",
|
| 270 |
+
" compute_metrics=compute_metrics\n",
|
| 271 |
+
" )\n",
|
| 272 |
+
" \n",
|
| 273 |
+
" # 开始训练!\n",
|
| 274 |
+
" trainer.train()\n",
|
| 275 |
+
" \n",
|
| 276 |
+
" # 保存最终结果\n",
|
| 277 |
+
" trainer.save_model(Config.OUTPUT_DIR)\n",
|
| 278 |
+
" tokenizer.save_pretrained(Config.OUTPUT_DIR)\n",
|
| 279 |
+
" print(\"💾 训练完成,模型已保存!\")"
|
| 280 |
+
]
|
| 281 |
+
},
|
| 282 |
+
{
|
| 283 |
+
"cell_type": "markdown",
|
| 284 |
+
"metadata": {},
|
| 285 |
+
"source": [
|
| 286 |
+
"## 6️⃣ 第六步:互动测试 (Inference Demo)\n",
|
| 287 |
+
"\n",
|
| 288 |
+
"现在模型已经“毕业”了,让我们来考考它!\n",
|
| 289 |
+
"在下面的输入框里随便输入一句话(支持中文),点击“分析”看看它觉得的情感是什么。"
|
| 290 |
+
]
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
"cell_type": "code",
|
| 294 |
+
"execution_count": null,
|
| 295 |
+
"metadata": {},
|
| 296 |
+
"outputs": [],
|
| 297 |
+
"source": [
|
| 298 |
+
"import ipywidgets as widgets\n",
|
| 299 |
+
"from IPython.display import display\n",
|
| 300 |
+
"\n",
|
| 301 |
+
"# 预测函数\n",
|
| 302 |
+
"def predict_sentiment(text):\n",
|
| 303 |
+
" # 1. 预处理\n",
|
| 304 |
+
" inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=128, padding=True)\n",
|
| 305 |
+
" inputs = {k: v.to(device) for k, v in inputs.items()}\n",
|
| 306 |
+
" \n",
|
| 307 |
+
" # 2. 模型推理\n",
|
| 308 |
+
" with torch.no_grad():\n",
|
| 309 |
+
" outputs = model(**inputs)\n",
|
| 310 |
+
" probs = torch.nn.functional.softmax(outputs.logits, dim=-1)\n",
|
| 311 |
+
" \n",
|
| 312 |
+
" # 3. 结果解析\n",
|
| 313 |
+
" pred_idx = torch.argmax(probs).item()\n",
|
| 314 |
+
" confidence = probs[0][pred_idx].item()\n",
|
| 315 |
+
" label = Config.ID2LABEL[pred_idx]\n",
|
| 316 |
+
" \n",
|
| 317 |
+
" return label, confidence\n",
|
| 318 |
+
"\n",
|
| 319 |
+
"# 界面组件\n",
|
| 320 |
+
"text_box = widgets.Text(placeholder='请输入要分析的句子...', description='评论:', layout=widgets.Layout(width='400px'))\n",
|
| 321 |
+
"btn_run = widgets.Button(description=\"开始分析\", button_style='primary')\n",
|
| 322 |
+
"output_area = widgets.Output()\n",
|
| 323 |
+
"\n",
|
| 324 |
+
"def on_click(b):\n",
|
| 325 |
+
" with output_area:\n",
|
| 326 |
+
" output_area.clear_output()\n",
|
| 327 |
+
" text = text_box.value\n",
|
| 328 |
+
" if not text:\n",
|
| 329 |
+
" print(\"❌ 请先输入内容!\")\n",
|
| 330 |
+
" return\n",
|
| 331 |
+
" \n",
|
| 332 |
+
" print(f\"🔍 正在分析: \\\"{text}\\\"\")\n",
|
| 333 |
+
" label, conf = predict_sentiment(text)\n",
|
| 334 |
+
" \n",
|
| 335 |
+
" # 只有置信度高才显示绿色,否则显示黄色\n",
|
| 336 |
+
" icon = \"✅\" if conf > 0.8 else \"🤔\"\n",
|
| 337 |
+
" print(f\"{icon} 预测结果: [{label}] \")\n",
|
| 338 |
+
" print(f\"📊 置信度: {conf*100:.2f}%\")\n",
|
| 339 |
+
"\n",
|
| 340 |
+
"btn_run.on_click(on_click)\n",
|
| 341 |
+
"display(text_box, btn_run, output_area)"
|
| 342 |
+
]
|
| 343 |
+
}
|
| 344 |
+
],
|
| 345 |
+
"metadata": {
|
| 346 |
+
"kernelspec": {
|
| 347 |
+
"display_name": "Python 3",
|
| 348 |
+
"language": "python",
|
| 349 |
+
"name": "python3"
|
| 350 |
+
},
|
| 351 |
+
"language_info": {
|
| 352 |
+
"codemirror_mode": {
|
| 353 |
+
"name": "ipython",
|
| 354 |
+
"version": 3
|
| 355 |
+
},
|
| 356 |
+
"file_extension": ".py",
|
| 357 |
+
"mimetype": "text/x-python",
|
| 358 |
+
"name": "python",
|
| 359 |
+
"nbconvert_exporter": "python",
|
| 360 |
+
"pygments_lexer": "ipython3",
|
| 361 |
+
"version": "3.12.0"
|
| 362 |
+
}
|
| 363 |
+
},
|
| 364 |
+
"nbformat": 4,
|
| 365 |
+
"nbformat_minor": 2
|
| 366 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers>=4.30.0
|
| 2 |
+
datasets>=2.14.0
|
| 3 |
+
scikit-learn
|
| 4 |
+
pandas
|
| 5 |
+
accelerate>=0.21.0
|
| 6 |
+
tqdm
|
results/checkpoint-4000/config.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertForSequenceClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"classifier_dropout": null,
|
| 7 |
+
"directionality": "bidi",
|
| 8 |
+
"dtype": "float32",
|
| 9 |
+
"hidden_act": "gelu",
|
| 10 |
+
"hidden_dropout_prob": 0.1,
|
| 11 |
+
"hidden_size": 768,
|
| 12 |
+
"id2label": {
|
| 13 |
+
"0": "negative",
|
| 14 |
+
"1": "neutral",
|
| 15 |
+
"2": "positive"
|
| 16 |
+
},
|
| 17 |
+
"initializer_range": 0.02,
|
| 18 |
+
"intermediate_size": 3072,
|
| 19 |
+
"label2id": {
|
| 20 |
+
"negative": 0,
|
| 21 |
+
"neutral": 1,
|
| 22 |
+
"positive": 2
|
| 23 |
+
},
|
| 24 |
+
"layer_norm_eps": 1e-12,
|
| 25 |
+
"max_position_embeddings": 512,
|
| 26 |
+
"model_type": "bert",
|
| 27 |
+
"num_attention_heads": 12,
|
| 28 |
+
"num_hidden_layers": 12,
|
| 29 |
+
"pad_token_id": 0,
|
| 30 |
+
"pooler_fc_size": 768,
|
| 31 |
+
"pooler_num_attention_heads": 12,
|
| 32 |
+
"pooler_num_fc_layers": 3,
|
| 33 |
+
"pooler_size_per_head": 128,
|
| 34 |
+
"pooler_type": "first_token_transform",
|
| 35 |
+
"position_embedding_type": "absolute",
|
| 36 |
+
"problem_type": "single_label_classification",
|
| 37 |
+
"transformers_version": "4.57.3",
|
| 38 |
+
"type_vocab_size": 2,
|
| 39 |
+
"use_cache": true,
|
| 40 |
+
"vocab_size": 21128
|
| 41 |
+
}
|
results/checkpoint-4000/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2411c50f0203761f7a239380c2d7dc58f6a204a1ca158d31c375b007aad25f5b
|
| 3 |
+
size 409103316
|
results/checkpoint-4000/optimizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:75c792b4789254d1d67ca82233869a041a59b9573dfac628ec4e04776278c4c6
|
| 3 |
+
size 818320969
|
results/checkpoint-4000/scheduler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e041cf40b86819ae2811b72b3e119b9b56d39ebef1f35169420e513898c8bcbf
|
| 3 |
+
size 1453
|
results/checkpoint-4000/special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "[CLS]",
|
| 3 |
+
"mask_token": "[MASK]",
|
| 4 |
+
"pad_token": "[PAD]",
|
| 5 |
+
"sep_token": "[SEP]",
|
| 6 |
+
"unk_token": "[UNK]"
|
| 7 |
+
}
|
results/checkpoint-4000/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
results/checkpoint-4000/tokenizer_config.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"100": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"101": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"102": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"103": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": false,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_lower_case": false,
|
| 47 |
+
"extra_special_tokens": {},
|
| 48 |
+
"mask_token": "[MASK]",
|
| 49 |
+
"model_max_length": 512,
|
| 50 |
+
"pad_token": "[PAD]",
|
| 51 |
+
"sep_token": "[SEP]",
|
| 52 |
+
"strip_accents": null,
|
| 53 |
+
"tokenize_chinese_chars": true,
|
| 54 |
+
"tokenizer_class": "BertTokenizer",
|
| 55 |
+
"unk_token": "[UNK]"
|
| 56 |
+
}
|
results/checkpoint-4000/trainer_state.json
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": 3500,
|
| 3 |
+
"best_metric": 0.774823898413337,
|
| 4 |
+
"best_model_checkpoint": "/Users/wangyiqiu/Desktop/program/\u795e\u7ecf\u7f51\u7edc\u62d3\u6251/results/checkpoint-3500",
|
| 5 |
+
"epoch": 0.2526847757422615,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 4000,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 0.006317119393556538,
|
| 14 |
+
"grad_norm": 13.061114311218262,
|
| 15 |
+
"learning_rate": 4.169298799747316e-07,
|
| 16 |
+
"loss": 1.354,
|
| 17 |
+
"step": 100
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"epoch": 0.012634238787113077,
|
| 21 |
+
"grad_norm": 13.682186126708984,
|
| 22 |
+
"learning_rate": 8.380711728785009e-07,
|
| 23 |
+
"loss": 1.0853,
|
| 24 |
+
"step": 200
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"epoch": 0.018951358180669616,
|
| 28 |
+
"grad_norm": 4.851679801940918,
|
| 29 |
+
"learning_rate": 1.2592124657822702e-06,
|
| 30 |
+
"loss": 0.9111,
|
| 31 |
+
"step": 300
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"epoch": 0.025268477574226154,
|
| 35 |
+
"grad_norm": 5.82253360748291,
|
| 36 |
+
"learning_rate": 1.6803537586860393e-06,
|
| 37 |
+
"loss": 0.7179,
|
| 38 |
+
"step": 400
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"epoch": 0.03158559696778269,
|
| 42 |
+
"grad_norm": 5.032683372497559,
|
| 43 |
+
"learning_rate": 2.1014950515898086e-06,
|
| 44 |
+
"loss": 0.6422,
|
| 45 |
+
"step": 500
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 0.03158559696778269,
|
| 49 |
+
"eval_accuracy": 0.7368075050637859,
|
| 50 |
+
"eval_f1": 0.7170832086299176,
|
| 51 |
+
"eval_loss": 0.6070035696029663,
|
| 52 |
+
"eval_precision": 0.7218199142709759,
|
| 53 |
+
"eval_recall": 0.7368075050637859,
|
| 54 |
+
"eval_runtime": 582.5178,
|
| 55 |
+
"eval_samples_per_second": 96.619,
|
| 56 |
+
"eval_steps_per_second": 3.02,
|
| 57 |
+
"step": 500
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"epoch": 0.03790271636133923,
|
| 61 |
+
"grad_norm": 7.424877166748047,
|
| 62 |
+
"learning_rate": 2.5226363444935774e-06,
|
| 63 |
+
"loss": 0.6155,
|
| 64 |
+
"step": 600
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"epoch": 0.04421983575489577,
|
| 68 |
+
"grad_norm": 16.976255416870117,
|
| 69 |
+
"learning_rate": 2.943777637397347e-06,
|
| 70 |
+
"loss": 0.5944,
|
| 71 |
+
"step": 700
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"epoch": 0.05053695514845231,
|
| 75 |
+
"grad_norm": 9.103567123413086,
|
| 76 |
+
"learning_rate": 3.3649189303011164e-06,
|
| 77 |
+
"loss": 0.5812,
|
| 78 |
+
"step": 800
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"epoch": 0.056854074542008845,
|
| 82 |
+
"grad_norm": 7.061375617980957,
|
| 83 |
+
"learning_rate": 3.7860602232048853e-06,
|
| 84 |
+
"loss": 0.5965,
|
| 85 |
+
"step": 900
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"epoch": 0.06317119393556538,
|
| 89 |
+
"grad_norm": 6.224503040313721,
|
| 90 |
+
"learning_rate": 4.207201516108655e-06,
|
| 91 |
+
"loss": 0.5553,
|
| 92 |
+
"step": 1000
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"epoch": 0.06317119393556538,
|
| 96 |
+
"eval_accuracy": 0.7581642443409972,
|
| 97 |
+
"eval_f1": 0.7448374295446439,
|
| 98 |
+
"eval_loss": 0.5610596537590027,
|
| 99 |
+
"eval_precision": 0.7461287482946488,
|
| 100 |
+
"eval_recall": 0.7581642443409972,
|
| 101 |
+
"eval_runtime": 584.5541,
|
| 102 |
+
"eval_samples_per_second": 96.282,
|
| 103 |
+
"eval_steps_per_second": 3.009,
|
| 104 |
+
"step": 1000
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"epoch": 0.06948831332912192,
|
| 108 |
+
"grad_norm": 6.321476459503174,
|
| 109 |
+
"learning_rate": 4.628342809012423e-06,
|
| 110 |
+
"loss": 0.592,
|
| 111 |
+
"step": 1100
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"epoch": 0.07580543272267846,
|
| 115 |
+
"grad_norm": 8.201200485229492,
|
| 116 |
+
"learning_rate": 5.0494841019161935e-06,
|
| 117 |
+
"loss": 0.5518,
|
| 118 |
+
"step": 1200
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"epoch": 0.082122552116235,
|
| 122 |
+
"grad_norm": 6.514477729797363,
|
| 123 |
+
"learning_rate": 5.470625394819963e-06,
|
| 124 |
+
"loss": 0.5897,
|
| 125 |
+
"step": 1300
|
| 126 |
+
},
|
| 127 |
+
{
|
| 128 |
+
"epoch": 0.08843967150979154,
|
| 129 |
+
"grad_norm": 8.077017784118652,
|
| 130 |
+
"learning_rate": 5.891766687723732e-06,
|
| 131 |
+
"loss": 0.5476,
|
| 132 |
+
"step": 1400
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"epoch": 0.09475679090334807,
|
| 136 |
+
"grad_norm": 9.256704330444336,
|
| 137 |
+
"learning_rate": 6.3129079806275005e-06,
|
| 138 |
+
"loss": 0.5263,
|
| 139 |
+
"step": 1500
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"epoch": 0.09475679090334807,
|
| 143 |
+
"eval_accuracy": 0.7675278064034683,
|
| 144 |
+
"eval_f1": 0.7632915279870514,
|
| 145 |
+
"eval_loss": 0.5426821112632751,
|
| 146 |
+
"eval_precision": 0.760979358962669,
|
| 147 |
+
"eval_recall": 0.7675278064034683,
|
| 148 |
+
"eval_runtime": 587.2504,
|
| 149 |
+
"eval_samples_per_second": 95.84,
|
| 150 |
+
"eval_steps_per_second": 2.995,
|
| 151 |
+
"step": 1500
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"epoch": 0.10107391029690461,
|
| 155 |
+
"grad_norm": 6.117814064025879,
|
| 156 |
+
"learning_rate": 6.73404927353127e-06,
|
| 157 |
+
"loss": 0.5563,
|
| 158 |
+
"step": 1600
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"epoch": 0.10739102969046115,
|
| 162 |
+
"grad_norm": 9.015992164611816,
|
| 163 |
+
"learning_rate": 7.15519056643504e-06,
|
| 164 |
+
"loss": 0.5622,
|
| 165 |
+
"step": 1700
|
| 166 |
+
},
|
| 167 |
+
{
|
| 168 |
+
"epoch": 0.11370814908401769,
|
| 169 |
+
"grad_norm": 8.684099197387695,
|
| 170 |
+
"learning_rate": 7.576331859338809e-06,
|
| 171 |
+
"loss": 0.5483,
|
| 172 |
+
"step": 1800
|
| 173 |
+
},
|
| 174 |
+
{
|
| 175 |
+
"epoch": 0.12002526847757422,
|
| 176 |
+
"grad_norm": 5.517951488494873,
|
| 177 |
+
"learning_rate": 7.997473152242578e-06,
|
| 178 |
+
"loss": 0.5467,
|
| 179 |
+
"step": 1900
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"epoch": 0.12634238787113075,
|
| 183 |
+
"grad_norm": 4.840009689331055,
|
| 184 |
+
"learning_rate": 8.418614445146347e-06,
|
| 185 |
+
"loss": 0.5472,
|
| 186 |
+
"step": 2000
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"epoch": 0.12634238787113075,
|
| 190 |
+
"eval_accuracy": 0.7682740485412743,
|
| 191 |
+
"eval_f1": 0.7644619158467771,
|
| 192 |
+
"eval_loss": 0.5479554533958435,
|
| 193 |
+
"eval_precision": 0.7616941910129872,
|
| 194 |
+
"eval_recall": 0.7682740485412743,
|
| 195 |
+
"eval_runtime": 594.3974,
|
| 196 |
+
"eval_samples_per_second": 94.687,
|
| 197 |
+
"eval_steps_per_second": 2.959,
|
| 198 |
+
"step": 2000
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"epoch": 0.1326595072646873,
|
| 202 |
+
"grad_norm": 9.188036918640137,
|
| 203 |
+
"learning_rate": 8.839755738050117e-06,
|
| 204 |
+
"loss": 0.5436,
|
| 205 |
+
"step": 2100
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"epoch": 0.13897662665824384,
|
| 209 |
+
"grad_norm": 5.845507621765137,
|
| 210 |
+
"learning_rate": 9.260897030953885e-06,
|
| 211 |
+
"loss": 0.5684,
|
| 212 |
+
"step": 2200
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"epoch": 0.14529374605180037,
|
| 216 |
+
"grad_norm": 6.014614105224609,
|
| 217 |
+
"learning_rate": 9.682038323857656e-06,
|
| 218 |
+
"loss": 0.5268,
|
| 219 |
+
"step": 2300
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
"epoch": 0.15161086544535693,
|
| 223 |
+
"grad_norm": 5.183818817138672,
|
| 224 |
+
"learning_rate": 1.0103179616761426e-05,
|
| 225 |
+
"loss": 0.5505,
|
| 226 |
+
"step": 2400
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"epoch": 0.15792798483891346,
|
| 230 |
+
"grad_norm": 4.270262718200684,
|
| 231 |
+
"learning_rate": 1.0524320909665192e-05,
|
| 232 |
+
"loss": 0.5327,
|
| 233 |
+
"step": 2500
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"epoch": 0.15792798483891346,
|
| 237 |
+
"eval_accuracy": 0.7718631178707225,
|
| 238 |
+
"eval_f1": 0.7701652961241094,
|
| 239 |
+
"eval_loss": 0.538950502872467,
|
| 240 |
+
"eval_precision": 0.7692113501499637,
|
| 241 |
+
"eval_recall": 0.7718631178707225,
|
| 242 |
+
"eval_runtime": 598.0361,
|
| 243 |
+
"eval_samples_per_second": 94.111,
|
| 244 |
+
"eval_steps_per_second": 2.941,
|
| 245 |
+
"step": 2500
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"epoch": 0.16424510423247,
|
| 249 |
+
"grad_norm": 6.861387729644775,
|
| 250 |
+
"learning_rate": 1.0945462202568964e-05,
|
| 251 |
+
"loss": 0.5301,
|
| 252 |
+
"step": 2600
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"epoch": 0.17056222362602652,
|
| 256 |
+
"grad_norm": 7.5304670333862305,
|
| 257 |
+
"learning_rate": 1.1366603495472733e-05,
|
| 258 |
+
"loss": 0.5254,
|
| 259 |
+
"step": 2700
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"epoch": 0.17687934301958308,
|
| 263 |
+
"grad_norm": 5.88840913772583,
|
| 264 |
+
"learning_rate": 1.1787744788376501e-05,
|
| 265 |
+
"loss": 0.5387,
|
| 266 |
+
"step": 2800
|
| 267 |
+
},
|
| 268 |
+
{
|
| 269 |
+
"epoch": 0.1831964624131396,
|
| 270 |
+
"grad_norm": 6.836195945739746,
|
| 271 |
+
"learning_rate": 1.2208886081280271e-05,
|
| 272 |
+
"loss": 0.5235,
|
| 273 |
+
"step": 2900
|
| 274 |
+
},
|
| 275 |
+
{
|
| 276 |
+
"epoch": 0.18951358180669614,
|
| 277 |
+
"grad_norm": 4.248595237731934,
|
| 278 |
+
"learning_rate": 1.263002737418404e-05,
|
| 279 |
+
"loss": 0.5342,
|
| 280 |
+
"step": 3000
|
| 281 |
+
},
|
| 282 |
+
{
|
| 283 |
+
"epoch": 0.18951358180669614,
|
| 284 |
+
"eval_accuracy": 0.7746348743825735,
|
| 285 |
+
"eval_f1": 0.7710043344887744,
|
| 286 |
+
"eval_loss": 0.5276312828063965,
|
| 287 |
+
"eval_precision": 0.7689047947812672,
|
| 288 |
+
"eval_recall": 0.7746348743825735,
|
| 289 |
+
"eval_runtime": 599.7353,
|
| 290 |
+
"eval_samples_per_second": 93.845,
|
| 291 |
+
"eval_steps_per_second": 2.933,
|
| 292 |
+
"step": 3000
|
| 293 |
+
},
|
| 294 |
+
{
|
| 295 |
+
"epoch": 0.19583070120025267,
|
| 296 |
+
"grad_norm": 6.620116710662842,
|
| 297 |
+
"learning_rate": 1.3051168667087808e-05,
|
| 298 |
+
"loss": 0.5432,
|
| 299 |
+
"step": 3100
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"epoch": 0.20214782059380923,
|
| 303 |
+
"grad_norm": 4.005882740020752,
|
| 304 |
+
"learning_rate": 1.3472309959991578e-05,
|
| 305 |
+
"loss": 0.5201,
|
| 306 |
+
"step": 3200
|
| 307 |
+
},
|
| 308 |
+
{
|
| 309 |
+
"epoch": 0.20846493998736576,
|
| 310 |
+
"grad_norm": 3.873512029647827,
|
| 311 |
+
"learning_rate": 1.3893451252895347e-05,
|
| 312 |
+
"loss": 0.5418,
|
| 313 |
+
"step": 3300
|
| 314 |
+
},
|
| 315 |
+
{
|
| 316 |
+
"epoch": 0.2147820593809223,
|
| 317 |
+
"grad_norm": 4.081575870513916,
|
| 318 |
+
"learning_rate": 1.4314592545799117e-05,
|
| 319 |
+
"loss": 0.5298,
|
| 320 |
+
"step": 3400
|
| 321 |
+
},
|
| 322 |
+
{
|
| 323 |
+
"epoch": 0.22109917877447885,
|
| 324 |
+
"grad_norm": 4.8460187911987305,
|
| 325 |
+
"learning_rate": 1.4735733838702885e-05,
|
| 326 |
+
"loss": 0.5397,
|
| 327 |
+
"step": 3500
|
| 328 |
+
},
|
| 329 |
+
{
|
| 330 |
+
"epoch": 0.22109917877447885,
|
| 331 |
+
"eval_accuracy": 0.7759319142887602,
|
| 332 |
+
"eval_f1": 0.774823898413337,
|
| 333 |
+
"eval_loss": 0.5257604718208313,
|
| 334 |
+
"eval_precision": 0.7763093994740736,
|
| 335 |
+
"eval_recall": 0.7759319142887602,
|
| 336 |
+
"eval_runtime": 594.4851,
|
| 337 |
+
"eval_samples_per_second": 94.674,
|
| 338 |
+
"eval_steps_per_second": 2.959,
|
| 339 |
+
"step": 3500
|
| 340 |
+
},
|
| 341 |
+
{
|
| 342 |
+
"epoch": 0.22741629816803538,
|
| 343 |
+
"grad_norm": 6.513636589050293,
|
| 344 |
+
"learning_rate": 1.5156875131606654e-05,
|
| 345 |
+
"loss": 0.5385,
|
| 346 |
+
"step": 3600
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
"epoch": 0.2337334175615919,
|
| 350 |
+
"grad_norm": 3.679028272628784,
|
| 351 |
+
"learning_rate": 1.5578016424510425e-05,
|
| 352 |
+
"loss": 0.535,
|
| 353 |
+
"step": 3700
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"epoch": 0.24005053695514844,
|
| 357 |
+
"grad_norm": 4.075804233551025,
|
| 358 |
+
"learning_rate": 1.5999157717414192e-05,
|
| 359 |
+
"loss": 0.5328,
|
| 360 |
+
"step": 3800
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"epoch": 0.246367656348705,
|
| 364 |
+
"grad_norm": 5.875431060791016,
|
| 365 |
+
"learning_rate": 1.6420299010317962e-05,
|
| 366 |
+
"loss": 0.5185,
|
| 367 |
+
"step": 3900
|
| 368 |
+
},
|
| 369 |
+
{
|
| 370 |
+
"epoch": 0.2526847757422615,
|
| 371 |
+
"grad_norm": 4.358110427856445,
|
| 372 |
+
"learning_rate": 1.6841440303221732e-05,
|
| 373 |
+
"loss": 0.5258,
|
| 374 |
+
"step": 4000
|
| 375 |
+
},
|
| 376 |
+
{
|
| 377 |
+
"epoch": 0.2526847757422615,
|
| 378 |
+
"eval_accuracy": 0.7764471767172453,
|
| 379 |
+
"eval_f1": 0.7730067867423673,
|
| 380 |
+
"eval_loss": 0.5273372530937195,
|
| 381 |
+
"eval_precision": 0.7717055148059055,
|
| 382 |
+
"eval_recall": 0.7764471767172453,
|
| 383 |
+
"eval_runtime": 619.3303,
|
| 384 |
+
"eval_samples_per_second": 90.876,
|
| 385 |
+
"eval_steps_per_second": 2.84,
|
| 386 |
+
"step": 4000
|
| 387 |
+
}
|
| 388 |
+
],
|
| 389 |
+
"logging_steps": 100,
|
| 390 |
+
"max_steps": 47490,
|
| 391 |
+
"num_input_tokens_seen": 0,
|
| 392 |
+
"num_train_epochs": 3,
|
| 393 |
+
"save_steps": 500,
|
| 394 |
+
"stateful_callbacks": {
|
| 395 |
+
"TrainerControl": {
|
| 396 |
+
"args": {
|
| 397 |
+
"should_epoch_stop": false,
|
| 398 |
+
"should_evaluate": false,
|
| 399 |
+
"should_log": false,
|
| 400 |
+
"should_save": true,
|
| 401 |
+
"should_training_stop": false
|
| 402 |
+
},
|
| 403 |
+
"attributes": {}
|
| 404 |
+
}
|
| 405 |
+
},
|
| 406 |
+
"total_flos": 8419629367296000.0,
|
| 407 |
+
"train_batch_size": 32,
|
| 408 |
+
"trial_name": null,
|
| 409 |
+
"trial_params": null
|
| 410 |
+
}
|
results/checkpoint-4000/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:88724115c05a14013e1bf5182b6efa00270cfee7c30485da6eb058c6d09f75a8
|
| 3 |
+
size 5805
|
results/checkpoint-4000/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
results/images/data_distribution_2025-12-18_15-27-36.png
ADDED
|
results/images/metrics_2025-12-18_15-06-59.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Timestamp: 2025-12-18_15-06-59
|
| 2 |
+
Final Validation Accuracy: 0.7683
|
| 3 |
+
Final Validation Loss: 0.5479554533958435
|
| 4 |
+
Plot saved to: training_metrics_2025-12-18_15-06-59.png
|
results/images/metrics_2025-12-18_15-19-18.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Timestamp: 2025-12-18_15-19-18
|
| 2 |
+
Final Validation Accuracy: 0.7719
|
| 3 |
+
Final Validation Loss: 0.538950502872467
|
| 4 |
+
Plot saved to: training_metrics_2025-12-18_15-19-18.png
|
results/images/metrics_2025-12-18_15-25-36.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Timestamp: 2025-12-18_15-25-36
|
| 2 |
+
Final Validation Accuracy: 0.7719
|
| 3 |
+
Final Validation Loss: 0.538950502872467
|
| 4 |
+
Plot saved to: training_metrics_2025-12-18_15-25-36.png
|
results/images/metrics_2025-12-18_15-27-41.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Timestamp: 2025-12-18_15-27-41
|
| 2 |
+
Final Validation Accuracy: 0.7746
|
| 3 |
+
Final Validation Loss: 0.5276312828063965
|
| 4 |
+
Plot saved to: training_metrics_2025-12-18_15-27-41.png
|
results/images/training_metrics_2025-12-18_15-06-59.png
ADDED
|
results/images/training_metrics_2025-12-18_15-19-18.png
ADDED
|
results/images/training_metrics_2025-12-18_15-25-36.png
ADDED
|
results/images/training_metrics_2025-12-18_15-27-41.png
ADDED
|
src/__init__.py
ADDED
|
File without changes
|
src/config.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
class Config:
|
| 4 |
+
# 路径配置
|
| 5 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
+
DATA_DIR = os.path.join(ROOT_DIR, 'data')
|
| 7 |
+
CHECKPOINT_DIR = os.path.join(ROOT_DIR, 'checkpoints')
|
| 8 |
+
RESULTS_DIR = os.path.join(ROOT_DIR, 'results')
|
| 9 |
+
OUTPUT_DIR = CHECKPOINT_DIR # Alias for compatibility
|
| 10 |
+
|
| 11 |
+
# 模型配置
|
| 12 |
+
BASE_MODEL = "google-bert/bert-base-chinese"
|
| 13 |
+
NUM_LABELS = 3
|
| 14 |
+
MAX_LENGTH = 128
|
| 15 |
+
|
| 16 |
+
# 训练配置
|
| 17 |
+
BATCH_SIZE = 32
|
| 18 |
+
LEARNING_RATE = 2e-5
|
| 19 |
+
NUM_EPOCHS = 3
|
| 20 |
+
WARMUP_RATIO = 0.1
|
| 21 |
+
WEIGHT_DECAY = 0.01
|
| 22 |
+
LOGGING_STEPS = 100
|
| 23 |
+
SAVE_STEPS = 500
|
| 24 |
+
EVAL_STEPS = 500
|
| 25 |
+
|
| 26 |
+
# 标签映射
|
| 27 |
+
LABEL2ID = {'negative': 0, 'neutral': 1, 'positive': 2}
|
| 28 |
+
ID2LABEL = {0: 'negative', 1: 'neutral', 2: 'positive'}
|
src/dataset.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from datasets import load_dataset, Dataset, concatenate_datasets, load_from_disk
|
| 4 |
+
from .config import Config
|
| 5 |
+
|
| 6 |
+
class DataProcessor:
|
| 7 |
+
def __init__(self, tokenizer):
|
| 8 |
+
self.tokenizer = tokenizer
|
| 9 |
+
|
| 10 |
+
def load_clap_data(self):
|
| 11 |
+
"""
|
| 12 |
+
加载 clapAI/MultiLingualSentiment 数据集的中文部分
|
| 13 |
+
"""
|
| 14 |
+
print("Loading clapAI/MultiLingualSentiment (zh)...")
|
| 15 |
+
try:
|
| 16 |
+
# 假设数据集结构支持 language='zh' 筛选,或者我们加载后筛选
|
| 17 |
+
# 注意:实际使用时可能需要根据具体 Hugging Face dataset 的 config name 调整
|
| 18 |
+
ds = load_dataset("clapAI/MultiLingualSentiment", "zh", split="train", trust_remote_code=True)
|
| 19 |
+
except Exception:
|
| 20 |
+
# Fallback if specific config not found, load all and filter (demo logic)
|
| 21 |
+
print("Warning: Could not load 'zh' specific config, attempting to load generic...")
|
| 22 |
+
ds = load_dataset("clapAI/MultiLingualSentiment", split="train", trust_remote_code=True)
|
| 23 |
+
ds = ds.filter(lambda x: x['language'] == 'zh')
|
| 24 |
+
|
| 25 |
+
# 映射标签 (假设原标签格式需要调整,这里做通用处理)
|
| 26 |
+
# 假设原数据集 label已经是 0,1,2 或者需要 map
|
| 27 |
+
# 这里为了演示,我们假设它已经是标准格式,或者我们需要查看数据结构
|
| 28 |
+
# 为保证稳健性,我们在 map_function 中处理
|
| 29 |
+
return ds
|
| 30 |
+
|
| 31 |
+
def load_medical_data(self):
|
| 32 |
+
"""
|
| 33 |
+
加载 OpenModels/Chinese-Herbal-Medicine-Sentiment 垂直领域数据
|
| 34 |
+
"""
|
| 35 |
+
print("Loading OpenModels/Chinese-Herbal-Medicine-Sentiment...")
|
| 36 |
+
ds = load_dataset("OpenModels/Chinese-Herbal-Medicine-Sentiment", split="train", trust_remote_code=True)
|
| 37 |
+
return ds
|
| 38 |
+
|
| 39 |
+
def clean_data(self, examples):
|
| 40 |
+
"""
|
| 41 |
+
数据清洗逻辑
|
| 42 |
+
"""
|
| 43 |
+
text = examples['text']
|
| 44 |
+
|
| 45 |
+
# 1. 剔除“默认好评”噪音
|
| 46 |
+
if "此用户未填写评价内容" in text:
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
# 简单长度过滤,太短的可能无意义
|
| 50 |
+
if len(text.strip()) < 2:
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
return True
|
| 54 |
+
|
| 55 |
+
def unify_labels(self, example):
|
| 56 |
+
"""
|
| 57 |
+
统一标签为: 0 (Negative), 1 (Neutral), 2 (Positive)
|
| 58 |
+
"""
|
| 59 |
+
label = example['label']
|
| 60 |
+
|
| 61 |
+
# 根据数据集实际情况调整映射逻辑
|
| 62 |
+
# 这里假设传入的数据集 label 可能是 string 或 int
|
| 63 |
+
# 这是一个示例映射,实际运行时需根据 print(ds.features) 确认
|
| 64 |
+
if isinstance(label, str):
|
| 65 |
+
label = label.lower()
|
| 66 |
+
if label in ['negative', 'pos', '0']: # 示例
|
| 67 |
+
return {'labels': 0}
|
| 68 |
+
elif label in ['neutral', 'neu', '1']:
|
| 69 |
+
return {'labels': 1}
|
| 70 |
+
elif label in ['positive', 'neg', '2']:
|
| 71 |
+
return {'labels': 2}
|
| 72 |
+
|
| 73 |
+
# 如果已经是 int,确保在 0-2 之间
|
| 74 |
+
return {'labels': int(label)}
|
| 75 |
+
|
| 76 |
+
def tokenize_function(self, examples):
|
| 77 |
+
return self.tokenizer(
|
| 78 |
+
examples['text'],
|
| 79 |
+
padding="max_length",
|
| 80 |
+
truncation=True,
|
| 81 |
+
max_length=Config.MAX_LENGTH
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def get_processed_dataset(self, cache_dir=None, num_proc=1):
|
| 85 |
+
# 默认使用 Config.DATA_DIR 作为缓存目录
|
| 86 |
+
if cache_dir is None:
|
| 87 |
+
cache_dir = Config.DATA_DIR
|
| 88 |
+
|
| 89 |
+
# 0. 尝试从本地加载已处理的数据
|
| 90 |
+
processed_path = os.path.join(cache_dir, "processed_dataset")
|
| 91 |
+
if os.path.exists(processed_path):
|
| 92 |
+
print(f"Loading processed dataset from {processed_path}...")
|
| 93 |
+
return load_from_disk(processed_path)
|
| 94 |
+
|
| 95 |
+
# 1. 加载数据
|
| 96 |
+
ds_clap = self.load_clap_data()
|
| 97 |
+
ds_med = self.load_medical_data()
|
| 98 |
+
|
| 99 |
+
# 2. 统一列名 (确保都有 'text' 和 'label')
|
| 100 |
+
# OpenModels keys: ['username', 'user_id', 'review_text', 'review_time', 'rating', 'product_id', 'sentiment_label', 'source_file']
|
| 101 |
+
if 'review_text' in ds_med.column_names:
|
| 102 |
+
ds_med = ds_med.rename_column('review_text', 'text')
|
| 103 |
+
if 'sentiment_label' in ds_med.column_names:
|
| 104 |
+
ds_med = ds_med.rename_column('sentiment_label', 'label')
|
| 105 |
+
|
| 106 |
+
# 3. 数据清洗
|
| 107 |
+
print("Cleaning datasets...")
|
| 108 |
+
ds_med = ds_med.filter(self.clean_data)
|
| 109 |
+
ds_clap = ds_clap.filter(self.clean_data)
|
| 110 |
+
|
| 111 |
+
# 4. 合并
|
| 112 |
+
# 确保 features 一致
|
| 113 |
+
common_cols = ['text', 'label']
|
| 114 |
+
ds_clap = ds_clap.select_columns(common_cols)
|
| 115 |
+
ds_med = ds_med.select_columns(common_cols)
|
| 116 |
+
|
| 117 |
+
combined_ds = concatenate_datasets([ds_clap, ds_med])
|
| 118 |
+
|
| 119 |
+
# 5.标签处理 & Tokenization
|
| 120 |
+
# transform label -> labels
|
| 121 |
+
combined_ds = combined_ds.map(self.unify_labels, remove_columns=['label'])
|
| 122 |
+
|
| 123 |
+
# tokenize and remove text
|
| 124 |
+
tokenized_ds = combined_ds.map(
|
| 125 |
+
self.tokenize_function,
|
| 126 |
+
batched=True,
|
| 127 |
+
remove_columns=['text']
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# 划分训练集和验证集
|
| 131 |
+
split_ds = tokenized_ds.train_test_split(test_size=0.1)
|
| 132 |
+
|
| 133 |
+
return split_ds
|
src/debug_paths.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
from config import Config
|
| 4 |
+
|
| 5 |
+
print(f"Current Working Directory: {os.getcwd()}")
|
| 6 |
+
print(f"Config.RESULTS_DIR: {Config.RESULTS_DIR}")
|
| 7 |
+
|
| 8 |
+
# Debug Finding Checkpoints
|
| 9 |
+
candidates = glob.glob(os.path.join(Config.RESULTS_DIR, "checkpoint-*"))
|
| 10 |
+
print(f"Found {len(candidates)} candidates:")
|
| 11 |
+
for c in candidates:
|
| 12 |
+
print(f" - {c}")
|
| 13 |
+
|
| 14 |
+
if not candidates:
|
| 15 |
+
# Try relative path manual
|
| 16 |
+
print("Trying relative path './results/checkpoint-*'...")
|
| 17 |
+
candidates = glob.glob("./results/checkpoint-*")
|
| 18 |
+
print(f"Found {len(candidates)} candidates via relative:")
|
| 19 |
+
for c in candidates:
|
| 20 |
+
print(f" - {c}")
|
src/metrics.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
| 3 |
+
|
| 4 |
+
def compute_metrics(pred):
|
| 5 |
+
labels = pred.label_ids
|
| 6 |
+
preds = pred.predictions.argmax(-1)
|
| 7 |
+
|
| 8 |
+
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
|
| 9 |
+
acc = accuracy_score(labels, preds)
|
| 10 |
+
|
| 11 |
+
return {
|
| 12 |
+
'accuracy': acc,
|
| 13 |
+
'f1': f1,
|
| 14 |
+
'precision': precision,
|
| 15 |
+
'recall': recall
|
| 16 |
+
}
|
src/monitor.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
+
import glob
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
def get_latest_checkpoint(checkpoint_dir):
|
| 9 |
+
# 查找所有 checkpoint-XXX 文件夹
|
| 10 |
+
checkpoints = glob.glob(os.path.join(checkpoint_dir, "checkpoint-*"))
|
| 11 |
+
if not checkpoints:
|
| 12 |
+
return None
|
| 13 |
+
# 按修改时间排序,最新的在最后
|
| 14 |
+
checkpoints.sort(key=os.path.getmtime)
|
| 15 |
+
return checkpoints[-1]
|
| 16 |
+
|
| 17 |
+
def read_metrics(checkpoint_path):
|
| 18 |
+
state_file = os.path.join(checkpoint_path, "trainer_state.json")
|
| 19 |
+
if not os.path.exists(state_file):
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
with open(state_file, 'r') as f:
|
| 24 |
+
data = json.load(f)
|
| 25 |
+
return data.get("log_history", [])
|
| 26 |
+
except:
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
def monitor(checkpoint_dir="checkpoints"):
|
| 30 |
+
print(f"👀 开始监视训练目录: {checkpoint_dir}")
|
| 31 |
+
print("按 Ctrl+C 退出监视")
|
| 32 |
+
print("-" * 50)
|
| 33 |
+
|
| 34 |
+
last_step = -1
|
| 35 |
+
|
| 36 |
+
while True:
|
| 37 |
+
latest_ckpt = get_latest_checkpoint(checkpoint_dir)
|
| 38 |
+
if latest_ckpt:
|
| 39 |
+
folder_name = os.path.basename(latest_ckpt)
|
| 40 |
+
logs = read_metrics(latest_ckpt)
|
| 41 |
+
|
| 42 |
+
if logs:
|
| 43 |
+
# 找到最新的 eval 记录
|
| 44 |
+
latest_log = logs[-1]
|
| 45 |
+
current_step = latest_log.get('step', 0)
|
| 46 |
+
|
| 47 |
+
# 如果有更新
|
| 48 |
+
if current_step != last_step:
|
| 49 |
+
timestamp = datetime.now().strftime("%H:%M:%S")
|
| 50 |
+
|
| 51 |
+
# 尝试寻找验证集指标 (eval_accuracy 等)
|
| 52 |
+
# log_history 混杂了 training loss 和 eval metrics
|
| 53 |
+
# 我们倒序找最近的一个包含 eval_accuracy 的记录
|
| 54 |
+
eval_record = None
|
| 55 |
+
train_record = None
|
| 56 |
+
|
| 57 |
+
for log in reversed(logs):
|
| 58 |
+
if 'eval_accuracy' in log and eval_record is None:
|
| 59 |
+
eval_record = log
|
| 60 |
+
if 'loss' in log and train_record is None:
|
| 61 |
+
train_record = log
|
| 62 |
+
if eval_record and train_record:
|
| 63 |
+
break
|
| 64 |
+
|
| 65 |
+
print(f"[{timestamp}] 最新检查点: {folder_name}")
|
| 66 |
+
if train_record:
|
| 67 |
+
print(f" 📉 Training Loss: {train_record.get('loss', 'N/A'):.4f} (Epoch {train_record.get('epoch', 'N/A'):.2f})")
|
| 68 |
+
if eval_record:
|
| 69 |
+
print(f" ✅ Eval Accuracy: {eval_record.get('eval_accuracy', 'N/A'):.4f}")
|
| 70 |
+
print(f" ✅ Eval F1 Score: {eval_record.get('eval_f1', 'N/A'):.4f}")
|
| 71 |
+
print("-" * 50)
|
| 72 |
+
|
| 73 |
+
last_step = current_step
|
| 74 |
+
|
| 75 |
+
time.sleep(10) # 每10秒检查一次
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
# 尝试从 config 读取路径,如果失败则使用默认
|
| 79 |
+
try:
|
| 80 |
+
from config import Config
|
| 81 |
+
ckpt_dir = Config.CHECKPOINT_DIR
|
| 82 |
+
except:
|
| 83 |
+
ckpt_dir = "checkpoints"
|
| 84 |
+
|
| 85 |
+
monitor(ckpt_dir)
|
src/predict.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 4 |
+
from .config import Config
|
| 5 |
+
|
| 6 |
+
class SentimentPredictor:
|
| 7 |
+
def __init__(self, model_path=None):
|
| 8 |
+
# 1. 如果未指定路径,尝试自动寻找最新的模型
|
| 9 |
+
if model_path is None:
|
| 10 |
+
# 优先检查 Config.CHECKPOINT_DIR (如果训练完成,final_model 会在这里)
|
| 11 |
+
if os.path.exists(os.path.join(Config.CHECKPOINT_DIR, "config.json")):
|
| 12 |
+
model_path = Config.CHECKPOINT_DIR
|
| 13 |
+
else:
|
| 14 |
+
# 如果没有 final_model,尝试寻找最新的 checkpoint (在 results 目录)
|
| 15 |
+
import glob
|
| 16 |
+
ckpt_list = glob.glob(os.path.join(Config.RESULTS_DIR, "checkpoint-*"))
|
| 17 |
+
if ckpt_list:
|
| 18 |
+
# 按修改时间排序,取最新的
|
| 19 |
+
ckpt_list.sort(key=os.path.getmtime)
|
| 20 |
+
model_path = ckpt_list[-1]
|
| 21 |
+
print(f"Using latest checkpoint found: {model_path}")
|
| 22 |
+
else:
|
| 23 |
+
# 只有在真的找不到时才回退
|
| 24 |
+
model_path = Config.CHECKPOINT_DIR
|
| 25 |
+
|
| 26 |
+
print(f"Loading model from {model_path}...")
|
| 27 |
+
try:
|
| 28 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 29 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
| 30 |
+
except OSError:
|
| 31 |
+
print(f"Warning: Model not found at {model_path}. Loading base model for demo purpose.")
|
| 32 |
+
self.tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
|
| 33 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(Config.BASE_MODEL, num_labels=Config.NUM_LABELS)
|
| 34 |
+
|
| 35 |
+
# Device selection
|
| 36 |
+
if torch.backends.mps.is_available():
|
| 37 |
+
self.device = torch.device("mps")
|
| 38 |
+
elif torch.cuda.is_available():
|
| 39 |
+
self.device = torch.device("cuda")
|
| 40 |
+
else:
|
| 41 |
+
self.device = torch.device("cpu")
|
| 42 |
+
|
| 43 |
+
self.model.to(self.device)
|
| 44 |
+
self.model.eval()
|
| 45 |
+
|
| 46 |
+
def predict(self, text):
|
| 47 |
+
inputs = self.tokenizer(
|
| 48 |
+
text,
|
| 49 |
+
return_tensors="pt",
|
| 50 |
+
truncation=True,
|
| 51 |
+
max_length=Config.MAX_LENGTH,
|
| 52 |
+
padding=True
|
| 53 |
+
)
|
| 54 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 55 |
+
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
outputs = self.model(**inputs)
|
| 58 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 59 |
+
prediction = torch.argmax(probabilities, dim=-1).item()
|
| 60 |
+
score = probabilities[0][prediction].item()
|
| 61 |
+
|
| 62 |
+
label = Config.ID2LABEL.get(prediction, "unknown")
|
| 63 |
+
return {
|
| 64 |
+
"text": text,
|
| 65 |
+
"sentiment": label,
|
| 66 |
+
"confidence": f"{score:.4f}"
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
# Demo
|
| 71 |
+
predictor = SentimentPredictor()
|
| 72 |
+
test_texts = [
|
| 73 |
+
"这家店的快递太慢了,而且东西味道很奇怪。",
|
| 74 |
+
"非常不错,包装很精美,下次还会来买。",
|
| 75 |
+
"感觉一般般吧,没有想象中那么好,但也还可以。"
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
print("\nPredicting...")
|
| 79 |
+
for text in test_texts:
|
| 80 |
+
result = predictor.predict(text)
|
| 81 |
+
print(f"Text: {result['text']}")
|
| 82 |
+
print(f"Sentiment: {result['sentiment']} (Confidence: {result['confidence']})")
|
| 83 |
+
print("-" * 30)
|
src/prepare_data.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from .config import Config
|
| 5 |
+
from .dataset import DataProcessor
|
| 6 |
+
|
| 7 |
+
def main():
|
| 8 |
+
print("⏳ 开始下载并处理数据...")
|
| 9 |
+
|
| 10 |
+
# 1. 确保 data 目录存在
|
| 11 |
+
if not os.path.exists(Config.DATA_DIR):
|
| 12 |
+
os.makedirs(Config.DATA_DIR)
|
| 13 |
+
|
| 14 |
+
# 2. 初始化流程
|
| 15 |
+
tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
|
| 16 |
+
processor = DataProcessor(tokenizer)
|
| 17 |
+
|
| 18 |
+
# 3. 获取处理后的数据 (get_processed_dataset 内部已经有加载逻辑)
|
| 19 |
+
# 注意:我们这里为了保存原始数据,可能需要调用 load_clap_data 和 load_medical_data
|
| 20 |
+
# 但 DataProcessor.get_processed_dataset 返回的是 encode 后的数据。
|
| 21 |
+
# 用户可能想要的是 Raw Data 或者 Processed Data。
|
| 22 |
+
# 这里我们保存 Processed Data (Ready for Training) 到磁盘
|
| 23 |
+
|
| 24 |
+
dataset = processor.get_processed_dataset()
|
| 25 |
+
|
| 26 |
+
save_path = os.path.join(Config.DATA_DIR, "processed_dataset")
|
| 27 |
+
print(f"💾 正在保存处理后的数据集到: {save_path}")
|
| 28 |
+
dataset.save_to_disk(save_path)
|
| 29 |
+
|
| 30 |
+
print("✅ 数据保存完成!")
|
| 31 |
+
print(f" Train set size: {len(dataset['train'])}")
|
| 32 |
+
print(f" Test set size: {len(dataset['test'])}")
|
| 33 |
+
print(" 下次加载可直接使用: from datasets import load_from_disk")
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
main()
|
src/train.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import (
|
| 4 |
+
AutoTokenizer,
|
| 5 |
+
AutoModelForSequenceClassification,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
Trainer
|
| 8 |
+
)
|
| 9 |
+
from .config import Config
|
| 10 |
+
from .dataset import DataProcessor
|
| 11 |
+
from .metrics import compute_metrics
|
| 12 |
+
from .visualization import plot_training_history
|
| 13 |
+
|
| 14 |
+
def main():
|
| 15 |
+
# 0. 设备检测 (针对 Mac Mini 优化)
|
| 16 |
+
if torch.backends.mps.is_available():
|
| 17 |
+
device = torch.device("mps")
|
| 18 |
+
print(f"Using device: MPS (Mac Silicon Acceleration)")
|
| 19 |
+
elif torch.cuda.is_available():
|
| 20 |
+
device = torch.device("cuda")
|
| 21 |
+
print(f"Using device: CUDA")
|
| 22 |
+
else:
|
| 23 |
+
device = torch.device("cpu")
|
| 24 |
+
print(f"Using device: CPU")
|
| 25 |
+
|
| 26 |
+
# 1. 初始化 Tokenizer
|
| 27 |
+
print(f"Loading tokenizer from {Config.BASE_MODEL}...")
|
| 28 |
+
tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
|
| 29 |
+
|
| 30 |
+
# 2. 准备数据
|
| 31 |
+
print("Preparing datasets...")
|
| 32 |
+
processor = DataProcessor(tokenizer)
|
| 33 |
+
# 使用 Config.DATA_DIR 确保数据下载到正确位置
|
| 34 |
+
# 使用多进程加速数据处理
|
| 35 |
+
num_proc = max(1, os.cpu_count() - 1)
|
| 36 |
+
# 注意: get_processed_dataset 内部需要实现真实的加载逻辑,这里假设 dataset.py 已经完善
|
| 37 |
+
# 如果 dataset.py 中有模拟逻辑,实际运行时需要联网下载数据
|
| 38 |
+
dataset = processor.get_processed_dataset(cache_dir=Config.DATA_DIR, num_proc=num_proc)
|
| 39 |
+
|
| 40 |
+
train_dataset = dataset['train']
|
| 41 |
+
eval_dataset = dataset['test']
|
| 42 |
+
|
| 43 |
+
print(f"Training on {len(train_dataset)} samples, Validating on {len(eval_dataset)} samples.")
|
| 44 |
+
|
| 45 |
+
# 3. 加载模型
|
| 46 |
+
print("Loading model...")
|
| 47 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 48 |
+
Config.BASE_MODEL,
|
| 49 |
+
num_labels=Config.NUM_LABELS,
|
| 50 |
+
id2label=Config.ID2LABEL,
|
| 51 |
+
label2id=Config.LABEL2ID
|
| 52 |
+
)
|
| 53 |
+
model.to(device)
|
| 54 |
+
|
| 55 |
+
# 4. 配置训练参数
|
| 56 |
+
training_args = TrainingArguments(
|
| 57 |
+
output_dir=Config.RESULTS_DIR,
|
| 58 |
+
num_train_epochs=Config.NUM_EPOCHS,
|
| 59 |
+
per_device_train_batch_size=Config.BATCH_SIZE,
|
| 60 |
+
per_device_eval_batch_size=Config.BATCH_SIZE,
|
| 61 |
+
learning_rate=Config.LEARNING_RATE,
|
| 62 |
+
warmup_ratio=Config.WARMUP_RATIO,
|
| 63 |
+
weight_decay=Config.WEIGHT_DECAY,
|
| 64 |
+
logging_dir=os.path.join(Config.RESULTS_DIR, 'logs'),
|
| 65 |
+
logging_steps=Config.LOGGING_STEPS,
|
| 66 |
+
eval_strategy="steps",
|
| 67 |
+
eval_steps=Config.EVAL_STEPS,
|
| 68 |
+
save_steps=Config.SAVE_STEPS,
|
| 69 |
+
load_best_model_at_end=True,
|
| 70 |
+
metric_for_best_model="f1",
|
| 71 |
+
# Mac MPS 特定优化:
|
| 72 |
+
# huggingface trainer 默认支持 mps,如果不手动指定 no_cuda,它通常会自动检测
|
| 73 |
+
# 但为了保险,我们可以尽量让 trainer 自己处理,或者显式use_mps_device (老版本不仅用)
|
| 74 |
+
# 最新版 transformers 会自动通过 accelerate 处理 device
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# 5. 初始化 Trainer
|
| 78 |
+
trainer = Trainer(
|
| 79 |
+
model=model,
|
| 80 |
+
args=training_args,
|
| 81 |
+
train_dataset=train_dataset,
|
| 82 |
+
eval_dataset=eval_dataset,
|
| 83 |
+
tokenizer=tokenizer,
|
| 84 |
+
compute_metrics=compute_metrics,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# 6. 开始训练
|
| 88 |
+
print("Starting training...")
|
| 89 |
+
trainer.train()
|
| 90 |
+
|
| 91 |
+
# 7. 保存最终模型
|
| 92 |
+
print(f"Saving model to {Config.CHECKPOINT_DIR}...")
|
| 93 |
+
trainer.save_model(Config.CHECKPOINT_DIR)
|
| 94 |
+
tokenizer.save_pretrained(Config.CHECKPOINT_DIR)
|
| 95 |
+
|
| 96 |
+
# 8. 绘制训练曲线
|
| 97 |
+
print("Generating training plots...")
|
| 98 |
+
plot_save_path = os.path.join(Config.RESULTS_DIR, 'training_curves.png')
|
| 99 |
+
plot_training_history(trainer.state.log_history, save_path=plot_save_path)
|
| 100 |
+
|
| 101 |
+
print("Training completed!")
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
main()
|
src/upload_to_hf.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import glob
|
| 4 |
+
import shutil
|
| 5 |
+
from huggingface_hub import HfApi, create_repo, upload_folder
|
| 6 |
+
from config import Config
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
print("🚀 开始全量上传 (All-in-One) 到 robot4/sentiment-analysis-bert-finetuned ...")
|
| 10 |
+
|
| 11 |
+
api = HfApi()
|
| 12 |
+
try:
|
| 13 |
+
user_info = api.whoami()
|
| 14 |
+
username = user_info['name']
|
| 15 |
+
print(f"✅ User: {username}")
|
| 16 |
+
except:
|
| 17 |
+
print("❌ Please login first.")
|
| 18 |
+
return
|
| 19 |
+
|
| 20 |
+
# 目标仓库 (用户指定)
|
| 21 |
+
target_repo_id = "robot4/sentiment-analysis-bert-finetuned"
|
| 22 |
+
|
| 23 |
+
# 1. 准备临时上传目录
|
| 24 |
+
upload_dir = "hf_upload_staging"
|
| 25 |
+
if os.path.exists(upload_dir):
|
| 26 |
+
shutil.rmtree(upload_dir)
|
| 27 |
+
os.makedirs(upload_dir)
|
| 28 |
+
|
| 29 |
+
print(f"📦 正在打包所有文件到 {upload_dir}...")
|
| 30 |
+
|
| 31 |
+
# A. 复制项目代码和资源
|
| 32 |
+
# 包含了 data, src, docs, notebooks, demo, results/images 等
|
| 33 |
+
items_to_copy = [
|
| 34 |
+
"src", "notebooks", "docs", "demo", "data",
|
| 35 |
+
"README.md", "requirements.txt", "*.pptx"
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
for pattern in items_to_copy:
|
| 39 |
+
for item in glob.glob(pattern):
|
| 40 |
+
dest = os.path.join(upload_dir, item)
|
| 41 |
+
print(f" - Adding {item}...")
|
| 42 |
+
if os.path.isdir(item):
|
| 43 |
+
shutil.copytree(item, dest, dirs_exist_ok=True)
|
| 44 |
+
else:
|
| 45 |
+
shutil.copy2(item, dest)
|
| 46 |
+
|
| 47 |
+
# B. 特殊处理 results 目录 (只传图片和 logs,不传所有 checkpoint 文件夹)
|
| 48 |
+
results_dest = os.path.join(upload_dir, "results")
|
| 49 |
+
os.makedirs(results_dest, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
# 复制图片
|
| 52 |
+
if os.path.exists("results/images"):
|
| 53 |
+
shutil.copytree("results/images", os.path.join(results_dest, "images"), dirs_exist_ok=True)
|
| 54 |
+
# 复制 txt metrics
|
| 55 |
+
for txt in glob.glob("results/*.txt"):
|
| 56 |
+
shutil.copy2(txt, results_dest)
|
| 57 |
+
|
| 58 |
+
# C. 提取最新模型权重到根目录 (方便直接加载)
|
| 59 |
+
candidates = glob.glob(os.path.join(Config.RESULTS_DIR, "checkpoint-*"))
|
| 60 |
+
candidates = [c for c in candidates if os.path.isdir(c)]
|
| 61 |
+
|
| 62 |
+
if candidates:
|
| 63 |
+
candidates.sort(key=os.path.getmtime)
|
| 64 |
+
latest_ckpt = candidates[-1]
|
| 65 |
+
print(f"✅ 提取最新模型权重: {latest_ckpt} -> 根目录")
|
| 66 |
+
|
| 67 |
+
model_files = ["config.json", "model.safetensors", "pytorch_model.bin", "tokenizer.json", "vocab.txt", "tokenizer_config.json", "special_tokens_map.json"]
|
| 68 |
+
|
| 69 |
+
for fname in os.listdir(latest_ckpt):
|
| 70 |
+
if fname in model_files or fname.endswith(".safetensors") or fname.endswith(".bin"):
|
| 71 |
+
shutil.copy2(os.path.join(latest_ckpt, fname), os.path.join(upload_dir, fname))
|
| 72 |
+
else:
|
| 73 |
+
print("⚠️ 未找到 Checkpoint,仅上传代码和数据。")
|
| 74 |
+
|
| 75 |
+
# 2. 执行上传
|
| 76 |
+
print(f"\n⬆️ 正在上传所有文件到 https://huggingface.co/{target_repo_id}")
|
| 77 |
+
create_repo(repo_id=target_repo_id, repo_type="model", exist_ok=True)
|
| 78 |
+
|
| 79 |
+
upload_folder(
|
| 80 |
+
folder_path=upload_dir,
|
| 81 |
+
repo_id=target_repo_id,
|
| 82 |
+
repo_type="model"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Cleanup
|
| 86 |
+
shutil.rmtree(upload_dir)
|
| 87 |
+
print("🎉 上传完毕!")
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 91 |
+
parent_dir = os.path.dirname(current_dir)
|
| 92 |
+
sys.path.append(parent_dir)
|
| 93 |
+
main()
|
src/visualization.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import seaborn as sns
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
# 设置中文字体 (尝试自动寻找可用字体)
|
| 9 |
+
def set_chinese_font():
|
| 10 |
+
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'PingFang SC', 'Heiti TC']
|
| 11 |
+
plt.rcParams['axes.unicode_minus'] = False
|
| 12 |
+
|
| 13 |
+
def plot_data_distribution(dataset_dict, save_path=None):
|
| 14 |
+
"""
|
| 15 |
+
绘制数据集中 Positive/Neutral/Negative 的分布饼图
|
| 16 |
+
"""
|
| 17 |
+
set_chinese_font()
|
| 18 |
+
|
| 19 |
+
# 统计数量
|
| 20 |
+
# 兼容 dataset_dict (DatasetDict) 或 dataset (Dataset)
|
| 21 |
+
if hasattr(dataset_dict, 'keys') and 'train' in dataset_dict.keys():
|
| 22 |
+
ds = dataset_dict['train']
|
| 23 |
+
else:
|
| 24 |
+
ds = dataset_dict
|
| 25 |
+
|
| 26 |
+
# 统计数量
|
| 27 |
+
if 'label' in ds.features:
|
| 28 |
+
train_labels = ds['label']
|
| 29 |
+
elif 'labels' in ds.features:
|
| 30 |
+
train_labels = ds['labels']
|
| 31 |
+
else:
|
| 32 |
+
# Fallback
|
| 33 |
+
train_labels = [x.get('label', x.get('labels')) for x in ds]
|
| 34 |
+
|
| 35 |
+
# 映射回字符串以便显示
|
| 36 |
+
id2label = {0: 'Negative (消极)', 1: 'Neutral (中性)', 2: 'Positive (积极)'}
|
| 37 |
+
labels_str = [id2label.get(x, str(x)) for x in train_labels]
|
| 38 |
+
|
| 39 |
+
df = pd.DataFrame({'Label': labels_str})
|
| 40 |
+
counts = df['Label'].value_counts()
|
| 41 |
+
|
| 42 |
+
plt.figure(figsize=(10, 6))
|
| 43 |
+
plt.pie(counts, labels=counts.index, autopct='%1.1f%%', startangle=140, colors=sns.color_palette("pastel"))
|
| 44 |
+
plt.title('训练集情感分布')
|
| 45 |
+
plt.tight_layout()
|
| 46 |
+
|
| 47 |
+
if save_path:
|
| 48 |
+
print(f"Saving distribution plot to {save_path}...")
|
| 49 |
+
plt.savefig(save_path)
|
| 50 |
+
# plt.show()
|
| 51 |
+
|
| 52 |
+
def plot_training_history(log_history, save_path=None):
|
| 53 |
+
"""
|
| 54 |
+
根据 Trainer 的 log_history 绘制 Loss 和 Accuracy 曲线
|
| 55 |
+
"""
|
| 56 |
+
set_chinese_font()
|
| 57 |
+
|
| 58 |
+
if not log_history:
|
| 59 |
+
print("没有可用的训练日志。")
|
| 60 |
+
return
|
| 61 |
+
|
| 62 |
+
df = pd.DataFrame(log_history)
|
| 63 |
+
|
| 64 |
+
# 过滤掉没有 loss 或 eval_accuracy 的行
|
| 65 |
+
train_loss = df[df['loss'].notna()]
|
| 66 |
+
eval_acc = df[df['eval_accuracy'].notna()]
|
| 67 |
+
|
| 68 |
+
plt.figure(figsize=(14, 5))
|
| 69 |
+
|
| 70 |
+
# 1. Loss Curve
|
| 71 |
+
plt.subplot(1, 2, 1)
|
| 72 |
+
plt.plot(train_loss['epoch'], train_loss['loss'], label='Training Loss', color='salmon')
|
| 73 |
+
if 'eval_loss' in df.columns:
|
| 74 |
+
eval_loss = df[df['eval_loss'].notna()]
|
| 75 |
+
plt.plot(eval_loss['epoch'], eval_loss['eval_loss'], label='Validation Loss', color='skyblue')
|
| 76 |
+
plt.title('训练损失 (Loss) 曲线')
|
| 77 |
+
plt.xlabel('Epoch')
|
| 78 |
+
plt.ylabel('Loss')
|
| 79 |
+
plt.legend()
|
| 80 |
+
plt.grid(True, alpha=0.3)
|
| 81 |
+
|
| 82 |
+
# 2. Accuracy Curve
|
| 83 |
+
if not eval_acc.empty:
|
| 84 |
+
plt.subplot(1, 2, 2)
|
| 85 |
+
plt.plot(eval_acc['epoch'], eval_acc['eval_accuracy'], label='Validation Accuracy', color='lightgreen', marker='o')
|
| 86 |
+
plt.title('验证集准确率 (Accuracy)')
|
| 87 |
+
plt.xlabel('Epoch')
|
| 88 |
+
plt.ylabel('Accuracy')
|
| 89 |
+
plt.legend()
|
| 90 |
+
plt.grid(True, alpha=0.3)
|
| 91 |
+
|
| 92 |
+
# 确保目录存在
|
| 93 |
+
save_dir = os.path.join(Config.RESULTS_DIR, "images")
|
| 94 |
+
if not os.path.exists(save_dir):
|
| 95 |
+
os.makedirs(save_dir)
|
| 96 |
+
|
| 97 |
+
plt.tight_layout()
|
| 98 |
+
|
| 99 |
+
# 生成时间戳 string,例如: 2024-12-18_14-30-00
|
| 100 |
+
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 101 |
+
|
| 102 |
+
# 默认保存路径
|
| 103 |
+
if save_path is None:
|
| 104 |
+
save_path = os.path.join(save_dir, f"training_metrics_{timestamp}.png")
|
| 105 |
+
|
| 106 |
+
print(f"Saving plot to {save_path}...")
|
| 107 |
+
plt.savefig(save_path)
|
| 108 |
+
|
| 109 |
+
# 也可以保存一份 JSON 或 TXT 格式的最终指标
|
| 110 |
+
if not eval_acc.empty:
|
| 111 |
+
final_acc = eval_acc.iloc[-1]['eval_accuracy']
|
| 112 |
+
final_loss = eval_acc.iloc[-1]['eval_loss'] if 'eval_loss' in eval_acc.columns else "N/A"
|
| 113 |
+
metrics_file = os.path.join(save_dir, f"metrics_{timestamp}.txt")
|
| 114 |
+
with open(metrics_file, "w") as f:
|
| 115 |
+
f.write(f"Timestamp: {timestamp}\n")
|
| 116 |
+
f.write(f"Final Validation Accuracy: {final_acc:.4f}\n")
|
| 117 |
+
f.write(f"Final Validation Loss: {final_loss}\n")
|
| 118 |
+
f.write(f"Plot saved to: {os.path.basename(save_path)}\n")
|
| 119 |
+
print(f"Saved metrics text to {metrics_file}")
|
| 120 |
+
|
| 121 |
+
def load_and_plot_logs(log_dir):
|
| 122 |
+
"""
|
| 123 |
+
从 checkpoint 目录加载 trainer_state.json 并绘图
|
| 124 |
+
"""
|
| 125 |
+
json_path = os.path.join(log_dir, 'trainer_state.json')
|
| 126 |
+
if not os.path.exists(json_path):
|
| 127 |
+
print(f"未找到日志文件: {json_path}")
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
with open(json_path, 'r') as f:
|
| 131 |
+
data = json.load(f)
|
| 132 |
+
|
| 133 |
+
plot_training_history(data['log_history'])
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
import sys
|
| 137 |
+
import os # Explicitly import os here if not globally sufficient or for clarity
|
| 138 |
+
# 如果直接运行此脚本,解决相对导入问题
|
| 139 |
+
# 将上一级目录加入 sys.path
|
| 140 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 141 |
+
sys.path.append(project_root)
|
| 142 |
+
|
| 143 |
+
from src.config import Config
|
| 144 |
+
# ---------------------------------------------------------
|
| 145 |
+
# 2. 生成数据分布图 (Data Distribution)
|
| 146 |
+
# ---------------------------------------------------------
|
| 147 |
+
try:
|
| 148 |
+
print("\n正在加载数据集以生成样本分布分析...")
|
| 149 |
+
from transformers import AutoTokenizer
|
| 150 |
+
from src.dataset import DataProcessor
|
| 151 |
+
|
| 152 |
+
tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
|
| 153 |
+
processor = DataProcessor(tokenizer)
|
| 154 |
+
# 尝试从 data 目录加载处理好的数据 (快)
|
| 155 |
+
dataset = processor.get_processed_dataset(cache_dir=Config.DATA_DIR)
|
| 156 |
+
|
| 157 |
+
# 生成带时间戳的文件名
|
| 158 |
+
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 159 |
+
dist_save_path = os.path.join(Config.RESULTS_DIR, "images", f"data_distribution_{timestamp}.png")
|
| 160 |
+
|
| 161 |
+
# 绘图并保存
|
| 162 |
+
plot_data_distribution(dataset, save_path=dist_save_path)
|
| 163 |
+
print(f"数据样本分布分析已保存至: {dist_save_path}")
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f"无法生成数据分布图 (可能是数据尚未下载或处理): {e}")
|
| 167 |
+
|
| 168 |
+
# ---------------------------------------------------------
|
| 169 |
+
# 3. 生成训练曲线 (Training History)
|
| 170 |
+
# ---------------------------------------------------------
|
| 171 |
+
import glob
|
| 172 |
+
|
| 173 |
+
# 找最新的 checkpoints
|
| 174 |
+
search_paths = [
|
| 175 |
+
Config.OUTPUT_DIR,
|
| 176 |
+
os.path.join(Config.RESULTS_DIR, "checkpoint-*")
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
candidates = []
|
| 180 |
+
for p in search_paths:
|
| 181 |
+
candidates.extend(glob.glob(p))
|
| 182 |
+
|
| 183 |
+
if candidates:
|
| 184 |
+
# 找最新的
|
| 185 |
+
candidates.sort(key=os.path.getmtime)
|
| 186 |
+
latest_ckpt = candidates[-1]
|
| 187 |
+
print(f"Loading logs from: {latest_ckpt}")
|
| 188 |
+
load_and_plot_logs(latest_ckpt)
|
| 189 |
+
else:
|
| 190 |
+
print("未找到任何 checkpoint 或 trainer_state.json 日志文件。")
|
train_cloud.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from datasets import load_dataset, concatenate_datasets
|
| 9 |
+
from transformers import (
|
| 10 |
+
AutoTokenizer,
|
| 11 |
+
AutoModelForSequenceClassification,
|
| 12 |
+
TrainingArguments,
|
| 13 |
+
Trainer
|
| 14 |
+
)
|
| 15 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
| 16 |
+
|
| 17 |
+
# ==========================================
|
| 18 |
+
# 1. 配置 (Configuration)
|
| 19 |
+
# ==========================================
|
| 20 |
+
class Config:
|
| 21 |
+
# 基础模型
|
| 22 |
+
BASE_MODEL = "google-bert/bert-base-chinese"
|
| 23 |
+
|
| 24 |
+
# 目录配置 (根据用户要求指定)
|
| 25 |
+
BASE_DIR = os.getcwd()
|
| 26 |
+
DATA_DIR = os.path.join(BASE_DIR, "data")
|
| 27 |
+
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
|
| 28 |
+
RESULTS_DIR = os.path.join(BASE_DIR, "results")
|
| 29 |
+
DOCS_DIR = os.path.join(BASE_DIR, "docs")
|
| 30 |
+
|
| 31 |
+
# 标签配置
|
| 32 |
+
NUM_LABELS = 3
|
| 33 |
+
LABEL2ID = {'negative': 0, 'neutral': 1, 'positive': 2}
|
| 34 |
+
ID2LABEL = {0: 'negative', 1: 'neutral', 2: 'positive'}
|
| 35 |
+
|
| 36 |
+
# 训练参数
|
| 37 |
+
MAX_LENGTH = 128
|
| 38 |
+
BATCH_SIZE = 32
|
| 39 |
+
LEARNING_RATE = 2e-5
|
| 40 |
+
NUM_EPOCHS = 3
|
| 41 |
+
WARMUP_RATIO = 0.1
|
| 42 |
+
SAVE_STEPS = 500
|
| 43 |
+
LOGGING_STEPS = 100
|
| 44 |
+
|
| 45 |
+
# ==========================================
|
| 46 |
+
# 2. 工具函数 (Utils)
|
| 47 |
+
# ==========================================
|
| 48 |
+
def ensure_directories():
|
| 49 |
+
""" 确保所有必要的目录存在 """
|
| 50 |
+
for path in [Config.DATA_DIR, Config.CHECKPOINT_DIR, Config.RESULTS_DIR, Config.DOCS_DIR]:
|
| 51 |
+
if not os.path.exists(path):
|
| 52 |
+
os.makedirs(path)
|
| 53 |
+
print(f">>> Created directory: {path}")
|
| 54 |
+
|
| 55 |
+
def plot_training_history(log_history, save_path):
|
| 56 |
+
""" 绘制训练曲线并保存 """
|
| 57 |
+
try:
|
| 58 |
+
# 设置字体 (尝试通用中文字体,云端可能缺失,回退到英文)
|
| 59 |
+
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
|
| 60 |
+
plt.rcParams['axes.unicode_minus'] = False
|
| 61 |
+
|
| 62 |
+
df = pd.DataFrame(log_history)
|
| 63 |
+
train_loss = df[df['loss'].notna()]
|
| 64 |
+
eval_acc = df[df['eval_accuracy'].notna()]
|
| 65 |
+
|
| 66 |
+
if train_loss.empty:
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
plt.figure(figsize=(12, 5))
|
| 70 |
+
|
| 71 |
+
# Loss
|
| 72 |
+
plt.subplot(1, 2, 1)
|
| 73 |
+
plt.plot(train_loss['epoch'], train_loss['loss'], label='Train Loss', color='#FF6B6B')
|
| 74 |
+
if 'eval_loss' in df.columns:
|
| 75 |
+
eval_loss = df[df['eval_loss'].notna()]
|
| 76 |
+
plt.plot(eval_loss['epoch'], eval_loss['eval_loss'], label='Val Loss', color='#4ECDC4')
|
| 77 |
+
plt.title('Loss Curve')
|
| 78 |
+
plt.xlabel('Epoch')
|
| 79 |
+
plt.ylabel('Loss')
|
| 80 |
+
plt.legend()
|
| 81 |
+
plt.grid(True, alpha=0.3)
|
| 82 |
+
|
| 83 |
+
# Accuracy
|
| 84 |
+
if not eval_acc.empty:
|
| 85 |
+
plt.subplot(1, 2, 2)
|
| 86 |
+
plt.plot(eval_acc['epoch'], eval_acc['eval_accuracy'], label='Val Accuracy', color='#6BCB77', marker='o')
|
| 87 |
+
plt.title('Accuracy Curve')
|
| 88 |
+
plt.xlabel('Epoch')
|
| 89 |
+
plt.ylabel('Accuracy')
|
| 90 |
+
plt.legend()
|
| 91 |
+
plt.grid(True, alpha=0.3)
|
| 92 |
+
|
| 93 |
+
plt.tight_layout()
|
| 94 |
+
plt.savefig(save_path)
|
| 95 |
+
print(f">>> Plot saved to {save_path}")
|
| 96 |
+
plt.close()
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print(f"Warning: Plotting failed ({e})")
|
| 99 |
+
|
| 100 |
+
# ==========================================
|
| 101 |
+
# 3. 数据处理 (Data Processor)
|
| 102 |
+
# ==========================================
|
| 103 |
+
class DataProcessor:
|
| 104 |
+
def __init__(self, tokenizer):
|
| 105 |
+
self.tokenizer = tokenizer
|
| 106 |
+
|
| 107 |
+
def clean_data(self, example):
|
| 108 |
+
text = example['text']
|
| 109 |
+
if text is None: return False
|
| 110 |
+
if "此用户未填写评价内容" in text: return False
|
| 111 |
+
if len(text.strip()) < 2: return False
|
| 112 |
+
return True
|
| 113 |
+
|
| 114 |
+
def unify_labels(self, example):
|
| 115 |
+
label = example['label']
|
| 116 |
+
if isinstance(label, str):
|
| 117 |
+
label = label.lower()
|
| 118 |
+
if label in ['negative', 'pos', '0']: return {'label': 0}
|
| 119 |
+
elif label in ['neutral', 'neu', '1']: return {'label': 1}
|
| 120 |
+
elif label in ['positive', 'neg', '2']: return {'label': 2}
|
| 121 |
+
return {'label': int(label)}
|
| 122 |
+
|
| 123 |
+
def tokenize_function(self, examples):
|
| 124 |
+
return self.tokenizer(examples['text'], padding="max_length", truncation=True, max_length=Config.MAX_LENGTH)
|
| 125 |
+
|
| 126 |
+
def get_dataset(self):
|
| 127 |
+
print(">>> Loading Datasets...")
|
| 128 |
+
# 指定 cache_dir 为 data 目录
|
| 129 |
+
ds_clap = load_dataset("clapAI/MultiLingualSentiment", split="train", trust_remote_code=True, cache_dir=Config.DATA_DIR)
|
| 130 |
+
ds_med = load_dataset("OpenModels/Chinese-Herbal-Medicine-Sentiment", split="train", trust_remote_code=True, cache_dir=Config.DATA_DIR)
|
| 131 |
+
|
| 132 |
+
# 列对齐
|
| 133 |
+
if 'review_text' in ds_med.column_names: ds_med = ds_med.rename_column('review_text', 'text')
|
| 134 |
+
if 'sentiment_label' in ds_med.column_names: ds_med = ds_med.rename_column('sentiment_label', 'label')
|
| 135 |
+
if 'language' in ds_clap.column_names: ds_clap = ds_clap.filter(lambda x: x['language'] == 'zh')
|
| 136 |
+
|
| 137 |
+
common_cols = ['text', 'label']
|
| 138 |
+
combined = concatenate_datasets([ds_clap.select_columns(common_cols), ds_med.select_columns(common_cols)])
|
| 139 |
+
|
| 140 |
+
# 清洗与处理
|
| 141 |
+
combined = combined.filter(self.clean_data).map(self.unify_labels)
|
| 142 |
+
tokenized = combined.map(self.tokenize_function, batched=True, remove_columns=['text', 'label'])
|
| 143 |
+
|
| 144 |
+
return tokenized.train_test_split(test_size=0.1)
|
| 145 |
+
|
| 146 |
+
# ==========================================
|
| 147 |
+
# 4. Metrics
|
| 148 |
+
# ==========================================
|
| 149 |
+
def compute_metrics(pred):
|
| 150 |
+
labels = pred.label_ids
|
| 151 |
+
preds = pred.predictions.argmax(-1)
|
| 152 |
+
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
|
| 153 |
+
acc = accuracy_score(labels, preds)
|
| 154 |
+
return {'accuracy': acc, 'f1': f1}
|
| 155 |
+
|
| 156 |
+
# ==========================================
|
| 157 |
+
# 5. 主流程
|
| 158 |
+
# ==========================================
|
| 159 |
+
def main():
|
| 160 |
+
print("=== Cloud Training Script ===")
|
| 161 |
+
ensure_directories()
|
| 162 |
+
|
| 163 |
+
if torch.cuda.is_available():
|
| 164 |
+
print(f"✅ CUDA Enabled: {torch.cuda.get_device_name(0)}")
|
| 165 |
+
else:
|
| 166 |
+
print("⚠️ Running on CPU")
|
| 167 |
+
|
| 168 |
+
tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
|
| 169 |
+
processor = DataProcessor(tokenizer)
|
| 170 |
+
dataset = processor.get_dataset()
|
| 171 |
+
|
| 172 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 173 |
+
Config.BASE_MODEL,
|
| 174 |
+
num_labels=Config.NUM_LABELS,
|
| 175 |
+
id2label=Config.ID2LABEL,
|
| 176 |
+
label2id=Config.LABEL2ID
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
training_args = TrainingArguments(
|
| 180 |
+
output_dir=Config.CHECKPOINT_DIR, # Checkpoints 存放在这里
|
| 181 |
+
num_train_epochs=Config.NUM_EPOCHS,
|
| 182 |
+
per_device_train_batch_size=Config.BATCH_SIZE,
|
| 183 |
+
per_device_eval_batch_size=Config.BATCH_SIZE,
|
| 184 |
+
learning_rate=Config.LEARNING_RATE,
|
| 185 |
+
warmup_ratio=Config.WARMUP_RATIO,
|
| 186 |
+
logging_dir=os.path.join(Config.RESULTS_DIR, 'logs'), # Logs 存放在 Results
|
| 187 |
+
logging_steps=Config.LOGGING_STEPS,
|
| 188 |
+
eval_strategy="steps",
|
| 189 |
+
eval_steps=Config.SAVE_STEPS,
|
| 190 |
+
save_steps=Config.SAVE_STEPS,
|
| 191 |
+
save_total_limit=2,
|
| 192 |
+
load_best_model_at_end=True,
|
| 193 |
+
metric_for_best_model="f1",
|
| 194 |
+
fp16=torch.cuda.is_available(),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
trainer = Trainer(
|
| 198 |
+
model=model,
|
| 199 |
+
args=training_args,
|
| 200 |
+
train_dataset=dataset['train'],
|
| 201 |
+
eval_dataset=dataset['test'],
|
| 202 |
+
processing_class=tokenizer,
|
| 203 |
+
compute_metrics=compute_metrics,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
print(">>> Starting Training...")
|
| 207 |
+
trainer.train()
|
| 208 |
+
|
| 209 |
+
# 保存最终模型到 checkpoints/final_model
|
| 210 |
+
final_path = os.path.join(Config.CHECKPOINT_DIR, "final_model")
|
| 211 |
+
print(f">>> Saving Final Model to {final_path}...")
|
| 212 |
+
trainer.save_model(final_path)
|
| 213 |
+
tokenizer.save_pretrained(final_path)
|
| 214 |
+
|
| 215 |
+
# 绘制曲线到 results/
|
| 216 |
+
print(">>> Generating Plots...")
|
| 217 |
+
plot_path = os.path.join(Config.RESULTS_DIR, "training_curves_cloud.png")
|
| 218 |
+
plot_training_history(trainer.state.log_history, plot_path)
|
| 219 |
+
|
| 220 |
+
print(">>> All Done!")
|
| 221 |
+
|
| 222 |
+
if __name__ == "__main__":
|
| 223 |
+
main()
|
基于BERT的情感分析系统.pptx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d638b1bd27852f0337c373b96512140eeecb8d3949b1bc4e87060411afe59f22
|
| 3 |
+
size 914714
|