Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .github/workflows/docs.yaml +59 -0
- .github/workflows/python-publish.yml +39 -0
- .gitignore +95 -0
- ADDITIONS_README.md +71 -0
- CLEANUP_SUMMARY.md +164 -0
- COMPUTE_METRICS_README.md +191 -0
- DATA_FILES_NOTICE.md +107 -0
- DATA_TRANSFER_SUMMARY.md +95 -0
- HF_UPLOAD_GUIDE.md +180 -0
- LICENCE +203 -0
- MANIFEST.in +2 -0
- NOTICE +23 -0
- QUICK_START_HF.md +104 -0
- README.md +279 -0
- UPLOAD_CHECKLIST.md +116 -0
- cleanup_for_hf.py +293 -0
- compute_cascade_metrics.py +568 -0
- data/cascades/.gitkeep +3 -0
- data/cascades/README.md +101 -0
- docs/Makefile +20 -0
- docs/README.md +13 -0
- docs/images/thinning_algo.jpg +3 -0
- docs/make.bat +35 -0
- docs/source/advanced/implementation.rst +143 -0
- docs/source/advanced/performance_valid.rst +41 -0
- docs/source/advanced/tensorboard.rst +75 -0
- docs/source/advanced/thinning_algo.rst +56 -0
- docs/source/conf.py +59 -0
- docs/source/dev_guide/model_custom.rst +78 -0
- docs/source/get_started/install.rst +64 -0
- docs/source/get_started/introduction.rst +60 -0
- docs/source/get_started/quick_start.rst +106 -0
- docs/source/index.rst +56 -0
- docs/source/ref/config.rst +10 -0
- docs/source/ref/hpo.rst +10 -0
- docs/source/ref/models.rst +50 -0
- docs/source/ref/preprocess.rst +10 -0
- docs/source/ref/runner.rst +10 -0
- docs/source/ref/utils.rst +10 -0
- docs/source/ref/wrapper.rst +17 -0
- docs/source/user_guide/dataset.rst +124 -0
- docs/source/user_guide/run_eval.rst +97 -0
- docs/source/user_guide/run_train_pipeline.rst +245 -0
- easy_tpp/__init__.py +1 -0
- easy_tpp/config_factory/__init__.py +13 -0
- easy_tpp/config_factory/config.py +120 -0
- easy_tpp/config_factory/data_config.py +147 -0
- easy_tpp/config_factory/hpo_config.py +132 -0
- easy_tpp/config_factory/model_config.py +274 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
docs/images/thinning_algo.jpg filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/docs.yaml
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: docs
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [ main ]
|
| 6 |
+
pull_request:
|
| 7 |
+
branches: [ main ]
|
| 8 |
+
release:
|
| 9 |
+
types: [ published ]
|
| 10 |
+
|
| 11 |
+
jobs:
|
| 12 |
+
build:
|
| 13 |
+
|
| 14 |
+
runs-on: ubuntu-latest
|
| 15 |
+
|
| 16 |
+
steps:
|
| 17 |
+
- uses: actions/checkout@v3
|
| 18 |
+
with:
|
| 19 |
+
fetch-depth: 0
|
| 20 |
+
- name: Set up Python
|
| 21 |
+
uses: actions/setup-python@v4
|
| 22 |
+
with:
|
| 23 |
+
python-version: '3.8'
|
| 24 |
+
- name: Install dependencies
|
| 25 |
+
run: |
|
| 26 |
+
python -m pip install --upgrade pip setuptools wheel
|
| 27 |
+
sudo apt-get update
|
| 28 |
+
sudo apt-get install openjdk-11-jdk
|
| 29 |
+
sudo apt-get install pandoc
|
| 30 |
+
- name: Build Sphinx docs
|
| 31 |
+
run: |
|
| 32 |
+
pip install tensorflow==2.2.0
|
| 33 |
+
pip install torch
|
| 34 |
+
pip install pandas
|
| 35 |
+
pip install numpy
|
| 36 |
+
pip install -r requirements-doc.txt
|
| 37 |
+
cd docs
|
| 38 |
+
make html
|
| 39 |
+
# Publish built docs to gh-pages branch.
|
| 40 |
+
# ===============================
|
| 41 |
+
- name: Commit documentation changes
|
| 42 |
+
run: |
|
| 43 |
+
git clone https://github.com/ant-research/EasyTemporalPointProcess.git --branch gh-pages --single-branch gh-pages
|
| 44 |
+
cp -r docs/build/html/* gh-pages/
|
| 45 |
+
cd gh-pages
|
| 46 |
+
touch .nojekyll
|
| 47 |
+
git config --local user.email "action@github.com"
|
| 48 |
+
git config --local user.name "GitHub Action"
|
| 49 |
+
git add .
|
| 50 |
+
git commit -m "Update documentation" -a || true
|
| 51 |
+
# The above command will fail if no changes were present, so we ignore
|
| 52 |
+
# that.
|
| 53 |
+
- name: Push changes
|
| 54 |
+
uses: ad-m/github-push-action@master
|
| 55 |
+
with:
|
| 56 |
+
branch: gh-pages
|
| 57 |
+
directory: gh-pages
|
| 58 |
+
github_token: ${{ secrets.GITHUB_TOKEN }}
|
| 59 |
+
# ===============================
|
.github/workflows/python-publish.yml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This workflow will upload a Python Package using Twine when a release is created
|
| 2 |
+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
| 3 |
+
|
| 4 |
+
# This workflow uses actions that are not certified by GitHub.
|
| 5 |
+
# They are provided by a third-party and are governed by
|
| 6 |
+
# separate terms of service, privacy policy, and support
|
| 7 |
+
# documentation.
|
| 8 |
+
|
| 9 |
+
name: Upload Python Package
|
| 10 |
+
|
| 11 |
+
on:
|
| 12 |
+
release:
|
| 13 |
+
types: [published]
|
| 14 |
+
|
| 15 |
+
permissions:
|
| 16 |
+
contents: read
|
| 17 |
+
|
| 18 |
+
jobs:
|
| 19 |
+
deploy:
|
| 20 |
+
|
| 21 |
+
runs-on: ubuntu-latest
|
| 22 |
+
|
| 23 |
+
steps:
|
| 24 |
+
- uses: actions/checkout@v2
|
| 25 |
+
- name: Set up Python
|
| 26 |
+
uses: actions/setup-python@v2
|
| 27 |
+
with:
|
| 28 |
+
python-version: '3.9'
|
| 29 |
+
- name: Install dependencies
|
| 30 |
+
run: |
|
| 31 |
+
pip install -r requirements.txt
|
| 32 |
+
pip install wheel
|
| 33 |
+
- name: Build package
|
| 34 |
+
run: python setup.py sdist bdist_wheel
|
| 35 |
+
- name: Publish package
|
| 36 |
+
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
| 37 |
+
with:
|
| 38 |
+
user: __token__
|
| 39 |
+
password: ${{ secrets.PYPI_API_TOKEN }}
|
.gitignore
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python build
|
| 2 |
+
build/
|
| 3 |
+
dist/
|
| 4 |
+
easy_tpp.egg-info/
|
| 5 |
+
*.egg-info/
|
| 6 |
+
|
| 7 |
+
# python temp
|
| 8 |
+
*.pyc
|
| 9 |
+
*.pyo
|
| 10 |
+
*.pyd
|
| 11 |
+
__pycache__/
|
| 12 |
+
*.so
|
| 13 |
+
*.egg
|
| 14 |
+
|
| 15 |
+
# proto
|
| 16 |
+
protoc
|
| 17 |
+
protoc-3.4.0.tar.gz
|
| 18 |
+
*_pb2.py
|
| 19 |
+
|
| 20 |
+
# misc
|
| 21 |
+
experiments/
|
| 22 |
+
log/
|
| 23 |
+
logs/
|
| 24 |
+
*.swp
|
| 25 |
+
*.swo
|
| 26 |
+
.vscode/
|
| 27 |
+
.idea/
|
| 28 |
+
|
| 29 |
+
# OS files
|
| 30 |
+
.DS_Store
|
| 31 |
+
.DS_Store?
|
| 32 |
+
._*
|
| 33 |
+
.Spotlight-V100
|
| 34 |
+
.Trashes
|
| 35 |
+
ehthumbs.db
|
| 36 |
+
Thumbs.db
|
| 37 |
+
|
| 38 |
+
# IDE
|
| 39 |
+
.vscode/
|
| 40 |
+
.idea/
|
| 41 |
+
*.sublime-project
|
| 42 |
+
*.sublime-workspace
|
| 43 |
+
|
| 44 |
+
# Testing
|
| 45 |
+
.pytest_cache/
|
| 46 |
+
.mypy_cache/
|
| 47 |
+
.ruff_cache/
|
| 48 |
+
.coverage
|
| 49 |
+
htmlcov/
|
| 50 |
+
.tox/
|
| 51 |
+
|
| 52 |
+
# Checkpoints and outputs
|
| 53 |
+
examples/checkpoints/*
|
| 54 |
+
notebooks/checkpoints/*
|
| 55 |
+
results/
|
| 56 |
+
outputs/
|
| 57 |
+
*.pth
|
| 58 |
+
*.pt
|
| 59 |
+
*.ckpt
|
| 60 |
+
|
| 61 |
+
# Data files (large files should not be committed)
|
| 62 |
+
examples/data/*
|
| 63 |
+
!examples/data/.gitkeep
|
| 64 |
+
*.json.gz
|
| 65 |
+
*.pkl
|
| 66 |
+
*.h5
|
| 67 |
+
*.hdf5
|
| 68 |
+
|
| 69 |
+
# Large cascade data files (use Git LFS or external storage)
|
| 70 |
+
data/cascades/information_cascade*.json
|
| 71 |
+
data/cascades/*.json
|
| 72 |
+
!data/cascades/.gitkeep
|
| 73 |
+
|
| 74 |
+
# Model files (use Git LFS or Hugging Face Hub)
|
| 75 |
+
*.bin
|
| 76 |
+
*.safetensors
|
| 77 |
+
*.onnx
|
| 78 |
+
|
| 79 |
+
# Temporary files
|
| 80 |
+
*.tmp
|
| 81 |
+
*.temp
|
| 82 |
+
*.bak
|
| 83 |
+
*~
|
| 84 |
+
|
| 85 |
+
# Jupyter Notebook checkpoints
|
| 86 |
+
.ipynb_checkpoints/
|
| 87 |
+
|
| 88 |
+
# Environment
|
| 89 |
+
.env
|
| 90 |
+
.venv
|
| 91 |
+
env/
|
| 92 |
+
venv/
|
| 93 |
+
ENV/
|
| 94 |
+
env.bak/
|
| 95 |
+
venv.bak/
|
ADDITIONS_README.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EasyTPP 新增功能说明
|
| 2 |
+
|
| 3 |
+
本仓库在原始 EasyTPP 基础上新增了用于处理信息级联数据的指标计算功能。
|
| 4 |
+
|
| 5 |
+
## 🆕 新增文件
|
| 6 |
+
|
| 7 |
+
### 核心功能
|
| 8 |
+
- **`compute_cascade_metrics.py`**: 计算级联指标的主脚本
|
| 9 |
+
- 情感得分 (Sentiment Score)
|
| 10 |
+
- 情感偏差 (Sentiment Deviation)
|
| 11 |
+
- 语境偏差 (Contextual Deviation)
|
| 12 |
+
- 困惑度 (Perplexity)
|
| 13 |
+
|
| 14 |
+
### 文档和工具
|
| 15 |
+
- **`COMPUTE_METRICS_README.md`**: 详细使用说明
|
| 16 |
+
- **`HF_UPLOAD_GUIDE.md`**: Hugging Face 上传指南
|
| 17 |
+
- **`UPLOAD_CHECKLIST.md`**: 上传检查清单(自动生成)
|
| 18 |
+
- **`cleanup_for_hf.py`**: 清理脚本,准备上传
|
| 19 |
+
- **`example_compute_metrics.sh`**: 使用示例脚本
|
| 20 |
+
- **`requirements_compute_metrics.txt`**: 额外依赖包
|
| 21 |
+
|
| 22 |
+
## 🚀 快速开始
|
| 23 |
+
|
| 24 |
+
### 1. 安装依赖
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
pip install -r requirements.txt
|
| 28 |
+
pip install -r requirements_compute_metrics.txt
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
### 2. 运行指标计算
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
python compute_cascade_metrics.py \
|
| 35 |
+
--input_cascade information_cascade.json \
|
| 36 |
+
--output output_with_metrics.json \
|
| 37 |
+
--batch_size 32
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
详细说明请参考 `COMPUTE_METRICS_README.md`
|
| 41 |
+
|
| 42 |
+
## 📦 上传到 Hugging Face
|
| 43 |
+
|
| 44 |
+
1. 运行清理脚本:
|
| 45 |
+
```bash
|
| 46 |
+
python cleanup_for_hf.py
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
2. 按照 `HF_UPLOAD_GUIDE.md` 的说明上传
|
| 50 |
+
|
| 51 |
+
## 🔗 相关文档
|
| 52 |
+
|
| 53 |
+
- [指标计算说明](COMPUTE_METRICS_README.md)
|
| 54 |
+
- [上传指南](HF_UPLOAD_GUIDE.md)
|
| 55 |
+
- [原始 EasyTPP README](README.md)
|
| 56 |
+
|
| 57 |
+
## 📝 使用场景
|
| 58 |
+
|
| 59 |
+
这些新增功能主要用于:
|
| 60 |
+
- 分析社交媒体信息级联(如微博转发、评论)
|
| 61 |
+
- 计算文本的情感特征和语义偏差
|
| 62 |
+
- 为 TPP 模型提供额外的特征输入
|
| 63 |
+
|
| 64 |
+
## ⚙️ 与 EasyTPP 集成
|
| 65 |
+
|
| 66 |
+
计算出的指标可以用于:
|
| 67 |
+
- `RobertTPPDataset`: 加载包含语义和偏差特征的数据
|
| 68 |
+
- `RobertEventTokenizer`: 处理自定义特征
|
| 69 |
+
- `TorchRobotTHP`: 使用语义和偏差特征的 TPP 模型
|
| 70 |
+
|
| 71 |
+
参考 `examples/train_robot_thp_with_features.py` 了解完整示例。
|
CLEANUP_SUMMARY.md
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 文件夹整理总结
|
| 2 |
+
|
| 3 |
+
## ✅ 整理完成时间
|
| 4 |
+
2025-01-19
|
| 5 |
+
|
| 6 |
+
## 📊 文件夹信息
|
| 7 |
+
|
| 8 |
+
- **位置**: `/Users/chenshuyi/Downloads/EasyTemporalPointProcess-main`
|
| 9 |
+
- **大小**: 1.3MB
|
| 10 |
+
- **状态**: ✅ 已整理,可以上传
|
| 11 |
+
|
| 12 |
+
## 🧹 已完成的清理工作
|
| 13 |
+
|
| 14 |
+
### 1. 更新 .gitignore
|
| 15 |
+
- ✅ 添加了 Python 缓存文件模式
|
| 16 |
+
- ✅ 添加了 IDE 配置文件排除
|
| 17 |
+
- ✅ 添加了 OS 系统文件排除
|
| 18 |
+
- ✅ 添加了测试和构建文件排除
|
| 19 |
+
- ✅ 添加了数据文件和模型文件排除规则
|
| 20 |
+
|
| 21 |
+
### 2. 创建清理脚本
|
| 22 |
+
- ✅ `cleanup_for_hf.py` - 自动清理脚本
|
| 23 |
+
- ✅ 已运行,确认无需要删除的文件
|
| 24 |
+
|
| 25 |
+
### 3. 文件检查
|
| 26 |
+
- ✅ 无大文件(>50MB)
|
| 27 |
+
- ✅ 无敏感信息
|
| 28 |
+
- ✅ 无临时文件
|
| 29 |
+
|
| 30 |
+
## 📁 文件结构
|
| 31 |
+
|
| 32 |
+
```
|
| 33 |
+
EasyTemporalPointProcess-main/
|
| 34 |
+
├── 📄 核心代码
|
| 35 |
+
│ ├── easy_tpp/ # EasyTPP 核心库
|
| 36 |
+
│ ├── examples/ # 示例代码
|
| 37 |
+
│ ├── notebooks/ # Jupyter notebooks
|
| 38 |
+
│ └── tests/ # 测试代码
|
| 39 |
+
│
|
| 40 |
+
├── 🆕 新增功能(级联指标计算)
|
| 41 |
+
│ ├── compute_cascade_metrics.py # 主计算脚本
|
| 42 |
+
│ ├── COMPUTE_METRICS_README.md # 使用说明
|
| 43 |
+
│ ├── requirements_compute_metrics.txt # 额外依赖
|
| 44 |
+
│ └── example_compute_metrics.sh # 示例脚本
|
| 45 |
+
│
|
| 46 |
+
├── 📚 文档
|
| 47 |
+
│ ├── README.md # 原始 README
|
| 48 |
+
│ ├── ADDITIONS_README.md # 新增功能说明
|
| 49 |
+
│ ├── HF_UPLOAD_GUIDE.md # Hugging Face 上传指南
|
| 50 |
+
│ ├── QUICK_START_HF.md # 快速开始指南
|
| 51 |
+
│ ├── UPLOAD_CHECKLIST.md # 上传检查清单
|
| 52 |
+
│ └── CLEANUP_SUMMARY.md # 本文件
|
| 53 |
+
│
|
| 54 |
+
├── 🛠️ 工具脚本
|
| 55 |
+
│ ├── cleanup_for_hf.py # 清理脚本
|
| 56 |
+
│ └── setup.py # 安装脚本
|
| 57 |
+
│
|
| 58 |
+
└── ⚙️ 配置文件
|
| 59 |
+
├── .gitignore # Git 忽略规则(已更新)
|
| 60 |
+
├── requirements.txt # 基础依赖
|
| 61 |
+
├── requirements_compute_metrics.txt # 指标计算依赖
|
| 62 |
+
└── setup.cfg # 安装配置
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## 📋 新增文件列表
|
| 66 |
+
|
| 67 |
+
### 核心功能
|
| 68 |
+
1. `compute_cascade_metrics.py` (19.5 KB)
|
| 69 |
+
- 计算情感得分、情感偏差、语境偏差、困惑度
|
| 70 |
+
|
| 71 |
+
### 文档
|
| 72 |
+
2. `COMPUTE_METRICS_README.md` (5.9 KB)
|
| 73 |
+
- 详细的指标计算使用说明
|
| 74 |
+
|
| 75 |
+
3. `HF_UPLOAD_GUIDE.md` (3.7 KB)
|
| 76 |
+
- Hugging Face 上传完整指南
|
| 77 |
+
|
| 78 |
+
4. `ADDITIONS_README.md` (1.9 KB)
|
| 79 |
+
- 新增功能概述
|
| 80 |
+
|
| 81 |
+
5. `QUICK_START_HF.md` (2.3 KB)
|
| 82 |
+
- 快速上传指南
|
| 83 |
+
|
| 84 |
+
6. `UPLOAD_CHECKLIST.md` (3.0 KB)
|
| 85 |
+
- 上传检查清单(自动生成)
|
| 86 |
+
|
| 87 |
+
7. `CLEANUP_SUMMARY.md` (本文件)
|
| 88 |
+
- 整理总结
|
| 89 |
+
|
| 90 |
+
### 工具和配置
|
| 91 |
+
8. `cleanup_for_hf.py` (7.8 KB)
|
| 92 |
+
- 自动清理脚本
|
| 93 |
+
|
| 94 |
+
9. `example_compute_metrics.sh` (1.2 KB)
|
| 95 |
+
- 使用示例脚本
|
| 96 |
+
|
| 97 |
+
10. `requirements_compute_metrics.txt` (266 B)
|
| 98 |
+
- 指标计算所需依赖
|
| 99 |
+
|
| 100 |
+
## 🎯 下一步操作
|
| 101 |
+
|
| 102 |
+
### 1. 上传到 Hugging Face
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
# 安装 CLI
|
| 106 |
+
pip install huggingface_hub
|
| 107 |
+
|
| 108 |
+
# 登录
|
| 109 |
+
huggingface-cli login
|
| 110 |
+
|
| 111 |
+
# 创建仓库(在网页上)
|
| 112 |
+
# https://huggingface.co/new
|
| 113 |
+
|
| 114 |
+
# 上传
|
| 115 |
+
cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main
|
| 116 |
+
huggingface-cli upload <username>/<repo-name> . --repo-type dataset
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
### 2. 在云电脑上下载
|
| 120 |
+
|
| 121 |
+
```bash
|
| 122 |
+
huggingface-cli download <username>/<repo-name> --local-dir ./EasyTPP
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
### 3. 使用新功能
|
| 126 |
+
|
| 127 |
+
```bash
|
| 128 |
+
cd EasyTPP
|
| 129 |
+
pip install -r requirements.txt
|
| 130 |
+
pip install -r requirements_compute_metrics.txt
|
| 131 |
+
|
| 132 |
+
python compute_cascade_metrics.py \
|
| 133 |
+
--input_cascade information_cascade.json \
|
| 134 |
+
--output output_with_metrics.json
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
## ✅ 检查清单
|
| 138 |
+
|
| 139 |
+
- [x] 清理临时文件
|
| 140 |
+
- [x] 更新 .gitignore
|
| 141 |
+
- [x] 检查大文件
|
| 142 |
+
- [x] 检查敏感信息
|
| 143 |
+
- [x] 创建上传指南
|
| 144 |
+
- [x] 创建使用文档
|
| 145 |
+
- [x] 验证文件结构
|
| 146 |
+
- [ ] 上传到 Hugging Face(待执行)
|
| 147 |
+
- [ ] 在云电脑上测试(待执行)
|
| 148 |
+
|
| 149 |
+
## 📝 注意事项
|
| 150 |
+
|
| 151 |
+
1. **文件大小**: 1.3MB,无需 Git LFS
|
| 152 |
+
2. **许可证**: 保持原始 Apache 2.0 许可证
|
| 153 |
+
3. **依赖**: 确保所有依赖都在 requirements 文件中
|
| 154 |
+
4. **文档**: 所有新增功能都有详细文档
|
| 155 |
+
|
| 156 |
+
## 🔗 相关链接
|
| 157 |
+
|
| 158 |
+
- [Hugging Face](https://huggingface.co/)
|
| 159 |
+
- [Hugging Face CLI 文档](https://huggingface.co/docs/huggingface_hub/guides/cli)
|
| 160 |
+
- [原始 EasyTPP 项目](https://github.com/ant-research/EasyTemporalPointProcess)
|
| 161 |
+
|
| 162 |
+
---
|
| 163 |
+
|
| 164 |
+
**整理完成!可以开始上传了!** 🚀
|
COMPUTE_METRICS_README.md
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 计算级联指标使用说明
|
| 2 |
+
|
| 3 |
+
本脚本用于计算信息级联数据的情感得分、情感deviation、contextual deviation和perplexity。
|
| 4 |
+
|
| 5 |
+
## 功能说明
|
| 6 |
+
|
| 7 |
+
脚本 `compute_cascade_metrics.py` 会处理以下两个JSON文件:
|
| 8 |
+
- `information_cascade.json`: 包含完整级联数据(原帖、评论、转发)
|
| 9 |
+
- `information_cascade_original_posts.json`: 包含原帖数据(可选)
|
| 10 |
+
|
| 11 |
+
计算以下指标:
|
| 12 |
+
1. **情感得分 (Sentiment Score)**: 文本的情感倾向得分
|
| 13 |
+
2. **情感偏差 (Sentiment Deviation)**: 相对于原帖的情感偏差
|
| 14 |
+
3. **语境偏差 (Contextual Deviation)**: 相对于原帖的语义偏差
|
| 15 |
+
4. **困惑度 (Perplexity)**: 文本的语言模型困惑度
|
| 16 |
+
|
| 17 |
+
## 安装依赖
|
| 18 |
+
|
| 19 |
+
在云电脑上安装必要的依赖:
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
pip install torch transformers numpy tqdm
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
如果需要使用GPU:
|
| 26 |
+
```bash
|
| 27 |
+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
## 使用方法
|
| 31 |
+
|
| 32 |
+
### 基本用法(使用默认模型)
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
python compute_cascade_metrics.py \
|
| 36 |
+
--input_cascade information_cascade.json \
|
| 37 |
+
--output output_with_metrics.json \
|
| 38 |
+
--batch_size 32
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### 完整用法(指定所有模型)
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
python compute_cascade_metrics.py \
|
| 45 |
+
--input_cascade information_cascade.json \
|
| 46 |
+
--input_original information_cascade_original_posts.json \
|
| 47 |
+
--output output_with_metrics.json \
|
| 48 |
+
--bert_model bert-base-chinese \
|
| 49 |
+
--sentiment_model <情感分析模型路径> \
|
| 50 |
+
--perplexity_model <语言模型路径> \
|
| 51 |
+
--batch_size 32 \
|
| 52 |
+
--max_length 512 \
|
| 53 |
+
--device cuda
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
### 参数说明
|
| 57 |
+
|
| 58 |
+
- `--input_cascade`: **必需**,输入级联JSON文件路径
|
| 59 |
+
- `--input_original`: 可选,输入原帖JSON文件路径
|
| 60 |
+
- `--output`: **必需**,输出JSON文件路径
|
| 61 |
+
- `--bert_model`: BERT模型名称或路径(默认: `bert-base-chinese`)
|
| 62 |
+
- `--sentiment_model`: 情感分析模型路径(可选,不提供则使用简化方法)
|
| 63 |
+
- `--perplexity_model`: 语言模型路径(可选,不提供则使用简化方法)
|
| 64 |
+
- `--batch_size`: 批处理大小(默认: 32)
|
| 65 |
+
- `--max_length`: 最大序列长度(默认: 512)
|
| 66 |
+
- `--device`: 计算设备,`cuda` 或 `cpu`(默认: 自动选择)
|
| 67 |
+
- `--max_cascades`: 最大处理级联数量(用于测试,默认: 处理所有)
|
| 68 |
+
|
| 69 |
+
## 输出格式
|
| 70 |
+
|
| 71 |
+
处理后的JSON文件会在每个节点中添加以下字段:
|
| 72 |
+
|
| 73 |
+
### 原帖 (`post_info`)
|
| 74 |
+
```json
|
| 75 |
+
{
|
| 76 |
+
"post_info": {
|
| 77 |
+
"content": "原帖内容",
|
| 78 |
+
"embedding": [0.1, 0.2, ...], // BERT语义向量 (768维)
|
| 79 |
+
"sentiment_score": 0.7, // 情感得分
|
| 80 |
+
"perplexity": 15.3 // 困惑度
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
### 评论 (`comment_tree`)
|
| 86 |
+
```json
|
| 87 |
+
{
|
| 88 |
+
"comment_tree": {
|
| 89 |
+
"comment_id": {
|
| 90 |
+
"content": "评论内容",
|
| 91 |
+
"embedding": [0.1, 0.2, ...],
|
| 92 |
+
"sentiment_score": 0.6,
|
| 93 |
+
"perplexity": 12.5,
|
| 94 |
+
"contextual_deviation": 0.25, // 语境偏差
|
| 95 |
+
"sentiment_deviation": 0.1 // 情感偏差
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
### 转发 (`repost_chain`)
|
| 102 |
+
```json
|
| 103 |
+
{
|
| 104 |
+
"repost_chain": [
|
| 105 |
+
{
|
| 106 |
+
"forward_text": "转发内容",
|
| 107 |
+
"comment_content": "评论内容",
|
| 108 |
+
"embedding": [0.1, 0.2, ...],
|
| 109 |
+
"sentiment_score": 0.5,
|
| 110 |
+
"perplexity": 18.2,
|
| 111 |
+
"contextual_deviation": 0.35,
|
| 112 |
+
"sentiment_deviation": 0.2
|
| 113 |
+
}
|
| 114 |
+
]
|
| 115 |
+
}
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
## 模型选择建议
|
| 119 |
+
|
| 120 |
+
### BERT模型
|
| 121 |
+
- 中文文本:`bert-base-chinese`
|
| 122 |
+
- 英文文本:`bert-base-uncased`
|
| 123 |
+
- 自定义模型:提供本地路径
|
| 124 |
+
|
| 125 |
+
### 情感分析模型
|
| 126 |
+
- 中文:可以使用 `uer/roberta-base-finetuned-chinanews-chinese` 或其他中文情感分析模型
|
| 127 |
+
- 英文:可以使用 `nlptown/bert-base-multilingual-uncased-sentiment` 等
|
| 128 |
+
- 如果不提供,脚本会使用基于关键词的简化方法
|
| 129 |
+
|
| 130 |
+
### 困惑度模型
|
| 131 |
+
- 中文:可以使用 `gpt2-chinese` 或其他中文语言模型
|
| 132 |
+
- 英文:可以使用 `gpt2` 等
|
| 133 |
+
- 如果不提供,脚本会使用基于词汇多样性的简化方法
|
| 134 |
+
|
| 135 |
+
## 注意事项
|
| 136 |
+
|
| 137 |
+
1. **大文件处理**: 如果JSON文件很大,处理时间可能较长。建议:
|
| 138 |
+
- 使用GPU加速(`--device cuda`)
|
| 139 |
+
- 调整批处理大小(`--batch_size`)
|
| 140 |
+
- 先用 `--max_cascades` 测试少量数据
|
| 141 |
+
|
| 142 |
+
2. **内存使用**:
|
| 143 |
+
- BERT模型需要较多内存
|
| 144 |
+
- 如果内存不足,减小 `--batch_size`
|
| 145 |
+
|
| 146 |
+
3. **简化方法**:
|
| 147 |
+
- 如果不提供情感分析模型或困惑度模型,脚本会使用简化的启发式方法
|
| 148 |
+
- 简化方法的结果可能不如专业模型准确,但计算速度快
|
| 149 |
+
|
| 150 |
+
4. **数据格式**:
|
| 151 |
+
- 确保输入的JSON文件格式正确
|
| 152 |
+
- JSON文件应包含 `cascades` 字段,每个级联包含 `post_info`、`comment_tree`、`repost_chain`
|
| 153 |
+
|
| 154 |
+
## 示例
|
| 155 |
+
|
| 156 |
+
### 示例1:使用默认设置处理数据
|
| 157 |
+
```bash
|
| 158 |
+
python compute_cascade_metrics.py \
|
| 159 |
+
--input_cascade /path/to/information_cascade.json \
|
| 160 |
+
--output /path/to/output.json
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
### 示例2:使用GPU和自定义模型
|
| 164 |
+
```bash
|
| 165 |
+
python compute_cascade_metrics.py \
|
| 166 |
+
--input_cascade /path/to/information_cascade.json \
|
| 167 |
+
--output /path/to/output.json \
|
| 168 |
+
--bert_model bert-base-chinese \
|
| 169 |
+
--sentiment_model /path/to/sentiment_model \
|
| 170 |
+
--device cuda \
|
| 171 |
+
--batch_size 64
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
### 示例3:测试模式(只处理前10个级联)
|
| 175 |
+
```bash
|
| 176 |
+
python compute_cascade_metrics.py \
|
| 177 |
+
--input_cascade /path/to/information_cascade.json \
|
| 178 |
+
--output /path/to/output.json \
|
| 179 |
+
--max_cascades 10
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
## 故障排除
|
| 183 |
+
|
| 184 |
+
1. **CUDA内存不足**: 减小 `--batch_size` 或使用 `--device cpu`
|
| 185 |
+
2. **模型下载失败**: 检查网络连接,或手动下载模型到本地后指定路径
|
| 186 |
+
3. **JSON格式错误**: 检查输入JSON文件格式是否正确
|
| 187 |
+
4. **处理速度慢**: 使用GPU(`--device cuda`)和增大批处理大小
|
| 188 |
+
|
| 189 |
+
## 与EasyTPP集成
|
| 190 |
+
|
| 191 |
+
处理后的JSON文件可以用于EasyTPP框架的训练。参考 `examples/train_robot_thp_with_features.py` 了解如何使用这些特征。
|
DATA_FILES_NOTICE.md
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ⚠️ 数据文件说明
|
| 2 |
+
|
| 3 |
+
## 📁 数据文件位置
|
| 4 |
+
|
| 5 |
+
数据文件已复制到 `data/cascades/` 目录:
|
| 6 |
+
|
| 7 |
+
- `data/cascades/information_cascade.json` (606MB)
|
| 8 |
+
- `data/cascades/information_cascade_original_posts.json` (980MB)
|
| 9 |
+
|
| 10 |
+
## ⚠️ 重要提示
|
| 11 |
+
|
| 12 |
+
**这些文件太大(总计约 1.6GB),不会上传到 Hugging Face!**
|
| 13 |
+
|
| 14 |
+
`.gitignore` 已配置为排除这些文件,因为它们超过了 Git/Hugging Face 的推荐大小限制。
|
| 15 |
+
|
| 16 |
+
## 📥 在云电脑上获取数据文件
|
| 17 |
+
|
| 18 |
+
### 方法1: 直接传输(推荐)
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
# 在云电脑上创建目录
|
| 22 |
+
mkdir -p data/cascades
|
| 23 |
+
|
| 24 |
+
# 使用 scp 从本地传输
|
| 25 |
+
scp -r user@local-machine:/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/information_cascade*.json ./data/cascades/
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
### 方法2: 使用云存储
|
| 29 |
+
|
| 30 |
+
1. 将文件上传到云存储(Google Drive, Dropbox, OneDrive 等)
|
| 31 |
+
2. 在云电脑上下载
|
| 32 |
+
|
| 33 |
+
### 方法3: 使用 Git LFS(如果配置)
|
| 34 |
+
|
| 35 |
+
如果需要通过 Git 管理大文件:
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
# 安装 Git LFS
|
| 39 |
+
git lfs install
|
| 40 |
+
|
| 41 |
+
# 跟踪大文件
|
| 42 |
+
git lfs track "data/cascades/*.json"
|
| 43 |
+
|
| 44 |
+
# 添加文件
|
| 45 |
+
git add .gitattributes
|
| 46 |
+
git add data/cascades/*.json
|
| 47 |
+
git commit -m "Add cascade data with LFS"
|
| 48 |
+
git push
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### 方法4: 使用 Hugging Face Dataset Hub
|
| 52 |
+
|
| 53 |
+
可以将数据文件单独上传到 Hugging Face Dataset Hub:
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
# 安装依赖
|
| 57 |
+
pip install huggingface_hub
|
| 58 |
+
|
| 59 |
+
# 上传数据文件
|
| 60 |
+
huggingface-cli upload <username>/cascade-data data/cascades/ --repo-type dataset
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
然后在云电脑上下载:
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
huggingface-cli download <username>/cascade-data --local-dir ./data/cascades
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## ✅ 验证文件
|
| 70 |
+
|
| 71 |
+
上传到 Hugging Face 后,验证:
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
# 检查文件是否存在
|
| 75 |
+
ls -lh data/cascades/
|
| 76 |
+
|
| 77 |
+
# 应该看到:
|
| 78 |
+
# information_cascade.json
|
| 79 |
+
# information_cascade_original_posts.json
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## 🚀 使用数据文件
|
| 83 |
+
|
| 84 |
+
文件准备好后,运行指标计算:
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
python compute_cascade_metrics.py \
|
| 88 |
+
--input_cascade data/cascades/information_cascade.json \
|
| 89 |
+
--input_original data/cascades/information_cascade_original_posts.json \
|
| 90 |
+
--output output_with_metrics.json \
|
| 91 |
+
--batch_size 32 \
|
| 92 |
+
--device cuda
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
## 📝 文件来源
|
| 96 |
+
|
| 97 |
+
原始文件位置:
|
| 98 |
+
- `/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/`
|
| 99 |
+
|
| 100 |
+
已复制到:
|
| 101 |
+
- `/Users/chenshuyi/Downloads/EasyTemporalPointProcess-main/data/cascades/`
|
| 102 |
+
|
| 103 |
+
## 🔗 相关文档
|
| 104 |
+
|
| 105 |
+
- [数据文件说明](data/cascades/README.md)
|
| 106 |
+
- [指标计算说明](COMPUTE_METRICS_README.md)
|
| 107 |
+
- [上传指南](HF_UPLOAD_GUIDE.md)
|
DATA_TRANSFER_SUMMARY.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 数据文件转移总结
|
| 2 |
+
|
| 3 |
+
## ✅ 已完成
|
| 4 |
+
|
| 5 |
+
两个 information cascade 文件已成功复制到 EasyTemporalPointProcess-main 文件夹。
|
| 6 |
+
|
| 7 |
+
## 📁 文件位置
|
| 8 |
+
|
| 9 |
+
### 源文件位置
|
| 10 |
+
- `/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/information_cascade.json`
|
| 11 |
+
- `/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/information_cascade_original_posts.json`
|
| 12 |
+
|
| 13 |
+
### 目标位置
|
| 14 |
+
- `/Users/chenshuyi/Downloads/EasyTemporalPointProcess-main/data/cascades/information_cascade.json` (606MB)
|
| 15 |
+
- `/Users/chenshuyi/Downloads/EasyTemporalPointProcess-main/data/cascades/information_cascade_original_posts.json` (980MB)
|
| 16 |
+
|
| 17 |
+
## ⚠️ 重要说明
|
| 18 |
+
|
| 19 |
+
### 文件大小
|
| 20 |
+
- **总大小**: 约 1.6GB
|
| 21 |
+
- **information_cascade.json**: 606MB
|
| 22 |
+
- **information_cascade_original_posts.json**: 980MB
|
| 23 |
+
|
| 24 |
+
### Git 排除配置
|
| 25 |
+
这些文件**不会上传到 Hugging Face**,因为:
|
| 26 |
+
1. 文件太大,超过 Git/Hugging Face 推荐大小
|
| 27 |
+
2. 已通过 `.gitignore` 排除:
|
| 28 |
+
```
|
| 29 |
+
data/cascades/information_cascade*.json
|
| 30 |
+
data/cascades/*.json
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
## 📥 在云电脑上获取数据文件
|
| 34 |
+
|
| 35 |
+
### 方法1: 使用 scp 传输(推荐)
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
# 在云电脑上
|
| 39 |
+
mkdir -p data/cascades
|
| 40 |
+
|
| 41 |
+
# 从本地传输
|
| 42 |
+
scp user@local-machine:/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/information_cascade*.json ./data/cascades/
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
### 方法2: 上传到 Hugging Face Dataset Hub
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
# 在本地
|
| 49 |
+
cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main
|
| 50 |
+
huggingface-cli upload <username>/cascade-data data/cascades/ --repo-type dataset
|
| 51 |
+
|
| 52 |
+
# 在云电脑上下载
|
| 53 |
+
huggingface-cli download <username>/cascade-data --local-dir ./data/cascades
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
### 方法3: 使用云存储
|
| 57 |
+
|
| 58 |
+
1. 将文件上传到 Google Drive / Dropbox / OneDrive
|
| 59 |
+
2. 在云电脑上下载
|
| 60 |
+
|
| 61 |
+
## 📝 相关文档
|
| 62 |
+
|
| 63 |
+
- **数据文件说明**: `data/cascades/README.md`
|
| 64 |
+
- **数据文件注意事项**: `DATA_FILES_NOTICE.md`
|
| 65 |
+
- **上传指南**: `HF_UPLOAD_GUIDE.md`
|
| 66 |
+
|
| 67 |
+
## ✅ 验证
|
| 68 |
+
|
| 69 |
+
上传到 Hugging Face 后,验证数据文件:
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
# 检查文件是否存在
|
| 73 |
+
ls -lh data/cascades/
|
| 74 |
+
|
| 75 |
+
# 应该看到:
|
| 76 |
+
# information_cascade.json (606MB)
|
| 77 |
+
# information_cascade_original_posts.json (980MB)
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## 🚀 使用数据文件
|
| 81 |
+
|
| 82 |
+
文件准备好后,运行指标计算:
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
python compute_cascade_metrics.py \
|
| 86 |
+
--input_cascade data/cascades/information_cascade.json \
|
| 87 |
+
--input_original data/cascades/information_cascade_original_posts.json \
|
| 88 |
+
--output output_with_metrics.json \
|
| 89 |
+
--batch_size 32 \
|
| 90 |
+
--device cuda
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
**数据文件已成功转移!** ✅
|
HF_UPLOAD_GUIDE.md
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face 上传指南
|
| 2 |
+
|
| 3 |
+
本指南说明如何将 EasyTemporalPointProcess-main 上传到 Hugging Face。
|
| 4 |
+
|
| 5 |
+
## 📋 准备工作
|
| 6 |
+
|
| 7 |
+
### 1. 运行清理脚本
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main
|
| 11 |
+
python cleanup_for_hf.py
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
这会自动:
|
| 15 |
+
- 删除 `__pycache__/`、`.pyc` 等临时文件
|
| 16 |
+
- 检查大文件
|
| 17 |
+
- 创建上传检查清单
|
| 18 |
+
|
| 19 |
+
### 2. 数据文件说明 ⚠️
|
| 20 |
+
|
| 21 |
+
**重要**: `data/cascades/` 目录包含大文件(约 1.6GB),**不会上传到 Hugging Face**。
|
| 22 |
+
|
| 23 |
+
这些文件已通过 `.gitignore` 排除:
|
| 24 |
+
- `information_cascade.json` (606MB)
|
| 25 |
+
- `information_cascade_original_posts.json` (980MB)
|
| 26 |
+
|
| 27 |
+
**在云电脑上获取数据文件的方法**:
|
| 28 |
+
- 方法1: 使用 scp 直接传输(推荐)
|
| 29 |
+
- 方法2: 上传到云存储后下载
|
| 30 |
+
- 方法3: 使用 Git LFS(如果配置)
|
| 31 |
+
- 方法4: 单独上传到 Hugging Face Dataset Hub
|
| 32 |
+
|
| 33 |
+
详细说明请参考 `DATA_FILES_NOTICE.md`
|
| 34 |
+
|
| 35 |
+
### 3. 手动检查
|
| 36 |
+
|
| 37 |
+
- [ ] 检查是否有敏感信息(API密钥、密码等)
|
| 38 |
+
- [ ] 确认大文件已正确排除(通过 .gitignore)
|
| 39 |
+
- [ ] 确保 `requirements.txt` 是最新的
|
| 40 |
+
- [ ] 检查 README.md 是否完整
|
| 41 |
+
|
| 42 |
+
## 🚀 上传方法
|
| 43 |
+
|
| 44 |
+
### 方法1: 使用 Hugging Face CLI(推荐)
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
# 1. 安装 Hugging Face CLI
|
| 48 |
+
pip install huggingface_hub
|
| 49 |
+
|
| 50 |
+
# 2. 登录
|
| 51 |
+
huggingface-cli login
|
| 52 |
+
# 输入你的 Hugging Face token(在 https://huggingface.co/settings/tokens 获取)
|
| 53 |
+
|
| 54 |
+
# 3. 创建仓库(在网页上创建,或使用 CLI)
|
| 55 |
+
# 访问 https://huggingface.co/new 创建新仓库
|
| 56 |
+
# 选择 "Dataset" 类型,命名为例如:easytpp-cascade-metrics
|
| 57 |
+
|
| 58 |
+
# 4. 上传文件
|
| 59 |
+
cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main
|
| 60 |
+
huggingface-cli upload <your-username>/easytpp-cascade-metrics . --repo-type dataset
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
### 方法2: 使用 Git
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
# 1. 初始化 Git(如果还没有)
|
| 67 |
+
cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main
|
| 68 |
+
git init
|
| 69 |
+
|
| 70 |
+
# 2. 添加文件
|
| 71 |
+
git add .
|
| 72 |
+
git commit -m "Add EasyTPP with cascade metrics computation"
|
| 73 |
+
|
| 74 |
+
# 3. 添加 Hugging Face 远程仓库
|
| 75 |
+
# 先在 https://huggingface.co/new 创建仓库
|
| 76 |
+
git remote add origin https://huggingface.co/<your-username>/<repo-name>
|
| 77 |
+
|
| 78 |
+
# 4. 推送
|
| 79 |
+
git push origin main
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
### 方法3: 使用 Web 界面上传
|
| 83 |
+
|
| 84 |
+
1. 访问 https://huggingface.co/new
|
| 85 |
+
2. 创建新的 Dataset 仓库
|
| 86 |
+
3. 点击 "Add file" → "Upload files"
|
| 87 |
+
4. 拖拽或选择文件夹上传
|
| 88 |
+
|
| 89 |
+
## 📦 在云电脑上下载
|
| 90 |
+
|
| 91 |
+
上传完成后,在云电脑上下载:
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
# 方法1: 使用 Hugging Face CLI
|
| 95 |
+
pip install huggingface_hub
|
| 96 |
+
huggingface-cli download <your-username>/<repo-name> --local-dir ./EasyTPP
|
| 97 |
+
|
| 98 |
+
# 方法2: 使用 Git
|
| 99 |
+
git clone https://huggingface.co/datasets/<your-username>/<repo-name>
|
| 100 |
+
cd <repo-name>
|
| 101 |
+
|
| 102 |
+
# 方法3: 使用 Python
|
| 103 |
+
from huggingface_hub import snapshot_download
|
| 104 |
+
snapshot_download(repo_id="<your-username>/<repo-name>", repo_type="dataset", local_dir="./EasyTPP")
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
### 📥 下载数据文件
|
| 108 |
+
|
| 109 |
+
**重要**: 代码仓库不包含数据文件(已通过 .gitignore 排除)。
|
| 110 |
+
|
| 111 |
+
数据文件需要单独获取:
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
# 方法1: 使用 scp 从本地传输(推荐)
|
| 115 |
+
mkdir -p data/cascades
|
| 116 |
+
scp user@local-machine:/path/to/information_cascade*.json ./data/cascades/
|
| 117 |
+
|
| 118 |
+
# 方法2: 如果已上传到 Hugging Face Dataset Hub
|
| 119 |
+
huggingface-cli download <username>/cascade-data --local-dir ./data/cascades
|
| 120 |
+
|
| 121 |
+
# 方法3: 从云存储下载
|
| 122 |
+
# (根据你使用的云存储服务)
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
详细说明请参考 `DATA_FILES_NOTICE.md`
|
| 126 |
+
|
| 127 |
+
## 📝 新增功能说明
|
| 128 |
+
|
| 129 |
+
本仓库在原始 EasyTPP 基础上新增了以下功能:
|
| 130 |
+
|
| 131 |
+
### 1. 级联指标计算 (`compute_cascade_metrics.py`)
|
| 132 |
+
|
| 133 |
+
用于计算信息级联数据的指标:
|
| 134 |
+
- **情感得分** (Sentiment Score)
|
| 135 |
+
- **情感偏差** (Sentiment Deviation)
|
| 136 |
+
- **语境偏差** (Contextual Deviation)
|
| 137 |
+
- **困惑度** (Perplexity)
|
| 138 |
+
|
| 139 |
+
详细说明请参考 `COMPUTE_METRICS_README.md`
|
| 140 |
+
|
| 141 |
+
### 2. 相关文件
|
| 142 |
+
|
| 143 |
+
- `compute_cascade_metrics.py`: 主计算脚本
|
| 144 |
+
- `COMPUTE_METRICS_README.md`: 使用说明
|
| 145 |
+
- `requirements_compute_metrics.txt`: 额外依赖
|
| 146 |
+
- `example_compute_metrics.sh`: 示例脚本
|
| 147 |
+
- `cleanup_for_hf.py`: 清理脚本
|
| 148 |
+
|
| 149 |
+
## ⚠️ 注意事项
|
| 150 |
+
|
| 151 |
+
1. **大文件处理**
|
| 152 |
+
- 如果文件 >50MB,考虑使用 Git LFS
|
| 153 |
+
- 或排除数据文件,使用外部链接
|
| 154 |
+
|
| 155 |
+
2. **敏感信息**
|
| 156 |
+
- 不要上传包含 API 密钥、密码的文件
|
| 157 |
+
- 检查配置文件中的敏感数据
|
| 158 |
+
|
| 159 |
+
3. **许可证**
|
| 160 |
+
- 确保所有代码都有适当的许可证
|
| 161 |
+
- 原始 EasyTPP 使用 Apache 2.0 许可证
|
| 162 |
+
|
| 163 |
+
4. **版本控制**
|
| 164 |
+
- 建议使用 Git 进行版本控制
|
| 165 |
+
- 每次更新后提交并推送
|
| 166 |
+
|
| 167 |
+
## 🔍 验证上传
|
| 168 |
+
|
| 169 |
+
上传后检查:
|
| 170 |
+
- [ ] 所有文件都已上传
|
| 171 |
+
- [ ] README 显示正确
|
| 172 |
+
- [ ] 代码可以正常下载
|
| 173 |
+
- [ ] 依赖可以正常安装
|
| 174 |
+
|
| 175 |
+
## 📞 问题反馈
|
| 176 |
+
|
| 177 |
+
如有问题,请检查:
|
| 178 |
+
1. Hugging Face 仓库设置是否正确
|
| 179 |
+
2. 文件大小是否超过限制
|
| 180 |
+
3. 是否有权限问题
|
LICENCE
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright 2022 The EasyTPP Authors. All rights reserved.
|
| 2 |
+
|
| 3 |
+
Apache License
|
| 4 |
+
Version 2.0, January 2004
|
| 5 |
+
http://www.apache.org/licenses/
|
| 6 |
+
|
| 7 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 8 |
+
|
| 9 |
+
1. Definitions.
|
| 10 |
+
|
| 11 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 12 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 13 |
+
|
| 14 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 15 |
+
the copyright owner that is granting the License.
|
| 16 |
+
|
| 17 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 18 |
+
other entities that control, are controlled by, or are under common
|
| 19 |
+
control with that entity. For the purposes of this definition,
|
| 20 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 21 |
+
direction or management of such entity, whether by contract or
|
| 22 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 23 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 24 |
+
|
| 25 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 26 |
+
exercising permissions granted by this License.
|
| 27 |
+
|
| 28 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 29 |
+
including but not limited to software source code, documentation
|
| 30 |
+
source, and configuration files.
|
| 31 |
+
|
| 32 |
+
"Object" form shall mean any form resulting from mechanical
|
| 33 |
+
transformation or translation of a Source form, including but
|
| 34 |
+
not limited to compiled object code, generated documentation,
|
| 35 |
+
and conversions to other media types.
|
| 36 |
+
|
| 37 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 38 |
+
Object form, made available under the License, as indicated by a
|
| 39 |
+
copyright notice that is included in or attached to the work
|
| 40 |
+
(an example is provided in the Appendix below).
|
| 41 |
+
|
| 42 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 43 |
+
form, that is based on (or derived from) the Work and for which the
|
| 44 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 45 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 46 |
+
of this License, Derivative Works shall not include works that remain
|
| 47 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 48 |
+
the Work and Derivative Works thereof.
|
| 49 |
+
|
| 50 |
+
"Contribution" shall mean any work of authorship, including
|
| 51 |
+
the original version of the Work and any modifications or additions
|
| 52 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 53 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 54 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 55 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 56 |
+
means any form of electronic, verbal, or written communication sent
|
| 57 |
+
to the Licensor or its representatives, including but not limited to
|
| 58 |
+
communication on electronic mailing lists, source code control systems,
|
| 59 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 60 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 61 |
+
excluding communication that is conspicuously marked or otherwise
|
| 62 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 63 |
+
|
| 64 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 65 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 66 |
+
subsequently incorporated within the Work.
|
| 67 |
+
|
| 68 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 69 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 70 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 71 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 72 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 73 |
+
Work and such Derivative Works in Source or Object form.
|
| 74 |
+
|
| 75 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 76 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 77 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 78 |
+
(except as stated in this section) patent license to make, have made,
|
| 79 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 80 |
+
where such license applies only to those patent claims licensable
|
| 81 |
+
by such Contributor that are necessarily infringed by their
|
| 82 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 83 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 84 |
+
institute patent litigation against any entity (including a
|
| 85 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 86 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 87 |
+
or contributory patent infringement, then any patent licenses
|
| 88 |
+
granted to You under this License for that Work shall terminate
|
| 89 |
+
as of the date such litigation is filed.
|
| 90 |
+
|
| 91 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 92 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 93 |
+
modifications, and in Source or Object form, provided that You
|
| 94 |
+
meet the following conditions:
|
| 95 |
+
|
| 96 |
+
(a) You must give any other recipients of the Work or
|
| 97 |
+
Derivative Works a copy of this License; and
|
| 98 |
+
|
| 99 |
+
(b) You must cause any modified files to carry prominent notices
|
| 100 |
+
stating that You changed the files; and
|
| 101 |
+
|
| 102 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 103 |
+
that You distribute, all copyright, patent, trademark, and
|
| 104 |
+
attribution notices from the Source form of the Work,
|
| 105 |
+
excluding those notices that do not pertain to any part of
|
| 106 |
+
the Derivative Works; and
|
| 107 |
+
|
| 108 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 109 |
+
distribution, then any Derivative Works that You distribute must
|
| 110 |
+
include a readable copy of the attribution notices contained
|
| 111 |
+
within such NOTICE file, excluding those notices that do not
|
| 112 |
+
pertain to any part of the Derivative Works, in at least one
|
| 113 |
+
of the following places: within a NOTICE text file distributed
|
| 114 |
+
as part of the Derivative Works; within the Source form or
|
| 115 |
+
documentation, if provided along with the Derivative Works; or,
|
| 116 |
+
within a display generated by the Derivative Works, if and
|
| 117 |
+
wherever such third-party notices normally appear. The contents
|
| 118 |
+
of the NOTICE file are for informational purposes only and
|
| 119 |
+
do not modify the License. You may add Your own attribution
|
| 120 |
+
notices within Derivative Works that You distribute, alongside
|
| 121 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 122 |
+
that such additional attribution notices cannot be construed
|
| 123 |
+
as modifying the License.
|
| 124 |
+
|
| 125 |
+
You may add Your own copyright statement to Your modifications and
|
| 126 |
+
may provide additional or different license terms and conditions
|
| 127 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 128 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 129 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 130 |
+
the conditions stated in this License.
|
| 131 |
+
|
| 132 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 133 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 134 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 135 |
+
this License, without any additional terms or conditions.
|
| 136 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 137 |
+
the terms of any separate license agreement you may have executed
|
| 138 |
+
with Licensor regarding such Contributions.
|
| 139 |
+
|
| 140 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 141 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 142 |
+
except as required for reasonable and customary use in describing the
|
| 143 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 144 |
+
|
| 145 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 146 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 147 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 148 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 149 |
+
implied, including, without limitation, any warranties or conditions
|
| 150 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 151 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 152 |
+
appropriateness of using or redistributing the Work and assume any
|
| 153 |
+
risks associated with Your exercise of permissions under this License.
|
| 154 |
+
|
| 155 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 156 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 157 |
+
unless required by applicable law (such as deliberate and grossly
|
| 158 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 159 |
+
liable to You for damages, including any direct, indirect, special,
|
| 160 |
+
incidental, or consequential damages of any character arising as a
|
| 161 |
+
result of this License or out of the use or inability to use the
|
| 162 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 163 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 164 |
+
other commercial damages or losses), even if such Contributor
|
| 165 |
+
has been advised of the possibility of such damages.
|
| 166 |
+
|
| 167 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 168 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 169 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 170 |
+
or other liability obligations and/or rights consistent with this
|
| 171 |
+
License. However, in accepting such obligations, You may act only
|
| 172 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 173 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 174 |
+
defend, and hold each Contributor harmless for any liability
|
| 175 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 176 |
+
of your accepting any such warranty or additional liability.
|
| 177 |
+
|
| 178 |
+
END OF TERMS AND CONDITIONS
|
| 179 |
+
|
| 180 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 181 |
+
|
| 182 |
+
To apply the Apache License to your work, attach the following
|
| 183 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 184 |
+
replaced with your own identifying information. (Don't include
|
| 185 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 186 |
+
comment syntax for the file format. We also recommend that a
|
| 187 |
+
file or class name and description of purpose be included on the
|
| 188 |
+
same "printed page" as the copyright notice for easier
|
| 189 |
+
identification within third-party archives.
|
| 190 |
+
|
| 191 |
+
Copyright [yyyy] [name of copyright owner]
|
| 192 |
+
|
| 193 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 194 |
+
you may not use this file except in compliance with the License.
|
| 195 |
+
You may obtain a copy of the License at
|
| 196 |
+
|
| 197 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 198 |
+
|
| 199 |
+
Unless required by applicable law or agreed to in writing, software
|
| 200 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 201 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 202 |
+
See the License for the specific language governing permissions and
|
| 203 |
+
limitations under the License.
|
MANIFEST.in
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
include requirements.txt
|
| 2 |
+
include version.py
|
NOTICE
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
=============================================================
|
| 2 |
+
EasyTPP is a open source tool developed by Machine Intelligence Team
|
| 3 |
+
Copyright (c) 2020-2022, Ant Group Holding Limited.
|
| 4 |
+
Licensed under the Apache License, Version 2.0
|
| 5 |
+
|
| 6 |
+
=============================================================
|
| 7 |
+
This toolkit contains various third-party components under
|
| 8 |
+
different open source licenses
|
| 9 |
+
|
| 10 |
+
-----------------------------
|
| 11 |
+
Training evaluation pipeline
|
| 12 |
+
Apache License, Version 2.0
|
| 13 |
+
FuxiCTR authors
|
| 14 |
+
|
| 15 |
+
----------------------------
|
| 16 |
+
Training evaluation pipeline
|
| 17 |
+
Apache License, Version 2.0
|
| 18 |
+
EasyNLP, Alibaba Inc.
|
| 19 |
+
|
| 20 |
+
----------------------------
|
| 21 |
+
Tokenizer and DataLoader
|
| 22 |
+
Apache License, Version 2.0
|
| 23 |
+
The HuggingFace Inc. team
|
QUICK_START_HF.md
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 快速开始:上传到 Hugging Face
|
| 2 |
+
|
| 3 |
+
## ✅ 整理完成
|
| 4 |
+
|
| 5 |
+
文件夹已整理完成,可以上传到 Hugging Face。
|
| 6 |
+
|
| 7 |
+
**文件夹大小**: 1.3MB(适合上传)
|
| 8 |
+
|
| 9 |
+
## 📋 整理内容
|
| 10 |
+
|
| 11 |
+
### 已完成的清理
|
| 12 |
+
- ✅ 更新了 `.gitignore` 文件
|
| 13 |
+
- ✅ 创建了清理脚本 `cleanup_for_hf.py`
|
| 14 |
+
- ✅ 检查了大文件(无大文件)
|
| 15 |
+
- ✅ 创建了上传指南
|
| 16 |
+
|
| 17 |
+
### 新增文件
|
| 18 |
+
- `compute_cascade_metrics.py` - 级联指标计算脚本
|
| 19 |
+
- `COMPUTE_METRICS_README.md` - 指标计算说明
|
| 20 |
+
- `HF_UPLOAD_GUIDE.md` - 上传指南
|
| 21 |
+
- `ADDITIONS_README.md` - 新增功能说明
|
| 22 |
+
- `cleanup_for_hf.py` - 清理脚本
|
| 23 |
+
- `requirements_compute_metrics.txt` - 额外依赖
|
| 24 |
+
|
| 25 |
+
## 🚀 三步上传
|
| 26 |
+
|
| 27 |
+
### 步骤1: 安装 Hugging Face CLI
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
pip install huggingface_hub
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
### 步骤2: 登录
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
huggingface-cli login
|
| 37 |
+
# 输入你的 token(在 https://huggingface.co/settings/tokens 获取)
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
### 步骤3: 创建仓库并上传
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
# 1. 在网页上创建仓库
|
| 44 |
+
# 访问 https://huggingface.co/new
|
| 45 |
+
# 选择 "Dataset",命名为例如:easytpp-cascade-metrics
|
| 46 |
+
|
| 47 |
+
# 2. 上传文件
|
| 48 |
+
cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main
|
| 49 |
+
huggingface-cli upload <your-username>/easytpp-cascade-metrics . --repo-type dataset
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## 📥 在云电脑上下载
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
# 方法1: 使用 CLI
|
| 56 |
+
huggingface-cli download <your-username>/easytpp-cascade-metrics --local-dir ./EasyTPP
|
| 57 |
+
|
| 58 |
+
# 方法2: 使用 Git
|
| 59 |
+
git clone https://huggingface.co/datasets/<your-username>/easytpp-cascade-metrics
|
| 60 |
+
cd easytpp-cascade-metrics
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
### ⚠️ 重要:数据文件需要单独获取
|
| 64 |
+
|
| 65 |
+
代码仓库**不包含**数据文件(已通过 .gitignore 排除,因为文件太大)。
|
| 66 |
+
|
| 67 |
+
数据文件需要单独传输:
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
# 在云电脑上创建目录
|
| 71 |
+
mkdir -p data/cascades
|
| 72 |
+
|
| 73 |
+
# 方法1: 使用 scp 从本地传输(推荐)
|
| 74 |
+
scp user@local-machine:/path/to/information_cascade*.json ./data/cascades/
|
| 75 |
+
|
| 76 |
+
# 方法2: 如果已上传到 Hugging Face Dataset Hub
|
| 77 |
+
huggingface-cli download <username>/cascade-data --local-dir ./data/cascades
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
详细说明请参考 `DATA_FILES_NOTICE.md`
|
| 81 |
+
|
| 82 |
+
## 📚 相关文档
|
| 83 |
+
|
| 84 |
+
- **详细上传指南**: `HF_UPLOAD_GUIDE.md`
|
| 85 |
+
- **数据文件说明**: `DATA_FILES_NOTICE.md` ⚠️ **重要**
|
| 86 |
+
- **指标计算说明**: `COMPUTE_METRICS_README.md`
|
| 87 |
+
- **新增功能**: `ADDITIONS_README.md`
|
| 88 |
+
- **上传检查清单**: `UPLOAD_CHECKLIST.md`
|
| 89 |
+
|
| 90 |
+
## ⚠️ 注意事项
|
| 91 |
+
|
| 92 |
+
1. **文件大小**: 当前文件夹 1.3MB,无需 Git LFS
|
| 93 |
+
2. **敏感信息**: 已检查,无敏感信息
|
| 94 |
+
3. **依赖**: 确保 `requirements.txt` 和 `requirements_compute_metrics.txt` 已包含
|
| 95 |
+
|
| 96 |
+
## 🎯 下一步
|
| 97 |
+
|
| 98 |
+
1. 按照上述步骤上传到 Hugging Face
|
| 99 |
+
2. 在云电脑上下载并测试
|
| 100 |
+
3. 运行 `compute_cascade_metrics.py` 计算指标
|
| 101 |
+
|
| 102 |
+
---
|
| 103 |
+
|
| 104 |
+
**准备好了!可以开始上传了!** 🎉
|
README.md
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EasyTPP [ICLR 2024]
|
| 2 |
+
|
| 3 |
+
<div align="center">
|
| 4 |
+
<a href="PyVersion">
|
| 5 |
+
<img alt="Python Version" src="https://img.shields.io/badge/python-3.9+-blue.svg">
|
| 6 |
+
</a>
|
| 7 |
+
<a href="LICENSE-CODE">
|
| 8 |
+
<img alt="Code License" src="https://img.shields.io/badge/license-Apache-000000.svg?&color=f5de53">
|
| 9 |
+
</a>
|
| 10 |
+
<a href="commit">
|
| 11 |
+
<img alt="Last Commit" src="https://img.shields.io/github/last-commit/ant-research/EasyTemporalPointProcess">
|
| 12 |
+
</a>
|
| 13 |
+
</div>
|
| 14 |
+
|
| 15 |
+
<div align="center">
|
| 16 |
+
<a href="https://pypi.python.org/pypi/easy-tpp/">
|
| 17 |
+
<img alt="PyPI version" src="https://img.shields.io/pypi/v/easy-tpp.svg?style=flat-square&color=b7534" />
|
| 18 |
+
</a>
|
| 19 |
+
<a href="https://static.pepy.tech/personalized-badge/easy-tpp">
|
| 20 |
+
<img alt="Downloads" src="https://static.pepy.tech/personalized-badge/easy-tpp?period=total&units=international_system&left_color=black&right_color=blue&left_text=Downloads" />
|
| 21 |
+
</a>
|
| 22 |
+
<a href="https://huggingface.co/easytpp" target="_blank">
|
| 23 |
+
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-EasyTPP-ffc107?color=ffc107&logoColor=white" />
|
| 24 |
+
</a>
|
| 25 |
+
<a href="https://github.com/ant-research/EasyTemporalPointProcess/issues">
|
| 26 |
+
<img alt="Open Issues" src="https://img.shields.io/github/issues-raw/ant-research/EasyTemporalPointProcess" />
|
| 27 |
+
</a>
|
| 28 |
+
</div>
|
| 29 |
+
|
| 30 |
+
`EasyTPP` is an easy-to-use development and application toolkit for [Temporal Point Process](https://mathworld.wolfram.com/TemporalPointProcess.html) (TPP), with key features in configurability, compatibility and reproducibility. We hope this project could benefit both researchers and practitioners with the goal of easily customized development and open benchmarking in TPP.
|
| 31 |
+
<span id='top'/>
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
| <a href='#features'>Features</a> | <a href='#model-list'>Model List</a> | <a href='#dataset'>Dataset</a> | <a href='#quick-start'>Quick Start</a> | <a href='#benchmark'>Benchmark</a> |<a href='#doc'>Documentation</a> |<a href='#todo'>Todo List</a> | <a href='#citation'>Citation</a> |<a href='#acknowledgment'>Acknowledgement</a> | <a href='#star-history'>Star History</a> |
|
| 36 |
+
|
| 37 |
+
## News
|
| 38 |
+
<span id='news'/>
|
| 39 |
+
|
| 40 |
+
-  [11-06-2025] We have released a new version of ``EasyTPP`` that exclusively supports PyTorch. TensorFlow support has been removed to streamline the codebase and focus on PyTorch-based implementations.
|
| 41 |
+
-  [11-05-2025] Added the implementation of the [S2P2](https://openreview.net/pdf?id=74SvE2GZwW) model, presented at NeurIPS'2025.
|
| 42 |
+
-  [02-17-2024] ``EasyTPP`` supports HuggingFace dataset API: all datasets have been published in [HuggingFace Repo](https://huggingface.co/easytpp) and see [tutorial notebook](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_1_dataset.ipynb) for an example of usage.
|
| 43 |
+
- [01-16-2024] Our paper [EasyTPP: Towards Open Benchmarking Temporal Point Process](https://arxiv.org/abs/2307.08097) is accepted by ICLR'2024!
|
| 44 |
+
<details>
|
| 45 |
+
<summary>Click to see previous news</summary>
|
| 46 |
+
<p>
|
| 47 |
+
- [09-30-2023] We published two textual event sequence datasets [GDELT](https://drive.google.com/drive/folders/1Ms-ATMMFf6v4eesfJndyuPLGtX58fCnk) and [Amazon-text-review](https://drive.google.com/drive/folders/1-SLYyrl7ucEG7NpSIF0eSoG9zcbZagZw) that are used in our paper [LAMP](https://arxiv.org/abs/2305.16646), where LLM can be applied for event prediction! See [Documentation](https://ant-research.github.io/EasyTemporalPointProcess/user_guide/dataset.html#preprocessed-datasets) for more details.
|
| 48 |
+
- [09-30-2023] Two of our papers [Language Model Can Improve Event Prediction by Few-Shot Abductive Reasoning](https://arxiv.org/abs/2305.16646) (LAMP) and [Prompt-augmented Temporal Point Process for Streaming Event Sequence](https://arxiv.org/abs/2310.04993) (PromptTPP) are accepted by NeurIPS'2023!
|
| 49 |
+
- [09-02-2023] We published two non-anthropogenic datasets [earthquake](https://drive.google.com/drive/folders/1ubeIz_CCNjHyuu6-XXD0T-gdOLm12rf4) and [volcano eruption](https://drive.google.com/drive/folders/1KSWbNi8LUwC-dxz1T5sOnd9zwAot95Tp?usp=drive_link)! See <a href='#dataset'>Dataset</a> for details.
|
| 50 |
+
- [05-29-2023] We released ``EasyTPP`` v0.0.1!
|
| 51 |
+
- [12-27-2022] Our paper [Bellman Meets Hawkes: Model-Based Reinforcement Learning via Temporal Point Processes](https://arxiv.org/abs/2201.12569) was accepted by AAAI'2023!
|
| 52 |
+
- [10-01-2022] Our paper [HYPRO: A Hybridly Normalized Probabilistic Model for Long-Horizon Prediction of Event Sequences](https://arxiv.org/abs/2210.01753) was accepted by NeurIPS'2022!
|
| 53 |
+
- [05-01-2022] We started to develop `EasyTPP`.</p>
|
| 54 |
+
</details>
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
## Features <a href='#top'>[Back to Top]</a>
|
| 58 |
+
<span id='features'/>
|
| 59 |
+
|
| 60 |
+
- **Configurable and customizable**: models are modularized and configurable,with abstract classes to support developing customized
|
| 61 |
+
TPP models.
|
| 62 |
+
- **PyTorch-based implementation**: `EasyTPP` implements state-of-the-art TPP models using PyTorch 1.7.0+, providing a clean and modern deep learning framework.
|
| 63 |
+
- **Reproducible**: all the benchmarks can be easily reproduced.
|
| 64 |
+
- **Hyper-parameter optimization**: a pipeline of [optuna](https://github.com/optuna/optuna)-based HPO is provided.
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
## Model List <a href='#top'>[Back to Top]</a>
|
| 68 |
+
<span id='model-list'/>
|
| 69 |
+
|
| 70 |
+
We provide reference implementations of various state-of-the-art TPP papers:
|
| 71 |
+
|
| 72 |
+
| No | Publication | Model | Paper | Implementation |
|
| 73 |
+
|:---:|:-----------:|:-------------:|:-----------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------|
|
| 74 |
+
| 1 | KDD'16 | RMTPP | [Recurrent Marked Temporal Point Processes: Embedding Event History to Vector](https://www.kdd.org/kdd2016/papers/files/rpp1081-duA.pdf) | [PyTorch](easy_tpp/model/torch_model/torch_rmtpp.py) |
|
| 75 |
+
| 2 | NeurIPS'17 | NHP | [The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process](https://arxiv.org/abs/1612.09328) | [PyTorch](easy_tpp/model/torch_model/torch_nhp.py) |
|
| 76 |
+
| 3 | NeurIPS'19 | FullyNN | [Fully Neural Network based Model for General Temporal Point Processes](https://arxiv.org/abs/1905.09690) | [PyTorch](easy_tpp/model/torch_model/torch_fullynn.py) |
|
| 77 |
+
| 4 | ICML'20 | SAHP | [Self-Attentive Hawkes process](https://arxiv.org/abs/1907.07561) | [PyTorch](easy_tpp/model/torch_model/torch_sahp.py) |
|
| 78 |
+
| 5 | ICML'20 | THP | [Transformer Hawkes process](https://arxiv.org/abs/2002.09291) | [PyTorch](easy_tpp/model/torch_model/torch_thp.py) |
|
| 79 |
+
| 6 | ICLR'20 | IntensityFree | [Intensity-Free Learning of Temporal Point Processes](https://arxiv.org/abs/1909.12127) | [PyTorch](easy_tpp/model/torch_model/torch_intensity_free.py) |
|
| 80 |
+
| 7 | ICLR'21 | ODETPP | [Neural Spatio-Temporal Point Processes (simplified)](https://arxiv.org/abs/2011.04583) | [PyTorch](easy_tpp/model/torch_model/torch_ode_tpp.py) |
|
| 81 |
+
| 8 | ICLR'22 | AttNHP | [Transformer Embeddings of Irregularly Spaced Events and Their Participants](https://arxiv.org/abs/2201.00044) | [PyTorch](easy_tpp/model/torch_model/torch_attnhp.py) |
|
| 82 |
+
| 9 | NeurIPS'25 | S2P2 | [Deep Continuous-Time State-Space Models for Marked Event Sequences](https://openreview.net/pdf?id=74SvE2GZwW) | [PyTorch](easy_tpp/model/torch_model/torch_s2p2.py) |
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
## Dataset <a href='#top'>[Back to Top]</a>
|
| 87 |
+
<span id='dataset'/>
|
| 88 |
+
|
| 89 |
+
We preprocessed one synthetic and five real world datasets from widely-cited works that contain diverse characteristics in terms of their application domains and temporal statistics:
|
| 90 |
+
- Synthetic: a univariate Hawkes process simulated by [Tick](https://github.com/X-DataInitiative/tick) library.
|
| 91 |
+
- Retweet ([Zhou, 2013](http://proceedings.mlr.press/v28/zhou13.pdf)): timestamped user retweet events.
|
| 92 |
+
- Taxi ([Whong, 2014](https://chriswhong.com/open-data/foil_nyc_taxi/)): timestamped taxi pick-up events.
|
| 93 |
+
- StackOverflow ([Leskovec, 2014](https://snap.stanford.edu/data/)): timestamped user badge reward events in StackOverflow.
|
| 94 |
+
- Taobao ([Xue et al, 2022](https://arxiv.org/abs/2210.01753)): timestamped user online shopping behavior events in Taobao platform.
|
| 95 |
+
- Amazon ([Xue et al, 2022](https://arxiv.org/abs/2210.01753)): timestamped user online shopping behavior events in Amazon platform.
|
| 96 |
+
|
| 97 |
+
Per users' request, we processed two non-anthropogenic datasets
|
| 98 |
+
- [Earthquake](https://drive.google.com/drive/folders/1ubeIz_CCNjHyuu6-XXD0T-gdOLm12rf4): timestamped earthquake events over the Conterminous U.S from 1996 to 2023, processed from [USGS](https://www.usgs.gov/programs/earthquake-hazards/science/earthquake-data).
|
| 99 |
+
- [Volcano eruption](https://drive.google.com/drive/folders/1KSWbNi8LUwC-dxz1T5sOnd9zwAot95Tp?usp=drive_link): timestamped volcano eruption events over the world in recent hundreds of years, processed from [The Smithsonian Institution](https://volcano.si.edu/).
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
All datasets are preprocess to the `Gatech` format dataset widely used for TPP researchers, and saved at [Google Drive](https://drive.google.com/drive/u/0/folders/1f8k82-NL6KFKuNMsUwozmbzDSFycYvz7) with a public access.
|
| 103 |
+
|
| 104 |
+
## Quick Start <a href='#top'>[Back to Top]</a>
|
| 105 |
+
<span id='quick-start'/>
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
### Colab Tutorials
|
| 109 |
+
|
| 110 |
+
Explore the following tutorials that can be opened directly in Google Colab:
|
| 111 |
+
|
| 112 |
+
- [](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_1_dataset.ipynb) Tutorial 1: Dataset in EasyTPP.
|
| 113 |
+
- [](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_2_tfb_wb.ipynb) Tutorial 2: Tensorboard in EasyTPP.
|
| 114 |
+
- [](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_3_train_eval.ipynb) Tutorial 3: Training and Evaluation of TPPs.
|
| 115 |
+
|
| 116 |
+
### End-to-end Example
|
| 117 |
+
|
| 118 |
+
We provide an end-to-end example for users to run a standard TPP model with `EasyTPP`.
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
### Step 1. Installation
|
| 122 |
+
|
| 123 |
+
First of all, we can install the package either by using pip or from the source code on Github.
|
| 124 |
+
|
| 125 |
+
To install the latest stable version:
|
| 126 |
+
```bash
|
| 127 |
+
pip install easy-tpp
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
To install the latest on GitHub:
|
| 131 |
+
```bash
|
| 132 |
+
git clone https://github.com/ant-research/EasyTemporalPointProcess.git
|
| 133 |
+
cd EasyTemporalPointProcess
|
| 134 |
+
python setup.py install
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
### Step 2. Prepare datasets
|
| 139 |
+
|
| 140 |
+
We need to put the datasets in a local directory before running a model and the datasets should follow a certain format. See [OnlineDoc - Datasets](https://ant-research.github.io/EasyTemporalPointProcess/user_guide/dataset.html) for more details.
|
| 141 |
+
|
| 142 |
+
Suppose we use the [taxi dataset](https://chriswhong.com/open-data/foil_nyc_taxi/) in the example.
|
| 143 |
+
|
| 144 |
+
### Step 3. Train the model
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
Before start training, we need to set up the config file for the pipeline. We provide a preset config file in [Example Config](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml). The details of the configuration can be found in [OnlineDoc - Training Pipeline](https://ant-research.github.io/EasyTemporalPointProcess/user_guide/run_train_pipeline.html).
|
| 148 |
+
|
| 149 |
+
After the setup of data and config, the directory structure is as follows:
|
| 150 |
+
|
| 151 |
+
```bash
|
| 152 |
+
|
| 153 |
+
data
|
| 154 |
+
|______taxi
|
| 155 |
+
|____ train.pkl
|
| 156 |
+
|____ dev.pkl
|
| 157 |
+
|____ test.pkl
|
| 158 |
+
|
| 159 |
+
configs
|
| 160 |
+
|______experiment_config.yaml
|
| 161 |
+
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
Then we start the training by simply running the script
|
| 166 |
+
|
| 167 |
+
```python
|
| 168 |
+
|
| 169 |
+
import argparse
|
| 170 |
+
from easy_tpp.config_factory import Config
|
| 171 |
+
from easy_tpp.runner import Runner
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def main():
|
| 175 |
+
parser = argparse.ArgumentParser()
|
| 176 |
+
|
| 177 |
+
parser.add_argument('--config_dir', type=str, required=False, default='configs/experiment_config.yaml',
|
| 178 |
+
help='Dir of configuration yaml to train and evaluate the model.')
|
| 179 |
+
|
| 180 |
+
parser.add_argument('--experiment_id', type=str, required=False, default='NHP_train',
|
| 181 |
+
help='Experiment id in the config file.')
|
| 182 |
+
|
| 183 |
+
args = parser.parse_args()
|
| 184 |
+
|
| 185 |
+
config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id)
|
| 186 |
+
|
| 187 |
+
model_runner = Runner.build_from_config(config)
|
| 188 |
+
|
| 189 |
+
model_runner.run()
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
if __name__ == '__main__':
|
| 193 |
+
main()
|
| 194 |
+
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
A more detailed example can be found at [OnlineDoc - QuickStart](https://ant-research.github.io/EasyTemporalPointProcess/get_started/quick_start.html).
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
## Documentation <a href='#top'>[Back to Top]</a>
|
| 201 |
+
<span id='doc'/>
|
| 202 |
+
|
| 203 |
+
The classes and methods of `EasyTPP` have been well documented so that users can generate the documentation by:
|
| 204 |
+
|
| 205 |
+
```shell
|
| 206 |
+
cd doc
|
| 207 |
+
pip install -r requirements.txt
|
| 208 |
+
make html
|
| 209 |
+
```
|
| 210 |
+
NOTE:
|
| 211 |
+
* The `doc/requirements.txt` is only for documentation by Sphinx, which can be automatically generated by Github actions `.github/workflows/docs.yml`. (Trigger by pull request.)
|
| 212 |
+
|
| 213 |
+
The full documentation is available on the [website](https://ant-research.github.io/EasyTemporalPointProcess/).
|
| 214 |
+
|
| 215 |
+
## Benchmark <a href='#top'>[Back to Top]</a>
|
| 216 |
+
<span id='benchmark'/>
|
| 217 |
+
|
| 218 |
+
In the [examples](https://github.com/ant-research/EasyTemporalPointProcess/tree/main/examples) folder, we provide a [script](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/benchmark_script.py) to benchmark the TPPs, with Taxi dataset as the input.
|
| 219 |
+
|
| 220 |
+
To run the script, one should download the Taxi data following the above instructions. The [config](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml) file is readily setup up. Then run
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
```shell
|
| 224 |
+
cd examples
|
| 225 |
+
python run_retweet.py
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
## License <a href='#top'>[Back to Top]</a>
|
| 230 |
+
|
| 231 |
+
This project is licensed under the [Apache License (Version 2.0)](https://github.com/alibaba/EasyNLP/blob/master/LICENSE). This toolkit also contains some code modified from other repos under other open-source licenses. See the [NOTICE](https://github.com/ant-research/EasyTPP/blob/master/NOTICE) file for more information.
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
## Todo List <a href='#top'>[Back to Top]</a>
|
| 235 |
+
<span id='todo'/>
|
| 236 |
+
|
| 237 |
+
- [x] New dataset:
|
| 238 |
+
- [x] Earthquake: the source data is available in [USGS](https://www.usgs.gov/programs/earthquake-hazards/science/earthquake-data).
|
| 239 |
+
- [x] Volcano eruption: the source data is available in [NCEI](https://www.ngdc.noaa.gov/hazard/volcano.shtml).
|
| 240 |
+
- [ ] New model:
|
| 241 |
+
- [ ] Meta Temporal Point Process, ICLR 2023.
|
| 242 |
+
- [ ] Model-based RL via TPP, AAAI 2022.
|
| 243 |
+
|
| 244 |
+
## Citation <a href='#top'>[Back to Top]</a>
|
| 245 |
+
|
| 246 |
+
<span id='citation'/>
|
| 247 |
+
|
| 248 |
+
If you find `EasyTPP` useful for your research or development, please cite the following <a href="https://arxiv.org/abs/2307.08097" target="_blank">paper</a>:
|
| 249 |
+
```
|
| 250 |
+
@inproceedings{xue2024easytpp,
|
| 251 |
+
title={EasyTPP: Towards Open Benchmarking Temporal Point Processes},
|
| 252 |
+
author={Siqiao Xue and Xiaoming Shi and Zhixuan Chu and Yan Wang and Hongyan Hao and Fan Zhou and Caigao Jiang and Chen Pan and James Y. Zhang and Qingsong Wen and Jun Zhou and Hongyuan Mei},
|
| 253 |
+
booktitle = {International Conference on Learning Representations (ICLR)},
|
| 254 |
+
year = {2024},
|
| 255 |
+
url ={https://arxiv.org/abs/2307.08097}
|
| 256 |
+
}
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
## Acknowledgment <a href='#top'>[Back to Top]</a>
|
| 260 |
+
<span id='acknowledgment'/>
|
| 261 |
+
|
| 262 |
+
The project is jointly initiated by Machine Intelligence Group, Alipay and DAMO Academy, Alibaba.
|
| 263 |
+
|
| 264 |
+
The following repositories are used in `EasyTPP`, either in close to original form or as an inspiration:
|
| 265 |
+
|
| 266 |
+
- [EasyRec](https://github.com/alibaba/EasyRec)
|
| 267 |
+
- [EasyNLP](https://github.com/alibaba/EasyNLP)
|
| 268 |
+
- [FuxiCTR](https://github.com/xue-pai/FuxiCTR)
|
| 269 |
+
- [Neural Hawkes Process](https://github.com/hongyuanmei/neurawkes)
|
| 270 |
+
- [Neural Hawkes Particle Smoothing](https://github.com/hongyuanmei/neural-hawkes-particle-smoothing)
|
| 271 |
+
- [Attentive Neural Hawkes Process](https://github.com/yangalan123/anhp-andtt)
|
| 272 |
+
- [Huggingface - transformers](https://github.com/huggingface/transformers)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
## Star History <a href='#top'>[Back to Top]</a>
|
| 276 |
+
<span id='star-history'/>
|
| 277 |
+
|
| 278 |
+

|
| 279 |
+
|
UPLOAD_CHECKLIST.md
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face 上传检查清单
|
| 2 |
+
|
| 3 |
+
## ✅ 清理完成
|
| 4 |
+
|
| 5 |
+
### 已删除的文件类型
|
| 6 |
+
- `__pycache__/` 文件夹
|
| 7 |
+
- `*.pyc`, `*.pyo`, `*.pyd` 文件
|
| 8 |
+
- `.DS_Store` 文件(macOS)
|
| 9 |
+
- `.vscode/`, `.idea/` 文件夹
|
| 10 |
+
- `*.swp`, `*.swo` 文件
|
| 11 |
+
|
| 12 |
+
### 需要手动检查的项目
|
| 13 |
+
|
| 14 |
+
1. **大文件检查**
|
| 15 |
+
- 检查是否有超过50MB的文件
|
| 16 |
+
- 考虑使用 Git LFS 或排除这些文件
|
| 17 |
+
|
| 18 |
+
2. **敏感信息检查**
|
| 19 |
+
- 检查是否有API密钥、密码等敏感信息
|
| 20 |
+
- 检查配置文件中的敏感数据
|
| 21 |
+
|
| 22 |
+
3. **数据文件**
|
| 23 |
+
- 检查 `examples/data/` 目录
|
| 24 |
+
- 如果数据文件很大,考虑排除或使用外部链接
|
| 25 |
+
|
| 26 |
+
4. **模型文件**
|
| 27 |
+
- 检查是否有预训练模型文件
|
| 28 |
+
- 大模型文件应使用 Git LFS 或 Hugging Face Model Hub
|
| 29 |
+
|
| 30 |
+
5. **日志文件**
|
| 31 |
+
- 确保没有日志文件被包含
|
| 32 |
+
- 检查 `log/`, `logs/` 目录
|
| 33 |
+
|
| 34 |
+
## 📦 上传到 Hugging Face
|
| 35 |
+
|
| 36 |
+
### 方法1: 使用 Hugging Face CLI
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
# 安装 Hugging Face CLI
|
| 40 |
+
pip install huggingface_hub
|
| 41 |
+
|
| 42 |
+
# 登录
|
| 43 |
+
huggingface-cli login
|
| 44 |
+
|
| 45 |
+
# 创建仓库(如果还没有)
|
| 46 |
+
# 在 https://huggingface.co/new 创建新仓库
|
| 47 |
+
|
| 48 |
+
# 上传文件
|
| 49 |
+
cd /path/to/EasyTemporalPointProcess-main
|
| 50 |
+
huggingface-cli upload <your-username>/<repo-name> . --repo-type dataset
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
### 方法2: 使用 Git
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
# 初始化 Git 仓库(如果还没有)
|
| 57 |
+
git init
|
| 58 |
+
git add .
|
| 59 |
+
git commit -m "Initial commit"
|
| 60 |
+
|
| 61 |
+
# 添加 Hugging Face 远程仓库
|
| 62 |
+
git remote add origin https://huggingface.co/<your-username>/<repo-name>
|
| 63 |
+
|
| 64 |
+
# 推送
|
| 65 |
+
git push origin main
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
### 方法3: 使用 Web 界面
|
| 69 |
+
|
| 70 |
+
1. 访问 https://huggingface.co/new
|
| 71 |
+
2. 创建新的 Dataset 或 Space
|
| 72 |
+
3. 使用 Web 界面上传文件
|
| 73 |
+
|
| 74 |
+
## 📝 文件结构说明
|
| 75 |
+
|
| 76 |
+
```
|
| 77 |
+
EasyTemporalPointProcess-main/
|
| 78 |
+
├── easy_tpp/ # 核心库代码
|
| 79 |
+
├── examples/ # 示例代码
|
| 80 |
+
├── notebooks/ # Jupyter notebooks
|
| 81 |
+
├── tests/ # 测试代码
|
| 82 |
+
├── docs/ # 文档
|
| 83 |
+
├── compute_cascade_metrics.py # 新增:级联指标计算脚本
|
| 84 |
+
├── COMPUTE_METRICS_README.md # 新增:指标计算说明
|
| 85 |
+
├── requirements.txt # 基础依赖
|
| 86 |
+
├── requirements_compute_metrics.txt # 新增:指标计算依赖
|
| 87 |
+
├── setup.py # 安装脚本
|
| 88 |
+
└── README.md # 项目说明
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
## ⚠️ 注意事项
|
| 92 |
+
|
| 93 |
+
1. **不要上传大文件到 Git 仓库**
|
| 94 |
+
- 使用 Git LFS 或 Hugging Face 的存储系统
|
| 95 |
+
- 考虑使用外部链接引用大文件
|
| 96 |
+
|
| 97 |
+
2. **检查许可证**
|
| 98 |
+
- 确保所有代码都有适当的许可证
|
| 99 |
+
- 检查第三方依赖的许可证兼容性
|
| 100 |
+
|
| 101 |
+
3. **README 文件**
|
| 102 |
+
- 确保 README.md 清晰说明项目用途
|
| 103 |
+
- 包含安装和使用说明
|
| 104 |
+
|
| 105 |
+
4. **依赖管理**
|
| 106 |
+
- 确保 requirements.txt 是最新的
|
| 107 |
+
- 考虑使用 `pip freeze` 生成精确版本
|
| 108 |
+
|
| 109 |
+
## 🔍 验证上传
|
| 110 |
+
|
| 111 |
+
上传后,检查:
|
| 112 |
+
- [ ] 所有文件都已上传
|
| 113 |
+
- [ ] 文件大小合理
|
| 114 |
+
- [ ] 没有敏感信息泄露
|
| 115 |
+
- [ ] README 显示正确
|
| 116 |
+
- [ ] 代码可以正常下载和使用
|
cleanup_for_hf.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
清理脚本:准备上传到 Hugging Face
|
| 5 |
+
|
| 6 |
+
该脚本会:
|
| 7 |
+
1. 清理不必要的文件(__pycache__, .pyc, .pyo等)
|
| 8 |
+
2. 检查大文件
|
| 9 |
+
3. 创建上传检查清单
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import shutil
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import List, Tuple
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def find_and_remove_patterns(root_dir: str, patterns: List[str]) -> List[str]:
|
| 19 |
+
"""
|
| 20 |
+
查找并删除匹配模式的文件/文件夹
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
root_dir: 根目录
|
| 24 |
+
patterns: 文件/文件夹模式列表
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
已删除的文件/文件夹列表
|
| 28 |
+
"""
|
| 29 |
+
removed = []
|
| 30 |
+
root_path = Path(root_dir)
|
| 31 |
+
|
| 32 |
+
for pattern in patterns:
|
| 33 |
+
for item in root_path.rglob(pattern):
|
| 34 |
+
if item.exists():
|
| 35 |
+
try:
|
| 36 |
+
if item.is_file():
|
| 37 |
+
item.unlink()
|
| 38 |
+
removed.append(str(item))
|
| 39 |
+
elif item.is_dir():
|
| 40 |
+
shutil.rmtree(item)
|
| 41 |
+
removed.append(str(item))
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"警告: 无法删除 {item}: {e}")
|
| 44 |
+
|
| 45 |
+
return removed
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def find_large_files(root_dir: str, size_mb: int = 50) -> List[Tuple[str, float]]:
|
| 49 |
+
"""
|
| 50 |
+
查找大文件
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
root_dir: 根目录
|
| 54 |
+
size_mb: 文件大小阈值(MB)
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
(文件路径, 大小MB) 列表
|
| 58 |
+
"""
|
| 59 |
+
large_files = []
|
| 60 |
+
root_path = Path(root_dir)
|
| 61 |
+
size_bytes = size_mb * 1024 * 1024
|
| 62 |
+
|
| 63 |
+
for item in root_path.rglob('*'):
|
| 64 |
+
if item.is_file():
|
| 65 |
+
try:
|
| 66 |
+
size = item.stat().st_size
|
| 67 |
+
if size > size_bytes:
|
| 68 |
+
size_mb_actual = size / (1024 * 1024)
|
| 69 |
+
large_files.append((str(item), size_mb_actual))
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"警告: 无法检查 {item}: {e}")
|
| 72 |
+
|
| 73 |
+
return large_files
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def check_gitignore(root_dir: str) -> bool:
|
| 77 |
+
"""
|
| 78 |
+
检查是否存在 .gitignore 文件
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
root_dir: 根目录
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
是否存在 .gitignore
|
| 85 |
+
"""
|
| 86 |
+
gitignore_path = Path(root_dir) / '.gitignore'
|
| 87 |
+
return gitignore_path.exists()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def create_upload_checklist(root_dir: str) -> str:
|
| 91 |
+
"""
|
| 92 |
+
创建上传检查清单
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
root_dir: 根目录
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
检查清单内容
|
| 99 |
+
"""
|
| 100 |
+
checklist = """# Hugging Face 上传检查清单
|
| 101 |
+
|
| 102 |
+
## ✅ 清理完成
|
| 103 |
+
|
| 104 |
+
### 已删除的文件类型
|
| 105 |
+
- `__pycache__/` 文件夹
|
| 106 |
+
- `*.pyc`, `*.pyo`, `*.pyd` 文件
|
| 107 |
+
- `.DS_Store` 文件(macOS)
|
| 108 |
+
- `.vscode/`, `.idea/` 文件夹
|
| 109 |
+
- `*.swp`, `*.swo` 文件
|
| 110 |
+
|
| 111 |
+
### 需要手动检查的项目
|
| 112 |
+
|
| 113 |
+
1. **大文件检查**
|
| 114 |
+
- 检查是否有超过50MB的文件
|
| 115 |
+
- 考虑使用 Git LFS 或排除这些文件
|
| 116 |
+
|
| 117 |
+
2. **敏感信息检查**
|
| 118 |
+
- 检查是否有API密钥、密码等敏感信息
|
| 119 |
+
- 检查配置文件中的敏感数据
|
| 120 |
+
|
| 121 |
+
3. **数据文件**
|
| 122 |
+
- 检查 `examples/data/` 目录
|
| 123 |
+
- 如果数据文件很大,考虑排除或使用外部链接
|
| 124 |
+
|
| 125 |
+
4. **模型文件**
|
| 126 |
+
- 检查是否有预训练模型文件
|
| 127 |
+
- 大模型文件应使用 Git LFS 或 Hugging Face Model Hub
|
| 128 |
+
|
| 129 |
+
5. **日志文件**
|
| 130 |
+
- 确保没有日志文件被包含
|
| 131 |
+
- 检查 `log/`, `logs/` 目录
|
| 132 |
+
|
| 133 |
+
## 📦 上传到 Hugging Face
|
| 134 |
+
|
| 135 |
+
### 方法1: 使用 Hugging Face CLI
|
| 136 |
+
|
| 137 |
+
```bash
|
| 138 |
+
# 安装 Hugging Face CLI
|
| 139 |
+
pip install huggingface_hub
|
| 140 |
+
|
| 141 |
+
# 登录
|
| 142 |
+
huggingface-cli login
|
| 143 |
+
|
| 144 |
+
# 创建仓库(如果还没有)
|
| 145 |
+
# 在 https://huggingface.co/new 创建新仓库
|
| 146 |
+
|
| 147 |
+
# 上传文件
|
| 148 |
+
cd /path/to/EasyTemporalPointProcess-main
|
| 149 |
+
huggingface-cli upload <your-username>/<repo-name> . --repo-type dataset
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
### 方法2: 使用 Git
|
| 153 |
+
|
| 154 |
+
```bash
|
| 155 |
+
# 初始化 Git 仓库(如果还没有)
|
| 156 |
+
git init
|
| 157 |
+
git add .
|
| 158 |
+
git commit -m "Initial commit"
|
| 159 |
+
|
| 160 |
+
# 添加 Hugging Face 远程仓库
|
| 161 |
+
git remote add origin https://huggingface.co/<your-username>/<repo-name>
|
| 162 |
+
|
| 163 |
+
# 推送
|
| 164 |
+
git push origin main
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
### 方法3: 使用 Web 界面
|
| 168 |
+
|
| 169 |
+
1. 访问 https://huggingface.co/new
|
| 170 |
+
2. 创建新的 Dataset 或 Space
|
| 171 |
+
3. 使用 Web 界面上传文件
|
| 172 |
+
|
| 173 |
+
## 📝 文件结构说明
|
| 174 |
+
|
| 175 |
+
```
|
| 176 |
+
EasyTemporalPointProcess-main/
|
| 177 |
+
├── easy_tpp/ # 核心库代码
|
| 178 |
+
├── examples/ # 示例代码
|
| 179 |
+
├── notebooks/ # Jupyter notebooks
|
| 180 |
+
├── tests/ # 测试代码
|
| 181 |
+
├── docs/ # 文档
|
| 182 |
+
├── compute_cascade_metrics.py # 新增:级联指标计算脚本
|
| 183 |
+
├── COMPUTE_METRICS_README.md # 新增:指标计算说明
|
| 184 |
+
├── requirements.txt # 基础依赖
|
| 185 |
+
├── requirements_compute_metrics.txt # 新增:指标计算依赖
|
| 186 |
+
├── setup.py # 安装脚本
|
| 187 |
+
└── README.md # 项目说明
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
## ⚠️ 注意事项
|
| 191 |
+
|
| 192 |
+
1. **不要上传大文件到 Git 仓库**
|
| 193 |
+
- 使用 Git LFS 或 Hugging Face 的存储系统
|
| 194 |
+
- 考虑使用外部链接引用大文件
|
| 195 |
+
|
| 196 |
+
2. **检查许可证**
|
| 197 |
+
- 确保所有代码都有适当的许可证
|
| 198 |
+
- 检查第三方依赖的许可证兼容性
|
| 199 |
+
|
| 200 |
+
3. **README 文件**
|
| 201 |
+
- 确保 README.md 清晰说明项目用途
|
| 202 |
+
- 包含安装和使用说明
|
| 203 |
+
|
| 204 |
+
4. **依赖管理**
|
| 205 |
+
- 确保 requirements.txt 是最新的
|
| 206 |
+
- 考虑使用 `pip freeze` 生成精确版本
|
| 207 |
+
|
| 208 |
+
## 🔍 验证上传
|
| 209 |
+
|
| 210 |
+
上传后,检查:
|
| 211 |
+
- [ ] 所有文件都已上传
|
| 212 |
+
- [ ] 文件大小合理
|
| 213 |
+
- [ ] 没有敏感信息泄露
|
| 214 |
+
- [ ] README 显示正确
|
| 215 |
+
- [ ] 代码可以正常下载和使用
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
return checklist
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def main():
|
| 222 |
+
"""主函数"""
|
| 223 |
+
root_dir = os.path.dirname(os.path.abspath(__file__))
|
| 224 |
+
|
| 225 |
+
print("=" * 60)
|
| 226 |
+
print("清理脚本:准备上传到 Hugging Face")
|
| 227 |
+
print("=" * 60)
|
| 228 |
+
|
| 229 |
+
# 要删除的模式
|
| 230 |
+
patterns_to_remove = [
|
| 231 |
+
'__pycache__',
|
| 232 |
+
'*.pyc',
|
| 233 |
+
'*.pyo',
|
| 234 |
+
'*.pyd',
|
| 235 |
+
'.DS_Store',
|
| 236 |
+
'.vscode',
|
| 237 |
+
'.idea',
|
| 238 |
+
'*.swp',
|
| 239 |
+
'*.swo',
|
| 240 |
+
'*.log',
|
| 241 |
+
'.pytest_cache',
|
| 242 |
+
'.mypy_cache',
|
| 243 |
+
'.ruff_cache',
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
print("\n1. 清理不必要的文件...")
|
| 247 |
+
removed = find_and_remove_patterns(root_dir, patterns_to_remove)
|
| 248 |
+
if removed:
|
| 249 |
+
print(f" 已删除 {len(removed)} 个文件/文件夹")
|
| 250 |
+
for item in removed[:10]: # 只显示前10个
|
| 251 |
+
print(f" - {item}")
|
| 252 |
+
if len(removed) > 10:
|
| 253 |
+
print(f" ... 还有 {len(removed) - 10} 个文件/文件夹")
|
| 254 |
+
else:
|
| 255 |
+
print(" 没有找到需要删除的文件")
|
| 256 |
+
|
| 257 |
+
# 检查大文件
|
| 258 |
+
print("\n2. 检查大文件(>50MB)...")
|
| 259 |
+
large_files = find_large_files(root_dir, size_mb=50)
|
| 260 |
+
if large_files:
|
| 261 |
+
print(f" 找到 {len(large_files)} 个大文件:")
|
| 262 |
+
for file_path, size_mb in large_files:
|
| 263 |
+
print(f" - {file_path} ({size_mb:.2f} MB)")
|
| 264 |
+
print("\n ⚠️ 建议:大文件应使用 Git LFS 或排除在上传之外")
|
| 265 |
+
else:
|
| 266 |
+
print(" ✅ 没有找到大文件")
|
| 267 |
+
|
| 268 |
+
# 检查 .gitignore
|
| 269 |
+
print("\n3. 检查 .gitignore...")
|
| 270 |
+
if check_gitignore(root_dir):
|
| 271 |
+
print(" ✅ .gitignore 文件存在")
|
| 272 |
+
else:
|
| 273 |
+
print(" ⚠️ 警告: .gitignore 文件不存在")
|
| 274 |
+
|
| 275 |
+
# 创建检查清单
|
| 276 |
+
print("\n4. 创建上传检查清单...")
|
| 277 |
+
checklist_content = create_upload_checklist(root_dir)
|
| 278 |
+
checklist_path = Path(root_dir) / 'UPLOAD_CHECKLIST.md'
|
| 279 |
+
with open(checklist_path, 'w', encoding='utf-8') as f:
|
| 280 |
+
f.write(checklist_content)
|
| 281 |
+
print(f" ✅ 已创建: {checklist_path}")
|
| 282 |
+
|
| 283 |
+
print("\n" + "=" * 60)
|
| 284 |
+
print("清理完成!")
|
| 285 |
+
print("=" * 60)
|
| 286 |
+
print("\n下一步:")
|
| 287 |
+
print("1. 查看 UPLOAD_CHECKLIST.md 了解上传步骤")
|
| 288 |
+
print("2. 检查是否有敏感信息需要移除")
|
| 289 |
+
print("3. 按照检查清单上传到 Hugging Face")
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
if __name__ == '__main__':
|
| 293 |
+
main()
|
compute_cascade_metrics.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
计算信息级联的指标:情感得分、情感deviation、contextual deviation、perplexity
|
| 5 |
+
|
| 6 |
+
该脚本处理 information_cascade.json 和 information_cascade_original_posts.json,
|
| 7 |
+
计算以下指标:
|
| 8 |
+
1. 情感得分 (sentiment score)
|
| 9 |
+
2. 情感deviation (sentiment deviation)
|
| 10 |
+
3. Contextual deviation (语境偏差)
|
| 11 |
+
4. Perplexity (困惑度)
|
| 12 |
+
|
| 13 |
+
使用方法(在云电脑上):
|
| 14 |
+
python compute_cascade_metrics.py \
|
| 15 |
+
--input_cascade information_cascade.json \
|
| 16 |
+
--input_original information_cascade_original_posts.json \
|
| 17 |
+
--output output_with_metrics.json \
|
| 18 |
+
--bert_model bert-base-chinese \
|
| 19 |
+
--sentiment_model <sentiment_model_path> \
|
| 20 |
+
--perplexity_model <perplexity_model_path> \
|
| 21 |
+
--batch_size 32
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import json
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
from typing import Dict, List, Any, Optional, Tuple
|
| 29 |
+
from tqdm import tqdm
|
| 30 |
+
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
|
| 31 |
+
import os
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class CascadeMetricsComputer:
|
| 35 |
+
"""
|
| 36 |
+
计算级联数据的各种指标
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
bert_model_name: str = 'bert-base-chinese',
|
| 42 |
+
sentiment_model_name: Optional[str] = None,
|
| 43 |
+
perplexity_model_name: Optional[str] = None,
|
| 44 |
+
device: Optional[str] = None,
|
| 45 |
+
batch_size: int = 32,
|
| 46 |
+
max_length: int = 512
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
初始化指标计算器
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
bert_model_name: BERT模型名称(用于计算语义向量和contextual deviation)
|
| 53 |
+
sentiment_model_name: 情感分析模型名称(用于计算情感得分)
|
| 54 |
+
perplexity_model_name: 语言模型名称(用于计算困惑度)
|
| 55 |
+
device: 计算设备('cuda'或'cpu'),如果为None则自动选择
|
| 56 |
+
batch_size: 批处理大小
|
| 57 |
+
max_length: 最大序列长度
|
| 58 |
+
"""
|
| 59 |
+
if device is None:
|
| 60 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 61 |
+
|
| 62 |
+
self.device = device
|
| 63 |
+
self.batch_size = batch_size
|
| 64 |
+
self.max_length = max_length
|
| 65 |
+
|
| 66 |
+
print(f"正在加载BERT模型: {bert_model_name}")
|
| 67 |
+
self.bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
|
| 68 |
+
self.bert_model = AutoModel.from_pretrained(bert_model_name)
|
| 69 |
+
self.bert_model.to(device)
|
| 70 |
+
self.bert_model.eval()
|
| 71 |
+
print(f"BERT模型已加载到设备: {device}")
|
| 72 |
+
|
| 73 |
+
# 加载情感分析模型
|
| 74 |
+
if sentiment_model_name:
|
| 75 |
+
print(f"正在加载情感分析模型: {sentiment_model_name}")
|
| 76 |
+
self.sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
|
| 77 |
+
self.sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name)
|
| 78 |
+
self.sentiment_model.to(device)
|
| 79 |
+
self.sentiment_model.eval()
|
| 80 |
+
print(f"情感分析模型已加载到设备: {device}")
|
| 81 |
+
else:
|
| 82 |
+
self.sentiment_tokenizer = None
|
| 83 |
+
self.sentiment_model = None
|
| 84 |
+
print("未提供情感分析模型,将使用简化的情感计算方法")
|
| 85 |
+
|
| 86 |
+
# 加载困惑度模型(语言模型)
|
| 87 |
+
if perplexity_model_name:
|
| 88 |
+
print(f"正在加载困惑度模型: {perplexity_model_name}")
|
| 89 |
+
self.perplexity_tokenizer = AutoTokenizer.from_pretrained(perplexity_model_name)
|
| 90 |
+
self.perplexity_model = AutoModelForCausalLM.from_pretrained(perplexity_model_name)
|
| 91 |
+
self.perplexity_model.to(device)
|
| 92 |
+
self.perplexity_model.eval()
|
| 93 |
+
print(f"困惑度模型已加载到设备: {device}")
|
| 94 |
+
else:
|
| 95 |
+
self.perplexity_tokenizer = None
|
| 96 |
+
self.perplexity_model = None
|
| 97 |
+
print("未提供困惑度模型,将使用简化的困惑度计算方法")
|
| 98 |
+
|
| 99 |
+
def compute_embeddings(self, texts: List[str]) -> np.ndarray:
|
| 100 |
+
"""
|
| 101 |
+
计算BERT语义向量
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
texts: 文本列表
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
语义向量矩阵 [num_texts, hidden_size]
|
| 108 |
+
"""
|
| 109 |
+
embeddings = []
|
| 110 |
+
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
for i in range(0, len(texts), self.batch_size):
|
| 113 |
+
batch_texts = texts[i:i + self.batch_size]
|
| 114 |
+
|
| 115 |
+
# 处理空文本
|
| 116 |
+
batch_texts = [text if text else "[PAD]" for text in batch_texts]
|
| 117 |
+
|
| 118 |
+
# 分词和编码
|
| 119 |
+
inputs = self.bert_tokenizer(
|
| 120 |
+
batch_texts,
|
| 121 |
+
return_tensors='pt',
|
| 122 |
+
padding=True,
|
| 123 |
+
truncation=True,
|
| 124 |
+
max_length=self.max_length
|
| 125 |
+
).to(self.device)
|
| 126 |
+
|
| 127 |
+
# 前向传播
|
| 128 |
+
outputs = self.bert_model(**inputs)
|
| 129 |
+
|
| 130 |
+
# 使用[CLS]标记的嵌入
|
| 131 |
+
batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
| 132 |
+
embeddings.append(batch_embeddings)
|
| 133 |
+
|
| 134 |
+
return np.vstack(embeddings)
|
| 135 |
+
|
| 136 |
+
def compute_sentiment_scores(self, texts: List[str]) -> List[float]:
|
| 137 |
+
"""
|
| 138 |
+
计算情感得分
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
texts: 文本列表
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
情感得分列表(每个文本一个得分,范围通常在[-1, 1]或[0, 1])
|
| 145 |
+
"""
|
| 146 |
+
if self.sentiment_model is None:
|
| 147 |
+
# 使用简化的情感计算方法
|
| 148 |
+
return self._compute_sentiment_simple(texts)
|
| 149 |
+
|
| 150 |
+
sentiment_scores = []
|
| 151 |
+
|
| 152 |
+
with torch.no_grad():
|
| 153 |
+
for i in range(0, len(texts), self.batch_size):
|
| 154 |
+
batch_texts = texts[i:i + self.batch_size]
|
| 155 |
+
batch_texts = [text if text else "[PAD]" for text in batch_texts]
|
| 156 |
+
|
| 157 |
+
inputs = self.sentiment_tokenizer(
|
| 158 |
+
batch_texts,
|
| 159 |
+
return_tensors='pt',
|
| 160 |
+
padding=True,
|
| 161 |
+
truncation=True,
|
| 162 |
+
max_length=self.max_length
|
| 163 |
+
).to(self.device)
|
| 164 |
+
|
| 165 |
+
outputs = self.sentiment_model(**inputs)
|
| 166 |
+
logits = outputs.logits
|
| 167 |
+
|
| 168 |
+
# 假设是二分类(正面/负面),使用softmax获取概率
|
| 169 |
+
probs = torch.softmax(logits, dim=-1)
|
| 170 |
+
|
| 171 |
+
# 计算情感得分:正面概率 - 负面概率(或使用其他方法)
|
| 172 |
+
if probs.shape[1] == 2:
|
| 173 |
+
# 二分类:[负面概率, 正面概率]
|
| 174 |
+
batch_scores = (probs[:, 1] - probs[:, 0]).cpu().numpy().tolist()
|
| 175 |
+
else:
|
| 176 |
+
# 多分类或其他情况,使用第一个类别的概率作为得分
|
| 177 |
+
batch_scores = probs[:, 0].cpu().numpy().tolist()
|
| 178 |
+
|
| 179 |
+
sentiment_scores.extend(batch_scores)
|
| 180 |
+
|
| 181 |
+
return sentiment_scores
|
| 182 |
+
|
| 183 |
+
def _compute_sentiment_simple(self, texts: List[str]) -> List[float]:
|
| 184 |
+
"""
|
| 185 |
+
简化的情感计算方法(基于启发式规则)
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
texts: 文本列表
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
情感得分列表
|
| 192 |
+
"""
|
| 193 |
+
scores = []
|
| 194 |
+
for text in texts:
|
| 195 |
+
if not text:
|
| 196 |
+
scores.append(0.0)
|
| 197 |
+
continue
|
| 198 |
+
|
| 199 |
+
# 简单的启发式方法
|
| 200 |
+
positive_words = ['好', '棒', '赞', '喜欢', '支持', '👍', '❤️', '😊', '😄']
|
| 201 |
+
negative_words = ['差', '坏', '讨厌', '反对', '👎', '😢', '😠', '😡']
|
| 202 |
+
|
| 203 |
+
positive_count = sum(1 for word in positive_words if word in text)
|
| 204 |
+
negative_count = sum(1 for word in negative_words if word in text)
|
| 205 |
+
|
| 206 |
+
# 计算情感得分(归一化到[-1, 1])
|
| 207 |
+
total_words = len(text)
|
| 208 |
+
if total_words > 0:
|
| 209 |
+
score = (positive_count - negative_count) / max(total_words, 1)
|
| 210 |
+
score = np.clip(score, -1.0, 1.0)
|
| 211 |
+
else:
|
| 212 |
+
score = 0.0
|
| 213 |
+
|
| 214 |
+
scores.append(score)
|
| 215 |
+
|
| 216 |
+
return scores
|
| 217 |
+
|
| 218 |
+
def compute_perplexity(self, texts: List[str]) -> List[float]:
|
| 219 |
+
"""
|
| 220 |
+
计算困惑度
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
texts: 文本列表
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
困惑度列表
|
| 227 |
+
"""
|
| 228 |
+
if self.perplexity_model is None:
|
| 229 |
+
# 使用简化的困惑度计算方法
|
| 230 |
+
return self._compute_perplexity_simple(texts)
|
| 231 |
+
|
| 232 |
+
perplexities = []
|
| 233 |
+
|
| 234 |
+
with torch.no_grad():
|
| 235 |
+
for text in texts:
|
| 236 |
+
if not text:
|
| 237 |
+
perplexities.append(0.0)
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
# 分词
|
| 241 |
+
inputs = self.perplexity_tokenizer(
|
| 242 |
+
text,
|
| 243 |
+
return_tensors='pt',
|
| 244 |
+
truncation=True,
|
| 245 |
+
max_length=self.max_length
|
| 246 |
+
).to(self.device)
|
| 247 |
+
|
| 248 |
+
# 计算困惑度
|
| 249 |
+
outputs = self.perplexity_model(**inputs, labels=inputs['input_ids'])
|
| 250 |
+
loss = outputs.loss
|
| 251 |
+
|
| 252 |
+
# 困惑度 = exp(loss)
|
| 253 |
+
perplexity = torch.exp(loss).item()
|
| 254 |
+
perplexities.append(perplexity)
|
| 255 |
+
|
| 256 |
+
return perplexities
|
| 257 |
+
|
| 258 |
+
def _compute_perplexity_simple(self, texts: List[str]) -> List[float]:
|
| 259 |
+
"""
|
| 260 |
+
简化的困惑度计算方法(基于词汇多样性)
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
texts: 文本列表
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
困惑度列表
|
| 267 |
+
"""
|
| 268 |
+
perplexities = []
|
| 269 |
+
|
| 270 |
+
for text in texts:
|
| 271 |
+
if not text:
|
| 272 |
+
perplexities.append(0.0)
|
| 273 |
+
continue
|
| 274 |
+
|
| 275 |
+
# 基于词汇多样性的简化方法
|
| 276 |
+
words = text.split()
|
| 277 |
+
unique_words = len(set(words))
|
| 278 |
+
total_words = len(words)
|
| 279 |
+
|
| 280 |
+
if total_words > 0:
|
| 281 |
+
# 词汇多样性越低,困惑度越高(简化代理)
|
| 282 |
+
perplexity_proxy = 1.0 - (unique_words / total_words)
|
| 283 |
+
else:
|
| 284 |
+
perplexity_proxy = 0.0
|
| 285 |
+
|
| 286 |
+
perplexities.append(perplexity_proxy)
|
| 287 |
+
|
| 288 |
+
return perplexities
|
| 289 |
+
|
| 290 |
+
def compute_cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
| 291 |
+
"""
|
| 292 |
+
计算余弦相似度
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
vec1: 向量1
|
| 296 |
+
vec2: 向量2
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
余弦相似度 [0, 1]
|
| 300 |
+
"""
|
| 301 |
+
dot_product = np.dot(vec1, vec2)
|
| 302 |
+
norm1 = np.linalg.norm(vec1)
|
| 303 |
+
norm2 = np.linalg.norm(vec2)
|
| 304 |
+
|
| 305 |
+
if norm1 == 0 or norm2 == 0:
|
| 306 |
+
return 0.0
|
| 307 |
+
|
| 308 |
+
similarity = dot_product / (norm1 * norm2)
|
| 309 |
+
return float(similarity)
|
| 310 |
+
|
| 311 |
+
def compute_contextual_deviation(self, root_embedding: np.ndarray, current_embedding: np.ndarray) -> float:
|
| 312 |
+
"""
|
| 313 |
+
计算语境偏差(Contextual Deviation)
|
| 314 |
+
|
| 315 |
+
定义为:1 - 语义相似度
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
root_embedding: 原帖的语义向量
|
| 319 |
+
current_embedding: 当前文本的语义向量
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
语境偏差值 [0, 1],越高表示越偏离原帖语境
|
| 323 |
+
"""
|
| 324 |
+
similarity = self.compute_cosine_similarity(root_embedding, current_embedding)
|
| 325 |
+
deviation = 1.0 - similarity
|
| 326 |
+
return deviation
|
| 327 |
+
|
| 328 |
+
def compute_sentiment_deviation(self, root_sentiment: float, current_sentiment: float) -> float:
|
| 329 |
+
"""
|
| 330 |
+
计算情感偏差(Sentiment Deviation)
|
| 331 |
+
|
| 332 |
+
定义为:|当前情感得分 - 原帖情感得分|
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
root_sentiment: 原帖的情感得分
|
| 336 |
+
current_sentiment: 当前文本的情感得分
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
情感偏差值 [0, 2](如果情感得分范围是[-1, 1])
|
| 340 |
+
"""
|
| 341 |
+
deviation = abs(current_sentiment - root_sentiment)
|
| 342 |
+
return deviation
|
| 343 |
+
|
| 344 |
+
def process_cascade(self, cascade: Dict[str, Any]) -> Dict[str, Any]:
|
| 345 |
+
"""
|
| 346 |
+
处理单个级联,计算所有指标
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
cascade: 级联数据字典
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
添加了指标后的级联数据字典
|
| 353 |
+
"""
|
| 354 |
+
# 1. 收集所有文本
|
| 355 |
+
texts: List[str] = []
|
| 356 |
+
indices: List[Tuple[str, Optional[str]]] = []
|
| 357 |
+
|
| 358 |
+
# 原帖
|
| 359 |
+
post_info = cascade.get('post_info', {})
|
| 360 |
+
post_content = post_info.get('content', '')
|
| 361 |
+
texts.append(post_content)
|
| 362 |
+
indices.append(('post', None))
|
| 363 |
+
|
| 364 |
+
# 评论
|
| 365 |
+
comment_tree = cascade.get('comment_tree', {})
|
| 366 |
+
comment_ids = list(comment_tree.keys())
|
| 367 |
+
for comment_id in comment_ids:
|
| 368 |
+
node = comment_tree[comment_id]
|
| 369 |
+
texts.append(node.get('content', ''))
|
| 370 |
+
indices.append(('comment', comment_id))
|
| 371 |
+
|
| 372 |
+
# 转发
|
| 373 |
+
repost_chain = cascade.get('repost_chain', [])
|
| 374 |
+
for node in repost_chain:
|
| 375 |
+
forward_text = node.get('forward_text', '') or ''
|
| 376 |
+
comment_content = node.get('comment_content', '') or ''
|
| 377 |
+
repost_text = forward_text + comment_content
|
| 378 |
+
texts.append(repost_text)
|
| 379 |
+
indices.append(('repost', node.get('repost_id')))
|
| 380 |
+
|
| 381 |
+
# 2. 批量计算特征
|
| 382 |
+
if len(texts) == 0:
|
| 383 |
+
return cascade
|
| 384 |
+
|
| 385 |
+
embeddings = self.compute_embeddings(texts)
|
| 386 |
+
sentiment_scores = self.compute_sentiment_scores(texts)
|
| 387 |
+
perplexities = self.compute_perplexity(texts)
|
| 388 |
+
|
| 389 |
+
# 3. 获取原帖的特征(用于计算偏差)
|
| 390 |
+
root_embedding = embeddings[0]
|
| 391 |
+
root_sentiment = sentiment_scores[0]
|
| 392 |
+
|
| 393 |
+
# 4. 将特征附加到级联数据中
|
| 394 |
+
# 原帖
|
| 395 |
+
post_info['embedding'] = root_embedding.tolist()
|
| 396 |
+
post_info['sentiment_score'] = root_sentiment
|
| 397 |
+
post_info['perplexity'] = perplexities[0]
|
| 398 |
+
|
| 399 |
+
# 评论
|
| 400 |
+
for i, comment_id in enumerate(comment_ids):
|
| 401 |
+
node = comment_tree[comment_id]
|
| 402 |
+
idx = 1 + i # 跳过原帖
|
| 403 |
+
|
| 404 |
+
node['embedding'] = embeddings[idx].tolist()
|
| 405 |
+
node['sentiment_score'] = sentiment_scores[idx]
|
| 406 |
+
node['perplexity'] = perplexities[idx]
|
| 407 |
+
|
| 408 |
+
# 计算偏差
|
| 409 |
+
node['contextual_deviation'] = self.compute_contextual_deviation(
|
| 410 |
+
root_embedding, embeddings[idx]
|
| 411 |
+
)
|
| 412 |
+
node['sentiment_deviation'] = self.compute_sentiment_deviation(
|
| 413 |
+
root_sentiment, sentiment_scores[idx]
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# 转发
|
| 417 |
+
offset = 1 + len(comment_ids)
|
| 418 |
+
for j, node in enumerate(repost_chain):
|
| 419 |
+
idx = offset + j
|
| 420 |
+
|
| 421 |
+
node['embedding'] = embeddings[idx].tolist()
|
| 422 |
+
node['sentiment_score'] = sentiment_scores[idx]
|
| 423 |
+
node['perplexity'] = perplexities[idx]
|
| 424 |
+
|
| 425 |
+
# 计算偏差
|
| 426 |
+
node['contextual_deviation'] = self.compute_contextual_deviation(
|
| 427 |
+
root_embedding, embeddings[idx]
|
| 428 |
+
)
|
| 429 |
+
node['sentiment_deviation'] = self.compute_sentiment_deviation(
|
| 430 |
+
root_sentiment, sentiment_scores[idx]
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
return cascade
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def load_json_file(file_path: str) -> Dict[str, Any]:
|
| 437 |
+
"""
|
| 438 |
+
加载JSON文件(支持大文件)
|
| 439 |
+
|
| 440 |
+
Args:
|
| 441 |
+
file_path: JSON文件路径
|
| 442 |
+
|
| 443 |
+
Returns:
|
| 444 |
+
数据字典
|
| 445 |
+
"""
|
| 446 |
+
print(f"正在加载JSON文件: {file_path}")
|
| 447 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 448 |
+
data = json.load(f)
|
| 449 |
+
print(f"已加载 {len(data.get('cascades', []))} 个级联")
|
| 450 |
+
return data
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def main():
|
| 454 |
+
parser = argparse.ArgumentParser(
|
| 455 |
+
description='计算信息级联的指标:情感得分、情感deviation、contextual deviation、perplexity'
|
| 456 |
+
)
|
| 457 |
+
parser.add_argument(
|
| 458 |
+
'--input_cascade',
|
| 459 |
+
type=str,
|
| 460 |
+
required=True,
|
| 461 |
+
help='输入级联JSON文件路径 (information_cascade.json)'
|
| 462 |
+
)
|
| 463 |
+
parser.add_argument(
|
| 464 |
+
'--input_original',
|
| 465 |
+
type=str,
|
| 466 |
+
default=None,
|
| 467 |
+
help='输入原帖JSON文件路径 (information_cascade_original_posts.json),可选'
|
| 468 |
+
)
|
| 469 |
+
parser.add_argument(
|
| 470 |
+
'--output',
|
| 471 |
+
type=str,
|
| 472 |
+
required=True,
|
| 473 |
+
help='输出JSON文件路径'
|
| 474 |
+
)
|
| 475 |
+
parser.add_argument(
|
| 476 |
+
'--bert_model',
|
| 477 |
+
type=str,
|
| 478 |
+
default='bert-base-chinese',
|
| 479 |
+
help='BERT模型名称或路径(用于计算语义向量)'
|
| 480 |
+
)
|
| 481 |
+
parser.add_argument(
|
| 482 |
+
'--sentiment_model',
|
| 483 |
+
type=str,
|
| 484 |
+
default=None,
|
| 485 |
+
help='情感分析模型名称或路径(可选)'
|
| 486 |
+
)
|
| 487 |
+
parser.add_argument(
|
| 488 |
+
'--perplexity_model',
|
| 489 |
+
type=str,
|
| 490 |
+
default=None,
|
| 491 |
+
help='语言模型名称或路径(用于计算困惑度,可选)'
|
| 492 |
+
)
|
| 493 |
+
parser.add_argument(
|
| 494 |
+
'--batch_size',
|
| 495 |
+
type=int,
|
| 496 |
+
default=32,
|
| 497 |
+
help='批处理大小'
|
| 498 |
+
)
|
| 499 |
+
parser.add_argument(
|
| 500 |
+
'--max_length',
|
| 501 |
+
type=int,
|
| 502 |
+
default=512,
|
| 503 |
+
help='最大序列长度'
|
| 504 |
+
)
|
| 505 |
+
parser.add_argument(
|
| 506 |
+
'--device',
|
| 507 |
+
type=str,
|
| 508 |
+
default=None,
|
| 509 |
+
help='计算设备(cuda/cpu),如果为None则自动选择'
|
| 510 |
+
)
|
| 511 |
+
parser.add_argument(
|
| 512 |
+
'--max_cascades',
|
| 513 |
+
type=int,
|
| 514 |
+
default=None,
|
| 515 |
+
help='最大处理级联数量(用于测试,None表示处理所有)'
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
args = parser.parse_args()
|
| 519 |
+
|
| 520 |
+
# 加载数据
|
| 521 |
+
cascade_data = load_json_file(args.input_cascade)
|
| 522 |
+
|
| 523 |
+
if args.input_original:
|
| 524 |
+
original_data = load_json_file(args.input_original)
|
| 525 |
+
# 如果需要合并数据,在这里处理
|
| 526 |
+
# 目前先只处理cascade_data
|
| 527 |
+
|
| 528 |
+
# 初始化指标计算器
|
| 529 |
+
print("\n初始化指标计算器...")
|
| 530 |
+
computer = CascadeMetricsComputer(
|
| 531 |
+
bert_model_name=args.bert_model,
|
| 532 |
+
sentiment_model_name=args.sentiment_model,
|
| 533 |
+
perplexity_model_name=args.perplexity_model,
|
| 534 |
+
device=args.device,
|
| 535 |
+
batch_size=args.batch_size,
|
| 536 |
+
max_length=args.max_length
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# 处理级联
|
| 540 |
+
cascades = cascade_data.get('cascades', [])
|
| 541 |
+
total_cascades = len(cascades)
|
| 542 |
+
if args.max_cascades:
|
| 543 |
+
cascades = cascades[:args.max_cascades]
|
| 544 |
+
|
| 545 |
+
print(f"\n开始处理 {len(cascades)}/{total_cascades} 个级联...")
|
| 546 |
+
processed_count = 0
|
| 547 |
+
for idx, cascade in enumerate(tqdm(cascades, desc="处理级联")):
|
| 548 |
+
try:
|
| 549 |
+
cascade_data['cascades'][idx] = computer.process_cascade(cascade)
|
| 550 |
+
processed_count += 1
|
| 551 |
+
except Exception as e:
|
| 552 |
+
print(f"\n处理级联 {idx} 时出错: {e}")
|
| 553 |
+
import traceback
|
| 554 |
+
traceback.print_exc()
|
| 555 |
+
continue
|
| 556 |
+
|
| 557 |
+
print(f"\n成功处理 {processed_count}/{len(cascades)} 个级联")
|
| 558 |
+
|
| 559 |
+
# 保存结果
|
| 560 |
+
print(f"\n正在保存结果到: {args.output}")
|
| 561 |
+
with open(args.output, 'w', encoding='utf-8') as f:
|
| 562 |
+
json.dump(cascade_data, f, ensure_ascii=False, indent=2)
|
| 563 |
+
|
| 564 |
+
print(f"✅ 完成!结果已保存到: {args.output}")
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
if __name__ == '__main__':
|
| 568 |
+
main()
|
data/cascades/.gitkeep
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 此目录用于存放级联数据文件
|
| 2 |
+
# 数据文件太大,已通过 .gitignore 排除
|
| 3 |
+
# 请参考 DATA_FILES_NOTICE.md 了解如何获取数据文件
|
data/cascades/README.md
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Cascade Data Files
|
| 2 |
+
|
| 3 |
+
本目录包含信息级联数据文件。
|
| 4 |
+
|
| 5 |
+
## 📁 文件说明
|
| 6 |
+
|
| 7 |
+
### 主要文件
|
| 8 |
+
|
| 9 |
+
1. **`information_cascade.json`** (606MB)
|
| 10 |
+
- 完整的级联数据,包含原帖、评论、转发等信息
|
| 11 |
+
- 用于计算级联指标和训练模型
|
| 12 |
+
|
| 13 |
+
2. **`information_cascade_original_posts.json`** (980MB)
|
| 14 |
+
- 原帖数据
|
| 15 |
+
- 包含原始微博帖子信息
|
| 16 |
+
|
| 17 |
+
## ⚠️ 文件大小说明
|
| 18 |
+
|
| 19 |
+
这些文件较大(总计约 1.6GB),**不会自动上传到 Git/Hugging Face**。
|
| 20 |
+
|
| 21 |
+
## 📥 如何获取数据文件
|
| 22 |
+
|
| 23 |
+
### 方法1: 手动下载
|
| 24 |
+
|
| 25 |
+
数据文件需要单独下载或传输到云电脑:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
# 在云电脑上创建目录
|
| 29 |
+
mkdir -p data/cascades
|
| 30 |
+
|
| 31 |
+
# 使用 scp 或其他方式传输文件
|
| 32 |
+
scp user@local:/path/to/information_cascade.json ./data/cascades/
|
| 33 |
+
scp user@local:/path/to/information_cascade_original_posts.json ./data/cascades/
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### 方法2: 使用 Git LFS(如果配置)
|
| 37 |
+
|
| 38 |
+
如果使用 Git LFS:
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
# 安装 Git LFS
|
| 42 |
+
git lfs install
|
| 43 |
+
|
| 44 |
+
# 跟踪大文件
|
| 45 |
+
git lfs track "data/cascades/*.json"
|
| 46 |
+
|
| 47 |
+
# 添加并提交
|
| 48 |
+
git add .gitattributes
|
| 49 |
+
git add data/cascades/*.json
|
| 50 |
+
git commit -m "Add cascade data files with LFS"
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
### 方法3: 使用外部存储
|
| 54 |
+
|
| 55 |
+
- 上传到云存储(如 Google Drive, Dropbox)
|
| 56 |
+
- 使用 Hugging Face Dataset Hub 的存储系统
|
| 57 |
+
- 使用对象存储服务(如 AWS S3, 阿里云 OSS)
|
| 58 |
+
|
| 59 |
+
## 🚀 使用数据文件
|
| 60 |
+
|
| 61 |
+
### 运行指标计算
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
python compute_cascade_metrics.py \
|
| 65 |
+
--input_cascade data/cascades/information_cascade.json \
|
| 66 |
+
--input_original data/cascades/information_cascade_original_posts.json \
|
| 67 |
+
--output output_with_metrics.json \
|
| 68 |
+
--batch_size 32
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
### 数据格式
|
| 72 |
+
|
| 73 |
+
JSON 文件格式:
|
| 74 |
+
```json
|
| 75 |
+
{
|
| 76 |
+
"cascades": [
|
| 77 |
+
{
|
| 78 |
+
"post_info": {
|
| 79 |
+
"content": "...",
|
| 80 |
+
"timestamp": "..."
|
| 81 |
+
},
|
| 82 |
+
"comment_tree": {...},
|
| 83 |
+
"repost_chain": [...]
|
| 84 |
+
}
|
| 85 |
+
]
|
| 86 |
+
}
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
详细格式说明请参考项目文档。
|
| 90 |
+
|
| 91 |
+
## 📝 注意事项
|
| 92 |
+
|
| 93 |
+
1. **文件大小**: 这些文件很大,确保有足够的磁盘空间
|
| 94 |
+
2. **内存**: 加载完整文件可能需要大量内存
|
| 95 |
+
3. **处理**: 建议使用批处理方式处理数据
|
| 96 |
+
4. **备份**: 建议保留数据文件的备份
|
| 97 |
+
|
| 98 |
+
## 🔗 相关文档
|
| 99 |
+
|
| 100 |
+
- [指标计算说明](../COMPUTE_METRICS_README.md)
|
| 101 |
+
- [上传指南](../HF_UPLOAD_GUIDE.md)
|
docs/Makefile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Minimal makefile for Sphinx documentation
|
| 2 |
+
#
|
| 3 |
+
|
| 4 |
+
# You can set these variables from the command line, and also
|
| 5 |
+
# from the environment for the first two.
|
| 6 |
+
SPHINXOPTS ?=
|
| 7 |
+
SPHINXBUILD ?= sphinx-build
|
| 8 |
+
SOURCEDIR = source
|
| 9 |
+
BUILDDIR = build
|
| 10 |
+
|
| 11 |
+
# Put it first so that "make" without argument is like "make help".
|
| 12 |
+
help:
|
| 13 |
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
| 14 |
+
|
| 15 |
+
.PHONY: help Makefile
|
| 16 |
+
|
| 17 |
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
| 18 |
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
| 19 |
+
%: Makefile
|
| 20 |
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
docs/README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Documentation for EasyTPP
|
| 2 |
+
|
| 3 |
+
This contains the full documentation of EasyTPP, which is hosted at github and can be updated manually (for releases)
|
| 4 |
+
by pushing to the gh-pages branch.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
To generate the documentation locally, type
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
pip install -r requirements-doc.txt
|
| 11 |
+
cd docs
|
| 12 |
+
make html
|
| 13 |
+
```
|
docs/images/thinning_algo.jpg
ADDED
|
Git LFS Details
|
docs/make.bat
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@ECHO OFF
|
| 2 |
+
|
| 3 |
+
pushd %~dp0
|
| 4 |
+
|
| 5 |
+
REM Command file for Sphinx documentation
|
| 6 |
+
|
| 7 |
+
if "%SPHINXBUILD%" == "" (
|
| 8 |
+
set SPHINXBUILD=sphinx-build
|
| 9 |
+
)
|
| 10 |
+
set SOURCEDIR=source
|
| 11 |
+
set BUILDDIR=build
|
| 12 |
+
|
| 13 |
+
%SPHINXBUILD% >NUL 2>NUL
|
| 14 |
+
if errorlevel 9009 (
|
| 15 |
+
echo.
|
| 16 |
+
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
| 17 |
+
echo.installed, then set the SPHINXBUILD environment variable to point
|
| 18 |
+
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
| 19 |
+
echo.may add the Sphinx directory to PATH.
|
| 20 |
+
echo.
|
| 21 |
+
echo.If you don't have Sphinx installed, grab it from
|
| 22 |
+
echo.https://www.sphinx-doc.org/
|
| 23 |
+
exit /b 1
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
if "%1" == "" goto help
|
| 27 |
+
|
| 28 |
+
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
| 29 |
+
goto end
|
| 30 |
+
|
| 31 |
+
:help
|
| 32 |
+
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
| 33 |
+
|
| 34 |
+
:end
|
| 35 |
+
popd
|
docs/source/advanced/implementation.rst
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
===================================
|
| 2 |
+
Model Implementation Details
|
| 3 |
+
===================================
|
| 4 |
+
|
| 5 |
+
Basic structure
|
| 6 |
+
===================================
|
| 7 |
+
|
| 8 |
+
In the model folder, `torch_basemodel` (**/model/torch_model/torch_basemodel.py**) / `tf_basemodel` (**/model/tf_model/tf_basemodel.py**) implements functionalities of computing loglikelihood and sampling procedures that are common
|
| 9 |
+
to all the TPP models. In the inherited class, models with specific structures are defined, explained in below sections.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
Computing the loglikelihood of non-pad event sequence
|
| 13 |
+
------------------------------------------------------
|
| 14 |
+
|
| 15 |
+
The loglikelihood computation, following the definition in Equation 8 of `The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process <https://arxiv.org/abs/1612.09328>`_, is shared by all the TPP models.
|
| 16 |
+
|
| 17 |
+
it takes `time_delta_seqs`, `lambda_at_event`, `lambdas_loss_samples`, `seq_mask`,
|
| 18 |
+
`lambda_type_mask` as the input and output the loglikelihood items, please see `torch_basemodel` (**/model/torch_model/torch_basemodel.py**) / `tf_basemodel` (**/model/tf_model/tf_basemodel.py**)
|
| 19 |
+
for details.
|
| 20 |
+
|
| 21 |
+
It is noted that:
|
| 22 |
+
|
| 23 |
+
1. Sequential prediction: because we performance sequential prediction, i.e., predict next one given previous, we do not consider the last one as it has no labels. To implement the `forward` function, we take input of `time_seqs[:, :-1]`
|
| 24 |
+
and `type_seqs[:, :-1]`. For `time_delta_seqs` it is different; please see the next point.
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
2. Continuous-time evolution: recall the definition in [dataset](./dataset.rst), assume we have a sequence of 4 events and 1 pad event
|
| 29 |
+
at the end, i.e.,
|
| 30 |
+
|
| 31 |
+
.. code-block:: bash
|
| 32 |
+
|
| 33 |
+
index: 0, 1, 2, 3, 4
|
| 34 |
+
dtimes: 0, t_1-t_0, t_2-t_1, t_3-t_2, pad
|
| 35 |
+
types: e_0, e_1, e_2, e_3, pad
|
| 36 |
+
non_pad_mask: True, True, True, True, False
|
| 37 |
+
|
| 38 |
+
For the i-th event, i-th dtime denotes the time evolution (e.g., decay in NHP) to the current event and
|
| 39 |
+
(i+1)-th dtime denotes the time evolution to the next event. To compute the non-event loglikelihood,
|
| 40 |
+
we should consider the time evolution after the event happens. Therefore we should use `type_delta_seqs[:, 1:]` with masks specified in the below step.
|
| 41 |
+
|
| 42 |
+
3. Masking: suppose we have predictions of 0,1,2,3-th event and their labels are 1,2,3,4-th events
|
| 43 |
+
where $4$-th event needed to be masked. So we should set the sequence mask as `True, True, True, False`, i.e., `seq_mask=batch_non_pad_mask[:, 1:]`.
|
| 44 |
+
The same logic applies to the attention mask and event type mask.
|
| 45 |
+
|
| 46 |
+
Therefore the following code is a typical example of calling the loglikelihood computation:
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
.. code-block:: python
|
| 50 |
+
|
| 51 |
+
event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, # seq_len = max_len - 1
|
| 52 |
+
lambdas_loss_samples=lambda_t_sample, # seq_len = max_len - 1
|
| 53 |
+
time_delta_seq=time_delta_seq[:, 1:],
|
| 54 |
+
seq_mask=batch_non_pad_mask[:, 1:],
|
| 55 |
+
lambda_type_mask=type_mask[:, 1:])
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
Computing the integral inside the loglikelihood
|
| 60 |
+
-----------------------------------------------
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
The loglikelihood of the parameters is the sum of the log-intensities of the events that happened, at the times they happened,
|
| 64 |
+
minus an integral of the total intensities over the observation interval over [0,T]:
|
| 65 |
+
|
| 66 |
+
.. math::
|
| 67 |
+
|
| 68 |
+
\sum_{t_i}\log \lambda_{k_i}(t_i) - \int_0^T \lambda(t) dt
|
| 69 |
+
|
| 70 |
+
The first term refers to event loglikelihood and the second term (including the negative sign) refers to the non-event loglikelihood.
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
Neural Hawkes Process (NHP)
|
| 78 |
+
===================================
|
| 79 |
+
|
| 80 |
+
We implement NHP based on author's official pytorch code `Github:nce-mpp <https://github.com/hongyuanmei/nce-mpp/blob/main/ncempp/models/nhp.py>`_.
|
| 81 |
+
|
| 82 |
+
1. A continuous-time LSTM is introduced, with the code mainly come from `Github:nce-mpp <https://github.com/hongyuanmei/nce-mpp/blob/main/ncempp/models/nhp.py>`_.
|
| 83 |
+
2. A `forward` function in NHP class that recursively update the states: we compute the event embedding, pass to the LSTM cell and then decay afterwards. Noted that for i-th event, we should use (i+1)-th dt for the decay. So we do not consider the last event as it has no decay time.
|
| 84 |
+
|
| 85 |
+
Attentive Neural Hawkes Process (AttNHP)
|
| 86 |
+
========================================
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
We implement AttNHP based on the authors' official pytorch code `Github:anhp-andtt <https://github.com/yangalan123/anhp-andtt>`_
|
| 90 |
+
and similar to NHP, we factorize it into based model and inherited model.
|
| 91 |
+
|
| 92 |
+
The forward functions is implemented faithfully to that of the author's repo.
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
Transformer Hawkes Process (THP)
|
| 96 |
+
========================================
|
| 97 |
+
|
| 98 |
+
We implement THP based on a fixed version of pytorch code `Github:anhp-andtt/thp <https://github.com/yangalan123/anhp-andtt/tree/master/thp>`_
|
| 99 |
+
and we factorize it into based model and inherited model.
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
Self-Attentive Hawkes Process (SAHP)
|
| 103 |
+
========================================
|
| 104 |
+
|
| 105 |
+
We implement SAHP based on a fixed version of pytorch code `Github:anhp-andtt/sahp <https://github.com/yangalan123/anhp-andtt/tree/master/sahp>`_
|
| 106 |
+
and we factorize it into based model and inherited model.
|
| 107 |
+
|
| 108 |
+
`SAHP` basically shares very similar structure to that of `THP`.
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
Recurrent Marked Temporal Point Processes (RMTPP)
|
| 113 |
+
====================================================
|
| 114 |
+
|
| 115 |
+
We implement RMTPP faithfully to the author's paper.
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
Intensity Free Learning of Temporal Point Process (IntensityFree)
|
| 119 |
+
==================================================================
|
| 120 |
+
|
| 121 |
+
We implement the model based on the author's torch code `Github:ifl-tpp <https://github.com/shchur/ifl-tpp>`_.
|
| 122 |
+
|
| 123 |
+
A small difference between our implementation and the author's is we ignore the `context_init` (the initial state of the RNN) because in our data setup, we do not need a learnable initial RNN state. This modification generally makes little impact on the learning process.
|
| 124 |
+
|
| 125 |
+
It is worth noting that the thinning algorithm can not be applied to this model because it is intensity-free. When comparing the performance of the model, we only look at its log-likelihood learning curve.
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
Fully Neural Network based Model for General Temporal Point Processes (FullyNN)
|
| 129 |
+
===============================================================================
|
| 130 |
+
|
| 131 |
+
We implement the model based on the author's keras code `Github:NeuralNetworkPointProcess <https://github.com/omitakahiro/NeuralNetworkPointProcess>`_.
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
ODE-based Temporal Point Process (ODETPP)
|
| 135 |
+
=========================================
|
| 136 |
+
|
| 137 |
+
We implement a TPP with Neural ODE state evolution, which is a simplified version of `Neural Spatio-Temporal Point Processes <https://arxiv.org/abs/2011.04583>`_. The ODE implementation uses the code from the `blog <https://msurtsukov.github.io/Neural-ODE/>`_
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
Attentive Neural Hawkes Network (ANHN)
|
| 141 |
+
======================================
|
| 142 |
+
|
| 143 |
+
We implement the model based on the author's paper: the attentive model without the graph regularizer is named ANHN.
|
docs/source/advanced/performance_valid.rst
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
=========================================
|
| 2 |
+
Performance validation of EasyTPP models
|
| 3 |
+
=========================================
|
| 4 |
+
|
| 5 |
+
We run the experiments on various dataset to validate the implementations: each model is trained with a max number of epochs and
|
| 6 |
+
the best model is selected based on the performance on the valid set, then we report the results on the test set.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
Simulated dataset
|
| 10 |
+
---------------------------
|
| 11 |
+
Conttime
|
| 12 |
+
**********************
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
+--------------+----------+----------+----------+--------------------+
|
| 17 |
+
| Models | Loglike | RMSE | Acc | Num Training Epochs|
|
| 18 |
+
+==============+==========+==========+==========+====================+
|
| 19 |
+
| Torch_NHP | -0.93504 | 0.34000 | 0.38656 | 200 |
|
| 20 |
+
+--------------+----------+----------+----------+--------------------+
|
| 21 |
+
| Tf_NHP | -0.85774 | 0.34014 | 0.38806 | 200 |
|
| 22 |
+
+--------------+----------+----------+----------+--------------------+
|
| 23 |
+
| Torch_AttNHP | -1.02001 | 0.33678 | 0.36782 | 200 |
|
| 24 |
+
+--------------+----------+----------+----------+--------------------+
|
| 25 |
+
| Tf_AttNHP | -1.02315 | 0.33816 | 0.19456 | 200 |
|
| 26 |
+
+--------------+----------+----------+----------+--------------------+
|
| 27 |
+
| Torch_AttNHP | -1.00593 | 0.33685 | 0.37723 | 500 |
|
| 28 |
+
+--------------+----------+----------+----------+--------------------+
|
| 29 |
+
| Tf_AttNHP | -0.99827 | 0.33717 | 0.36498 | 500 |
|
| 30 |
+
+--------------+----------+----------+----------+--------------------+
|
| 31 |
+
| Torch_THP | -0.99827 | 0.33717 | 0.36498 | 500 |
|
| 32 |
+
+--------------+----------+----------+----------+--------------------+
|
| 33 |
+
| Tf_THP | -1.01898 | 0.33677 | 0.37875 | 500 |
|
| 34 |
+
+--------------+----------+----------+----------+--------------------+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
## Real dataset
|
| 39 |
+
### Taxi
|
| 40 |
+
|
| 41 |
+
|
docs/source/advanced/tensorboard.rst
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
===================================
|
| 2 |
+
Launching the Tensorboard
|
| 3 |
+
===================================
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
Here we present how to launch the tensorboard within the ``EasyTPP`` framework.
|
| 7 |
+
|
| 8 |
+
Step 1: Activate the usage of tensorboard in Config file
|
| 9 |
+
========================================================
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
As shown in `Training Pipeline <../get_started/run_train_pipeline.html>`_, we need to firstly initialize the 'model_config.yaml' file to setup the running config before training or evaluating the model.
|
| 13 |
+
|
| 14 |
+
In the ``model config`` (`modeling` attribute of the config), one needs to set ``use_tfb`` to ``True`` in `trainer`. Then before the running process, summary writers tracking the performance on training and valid sets are both initialized.
|
| 15 |
+
|
| 16 |
+
.. code-block:: yaml
|
| 17 |
+
|
| 18 |
+
NHP_train:
|
| 19 |
+
base_config:
|
| 20 |
+
stage: train
|
| 21 |
+
backend: torch
|
| 22 |
+
dataset_id: taxi
|
| 23 |
+
runner_id: std_tpp
|
| 24 |
+
model_id: NHP # model name
|
| 25 |
+
base_dir: './checkpoints/'
|
| 26 |
+
trainer_config:
|
| 27 |
+
batch_size: 256
|
| 28 |
+
max_epoch: 200
|
| 29 |
+
shuffle: False
|
| 30 |
+
optimizer: adam
|
| 31 |
+
learning_rate: 1.e-3
|
| 32 |
+
valid_freq: 1
|
| 33 |
+
use_tfb: True # Activate the tensorboard
|
| 34 |
+
metrics: [ 'acc', 'rmse' ]
|
| 35 |
+
seed: 2019
|
| 36 |
+
gpu: -1
|
| 37 |
+
model_config:
|
| 38 |
+
hidden_size: 64
|
| 39 |
+
loss_integral_num_sample_per_step: 20
|
| 40 |
+
# pretrained_model_dir: ./checkpoints/75518_4377527680_230530-132355/models/saved_model
|
| 41 |
+
thinning:
|
| 42 |
+
num_seq: 10
|
| 43 |
+
num_sample: 1
|
| 44 |
+
num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm
|
| 45 |
+
look_ahead_time: 10
|
| 46 |
+
patience_counter: 5 # the maximum iteration used in adaptive thinning
|
| 47 |
+
over_sample_rate: 5
|
| 48 |
+
num_samples_boundary: 5
|
| 49 |
+
dtime_max: 5
|
| 50 |
+
num_step_gen: 1
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
Step 2: Launching the tensorboard
|
| 55 |
+
========================================================
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
We simply go to the output file of the training runner (its directory is specified in `base_dir` of ``base_config``), find out the tensorboard file address and launch it.
|
| 59 |
+
|
| 60 |
+
A complete example of using tensorboard can be seen at *examples/run_tensorboard.py*.
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
.. code-block:: python
|
| 64 |
+
|
| 65 |
+
import os
|
| 66 |
+
|
| 67 |
+
def main():
|
| 68 |
+
# one can find this dir in the config out file
|
| 69 |
+
log_dir = './checkpoints/NHP_train_taxi_20220527-20:18:30/tfb_train'
|
| 70 |
+
os.system('tensorboard --logdir={}'.format(log_dir))
|
| 71 |
+
return
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == '__main__':
|
| 75 |
+
main()
|
docs/source/advanced/thinning_algo.rst
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
==============================================
|
| 2 |
+
Thinning Algorithm for Sampling Event Sequence
|
| 3 |
+
==============================================
|
| 4 |
+
|
| 5 |
+
In ``EasyTPP`` we use ``Thinning algorithm`` depicted in Algorithm 2
|
| 6 |
+
in `The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process <https://arxiv.org/abs/1612.09328>`_
|
| 7 |
+
for event sampling.
|
| 8 |
+
|
| 9 |
+
The implementation of the algorithm
|
| 10 |
+
====================================
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
We implement the algorithm both in PyTorch and Tensorflow, as seen in *./model/torch_thinning.py* and
|
| 14 |
+
*./model/tf_thinning.py*, which basically follow the same procedure.
|
| 15 |
+
|
| 16 |
+
The corresponding code is in function ``draw_next_time_one_step``, which consists of the following steps:
|
| 17 |
+
|
| 18 |
+
1. Compute the upper bound of the intensity at each event timestamp in function ``compute_intensity_upper_bound``, where we sample some timestamps inside event intervals and output a upper bound intensity matrix [batch_size, seq_len], denoting the upper bound of prediced intensity (for next time interval) for each sequence at each timestamp.
|
| 19 |
+
2. Sample the exponential distribution with the intensity computed in Step 1 in function ``sample_exp_distribution``, where we simply divide the standard exponential number with the intensity, which is equivalent to sampling with exp(sample_rate), according to `the properties of Exponential Distribution <https://en.wikipedia.org/wiki/Exponential_distribution>`_. The exponential random variables have size [batch_size, seq_len, num_sample, num_exp], where num_sample refers to the number of event times sampled in every interval and num_exp refers to number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm.
|
| 20 |
+
3. Compute the intensities at the sample times proposed in Step 2, with final size `[batch_size, seq_len, num_sample, num_exp]`.
|
| 21 |
+
4. Sample the standard uniform distribution with size `[batch_size, seq_len, num_sample, num_exp]`.
|
| 22 |
+
5. Perform the acceptance sampling with certain probability in function ``sample_accept``.
|
| 23 |
+
6. The earliest sampling dtimes are accepted. For unaccepted sampling dtimes, use boundary/maxsampletime for that draw.
|
| 24 |
+
7. The final predicted dtimes has size `[batch_size, seq_len, num_sample]`, which refers to the sampling dtimes for each sequence at each timestamp, along with an equal weight vector.
|
| 25 |
+
8. The product of the predicted dtimes and the weight is the final predicted dtimes, with size `[batch_size, seq_len]`.
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
.. image:: ../../images/thinning_algo.jpg
|
| 29 |
+
:alt: thinning_algo
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
One-step prediction
|
| 34 |
+
====================================
|
| 35 |
+
By default, once given the parameters of thinning algo (defining a ``thinning`` config as part of ``model_config``), we perform the one-step prediction in model evaluation, i.e., predict the next event given the prefix. The implementation is in function ``prediction_event_one_step`` in BaseModel (i.e., TorchBaseModel or TfBaseModel).
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
Multi-step prediction
|
| 39 |
+
====================================
|
| 40 |
+
The recursive multi-step prediction is activated by setting `num_step_gen` to a number bigger than 1 in the ``thinning`` config.
|
| 41 |
+
|
| 42 |
+
Be noted that, we generate the multi-step events after the last non-pad event of each sequence. The implementation is in function `predict_multi_step_since_last_event` in BaseModel (i.e., TorchBaseModel or TfBaseModel).
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
.. code-block:: yaml
|
| 46 |
+
|
| 47 |
+
thinning:
|
| 48 |
+
num_seq: 10
|
| 49 |
+
num_sample: 1
|
| 50 |
+
num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm
|
| 51 |
+
look_ahead_time: 10
|
| 52 |
+
patience_counter: 5 # the maximum iteration used in adaptive thinning
|
| 53 |
+
over_sample_rate: 5
|
| 54 |
+
num_samples_boundary: 5
|
| 55 |
+
dtime_max: 5
|
| 56 |
+
num_step_gen: 5 # by default it is single step, i.e., 1
|
docs/source/conf.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration file for the Sphinx documentation builder.
|
| 2 |
+
#
|
| 3 |
+
# For the full list of built-in configuration values, see the documentation:
|
| 4 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# -- Autodoc information -----------------------------------------------------
|
| 8 |
+
# https://sphinx-rtd-tutorial.readthedocs.io/en/latest/sphinx-config.html
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
|
| 14 |
+
sys.path.insert(0, os.path.abspath('../../easy_tpp/'))
|
| 15 |
+
|
| 16 |
+
sys.path.insert(0, os.path.abspath('../..'))
|
| 17 |
+
|
| 18 |
+
# -- Project information -----------------------------------------------------
|
| 19 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
| 20 |
+
|
| 21 |
+
project = 'EasyTPP'
|
| 22 |
+
copyright = '2022, Machine Intelligence, Alipay'
|
| 23 |
+
author = 'Machine Intelligence, Alipay'
|
| 24 |
+
release = '0.0.2'
|
| 25 |
+
|
| 26 |
+
# -- General configuration ---------------------------------------------------
|
| 27 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
| 28 |
+
|
| 29 |
+
extensions = [
|
| 30 |
+
"sphinx.ext.autodoc",
|
| 31 |
+
'sphinx.ext.viewcode',
|
| 32 |
+
"sphinx.ext.todo",
|
| 33 |
+
"sphinx.ext.mathjax",
|
| 34 |
+
"sphinx.ext.napoleon",
|
| 35 |
+
'sphinx.ext.autosummary'
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
napoleon_google_docstring = True
|
| 39 |
+
napoleon_numpy_docstring = False
|
| 40 |
+
|
| 41 |
+
templates_path = ['_templates']
|
| 42 |
+
# List of patterns, relative to source directory, that match files and
|
| 43 |
+
# directories to ignore when looking for source files.
|
| 44 |
+
# This patterns also effect to html_static_path and html_extra_path
|
| 45 |
+
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
| 46 |
+
|
| 47 |
+
# -- Options for HTML output -------------------------------------------------
|
| 48 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
| 49 |
+
|
| 50 |
+
html_theme = 'sphinx_rtd_theme'
|
| 51 |
+
html_static_path = ['_static']
|
| 52 |
+
|
| 53 |
+
autodoc_member_order = "bysource"
|
| 54 |
+
autodoc_default_flags = ["members"]
|
| 55 |
+
autodoc_default_options = {
|
| 56 |
+
"members": True,
|
| 57 |
+
"member-order": "bysource",
|
| 58 |
+
"special-members": "__init__",
|
| 59 |
+
}
|
docs/source/dev_guide/model_custom.rst
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
==================
|
| 2 |
+
Customize a Model
|
| 3 |
+
==================
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
Here we introduce how to customize a TPP model with the support of ``EasyTPP``.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
Create a new TPP Model Class
|
| 11 |
+
=============================
|
| 12 |
+
|
| 13 |
+
Assume we are building a PyTorch model. We need to initialize the model by inheriting class `EasyTPP.model.torch_model.TorchBaseModel <../ref/models.html>`_.
|
| 14 |
+
|
| 15 |
+
.. code-block:: python
|
| 16 |
+
|
| 17 |
+
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel
|
| 18 |
+
|
| 19 |
+
# Custom Torch TPP implementations need to
|
| 20 |
+
# inherit from the TorchBaseModel interface
|
| 21 |
+
class NewModel(TorchBaseModel):
|
| 22 |
+
def __init__(self, model_config):
|
| 23 |
+
super(NewModel, self).__init__(model_config)
|
| 24 |
+
|
| 25 |
+
# Forward along the sequence, output the states / intensities at the event times
|
| 26 |
+
def forward(self, batch):
|
| 27 |
+
...
|
| 28 |
+
return states
|
| 29 |
+
|
| 30 |
+
# Compute the loglikelihood loss
|
| 31 |
+
def loglike_loss(self, batch):
|
| 32 |
+
....
|
| 33 |
+
return loglike
|
| 34 |
+
|
| 35 |
+
# Compute the intensities at given sampling times
|
| 36 |
+
# Used in the Thinning sampler
|
| 37 |
+
def compute_intensities_at_sample_times(self, batch, sample_times, **kwargs):
|
| 38 |
+
...
|
| 39 |
+
return intensities
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
If we are building a Tensorflow model, we start with the following code
|
| 43 |
+
|
| 44 |
+
.. code-block:: python
|
| 45 |
+
|
| 46 |
+
from easy_tpp.model.torch_model.tf_basemodel import TfBaseModel
|
| 47 |
+
|
| 48 |
+
# Custom Tf TPP implementations need to
|
| 49 |
+
# inherit from the TorchBaseModel interface
|
| 50 |
+
class NewModel(TfBaseModel):
|
| 51 |
+
def __init__(self, model_config):
|
| 52 |
+
super(NewModel, self).__init__(model_config)
|
| 53 |
+
|
| 54 |
+
# Forward along the sequence, output the states / intensities at the event times
|
| 55 |
+
def forward(self, batch):
|
| 56 |
+
...
|
| 57 |
+
return states
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Compute the loglikelihood loss
|
| 61 |
+
def loglike_loss(self, batch):
|
| 62 |
+
....
|
| 63 |
+
return loglike
|
| 64 |
+
|
| 65 |
+
# Compute the intensities at given sampling times
|
| 66 |
+
# Used in the Thinning sampler
|
| 67 |
+
def compute_intensities_at_sample_times(self, batch, sample_times, **kwargs):
|
| 68 |
+
...
|
| 69 |
+
return intensities
|
| 70 |
+
|
| 71 |
+
Rewrite Relevant Methods
|
| 72 |
+
==============================
|
| 73 |
+
|
| 74 |
+
There are three important functions needed to be implemented:
|
| 75 |
+
|
| 76 |
+
- `forward`: the input is the batch data and the output is states at each step.
|
| 77 |
+
- `loglike_loss`: it computes the loglikihood loss given the batch data.
|
| 78 |
+
- `compute_intensities_at_sample_times`: it computes the intensities at each sampling steps.
|
docs/source/get_started/install.rst
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
==================
|
| 2 |
+
Installation
|
| 3 |
+
==================
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
``EasyTPP`` provides an open-source library for `Neural TPP`, with a fully automated pipeline for model training and prediction.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
Requirements
|
| 10 |
+
=============
|
| 11 |
+
|
| 12 |
+
.. code-block:: bash
|
| 13 |
+
|
| 14 |
+
PyTorch version >= 1.8.0
|
| 15 |
+
Python version >= 3.7
|
| 16 |
+
Tensorflow version >= 1.13.1 (only needed when using Tensorflow backend)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
First, we need a python environment whose version is at least greater than 3.7.0. If you don’t have one, please refer to the `Documentation <https://docs.anaconda.com/anaconda/install/>`_ to install and configure the Anaconda environment.
|
| 21 |
+
|
| 22 |
+
.. code-block:: bash
|
| 23 |
+
|
| 24 |
+
conda create -n easytpp python=3.8
|
| 25 |
+
conda activate easytpp
|
| 26 |
+
|
| 27 |
+
Then, install Pytorch and keep the version at least greater than 1.8.0.
|
| 28 |
+
|
| 29 |
+
.. code-block:: bash
|
| 30 |
+
|
| 31 |
+
pip install torch
|
| 32 |
+
|
| 33 |
+
By default, we assume to use PyTorch. If one wants to use Tensorflow backend, please install tensorflow additionally. Both Tensorflow 1.13.1 and 2.x are supported.
|
| 34 |
+
|
| 35 |
+
.. code-block:: bash
|
| 36 |
+
|
| 37 |
+
pip install tensorflow
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
Install
|
| 42 |
+
=====================
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
Install with pip
|
| 46 |
+
--------------------------
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
.. code-block:: bash
|
| 50 |
+
|
| 51 |
+
pip install easy-tpp
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
Install with the source
|
| 55 |
+
--------------------------
|
| 56 |
+
|
| 57 |
+
Setup from the source:
|
| 58 |
+
|
| 59 |
+
.. code-block:: bash
|
| 60 |
+
|
| 61 |
+
git clone https://github.com/ant-research/EasyTemporalPointProcess.git
|
| 62 |
+
cd EasyTemporalPointProcess
|
| 63 |
+
python setup.py install
|
| 64 |
+
|
docs/source/get_started/introduction.rst
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
==================
|
| 2 |
+
Introduction
|
| 3 |
+
==================
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
``EasyTPP`` provides an open-source library for `Neural TPP`, with a fully automated pipeline for model training and prediction.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
Framework
|
| 10 |
+
=========
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
``EasyTPP`` supports both Tensorflow and PyTorch: each model has two equivalent versions implemented in Tensorflow 1.13 and Pytorch 1.8 respectively. The data processing and model training / prediction pipeline are compatible with both Tensorflow and Pytorch as well.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
At the module level, ``EasyTPP`` is a package that consists of the following components, which are designed as loose-coupled modules that provide flexibility for users to develop customized functionalities.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
======================== ==============================================================================
|
| 21 |
+
Name Description
|
| 22 |
+
======================== ==============================================================================
|
| 23 |
+
`Preprocess` module Provides data batch-wise padding, inter-time processing and other related work for raw sequence.
|
| 24 |
+
|
| 25 |
+
`Model` module Implements a list of SOTA TPP models. Please refer to `Model Validation <../advanced/performance_valid.html>`_ for more details.
|
| 26 |
+
|
| 27 |
+
`Config` module Encapsulate the construction of the configuration needed to run the pipeline.
|
| 28 |
+
|
| 29 |
+
`Runner` module Controls the training and prediction pipeline.
|
| 30 |
+
======================== ==============================================================================
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
Install
|
| 35 |
+
=========
|
| 36 |
+
|
| 37 |
+
``EasyTPP`` can be installed either by pip or the source. By default it is built based on PyTorch. If one wants to run with the Tensorflow backend, one needs to install Tensorflow additionally.
|
| 38 |
+
|
| 39 |
+
Please see `Installation <./install.html>`_ for details of requirement and installation.
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
Prepare Data
|
| 43 |
+
============
|
| 44 |
+
|
| 45 |
+
By default, we use the data in Gatech format, i.e., each dataset is a dict containing the keys such as `time_since_last_event`, `time_since_start` and `type_event`. `Preprocess <../ref/preprocess.html>`_ module
|
| 46 |
+
will preprocess the data and feed it into the model.
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
An example of building a pseudo dataloader can be found at `examples <https://github.com/ant-research/EasyTemporalPointProcess/tree/main/examples/data_loader.py>`_. Please refer to `Datatset <../user_guide/dataset.html>`_ for more explanations of the `TPP` dataset iterator.
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
Model Training and Prediction
|
| 53 |
+
==============================
|
| 54 |
+
|
| 55 |
+
The training and prediction pipeline consists of two steps:
|
| 56 |
+
|
| 57 |
+
1. Setup the config file, which specifies the dataset dir, model params and pipeline settings.
|
| 58 |
+
2. Launch the python script to run the whole pipeline.
|
| 59 |
+
|
| 60 |
+
Please see `Training Pipeline <../user_guide/run_train_pipeline.html>`_ and `Evaluation Pipeline <../user_guide/run_eval.html>`_ for more details.
|
docs/source/get_started/quick_start.rst
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
====================
|
| 2 |
+
Quick Start
|
| 3 |
+
====================
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
We use the [Taxi]_ dataset as an example to show how to use ``EasyTPP`` to train a model. More details and results are provided in `Training Pipeline <../user_guide/run_train_pipeline.html>`_.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
Download Dataset
|
| 10 |
+
===================
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
The Taxi dataset we used is preprocessed by `HYPRO <https://github.com/iLampard/hypro_tpp>`_ . You can either download the dataset (in pickle) from Google Drive `here <https://drive.google.com/drive/folders/1vNX2gFuGfhoh-vngoebaQlj2-ZIZMiBo>`_ or the dataset (in json) from `HuggingFace <https://huggingface.co/easytpp>`_.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
Note that if the data sources are pickle files, we need to write the data config (in `Example Config <https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml>`_) in the following way
|
| 18 |
+
|
| 19 |
+
.. code-block:: yaml
|
| 20 |
+
|
| 21 |
+
data:
|
| 22 |
+
taxi:
|
| 23 |
+
data_format: pickle
|
| 24 |
+
train_dir: ./data/taxi/train.pkl
|
| 25 |
+
valid_dir: ./data/taxi/dev.pkl
|
| 26 |
+
test_dir: ./data/taxi/test.pkl
|
| 27 |
+
|
| 28 |
+
If we choose to directly load from HuggingFace, we can put it this way:
|
| 29 |
+
|
| 30 |
+
.. code-block:: yaml
|
| 31 |
+
|
| 32 |
+
data:
|
| 33 |
+
taxi:
|
| 34 |
+
data_format: json
|
| 35 |
+
train_dir: easytpp/taxi
|
| 36 |
+
valid_dir: easytpp/taxi
|
| 37 |
+
test_dir: easytpp/taxi
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
Meanwhile, it is also feasible to put the local directory of json files downloaded from HuggingFace in the config:
|
| 41 |
+
|
| 42 |
+
.. code-block:: yaml
|
| 43 |
+
|
| 44 |
+
data:
|
| 45 |
+
taxi:
|
| 46 |
+
data_format: json
|
| 47 |
+
train_dir: ./data/taxi/train.json
|
| 48 |
+
valid_dir: ./data/taxi/dev.json
|
| 49 |
+
test_dir: ./data/taxi/test.json
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
Setup the configuration file
|
| 55 |
+
==============================
|
| 56 |
+
|
| 57 |
+
We provide a preset config file in `Example Config <https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml>`_. The details of the configuration can be found in `Training Pipeline <../user_guide/run_train_pipeline.html>`_.
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
Train the Model
|
| 63 |
+
=========================
|
| 64 |
+
|
| 65 |
+
At this stage we need to write a script to run the training pipeline. There is a preset script `train_nhp.py <https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/train_nhp.py>`_ and one can simply copy it.
|
| 66 |
+
|
| 67 |
+
Taking the pickle data source for example, after the setup of data, config and running script, the directory structure is as follows:
|
| 68 |
+
|
| 69 |
+
.. code-block:: bash
|
| 70 |
+
|
| 71 |
+
data
|
| 72 |
+
|______taxi
|
| 73 |
+
|____ train.pkl
|
| 74 |
+
|____ dev.pkl
|
| 75 |
+
|____ test.pkl
|
| 76 |
+
|
| 77 |
+
configs
|
| 78 |
+
|______experiment_config.yaml
|
| 79 |
+
|
| 80 |
+
train_nhp.py
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
The one can simply run the following command.
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
.. code-block:: bash
|
| 88 |
+
|
| 89 |
+
python train_nhp.py
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
Reference
|
| 94 |
+
----------
|
| 95 |
+
|
| 96 |
+
.. [Taxi]
|
| 97 |
+
|
| 98 |
+
.. code-block:: bash
|
| 99 |
+
|
| 100 |
+
@misc{whong-14-taxi,
|
| 101 |
+
title = {F{OIL}ing {NYC}’s Taxi Trip Data},
|
| 102 |
+
author={Whong, Chris},
|
| 103 |
+
year = {2014},
|
| 104 |
+
url = {https://chriswhong.com/open-data/foil_nyc_taxi/}
|
| 105 |
+
}
|
| 106 |
+
|
docs/source/index.rst
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
===================================
|
| 2 |
+
``EasyTPP`` Documentation
|
| 3 |
+
===================================
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
``EasyTPP`` is an easy-to-use development and application toolkit for `Neural Temporal Point Process <https://mathworld.wolfram.com/TemporalPointProcess.html>`_ (*Neural TPP*), with key features in configurability, compatibility and reproducibility. We hope this project could benefit both researchers and practitioners with the goal of easily customized development and open benchmarking.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
.. toctree::
|
| 11 |
+
:hidden:
|
| 12 |
+
|
| 13 |
+
.. toctree::
|
| 14 |
+
:maxdepth: 2
|
| 15 |
+
:caption: GETTING STARTED
|
| 16 |
+
|
| 17 |
+
Introduction <get_started/introduction.rst>
|
| 18 |
+
Installation <get_started/install.rst>
|
| 19 |
+
Quick Start <get_started/quick_start.rst>
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
.. toctree::
|
| 23 |
+
:maxdepth: 2
|
| 24 |
+
:caption: USER GUIDE
|
| 25 |
+
|
| 26 |
+
Dataset <user_guide/dataset.rst>
|
| 27 |
+
Model Training <user_guide/run_train_pipeline.rst>
|
| 28 |
+
Model Prediction <user_guide/run_eval.rst>
|
| 29 |
+
|
| 30 |
+
.. toctree::
|
| 31 |
+
:maxdepth: 2
|
| 32 |
+
:caption: DEVELOPER GUIDE
|
| 33 |
+
|
| 34 |
+
Model Customization <dev_guide/model_custom.rst>
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
.. toctree::
|
| 38 |
+
:maxdepth: 2
|
| 39 |
+
:caption: ADVANCED TOPICS
|
| 40 |
+
|
| 41 |
+
Thinning Algorithm <advanced/thinning_algo.rst>
|
| 42 |
+
Tensorboard <advanced/tensorboard.rst>
|
| 43 |
+
Performance Benchmarks <advanced/performance_valid.rst>
|
| 44 |
+
Implementation Details <advanced/implementation.rst>
|
| 45 |
+
|
| 46 |
+
.. toctree::
|
| 47 |
+
:maxdepth: 2
|
| 48 |
+
:caption: API REFERENCE
|
| 49 |
+
|
| 50 |
+
Config <ref/config.rst>
|
| 51 |
+
Preprocess <ref/preprocess.rst>
|
| 52 |
+
Model <ref/models.rst>
|
| 53 |
+
Runner <ref/runner.rst>
|
| 54 |
+
Hyper-parameter Optimization <ref/hpo.rst>
|
| 55 |
+
Tf and Torch Wrapper <ref/wrapper.rst>
|
| 56 |
+
Utilities <ref/utils.rst>
|
docs/source/ref/config.rst
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. _api-config:
|
| 2 |
+
|
| 3 |
+
EasyTPP Config Modules
|
| 4 |
+
============================
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
.. automodule:: config_factory
|
| 8 |
+
:members:
|
| 9 |
+
:undoc-members:
|
| 10 |
+
:show-inheritance:
|
docs/source/ref/hpo.rst
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. _api-config:
|
| 2 |
+
|
| 3 |
+
EasyTPP Config Modules
|
| 4 |
+
============================
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
.. automodule:: hpo
|
| 8 |
+
:members:
|
| 9 |
+
:undoc-members:
|
| 10 |
+
:show-inheritance:
|
docs/source/ref/models.rst
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. _api-model:
|
| 2 |
+
|
| 3 |
+
EasyTPP Models
|
| 4 |
+
====================
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
.. _api-tf_model:
|
| 9 |
+
|
| 10 |
+
model.tf_model module
|
| 11 |
+
------------------------------
|
| 12 |
+
|
| 13 |
+
.. automodule:: easy_tpp.model.tf_model
|
| 14 |
+
.. autosummary::
|
| 15 |
+
:toctree: ../generated/
|
| 16 |
+
|
| 17 |
+
tf_baselayer
|
| 18 |
+
tf_basemodel
|
| 19 |
+
tf_nhp
|
| 20 |
+
tf_fullynn
|
| 21 |
+
tf_intensity_free
|
| 22 |
+
tf_ode_tpp
|
| 23 |
+
tf_rmtpp
|
| 24 |
+
tf_sahp
|
| 25 |
+
tf_thp
|
| 26 |
+
tf_attnhp
|
| 27 |
+
tf_thinning
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
.. _api-torch_model:
|
| 31 |
+
|
| 32 |
+
model.torch_model module
|
| 33 |
+
------------------------------
|
| 34 |
+
|
| 35 |
+
.. automodule:: easy_tpp.model.torch_model
|
| 36 |
+
.. autosummary::
|
| 37 |
+
:toctree: ../generated/
|
| 38 |
+
|
| 39 |
+
torch_baselayer
|
| 40 |
+
torch_basemodel
|
| 41 |
+
torch_nhp
|
| 42 |
+
torch_fullynn
|
| 43 |
+
torch_intensity_free
|
| 44 |
+
torch_ode_tpp
|
| 45 |
+
torch_rmtpp
|
| 46 |
+
torch_sahp
|
| 47 |
+
torch_thp
|
| 48 |
+
torch_attnhp
|
| 49 |
+
torch_thinning
|
| 50 |
+
|
docs/source/ref/preprocess.rst
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. _api-preprocess:
|
| 2 |
+
|
| 3 |
+
EasyTPP Preprocess Modules
|
| 4 |
+
==========================
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
.. automodule:: preprocess
|
| 8 |
+
:members:
|
| 9 |
+
:undoc-members:
|
| 10 |
+
:show-inheritance:
|
docs/source/ref/runner.rst
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. _api-modelrunner:
|
| 2 |
+
|
| 3 |
+
EasyTPP Model Runner Modules
|
| 4 |
+
============================
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
.. automodule:: runner
|
| 8 |
+
:members:
|
| 9 |
+
:undoc-members:
|
| 10 |
+
:show-inheritance:
|
docs/source/ref/utils.rst
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. _api-util:
|
| 2 |
+
|
| 3 |
+
EasyTPP Utilities Modules
|
| 4 |
+
==========================
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
.. automodule:: utils
|
| 8 |
+
:members:
|
| 9 |
+
:undoc-members:
|
| 10 |
+
:show-inheritance:
|
docs/source/ref/wrapper.rst
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. _api-wrapper:
|
| 2 |
+
|
| 3 |
+
EasyTPP Tf and Torch Wrapper Modules
|
| 4 |
+
====================================
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
.. automodule:: tf_wrapper
|
| 8 |
+
:members:
|
| 9 |
+
:undoc-members:
|
| 10 |
+
:show-inheritance:
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
.. automodule:: torch_wrapper
|
| 15 |
+
:members:
|
| 16 |
+
:undoc-members:
|
| 17 |
+
:show-inheritance:
|
docs/source/user_guide/dataset.rst
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
===========================================
|
| 2 |
+
Expected Dataset Format and Data Processing
|
| 3 |
+
===========================================
|
| 4 |
+
|
| 5 |
+
Required format
|
| 6 |
+
===================================
|
| 7 |
+
|
| 8 |
+
In EasyTPP we use the data in Gatech format, i.e., each dataset is a dict containing the following keys as
|
| 9 |
+
|
| 10 |
+
.. code-block:: bash
|
| 11 |
+
|
| 12 |
+
dim_process: 5 # num of event types (no padding)
|
| 13 |
+
'train': [[{'idx_event': 2, 'time_since_last_event': 1.0267814, 'time_since_last_same_event': 1.0267814, 'type_event': 3, 'time_since_start': 1.0267814}, {'idx_event': 3, 'time_since_last_event': 0.4029268, 'time_since_last_same_event': 1.4297082, 'type_event': 0, 'time_since_start': 1.4297082},...,],[{}...{}]]
|
| 14 |
+
|
| 15 |
+
where `dim_process` refers to the number of event types (without padding) and
|
| 16 |
+
`train` (or `dev` / `test`) contains a list of list which corresponds to an event sequence each.
|
| 17 |
+
|
| 18 |
+
Each pickle file generates a set of event sequences, each containing three sub sequences:
|
| 19 |
+
|
| 20 |
+
1. `time_seqs`: absolute timestamps of the events, correspond to `time_since_last_event`.
|
| 21 |
+
2. `time_delta_seqs`: relative timestamps of the events, correspond to `time_since_last_same_event`.
|
| 22 |
+
3. `type_seqs`: types of the events, correspond to `type_event`. Be noted that the event type index `starts from 0`.
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
Data processing
|
| 26 |
+
===================================
|
| 27 |
+
|
| 28 |
+
The data processing follows the similar pipeline as in official code of `AttNHP <https://github.com/yangalan123/anhp-andtt>`_. We name it the process of `event tokenize`.
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
Sequence padding
|
| 32 |
+
----------------
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
time_seqs, time_delta_seqs and type_seqs are firstly padded to `the max length of the whole dataset` and then fed into the model in batch.
|
| 36 |
+
|
| 37 |
+
.. code-block:: bash
|
| 38 |
+
|
| 39 |
+
input: raw event sequence (e_0, e_1, e_2, e_3) and max_len=6 # the max length among all data seqs
|
| 40 |
+
|
| 41 |
+
output:
|
| 42 |
+
|
| 43 |
+
index: 0, 1, 2, 3, 4 5
|
| 44 |
+
dtimes: 0, t_1-t_0, t_2-t_1, t_3-t_2, time_pad, time_pad
|
| 45 |
+
types: e_0, e_1, e_2, e_3, type_pad, type_pad
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
By default, we set the value of time_pad and type_pad to be the *num_event_types* (because we assume the event type index starts from 0, therefore the integer value of num_event_types is unused).
|
| 49 |
+
|
| 50 |
+
Sequence masking
|
| 51 |
+
----------------
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
After padding, we perform the masking for the event sequences and generate three more seqs: batch_non_pad_mask, attention_mask, type_mask:
|
| 55 |
+
|
| 56 |
+
1. `batch_non_pad_mask`: it indicates the position of masks in the sequence.
|
| 57 |
+
2. `attention_mask`: it indicates the masks used in the attention calculation (one event can only attend to its past events).
|
| 58 |
+
3. `type_mask`: it uses one-hot vector to represent the event type. The padded event is a zero vector.
|
| 59 |
+
|
| 60 |
+
Finally, each batch contains six elements: time_seqs, time_delta_seqs, event_seq, batch_non_pad_mask, attention_mask, type_mask. The implementation of padding mechanism can be found at `event_tokenizer <https://github.com/ant-research/EasyTemporalPointProcess/blob/main/easy_tpp/preprocess/event_tokenizer.py>`_.
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
An example
|
| 65 |
+
----------------
|
| 66 |
+
|
| 67 |
+
We take a real event sequence for example. Assume we have an input sequence $[ 1, 9, 5, 0]$ with num_event_types=11 and max_len=6.
|
| 68 |
+
|
| 69 |
+
Then the padded time_seqs, time_delta_seqs and type_seqs become
|
| 70 |
+
|
| 71 |
+
.. code-block:: bash
|
| 72 |
+
|
| 73 |
+
# time_seqs
|
| 74 |
+
[ 0.0000, 0.8252, 1.3806, 1.8349, 11.0000, 11.0000]
|
| 75 |
+
|
| 76 |
+
# time_delta_seqs
|
| 77 |
+
[ 0.0000, 0.8252, 0.5554, 0.4542, 11.0000, 11.0000]
|
| 78 |
+
|
| 79 |
+
# type_seqs
|
| 80 |
+
[ 1, 9, 5, 0, 11, 11]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
The mask sequences are
|
| 84 |
+
|
| 85 |
+
.. code-block:: bash
|
| 86 |
+
|
| 87 |
+
# batch_non_pad_mask
|
| 88 |
+
[ True, True, True, True, False, False]
|
| 89 |
+
|
| 90 |
+
# attention_mask
|
| 91 |
+
[[True, True, True, True, True, True],
|
| 92 |
+
[False, True, True, True, True, True],
|
| 93 |
+
[False, False, True, True, True, True],
|
| 94 |
+
[False, False, False, True, True, True],
|
| 95 |
+
[False, False, False, False, True, True],
|
| 96 |
+
[False, False, False, False, True, True]]
|
| 97 |
+
|
| 98 |
+
# type_mask
|
| 99 |
+
[[False, True, False, False, False, False, False, False, False, False, False],
|
| 100 |
+
[False, False, False, False, False, False, False, False, False, True, False],
|
| 101 |
+
[False, False, False, False, False, True, False, False, False, False, False],
|
| 102 |
+
[True, False, False, False, False, False, False, False, False, False, False],
|
| 103 |
+
[False, False, False, False, False, False, False, False, False, False, False],
|
| 104 |
+
[False, False, False, False, False, False, False, False, False, False, False]],
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
The runnable examples of constructing and iterating the dataset object can be found at `examples/event_tokenizer.py <https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/event_tokenizer.py>`_
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
Preprocessed Datasets
|
| 111 |
+
===================================
|
| 112 |
+
|
| 113 |
+
We have preprocessed some widely-used open source datasets in Gatech format, which can be found at `Google Drive <https://drive.google.com/drive/folders/0BwqmV0EcoUc8UklIR1BKV25YR1U?resourcekey=0-OrlU87jyc1m-dVMmY5aC4w>`_. We use them for validating and benchmarking EasyTPP models.
|
| 114 |
+
|
| 115 |
+
- Retweet (`Zhou, 2013 <http://proceedings.mlr.press/v28/zhou13.pdf>`_). This dataset contains time-stamped user retweet event sequences. The events are categorized into 3 types: retweets by “small,” “medium” and “large” users. Small users have fewer than 120 followers, medium users have fewer than 1363, and the rest are large users. We work on a subset of 5200 most active users with an average sequence length of 70.
|
| 116 |
+
- Taxi (`Whong, 2014 <https://chriswhong.com/open-data/foil_nyc_taxi>`_). This dataset tracks the time-stamped taxi pick-up and drop-off events across the five boroughs of the New York City; each (borough, pick-up or drop-off) combination defines an event type, so there are 10 event types in total. We work on a randomly sampled subset of 2000 drivers and each driver has a sequence. We randomly sampled disjoint train, dev and test sets with 1400, 200 and 400 sequences.
|
| 117 |
+
- StackOverflow ( `Leskovec, 2014 <https://snap.stanford.edu/data/>`_). This dataset has two years of user awards on a question-answering website: each user received a sequence of badges and there are 22 different kinds of badges in total. We randomly sampled disjoint train, dev and test sets with 1400,400 and 400 sequences from the dataset.
|
| 118 |
+
- Taobao (`Xue et al, 2022 <https://arxiv.org/abs/2210.01753>`_). This dataset contains time-stamped user click behaviors on Taobao shopping pages from November 25 to December 03, 2017. Each user has a sequence of item click events with each event containing the timestamp and the category of the item. The categories of all items are first ranked by frequencies and the top 19 are kept while the rest are merged into one category, with each category corresponding to an event type. We work on a subset of 4800 most active users with an average sequence length of 150 and then end up with 20 event types.
|
| 119 |
+
- Amazon (`Xue et al, 2022 <https://arxiv.org/abs/2210.01753>`_). This dataset includes time-stamped user product reviews behavior from January, 2008 to October, 2018. Each user has a sequence of produce review events with each event containing the timestamp and category of the reviewed product, with each category corresponding to an event type. We work on a subset of 5200 most active users with an average sequence length of 70 and then end up with 16 event types.
|
| 120 |
+
|
| 121 |
+
Besides, we also published two textual event sequence datasets:
|
| 122 |
+
|
| 123 |
+
- GDELT (`Shi et al, 2023 <https://arxiv.org/abs/2305.16646>`_). The GDELT Project monitors events all over the world, with live datasets updated every 15 minutes. We only focused on the political events that happened in G20 countries from 2022-01-01 to 2022-07-31, ending up with a corpus of 109000 time-stamped event tokens. The event type of each token has a structured name of the format subject-predicate-object. Each {predicate} is one of the twenty CAMEO codes such as {CONSULT} and {INVESTIGATE}; each {subject} or {object} is one of the 2279 political entities (individuals, groups, and states) such as {Tesla} and {Australia}. We split the dataset into disjoint train, dev, and test sets based on their dates: the 83100 events that happened before 2022-07-05 are training data; the 16650 events after 2022-07-19 are test data; the 9250 events between these dates are development data.
|
| 124 |
+
- Amazon-text-review (`Shi et al, 2023 <https://arxiv.org/abs/2305.16646>`_). This dataset contains user reviews on Amazon shopping website from 2014-01-04 to 2016-10-02. We focused on the most active 2500 users and each user has a sequence of product review events. The type is the category of the product: we selected the most frequently-reviewed 23 categories and grouped all the others into a special OTHER category, ending up with 24 categories in total. Each review event also has a mark which is the actual content of the review. Each of the 2500 sequences is cut into three segments: the events that happened before 2015-08-01 are training data; those after 2016-02-01 are test data; the events between these dates are dev data. Then we have 49,680 training tokens, 7,020 dev tokens, and 13,090 test tokens.
|
docs/source/user_guide/run_eval.rst
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
================================
|
| 2 |
+
Evaluate a Model
|
| 3 |
+
================================
|
| 4 |
+
|
| 5 |
+
Step 1: Setup the config file
|
| 6 |
+
===============================================
|
| 7 |
+
|
| 8 |
+
Same as in the training pipeline, firstly we need to initialize the task configuration in the config file.
|
| 9 |
+
|
| 10 |
+
Similar to the setup in `Training Pipeline <./run_train_pipeline.html>`_, we set the `stage` to `eval` and pass the `pretrained_model_dir` to ``the model_config``
|
| 11 |
+
|
| 12 |
+
Note that the *pretrained_model_dir* can be found in the log of the training process.
|
| 13 |
+
|
| 14 |
+
.. code-block:: yaml
|
| 15 |
+
|
| 16 |
+
NHP_eval:
|
| 17 |
+
base_config:
|
| 18 |
+
stage: eval
|
| 19 |
+
backend: torch
|
| 20 |
+
dataset_id: taxi
|
| 21 |
+
runner_id: std_tpp
|
| 22 |
+
base_dir: './checkpoints/'
|
| 23 |
+
model_id: NHP
|
| 24 |
+
trainer_config:
|
| 25 |
+
batch_size: 256
|
| 26 |
+
max_epoch: 1
|
| 27 |
+
model_config:
|
| 28 |
+
hidden_size: 64
|
| 29 |
+
use_ln: False
|
| 30 |
+
seed: 2019
|
| 31 |
+
gpu: 0
|
| 32 |
+
pretrained_model_dir: ./checkpoints/26507_4380788096_231111-101848/models/saved_model # must provide this dir
|
| 33 |
+
thinning:
|
| 34 |
+
num_seq: 10
|
| 35 |
+
num_sample: 1
|
| 36 |
+
num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm
|
| 37 |
+
look_ahead_time: 10
|
| 38 |
+
patience_counter: 5 # the maximum iteration used in adaptive thinning
|
| 39 |
+
over_sample_rate: 5
|
| 40 |
+
num_samples_boundary: 5
|
| 41 |
+
dtime_max: 5
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
A complete example of these files can be seen at `examples/example_config.yaml <https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml>`_ .
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
Step 2: Run the evaluation script
|
| 50 |
+
=================================
|
| 51 |
+
|
| 52 |
+
Same as in the training pipeline, we need to initialize a ``ModelRunner`` object to do the evaluation.
|
| 53 |
+
|
| 54 |
+
The following code is an example, which is a copy from `examples/train_nhp.py <https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/train_nhp.py>`_ .
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
.. code-block:: python
|
| 58 |
+
|
| 59 |
+
import argparse
|
| 60 |
+
|
| 61 |
+
from easy_tpp.config_factory import RunnerConfig
|
| 62 |
+
from easy_tpp.runner import Runner
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def main():
|
| 66 |
+
parser = argparse.ArgumentParser()
|
| 67 |
+
|
| 68 |
+
parser.add_argument('--config_dir', type=str, required=False, default='configs/experiment_config.yaml',
|
| 69 |
+
help='Dir of configuration yaml to train and evaluate the model.')
|
| 70 |
+
|
| 71 |
+
parser.add_argument('--experiment_id', type=str, required=False, default='RMTPP_eval',
|
| 72 |
+
help='Experiment id in the config file.')
|
| 73 |
+
|
| 74 |
+
args = parser.parse_args()
|
| 75 |
+
|
| 76 |
+
config = RunnerConfig.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id)
|
| 77 |
+
|
| 78 |
+
model_runner = Runner.build_from_config(config)
|
| 79 |
+
|
| 80 |
+
model_runner.run()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == '__main__':
|
| 84 |
+
main()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
Checkout the output
|
| 90 |
+
====================
|
| 91 |
+
|
| 92 |
+
The evaluation result will be print in the console and saved in the logs whose directory is specified in the
|
| 93 |
+
out config file, i.e.:
|
| 94 |
+
|
| 95 |
+
.. code-block:: bash
|
| 96 |
+
|
| 97 |
+
'output_config_dir': './checkpoints/NHP_test_conttime_20221002-13:19:23/NHP_test_output.yaml'
|
docs/source/user_guide/run_train_pipeline.rst
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
============================================
|
| 2 |
+
Training a Model & Configuration Explanation
|
| 3 |
+
============================================
|
| 4 |
+
|
| 5 |
+
This tutorial shows how one can use ``EasyTPP`` to train the implemented models.
|
| 6 |
+
|
| 7 |
+
In principle, firstly we need to initialize a config yaml file, containing all the input configuration to guide the training and eval process. The overall structure of a config file is shown as below:
|
| 8 |
+
|
| 9 |
+
.. code-block:: yaml
|
| 10 |
+
|
| 11 |
+
pipeline_config_id: .. # name of the config for guiding the pipeline
|
| 12 |
+
|
| 13 |
+
data:
|
| 14 |
+
[Dataset ID]: # name of the dataset, e.g, taxi
|
| 15 |
+
....
|
| 16 |
+
|
| 17 |
+
[EXPERIMENT ID]: # name of the experiment to run
|
| 18 |
+
base_config:
|
| 19 |
+
....
|
| 20 |
+
model_config:
|
| 21 |
+
...
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
After the config file is setup, we can run the script, by specifying the `config directory` and `experiment id`, to start the pipeline. We currently provide a preset script at `examples/train_nhp.py`.
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
Step 1: Setup the config file containing data and model configs
|
| 28 |
+
================================================================
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
To be specific, one needs to define the following entries in the config file:
|
| 32 |
+
|
| 33 |
+
- **pipeline_config_id**: registered name of EasyTPP.Config objects, such as `runner_config` or `hpo_runner_config`. By reading this, the corresponding configuration class will be loaded for constructing the pipeline.
|
| 34 |
+
|
| 35 |
+
.. code-block:: yaml
|
| 36 |
+
|
| 37 |
+
pipeline_config_id: runner_config
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
- **data**: dataset specifics. One can put multiple dataset specifics in the config file, but only one will be used in one experiment.
|
| 41 |
+
|
| 42 |
+
- *[DATASET ID]*: name of the dataset, e.g., taxi.
|
| 43 |
+
- *train_dir, valid_dir, test_dir*: directory of the datafile. For the moment we only accept pkl file (please see `Dataset <./dataset.html>`_ for details)
|
| 44 |
+
- *data_spec*: define the event type information.
|
| 45 |
+
|
| 46 |
+
.. code-block:: yaml
|
| 47 |
+
|
| 48 |
+
data:
|
| 49 |
+
taxi:
|
| 50 |
+
data_format: pkl
|
| 51 |
+
train_dir: ../data/taxi/train.pkl
|
| 52 |
+
valid_dir: ../data/taxi/dev.pkl
|
| 53 |
+
test_dir: ../data/taxi/test.pkl
|
| 54 |
+
data_spec:
|
| 55 |
+
num_event_types: 7 # num of types excluding pad events.
|
| 56 |
+
pad_token_id: 6 # event type index for pad events
|
| 57 |
+
padding_side: right # pad at the right end of the sequence
|
| 58 |
+
truncation_side: right # truncate at the right end of the sequence
|
| 59 |
+
max_len: 100 # max sequence length used as model input
|
| 60 |
+
|
| 61 |
+
- **[EXPERIMENT ID]**: name of the experiment to run in the pipeline. It contains two blocks of configs:
|
| 62 |
+
|
| 63 |
+
*base_config* contains the pipeline framework related specifications.
|
| 64 |
+
|
| 65 |
+
.. code-block:: yaml
|
| 66 |
+
|
| 67 |
+
base_config:
|
| 68 |
+
stage: train # train, eval and generate
|
| 69 |
+
backend: tensorflow # tensorflow and torch
|
| 70 |
+
dataset_id: conttime # name of the dataset
|
| 71 |
+
runner_id: std_tpp # registered name of the pipeline runner
|
| 72 |
+
model_id: RMTPP # model name # registered name of the implemented model
|
| 73 |
+
base_dir: './checkpoints/' # base dir to save the logs and models.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
*model_config* contains the model related specifications.
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
.. code-block:: yaml
|
| 81 |
+
|
| 82 |
+
model_config:
|
| 83 |
+
hidden_size: 32
|
| 84 |
+
time_emb_size: 16
|
| 85 |
+
num_layers: 2
|
| 86 |
+
num_heads: 2
|
| 87 |
+
mc_num_sample_per_step: 20
|
| 88 |
+
sharing_param_layer: False
|
| 89 |
+
loss_integral_num_sample_per_step: 20
|
| 90 |
+
dropout: 0.0
|
| 91 |
+
use_ln: False
|
| 92 |
+
thinning_params: # thinning algorithm for event sampling
|
| 93 |
+
num_seq: 10
|
| 94 |
+
num_sample: 1
|
| 95 |
+
num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm
|
| 96 |
+
look_ahead_time: 10
|
| 97 |
+
patience_counter: 5 # the maximum iteration used in adaptive thinning
|
| 98 |
+
over_sample_rate: 5
|
| 99 |
+
num_samples_boundary: 5
|
| 100 |
+
dtime_max: 5
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
*trainer_config* contains the training related specifications.
|
| 104 |
+
|
| 105 |
+
.. code-block:: yaml
|
| 106 |
+
|
| 107 |
+
trainer_config: # trainer arguments
|
| 108 |
+
seed: 2019
|
| 109 |
+
gpu: 0
|
| 110 |
+
batch_size: 256
|
| 111 |
+
max_epoch: 10
|
| 112 |
+
shuffle: False
|
| 113 |
+
optimizer: adam
|
| 114 |
+
learning_rate: 1.e-3
|
| 115 |
+
valid_freq: 1
|
| 116 |
+
use_tfb: False
|
| 117 |
+
metrics: ['acc', 'rmse']
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
A complete example of these files can be seen at *examples/example_config*.
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
Step 2: Run the training script
|
| 126 |
+
===============================================
|
| 127 |
+
|
| 128 |
+
To run the training process, we simply need to call two functions:
|
| 129 |
+
|
| 130 |
+
1. ``Config``: it reads the directory of the configs specified in Step 1 and do some processing to form a complete configuration.
|
| 131 |
+
2. ``Runner``: it reads the configuration and setups the whole pipeline for training, evaluation and generation.
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
The following code is an example, which is a copy from *examples/train_nhp.py*.
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
.. code-block:: python
|
| 138 |
+
|
| 139 |
+
import argparse
|
| 140 |
+
from easy_tpp.config_factory import Config
|
| 141 |
+
from easy_tpp.runner import Runner
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def main():
|
| 145 |
+
parser = argparse.ArgumentParser()
|
| 146 |
+
|
| 147 |
+
parser.add_argument('--config_dir', type=str, required=False, default='configs/experiment_config.yaml',
|
| 148 |
+
help='Dir of configuration yaml to train and evaluate the model.')
|
| 149 |
+
|
| 150 |
+
parser.add_argument('--experiment_id', type=str, required=False, default='RMTPP_train',
|
| 151 |
+
help='Experiment id in the config file.')
|
| 152 |
+
|
| 153 |
+
args = parser.parse_args()
|
| 154 |
+
|
| 155 |
+
config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id)
|
| 156 |
+
|
| 157 |
+
model_runner = Runner.build_from_config(config)
|
| 158 |
+
|
| 159 |
+
model_runner.run()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == '__main__':
|
| 163 |
+
main()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
Checkout the output
|
| 170 |
+
========================
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
During training, the log, the best model based on valid set performance, the complete configuration file are all saved. The directory of the saved files is specified in 'base' of ``model_config.yaml``, i.e.,
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
In the `./checkpoints/` folder, one find the correct subfolder by concatenating the 'experiment_id' and running timestamps. Inside that subfolder, there is a complete configuration file, e.g., ``NHP_train_output.yaml`` that records all the information used in the pipeline. The
|
| 178 |
+
|
| 179 |
+
.. code-block:: yaml
|
| 180 |
+
|
| 181 |
+
data_config:
|
| 182 |
+
train_dir: ../data/conttime/train.pkl
|
| 183 |
+
valid_dir: ../data/conttime/dev.pkl
|
| 184 |
+
test_dir: ../data/conttime/test.pkl
|
| 185 |
+
specs:
|
| 186 |
+
num_event_types_pad: 6
|
| 187 |
+
num_event_types: 5
|
| 188 |
+
event_pad_index: 5
|
| 189 |
+
data_format: pkl
|
| 190 |
+
base_config:
|
| 191 |
+
stage: train
|
| 192 |
+
backend: tensorflow
|
| 193 |
+
dataset_id: conttime
|
| 194 |
+
runner_id: std_tpp
|
| 195 |
+
model_id: RMTPP
|
| 196 |
+
base_dir: ./checkpoints/
|
| 197 |
+
exp_id: RMTPP_train
|
| 198 |
+
log_folder: ./checkpoints/98888_4299965824_221205-153425
|
| 199 |
+
saved_model_dir: ./checkpoints/98888_4299965824_221205-153425/models/saved_model
|
| 200 |
+
saved_log_dir: ./checkpoints/98888_4299965824_221205-153425/log
|
| 201 |
+
output_config_dir: ./checkpoints/98888_4299965824_221205-153425/RMTPP_train_output.yaml
|
| 202 |
+
model_config:
|
| 203 |
+
hidden_size: 32
|
| 204 |
+
time_emb_size: 16
|
| 205 |
+
num_layers: 2
|
| 206 |
+
num_heads: 2
|
| 207 |
+
mc_num_sample_per_step: 20
|
| 208 |
+
sharing_param_layer: false
|
| 209 |
+
loss_integral_num_sample_per_step: 20
|
| 210 |
+
dropout: 0.0
|
| 211 |
+
use_ln: false
|
| 212 |
+
seed: 2019
|
| 213 |
+
gpu: 0
|
| 214 |
+
thinning_params:
|
| 215 |
+
num_seq: 10
|
| 216 |
+
num_sample: 1
|
| 217 |
+
num_exp: 500
|
| 218 |
+
look_ahead_time: 10
|
| 219 |
+
patience_counter: 5
|
| 220 |
+
over_sample_rate: 5
|
| 221 |
+
num_samples_boundary: 5
|
| 222 |
+
dtime_max: 5
|
| 223 |
+
num_step_gen: 1
|
| 224 |
+
trainer:
|
| 225 |
+
batch_size: 256
|
| 226 |
+
max_epoch: 10
|
| 227 |
+
shuffle: false
|
| 228 |
+
optimizer: adam
|
| 229 |
+
learning_rate: 0.001
|
| 230 |
+
valid_freq: 1
|
| 231 |
+
use_tfb: false
|
| 232 |
+
metrics:
|
| 233 |
+
- acc
|
| 234 |
+
- rmse
|
| 235 |
+
seq_pad_end: true
|
| 236 |
+
is_training: true
|
| 237 |
+
num_event_types_pad: 6
|
| 238 |
+
num_event_types: 5
|
| 239 |
+
event_pad_index: 5
|
| 240 |
+
model_id: RMTPP
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
If we set ``use_tfb`` to ``true``, it means we can launch the tensorboard to track the training process, one
|
| 245 |
+
can see `Running Tensorboard <../advanced/tensorboard.html>`_ for details.
|
easy_tpp/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = '0.1.0'
|
easy_tpp/config_factory/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from easy_tpp.config_factory.config import Config
|
| 2 |
+
from easy_tpp.config_factory.data_config import DataConfig, DataSpecConfig
|
| 3 |
+
from easy_tpp.config_factory.hpo_config import HPOConfig, HPORunnerConfig
|
| 4 |
+
from easy_tpp.config_factory.runner_config import RunnerConfig, ModelConfig, BaseConfig
|
| 5 |
+
|
| 6 |
+
__all__ = ['Config',
|
| 7 |
+
'DataConfig',
|
| 8 |
+
'DataSpecConfig',
|
| 9 |
+
'ModelConfig',
|
| 10 |
+
'BaseConfig',
|
| 11 |
+
'RunnerConfig',
|
| 12 |
+
'HPOConfig',
|
| 13 |
+
'HPORunnerConfig']
|
easy_tpp/config_factory/config.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
from typing import Any
|
| 3 |
+
from omegaconf import OmegaConf
|
| 4 |
+
|
| 5 |
+
from easy_tpp.utils import save_yaml_config, Registrable, logger
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Config(Registrable):
|
| 9 |
+
|
| 10 |
+
def save_to_yaml_file(self, config_dir):
|
| 11 |
+
"""Save the config into the yaml file 'config_dir'.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
config_dir (str): Target filename.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
"""
|
| 18 |
+
yaml_config = self.get_yaml_config()
|
| 19 |
+
OmegaConf.save(yaml_config, config_dir)
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def build_from_yaml_file(yaml_dir, **kwargs):
|
| 23 |
+
"""Load yaml config file from disk.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
yaml_dir (str): Path of the yaml config file.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
EasyTPP.Config: Config object corresponding to cls.
|
| 30 |
+
"""
|
| 31 |
+
config = OmegaConf.load(yaml_dir)
|
| 32 |
+
pipeline_config = config.get('pipeline_config_id')
|
| 33 |
+
config_cls = Config.by_name(pipeline_config.lower())
|
| 34 |
+
logger.critical(f'Load pipeline config class {config_cls.__name__}')
|
| 35 |
+
return config_cls.parse_from_yaml_config(config, **kwargs)
|
| 36 |
+
|
| 37 |
+
@abstractmethod
|
| 38 |
+
def get_yaml_config(self):
|
| 39 |
+
"""Get the yaml format config from self.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
"""
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
@abstractmethod
|
| 47 |
+
def parse_from_yaml_config(yaml_config):
|
| 48 |
+
"""Parse from the yaml to generate the config object.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
yaml_config (dict): configs from yaml file.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
EasyTPP.Config: Config class for data.
|
| 55 |
+
"""
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
@abstractmethod
|
| 59 |
+
def copy(self):
|
| 60 |
+
"""Get a same and freely modifiable copy of self.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
"""
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
def __str__(self):
|
| 67 |
+
"""Str representation of the config.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
str: str representation of the dict format of the config.
|
| 71 |
+
"""
|
| 72 |
+
return str(self.get_yaml_config())
|
| 73 |
+
|
| 74 |
+
def update(self, config):
|
| 75 |
+
"""Update the config.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
config (dict): config dict.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
EasyTPP.Config: Config class for data.
|
| 82 |
+
"""
|
| 83 |
+
logger.critical(f'Update config class {self.__class__.__name__}')
|
| 84 |
+
return self.parse_from_yaml_config(config)
|
| 85 |
+
|
| 86 |
+
def pop(self, key: str, default_var: Any):
|
| 87 |
+
"""pop out the key-value item from the config.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
key (str): key name.
|
| 91 |
+
default_var (Any): default value to pop.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
Any: value to pop.
|
| 95 |
+
"""
|
| 96 |
+
return vars(self).pop(key) or default_var
|
| 97 |
+
|
| 98 |
+
def get(self, key: str, default_var: Any):
|
| 99 |
+
"""Retrieve the key-value item from the config.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
key (str): key name.
|
| 103 |
+
default_var (Any): default value to pop.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Any: value to get.
|
| 107 |
+
"""
|
| 108 |
+
return vars(self)[key] or default_var
|
| 109 |
+
|
| 110 |
+
def set(self, key: str, var_to_set: Any):
|
| 111 |
+
"""Set the key-value item from the config.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
key (str): key name.
|
| 115 |
+
var_to_set (Any): default value to pop.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
Any: value to get.
|
| 119 |
+
"""
|
| 120 |
+
vars(self)[key] = var_to_set
|
easy_tpp/config_factory/data_config.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from easy_tpp.config_factory.config import Config
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class DataSpecConfig(Config):
|
| 5 |
+
def __init__(self, **kwargs):
|
| 6 |
+
"""Initialize the Config class.
|
| 7 |
+
"""
|
| 8 |
+
self.num_event_types = kwargs.get('num_event_types')
|
| 9 |
+
self.pad_token_id = kwargs.get('pad_token_id')
|
| 10 |
+
self.padding_side = kwargs.get('padding_side')
|
| 11 |
+
self.truncation_side = kwargs.get('truncation_side')
|
| 12 |
+
self.padding_strategy = kwargs.get('padding_strategy')
|
| 13 |
+
self.max_len = kwargs.get('max_len')
|
| 14 |
+
self.truncation_strategy = kwargs.get('truncation_strategy')
|
| 15 |
+
self.num_event_types_pad = self.num_event_types + 1
|
| 16 |
+
self.model_input_names = kwargs.get('model_input_names')
|
| 17 |
+
|
| 18 |
+
if self.padding_side is not None and self.padding_side not in ["right", "left"]:
|
| 19 |
+
raise ValueError(
|
| 20 |
+
f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
if self.truncation_side is not None and self.truncation_side not in ["right", "left"]:
|
| 24 |
+
raise ValueError(
|
| 25 |
+
f"Truncation side should be selected between 'right' and 'left', current value: {self.truncation_side}"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def get_yaml_config(self):
|
| 29 |
+
"""Return the config in dict (yaml compatible) format.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
dict: config of the data specs in dict format.
|
| 33 |
+
"""
|
| 34 |
+
return {
|
| 35 |
+
'num_event_types': self.num_event_types,
|
| 36 |
+
'pad_token_id': self.pad_token_id,
|
| 37 |
+
'padding_side': self.padding_side,
|
| 38 |
+
'truncation_side': self.truncation_side,
|
| 39 |
+
'padding_strategy': self.padding_strategy,
|
| 40 |
+
'truncation_strategy': self.truncation_strategy,
|
| 41 |
+
'max_len': self.max_len
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def parse_from_yaml_config(yaml_config):
|
| 46 |
+
"""Parse from the yaml to generate the config object.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
yaml_config (dict): configs from yaml file.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
DataSpecConfig: Config class for data specs.
|
| 53 |
+
"""
|
| 54 |
+
return DataSpecConfig(**yaml_config)
|
| 55 |
+
|
| 56 |
+
def copy(self):
|
| 57 |
+
"""Copy the config.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
DataSpecConfig: a copy of current config.
|
| 61 |
+
"""
|
| 62 |
+
return DataSpecConfig(num_event_types_pad=self.num_event_types_pad,
|
| 63 |
+
num_event_types=self.num_event_types,
|
| 64 |
+
event_pad_index=self.pad_token_id,
|
| 65 |
+
padding_side=self.padding_side,
|
| 66 |
+
truncation_side=self.truncation_side,
|
| 67 |
+
padding_strategy=self.padding_strategy,
|
| 68 |
+
truncation_strategy=self.truncation_strategy,
|
| 69 |
+
max_len=self.max_len)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@Config.register('data_config')
|
| 73 |
+
class DataConfig(Config):
|
| 74 |
+
def __init__(self, train_dir, valid_dir, test_dir, data_format, specs=None):
|
| 75 |
+
"""Initialize the DataConfig object.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
train_dir (str): dir of tran set.
|
| 79 |
+
valid_dir (str): dir of valid set.
|
| 80 |
+
test_dir (str): dir of test set.
|
| 81 |
+
specs (dict, optional): specs of dataset. Defaults to None.
|
| 82 |
+
"""
|
| 83 |
+
self.train_dir = train_dir
|
| 84 |
+
self.valid_dir = valid_dir
|
| 85 |
+
self.test_dir = test_dir
|
| 86 |
+
self.data_specs = specs or DataSpecConfig()
|
| 87 |
+
self.data_format = train_dir.split('.')[-1] if data_format is None else data_format
|
| 88 |
+
|
| 89 |
+
def get_yaml_config(self):
|
| 90 |
+
"""Return the config in dict (yaml compatible) format.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
dict: config of the data in dict format.
|
| 94 |
+
"""
|
| 95 |
+
return {
|
| 96 |
+
'train_dir': self.train_dir,
|
| 97 |
+
'valid_dir': self.valid_dir,
|
| 98 |
+
'test_dir': self.test_dir,
|
| 99 |
+
'data_format': self.data_format,
|
| 100 |
+
'data_specs': self.data_specs.get_yaml_config(),
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
@staticmethod
|
| 104 |
+
def parse_from_yaml_config(yaml_config):
|
| 105 |
+
"""Parse from the yaml to generate the config object.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
yaml_config (dict): configs from yaml file.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
EasyTPP.DataConfig: Config class for data.
|
| 112 |
+
"""
|
| 113 |
+
return DataConfig(
|
| 114 |
+
train_dir=yaml_config.get('train_dir'),
|
| 115 |
+
valid_dir=yaml_config.get('valid_dir'),
|
| 116 |
+
test_dir=yaml_config.get('test_dir'),
|
| 117 |
+
data_format=yaml_config.get('data_format'),
|
| 118 |
+
specs=DataSpecConfig.parse_from_yaml_config(yaml_config.get('data_specs'))
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def copy(self):
|
| 122 |
+
"""Copy the config.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
EasyTPP.DataConfig: a copy of current config.
|
| 126 |
+
"""
|
| 127 |
+
return DataConfig(train_dir=self.train_dir,
|
| 128 |
+
valid_dir=self.valid_dir,
|
| 129 |
+
test_dir=self.test_dir,
|
| 130 |
+
specs=self.data_specs)
|
| 131 |
+
|
| 132 |
+
def get_data_dir(self, split):
|
| 133 |
+
"""Get the dir of the source raw data.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
split (str): dataset split notation, 'train', 'dev' or 'valid', 'test'.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
str: dir of the source raw data file.
|
| 140 |
+
"""
|
| 141 |
+
split = split.lower()
|
| 142 |
+
if split == 'train':
|
| 143 |
+
return self.train_dir
|
| 144 |
+
elif split in ['dev', 'valid']:
|
| 145 |
+
return self.valid_dir
|
| 146 |
+
else:
|
| 147 |
+
return self.test_dir
|
easy_tpp/config_factory/hpo_config.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from easy_tpp.config_factory.config import Config
|
| 2 |
+
from easy_tpp.config_factory.runner_config import RunnerConfig
|
| 3 |
+
from easy_tpp.utils import parse_uri_to_protocol_and_path, py_assert
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class HPOConfig(Config):
|
| 7 |
+
def __init__(self, framework_id, storage_uri, is_continuous, num_trials, num_jobs):
|
| 8 |
+
"""Initialize the HPO Config
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
framework_id (str): hpo framework id.
|
| 12 |
+
storage_uri (str): result storage dir.
|
| 13 |
+
is_continuous (bool): whether to continuously do the optimization.
|
| 14 |
+
num_trials (int): num of trails used in optimization.
|
| 15 |
+
num_jobs (int): num of the jobs.
|
| 16 |
+
"""
|
| 17 |
+
self.framework_id = framework_id or 'optuna'
|
| 18 |
+
self.is_continuous = is_continuous if is_continuous is not None else True
|
| 19 |
+
self.num_trials = num_trials or 50
|
| 20 |
+
self.storage_uri = storage_uri
|
| 21 |
+
self.num_jobs = num_jobs if num_jobs is not None else 1
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def storage_protocol(self):
|
| 25 |
+
"""Get the storage protocol
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
str: the dir of the storage protocol.
|
| 29 |
+
"""
|
| 30 |
+
storage_protocol, _ = parse_uri_to_protocol_and_path(self.storage_uri)
|
| 31 |
+
return storage_protocol
|
| 32 |
+
|
| 33 |
+
@property
|
| 34 |
+
def storage_path(self):
|
| 35 |
+
"""Get the storage protocol
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
str: the dir of the hpo data storage.
|
| 39 |
+
"""
|
| 40 |
+
_, storage_path = parse_uri_to_protocol_and_path(self.storage_uri)
|
| 41 |
+
return storage_path
|
| 42 |
+
|
| 43 |
+
def get_yaml_config(self):
|
| 44 |
+
"""Return the config in dict (yaml compatible) format.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
dict: config of the HPO specs in dict format.
|
| 48 |
+
"""
|
| 49 |
+
return {
|
| 50 |
+
'framework_id': self.framework_id,
|
| 51 |
+
'storage_uri': self.storage_uri,
|
| 52 |
+
'is_continuous': self.is_continuous,
|
| 53 |
+
'num_trials': self.num_trials,
|
| 54 |
+
'num_jobs': self.num_jobs
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def parse_from_yaml_config(yaml_config, **kwargs):
|
| 59 |
+
"""Parse from the yaml to generate the config object.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
yaml_config (dict): configs from yaml file.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
EasyTPP.HPOConfig: Config class for HPO specs.
|
| 66 |
+
"""
|
| 67 |
+
if yaml_config is None:
|
| 68 |
+
return None
|
| 69 |
+
else:
|
| 70 |
+
return HPOConfig(
|
| 71 |
+
framework_id=yaml_config.get('framework_id'),
|
| 72 |
+
storage_uri=yaml_config.get('storage_uri'),
|
| 73 |
+
is_continuous=yaml_config.get('is_continuous'),
|
| 74 |
+
num_trials=yaml_config.get('num_trials'),
|
| 75 |
+
num_jobs=yaml_config.get('num_jobs'),
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def copy(self):
|
| 79 |
+
"""Copy the config.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
EasyTPP.HPOConfig: a copy of current config.
|
| 83 |
+
"""
|
| 84 |
+
return HPOConfig(
|
| 85 |
+
framework_id=self.framework_id,
|
| 86 |
+
storage_uri=self.storage_uri,
|
| 87 |
+
is_continuous=self.is_continuous,
|
| 88 |
+
num_trials=self.num_trials,
|
| 89 |
+
num_jobs=self.num_jobs
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@Config.register('hpo_runner_config')
|
| 94 |
+
class HPORunnerConfig(Config):
|
| 95 |
+
def __init__(self, hpo_config, runner_config):
|
| 96 |
+
"""Initialize the config class
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
hpo_config (EasyTPP.HPOConfig): hpo config class.
|
| 100 |
+
runner_config (EasyTPP.RunnerConfig): runner config class.
|
| 101 |
+
"""
|
| 102 |
+
self.hpo_config = hpo_config
|
| 103 |
+
self.runner_config = runner_config
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def parse_from_yaml_config(yaml_config, **kwargs):
|
| 107 |
+
"""Parse from the yaml to generate the config object.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
yaml_config (dict): configs from yaml file.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
EasyTPP.HPORunnerConfig: Config class for HPO specs.
|
| 114 |
+
"""
|
| 115 |
+
runner_config = RunnerConfig.parse_from_yaml_config(yaml_config, **kwargs)
|
| 116 |
+
hpo_config = HPOConfig.parse_from_yaml_config(yaml_config.get('hpo'), **kwargs)
|
| 117 |
+
py_assert(hpo_config is not None, ValueError, 'No hpo configs is provided for HyperTuner')
|
| 118 |
+
return HPORunnerConfig(
|
| 119 |
+
hpo_config=hpo_config,
|
| 120 |
+
runner_config=runner_config
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def copy(self):
|
| 124 |
+
"""Copy the config.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
EasyTPP.HPORunnerConfig: a copy of current config.
|
| 128 |
+
"""
|
| 129 |
+
return HPORunnerConfig(
|
| 130 |
+
hpo_config=self.hpo_config,
|
| 131 |
+
runner_config=self.runner_config
|
| 132 |
+
)
|
easy_tpp/config_factory/model_config.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from easy_tpp.config_factory.config import Config
|
| 2 |
+
|
| 3 |
+
from easy_tpp.utils.const import Backend
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TrainerConfig(Config):
|
| 7 |
+
|
| 8 |
+
def __init__(self, **kwargs):
|
| 9 |
+
"""Initialize the Config class.
|
| 10 |
+
"""
|
| 11 |
+
self.seed = kwargs.get('seed', 9899)
|
| 12 |
+
self.gpu = kwargs.get('gpu', -1)
|
| 13 |
+
self.batch_size = kwargs.get('batch_size', 256)
|
| 14 |
+
self.max_epoch = kwargs.get('max_epoch', 10)
|
| 15 |
+
self.shuffle = kwargs.get('shuffle', False)
|
| 16 |
+
self.optimizer = kwargs.get('optimizer', 'adam')
|
| 17 |
+
self.learning_rate = kwargs.get('learning_rate', 1.e-3)
|
| 18 |
+
self.valid_freq = kwargs.get('valid_freq', 1)
|
| 19 |
+
self.use_tfb = kwargs.get('use_tfb', False)
|
| 20 |
+
self.metrics = kwargs.get('metrics', ['acc', 'rmse'])
|
| 21 |
+
|
| 22 |
+
def get_yaml_config(self):
|
| 23 |
+
"""Return the config in dict (yaml compatible) format.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
dict: config of the trainer specs in dict format.
|
| 27 |
+
"""
|
| 28 |
+
return {'seed': self.seed,
|
| 29 |
+
'gpu': self.gpu,
|
| 30 |
+
'batch_size': self.batch_size,
|
| 31 |
+
'max_epoch': self.max_epoch,
|
| 32 |
+
'shuffle': self.shuffle,
|
| 33 |
+
'optimizer': self.optimizer,
|
| 34 |
+
'learning_rate': self.learning_rate,
|
| 35 |
+
'valid_freq': self.valid_freq,
|
| 36 |
+
'use_tfb': self.use_tfb,
|
| 37 |
+
'metrics': self.metrics
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def parse_from_yaml_config(yaml_config):
|
| 42 |
+
"""Parse from the yaml to generate the config object.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
yaml_config (dict): configs from yaml file.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
EasyTPP.TrainerConfig: Config class for trainer specs.
|
| 49 |
+
"""
|
| 50 |
+
return TrainerConfig(**yaml_config)
|
| 51 |
+
|
| 52 |
+
def copy(self):
|
| 53 |
+
"""Copy the config.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
EasyTPP.TrainerConfig: a copy of current config.
|
| 57 |
+
"""
|
| 58 |
+
return TrainerConfig(batch_size=self.batch_size,
|
| 59 |
+
max_epoch=self.max_epoch,
|
| 60 |
+
shuffle=self.shuffle,
|
| 61 |
+
optimizer=self.optimizer,
|
| 62 |
+
learning_rate=self.learning_rate,
|
| 63 |
+
valid_freq=self.valid_freq,
|
| 64 |
+
use_tfb=self.use_tfb,
|
| 65 |
+
metrics=self.metrics
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ThinningConfig(Config):
|
| 70 |
+
def __init__(self, **kwargs):
|
| 71 |
+
"""Initialize the Config class.
|
| 72 |
+
"""
|
| 73 |
+
self.num_seq = kwargs.get('num_seq', 10)
|
| 74 |
+
self.num_sample = kwargs.get('num_sample', 1)
|
| 75 |
+
self.num_exp = kwargs.get('num_exp', 500)
|
| 76 |
+
self.look_ahead_time = kwargs.get('look_ahead_time', 10)
|
| 77 |
+
self.patience_counter = kwargs.get('patience_counter', 5)
|
| 78 |
+
self.over_sample_rate = kwargs.get('over_sample_rate', 5)
|
| 79 |
+
self.num_samples_boundary = kwargs.get('num_samples_boundary', 5)
|
| 80 |
+
self.dtime_max = kwargs.get('dtime_max', 5)
|
| 81 |
+
# we pad the sequence at the front only in multi-step generation
|
| 82 |
+
self.num_step_gen = kwargs.get('num_step_gen', 1)
|
| 83 |
+
|
| 84 |
+
def get_yaml_config(self):
|
| 85 |
+
"""Return the config in dict (yaml compatible) format.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
dict: config of the thinning specs in dict format.
|
| 89 |
+
"""
|
| 90 |
+
return {'num_seq': self.num_seq,
|
| 91 |
+
'num_sample': self.num_sample,
|
| 92 |
+
'num_exp': self.num_exp,
|
| 93 |
+
'look_ahead_time': self.look_ahead_time,
|
| 94 |
+
'patience_counter': self.patience_counter,
|
| 95 |
+
'over_sample_rate': self.over_sample_rate,
|
| 96 |
+
'num_samples_boundary': self.num_samples_boundary,
|
| 97 |
+
'dtime_max': self.dtime_max,
|
| 98 |
+
'num_step_gen': self.num_step_gen}
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def parse_from_yaml_config(yaml_config):
|
| 102 |
+
"""Parse from the yaml to generate the config object.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
yaml_config (dict): configs from yaml file.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
EasyTPP.ThinningConfig: Config class for thinning algorithms.
|
| 109 |
+
"""
|
| 110 |
+
return ThinningConfig(**yaml_config) if yaml_config is not None else None
|
| 111 |
+
|
| 112 |
+
def copy(self):
|
| 113 |
+
"""Copy the config.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
EasyTPP.ThinningConfig: a copy of current config.
|
| 117 |
+
"""
|
| 118 |
+
return ThinningConfig(num_seq=self.num_seq,
|
| 119 |
+
num_sample=self.num_sample,
|
| 120 |
+
num_exp=self.num_exp,
|
| 121 |
+
look_ahead_time=self.look_ahead_time,
|
| 122 |
+
patience_counter=self.patience_counter,
|
| 123 |
+
over_sample_rate=self.over_sample_rate,
|
| 124 |
+
num_samples_boundary=self.num_samples_boundary,
|
| 125 |
+
dtime_max=self.dtime_max,
|
| 126 |
+
num_step_gen=self.num_step_gen)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class BaseConfig(Config):
|
| 130 |
+
def __init__(self, **kwargs):
|
| 131 |
+
"""Initialize the Config class.
|
| 132 |
+
"""
|
| 133 |
+
self.stage = kwargs.get('stage')
|
| 134 |
+
self.backend = kwargs.get('backend')
|
| 135 |
+
self.dataset_id = kwargs.get('dataset_id')
|
| 136 |
+
self.runner_id = kwargs.get('runner_id')
|
| 137 |
+
self.model_id = kwargs.get('model_id')
|
| 138 |
+
self.exp_id = kwargs.get('exp_id')
|
| 139 |
+
self.base_dir = kwargs.get('base_dir')
|
| 140 |
+
self.specs = kwargs.get('specs', {})
|
| 141 |
+
self.backend = self.set_backend(self.backend)
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
def set_backend(backend):
|
| 145 |
+
if backend.lower() in ['torch', 'pytorch']:
|
| 146 |
+
return Backend.Torch
|
| 147 |
+
else:
|
| 148 |
+
raise ValueError(
|
| 149 |
+
f"Backend should be 'torch' or 'pytorch', current value: {backend}"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def get_yaml_config(self):
|
| 153 |
+
"""Return the config in dict (yaml compatible) format.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
dict: config of the base config specs in dict format.
|
| 157 |
+
"""
|
| 158 |
+
return {'stage': self.stage,
|
| 159 |
+
'backend': str(self.backend),
|
| 160 |
+
'dataset_id': self.dataset_id,
|
| 161 |
+
'runner_id': self.runner_id,
|
| 162 |
+
'model_id': self.model_id,
|
| 163 |
+
'base_dir': self.base_dir,
|
| 164 |
+
'specs': self.specs}
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
def parse_from_yaml_config(yaml_config):
|
| 168 |
+
"""Parse from the yaml to generate the config object.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
yaml_config (dict): configs from yaml file.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
BaseConfig: Config class for trainer specs.
|
| 175 |
+
"""
|
| 176 |
+
return BaseConfig(**yaml_config)
|
| 177 |
+
|
| 178 |
+
def copy(self):
|
| 179 |
+
"""Copy the config.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
BaseConfig: a copy of current config.
|
| 183 |
+
"""
|
| 184 |
+
return BaseConfig(stage=self.stage,
|
| 185 |
+
backend=self.backend,
|
| 186 |
+
dataset_id=self.dataset_id,
|
| 187 |
+
runner_id=self.runner_id,
|
| 188 |
+
model_id=self.model_id,
|
| 189 |
+
base_dir=self.base_dir,
|
| 190 |
+
specs=self.specs)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class ModelConfig(Config):
|
| 194 |
+
def __init__(self, **kwargs):
|
| 195 |
+
"""Initialize the Config class.
|
| 196 |
+
"""
|
| 197 |
+
self.rnn_type = kwargs.get('rnn_type', 'LSTM')
|
| 198 |
+
self.hidden_size = kwargs.get('hidden_size', 32)
|
| 199 |
+
self.time_emb_size = kwargs.get('time_emb_size', 16)
|
| 200 |
+
self.num_layers = kwargs.get('num_layers', 2)
|
| 201 |
+
self.num_heads = kwargs.get('num_heads', 2)
|
| 202 |
+
self.sharing_param_layer = kwargs.get('sharing_param_layer', False)
|
| 203 |
+
self.use_mc_samples = kwargs.get('use_mc_samples', True) # if using MC samples in computing log-likelihood
|
| 204 |
+
self.loss_integral_num_sample_per_step = kwargs.get('loss_integral_num_sample_per_step', 20) # mc_num_sample_per_step
|
| 205 |
+
self.dropout_rate = kwargs.get('dropout_rate', 0.0)
|
| 206 |
+
self.use_ln = kwargs.get('use_ln', False)
|
| 207 |
+
self.thinning = ThinningConfig.parse_from_yaml_config(kwargs.get('thinning'))
|
| 208 |
+
self.is_training = kwargs.get('training', False)
|
| 209 |
+
self.num_event_types_pad = kwargs.get('num_event_types_pad', None)
|
| 210 |
+
self.num_event_types = kwargs.get('num_event_types', None)
|
| 211 |
+
self.pad_token_id = kwargs.get('event_pad_index', None)
|
| 212 |
+
self.model_id = kwargs.get('model_id', None)
|
| 213 |
+
self.pretrained_model_dir = kwargs.get('pretrained_model_dir', None)
|
| 214 |
+
self.gpu = kwargs.get('gpu', -1)
|
| 215 |
+
self.model_specs = kwargs.get('model_specs', {})
|
| 216 |
+
|
| 217 |
+
def get_yaml_config(self):
|
| 218 |
+
"""Return the config in dict (yaml compatible) format.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
dict: config of the model config specs in dict format.
|
| 222 |
+
"""
|
| 223 |
+
return {'rnn_type': self.rnn_type,
|
| 224 |
+
'hidden_size': self.hidden_size,
|
| 225 |
+
'time_emb_size': self.time_emb_size,
|
| 226 |
+
'num_layers': self.num_layers,
|
| 227 |
+
'sharing_param_layer': self.sharing_param_layer,
|
| 228 |
+
'loss_integral_num_sample_per_step': self.loss_integral_num_sample_per_step,
|
| 229 |
+
'dropout_rate': self.dropout_rate,
|
| 230 |
+
'use_ln': self.use_ln,
|
| 231 |
+
# for some models / cases we may not need to pass thinning config
|
| 232 |
+
# e.g., for intensity-free model
|
| 233 |
+
'thinning': None if self.thinning is None else self.thinning.get_yaml_config(),
|
| 234 |
+
'num_event_types_pad': self.num_event_types_pad,
|
| 235 |
+
'num_event_types': self.num_event_types,
|
| 236 |
+
'event_pad_index': self.pad_token_id,
|
| 237 |
+
'model_id': self.model_id,
|
| 238 |
+
'pretrained_model_dir': self.pretrained_model_dir,
|
| 239 |
+
'gpu': self.gpu,
|
| 240 |
+
'model_specs': self.model_specs}
|
| 241 |
+
|
| 242 |
+
@staticmethod
|
| 243 |
+
def parse_from_yaml_config(yaml_config):
|
| 244 |
+
"""Parse from the yaml to generate the config object.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
yaml_config (dict): configs from yaml file.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
ModelConfig: Config class for trainer specs.
|
| 251 |
+
"""
|
| 252 |
+
return ModelConfig(**yaml_config)
|
| 253 |
+
|
| 254 |
+
def copy(self):
|
| 255 |
+
"""Copy the config.
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
ModelConfig: a copy of current config.
|
| 259 |
+
"""
|
| 260 |
+
return ModelConfig(rnn_type=self.rnn_type,
|
| 261 |
+
hidden_size=self.hidden_size,
|
| 262 |
+
time_emb_size=self.time_emb_size,
|
| 263 |
+
num_layers=self.num_layers,
|
| 264 |
+
sharing_param_layer=self.sharing_param_layer,
|
| 265 |
+
loss_integral_num_sample_per_step=self.loss_integral_num_sample_per_step,
|
| 266 |
+
dropout_rate=self.dropout_rate,
|
| 267 |
+
use_ln=self.use_ln,
|
| 268 |
+
thinning=self.thinning,
|
| 269 |
+
num_event_types_pad=self.num_event_types_pad,
|
| 270 |
+
num_event_types=self.num_event_types,
|
| 271 |
+
event_pad_index=self.pad_token_id,
|
| 272 |
+
pretrained_model_dir=self.pretrained_model_dir,
|
| 273 |
+
gpu=self.gpu,
|
| 274 |
+
model_specs=self.model_specs)
|