Abigail99216 commited on
Commit
f43af3c
·
verified ·
1 Parent(s): efa62b0

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .github/workflows/docs.yaml +59 -0
  3. .github/workflows/python-publish.yml +39 -0
  4. .gitignore +95 -0
  5. ADDITIONS_README.md +71 -0
  6. CLEANUP_SUMMARY.md +164 -0
  7. COMPUTE_METRICS_README.md +191 -0
  8. DATA_FILES_NOTICE.md +107 -0
  9. DATA_TRANSFER_SUMMARY.md +95 -0
  10. HF_UPLOAD_GUIDE.md +180 -0
  11. LICENCE +203 -0
  12. MANIFEST.in +2 -0
  13. NOTICE +23 -0
  14. QUICK_START_HF.md +104 -0
  15. README.md +279 -0
  16. UPLOAD_CHECKLIST.md +116 -0
  17. cleanup_for_hf.py +293 -0
  18. compute_cascade_metrics.py +568 -0
  19. data/cascades/.gitkeep +3 -0
  20. data/cascades/README.md +101 -0
  21. docs/Makefile +20 -0
  22. docs/README.md +13 -0
  23. docs/images/thinning_algo.jpg +3 -0
  24. docs/make.bat +35 -0
  25. docs/source/advanced/implementation.rst +143 -0
  26. docs/source/advanced/performance_valid.rst +41 -0
  27. docs/source/advanced/tensorboard.rst +75 -0
  28. docs/source/advanced/thinning_algo.rst +56 -0
  29. docs/source/conf.py +59 -0
  30. docs/source/dev_guide/model_custom.rst +78 -0
  31. docs/source/get_started/install.rst +64 -0
  32. docs/source/get_started/introduction.rst +60 -0
  33. docs/source/get_started/quick_start.rst +106 -0
  34. docs/source/index.rst +56 -0
  35. docs/source/ref/config.rst +10 -0
  36. docs/source/ref/hpo.rst +10 -0
  37. docs/source/ref/models.rst +50 -0
  38. docs/source/ref/preprocess.rst +10 -0
  39. docs/source/ref/runner.rst +10 -0
  40. docs/source/ref/utils.rst +10 -0
  41. docs/source/ref/wrapper.rst +17 -0
  42. docs/source/user_guide/dataset.rst +124 -0
  43. docs/source/user_guide/run_eval.rst +97 -0
  44. docs/source/user_guide/run_train_pipeline.rst +245 -0
  45. easy_tpp/__init__.py +1 -0
  46. easy_tpp/config_factory/__init__.py +13 -0
  47. easy_tpp/config_factory/config.py +120 -0
  48. easy_tpp/config_factory/data_config.py +147 -0
  49. easy_tpp/config_factory/hpo_config.py +132 -0
  50. easy_tpp/config_factory/model_config.py +274 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ docs/images/thinning_algo.jpg filter=lfs diff=lfs merge=lfs -text
.github/workflows/docs.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: docs
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ pull_request:
7
+ branches: [ main ]
8
+ release:
9
+ types: [ published ]
10
+
11
+ jobs:
12
+ build:
13
+
14
+ runs-on: ubuntu-latest
15
+
16
+ steps:
17
+ - uses: actions/checkout@v3
18
+ with:
19
+ fetch-depth: 0
20
+ - name: Set up Python
21
+ uses: actions/setup-python@v4
22
+ with:
23
+ python-version: '3.8'
24
+ - name: Install dependencies
25
+ run: |
26
+ python -m pip install --upgrade pip setuptools wheel
27
+ sudo apt-get update
28
+ sudo apt-get install openjdk-11-jdk
29
+ sudo apt-get install pandoc
30
+ - name: Build Sphinx docs
31
+ run: |
32
+ pip install tensorflow==2.2.0
33
+ pip install torch
34
+ pip install pandas
35
+ pip install numpy
36
+ pip install -r requirements-doc.txt
37
+ cd docs
38
+ make html
39
+ # Publish built docs to gh-pages branch.
40
+ # ===============================
41
+ - name: Commit documentation changes
42
+ run: |
43
+ git clone https://github.com/ant-research/EasyTemporalPointProcess.git --branch gh-pages --single-branch gh-pages
44
+ cp -r docs/build/html/* gh-pages/
45
+ cd gh-pages
46
+ touch .nojekyll
47
+ git config --local user.email "action@github.com"
48
+ git config --local user.name "GitHub Action"
49
+ git add .
50
+ git commit -m "Update documentation" -a || true
51
+ # The above command will fail if no changes were present, so we ignore
52
+ # that.
53
+ - name: Push changes
54
+ uses: ad-m/github-push-action@master
55
+ with:
56
+ branch: gh-pages
57
+ directory: gh-pages
58
+ github_token: ${{ secrets.GITHUB_TOKEN }}
59
+ # ===============================
.github/workflows/python-publish.yml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow will upload a Python Package using Twine when a release is created
2
+ # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3
+
4
+ # This workflow uses actions that are not certified by GitHub.
5
+ # They are provided by a third-party and are governed by
6
+ # separate terms of service, privacy policy, and support
7
+ # documentation.
8
+
9
+ name: Upload Python Package
10
+
11
+ on:
12
+ release:
13
+ types: [published]
14
+
15
+ permissions:
16
+ contents: read
17
+
18
+ jobs:
19
+ deploy:
20
+
21
+ runs-on: ubuntu-latest
22
+
23
+ steps:
24
+ - uses: actions/checkout@v2
25
+ - name: Set up Python
26
+ uses: actions/setup-python@v2
27
+ with:
28
+ python-version: '3.9'
29
+ - name: Install dependencies
30
+ run: |
31
+ pip install -r requirements.txt
32
+ pip install wheel
33
+ - name: Build package
34
+ run: python setup.py sdist bdist_wheel
35
+ - name: Publish package
36
+ uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37
+ with:
38
+ user: __token__
39
+ password: ${{ secrets.PYPI_API_TOKEN }}
.gitignore ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python build
2
+ build/
3
+ dist/
4
+ easy_tpp.egg-info/
5
+ *.egg-info/
6
+
7
+ # python temp
8
+ *.pyc
9
+ *.pyo
10
+ *.pyd
11
+ __pycache__/
12
+ *.so
13
+ *.egg
14
+
15
+ # proto
16
+ protoc
17
+ protoc-3.4.0.tar.gz
18
+ *_pb2.py
19
+
20
+ # misc
21
+ experiments/
22
+ log/
23
+ logs/
24
+ *.swp
25
+ *.swo
26
+ .vscode/
27
+ .idea/
28
+
29
+ # OS files
30
+ .DS_Store
31
+ .DS_Store?
32
+ ._*
33
+ .Spotlight-V100
34
+ .Trashes
35
+ ehthumbs.db
36
+ Thumbs.db
37
+
38
+ # IDE
39
+ .vscode/
40
+ .idea/
41
+ *.sublime-project
42
+ *.sublime-workspace
43
+
44
+ # Testing
45
+ .pytest_cache/
46
+ .mypy_cache/
47
+ .ruff_cache/
48
+ .coverage
49
+ htmlcov/
50
+ .tox/
51
+
52
+ # Checkpoints and outputs
53
+ examples/checkpoints/*
54
+ notebooks/checkpoints/*
55
+ results/
56
+ outputs/
57
+ *.pth
58
+ *.pt
59
+ *.ckpt
60
+
61
+ # Data files (large files should not be committed)
62
+ examples/data/*
63
+ !examples/data/.gitkeep
64
+ *.json.gz
65
+ *.pkl
66
+ *.h5
67
+ *.hdf5
68
+
69
+ # Large cascade data files (use Git LFS or external storage)
70
+ data/cascades/information_cascade*.json
71
+ data/cascades/*.json
72
+ !data/cascades/.gitkeep
73
+
74
+ # Model files (use Git LFS or Hugging Face Hub)
75
+ *.bin
76
+ *.safetensors
77
+ *.onnx
78
+
79
+ # Temporary files
80
+ *.tmp
81
+ *.temp
82
+ *.bak
83
+ *~
84
+
85
+ # Jupyter Notebook checkpoints
86
+ .ipynb_checkpoints/
87
+
88
+ # Environment
89
+ .env
90
+ .venv
91
+ env/
92
+ venv/
93
+ ENV/
94
+ env.bak/
95
+ venv.bak/
ADDITIONS_README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EasyTPP 新增功能说明
2
+
3
+ 本仓库在原始 EasyTPP 基础上新增了用于处理信息级联数据的指标计算功能。
4
+
5
+ ## 🆕 新增文件
6
+
7
+ ### 核心功能
8
+ - **`compute_cascade_metrics.py`**: 计算级联指标的主脚本
9
+ - 情感得分 (Sentiment Score)
10
+ - 情感偏差 (Sentiment Deviation)
11
+ - 语境偏差 (Contextual Deviation)
12
+ - 困惑度 (Perplexity)
13
+
14
+ ### 文档和工具
15
+ - **`COMPUTE_METRICS_README.md`**: 详细使用说明
16
+ - **`HF_UPLOAD_GUIDE.md`**: Hugging Face 上传指南
17
+ - **`UPLOAD_CHECKLIST.md`**: 上传检查清单(自动生成)
18
+ - **`cleanup_for_hf.py`**: 清理脚本,准备上传
19
+ - **`example_compute_metrics.sh`**: 使用示例脚本
20
+ - **`requirements_compute_metrics.txt`**: 额外依赖包
21
+
22
+ ## 🚀 快速开始
23
+
24
+ ### 1. 安装依赖
25
+
26
+ ```bash
27
+ pip install -r requirements.txt
28
+ pip install -r requirements_compute_metrics.txt
29
+ ```
30
+
31
+ ### 2. 运行指标计算
32
+
33
+ ```bash
34
+ python compute_cascade_metrics.py \
35
+ --input_cascade information_cascade.json \
36
+ --output output_with_metrics.json \
37
+ --batch_size 32
38
+ ```
39
+
40
+ 详细说明请参考 `COMPUTE_METRICS_README.md`
41
+
42
+ ## 📦 上传到 Hugging Face
43
+
44
+ 1. 运行清理脚本:
45
+ ```bash
46
+ python cleanup_for_hf.py
47
+ ```
48
+
49
+ 2. 按照 `HF_UPLOAD_GUIDE.md` 的说明上传
50
+
51
+ ## 🔗 相关文档
52
+
53
+ - [指标计算说明](COMPUTE_METRICS_README.md)
54
+ - [上传指南](HF_UPLOAD_GUIDE.md)
55
+ - [原始 EasyTPP README](README.md)
56
+
57
+ ## 📝 使用场景
58
+
59
+ 这些新增功能主要用于:
60
+ - 分析社交媒体信息级联(如微博转发、评论)
61
+ - 计算文本的情感特征和语义偏差
62
+ - 为 TPP 模型提供额外的特征输入
63
+
64
+ ## ⚙️ 与 EasyTPP 集成
65
+
66
+ 计算出的指标可以用于:
67
+ - `RobertTPPDataset`: 加载包含语义和偏差特征的数据
68
+ - `RobertEventTokenizer`: 处理自定义特征
69
+ - `TorchRobotTHP`: 使用语义和偏差特征的 TPP 模型
70
+
71
+ 参考 `examples/train_robot_thp_with_features.py` 了解完整示例。
CLEANUP_SUMMARY.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 文件夹整理总结
2
+
3
+ ## ✅ 整理完成时间
4
+ 2025-01-19
5
+
6
+ ## 📊 文件夹信息
7
+
8
+ - **位置**: `/Users/chenshuyi/Downloads/EasyTemporalPointProcess-main`
9
+ - **大小**: 1.3MB
10
+ - **状态**: ✅ 已整理,可以上传
11
+
12
+ ## 🧹 已完成的清理工作
13
+
14
+ ### 1. 更新 .gitignore
15
+ - ✅ 添加了 Python 缓存文件模式
16
+ - ✅ 添加了 IDE 配置文件排除
17
+ - ✅ 添加了 OS 系统文件排除
18
+ - ✅ 添加了测试和构建文件排除
19
+ - ✅ 添加了数据文件和模型文件排除规则
20
+
21
+ ### 2. 创建清理脚本
22
+ - ✅ `cleanup_for_hf.py` - 自动清理脚本
23
+ - ✅ 已运行,确认无需要删除的文件
24
+
25
+ ### 3. 文件检查
26
+ - ✅ 无大文件(>50MB)
27
+ - ✅ 无敏感信息
28
+ - ✅ 无临时文件
29
+
30
+ ## 📁 文件结构
31
+
32
+ ```
33
+ EasyTemporalPointProcess-main/
34
+ ├── 📄 核心代码
35
+ │ ├── easy_tpp/ # EasyTPP 核心库
36
+ │ ├── examples/ # 示例代码
37
+ │ ├── notebooks/ # Jupyter notebooks
38
+ │ └── tests/ # 测试代码
39
+
40
+ ├── 🆕 新增功能(级联指标计算)
41
+ │ ├── compute_cascade_metrics.py # 主计算脚本
42
+ │ ├── COMPUTE_METRICS_README.md # 使用说明
43
+ │ ├── requirements_compute_metrics.txt # 额外依赖
44
+ │ └── example_compute_metrics.sh # 示例脚本
45
+
46
+ ├── 📚 文档
47
+ │ ├── README.md # 原始 README
48
+ │ ├── ADDITIONS_README.md # 新增功能说明
49
+ │ ├── HF_UPLOAD_GUIDE.md # Hugging Face 上传指南
50
+ │ ├── QUICK_START_HF.md # 快速开始指南
51
+ │ ├── UPLOAD_CHECKLIST.md # 上传检查清单
52
+ │ └── CLEANUP_SUMMARY.md # 本文件
53
+
54
+ ├── 🛠️ 工具脚本
55
+ │ ├── cleanup_for_hf.py # 清理脚本
56
+ │ └── setup.py # 安装脚本
57
+
58
+ └── ⚙️ 配置文件
59
+ ├── .gitignore # Git 忽略规则(已更新)
60
+ ├── requirements.txt # 基础依赖
61
+ ├── requirements_compute_metrics.txt # 指标计算依赖
62
+ └── setup.cfg # 安装配置
63
+ ```
64
+
65
+ ## 📋 新增文件列表
66
+
67
+ ### 核心功能
68
+ 1. `compute_cascade_metrics.py` (19.5 KB)
69
+ - 计算情感得分、情感偏差、语境偏差、困惑度
70
+
71
+ ### 文档
72
+ 2. `COMPUTE_METRICS_README.md` (5.9 KB)
73
+ - 详细的指标计算使用说明
74
+
75
+ 3. `HF_UPLOAD_GUIDE.md` (3.7 KB)
76
+ - Hugging Face 上传完整指南
77
+
78
+ 4. `ADDITIONS_README.md` (1.9 KB)
79
+ - 新增功能概述
80
+
81
+ 5. `QUICK_START_HF.md` (2.3 KB)
82
+ - 快速上传指南
83
+
84
+ 6. `UPLOAD_CHECKLIST.md` (3.0 KB)
85
+ - 上传检查清单(自动生成)
86
+
87
+ 7. `CLEANUP_SUMMARY.md` (本文件)
88
+ - 整理总结
89
+
90
+ ### 工具和配置
91
+ 8. `cleanup_for_hf.py` (7.8 KB)
92
+ - 自动清理脚本
93
+
94
+ 9. `example_compute_metrics.sh` (1.2 KB)
95
+ - 使用示例脚本
96
+
97
+ 10. `requirements_compute_metrics.txt` (266 B)
98
+ - 指标计算所需依赖
99
+
100
+ ## 🎯 下一步操作
101
+
102
+ ### 1. 上传到 Hugging Face
103
+
104
+ ```bash
105
+ # 安装 CLI
106
+ pip install huggingface_hub
107
+
108
+ # 登录
109
+ huggingface-cli login
110
+
111
+ # 创建仓库(在网页上)
112
+ # https://huggingface.co/new
113
+
114
+ # 上传
115
+ cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main
116
+ huggingface-cli upload <username>/<repo-name> . --repo-type dataset
117
+ ```
118
+
119
+ ### 2. 在云电脑上下载
120
+
121
+ ```bash
122
+ huggingface-cli download <username>/<repo-name> --local-dir ./EasyTPP
123
+ ```
124
+
125
+ ### 3. 使用新功能
126
+
127
+ ```bash
128
+ cd EasyTPP
129
+ pip install -r requirements.txt
130
+ pip install -r requirements_compute_metrics.txt
131
+
132
+ python compute_cascade_metrics.py \
133
+ --input_cascade information_cascade.json \
134
+ --output output_with_metrics.json
135
+ ```
136
+
137
+ ## ✅ 检查清单
138
+
139
+ - [x] 清理临时文件
140
+ - [x] 更新 .gitignore
141
+ - [x] 检查大文件
142
+ - [x] 检查敏感信息
143
+ - [x] 创建上传指南
144
+ - [x] 创建使用文档
145
+ - [x] 验证文件结构
146
+ - [ ] 上传到 Hugging Face(待执行)
147
+ - [ ] 在云电脑上测试(待执行)
148
+
149
+ ## 📝 注意事项
150
+
151
+ 1. **文件大小**: 1.3MB,无需 Git LFS
152
+ 2. **许可证**: 保持原始 Apache 2.0 许可证
153
+ 3. **依赖**: 确保所有依赖都在 requirements 文件中
154
+ 4. **文档**: 所有新增功能都有详细文档
155
+
156
+ ## 🔗 相关链接
157
+
158
+ - [Hugging Face](https://huggingface.co/)
159
+ - [Hugging Face CLI 文档](https://huggingface.co/docs/huggingface_hub/guides/cli)
160
+ - [原始 EasyTPP 项目](https://github.com/ant-research/EasyTemporalPointProcess)
161
+
162
+ ---
163
+
164
+ **整理完成!可以开始上传了!** 🚀
COMPUTE_METRICS_README.md ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 计算级联指标使用说明
2
+
3
+ 本脚本用于计算信息级联数据的情感得分、情感deviation、contextual deviation和perplexity。
4
+
5
+ ## 功能说明
6
+
7
+ 脚本 `compute_cascade_metrics.py` 会处理以下两个JSON文件:
8
+ - `information_cascade.json`: 包含完整级联数据(原帖、评论、转发)
9
+ - `information_cascade_original_posts.json`: 包含原帖数据(可选)
10
+
11
+ 计算以下指标:
12
+ 1. **情感得分 (Sentiment Score)**: 文本的情感倾向得分
13
+ 2. **情感偏差 (Sentiment Deviation)**: 相对于原帖的情感偏差
14
+ 3. **语境偏差 (Contextual Deviation)**: 相对于原帖的语义偏差
15
+ 4. **困惑度 (Perplexity)**: 文本的语言模型困惑度
16
+
17
+ ## 安装依赖
18
+
19
+ 在云电脑上安装必要的依赖:
20
+
21
+ ```bash
22
+ pip install torch transformers numpy tqdm
23
+ ```
24
+
25
+ 如果需要使用GPU:
26
+ ```bash
27
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
28
+ ```
29
+
30
+ ## 使用方法
31
+
32
+ ### 基本用法(使用默认模型)
33
+
34
+ ```bash
35
+ python compute_cascade_metrics.py \
36
+ --input_cascade information_cascade.json \
37
+ --output output_with_metrics.json \
38
+ --batch_size 32
39
+ ```
40
+
41
+ ### 完整用法(指定所有模型)
42
+
43
+ ```bash
44
+ python compute_cascade_metrics.py \
45
+ --input_cascade information_cascade.json \
46
+ --input_original information_cascade_original_posts.json \
47
+ --output output_with_metrics.json \
48
+ --bert_model bert-base-chinese \
49
+ --sentiment_model <情感分析模型路径> \
50
+ --perplexity_model <语言模型路径> \
51
+ --batch_size 32 \
52
+ --max_length 512 \
53
+ --device cuda
54
+ ```
55
+
56
+ ### 参数说明
57
+
58
+ - `--input_cascade`: **必需**,输入级联JSON文件路径
59
+ - `--input_original`: 可选,输入原帖JSON文件路径
60
+ - `--output`: **必需**,输出JSON文件路径
61
+ - `--bert_model`: BERT模型名称或路径(默认: `bert-base-chinese`)
62
+ - `--sentiment_model`: 情感分析模型路径(可选,不提供则使用简化方法)
63
+ - `--perplexity_model`: 语言模型路径(可选,不提供则使用简化方法)
64
+ - `--batch_size`: 批处理大小(默认: 32)
65
+ - `--max_length`: 最大序列长度(默认: 512)
66
+ - `--device`: 计算设备,`cuda` 或 `cpu`(默认: 自动选择)
67
+ - `--max_cascades`: 最大处理级联数量(用于测试,默认: 处理所有)
68
+
69
+ ## 输出格式
70
+
71
+ 处理后的JSON文件会在每个节点中添加以下字段:
72
+
73
+ ### 原帖 (`post_info`)
74
+ ```json
75
+ {
76
+ "post_info": {
77
+ "content": "原帖内容",
78
+ "embedding": [0.1, 0.2, ...], // BERT语义向量 (768维)
79
+ "sentiment_score": 0.7, // 情感得分
80
+ "perplexity": 15.3 // 困惑度
81
+ }
82
+ }
83
+ ```
84
+
85
+ ### 评论 (`comment_tree`)
86
+ ```json
87
+ {
88
+ "comment_tree": {
89
+ "comment_id": {
90
+ "content": "评论内容",
91
+ "embedding": [0.1, 0.2, ...],
92
+ "sentiment_score": 0.6,
93
+ "perplexity": 12.5,
94
+ "contextual_deviation": 0.25, // 语境偏差
95
+ "sentiment_deviation": 0.1 // 情感偏差
96
+ }
97
+ }
98
+ }
99
+ ```
100
+
101
+ ### 转发 (`repost_chain`)
102
+ ```json
103
+ {
104
+ "repost_chain": [
105
+ {
106
+ "forward_text": "转发内容",
107
+ "comment_content": "评论内容",
108
+ "embedding": [0.1, 0.2, ...],
109
+ "sentiment_score": 0.5,
110
+ "perplexity": 18.2,
111
+ "contextual_deviation": 0.35,
112
+ "sentiment_deviation": 0.2
113
+ }
114
+ ]
115
+ }
116
+ ```
117
+
118
+ ## 模型选择建议
119
+
120
+ ### BERT模型
121
+ - 中文文本:`bert-base-chinese`
122
+ - 英文文本:`bert-base-uncased`
123
+ - 自定义模型:提供本地路径
124
+
125
+ ### 情感分析模型
126
+ - 中文:可以使用 `uer/roberta-base-finetuned-chinanews-chinese` 或其他中文情感分析模型
127
+ - 英文:可以使用 `nlptown/bert-base-multilingual-uncased-sentiment` 等
128
+ - 如果不提供,脚本会使用基于关键词的简化方法
129
+
130
+ ### 困惑度模型
131
+ - 中文:可以使用 `gpt2-chinese` 或其他中文语言模型
132
+ - 英文:可以使用 `gpt2` 等
133
+ - 如果不提供,脚本会使用基于词汇多样性的简化方法
134
+
135
+ ## 注意事项
136
+
137
+ 1. **大文件处理**: 如果JSON文件很大,处理时间可能较长。建议:
138
+ - 使用GPU加速(`--device cuda`)
139
+ - 调整批处理大小(`--batch_size`)
140
+ - 先用 `--max_cascades` 测试少量数据
141
+
142
+ 2. **内存使用**:
143
+ - BERT模型需要较多内存
144
+ - 如果内存不足,减小 `--batch_size`
145
+
146
+ 3. **简化方法**:
147
+ - 如果不提供情感分析模型或困惑度模型,脚本会使用简化的启发式方法
148
+ - 简化方法的结果可能不如专业模型准确,但计算速度快
149
+
150
+ 4. **数据格式**:
151
+ - 确保输入的JSON文件格式正确
152
+ - JSON文件应包含 `cascades` 字段,每个级联包含 `post_info`、`comment_tree`、`repost_chain`
153
+
154
+ ## 示例
155
+
156
+ ### 示例1:使用默认设置处理数据
157
+ ```bash
158
+ python compute_cascade_metrics.py \
159
+ --input_cascade /path/to/information_cascade.json \
160
+ --output /path/to/output.json
161
+ ```
162
+
163
+ ### 示例2:使用GPU和自定义模型
164
+ ```bash
165
+ python compute_cascade_metrics.py \
166
+ --input_cascade /path/to/information_cascade.json \
167
+ --output /path/to/output.json \
168
+ --bert_model bert-base-chinese \
169
+ --sentiment_model /path/to/sentiment_model \
170
+ --device cuda \
171
+ --batch_size 64
172
+ ```
173
+
174
+ ### 示例3:测试模式(只处理前10个级联)
175
+ ```bash
176
+ python compute_cascade_metrics.py \
177
+ --input_cascade /path/to/information_cascade.json \
178
+ --output /path/to/output.json \
179
+ --max_cascades 10
180
+ ```
181
+
182
+ ## 故障排除
183
+
184
+ 1. **CUDA内存不足**: 减小 `--batch_size` 或使用 `--device cpu`
185
+ 2. **模型下载失败**: 检查网络连接,或手动下载模型到本地后指定路径
186
+ 3. **JSON格式错误**: 检查输入JSON文件格式是否正确
187
+ 4. **处理速度慢**: 使用GPU(`--device cuda`)和增大批处理大小
188
+
189
+ ## 与EasyTPP集成
190
+
191
+ 处理后的JSON文件可以用于EasyTPP框架的训练。参考 `examples/train_robot_thp_with_features.py` 了解如何使用这些特征。
DATA_FILES_NOTICE.md ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ⚠️ 数据文件说明
2
+
3
+ ## 📁 数据文件位置
4
+
5
+ 数据文件已复制到 `data/cascades/` 目录:
6
+
7
+ - `data/cascades/information_cascade.json` (606MB)
8
+ - `data/cascades/information_cascade_original_posts.json` (980MB)
9
+
10
+ ## ⚠️ 重要提示
11
+
12
+ **这些文件太大(总计约 1.6GB),不会上传到 Hugging Face!**
13
+
14
+ `.gitignore` 已配置为排除这些文件,因为它们超过了 Git/Hugging Face 的推荐大小限制。
15
+
16
+ ## 📥 在云电脑上获取数据文件
17
+
18
+ ### 方法1: 直接传输(推荐)
19
+
20
+ ```bash
21
+ # 在云电脑上创建目录
22
+ mkdir -p data/cascades
23
+
24
+ # 使用 scp 从本地传输
25
+ scp -r user@local-machine:/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/information_cascade*.json ./data/cascades/
26
+ ```
27
+
28
+ ### 方法2: 使用云存储
29
+
30
+ 1. 将文件上传到云存储(Google Drive, Dropbox, OneDrive 等)
31
+ 2. 在云电脑上下载
32
+
33
+ ### 方法3: 使用 Git LFS(如果配置)
34
+
35
+ 如果需要通过 Git 管理大文件:
36
+
37
+ ```bash
38
+ # 安装 Git LFS
39
+ git lfs install
40
+
41
+ # 跟踪大文件
42
+ git lfs track "data/cascades/*.json"
43
+
44
+ # 添加文件
45
+ git add .gitattributes
46
+ git add data/cascades/*.json
47
+ git commit -m "Add cascade data with LFS"
48
+ git push
49
+ ```
50
+
51
+ ### 方法4: 使用 Hugging Face Dataset Hub
52
+
53
+ 可以将数据文件单独上传到 Hugging Face Dataset Hub:
54
+
55
+ ```bash
56
+ # 安装依赖
57
+ pip install huggingface_hub
58
+
59
+ # 上传数据文件
60
+ huggingface-cli upload <username>/cascade-data data/cascades/ --repo-type dataset
61
+ ```
62
+
63
+ 然后在云电脑上下载:
64
+
65
+ ```bash
66
+ huggingface-cli download <username>/cascade-data --local-dir ./data/cascades
67
+ ```
68
+
69
+ ## ✅ 验证文件
70
+
71
+ 上传到 Hugging Face 后,验证:
72
+
73
+ ```bash
74
+ # 检查文件是否存在
75
+ ls -lh data/cascades/
76
+
77
+ # 应该看到:
78
+ # information_cascade.json
79
+ # information_cascade_original_posts.json
80
+ ```
81
+
82
+ ## 🚀 使用数据文件
83
+
84
+ 文件准备好后,运行指标计算:
85
+
86
+ ```bash
87
+ python compute_cascade_metrics.py \
88
+ --input_cascade data/cascades/information_cascade.json \
89
+ --input_original data/cascades/information_cascade_original_posts.json \
90
+ --output output_with_metrics.json \
91
+ --batch_size 32 \
92
+ --device cuda
93
+ ```
94
+
95
+ ## 📝 文件来源
96
+
97
+ 原始文件位置:
98
+ - `/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/`
99
+
100
+ 已复制到:
101
+ - `/Users/chenshuyi/Downloads/EasyTemporalPointProcess-main/data/cascades/`
102
+
103
+ ## 🔗 相关文档
104
+
105
+ - [数据文件说明](data/cascades/README.md)
106
+ - [指标计算说明](COMPUTE_METRICS_README.md)
107
+ - [上传指南](HF_UPLOAD_GUIDE.md)
DATA_TRANSFER_SUMMARY.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 数据文件转移总结
2
+
3
+ ## ✅ 已完成
4
+
5
+ 两个 information cascade 文件已成功复制到 EasyTemporalPointProcess-main 文件夹。
6
+
7
+ ## 📁 文件位置
8
+
9
+ ### 源文件位置
10
+ - `/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/information_cascade.json`
11
+ - `/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/information_cascade_original_posts.json`
12
+
13
+ ### 目标位置
14
+ - `/Users/chenshuyi/Downloads/EasyTemporalPointProcess-main/data/cascades/information_cascade.json` (606MB)
15
+ - `/Users/chenshuyi/Downloads/EasyTemporalPointProcess-main/data/cascades/information_cascade_original_posts.json` (980MB)
16
+
17
+ ## ⚠️ 重要说明
18
+
19
+ ### 文件大小
20
+ - **总大小**: 约 1.6GB
21
+ - **information_cascade.json**: 606MB
22
+ - **information_cascade_original_posts.json**: 980MB
23
+
24
+ ### Git 排除配置
25
+ 这些文件**不会上传到 Hugging Face**,因为:
26
+ 1. 文件太大,超过 Git/Hugging Face 推荐大小
27
+ 2. 已通过 `.gitignore` 排除:
28
+ ```
29
+ data/cascades/information_cascade*.json
30
+ data/cascades/*.json
31
+ ```
32
+
33
+ ## 📥 在云电脑上获取数据文件
34
+
35
+ ### 方法1: 使用 scp 传输(推荐)
36
+
37
+ ```bash
38
+ # 在云电脑上
39
+ mkdir -p data/cascades
40
+
41
+ # 从本地传输
42
+ scp user@local-machine:/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/information_cascade*.json ./data/cascades/
43
+ ```
44
+
45
+ ### 方法2: 上传到 Hugging Face Dataset Hub
46
+
47
+ ```bash
48
+ # 在本地
49
+ cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main
50
+ huggingface-cli upload <username>/cascade-data data/cascades/ --repo-type dataset
51
+
52
+ # 在云电脑上下载
53
+ huggingface-cli download <username>/cascade-data --local-dir ./data/cascades
54
+ ```
55
+
56
+ ### 方法3: 使用云存储
57
+
58
+ 1. 将文件上传到 Google Drive / Dropbox / OneDrive
59
+ 2. 在云电脑上下载
60
+
61
+ ## 📝 相关文档
62
+
63
+ - **数据文件说明**: `data/cascades/README.md`
64
+ - **数据文件注意事项**: `DATA_FILES_NOTICE.md`
65
+ - **上传指南**: `HF_UPLOAD_GUIDE.md`
66
+
67
+ ## ✅ 验证
68
+
69
+ 上传到 Hugging Face 后,验证数据文件:
70
+
71
+ ```bash
72
+ # 检查文件是否存在
73
+ ls -lh data/cascades/
74
+
75
+ # 应该看到:
76
+ # information_cascade.json (606MB)
77
+ # information_cascade_original_posts.json (980MB)
78
+ ```
79
+
80
+ ## 🚀 使用数据文件
81
+
82
+ 文件准备好后,运行指标计算:
83
+
84
+ ```bash
85
+ python compute_cascade_metrics.py \
86
+ --input_cascade data/cascades/information_cascade.json \
87
+ --input_original data/cascades/information_cascade_original_posts.json \
88
+ --output output_with_metrics.json \
89
+ --batch_size 32 \
90
+ --device cuda
91
+ ```
92
+
93
+ ---
94
+
95
+ **数据文件已成功转移!** ✅
HF_UPLOAD_GUIDE.md ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face 上传指南
2
+
3
+ 本指南说明如何将 EasyTemporalPointProcess-main 上传到 Hugging Face。
4
+
5
+ ## 📋 准备工作
6
+
7
+ ### 1. 运行清理脚本
8
+
9
+ ```bash
10
+ cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main
11
+ python cleanup_for_hf.py
12
+ ```
13
+
14
+ 这会自动:
15
+ - 删除 `__pycache__/`、`.pyc` 等临时文件
16
+ - 检查大文件
17
+ - 创建上传检查清单
18
+
19
+ ### 2. 数据文件说明 ⚠️
20
+
21
+ **重要**: `data/cascades/` 目录包含大文件(约 1.6GB),**不会上传到 Hugging Face**。
22
+
23
+ 这些文件已通过 `.gitignore` 排除:
24
+ - `information_cascade.json` (606MB)
25
+ - `information_cascade_original_posts.json` (980MB)
26
+
27
+ **在云电脑上获取数据文件的方法**:
28
+ - 方法1: 使用 scp 直接传输(推荐)
29
+ - 方法2: 上传到云存储后下载
30
+ - 方法3: 使用 Git LFS(如果配置)
31
+ - 方法4: 单独上传到 Hugging Face Dataset Hub
32
+
33
+ 详细说明请参考 `DATA_FILES_NOTICE.md`
34
+
35
+ ### 3. 手动检查
36
+
37
+ - [ ] 检查是否有敏感信息(API密钥、密码等)
38
+ - [ ] 确认大文件已正确排除(通过 .gitignore)
39
+ - [ ] 确保 `requirements.txt` 是最新的
40
+ - [ ] 检查 README.md 是否完整
41
+
42
+ ## 🚀 上传方法
43
+
44
+ ### 方法1: 使用 Hugging Face CLI(推荐)
45
+
46
+ ```bash
47
+ # 1. 安装 Hugging Face CLI
48
+ pip install huggingface_hub
49
+
50
+ # 2. 登录
51
+ huggingface-cli login
52
+ # 输入你的 Hugging Face token(在 https://huggingface.co/settings/tokens 获取)
53
+
54
+ # 3. 创建仓库(在网页上创建,或使用 CLI)
55
+ # 访问 https://huggingface.co/new 创建新仓库
56
+ # 选择 "Dataset" 类型,命名为例如:easytpp-cascade-metrics
57
+
58
+ # 4. 上传文件
59
+ cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main
60
+ huggingface-cli upload <your-username>/easytpp-cascade-metrics . --repo-type dataset
61
+ ```
62
+
63
+ ### 方法2: 使用 Git
64
+
65
+ ```bash
66
+ # 1. 初始化 Git(如果还没有)
67
+ cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main
68
+ git init
69
+
70
+ # 2. 添加文件
71
+ git add .
72
+ git commit -m "Add EasyTPP with cascade metrics computation"
73
+
74
+ # 3. 添加 Hugging Face 远程仓库
75
+ # 先在 https://huggingface.co/new 创建仓库
76
+ git remote add origin https://huggingface.co/<your-username>/<repo-name>
77
+
78
+ # 4. 推送
79
+ git push origin main
80
+ ```
81
+
82
+ ### 方法3: 使用 Web 界面上传
83
+
84
+ 1. 访问 https://huggingface.co/new
85
+ 2. 创建新的 Dataset 仓库
86
+ 3. 点击 "Add file" → "Upload files"
87
+ 4. 拖拽或选择文件夹上传
88
+
89
+ ## 📦 在云电脑上下载
90
+
91
+ 上传完成后,在云电脑上下载:
92
+
93
+ ```bash
94
+ # 方法1: 使用 Hugging Face CLI
95
+ pip install huggingface_hub
96
+ huggingface-cli download <your-username>/<repo-name> --local-dir ./EasyTPP
97
+
98
+ # 方法2: 使用 Git
99
+ git clone https://huggingface.co/datasets/<your-username>/<repo-name>
100
+ cd <repo-name>
101
+
102
+ # 方法3: 使用 Python
103
+ from huggingface_hub import snapshot_download
104
+ snapshot_download(repo_id="<your-username>/<repo-name>", repo_type="dataset", local_dir="./EasyTPP")
105
+ ```
106
+
107
+ ### 📥 下载数据文件
108
+
109
+ **重要**: 代码仓库不包含数据文件(已通过 .gitignore 排除)。
110
+
111
+ 数据文件需要单独获取:
112
+
113
+ ```bash
114
+ # 方法1: 使用 scp 从本地传输(推荐)
115
+ mkdir -p data/cascades
116
+ scp user@local-machine:/path/to/information_cascade*.json ./data/cascades/
117
+
118
+ # 方法2: 如果已上传到 Hugging Face Dataset Hub
119
+ huggingface-cli download <username>/cascade-data --local-dir ./data/cascades
120
+
121
+ # 方法3: 从云存储下载
122
+ # (根据你使用的云存储服务)
123
+ ```
124
+
125
+ 详细说明请参考 `DATA_FILES_NOTICE.md`
126
+
127
+ ## 📝 新增功能说明
128
+
129
+ 本仓库在原始 EasyTPP 基础上新增了以下功能:
130
+
131
+ ### 1. 级联指标计算 (`compute_cascade_metrics.py`)
132
+
133
+ 用于计算信息级联数据的指标:
134
+ - **情感得分** (Sentiment Score)
135
+ - **情感偏差** (Sentiment Deviation)
136
+ - **语境偏差** (Contextual Deviation)
137
+ - **困惑度** (Perplexity)
138
+
139
+ 详细说明请参考 `COMPUTE_METRICS_README.md`
140
+
141
+ ### 2. 相关文件
142
+
143
+ - `compute_cascade_metrics.py`: 主计算脚本
144
+ - `COMPUTE_METRICS_README.md`: 使用说明
145
+ - `requirements_compute_metrics.txt`: 额外依赖
146
+ - `example_compute_metrics.sh`: 示例脚本
147
+ - `cleanup_for_hf.py`: 清理脚本
148
+
149
+ ## ⚠️ 注意事项
150
+
151
+ 1. **大文件处理**
152
+ - 如果文件 >50MB,考虑使用 Git LFS
153
+ - 或排除数据文件,使用外部链接
154
+
155
+ 2. **敏感信息**
156
+ - 不要上传包含 API 密钥、密码的文件
157
+ - 检查配置文件中的敏感数据
158
+
159
+ 3. **许可证**
160
+ - 确保所有代码都有适当的许可证
161
+ - 原始 EasyTPP 使用 Apache 2.0 许可证
162
+
163
+ 4. **版本控制**
164
+ - 建议使用 Git 进行版本控制
165
+ - 每次更新后提交并推送
166
+
167
+ ## 🔍 验证上传
168
+
169
+ 上传后检查:
170
+ - [ ] 所有文件都已上传
171
+ - [ ] README 显示正确
172
+ - [ ] 代码可以正常下载
173
+ - [ ] 依赖可以正常安装
174
+
175
+ ## 📞 问题反馈
176
+
177
+ 如有问题,请检查:
178
+ 1. Hugging Face 仓库设置是否正确
179
+ 2. 文件大小是否超过限制
180
+ 3. 是否有权限问题
LICENCE ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2022 The EasyTPP Authors. All rights reserved.
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright [yyyy] [name of copyright owner]
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
MANIFEST.in ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ include requirements.txt
2
+ include version.py
NOTICE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ =============================================================
2
+ EasyTPP is a open source tool developed by Machine Intelligence Team
3
+ Copyright (c) 2020-2022, Ant Group Holding Limited.
4
+ Licensed under the Apache License, Version 2.0
5
+
6
+ =============================================================
7
+ This toolkit contains various third-party components under
8
+ different open source licenses
9
+
10
+ -----------------------------
11
+ Training evaluation pipeline
12
+ Apache License, Version 2.0
13
+ FuxiCTR authors
14
+
15
+ ----------------------------
16
+ Training evaluation pipeline
17
+ Apache License, Version 2.0
18
+ EasyNLP, Alibaba Inc.
19
+
20
+ ----------------------------
21
+ Tokenizer and DataLoader
22
+ Apache License, Version 2.0
23
+ The HuggingFace Inc. team
QUICK_START_HF.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 快速开始:上传到 Hugging Face
2
+
3
+ ## ✅ 整理完成
4
+
5
+ 文件夹已整理完成,可以上传到 Hugging Face。
6
+
7
+ **文件夹大小**: 1.3MB(适合上传)
8
+
9
+ ## 📋 整理内容
10
+
11
+ ### 已完成的清理
12
+ - ✅ 更新了 `.gitignore` 文件
13
+ - ✅ 创建了清理脚本 `cleanup_for_hf.py`
14
+ - ✅ 检查了大文件(无大文件)
15
+ - ✅ 创建了上传指南
16
+
17
+ ### 新增文件
18
+ - `compute_cascade_metrics.py` - 级联指标计算脚本
19
+ - `COMPUTE_METRICS_README.md` - 指标计算说明
20
+ - `HF_UPLOAD_GUIDE.md` - 上传指南
21
+ - `ADDITIONS_README.md` - 新增功能说明
22
+ - `cleanup_for_hf.py` - 清理脚本
23
+ - `requirements_compute_metrics.txt` - 额外依赖
24
+
25
+ ## 🚀 三步上传
26
+
27
+ ### 步骤1: 安装 Hugging Face CLI
28
+
29
+ ```bash
30
+ pip install huggingface_hub
31
+ ```
32
+
33
+ ### 步骤2: 登录
34
+
35
+ ```bash
36
+ huggingface-cli login
37
+ # 输入你的 token(在 https://huggingface.co/settings/tokens 获取)
38
+ ```
39
+
40
+ ### 步骤3: 创建仓库并上传
41
+
42
+ ```bash
43
+ # 1. 在网页上创建仓库
44
+ # 访问 https://huggingface.co/new
45
+ # 选择 "Dataset",命名为例如:easytpp-cascade-metrics
46
+
47
+ # 2. 上传文件
48
+ cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main
49
+ huggingface-cli upload <your-username>/easytpp-cascade-metrics . --repo-type dataset
50
+ ```
51
+
52
+ ## 📥 在云电脑上下载
53
+
54
+ ```bash
55
+ # 方法1: 使用 CLI
56
+ huggingface-cli download <your-username>/easytpp-cascade-metrics --local-dir ./EasyTPP
57
+
58
+ # 方法2: 使用 Git
59
+ git clone https://huggingface.co/datasets/<your-username>/easytpp-cascade-metrics
60
+ cd easytpp-cascade-metrics
61
+ ```
62
+
63
+ ### ⚠️ 重要:数据文件需要单独获取
64
+
65
+ 代码仓库**不包含**数据文件(已通过 .gitignore 排除,因为文件太大)。
66
+
67
+ 数据文件需要单独传输:
68
+
69
+ ```bash
70
+ # 在云电脑上创建目录
71
+ mkdir -p data/cascades
72
+
73
+ # 方法1: 使用 scp 从本地传输(推荐)
74
+ scp user@local-machine:/path/to/information_cascade*.json ./data/cascades/
75
+
76
+ # 方法2: 如果已上传到 Hugging Face Dataset Hub
77
+ huggingface-cli download <username>/cascade-data --local-dir ./data/cascades
78
+ ```
79
+
80
+ 详细说明请参考 `DATA_FILES_NOTICE.md`
81
+
82
+ ## 📚 相关文档
83
+
84
+ - **详细上传指南**: `HF_UPLOAD_GUIDE.md`
85
+ - **数据文件说明**: `DATA_FILES_NOTICE.md` ⚠️ **重要**
86
+ - **指标计算说明**: `COMPUTE_METRICS_README.md`
87
+ - **新增功能**: `ADDITIONS_README.md`
88
+ - **上传检查清单**: `UPLOAD_CHECKLIST.md`
89
+
90
+ ## ⚠️ 注意事项
91
+
92
+ 1. **文件大小**: 当前文件夹 1.3MB,无需 Git LFS
93
+ 2. **敏感信息**: 已检查,无敏感信息
94
+ 3. **依赖**: 确保 `requirements.txt` 和 `requirements_compute_metrics.txt` 已包含
95
+
96
+ ## 🎯 下一步
97
+
98
+ 1. 按照上述步骤上传到 Hugging Face
99
+ 2. 在云电脑上下载并测试
100
+ 3. 运行 `compute_cascade_metrics.py` 计算指标
101
+
102
+ ---
103
+
104
+ **准备好了!可以开始上传了!** 🎉
README.md ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EasyTPP [ICLR 2024]
2
+
3
+ <div align="center">
4
+ <a href="PyVersion">
5
+ <img alt="Python Version" src="https://img.shields.io/badge/python-3.9+-blue.svg">
6
+ </a>
7
+ <a href="LICENSE-CODE">
8
+ <img alt="Code License" src="https://img.shields.io/badge/license-Apache-000000.svg?&color=f5de53">
9
+ </a>
10
+ <a href="commit">
11
+ <img alt="Last Commit" src="https://img.shields.io/github/last-commit/ant-research/EasyTemporalPointProcess">
12
+ </a>
13
+ </div>
14
+
15
+ <div align="center">
16
+ <a href="https://pypi.python.org/pypi/easy-tpp/">
17
+ <img alt="PyPI version" src="https://img.shields.io/pypi/v/easy-tpp.svg?style=flat-square&color=b7534" />
18
+ </a>
19
+ <a href="https://static.pepy.tech/personalized-badge/easy-tpp">
20
+ <img alt="Downloads" src="https://static.pepy.tech/personalized-badge/easy-tpp?period=total&units=international_system&left_color=black&right_color=blue&left_text=Downloads" />
21
+ </a>
22
+ <a href="https://huggingface.co/easytpp" target="_blank">
23
+ <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-EasyTPP-ffc107?color=ffc107&logoColor=white" />
24
+ </a>
25
+ <a href="https://github.com/ant-research/EasyTemporalPointProcess/issues">
26
+ <img alt="Open Issues" src="https://img.shields.io/github/issues-raw/ant-research/EasyTemporalPointProcess" />
27
+ </a>
28
+ </div>
29
+
30
+ `EasyTPP` is an easy-to-use development and application toolkit for [Temporal Point Process](https://mathworld.wolfram.com/TemporalPointProcess.html) (TPP), with key features in configurability, compatibility and reproducibility. We hope this project could benefit both researchers and practitioners with the goal of easily customized development and open benchmarking in TPP.
31
+ <span id='top'/>
32
+
33
+
34
+
35
+ | <a href='#features'>Features</a> | <a href='#model-list'>Model List</a> | <a href='#dataset'>Dataset</a> | <a href='#quick-start'>Quick Start</a> | <a href='#benchmark'>Benchmark</a> |<a href='#doc'>Documentation</a> |<a href='#todo'>Todo List</a> | <a href='#citation'>Citation</a> |<a href='#acknowledgment'>Acknowledgement</a> | <a href='#star-history'>Star History</a> |
36
+
37
+ ## News
38
+ <span id='news'/>
39
+
40
+ - ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [11-06-2025] We have released a new version of ``EasyTPP`` that exclusively supports PyTorch. TensorFlow support has been removed to streamline the codebase and focus on PyTorch-based implementations.
41
+ - ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [11-05-2025] Added the implementation of the [S2P2](https://openreview.net/pdf?id=74SvE2GZwW) model, presented at NeurIPS'2025.
42
+ - ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [02-17-2024] ``EasyTPP`` supports HuggingFace dataset API: all datasets have been published in [HuggingFace Repo](https://huggingface.co/easytpp) and see [tutorial notebook](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_1_dataset.ipynb) for an example of usage.
43
+ - [01-16-2024] Our paper [EasyTPP: Towards Open Benchmarking Temporal Point Process](https://arxiv.org/abs/2307.08097) is accepted by ICLR'2024!
44
+ <details>
45
+ <summary>Click to see previous news</summary>
46
+ <p>
47
+ - [09-30-2023] We published two textual event sequence datasets [GDELT](https://drive.google.com/drive/folders/1Ms-ATMMFf6v4eesfJndyuPLGtX58fCnk) and [Amazon-text-review](https://drive.google.com/drive/folders/1-SLYyrl7ucEG7NpSIF0eSoG9zcbZagZw) that are used in our paper [LAMP](https://arxiv.org/abs/2305.16646), where LLM can be applied for event prediction! See [Documentation](https://ant-research.github.io/EasyTemporalPointProcess/user_guide/dataset.html#preprocessed-datasets) for more details.
48
+ - [09-30-2023] Two of our papers [Language Model Can Improve Event Prediction by Few-Shot Abductive Reasoning](https://arxiv.org/abs/2305.16646) (LAMP) and [Prompt-augmented Temporal Point Process for Streaming Event Sequence](https://arxiv.org/abs/2310.04993) (PromptTPP) are accepted by NeurIPS'2023!
49
+ - [09-02-2023] We published two non-anthropogenic datasets [earthquake](https://drive.google.com/drive/folders/1ubeIz_CCNjHyuu6-XXD0T-gdOLm12rf4) and [volcano eruption](https://drive.google.com/drive/folders/1KSWbNi8LUwC-dxz1T5sOnd9zwAot95Tp?usp=drive_link)! See <a href='#dataset'>Dataset</a> for details.
50
+ - [05-29-2023] We released ``EasyTPP`` v0.0.1!
51
+ - [12-27-2022] Our paper [Bellman Meets Hawkes: Model-Based Reinforcement Learning via Temporal Point Processes](https://arxiv.org/abs/2201.12569) was accepted by AAAI'2023!
52
+ - [10-01-2022] Our paper [HYPRO: A Hybridly Normalized Probabilistic Model for Long-Horizon Prediction of Event Sequences](https://arxiv.org/abs/2210.01753) was accepted by NeurIPS'2022!
53
+ - [05-01-2022] We started to develop `EasyTPP`.</p>
54
+ </details>
55
+
56
+
57
+ ## Features <a href='#top'>[Back to Top]</a>
58
+ <span id='features'/>
59
+
60
+ - **Configurable and customizable**: models are modularized and configurable,with abstract classes to support developing customized
61
+ TPP models.
62
+ - **PyTorch-based implementation**: `EasyTPP` implements state-of-the-art TPP models using PyTorch 1.7.0+, providing a clean and modern deep learning framework.
63
+ - **Reproducible**: all the benchmarks can be easily reproduced.
64
+ - **Hyper-parameter optimization**: a pipeline of [optuna](https://github.com/optuna/optuna)-based HPO is provided.
65
+
66
+
67
+ ## Model List <a href='#top'>[Back to Top]</a>
68
+ <span id='model-list'/>
69
+
70
+ We provide reference implementations of various state-of-the-art TPP papers:
71
+
72
+ | No | Publication | Model | Paper | Implementation |
73
+ |:---:|:-----------:|:-------------:|:-----------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------|
74
+ | 1 | KDD'16 | RMTPP | [Recurrent Marked Temporal Point Processes: Embedding Event History to Vector](https://www.kdd.org/kdd2016/papers/files/rpp1081-duA.pdf) | [PyTorch](easy_tpp/model/torch_model/torch_rmtpp.py) |
75
+ | 2 | NeurIPS'17 | NHP | [The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process](https://arxiv.org/abs/1612.09328) | [PyTorch](easy_tpp/model/torch_model/torch_nhp.py) |
76
+ | 3 | NeurIPS'19 | FullyNN | [Fully Neural Network based Model for General Temporal Point Processes](https://arxiv.org/abs/1905.09690) | [PyTorch](easy_tpp/model/torch_model/torch_fullynn.py) |
77
+ | 4 | ICML'20 | SAHP | [Self-Attentive Hawkes process](https://arxiv.org/abs/1907.07561) | [PyTorch](easy_tpp/model/torch_model/torch_sahp.py) |
78
+ | 5 | ICML'20 | THP | [Transformer Hawkes process](https://arxiv.org/abs/2002.09291) | [PyTorch](easy_tpp/model/torch_model/torch_thp.py) |
79
+ | 6 | ICLR'20 | IntensityFree | [Intensity-Free Learning of Temporal Point Processes](https://arxiv.org/abs/1909.12127) | [PyTorch](easy_tpp/model/torch_model/torch_intensity_free.py) |
80
+ | 7 | ICLR'21 | ODETPP | [Neural Spatio-Temporal Point Processes (simplified)](https://arxiv.org/abs/2011.04583) | [PyTorch](easy_tpp/model/torch_model/torch_ode_tpp.py) |
81
+ | 8 | ICLR'22 | AttNHP | [Transformer Embeddings of Irregularly Spaced Events and Their Participants](https://arxiv.org/abs/2201.00044) | [PyTorch](easy_tpp/model/torch_model/torch_attnhp.py) |
82
+ | 9 | NeurIPS'25 | S2P2 | [Deep Continuous-Time State-Space Models for Marked Event Sequences](https://openreview.net/pdf?id=74SvE2GZwW) | [PyTorch](easy_tpp/model/torch_model/torch_s2p2.py) |
83
+
84
+
85
+
86
+ ## Dataset <a href='#top'>[Back to Top]</a>
87
+ <span id='dataset'/>
88
+
89
+ We preprocessed one synthetic and five real world datasets from widely-cited works that contain diverse characteristics in terms of their application domains and temporal statistics:
90
+ - Synthetic: a univariate Hawkes process simulated by [Tick](https://github.com/X-DataInitiative/tick) library.
91
+ - Retweet ([Zhou, 2013](http://proceedings.mlr.press/v28/zhou13.pdf)): timestamped user retweet events.
92
+ - Taxi ([Whong, 2014](https://chriswhong.com/open-data/foil_nyc_taxi/)): timestamped taxi pick-up events.
93
+ - StackOverflow ([Leskovec, 2014](https://snap.stanford.edu/data/)): timestamped user badge reward events in StackOverflow.
94
+ - Taobao ([Xue et al, 2022](https://arxiv.org/abs/2210.01753)): timestamped user online shopping behavior events in Taobao platform.
95
+ - Amazon ([Xue et al, 2022](https://arxiv.org/abs/2210.01753)): timestamped user online shopping behavior events in Amazon platform.
96
+
97
+ Per users' request, we processed two non-anthropogenic datasets
98
+ - [Earthquake](https://drive.google.com/drive/folders/1ubeIz_CCNjHyuu6-XXD0T-gdOLm12rf4): timestamped earthquake events over the Conterminous U.S from 1996 to 2023, processed from [USGS](https://www.usgs.gov/programs/earthquake-hazards/science/earthquake-data).
99
+ - [Volcano eruption](https://drive.google.com/drive/folders/1KSWbNi8LUwC-dxz1T5sOnd9zwAot95Tp?usp=drive_link): timestamped volcano eruption events over the world in recent hundreds of years, processed from [The Smithsonian Institution](https://volcano.si.edu/).
100
+
101
+
102
+ All datasets are preprocess to the `Gatech` format dataset widely used for TPP researchers, and saved at [Google Drive](https://drive.google.com/drive/u/0/folders/1f8k82-NL6KFKuNMsUwozmbzDSFycYvz7) with a public access.
103
+
104
+ ## Quick Start <a href='#top'>[Back to Top]</a>
105
+ <span id='quick-start'/>
106
+
107
+
108
+ ### Colab Tutorials
109
+
110
+ Explore the following tutorials that can be opened directly in Google Colab:
111
+
112
+ - [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_1_dataset.ipynb) Tutorial 1: Dataset in EasyTPP.
113
+ - [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_2_tfb_wb.ipynb) Tutorial 2: Tensorboard in EasyTPP.
114
+ - [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_3_train_eval.ipynb) Tutorial 3: Training and Evaluation of TPPs.
115
+
116
+ ### End-to-end Example
117
+
118
+ We provide an end-to-end example for users to run a standard TPP model with `EasyTPP`.
119
+
120
+
121
+ ### Step 1. Installation
122
+
123
+ First of all, we can install the package either by using pip or from the source code on Github.
124
+
125
+ To install the latest stable version:
126
+ ```bash
127
+ pip install easy-tpp
128
+ ```
129
+
130
+ To install the latest on GitHub:
131
+ ```bash
132
+ git clone https://github.com/ant-research/EasyTemporalPointProcess.git
133
+ cd EasyTemporalPointProcess
134
+ python setup.py install
135
+ ```
136
+
137
+
138
+ ### Step 2. Prepare datasets
139
+
140
+ We need to put the datasets in a local directory before running a model and the datasets should follow a certain format. See [OnlineDoc - Datasets](https://ant-research.github.io/EasyTemporalPointProcess/user_guide/dataset.html) for more details.
141
+
142
+ Suppose we use the [taxi dataset](https://chriswhong.com/open-data/foil_nyc_taxi/) in the example.
143
+
144
+ ### Step 3. Train the model
145
+
146
+
147
+ Before start training, we need to set up the config file for the pipeline. We provide a preset config file in [Example Config](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml). The details of the configuration can be found in [OnlineDoc - Training Pipeline](https://ant-research.github.io/EasyTemporalPointProcess/user_guide/run_train_pipeline.html).
148
+
149
+ After the setup of data and config, the directory structure is as follows:
150
+
151
+ ```bash
152
+
153
+ data
154
+ |______taxi
155
+ |____ train.pkl
156
+ |____ dev.pkl
157
+ |____ test.pkl
158
+
159
+ configs
160
+ |______experiment_config.yaml
161
+
162
+ ```
163
+
164
+
165
+ Then we start the training by simply running the script
166
+
167
+ ```python
168
+
169
+ import argparse
170
+ from easy_tpp.config_factory import Config
171
+ from easy_tpp.runner import Runner
172
+
173
+
174
+ def main():
175
+ parser = argparse.ArgumentParser()
176
+
177
+ parser.add_argument('--config_dir', type=str, required=False, default='configs/experiment_config.yaml',
178
+ help='Dir of configuration yaml to train and evaluate the model.')
179
+
180
+ parser.add_argument('--experiment_id', type=str, required=False, default='NHP_train',
181
+ help='Experiment id in the config file.')
182
+
183
+ args = parser.parse_args()
184
+
185
+ config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id)
186
+
187
+ model_runner = Runner.build_from_config(config)
188
+
189
+ model_runner.run()
190
+
191
+
192
+ if __name__ == '__main__':
193
+ main()
194
+
195
+ ```
196
+
197
+ A more detailed example can be found at [OnlineDoc - QuickStart](https://ant-research.github.io/EasyTemporalPointProcess/get_started/quick_start.html).
198
+
199
+
200
+ ## Documentation <a href='#top'>[Back to Top]</a>
201
+ <span id='doc'/>
202
+
203
+ The classes and methods of `EasyTPP` have been well documented so that users can generate the documentation by:
204
+
205
+ ```shell
206
+ cd doc
207
+ pip install -r requirements.txt
208
+ make html
209
+ ```
210
+ NOTE:
211
+ * The `doc/requirements.txt` is only for documentation by Sphinx, which can be automatically generated by Github actions `.github/workflows/docs.yml`. (Trigger by pull request.)
212
+
213
+ The full documentation is available on the [website](https://ant-research.github.io/EasyTemporalPointProcess/).
214
+
215
+ ## Benchmark <a href='#top'>[Back to Top]</a>
216
+ <span id='benchmark'/>
217
+
218
+ In the [examples](https://github.com/ant-research/EasyTemporalPointProcess/tree/main/examples) folder, we provide a [script](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/benchmark_script.py) to benchmark the TPPs, with Taxi dataset as the input.
219
+
220
+ To run the script, one should download the Taxi data following the above instructions. The [config](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml) file is readily setup up. Then run
221
+
222
+
223
+ ```shell
224
+ cd examples
225
+ python run_retweet.py
226
+ ```
227
+
228
+
229
+ ## License <a href='#top'>[Back to Top]</a>
230
+
231
+ This project is licensed under the [Apache License (Version 2.0)](https://github.com/alibaba/EasyNLP/blob/master/LICENSE). This toolkit also contains some code modified from other repos under other open-source licenses. See the [NOTICE](https://github.com/ant-research/EasyTPP/blob/master/NOTICE) file for more information.
232
+
233
+
234
+ ## Todo List <a href='#top'>[Back to Top]</a>
235
+ <span id='todo'/>
236
+
237
+ - [x] New dataset:
238
+ - [x] Earthquake: the source data is available in [USGS](https://www.usgs.gov/programs/earthquake-hazards/science/earthquake-data).
239
+ - [x] Volcano eruption: the source data is available in [NCEI](https://www.ngdc.noaa.gov/hazard/volcano.shtml).
240
+ - [ ] New model:
241
+ - [ ] Meta Temporal Point Process, ICLR 2023.
242
+ - [ ] Model-based RL via TPP, AAAI 2022.
243
+
244
+ ## Citation <a href='#top'>[Back to Top]</a>
245
+
246
+ <span id='citation'/>
247
+
248
+ If you find `EasyTPP` useful for your research or development, please cite the following <a href="https://arxiv.org/abs/2307.08097" target="_blank">paper</a>:
249
+ ```
250
+ @inproceedings{xue2024easytpp,
251
+ title={EasyTPP: Towards Open Benchmarking Temporal Point Processes},
252
+ author={Siqiao Xue and Xiaoming Shi and Zhixuan Chu and Yan Wang and Hongyan Hao and Fan Zhou and Caigao Jiang and Chen Pan and James Y. Zhang and Qingsong Wen and Jun Zhou and Hongyuan Mei},
253
+ booktitle = {International Conference on Learning Representations (ICLR)},
254
+ year = {2024},
255
+ url ={https://arxiv.org/abs/2307.08097}
256
+ }
257
+ ```
258
+
259
+ ## Acknowledgment <a href='#top'>[Back to Top]</a>
260
+ <span id='acknowledgment'/>
261
+
262
+ The project is jointly initiated by Machine Intelligence Group, Alipay and DAMO Academy, Alibaba.
263
+
264
+ The following repositories are used in `EasyTPP`, either in close to original form or as an inspiration:
265
+
266
+ - [EasyRec](https://github.com/alibaba/EasyRec)
267
+ - [EasyNLP](https://github.com/alibaba/EasyNLP)
268
+ - [FuxiCTR](https://github.com/xue-pai/FuxiCTR)
269
+ - [Neural Hawkes Process](https://github.com/hongyuanmei/neurawkes)
270
+ - [Neural Hawkes Particle Smoothing](https://github.com/hongyuanmei/neural-hawkes-particle-smoothing)
271
+ - [Attentive Neural Hawkes Process](https://github.com/yangalan123/anhp-andtt)
272
+ - [Huggingface - transformers](https://github.com/huggingface/transformers)
273
+
274
+
275
+ ## Star History <a href='#top'>[Back to Top]</a>
276
+ <span id='star-history'/>
277
+
278
+ ![Star History Chart](https://api.star-history.com/svg?repos=ant-research/EasyTemporalPointProcess&type=Date)
279
+
UPLOAD_CHECKLIST.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face 上传检查清单
2
+
3
+ ## ✅ 清理完成
4
+
5
+ ### 已删除的文件类型
6
+ - `__pycache__/` 文件夹
7
+ - `*.pyc`, `*.pyo`, `*.pyd` 文件
8
+ - `.DS_Store` 文件(macOS)
9
+ - `.vscode/`, `.idea/` 文件夹
10
+ - `*.swp`, `*.swo` 文件
11
+
12
+ ### 需要手动检查的项目
13
+
14
+ 1. **大文件检查**
15
+ - 检查是否有超过50MB的文件
16
+ - 考虑使用 Git LFS 或排除这些文件
17
+
18
+ 2. **敏感信息检查**
19
+ - 检查是否有API密钥、密码等敏感信息
20
+ - 检查配置文件中的敏感数据
21
+
22
+ 3. **数据文件**
23
+ - 检查 `examples/data/` 目录
24
+ - 如果数据文件很大,考虑排除或使用外部链接
25
+
26
+ 4. **模型文件**
27
+ - 检查是否有预训练模型文件
28
+ - 大模型文件应使用 Git LFS 或 Hugging Face Model Hub
29
+
30
+ 5. **日志文件**
31
+ - 确保没有日志文件被包含
32
+ - 检查 `log/`, `logs/` 目录
33
+
34
+ ## 📦 上传到 Hugging Face
35
+
36
+ ### 方法1: 使用 Hugging Face CLI
37
+
38
+ ```bash
39
+ # 安装 Hugging Face CLI
40
+ pip install huggingface_hub
41
+
42
+ # 登录
43
+ huggingface-cli login
44
+
45
+ # 创建仓库(如果还没有)
46
+ # 在 https://huggingface.co/new 创建新仓库
47
+
48
+ # 上传文件
49
+ cd /path/to/EasyTemporalPointProcess-main
50
+ huggingface-cli upload <your-username>/<repo-name> . --repo-type dataset
51
+ ```
52
+
53
+ ### 方法2: 使用 Git
54
+
55
+ ```bash
56
+ # 初始化 Git 仓库(如果还没有)
57
+ git init
58
+ git add .
59
+ git commit -m "Initial commit"
60
+
61
+ # 添加 Hugging Face 远程仓库
62
+ git remote add origin https://huggingface.co/<your-username>/<repo-name>
63
+
64
+ # 推送
65
+ git push origin main
66
+ ```
67
+
68
+ ### 方法3: 使用 Web 界面
69
+
70
+ 1. 访问 https://huggingface.co/new
71
+ 2. 创建新的 Dataset 或 Space
72
+ 3. 使用 Web 界面上传文件
73
+
74
+ ## 📝 文件结构说明
75
+
76
+ ```
77
+ EasyTemporalPointProcess-main/
78
+ ├── easy_tpp/ # 核心库代码
79
+ ├── examples/ # 示例代码
80
+ ├── notebooks/ # Jupyter notebooks
81
+ ├── tests/ # 测试代码
82
+ ├── docs/ # 文档
83
+ ├── compute_cascade_metrics.py # 新增:级联指标计算脚本
84
+ ├── COMPUTE_METRICS_README.md # 新增:指标计算说明
85
+ ├── requirements.txt # 基础依赖
86
+ ├── requirements_compute_metrics.txt # 新增:指标计算依赖
87
+ ├── setup.py # 安装脚本
88
+ └── README.md # 项目说明
89
+ ```
90
+
91
+ ## ⚠️ 注意事项
92
+
93
+ 1. **不要上传大文件到 Git 仓库**
94
+ - 使用 Git LFS 或 Hugging Face 的存储系统
95
+ - 考虑使用外部链接引用大文件
96
+
97
+ 2. **检查许可证**
98
+ - 确保所有代码都有适当的许可证
99
+ - 检查第三方依赖的许可证兼容性
100
+
101
+ 3. **README 文件**
102
+ - 确保 README.md 清晰说明项目用途
103
+ - 包含安装和使用说明
104
+
105
+ 4. **依赖管理**
106
+ - 确保 requirements.txt 是最新的
107
+ - 考虑使用 `pip freeze` 生成精确版本
108
+
109
+ ## 🔍 验证上传
110
+
111
+ 上传后,检查:
112
+ - [ ] 所有文件都已上传
113
+ - [ ] 文件大小合理
114
+ - [ ] 没有敏感信息泄露
115
+ - [ ] README 显示正确
116
+ - [ ] 代码可以正常下载和使用
cleanup_for_hf.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 清理脚本:准备上传到 Hugging Face
5
+
6
+ 该脚本会:
7
+ 1. 清理不必要的文件(__pycache__, .pyc, .pyo等)
8
+ 2. 检查大文件
9
+ 3. 创建上传检查清单
10
+ """
11
+
12
+ import os
13
+ import shutil
14
+ from pathlib import Path
15
+ from typing import List, Tuple
16
+
17
+
18
+ def find_and_remove_patterns(root_dir: str, patterns: List[str]) -> List[str]:
19
+ """
20
+ 查找并删除匹配模式的文件/文件夹
21
+
22
+ Args:
23
+ root_dir: 根目录
24
+ patterns: 文件/文件夹模式列表
25
+
26
+ Returns:
27
+ 已删除的文件/文件夹列表
28
+ """
29
+ removed = []
30
+ root_path = Path(root_dir)
31
+
32
+ for pattern in patterns:
33
+ for item in root_path.rglob(pattern):
34
+ if item.exists():
35
+ try:
36
+ if item.is_file():
37
+ item.unlink()
38
+ removed.append(str(item))
39
+ elif item.is_dir():
40
+ shutil.rmtree(item)
41
+ removed.append(str(item))
42
+ except Exception as e:
43
+ print(f"警告: 无法删除 {item}: {e}")
44
+
45
+ return removed
46
+
47
+
48
+ def find_large_files(root_dir: str, size_mb: int = 50) -> List[Tuple[str, float]]:
49
+ """
50
+ 查找大文件
51
+
52
+ Args:
53
+ root_dir: 根目录
54
+ size_mb: 文件大小阈值(MB)
55
+
56
+ Returns:
57
+ (文件路径, 大小MB) 列表
58
+ """
59
+ large_files = []
60
+ root_path = Path(root_dir)
61
+ size_bytes = size_mb * 1024 * 1024
62
+
63
+ for item in root_path.rglob('*'):
64
+ if item.is_file():
65
+ try:
66
+ size = item.stat().st_size
67
+ if size > size_bytes:
68
+ size_mb_actual = size / (1024 * 1024)
69
+ large_files.append((str(item), size_mb_actual))
70
+ except Exception as e:
71
+ print(f"警告: 无法检查 {item}: {e}")
72
+
73
+ return large_files
74
+
75
+
76
+ def check_gitignore(root_dir: str) -> bool:
77
+ """
78
+ 检查是否存在 .gitignore 文件
79
+
80
+ Args:
81
+ root_dir: 根目录
82
+
83
+ Returns:
84
+ 是否存在 .gitignore
85
+ """
86
+ gitignore_path = Path(root_dir) / '.gitignore'
87
+ return gitignore_path.exists()
88
+
89
+
90
+ def create_upload_checklist(root_dir: str) -> str:
91
+ """
92
+ 创建上传检查清单
93
+
94
+ Args:
95
+ root_dir: 根目录
96
+
97
+ Returns:
98
+ 检查清单内容
99
+ """
100
+ checklist = """# Hugging Face 上传检查清单
101
+
102
+ ## ✅ 清理完成
103
+
104
+ ### 已删除的文件类型
105
+ - `__pycache__/` 文件夹
106
+ - `*.pyc`, `*.pyo`, `*.pyd` 文件
107
+ - `.DS_Store` 文件(macOS)
108
+ - `.vscode/`, `.idea/` 文件夹
109
+ - `*.swp`, `*.swo` 文件
110
+
111
+ ### 需要手动检查的项目
112
+
113
+ 1. **大文件检查**
114
+ - 检查是否有超过50MB的文件
115
+ - 考虑使用 Git LFS 或排除这些文件
116
+
117
+ 2. **敏感信息检查**
118
+ - 检查是否有API密钥、密码等敏感信息
119
+ - 检查配置文件中的敏感数据
120
+
121
+ 3. **数据文件**
122
+ - 检查 `examples/data/` 目录
123
+ - 如果数据文件很大,考虑排除或使用外部链接
124
+
125
+ 4. **模型文件**
126
+ - 检查是否有预训练模型文件
127
+ - 大模型文件应使用 Git LFS 或 Hugging Face Model Hub
128
+
129
+ 5. **日志文件**
130
+ - 确保没有日志文件被包含
131
+ - 检查 `log/`, `logs/` 目录
132
+
133
+ ## 📦 上传到 Hugging Face
134
+
135
+ ### 方法1: 使用 Hugging Face CLI
136
+
137
+ ```bash
138
+ # 安装 Hugging Face CLI
139
+ pip install huggingface_hub
140
+
141
+ # 登录
142
+ huggingface-cli login
143
+
144
+ # 创建仓库(如果还没有)
145
+ # 在 https://huggingface.co/new 创建新仓库
146
+
147
+ # 上传文件
148
+ cd /path/to/EasyTemporalPointProcess-main
149
+ huggingface-cli upload <your-username>/<repo-name> . --repo-type dataset
150
+ ```
151
+
152
+ ### 方法2: 使用 Git
153
+
154
+ ```bash
155
+ # 初始化 Git 仓库(如果还没有)
156
+ git init
157
+ git add .
158
+ git commit -m "Initial commit"
159
+
160
+ # 添加 Hugging Face 远程仓库
161
+ git remote add origin https://huggingface.co/<your-username>/<repo-name>
162
+
163
+ # 推送
164
+ git push origin main
165
+ ```
166
+
167
+ ### 方法3: 使用 Web 界面
168
+
169
+ 1. 访问 https://huggingface.co/new
170
+ 2. 创建新的 Dataset 或 Space
171
+ 3. 使用 Web 界面上传文件
172
+
173
+ ## 📝 文件结构说明
174
+
175
+ ```
176
+ EasyTemporalPointProcess-main/
177
+ ├── easy_tpp/ # 核心库代码
178
+ ├── examples/ # 示例代码
179
+ ├── notebooks/ # Jupyter notebooks
180
+ ├── tests/ # 测试代码
181
+ ├── docs/ # 文档
182
+ ├── compute_cascade_metrics.py # 新增:级联指标计算脚本
183
+ ├── COMPUTE_METRICS_README.md # 新增:指标计算说明
184
+ ├── requirements.txt # 基础依赖
185
+ ├── requirements_compute_metrics.txt # 新增:指标计算依赖
186
+ ├── setup.py # 安装脚本
187
+ └── README.md # 项目说明
188
+ ```
189
+
190
+ ## ⚠️ 注意事项
191
+
192
+ 1. **不要上传大文件到 Git 仓库**
193
+ - 使用 Git LFS 或 Hugging Face 的存储系统
194
+ - 考虑使用外部链接引用大文件
195
+
196
+ 2. **检查许可证**
197
+ - 确保所有代码都有适当的许可证
198
+ - 检查第三方依赖的许可证兼容性
199
+
200
+ 3. **README 文件**
201
+ - 确保 README.md 清晰说明项目用途
202
+ - 包含安装和使用说明
203
+
204
+ 4. **依赖管理**
205
+ - 确保 requirements.txt 是最新的
206
+ - 考虑使用 `pip freeze` 生成精确版本
207
+
208
+ ## 🔍 验证上传
209
+
210
+ 上传后,检查:
211
+ - [ ] 所有文件都已上传
212
+ - [ ] 文件大小合理
213
+ - [ ] 没有敏感信息泄露
214
+ - [ ] README 显示正确
215
+ - [ ] 代码可以正常下载和使用
216
+ """
217
+
218
+ return checklist
219
+
220
+
221
+ def main():
222
+ """主函数"""
223
+ root_dir = os.path.dirname(os.path.abspath(__file__))
224
+
225
+ print("=" * 60)
226
+ print("清理脚本:准备上传到 Hugging Face")
227
+ print("=" * 60)
228
+
229
+ # 要删除的模式
230
+ patterns_to_remove = [
231
+ '__pycache__',
232
+ '*.pyc',
233
+ '*.pyo',
234
+ '*.pyd',
235
+ '.DS_Store',
236
+ '.vscode',
237
+ '.idea',
238
+ '*.swp',
239
+ '*.swo',
240
+ '*.log',
241
+ '.pytest_cache',
242
+ '.mypy_cache',
243
+ '.ruff_cache',
244
+ ]
245
+
246
+ print("\n1. 清理不必要的文件...")
247
+ removed = find_and_remove_patterns(root_dir, patterns_to_remove)
248
+ if removed:
249
+ print(f" 已删除 {len(removed)} 个文件/文件夹")
250
+ for item in removed[:10]: # 只显示前10个
251
+ print(f" - {item}")
252
+ if len(removed) > 10:
253
+ print(f" ... 还有 {len(removed) - 10} 个文件/文件夹")
254
+ else:
255
+ print(" 没有找到需要删除的文件")
256
+
257
+ # 检查大文件
258
+ print("\n2. 检查大文件(>50MB)...")
259
+ large_files = find_large_files(root_dir, size_mb=50)
260
+ if large_files:
261
+ print(f" 找到 {len(large_files)} 个大文件:")
262
+ for file_path, size_mb in large_files:
263
+ print(f" - {file_path} ({size_mb:.2f} MB)")
264
+ print("\n ⚠️ 建议:大文件应使用 Git LFS 或排除在上传之外")
265
+ else:
266
+ print(" ✅ 没有找到大文件")
267
+
268
+ # 检查 .gitignore
269
+ print("\n3. 检查 .gitignore...")
270
+ if check_gitignore(root_dir):
271
+ print(" ✅ .gitignore 文件存在")
272
+ else:
273
+ print(" ⚠️ 警告: .gitignore 文件不存在")
274
+
275
+ # 创建检查清单
276
+ print("\n4. 创建上传检查清单...")
277
+ checklist_content = create_upload_checklist(root_dir)
278
+ checklist_path = Path(root_dir) / 'UPLOAD_CHECKLIST.md'
279
+ with open(checklist_path, 'w', encoding='utf-8') as f:
280
+ f.write(checklist_content)
281
+ print(f" ✅ 已创建: {checklist_path}")
282
+
283
+ print("\n" + "=" * 60)
284
+ print("清理完成!")
285
+ print("=" * 60)
286
+ print("\n下一步:")
287
+ print("1. 查看 UPLOAD_CHECKLIST.md 了解上传步骤")
288
+ print("2. 检查是否有敏感信息需要移除")
289
+ print("3. 按照检查清单上传到 Hugging Face")
290
+
291
+
292
+ if __name__ == '__main__':
293
+ main()
compute_cascade_metrics.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 计算信息级联的指标:情感得分、情感deviation、contextual deviation、perplexity
5
+
6
+ 该脚本处理 information_cascade.json 和 information_cascade_original_posts.json,
7
+ 计算以下指标:
8
+ 1. 情感得分 (sentiment score)
9
+ 2. 情感deviation (sentiment deviation)
10
+ 3. Contextual deviation (语境偏差)
11
+ 4. Perplexity (困惑度)
12
+
13
+ 使用方法(在云电脑上):
14
+ python compute_cascade_metrics.py \
15
+ --input_cascade information_cascade.json \
16
+ --input_original information_cascade_original_posts.json \
17
+ --output output_with_metrics.json \
18
+ --bert_model bert-base-chinese \
19
+ --sentiment_model <sentiment_model_path> \
20
+ --perplexity_model <perplexity_model_path> \
21
+ --batch_size 32
22
+ """
23
+
24
+ import argparse
25
+ import json
26
+ import numpy as np
27
+ import torch
28
+ from typing import Dict, List, Any, Optional, Tuple
29
+ from tqdm import tqdm
30
+ from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
31
+ import os
32
+
33
+
34
+ class CascadeMetricsComputer:
35
+ """
36
+ 计算级联数据的各种指标
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ bert_model_name: str = 'bert-base-chinese',
42
+ sentiment_model_name: Optional[str] = None,
43
+ perplexity_model_name: Optional[str] = None,
44
+ device: Optional[str] = None,
45
+ batch_size: int = 32,
46
+ max_length: int = 512
47
+ ):
48
+ """
49
+ 初始化指标计算器
50
+
51
+ Args:
52
+ bert_model_name: BERT模型名称(用于计算语义向量和contextual deviation)
53
+ sentiment_model_name: 情感分析模型名称(用于计算情感得分)
54
+ perplexity_model_name: 语言模型名称(用于计算困惑度)
55
+ device: 计算设备('cuda'或'cpu'),如果为None则自动选择
56
+ batch_size: 批处理大小
57
+ max_length: 最大序列长度
58
+ """
59
+ if device is None:
60
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
61
+
62
+ self.device = device
63
+ self.batch_size = batch_size
64
+ self.max_length = max_length
65
+
66
+ print(f"正在加载BERT模型: {bert_model_name}")
67
+ self.bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
68
+ self.bert_model = AutoModel.from_pretrained(bert_model_name)
69
+ self.bert_model.to(device)
70
+ self.bert_model.eval()
71
+ print(f"BERT模型已加载到设备: {device}")
72
+
73
+ # 加载情感分析模型
74
+ if sentiment_model_name:
75
+ print(f"正在加载情感分析模型: {sentiment_model_name}")
76
+ self.sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
77
+ self.sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name)
78
+ self.sentiment_model.to(device)
79
+ self.sentiment_model.eval()
80
+ print(f"情感分析模型已加载到设备: {device}")
81
+ else:
82
+ self.sentiment_tokenizer = None
83
+ self.sentiment_model = None
84
+ print("未提供情感分析模型,将使用简化的情感计算方法")
85
+
86
+ # 加载困惑度模型(语言模型)
87
+ if perplexity_model_name:
88
+ print(f"正在加载困惑度模型: {perplexity_model_name}")
89
+ self.perplexity_tokenizer = AutoTokenizer.from_pretrained(perplexity_model_name)
90
+ self.perplexity_model = AutoModelForCausalLM.from_pretrained(perplexity_model_name)
91
+ self.perplexity_model.to(device)
92
+ self.perplexity_model.eval()
93
+ print(f"困惑度模型已加载到设备: {device}")
94
+ else:
95
+ self.perplexity_tokenizer = None
96
+ self.perplexity_model = None
97
+ print("未提供困惑度模型,将使用简化的困惑度计算方法")
98
+
99
+ def compute_embeddings(self, texts: List[str]) -> np.ndarray:
100
+ """
101
+ 计算BERT语义向量
102
+
103
+ Args:
104
+ texts: 文本列表
105
+
106
+ Returns:
107
+ 语义向量矩阵 [num_texts, hidden_size]
108
+ """
109
+ embeddings = []
110
+
111
+ with torch.no_grad():
112
+ for i in range(0, len(texts), self.batch_size):
113
+ batch_texts = texts[i:i + self.batch_size]
114
+
115
+ # 处理空文本
116
+ batch_texts = [text if text else "[PAD]" for text in batch_texts]
117
+
118
+ # 分词和编码
119
+ inputs = self.bert_tokenizer(
120
+ batch_texts,
121
+ return_tensors='pt',
122
+ padding=True,
123
+ truncation=True,
124
+ max_length=self.max_length
125
+ ).to(self.device)
126
+
127
+ # 前向传播
128
+ outputs = self.bert_model(**inputs)
129
+
130
+ # 使用[CLS]标记的嵌入
131
+ batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
132
+ embeddings.append(batch_embeddings)
133
+
134
+ return np.vstack(embeddings)
135
+
136
+ def compute_sentiment_scores(self, texts: List[str]) -> List[float]:
137
+ """
138
+ 计算情感得分
139
+
140
+ Args:
141
+ texts: 文本列表
142
+
143
+ Returns:
144
+ 情感得分列表(每个文本一个得分,范围通常在[-1, 1]或[0, 1])
145
+ """
146
+ if self.sentiment_model is None:
147
+ # 使用简化的情感计算方法
148
+ return self._compute_sentiment_simple(texts)
149
+
150
+ sentiment_scores = []
151
+
152
+ with torch.no_grad():
153
+ for i in range(0, len(texts), self.batch_size):
154
+ batch_texts = texts[i:i + self.batch_size]
155
+ batch_texts = [text if text else "[PAD]" for text in batch_texts]
156
+
157
+ inputs = self.sentiment_tokenizer(
158
+ batch_texts,
159
+ return_tensors='pt',
160
+ padding=True,
161
+ truncation=True,
162
+ max_length=self.max_length
163
+ ).to(self.device)
164
+
165
+ outputs = self.sentiment_model(**inputs)
166
+ logits = outputs.logits
167
+
168
+ # 假设是二分类(正面/负面),使用softmax获取概率
169
+ probs = torch.softmax(logits, dim=-1)
170
+
171
+ # 计算情感得分:正面概率 - 负面概率(或使用其他方法)
172
+ if probs.shape[1] == 2:
173
+ # 二分类:[负面概率, 正面概率]
174
+ batch_scores = (probs[:, 1] - probs[:, 0]).cpu().numpy().tolist()
175
+ else:
176
+ # 多分类或其他情况,使用第一个类别的概率作为得分
177
+ batch_scores = probs[:, 0].cpu().numpy().tolist()
178
+
179
+ sentiment_scores.extend(batch_scores)
180
+
181
+ return sentiment_scores
182
+
183
+ def _compute_sentiment_simple(self, texts: List[str]) -> List[float]:
184
+ """
185
+ 简化的情感计算方法(基于启发式规则)
186
+
187
+ Args:
188
+ texts: 文本列表
189
+
190
+ Returns:
191
+ 情感得分列表
192
+ """
193
+ scores = []
194
+ for text in texts:
195
+ if not text:
196
+ scores.append(0.0)
197
+ continue
198
+
199
+ # 简单的启发式方法
200
+ positive_words = ['好', '棒', '赞', '喜欢', '支持', '👍', '❤️', '😊', '😄']
201
+ negative_words = ['差', '坏', '讨厌', '反对', '👎', '😢', '😠', '😡']
202
+
203
+ positive_count = sum(1 for word in positive_words if word in text)
204
+ negative_count = sum(1 for word in negative_words if word in text)
205
+
206
+ # 计算情感得分(归一化到[-1, 1])
207
+ total_words = len(text)
208
+ if total_words > 0:
209
+ score = (positive_count - negative_count) / max(total_words, 1)
210
+ score = np.clip(score, -1.0, 1.0)
211
+ else:
212
+ score = 0.0
213
+
214
+ scores.append(score)
215
+
216
+ return scores
217
+
218
+ def compute_perplexity(self, texts: List[str]) -> List[float]:
219
+ """
220
+ 计算困惑度
221
+
222
+ Args:
223
+ texts: 文本列表
224
+
225
+ Returns:
226
+ 困惑度列表
227
+ """
228
+ if self.perplexity_model is None:
229
+ # 使用简化的困惑度计算方法
230
+ return self._compute_perplexity_simple(texts)
231
+
232
+ perplexities = []
233
+
234
+ with torch.no_grad():
235
+ for text in texts:
236
+ if not text:
237
+ perplexities.append(0.0)
238
+ continue
239
+
240
+ # 分词
241
+ inputs = self.perplexity_tokenizer(
242
+ text,
243
+ return_tensors='pt',
244
+ truncation=True,
245
+ max_length=self.max_length
246
+ ).to(self.device)
247
+
248
+ # 计算困惑度
249
+ outputs = self.perplexity_model(**inputs, labels=inputs['input_ids'])
250
+ loss = outputs.loss
251
+
252
+ # 困惑度 = exp(loss)
253
+ perplexity = torch.exp(loss).item()
254
+ perplexities.append(perplexity)
255
+
256
+ return perplexities
257
+
258
+ def _compute_perplexity_simple(self, texts: List[str]) -> List[float]:
259
+ """
260
+ 简化的困惑度计算方法(基于词汇多样性)
261
+
262
+ Args:
263
+ texts: 文本列表
264
+
265
+ Returns:
266
+ 困惑度列表
267
+ """
268
+ perplexities = []
269
+
270
+ for text in texts:
271
+ if not text:
272
+ perplexities.append(0.0)
273
+ continue
274
+
275
+ # 基于词汇多样性的简化方法
276
+ words = text.split()
277
+ unique_words = len(set(words))
278
+ total_words = len(words)
279
+
280
+ if total_words > 0:
281
+ # 词汇多样性越低,困惑度越高(简化代理)
282
+ perplexity_proxy = 1.0 - (unique_words / total_words)
283
+ else:
284
+ perplexity_proxy = 0.0
285
+
286
+ perplexities.append(perplexity_proxy)
287
+
288
+ return perplexities
289
+
290
+ def compute_cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
291
+ """
292
+ 计算余弦相似度
293
+
294
+ Args:
295
+ vec1: 向量1
296
+ vec2: 向量2
297
+
298
+ Returns:
299
+ 余弦相似度 [0, 1]
300
+ """
301
+ dot_product = np.dot(vec1, vec2)
302
+ norm1 = np.linalg.norm(vec1)
303
+ norm2 = np.linalg.norm(vec2)
304
+
305
+ if norm1 == 0 or norm2 == 0:
306
+ return 0.0
307
+
308
+ similarity = dot_product / (norm1 * norm2)
309
+ return float(similarity)
310
+
311
+ def compute_contextual_deviation(self, root_embedding: np.ndarray, current_embedding: np.ndarray) -> float:
312
+ """
313
+ 计算语境偏差(Contextual Deviation)
314
+
315
+ 定义为:1 - 语义相似度
316
+
317
+ Args:
318
+ root_embedding: 原帖的语义向量
319
+ current_embedding: 当前文本的语义向量
320
+
321
+ Returns:
322
+ 语境偏差值 [0, 1],越高表示越偏离原帖语境
323
+ """
324
+ similarity = self.compute_cosine_similarity(root_embedding, current_embedding)
325
+ deviation = 1.0 - similarity
326
+ return deviation
327
+
328
+ def compute_sentiment_deviation(self, root_sentiment: float, current_sentiment: float) -> float:
329
+ """
330
+ 计算情感偏差(Sentiment Deviation)
331
+
332
+ 定义为:|当前情感得分 - 原帖情感得分|
333
+
334
+ Args:
335
+ root_sentiment: 原帖的情感得分
336
+ current_sentiment: 当前文本的情感得分
337
+
338
+ Returns:
339
+ 情感偏差值 [0, 2](如果情感得分范围是[-1, 1])
340
+ """
341
+ deviation = abs(current_sentiment - root_sentiment)
342
+ return deviation
343
+
344
+ def process_cascade(self, cascade: Dict[str, Any]) -> Dict[str, Any]:
345
+ """
346
+ 处理单个级联,计算所有指标
347
+
348
+ Args:
349
+ cascade: 级联数据字典
350
+
351
+ Returns:
352
+ 添加了指标后的级联数据字典
353
+ """
354
+ # 1. 收集所有文本
355
+ texts: List[str] = []
356
+ indices: List[Tuple[str, Optional[str]]] = []
357
+
358
+ # 原帖
359
+ post_info = cascade.get('post_info', {})
360
+ post_content = post_info.get('content', '')
361
+ texts.append(post_content)
362
+ indices.append(('post', None))
363
+
364
+ # 评论
365
+ comment_tree = cascade.get('comment_tree', {})
366
+ comment_ids = list(comment_tree.keys())
367
+ for comment_id in comment_ids:
368
+ node = comment_tree[comment_id]
369
+ texts.append(node.get('content', ''))
370
+ indices.append(('comment', comment_id))
371
+
372
+ # 转发
373
+ repost_chain = cascade.get('repost_chain', [])
374
+ for node in repost_chain:
375
+ forward_text = node.get('forward_text', '') or ''
376
+ comment_content = node.get('comment_content', '') or ''
377
+ repost_text = forward_text + comment_content
378
+ texts.append(repost_text)
379
+ indices.append(('repost', node.get('repost_id')))
380
+
381
+ # 2. 批量计算特征
382
+ if len(texts) == 0:
383
+ return cascade
384
+
385
+ embeddings = self.compute_embeddings(texts)
386
+ sentiment_scores = self.compute_sentiment_scores(texts)
387
+ perplexities = self.compute_perplexity(texts)
388
+
389
+ # 3. 获取原帖的特征(用于计算偏差)
390
+ root_embedding = embeddings[0]
391
+ root_sentiment = sentiment_scores[0]
392
+
393
+ # 4. 将特征附加到级联数据中
394
+ # 原帖
395
+ post_info['embedding'] = root_embedding.tolist()
396
+ post_info['sentiment_score'] = root_sentiment
397
+ post_info['perplexity'] = perplexities[0]
398
+
399
+ # 评论
400
+ for i, comment_id in enumerate(comment_ids):
401
+ node = comment_tree[comment_id]
402
+ idx = 1 + i # 跳过原帖
403
+
404
+ node['embedding'] = embeddings[idx].tolist()
405
+ node['sentiment_score'] = sentiment_scores[idx]
406
+ node['perplexity'] = perplexities[idx]
407
+
408
+ # 计算偏差
409
+ node['contextual_deviation'] = self.compute_contextual_deviation(
410
+ root_embedding, embeddings[idx]
411
+ )
412
+ node['sentiment_deviation'] = self.compute_sentiment_deviation(
413
+ root_sentiment, sentiment_scores[idx]
414
+ )
415
+
416
+ # 转发
417
+ offset = 1 + len(comment_ids)
418
+ for j, node in enumerate(repost_chain):
419
+ idx = offset + j
420
+
421
+ node['embedding'] = embeddings[idx].tolist()
422
+ node['sentiment_score'] = sentiment_scores[idx]
423
+ node['perplexity'] = perplexities[idx]
424
+
425
+ # 计算偏差
426
+ node['contextual_deviation'] = self.compute_contextual_deviation(
427
+ root_embedding, embeddings[idx]
428
+ )
429
+ node['sentiment_deviation'] = self.compute_sentiment_deviation(
430
+ root_sentiment, sentiment_scores[idx]
431
+ )
432
+
433
+ return cascade
434
+
435
+
436
+ def load_json_file(file_path: str) -> Dict[str, Any]:
437
+ """
438
+ 加载JSON文件(支持大文件)
439
+
440
+ Args:
441
+ file_path: JSON文件路径
442
+
443
+ Returns:
444
+ 数据字典
445
+ """
446
+ print(f"正在加载JSON文件: {file_path}")
447
+ with open(file_path, 'r', encoding='utf-8') as f:
448
+ data = json.load(f)
449
+ print(f"已加载 {len(data.get('cascades', []))} 个级联")
450
+ return data
451
+
452
+
453
+ def main():
454
+ parser = argparse.ArgumentParser(
455
+ description='计算信息级联的指标:情感得分、情感deviation、contextual deviation、perplexity'
456
+ )
457
+ parser.add_argument(
458
+ '--input_cascade',
459
+ type=str,
460
+ required=True,
461
+ help='输入级联JSON文件路径 (information_cascade.json)'
462
+ )
463
+ parser.add_argument(
464
+ '--input_original',
465
+ type=str,
466
+ default=None,
467
+ help='输入原帖JSON文件路径 (information_cascade_original_posts.json),可选'
468
+ )
469
+ parser.add_argument(
470
+ '--output',
471
+ type=str,
472
+ required=True,
473
+ help='输出JSON文件路径'
474
+ )
475
+ parser.add_argument(
476
+ '--bert_model',
477
+ type=str,
478
+ default='bert-base-chinese',
479
+ help='BERT模型名称或路径(用于计算语义向量)'
480
+ )
481
+ parser.add_argument(
482
+ '--sentiment_model',
483
+ type=str,
484
+ default=None,
485
+ help='情感分析模型名称或路径(可选)'
486
+ )
487
+ parser.add_argument(
488
+ '--perplexity_model',
489
+ type=str,
490
+ default=None,
491
+ help='语言模型名称或路径(用于计算困惑度,可选)'
492
+ )
493
+ parser.add_argument(
494
+ '--batch_size',
495
+ type=int,
496
+ default=32,
497
+ help='批处理大小'
498
+ )
499
+ parser.add_argument(
500
+ '--max_length',
501
+ type=int,
502
+ default=512,
503
+ help='最大序列长度'
504
+ )
505
+ parser.add_argument(
506
+ '--device',
507
+ type=str,
508
+ default=None,
509
+ help='计算设备(cuda/cpu),如果为None则自动选择'
510
+ )
511
+ parser.add_argument(
512
+ '--max_cascades',
513
+ type=int,
514
+ default=None,
515
+ help='最大处理级联数量(用于测试,None表示处理所有)'
516
+ )
517
+
518
+ args = parser.parse_args()
519
+
520
+ # 加载数据
521
+ cascade_data = load_json_file(args.input_cascade)
522
+
523
+ if args.input_original:
524
+ original_data = load_json_file(args.input_original)
525
+ # 如果需要合并数据,在这里处理
526
+ # 目前先只处理cascade_data
527
+
528
+ # 初始化指标计算器
529
+ print("\n初始化指标计算器...")
530
+ computer = CascadeMetricsComputer(
531
+ bert_model_name=args.bert_model,
532
+ sentiment_model_name=args.sentiment_model,
533
+ perplexity_model_name=args.perplexity_model,
534
+ device=args.device,
535
+ batch_size=args.batch_size,
536
+ max_length=args.max_length
537
+ )
538
+
539
+ # 处理级联
540
+ cascades = cascade_data.get('cascades', [])
541
+ total_cascades = len(cascades)
542
+ if args.max_cascades:
543
+ cascades = cascades[:args.max_cascades]
544
+
545
+ print(f"\n开始处理 {len(cascades)}/{total_cascades} 个级联...")
546
+ processed_count = 0
547
+ for idx, cascade in enumerate(tqdm(cascades, desc="处理级联")):
548
+ try:
549
+ cascade_data['cascades'][idx] = computer.process_cascade(cascade)
550
+ processed_count += 1
551
+ except Exception as e:
552
+ print(f"\n处理级联 {idx} 时出错: {e}")
553
+ import traceback
554
+ traceback.print_exc()
555
+ continue
556
+
557
+ print(f"\n成功处理 {processed_count}/{len(cascades)} 个级联")
558
+
559
+ # 保存结果
560
+ print(f"\n正在保存结果到: {args.output}")
561
+ with open(args.output, 'w', encoding='utf-8') as f:
562
+ json.dump(cascade_data, f, ensure_ascii=False, indent=2)
563
+
564
+ print(f"✅ 完成!结果已保存到: {args.output}")
565
+
566
+
567
+ if __name__ == '__main__':
568
+ main()
data/cascades/.gitkeep ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # 此目录用于存放级联数据文件
2
+ # 数据文件太大,已通过 .gitignore 排除
3
+ # 请参考 DATA_FILES_NOTICE.md 了解如何获取数据文件
data/cascades/README.md ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cascade Data Files
2
+
3
+ 本目录包含信息级联数据文件。
4
+
5
+ ## 📁 文件说明
6
+
7
+ ### 主要文件
8
+
9
+ 1. **`information_cascade.json`** (606MB)
10
+ - 完整的级联数据,包含原帖、评论、转发等信息
11
+ - 用于计算级联指标和训练模型
12
+
13
+ 2. **`information_cascade_original_posts.json`** (980MB)
14
+ - 原帖数据
15
+ - 包含原始微博帖子信息
16
+
17
+ ## ⚠️ 文件大小说明
18
+
19
+ 这些文件较大(总计约 1.6GB),**不会自动上传到 Git/Hugging Face**。
20
+
21
+ ## 📥 如何获取数据文件
22
+
23
+ ### 方法1: 手动下载
24
+
25
+ 数据文件需要单独下载或传输到云电脑:
26
+
27
+ ```bash
28
+ # 在云电脑上创建目录
29
+ mkdir -p data/cascades
30
+
31
+ # 使用 scp 或其他方式传输文件
32
+ scp user@local:/path/to/information_cascade.json ./data/cascades/
33
+ scp user@local:/path/to/information_cascade_original_posts.json ./data/cascades/
34
+ ```
35
+
36
+ ### 方法2: 使用 Git LFS(如果配置)
37
+
38
+ 如果使用 Git LFS:
39
+
40
+ ```bash
41
+ # 安装 Git LFS
42
+ git lfs install
43
+
44
+ # 跟踪大文件
45
+ git lfs track "data/cascades/*.json"
46
+
47
+ # 添加并提交
48
+ git add .gitattributes
49
+ git add data/cascades/*.json
50
+ git commit -m "Add cascade data files with LFS"
51
+ ```
52
+
53
+ ### 方法3: 使用外部存储
54
+
55
+ - 上传到云存储(如 Google Drive, Dropbox)
56
+ - 使用 Hugging Face Dataset Hub 的存储系统
57
+ - 使用对象存储服务(如 AWS S3, 阿里云 OSS)
58
+
59
+ ## 🚀 使用数据文件
60
+
61
+ ### 运行指标计算
62
+
63
+ ```bash
64
+ python compute_cascade_metrics.py \
65
+ --input_cascade data/cascades/information_cascade.json \
66
+ --input_original data/cascades/information_cascade_original_posts.json \
67
+ --output output_with_metrics.json \
68
+ --batch_size 32
69
+ ```
70
+
71
+ ### 数据格式
72
+
73
+ JSON 文件格式:
74
+ ```json
75
+ {
76
+ "cascades": [
77
+ {
78
+ "post_info": {
79
+ "content": "...",
80
+ "timestamp": "..."
81
+ },
82
+ "comment_tree": {...},
83
+ "repost_chain": [...]
84
+ }
85
+ ]
86
+ }
87
+ ```
88
+
89
+ 详细格式说明请参考项目文档。
90
+
91
+ ## 📝 注意事项
92
+
93
+ 1. **文件大小**: 这些文件很大,确保有足够的磁盘空间
94
+ 2. **内存**: 加载完整文件可能需要大量内存
95
+ 3. **处理**: 建议使用批处理方式处理数据
96
+ 4. **备份**: 建议保留数据文件的备份
97
+
98
+ ## 🔗 相关文档
99
+
100
+ - [指标计算说明](../COMPUTE_METRICS_README.md)
101
+ - [上传指南](../HF_UPLOAD_GUIDE.md)
docs/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal makefile for Sphinx documentation
2
+ #
3
+
4
+ # You can set these variables from the command line, and also
5
+ # from the environment for the first two.
6
+ SPHINXOPTS ?=
7
+ SPHINXBUILD ?= sphinx-build
8
+ SOURCEDIR = source
9
+ BUILDDIR = build
10
+
11
+ # Put it first so that "make" without argument is like "make help".
12
+ help:
13
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
+
15
+ .PHONY: help Makefile
16
+
17
+ # Catch-all target: route all unknown targets to Sphinx using the new
18
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
+ %: Makefile
20
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
docs/README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Documentation for EasyTPP
2
+
3
+ This contains the full documentation of EasyTPP, which is hosted at github and can be updated manually (for releases)
4
+ by pushing to the gh-pages branch.
5
+
6
+
7
+ To generate the documentation locally, type
8
+
9
+ ```
10
+ pip install -r requirements-doc.txt
11
+ cd docs
12
+ make html
13
+ ```
docs/images/thinning_algo.jpg ADDED

Git LFS Details

  • SHA256: f025ac2ec9c15033ba78ddbdc2b3cb90a842d3df1ac4c73d056fd61ee1304fec
  • Pointer size: 131 Bytes
  • Size of remote file: 236 kB
docs/make.bat ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @ECHO OFF
2
+
3
+ pushd %~dp0
4
+
5
+ REM Command file for Sphinx documentation
6
+
7
+ if "%SPHINXBUILD%" == "" (
8
+ set SPHINXBUILD=sphinx-build
9
+ )
10
+ set SOURCEDIR=source
11
+ set BUILDDIR=build
12
+
13
+ %SPHINXBUILD% >NUL 2>NUL
14
+ if errorlevel 9009 (
15
+ echo.
16
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17
+ echo.installed, then set the SPHINXBUILD environment variable to point
18
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
19
+ echo.may add the Sphinx directory to PATH.
20
+ echo.
21
+ echo.If you don't have Sphinx installed, grab it from
22
+ echo.https://www.sphinx-doc.org/
23
+ exit /b 1
24
+ )
25
+
26
+ if "%1" == "" goto help
27
+
28
+ %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29
+ goto end
30
+
31
+ :help
32
+ %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33
+
34
+ :end
35
+ popd
docs/source/advanced/implementation.rst ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ===================================
2
+ Model Implementation Details
3
+ ===================================
4
+
5
+ Basic structure
6
+ ===================================
7
+
8
+ In the model folder, `torch_basemodel` (**/model/torch_model/torch_basemodel.py**) / `tf_basemodel` (**/model/tf_model/tf_basemodel.py**) implements functionalities of computing loglikelihood and sampling procedures that are common
9
+ to all the TPP models. In the inherited class, models with specific structures are defined, explained in below sections.
10
+
11
+
12
+ Computing the loglikelihood of non-pad event sequence
13
+ ------------------------------------------------------
14
+
15
+ The loglikelihood computation, following the definition in Equation 8 of `The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process <https://arxiv.org/abs/1612.09328>`_, is shared by all the TPP models.
16
+
17
+ it takes `time_delta_seqs`, `lambda_at_event`, `lambdas_loss_samples`, `seq_mask`,
18
+ `lambda_type_mask` as the input and output the loglikelihood items, please see `torch_basemodel` (**/model/torch_model/torch_basemodel.py**) / `tf_basemodel` (**/model/tf_model/tf_basemodel.py**)
19
+ for details.
20
+
21
+ It is noted that:
22
+
23
+ 1. Sequential prediction: because we performance sequential prediction, i.e., predict next one given previous, we do not consider the last one as it has no labels. To implement the `forward` function, we take input of `time_seqs[:, :-1]`
24
+ and `type_seqs[:, :-1]`. For `time_delta_seqs` it is different; please see the next point.
25
+
26
+
27
+
28
+ 2. Continuous-time evolution: recall the definition in [dataset](./dataset.rst), assume we have a sequence of 4 events and 1 pad event
29
+ at the end, i.e.,
30
+
31
+ .. code-block:: bash
32
+
33
+ index: 0, 1, 2, 3, 4
34
+ dtimes: 0, t_1-t_0, t_2-t_1, t_3-t_2, pad
35
+ types: e_0, e_1, e_2, e_3, pad
36
+ non_pad_mask: True, True, True, True, False
37
+
38
+ For the i-th event, i-th dtime denotes the time evolution (e.g., decay in NHP) to the current event and
39
+ (i+1)-th dtime denotes the time evolution to the next event. To compute the non-event loglikelihood,
40
+ we should consider the time evolution after the event happens. Therefore we should use `type_delta_seqs[:, 1:]` with masks specified in the below step.
41
+
42
+ 3. Masking: suppose we have predictions of 0,1,2,3-th event and their labels are 1,2,3,4-th events
43
+ where $4$-th event needed to be masked. So we should set the sequence mask as `True, True, True, False`, i.e., `seq_mask=batch_non_pad_mask[:, 1:]`.
44
+ The same logic applies to the attention mask and event type mask.
45
+
46
+ Therefore the following code is a typical example of calling the loglikelihood computation:
47
+
48
+
49
+ .. code-block:: python
50
+
51
+ event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, # seq_len = max_len - 1
52
+ lambdas_loss_samples=lambda_t_sample, # seq_len = max_len - 1
53
+ time_delta_seq=time_delta_seq[:, 1:],
54
+ seq_mask=batch_non_pad_mask[:, 1:],
55
+ lambda_type_mask=type_mask[:, 1:])
56
+
57
+
58
+
59
+ Computing the integral inside the loglikelihood
60
+ -----------------------------------------------
61
+
62
+
63
+ The loglikelihood of the parameters is the sum of the log-intensities of the events that happened, at the times they happened,
64
+ minus an integral of the total intensities over the observation interval over [0,T]:
65
+
66
+ .. math::
67
+
68
+ \sum_{t_i}\log \lambda_{k_i}(t_i) - \int_0^T \lambda(t) dt
69
+
70
+ The first term refers to event loglikelihood and the second term (including the negative sign) refers to the non-event loglikelihood.
71
+
72
+
73
+
74
+
75
+
76
+
77
+ Neural Hawkes Process (NHP)
78
+ ===================================
79
+
80
+ We implement NHP based on author's official pytorch code `Github:nce-mpp <https://github.com/hongyuanmei/nce-mpp/blob/main/ncempp/models/nhp.py>`_.
81
+
82
+ 1. A continuous-time LSTM is introduced, with the code mainly come from `Github:nce-mpp <https://github.com/hongyuanmei/nce-mpp/blob/main/ncempp/models/nhp.py>`_.
83
+ 2. A `forward` function in NHP class that recursively update the states: we compute the event embedding, pass to the LSTM cell and then decay afterwards. Noted that for i-th event, we should use (i+1)-th dt for the decay. So we do not consider the last event as it has no decay time.
84
+
85
+ Attentive Neural Hawkes Process (AttNHP)
86
+ ========================================
87
+
88
+
89
+ We implement AttNHP based on the authors' official pytorch code `Github:anhp-andtt <https://github.com/yangalan123/anhp-andtt>`_
90
+ and similar to NHP, we factorize it into based model and inherited model.
91
+
92
+ The forward functions is implemented faithfully to that of the author's repo.
93
+
94
+
95
+ Transformer Hawkes Process (THP)
96
+ ========================================
97
+
98
+ We implement THP based on a fixed version of pytorch code `Github:anhp-andtt/thp <https://github.com/yangalan123/anhp-andtt/tree/master/thp>`_
99
+ and we factorize it into based model and inherited model.
100
+
101
+
102
+ Self-Attentive Hawkes Process (SAHP)
103
+ ========================================
104
+
105
+ We implement SAHP based on a fixed version of pytorch code `Github:anhp-andtt/sahp <https://github.com/yangalan123/anhp-andtt/tree/master/sahp>`_
106
+ and we factorize it into based model and inherited model.
107
+
108
+ `SAHP` basically shares very similar structure to that of `THP`.
109
+
110
+
111
+
112
+ Recurrent Marked Temporal Point Processes (RMTPP)
113
+ ====================================================
114
+
115
+ We implement RMTPP faithfully to the author's paper.
116
+
117
+
118
+ Intensity Free Learning of Temporal Point Process (IntensityFree)
119
+ ==================================================================
120
+
121
+ We implement the model based on the author's torch code `Github:ifl-tpp <https://github.com/shchur/ifl-tpp>`_.
122
+
123
+ A small difference between our implementation and the author's is we ignore the `context_init` (the initial state of the RNN) because in our data setup, we do not need a learnable initial RNN state. This modification generally makes little impact on the learning process.
124
+
125
+ It is worth noting that the thinning algorithm can not be applied to this model because it is intensity-free. When comparing the performance of the model, we only look at its log-likelihood learning curve.
126
+
127
+
128
+ Fully Neural Network based Model for General Temporal Point Processes (FullyNN)
129
+ ===============================================================================
130
+
131
+ We implement the model based on the author's keras code `Github:NeuralNetworkPointProcess <https://github.com/omitakahiro/NeuralNetworkPointProcess>`_.
132
+
133
+
134
+ ODE-based Temporal Point Process (ODETPP)
135
+ =========================================
136
+
137
+ We implement a TPP with Neural ODE state evolution, which is a simplified version of `Neural Spatio-Temporal Point Processes <https://arxiv.org/abs/2011.04583>`_. The ODE implementation uses the code from the `blog <https://msurtsukov.github.io/Neural-ODE/>`_
138
+
139
+
140
+ Attentive Neural Hawkes Network (ANHN)
141
+ ======================================
142
+
143
+ We implement the model based on the author's paper: the attentive model without the graph regularizer is named ANHN.
docs/source/advanced/performance_valid.rst ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ =========================================
2
+ Performance validation of EasyTPP models
3
+ =========================================
4
+
5
+ We run the experiments on various dataset to validate the implementations: each model is trained with a max number of epochs and
6
+ the best model is selected based on the performance on the valid set, then we report the results on the test set.
7
+
8
+
9
+ Simulated dataset
10
+ ---------------------------
11
+ Conttime
12
+ **********************
13
+
14
+
15
+
16
+ +--------------+----------+----------+----------+--------------------+
17
+ | Models | Loglike | RMSE | Acc | Num Training Epochs|
18
+ +==============+==========+==========+==========+====================+
19
+ | Torch_NHP | -0.93504 | 0.34000 | 0.38656 | 200 |
20
+ +--------------+----------+----------+----------+--------------------+
21
+ | Tf_NHP | -0.85774 | 0.34014 | 0.38806 | 200 |
22
+ +--------------+----------+----------+----------+--------------------+
23
+ | Torch_AttNHP | -1.02001 | 0.33678 | 0.36782 | 200 |
24
+ +--------------+----------+----------+----------+--------------------+
25
+ | Tf_AttNHP | -1.02315 | 0.33816 | 0.19456 | 200 |
26
+ +--------------+----------+----------+----------+--------------------+
27
+ | Torch_AttNHP | -1.00593 | 0.33685 | 0.37723 | 500 |
28
+ +--------------+----------+----------+----------+--------------------+
29
+ | Tf_AttNHP | -0.99827 | 0.33717 | 0.36498 | 500 |
30
+ +--------------+----------+----------+----------+--------------------+
31
+ | Torch_THP | -0.99827 | 0.33717 | 0.36498 | 500 |
32
+ +--------------+----------+----------+----------+--------------------+
33
+ | Tf_THP | -1.01898 | 0.33677 | 0.37875 | 500 |
34
+ +--------------+----------+----------+----------+--------------------+
35
+
36
+
37
+
38
+ ## Real dataset
39
+ ### Taxi
40
+
41
+
docs/source/advanced/tensorboard.rst ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ===================================
2
+ Launching the Tensorboard
3
+ ===================================
4
+
5
+
6
+ Here we present how to launch the tensorboard within the ``EasyTPP`` framework.
7
+
8
+ Step 1: Activate the usage of tensorboard in Config file
9
+ ========================================================
10
+
11
+
12
+ As shown in `Training Pipeline <../get_started/run_train_pipeline.html>`_, we need to firstly initialize the 'model_config.yaml' file to setup the running config before training or evaluating the model.
13
+
14
+ In the ``model config`` (`modeling` attribute of the config), one needs to set ``use_tfb`` to ``True`` in `trainer`. Then before the running process, summary writers tracking the performance on training and valid sets are both initialized.
15
+
16
+ .. code-block:: yaml
17
+
18
+ NHP_train:
19
+ base_config:
20
+ stage: train
21
+ backend: torch
22
+ dataset_id: taxi
23
+ runner_id: std_tpp
24
+ model_id: NHP # model name
25
+ base_dir: './checkpoints/'
26
+ trainer_config:
27
+ batch_size: 256
28
+ max_epoch: 200
29
+ shuffle: False
30
+ optimizer: adam
31
+ learning_rate: 1.e-3
32
+ valid_freq: 1
33
+ use_tfb: True # Activate the tensorboard
34
+ metrics: [ 'acc', 'rmse' ]
35
+ seed: 2019
36
+ gpu: -1
37
+ model_config:
38
+ hidden_size: 64
39
+ loss_integral_num_sample_per_step: 20
40
+ # pretrained_model_dir: ./checkpoints/75518_4377527680_230530-132355/models/saved_model
41
+ thinning:
42
+ num_seq: 10
43
+ num_sample: 1
44
+ num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm
45
+ look_ahead_time: 10
46
+ patience_counter: 5 # the maximum iteration used in adaptive thinning
47
+ over_sample_rate: 5
48
+ num_samples_boundary: 5
49
+ dtime_max: 5
50
+ num_step_gen: 1
51
+
52
+
53
+
54
+ Step 2: Launching the tensorboard
55
+ ========================================================
56
+
57
+
58
+ We simply go to the output file of the training runner (its directory is specified in `base_dir` of ``base_config``), find out the tensorboard file address and launch it.
59
+
60
+ A complete example of using tensorboard can be seen at *examples/run_tensorboard.py*.
61
+
62
+
63
+ .. code-block:: python
64
+
65
+ import os
66
+
67
+ def main():
68
+ # one can find this dir in the config out file
69
+ log_dir = './checkpoints/NHP_train_taxi_20220527-20:18:30/tfb_train'
70
+ os.system('tensorboard --logdir={}'.format(log_dir))
71
+ return
72
+
73
+
74
+ if __name__ == '__main__':
75
+ main()
docs/source/advanced/thinning_algo.rst ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ==============================================
2
+ Thinning Algorithm for Sampling Event Sequence
3
+ ==============================================
4
+
5
+ In ``EasyTPP`` we use ``Thinning algorithm`` depicted in Algorithm 2
6
+ in `The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process <https://arxiv.org/abs/1612.09328>`_
7
+ for event sampling.
8
+
9
+ The implementation of the algorithm
10
+ ====================================
11
+
12
+
13
+ We implement the algorithm both in PyTorch and Tensorflow, as seen in *./model/torch_thinning.py* and
14
+ *./model/tf_thinning.py*, which basically follow the same procedure.
15
+
16
+ The corresponding code is in function ``draw_next_time_one_step``, which consists of the following steps:
17
+
18
+ 1. Compute the upper bound of the intensity at each event timestamp in function ``compute_intensity_upper_bound``, where we sample some timestamps inside event intervals and output a upper bound intensity matrix [batch_size, seq_len], denoting the upper bound of prediced intensity (for next time interval) for each sequence at each timestamp.
19
+ 2. Sample the exponential distribution with the intensity computed in Step 1 in function ``sample_exp_distribution``, where we simply divide the standard exponential number with the intensity, which is equivalent to sampling with exp(sample_rate), according to `the properties of Exponential Distribution <https://en.wikipedia.org/wiki/Exponential_distribution>`_. The exponential random variables have size [batch_size, seq_len, num_sample, num_exp], where num_sample refers to the number of event times sampled in every interval and num_exp refers to number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm.
20
+ 3. Compute the intensities at the sample times proposed in Step 2, with final size `[batch_size, seq_len, num_sample, num_exp]`.
21
+ 4. Sample the standard uniform distribution with size `[batch_size, seq_len, num_sample, num_exp]`.
22
+ 5. Perform the acceptance sampling with certain probability in function ``sample_accept``.
23
+ 6. The earliest sampling dtimes are accepted. For unaccepted sampling dtimes, use boundary/maxsampletime for that draw.
24
+ 7. The final predicted dtimes has size `[batch_size, seq_len, num_sample]`, which refers to the sampling dtimes for each sequence at each timestamp, along with an equal weight vector.
25
+ 8. The product of the predicted dtimes and the weight is the final predicted dtimes, with size `[batch_size, seq_len]`.
26
+
27
+
28
+ .. image:: ../../images/thinning_algo.jpg
29
+ :alt: thinning_algo
30
+
31
+
32
+
33
+ One-step prediction
34
+ ====================================
35
+ By default, once given the parameters of thinning algo (defining a ``thinning`` config as part of ``model_config``), we perform the one-step prediction in model evaluation, i.e., predict the next event given the prefix. The implementation is in function ``prediction_event_one_step`` in BaseModel (i.e., TorchBaseModel or TfBaseModel).
36
+
37
+
38
+ Multi-step prediction
39
+ ====================================
40
+ The recursive multi-step prediction is activated by setting `num_step_gen` to a number bigger than 1 in the ``thinning`` config.
41
+
42
+ Be noted that, we generate the multi-step events after the last non-pad event of each sequence. The implementation is in function `predict_multi_step_since_last_event` in BaseModel (i.e., TorchBaseModel or TfBaseModel).
43
+
44
+
45
+ .. code-block:: yaml
46
+
47
+ thinning:
48
+ num_seq: 10
49
+ num_sample: 1
50
+ num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm
51
+ look_ahead_time: 10
52
+ patience_counter: 5 # the maximum iteration used in adaptive thinning
53
+ over_sample_rate: 5
54
+ num_samples_boundary: 5
55
+ dtime_max: 5
56
+ num_step_gen: 5 # by default it is single step, i.e., 1
docs/source/conf.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration file for the Sphinx documentation builder.
2
+ #
3
+ # For the full list of built-in configuration values, see the documentation:
4
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html
5
+
6
+
7
+ # -- Autodoc information -----------------------------------------------------
8
+ # https://sphinx-rtd-tutorial.readthedocs.io/en/latest/sphinx-config.html
9
+
10
+
11
+ import os
12
+ import sys
13
+
14
+ sys.path.insert(0, os.path.abspath('../../easy_tpp/'))
15
+
16
+ sys.path.insert(0, os.path.abspath('../..'))
17
+
18
+ # -- Project information -----------------------------------------------------
19
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
20
+
21
+ project = 'EasyTPP'
22
+ copyright = '2022, Machine Intelligence, Alipay'
23
+ author = 'Machine Intelligence, Alipay'
24
+ release = '0.0.2'
25
+
26
+ # -- General configuration ---------------------------------------------------
27
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
28
+
29
+ extensions = [
30
+ "sphinx.ext.autodoc",
31
+ 'sphinx.ext.viewcode',
32
+ "sphinx.ext.todo",
33
+ "sphinx.ext.mathjax",
34
+ "sphinx.ext.napoleon",
35
+ 'sphinx.ext.autosummary'
36
+ ]
37
+
38
+ napoleon_google_docstring = True
39
+ napoleon_numpy_docstring = False
40
+
41
+ templates_path = ['_templates']
42
+ # List of patterns, relative to source directory, that match files and
43
+ # directories to ignore when looking for source files.
44
+ # This patterns also effect to html_static_path and html_extra_path
45
+ exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
46
+
47
+ # -- Options for HTML output -------------------------------------------------
48
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
49
+
50
+ html_theme = 'sphinx_rtd_theme'
51
+ html_static_path = ['_static']
52
+
53
+ autodoc_member_order = "bysource"
54
+ autodoc_default_flags = ["members"]
55
+ autodoc_default_options = {
56
+ "members": True,
57
+ "member-order": "bysource",
58
+ "special-members": "__init__",
59
+ }
docs/source/dev_guide/model_custom.rst ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ==================
2
+ Customize a Model
3
+ ==================
4
+
5
+
6
+ Here we introduce how to customize a TPP model with the support of ``EasyTPP``.
7
+
8
+
9
+
10
+ Create a new TPP Model Class
11
+ =============================
12
+
13
+ Assume we are building a PyTorch model. We need to initialize the model by inheriting class `EasyTPP.model.torch_model.TorchBaseModel <../ref/models.html>`_.
14
+
15
+ .. code-block:: python
16
+
17
+ from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel
18
+
19
+ # Custom Torch TPP implementations need to
20
+ # inherit from the TorchBaseModel interface
21
+ class NewModel(TorchBaseModel):
22
+ def __init__(self, model_config):
23
+ super(NewModel, self).__init__(model_config)
24
+
25
+ # Forward along the sequence, output the states / intensities at the event times
26
+ def forward(self, batch):
27
+ ...
28
+ return states
29
+
30
+ # Compute the loglikelihood loss
31
+ def loglike_loss(self, batch):
32
+ ....
33
+ return loglike
34
+
35
+ # Compute the intensities at given sampling times
36
+ # Used in the Thinning sampler
37
+ def compute_intensities_at_sample_times(self, batch, sample_times, **kwargs):
38
+ ...
39
+ return intensities
40
+
41
+
42
+ If we are building a Tensorflow model, we start with the following code
43
+
44
+ .. code-block:: python
45
+
46
+ from easy_tpp.model.torch_model.tf_basemodel import TfBaseModel
47
+
48
+ # Custom Tf TPP implementations need to
49
+ # inherit from the TorchBaseModel interface
50
+ class NewModel(TfBaseModel):
51
+ def __init__(self, model_config):
52
+ super(NewModel, self).__init__(model_config)
53
+
54
+ # Forward along the sequence, output the states / intensities at the event times
55
+ def forward(self, batch):
56
+ ...
57
+ return states
58
+
59
+
60
+ # Compute the loglikelihood loss
61
+ def loglike_loss(self, batch):
62
+ ....
63
+ return loglike
64
+
65
+ # Compute the intensities at given sampling times
66
+ # Used in the Thinning sampler
67
+ def compute_intensities_at_sample_times(self, batch, sample_times, **kwargs):
68
+ ...
69
+ return intensities
70
+
71
+ Rewrite Relevant Methods
72
+ ==============================
73
+
74
+ There are three important functions needed to be implemented:
75
+
76
+ - `forward`: the input is the batch data and the output is states at each step.
77
+ - `loglike_loss`: it computes the loglikihood loss given the batch data.
78
+ - `compute_intensities_at_sample_times`: it computes the intensities at each sampling steps.
docs/source/get_started/install.rst ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ==================
2
+ Installation
3
+ ==================
4
+
5
+
6
+ ``EasyTPP`` provides an open-source library for `Neural TPP`, with a fully automated pipeline for model training and prediction.
7
+
8
+
9
+ Requirements
10
+ =============
11
+
12
+ .. code-block:: bash
13
+
14
+ PyTorch version >= 1.8.0
15
+ Python version >= 3.7
16
+ Tensorflow version >= 1.13.1 (only needed when using Tensorflow backend)
17
+
18
+
19
+
20
+ First, we need a python environment whose version is at least greater than 3.7.0. If you don’t have one, please refer to the `Documentation <https://docs.anaconda.com/anaconda/install/>`_ to install and configure the Anaconda environment.
21
+
22
+ .. code-block:: bash
23
+
24
+ conda create -n easytpp python=3.8
25
+ conda activate easytpp
26
+
27
+ Then, install Pytorch and keep the version at least greater than 1.8.0.
28
+
29
+ .. code-block:: bash
30
+
31
+ pip install torch
32
+
33
+ By default, we assume to use PyTorch. If one wants to use Tensorflow backend, please install tensorflow additionally. Both Tensorflow 1.13.1 and 2.x are supported.
34
+
35
+ .. code-block:: bash
36
+
37
+ pip install tensorflow
38
+
39
+
40
+
41
+ Install
42
+ =====================
43
+
44
+
45
+ Install with pip
46
+ --------------------------
47
+
48
+
49
+ .. code-block:: bash
50
+
51
+ pip install easy-tpp
52
+
53
+
54
+ Install with the source
55
+ --------------------------
56
+
57
+ Setup from the source:
58
+
59
+ .. code-block:: bash
60
+
61
+ git clone https://github.com/ant-research/EasyTemporalPointProcess.git
62
+ cd EasyTemporalPointProcess
63
+ python setup.py install
64
+
docs/source/get_started/introduction.rst ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ==================
2
+ Introduction
3
+ ==================
4
+
5
+
6
+ ``EasyTPP`` provides an open-source library for `Neural TPP`, with a fully automated pipeline for model training and prediction.
7
+
8
+
9
+ Framework
10
+ =========
11
+
12
+
13
+ ``EasyTPP`` supports both Tensorflow and PyTorch: each model has two equivalent versions implemented in Tensorflow 1.13 and Pytorch 1.8 respectively. The data processing and model training / prediction pipeline are compatible with both Tensorflow and Pytorch as well.
14
+
15
+
16
+ At the module level, ``EasyTPP`` is a package that consists of the following components, which are designed as loose-coupled modules that provide flexibility for users to develop customized functionalities.
17
+
18
+
19
+
20
+ ======================== ==============================================================================
21
+ Name Description
22
+ ======================== ==============================================================================
23
+ `Preprocess` module Provides data batch-wise padding, inter-time processing and other related work for raw sequence.
24
+
25
+ `Model` module Implements a list of SOTA TPP models. Please refer to `Model Validation <../advanced/performance_valid.html>`_ for more details.
26
+
27
+ `Config` module Encapsulate the construction of the configuration needed to run the pipeline.
28
+
29
+ `Runner` module Controls the training and prediction pipeline.
30
+ ======================== ==============================================================================
31
+
32
+
33
+
34
+ Install
35
+ =========
36
+
37
+ ``EasyTPP`` can be installed either by pip or the source. By default it is built based on PyTorch. If one wants to run with the Tensorflow backend, one needs to install Tensorflow additionally.
38
+
39
+ Please see `Installation <./install.html>`_ for details of requirement and installation.
40
+
41
+
42
+ Prepare Data
43
+ ============
44
+
45
+ By default, we use the data in Gatech format, i.e., each dataset is a dict containing the keys such as `time_since_last_event`, `time_since_start` and `type_event`. `Preprocess <../ref/preprocess.html>`_ module
46
+ will preprocess the data and feed it into the model.
47
+
48
+
49
+ An example of building a pseudo dataloader can be found at `examples <https://github.com/ant-research/EasyTemporalPointProcess/tree/main/examples/data_loader.py>`_. Please refer to `Datatset <../user_guide/dataset.html>`_ for more explanations of the `TPP` dataset iterator.
50
+
51
+
52
+ Model Training and Prediction
53
+ ==============================
54
+
55
+ The training and prediction pipeline consists of two steps:
56
+
57
+ 1. Setup the config file, which specifies the dataset dir, model params and pipeline settings.
58
+ 2. Launch the python script to run the whole pipeline.
59
+
60
+ Please see `Training Pipeline <../user_guide/run_train_pipeline.html>`_ and `Evaluation Pipeline <../user_guide/run_eval.html>`_ for more details.
docs/source/get_started/quick_start.rst ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ====================
2
+ Quick Start
3
+ ====================
4
+
5
+
6
+ We use the [Taxi]_ dataset as an example to show how to use ``EasyTPP`` to train a model. More details and results are provided in `Training Pipeline <../user_guide/run_train_pipeline.html>`_.
7
+
8
+
9
+ Download Dataset
10
+ ===================
11
+
12
+
13
+
14
+ The Taxi dataset we used is preprocessed by `HYPRO <https://github.com/iLampard/hypro_tpp>`_ . You can either download the dataset (in pickle) from Google Drive `here <https://drive.google.com/drive/folders/1vNX2gFuGfhoh-vngoebaQlj2-ZIZMiBo>`_ or the dataset (in json) from `HuggingFace <https://huggingface.co/easytpp>`_.
15
+
16
+
17
+ Note that if the data sources are pickle files, we need to write the data config (in `Example Config <https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml>`_) in the following way
18
+
19
+ .. code-block:: yaml
20
+
21
+ data:
22
+ taxi:
23
+ data_format: pickle
24
+ train_dir: ./data/taxi/train.pkl
25
+ valid_dir: ./data/taxi/dev.pkl
26
+ test_dir: ./data/taxi/test.pkl
27
+
28
+ If we choose to directly load from HuggingFace, we can put it this way:
29
+
30
+ .. code-block:: yaml
31
+
32
+ data:
33
+ taxi:
34
+ data_format: json
35
+ train_dir: easytpp/taxi
36
+ valid_dir: easytpp/taxi
37
+ test_dir: easytpp/taxi
38
+
39
+
40
+ Meanwhile, it is also feasible to put the local directory of json files downloaded from HuggingFace in the config:
41
+
42
+ .. code-block:: yaml
43
+
44
+ data:
45
+ taxi:
46
+ data_format: json
47
+ train_dir: ./data/taxi/train.json
48
+ valid_dir: ./data/taxi/dev.json
49
+ test_dir: ./data/taxi/test.json
50
+
51
+
52
+
53
+
54
+ Setup the configuration file
55
+ ==============================
56
+
57
+ We provide a preset config file in `Example Config <https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml>`_. The details of the configuration can be found in `Training Pipeline <../user_guide/run_train_pipeline.html>`_.
58
+
59
+
60
+
61
+
62
+ Train the Model
63
+ =========================
64
+
65
+ At this stage we need to write a script to run the training pipeline. There is a preset script `train_nhp.py <https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/train_nhp.py>`_ and one can simply copy it.
66
+
67
+ Taking the pickle data source for example, after the setup of data, config and running script, the directory structure is as follows:
68
+
69
+ .. code-block:: bash
70
+
71
+ data
72
+ |______taxi
73
+ |____ train.pkl
74
+ |____ dev.pkl
75
+ |____ test.pkl
76
+
77
+ configs
78
+ |______experiment_config.yaml
79
+
80
+ train_nhp.py
81
+
82
+
83
+
84
+ The one can simply run the following command.
85
+
86
+
87
+ .. code-block:: bash
88
+
89
+ python train_nhp.py
90
+
91
+
92
+
93
+ Reference
94
+ ----------
95
+
96
+ .. [Taxi]
97
+
98
+ .. code-block:: bash
99
+
100
+ @misc{whong-14-taxi,
101
+ title = {F{OIL}ing {NYC}’s Taxi Trip Data},
102
+ author={Whong, Chris},
103
+ year = {2014},
104
+ url = {https://chriswhong.com/open-data/foil_nyc_taxi/}
105
+ }
106
+
docs/source/index.rst ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ===================================
2
+ ``EasyTPP`` Documentation
3
+ ===================================
4
+
5
+
6
+ ``EasyTPP`` is an easy-to-use development and application toolkit for `Neural Temporal Point Process <https://mathworld.wolfram.com/TemporalPointProcess.html>`_ (*Neural TPP*), with key features in configurability, compatibility and reproducibility. We hope this project could benefit both researchers and practitioners with the goal of easily customized development and open benchmarking.
7
+
8
+
9
+
10
+ .. toctree::
11
+ :hidden:
12
+
13
+ .. toctree::
14
+ :maxdepth: 2
15
+ :caption: GETTING STARTED
16
+
17
+ Introduction <get_started/introduction.rst>
18
+ Installation <get_started/install.rst>
19
+ Quick Start <get_started/quick_start.rst>
20
+
21
+
22
+ .. toctree::
23
+ :maxdepth: 2
24
+ :caption: USER GUIDE
25
+
26
+ Dataset <user_guide/dataset.rst>
27
+ Model Training <user_guide/run_train_pipeline.rst>
28
+ Model Prediction <user_guide/run_eval.rst>
29
+
30
+ .. toctree::
31
+ :maxdepth: 2
32
+ :caption: DEVELOPER GUIDE
33
+
34
+ Model Customization <dev_guide/model_custom.rst>
35
+
36
+
37
+ .. toctree::
38
+ :maxdepth: 2
39
+ :caption: ADVANCED TOPICS
40
+
41
+ Thinning Algorithm <advanced/thinning_algo.rst>
42
+ Tensorboard <advanced/tensorboard.rst>
43
+ Performance Benchmarks <advanced/performance_valid.rst>
44
+ Implementation Details <advanced/implementation.rst>
45
+
46
+ .. toctree::
47
+ :maxdepth: 2
48
+ :caption: API REFERENCE
49
+
50
+ Config <ref/config.rst>
51
+ Preprocess <ref/preprocess.rst>
52
+ Model <ref/models.rst>
53
+ Runner <ref/runner.rst>
54
+ Hyper-parameter Optimization <ref/hpo.rst>
55
+ Tf and Torch Wrapper <ref/wrapper.rst>
56
+ Utilities <ref/utils.rst>
docs/source/ref/config.rst ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .. _api-config:
2
+
3
+ EasyTPP Config Modules
4
+ ============================
5
+
6
+
7
+ .. automodule:: config_factory
8
+ :members:
9
+ :undoc-members:
10
+ :show-inheritance:
docs/source/ref/hpo.rst ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .. _api-config:
2
+
3
+ EasyTPP Config Modules
4
+ ============================
5
+
6
+
7
+ .. automodule:: hpo
8
+ :members:
9
+ :undoc-members:
10
+ :show-inheritance:
docs/source/ref/models.rst ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. _api-model:
2
+
3
+ EasyTPP Models
4
+ ====================
5
+
6
+
7
+
8
+ .. _api-tf_model:
9
+
10
+ model.tf_model module
11
+ ------------------------------
12
+
13
+ .. automodule:: easy_tpp.model.tf_model
14
+ .. autosummary::
15
+ :toctree: ../generated/
16
+
17
+ tf_baselayer
18
+ tf_basemodel
19
+ tf_nhp
20
+ tf_fullynn
21
+ tf_intensity_free
22
+ tf_ode_tpp
23
+ tf_rmtpp
24
+ tf_sahp
25
+ tf_thp
26
+ tf_attnhp
27
+ tf_thinning
28
+
29
+
30
+ .. _api-torch_model:
31
+
32
+ model.torch_model module
33
+ ------------------------------
34
+
35
+ .. automodule:: easy_tpp.model.torch_model
36
+ .. autosummary::
37
+ :toctree: ../generated/
38
+
39
+ torch_baselayer
40
+ torch_basemodel
41
+ torch_nhp
42
+ torch_fullynn
43
+ torch_intensity_free
44
+ torch_ode_tpp
45
+ torch_rmtpp
46
+ torch_sahp
47
+ torch_thp
48
+ torch_attnhp
49
+ torch_thinning
50
+
docs/source/ref/preprocess.rst ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .. _api-preprocess:
2
+
3
+ EasyTPP Preprocess Modules
4
+ ==========================
5
+
6
+
7
+ .. automodule:: preprocess
8
+ :members:
9
+ :undoc-members:
10
+ :show-inheritance:
docs/source/ref/runner.rst ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .. _api-modelrunner:
2
+
3
+ EasyTPP Model Runner Modules
4
+ ============================
5
+
6
+
7
+ .. automodule:: runner
8
+ :members:
9
+ :undoc-members:
10
+ :show-inheritance:
docs/source/ref/utils.rst ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .. _api-util:
2
+
3
+ EasyTPP Utilities Modules
4
+ ==========================
5
+
6
+
7
+ .. automodule:: utils
8
+ :members:
9
+ :undoc-members:
10
+ :show-inheritance:
docs/source/ref/wrapper.rst ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. _api-wrapper:
2
+
3
+ EasyTPP Tf and Torch Wrapper Modules
4
+ ====================================
5
+
6
+
7
+ .. automodule:: tf_wrapper
8
+ :members:
9
+ :undoc-members:
10
+ :show-inheritance:
11
+
12
+
13
+
14
+ .. automodule:: torch_wrapper
15
+ :members:
16
+ :undoc-members:
17
+ :show-inheritance:
docs/source/user_guide/dataset.rst ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ===========================================
2
+ Expected Dataset Format and Data Processing
3
+ ===========================================
4
+
5
+ Required format
6
+ ===================================
7
+
8
+ In EasyTPP we use the data in Gatech format, i.e., each dataset is a dict containing the following keys as
9
+
10
+ .. code-block:: bash
11
+
12
+ dim_process: 5 # num of event types (no padding)
13
+ 'train': [[{'idx_event': 2, 'time_since_last_event': 1.0267814, 'time_since_last_same_event': 1.0267814, 'type_event': 3, 'time_since_start': 1.0267814}, {'idx_event': 3, 'time_since_last_event': 0.4029268, 'time_since_last_same_event': 1.4297082, 'type_event': 0, 'time_since_start': 1.4297082},...,],[{}...{}]]
14
+
15
+ where `dim_process` refers to the number of event types (without padding) and
16
+ `train` (or `dev` / `test`) contains a list of list which corresponds to an event sequence each.
17
+
18
+ Each pickle file generates a set of event sequences, each containing three sub sequences:
19
+
20
+ 1. `time_seqs`: absolute timestamps of the events, correspond to `time_since_last_event`.
21
+ 2. `time_delta_seqs`: relative timestamps of the events, correspond to `time_since_last_same_event`.
22
+ 3. `type_seqs`: types of the events, correspond to `type_event`. Be noted that the event type index `starts from 0`.
23
+
24
+
25
+ Data processing
26
+ ===================================
27
+
28
+ The data processing follows the similar pipeline as in official code of `AttNHP <https://github.com/yangalan123/anhp-andtt>`_. We name it the process of `event tokenize`.
29
+
30
+
31
+ Sequence padding
32
+ ----------------
33
+
34
+
35
+ time_seqs, time_delta_seqs and type_seqs are firstly padded to `the max length of the whole dataset` and then fed into the model in batch.
36
+
37
+ .. code-block:: bash
38
+
39
+ input: raw event sequence (e_0, e_1, e_2, e_3) and max_len=6 # the max length among all data seqs
40
+
41
+ output:
42
+
43
+ index: 0, 1, 2, 3, 4 5
44
+ dtimes: 0, t_1-t_0, t_2-t_1, t_3-t_2, time_pad, time_pad
45
+ types: e_0, e_1, e_2, e_3, type_pad, type_pad
46
+
47
+
48
+ By default, we set the value of time_pad and type_pad to be the *num_event_types* (because we assume the event type index starts from 0, therefore the integer value of num_event_types is unused).
49
+
50
+ Sequence masking
51
+ ----------------
52
+
53
+
54
+ After padding, we perform the masking for the event sequences and generate three more seqs: batch_non_pad_mask, attention_mask, type_mask:
55
+
56
+ 1. `batch_non_pad_mask`: it indicates the position of masks in the sequence.
57
+ 2. `attention_mask`: it indicates the masks used in the attention calculation (one event can only attend to its past events).
58
+ 3. `type_mask`: it uses one-hot vector to represent the event type. The padded event is a zero vector.
59
+
60
+ Finally, each batch contains six elements: time_seqs, time_delta_seqs, event_seq, batch_non_pad_mask, attention_mask, type_mask. The implementation of padding mechanism can be found at `event_tokenizer <https://github.com/ant-research/EasyTemporalPointProcess/blob/main/easy_tpp/preprocess/event_tokenizer.py>`_.
61
+
62
+
63
+
64
+ An example
65
+ ----------------
66
+
67
+ We take a real event sequence for example. Assume we have an input sequence $[ 1, 9, 5, 0]$ with num_event_types=11 and max_len=6.
68
+
69
+ Then the padded time_seqs, time_delta_seqs and type_seqs become
70
+
71
+ .. code-block:: bash
72
+
73
+ # time_seqs
74
+ [ 0.0000, 0.8252, 1.3806, 1.8349, 11.0000, 11.0000]
75
+
76
+ # time_delta_seqs
77
+ [ 0.0000, 0.8252, 0.5554, 0.4542, 11.0000, 11.0000]
78
+
79
+ # type_seqs
80
+ [ 1, 9, 5, 0, 11, 11]
81
+
82
+
83
+ The mask sequences are
84
+
85
+ .. code-block:: bash
86
+
87
+ # batch_non_pad_mask
88
+ [ True, True, True, True, False, False]
89
+
90
+ # attention_mask
91
+ [[True, True, True, True, True, True],
92
+ [False, True, True, True, True, True],
93
+ [False, False, True, True, True, True],
94
+ [False, False, False, True, True, True],
95
+ [False, False, False, False, True, True],
96
+ [False, False, False, False, True, True]]
97
+
98
+ # type_mask
99
+ [[False, True, False, False, False, False, False, False, False, False, False],
100
+ [False, False, False, False, False, False, False, False, False, True, False],
101
+ [False, False, False, False, False, True, False, False, False, False, False],
102
+ [True, False, False, False, False, False, False, False, False, False, False],
103
+ [False, False, False, False, False, False, False, False, False, False, False],
104
+ [False, False, False, False, False, False, False, False, False, False, False]],
105
+
106
+
107
+ The runnable examples of constructing and iterating the dataset object can be found at `examples/event_tokenizer.py <https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/event_tokenizer.py>`_
108
+
109
+
110
+ Preprocessed Datasets
111
+ ===================================
112
+
113
+ We have preprocessed some widely-used open source datasets in Gatech format, which can be found at `Google Drive <https://drive.google.com/drive/folders/0BwqmV0EcoUc8UklIR1BKV25YR1U?resourcekey=0-OrlU87jyc1m-dVMmY5aC4w>`_. We use them for validating and benchmarking EasyTPP models.
114
+
115
+ - Retweet (`Zhou, 2013 <http://proceedings.mlr.press/v28/zhou13.pdf>`_). This dataset contains time-stamped user retweet event sequences. The events are categorized into 3 types: retweets by “small,” “medium” and “large” users. Small users have fewer than 120 followers, medium users have fewer than 1363, and the rest are large users. We work on a subset of 5200 most active users with an average sequence length of 70.
116
+ - Taxi (`Whong, 2014 <https://chriswhong.com/open-data/foil_nyc_taxi>`_). This dataset tracks the time-stamped taxi pick-up and drop-off events across the five boroughs of the New York City; each (borough, pick-up or drop-off) combination defines an event type, so there are 10 event types in total. We work on a randomly sampled subset of 2000 drivers and each driver has a sequence. We randomly sampled disjoint train, dev and test sets with 1400, 200 and 400 sequences.
117
+ - StackOverflow ( `Leskovec, 2014 <https://snap.stanford.edu/data/>`_). This dataset has two years of user awards on a question-answering website: each user received a sequence of badges and there are 22 different kinds of badges in total. We randomly sampled disjoint train, dev and test sets with 1400,400 and 400 sequences from the dataset.
118
+ - Taobao (`Xue et al, 2022 <https://arxiv.org/abs/2210.01753>`_). This dataset contains time-stamped user click behaviors on Taobao shopping pages from November 25 to December 03, 2017. Each user has a sequence of item click events with each event containing the timestamp and the category of the item. The categories of all items are first ranked by frequencies and the top 19 are kept while the rest are merged into one category, with each category corresponding to an event type. We work on a subset of 4800 most active users with an average sequence length of 150 and then end up with 20 event types.
119
+ - Amazon (`Xue et al, 2022 <https://arxiv.org/abs/2210.01753>`_). This dataset includes time-stamped user product reviews behavior from January, 2008 to October, 2018. Each user has a sequence of produce review events with each event containing the timestamp and category of the reviewed product, with each category corresponding to an event type. We work on a subset of 5200 most active users with an average sequence length of 70 and then end up with 16 event types.
120
+
121
+ Besides, we also published two textual event sequence datasets:
122
+
123
+ - GDELT (`Shi et al, 2023 <https://arxiv.org/abs/2305.16646>`_). The GDELT Project monitors events all over the world, with live datasets updated every 15 minutes. We only focused on the political events that happened in G20 countries from 2022-01-01 to 2022-07-31, ending up with a corpus of 109000 time-stamped event tokens. The event type of each token has a structured name of the format subject-predicate-object. Each {predicate} is one of the twenty CAMEO codes such as {CONSULT} and {INVESTIGATE}; each {subject} or {object} is one of the 2279 political entities (individuals, groups, and states) such as {Tesla} and {Australia}. We split the dataset into disjoint train, dev, and test sets based on their dates: the 83100 events that happened before 2022-07-05 are training data; the 16650 events after 2022-07-19 are test data; the 9250 events between these dates are development data.
124
+ - Amazon-text-review (`Shi et al, 2023 <https://arxiv.org/abs/2305.16646>`_). This dataset contains user reviews on Amazon shopping website from 2014-01-04 to 2016-10-02. We focused on the most active 2500 users and each user has a sequence of product review events. The type is the category of the product: we selected the most frequently-reviewed 23 categories and grouped all the others into a special OTHER category, ending up with 24 categories in total. Each review event also has a mark which is the actual content of the review. Each of the 2500 sequences is cut into three segments: the events that happened before 2015-08-01 are training data; those after 2016-02-01 are test data; the events between these dates are dev data. Then we have 49,680 training tokens, 7,020 dev tokens, and 13,090 test tokens.
docs/source/user_guide/run_eval.rst ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ================================
2
+ Evaluate a Model
3
+ ================================
4
+
5
+ Step 1: Setup the config file
6
+ ===============================================
7
+
8
+ Same as in the training pipeline, firstly we need to initialize the task configuration in the config file.
9
+
10
+ Similar to the setup in `Training Pipeline <./run_train_pipeline.html>`_, we set the `stage` to `eval` and pass the `pretrained_model_dir` to ``the model_config``
11
+
12
+ Note that the *pretrained_model_dir* can be found in the log of the training process.
13
+
14
+ .. code-block:: yaml
15
+
16
+ NHP_eval:
17
+ base_config:
18
+ stage: eval
19
+ backend: torch
20
+ dataset_id: taxi
21
+ runner_id: std_tpp
22
+ base_dir: './checkpoints/'
23
+ model_id: NHP
24
+ trainer_config:
25
+ batch_size: 256
26
+ max_epoch: 1
27
+ model_config:
28
+ hidden_size: 64
29
+ use_ln: False
30
+ seed: 2019
31
+ gpu: 0
32
+ pretrained_model_dir: ./checkpoints/26507_4380788096_231111-101848/models/saved_model # must provide this dir
33
+ thinning:
34
+ num_seq: 10
35
+ num_sample: 1
36
+ num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm
37
+ look_ahead_time: 10
38
+ patience_counter: 5 # the maximum iteration used in adaptive thinning
39
+ over_sample_rate: 5
40
+ num_samples_boundary: 5
41
+ dtime_max: 5
42
+
43
+
44
+
45
+
46
+ A complete example of these files can be seen at `examples/example_config.yaml <https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml>`_ .
47
+
48
+
49
+ Step 2: Run the evaluation script
50
+ =================================
51
+
52
+ Same as in the training pipeline, we need to initialize a ``ModelRunner`` object to do the evaluation.
53
+
54
+ The following code is an example, which is a copy from `examples/train_nhp.py <https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/train_nhp.py>`_ .
55
+
56
+
57
+ .. code-block:: python
58
+
59
+ import argparse
60
+
61
+ from easy_tpp.config_factory import RunnerConfig
62
+ from easy_tpp.runner import Runner
63
+
64
+
65
+ def main():
66
+ parser = argparse.ArgumentParser()
67
+
68
+ parser.add_argument('--config_dir', type=str, required=False, default='configs/experiment_config.yaml',
69
+ help='Dir of configuration yaml to train and evaluate the model.')
70
+
71
+ parser.add_argument('--experiment_id', type=str, required=False, default='RMTPP_eval',
72
+ help='Experiment id in the config file.')
73
+
74
+ args = parser.parse_args()
75
+
76
+ config = RunnerConfig.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id)
77
+
78
+ model_runner = Runner.build_from_config(config)
79
+
80
+ model_runner.run()
81
+
82
+
83
+ if __name__ == '__main__':
84
+ main()
85
+
86
+
87
+
88
+
89
+ Checkout the output
90
+ ====================
91
+
92
+ The evaluation result will be print in the console and saved in the logs whose directory is specified in the
93
+ out config file, i.e.:
94
+
95
+ .. code-block:: bash
96
+
97
+ 'output_config_dir': './checkpoints/NHP_test_conttime_20221002-13:19:23/NHP_test_output.yaml'
docs/source/user_guide/run_train_pipeline.rst ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ============================================
2
+ Training a Model & Configuration Explanation
3
+ ============================================
4
+
5
+ This tutorial shows how one can use ``EasyTPP`` to train the implemented models.
6
+
7
+ In principle, firstly we need to initialize a config yaml file, containing all the input configuration to guide the training and eval process. The overall structure of a config file is shown as below:
8
+
9
+ .. code-block:: yaml
10
+
11
+ pipeline_config_id: .. # name of the config for guiding the pipeline
12
+
13
+ data:
14
+ [Dataset ID]: # name of the dataset, e.g, taxi
15
+ ....
16
+
17
+ [EXPERIMENT ID]: # name of the experiment to run
18
+ base_config:
19
+ ....
20
+ model_config:
21
+ ...
22
+
23
+
24
+ After the config file is setup, we can run the script, by specifying the `config directory` and `experiment id`, to start the pipeline. We currently provide a preset script at `examples/train_nhp.py`.
25
+
26
+
27
+ Step 1: Setup the config file containing data and model configs
28
+ ================================================================
29
+
30
+
31
+ To be specific, one needs to define the following entries in the config file:
32
+
33
+ - **pipeline_config_id**: registered name of EasyTPP.Config objects, such as `runner_config` or `hpo_runner_config`. By reading this, the corresponding configuration class will be loaded for constructing the pipeline.
34
+
35
+ .. code-block:: yaml
36
+
37
+ pipeline_config_id: runner_config
38
+
39
+
40
+ - **data**: dataset specifics. One can put multiple dataset specifics in the config file, but only one will be used in one experiment.
41
+
42
+ - *[DATASET ID]*: name of the dataset, e.g., taxi.
43
+ - *train_dir, valid_dir, test_dir*: directory of the datafile. For the moment we only accept pkl file (please see `Dataset <./dataset.html>`_ for details)
44
+ - *data_spec*: define the event type information.
45
+
46
+ .. code-block:: yaml
47
+
48
+ data:
49
+ taxi:
50
+ data_format: pkl
51
+ train_dir: ../data/taxi/train.pkl
52
+ valid_dir: ../data/taxi/dev.pkl
53
+ test_dir: ../data/taxi/test.pkl
54
+ data_spec:
55
+ num_event_types: 7 # num of types excluding pad events.
56
+ pad_token_id: 6 # event type index for pad events
57
+ padding_side: right # pad at the right end of the sequence
58
+ truncation_side: right # truncate at the right end of the sequence
59
+ max_len: 100 # max sequence length used as model input
60
+
61
+ - **[EXPERIMENT ID]**: name of the experiment to run in the pipeline. It contains two blocks of configs:
62
+
63
+ *base_config* contains the pipeline framework related specifications.
64
+
65
+ .. code-block:: yaml
66
+
67
+ base_config:
68
+ stage: train # train, eval and generate
69
+ backend: tensorflow # tensorflow and torch
70
+ dataset_id: conttime # name of the dataset
71
+ runner_id: std_tpp # registered name of the pipeline runner
72
+ model_id: RMTPP # model name # registered name of the implemented model
73
+ base_dir: './checkpoints/' # base dir to save the logs and models.
74
+
75
+
76
+
77
+ *model_config* contains the model related specifications.
78
+
79
+
80
+ .. code-block:: yaml
81
+
82
+ model_config:
83
+ hidden_size: 32
84
+ time_emb_size: 16
85
+ num_layers: 2
86
+ num_heads: 2
87
+ mc_num_sample_per_step: 20
88
+ sharing_param_layer: False
89
+ loss_integral_num_sample_per_step: 20
90
+ dropout: 0.0
91
+ use_ln: False
92
+ thinning_params: # thinning algorithm for event sampling
93
+ num_seq: 10
94
+ num_sample: 1
95
+ num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm
96
+ look_ahead_time: 10
97
+ patience_counter: 5 # the maximum iteration used in adaptive thinning
98
+ over_sample_rate: 5
99
+ num_samples_boundary: 5
100
+ dtime_max: 5
101
+
102
+
103
+ *trainer_config* contains the training related specifications.
104
+
105
+ .. code-block:: yaml
106
+
107
+ trainer_config: # trainer arguments
108
+ seed: 2019
109
+ gpu: 0
110
+ batch_size: 256
111
+ max_epoch: 10
112
+ shuffle: False
113
+ optimizer: adam
114
+ learning_rate: 1.e-3
115
+ valid_freq: 1
116
+ use_tfb: False
117
+ metrics: ['acc', 'rmse']
118
+
119
+
120
+
121
+
122
+ A complete example of these files can be seen at *examples/example_config*.
123
+
124
+
125
+ Step 2: Run the training script
126
+ ===============================================
127
+
128
+ To run the training process, we simply need to call two functions:
129
+
130
+ 1. ``Config``: it reads the directory of the configs specified in Step 1 and do some processing to form a complete configuration.
131
+ 2. ``Runner``: it reads the configuration and setups the whole pipeline for training, evaluation and generation.
132
+
133
+
134
+ The following code is an example, which is a copy from *examples/train_nhp.py*.
135
+
136
+
137
+ .. code-block:: python
138
+
139
+ import argparse
140
+ from easy_tpp.config_factory import Config
141
+ from easy_tpp.runner import Runner
142
+
143
+
144
+ def main():
145
+ parser = argparse.ArgumentParser()
146
+
147
+ parser.add_argument('--config_dir', type=str, required=False, default='configs/experiment_config.yaml',
148
+ help='Dir of configuration yaml to train and evaluate the model.')
149
+
150
+ parser.add_argument('--experiment_id', type=str, required=False, default='RMTPP_train',
151
+ help='Experiment id in the config file.')
152
+
153
+ args = parser.parse_args()
154
+
155
+ config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id)
156
+
157
+ model_runner = Runner.build_from_config(config)
158
+
159
+ model_runner.run()
160
+
161
+
162
+ if __name__ == '__main__':
163
+ main()
164
+
165
+
166
+
167
+
168
+
169
+ Checkout the output
170
+ ========================
171
+
172
+
173
+ During training, the log, the best model based on valid set performance, the complete configuration file are all saved. The directory of the saved files is specified in 'base' of ``model_config.yaml``, i.e.,
174
+
175
+
176
+
177
+ In the `./checkpoints/` folder, one find the correct subfolder by concatenating the 'experiment_id' and running timestamps. Inside that subfolder, there is a complete configuration file, e.g., ``NHP_train_output.yaml`` that records all the information used in the pipeline. The
178
+
179
+ .. code-block:: yaml
180
+
181
+ data_config:
182
+ train_dir: ../data/conttime/train.pkl
183
+ valid_dir: ../data/conttime/dev.pkl
184
+ test_dir: ../data/conttime/test.pkl
185
+ specs:
186
+ num_event_types_pad: 6
187
+ num_event_types: 5
188
+ event_pad_index: 5
189
+ data_format: pkl
190
+ base_config:
191
+ stage: train
192
+ backend: tensorflow
193
+ dataset_id: conttime
194
+ runner_id: std_tpp
195
+ model_id: RMTPP
196
+ base_dir: ./checkpoints/
197
+ exp_id: RMTPP_train
198
+ log_folder: ./checkpoints/98888_4299965824_221205-153425
199
+ saved_model_dir: ./checkpoints/98888_4299965824_221205-153425/models/saved_model
200
+ saved_log_dir: ./checkpoints/98888_4299965824_221205-153425/log
201
+ output_config_dir: ./checkpoints/98888_4299965824_221205-153425/RMTPP_train_output.yaml
202
+ model_config:
203
+ hidden_size: 32
204
+ time_emb_size: 16
205
+ num_layers: 2
206
+ num_heads: 2
207
+ mc_num_sample_per_step: 20
208
+ sharing_param_layer: false
209
+ loss_integral_num_sample_per_step: 20
210
+ dropout: 0.0
211
+ use_ln: false
212
+ seed: 2019
213
+ gpu: 0
214
+ thinning_params:
215
+ num_seq: 10
216
+ num_sample: 1
217
+ num_exp: 500
218
+ look_ahead_time: 10
219
+ patience_counter: 5
220
+ over_sample_rate: 5
221
+ num_samples_boundary: 5
222
+ dtime_max: 5
223
+ num_step_gen: 1
224
+ trainer:
225
+ batch_size: 256
226
+ max_epoch: 10
227
+ shuffle: false
228
+ optimizer: adam
229
+ learning_rate: 0.001
230
+ valid_freq: 1
231
+ use_tfb: false
232
+ metrics:
233
+ - acc
234
+ - rmse
235
+ seq_pad_end: true
236
+ is_training: true
237
+ num_event_types_pad: 6
238
+ num_event_types: 5
239
+ event_pad_index: 5
240
+ model_id: RMTPP
241
+
242
+
243
+
244
+ If we set ``use_tfb`` to ``true``, it means we can launch the tensorboard to track the training process, one
245
+ can see `Running Tensorboard <../advanced/tensorboard.html>`_ for details.
easy_tpp/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = '0.1.0'
easy_tpp/config_factory/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easy_tpp.config_factory.config import Config
2
+ from easy_tpp.config_factory.data_config import DataConfig, DataSpecConfig
3
+ from easy_tpp.config_factory.hpo_config import HPOConfig, HPORunnerConfig
4
+ from easy_tpp.config_factory.runner_config import RunnerConfig, ModelConfig, BaseConfig
5
+
6
+ __all__ = ['Config',
7
+ 'DataConfig',
8
+ 'DataSpecConfig',
9
+ 'ModelConfig',
10
+ 'BaseConfig',
11
+ 'RunnerConfig',
12
+ 'HPOConfig',
13
+ 'HPORunnerConfig']
easy_tpp/config_factory/config.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any
3
+ from omegaconf import OmegaConf
4
+
5
+ from easy_tpp.utils import save_yaml_config, Registrable, logger
6
+
7
+
8
+ class Config(Registrable):
9
+
10
+ def save_to_yaml_file(self, config_dir):
11
+ """Save the config into the yaml file 'config_dir'.
12
+
13
+ Args:
14
+ config_dir (str): Target filename.
15
+
16
+ Returns:
17
+ """
18
+ yaml_config = self.get_yaml_config()
19
+ OmegaConf.save(yaml_config, config_dir)
20
+
21
+ @staticmethod
22
+ def build_from_yaml_file(yaml_dir, **kwargs):
23
+ """Load yaml config file from disk.
24
+
25
+ Args:
26
+ yaml_dir (str): Path of the yaml config file.
27
+
28
+ Returns:
29
+ EasyTPP.Config: Config object corresponding to cls.
30
+ """
31
+ config = OmegaConf.load(yaml_dir)
32
+ pipeline_config = config.get('pipeline_config_id')
33
+ config_cls = Config.by_name(pipeline_config.lower())
34
+ logger.critical(f'Load pipeline config class {config_cls.__name__}')
35
+ return config_cls.parse_from_yaml_config(config, **kwargs)
36
+
37
+ @abstractmethod
38
+ def get_yaml_config(self):
39
+ """Get the yaml format config from self.
40
+
41
+ Returns:
42
+ """
43
+ pass
44
+
45
+ @staticmethod
46
+ @abstractmethod
47
+ def parse_from_yaml_config(yaml_config):
48
+ """Parse from the yaml to generate the config object.
49
+
50
+ Args:
51
+ yaml_config (dict): configs from yaml file.
52
+
53
+ Returns:
54
+ EasyTPP.Config: Config class for data.
55
+ """
56
+ pass
57
+
58
+ @abstractmethod
59
+ def copy(self):
60
+ """Get a same and freely modifiable copy of self.
61
+
62
+ Returns:
63
+ """
64
+ pass
65
+
66
+ def __str__(self):
67
+ """Str representation of the config.
68
+
69
+ Returns:
70
+ str: str representation of the dict format of the config.
71
+ """
72
+ return str(self.get_yaml_config())
73
+
74
+ def update(self, config):
75
+ """Update the config.
76
+
77
+ Args:
78
+ config (dict): config dict.
79
+
80
+ Returns:
81
+ EasyTPP.Config: Config class for data.
82
+ """
83
+ logger.critical(f'Update config class {self.__class__.__name__}')
84
+ return self.parse_from_yaml_config(config)
85
+
86
+ def pop(self, key: str, default_var: Any):
87
+ """pop out the key-value item from the config.
88
+
89
+ Args:
90
+ key (str): key name.
91
+ default_var (Any): default value to pop.
92
+
93
+ Returns:
94
+ Any: value to pop.
95
+ """
96
+ return vars(self).pop(key) or default_var
97
+
98
+ def get(self, key: str, default_var: Any):
99
+ """Retrieve the key-value item from the config.
100
+
101
+ Args:
102
+ key (str): key name.
103
+ default_var (Any): default value to pop.
104
+
105
+ Returns:
106
+ Any: value to get.
107
+ """
108
+ return vars(self)[key] or default_var
109
+
110
+ def set(self, key: str, var_to_set: Any):
111
+ """Set the key-value item from the config.
112
+
113
+ Args:
114
+ key (str): key name.
115
+ var_to_set (Any): default value to pop.
116
+
117
+ Returns:
118
+ Any: value to get.
119
+ """
120
+ vars(self)[key] = var_to_set
easy_tpp/config_factory/data_config.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easy_tpp.config_factory.config import Config
2
+
3
+
4
+ class DataSpecConfig(Config):
5
+ def __init__(self, **kwargs):
6
+ """Initialize the Config class.
7
+ """
8
+ self.num_event_types = kwargs.get('num_event_types')
9
+ self.pad_token_id = kwargs.get('pad_token_id')
10
+ self.padding_side = kwargs.get('padding_side')
11
+ self.truncation_side = kwargs.get('truncation_side')
12
+ self.padding_strategy = kwargs.get('padding_strategy')
13
+ self.max_len = kwargs.get('max_len')
14
+ self.truncation_strategy = kwargs.get('truncation_strategy')
15
+ self.num_event_types_pad = self.num_event_types + 1
16
+ self.model_input_names = kwargs.get('model_input_names')
17
+
18
+ if self.padding_side is not None and self.padding_side not in ["right", "left"]:
19
+ raise ValueError(
20
+ f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
21
+ )
22
+
23
+ if self.truncation_side is not None and self.truncation_side not in ["right", "left"]:
24
+ raise ValueError(
25
+ f"Truncation side should be selected between 'right' and 'left', current value: {self.truncation_side}"
26
+ )
27
+
28
+ def get_yaml_config(self):
29
+ """Return the config in dict (yaml compatible) format.
30
+
31
+ Returns:
32
+ dict: config of the data specs in dict format.
33
+ """
34
+ return {
35
+ 'num_event_types': self.num_event_types,
36
+ 'pad_token_id': self.pad_token_id,
37
+ 'padding_side': self.padding_side,
38
+ 'truncation_side': self.truncation_side,
39
+ 'padding_strategy': self.padding_strategy,
40
+ 'truncation_strategy': self.truncation_strategy,
41
+ 'max_len': self.max_len
42
+ }
43
+
44
+ @staticmethod
45
+ def parse_from_yaml_config(yaml_config):
46
+ """Parse from the yaml to generate the config object.
47
+
48
+ Args:
49
+ yaml_config (dict): configs from yaml file.
50
+
51
+ Returns:
52
+ DataSpecConfig: Config class for data specs.
53
+ """
54
+ return DataSpecConfig(**yaml_config)
55
+
56
+ def copy(self):
57
+ """Copy the config.
58
+
59
+ Returns:
60
+ DataSpecConfig: a copy of current config.
61
+ """
62
+ return DataSpecConfig(num_event_types_pad=self.num_event_types_pad,
63
+ num_event_types=self.num_event_types,
64
+ event_pad_index=self.pad_token_id,
65
+ padding_side=self.padding_side,
66
+ truncation_side=self.truncation_side,
67
+ padding_strategy=self.padding_strategy,
68
+ truncation_strategy=self.truncation_strategy,
69
+ max_len=self.max_len)
70
+
71
+
72
+ @Config.register('data_config')
73
+ class DataConfig(Config):
74
+ def __init__(self, train_dir, valid_dir, test_dir, data_format, specs=None):
75
+ """Initialize the DataConfig object.
76
+
77
+ Args:
78
+ train_dir (str): dir of tran set.
79
+ valid_dir (str): dir of valid set.
80
+ test_dir (str): dir of test set.
81
+ specs (dict, optional): specs of dataset. Defaults to None.
82
+ """
83
+ self.train_dir = train_dir
84
+ self.valid_dir = valid_dir
85
+ self.test_dir = test_dir
86
+ self.data_specs = specs or DataSpecConfig()
87
+ self.data_format = train_dir.split('.')[-1] if data_format is None else data_format
88
+
89
+ def get_yaml_config(self):
90
+ """Return the config in dict (yaml compatible) format.
91
+
92
+ Returns:
93
+ dict: config of the data in dict format.
94
+ """
95
+ return {
96
+ 'train_dir': self.train_dir,
97
+ 'valid_dir': self.valid_dir,
98
+ 'test_dir': self.test_dir,
99
+ 'data_format': self.data_format,
100
+ 'data_specs': self.data_specs.get_yaml_config(),
101
+ }
102
+
103
+ @staticmethod
104
+ def parse_from_yaml_config(yaml_config):
105
+ """Parse from the yaml to generate the config object.
106
+
107
+ Args:
108
+ yaml_config (dict): configs from yaml file.
109
+
110
+ Returns:
111
+ EasyTPP.DataConfig: Config class for data.
112
+ """
113
+ return DataConfig(
114
+ train_dir=yaml_config.get('train_dir'),
115
+ valid_dir=yaml_config.get('valid_dir'),
116
+ test_dir=yaml_config.get('test_dir'),
117
+ data_format=yaml_config.get('data_format'),
118
+ specs=DataSpecConfig.parse_from_yaml_config(yaml_config.get('data_specs'))
119
+ )
120
+
121
+ def copy(self):
122
+ """Copy the config.
123
+
124
+ Returns:
125
+ EasyTPP.DataConfig: a copy of current config.
126
+ """
127
+ return DataConfig(train_dir=self.train_dir,
128
+ valid_dir=self.valid_dir,
129
+ test_dir=self.test_dir,
130
+ specs=self.data_specs)
131
+
132
+ def get_data_dir(self, split):
133
+ """Get the dir of the source raw data.
134
+
135
+ Args:
136
+ split (str): dataset split notation, 'train', 'dev' or 'valid', 'test'.
137
+
138
+ Returns:
139
+ str: dir of the source raw data file.
140
+ """
141
+ split = split.lower()
142
+ if split == 'train':
143
+ return self.train_dir
144
+ elif split in ['dev', 'valid']:
145
+ return self.valid_dir
146
+ else:
147
+ return self.test_dir
easy_tpp/config_factory/hpo_config.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easy_tpp.config_factory.config import Config
2
+ from easy_tpp.config_factory.runner_config import RunnerConfig
3
+ from easy_tpp.utils import parse_uri_to_protocol_and_path, py_assert
4
+
5
+
6
+ class HPOConfig(Config):
7
+ def __init__(self, framework_id, storage_uri, is_continuous, num_trials, num_jobs):
8
+ """Initialize the HPO Config
9
+
10
+ Args:
11
+ framework_id (str): hpo framework id.
12
+ storage_uri (str): result storage dir.
13
+ is_continuous (bool): whether to continuously do the optimization.
14
+ num_trials (int): num of trails used in optimization.
15
+ num_jobs (int): num of the jobs.
16
+ """
17
+ self.framework_id = framework_id or 'optuna'
18
+ self.is_continuous = is_continuous if is_continuous is not None else True
19
+ self.num_trials = num_trials or 50
20
+ self.storage_uri = storage_uri
21
+ self.num_jobs = num_jobs if num_jobs is not None else 1
22
+
23
+ @property
24
+ def storage_protocol(self):
25
+ """Get the storage protocol
26
+
27
+ Returns:
28
+ str: the dir of the storage protocol.
29
+ """
30
+ storage_protocol, _ = parse_uri_to_protocol_and_path(self.storage_uri)
31
+ return storage_protocol
32
+
33
+ @property
34
+ def storage_path(self):
35
+ """Get the storage protocol
36
+
37
+ Returns:
38
+ str: the dir of the hpo data storage.
39
+ """
40
+ _, storage_path = parse_uri_to_protocol_and_path(self.storage_uri)
41
+ return storage_path
42
+
43
+ def get_yaml_config(self):
44
+ """Return the config in dict (yaml compatible) format.
45
+
46
+ Returns:
47
+ dict: config of the HPO specs in dict format.
48
+ """
49
+ return {
50
+ 'framework_id': self.framework_id,
51
+ 'storage_uri': self.storage_uri,
52
+ 'is_continuous': self.is_continuous,
53
+ 'num_trials': self.num_trials,
54
+ 'num_jobs': self.num_jobs
55
+ }
56
+
57
+ @staticmethod
58
+ def parse_from_yaml_config(yaml_config, **kwargs):
59
+ """Parse from the yaml to generate the config object.
60
+
61
+ Args:
62
+ yaml_config (dict): configs from yaml file.
63
+
64
+ Returns:
65
+ EasyTPP.HPOConfig: Config class for HPO specs.
66
+ """
67
+ if yaml_config is None:
68
+ return None
69
+ else:
70
+ return HPOConfig(
71
+ framework_id=yaml_config.get('framework_id'),
72
+ storage_uri=yaml_config.get('storage_uri'),
73
+ is_continuous=yaml_config.get('is_continuous'),
74
+ num_trials=yaml_config.get('num_trials'),
75
+ num_jobs=yaml_config.get('num_jobs'),
76
+ )
77
+
78
+ def copy(self):
79
+ """Copy the config.
80
+
81
+ Returns:
82
+ EasyTPP.HPOConfig: a copy of current config.
83
+ """
84
+ return HPOConfig(
85
+ framework_id=self.framework_id,
86
+ storage_uri=self.storage_uri,
87
+ is_continuous=self.is_continuous,
88
+ num_trials=self.num_trials,
89
+ num_jobs=self.num_jobs
90
+ )
91
+
92
+
93
+ @Config.register('hpo_runner_config')
94
+ class HPORunnerConfig(Config):
95
+ def __init__(self, hpo_config, runner_config):
96
+ """Initialize the config class
97
+
98
+ Args:
99
+ hpo_config (EasyTPP.HPOConfig): hpo config class.
100
+ runner_config (EasyTPP.RunnerConfig): runner config class.
101
+ """
102
+ self.hpo_config = hpo_config
103
+ self.runner_config = runner_config
104
+
105
+ @staticmethod
106
+ def parse_from_yaml_config(yaml_config, **kwargs):
107
+ """Parse from the yaml to generate the config object.
108
+
109
+ Args:
110
+ yaml_config (dict): configs from yaml file.
111
+
112
+ Returns:
113
+ EasyTPP.HPORunnerConfig: Config class for HPO specs.
114
+ """
115
+ runner_config = RunnerConfig.parse_from_yaml_config(yaml_config, **kwargs)
116
+ hpo_config = HPOConfig.parse_from_yaml_config(yaml_config.get('hpo'), **kwargs)
117
+ py_assert(hpo_config is not None, ValueError, 'No hpo configs is provided for HyperTuner')
118
+ return HPORunnerConfig(
119
+ hpo_config=hpo_config,
120
+ runner_config=runner_config
121
+ )
122
+
123
+ def copy(self):
124
+ """Copy the config.
125
+
126
+ Returns:
127
+ EasyTPP.HPORunnerConfig: a copy of current config.
128
+ """
129
+ return HPORunnerConfig(
130
+ hpo_config=self.hpo_config,
131
+ runner_config=self.runner_config
132
+ )
easy_tpp/config_factory/model_config.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easy_tpp.config_factory.config import Config
2
+
3
+ from easy_tpp.utils.const import Backend
4
+
5
+
6
+ class TrainerConfig(Config):
7
+
8
+ def __init__(self, **kwargs):
9
+ """Initialize the Config class.
10
+ """
11
+ self.seed = kwargs.get('seed', 9899)
12
+ self.gpu = kwargs.get('gpu', -1)
13
+ self.batch_size = kwargs.get('batch_size', 256)
14
+ self.max_epoch = kwargs.get('max_epoch', 10)
15
+ self.shuffle = kwargs.get('shuffle', False)
16
+ self.optimizer = kwargs.get('optimizer', 'adam')
17
+ self.learning_rate = kwargs.get('learning_rate', 1.e-3)
18
+ self.valid_freq = kwargs.get('valid_freq', 1)
19
+ self.use_tfb = kwargs.get('use_tfb', False)
20
+ self.metrics = kwargs.get('metrics', ['acc', 'rmse'])
21
+
22
+ def get_yaml_config(self):
23
+ """Return the config in dict (yaml compatible) format.
24
+
25
+ Returns:
26
+ dict: config of the trainer specs in dict format.
27
+ """
28
+ return {'seed': self.seed,
29
+ 'gpu': self.gpu,
30
+ 'batch_size': self.batch_size,
31
+ 'max_epoch': self.max_epoch,
32
+ 'shuffle': self.shuffle,
33
+ 'optimizer': self.optimizer,
34
+ 'learning_rate': self.learning_rate,
35
+ 'valid_freq': self.valid_freq,
36
+ 'use_tfb': self.use_tfb,
37
+ 'metrics': self.metrics
38
+ }
39
+
40
+ @staticmethod
41
+ def parse_from_yaml_config(yaml_config):
42
+ """Parse from the yaml to generate the config object.
43
+
44
+ Args:
45
+ yaml_config (dict): configs from yaml file.
46
+
47
+ Returns:
48
+ EasyTPP.TrainerConfig: Config class for trainer specs.
49
+ """
50
+ return TrainerConfig(**yaml_config)
51
+
52
+ def copy(self):
53
+ """Copy the config.
54
+
55
+ Returns:
56
+ EasyTPP.TrainerConfig: a copy of current config.
57
+ """
58
+ return TrainerConfig(batch_size=self.batch_size,
59
+ max_epoch=self.max_epoch,
60
+ shuffle=self.shuffle,
61
+ optimizer=self.optimizer,
62
+ learning_rate=self.learning_rate,
63
+ valid_freq=self.valid_freq,
64
+ use_tfb=self.use_tfb,
65
+ metrics=self.metrics
66
+ )
67
+
68
+
69
+ class ThinningConfig(Config):
70
+ def __init__(self, **kwargs):
71
+ """Initialize the Config class.
72
+ """
73
+ self.num_seq = kwargs.get('num_seq', 10)
74
+ self.num_sample = kwargs.get('num_sample', 1)
75
+ self.num_exp = kwargs.get('num_exp', 500)
76
+ self.look_ahead_time = kwargs.get('look_ahead_time', 10)
77
+ self.patience_counter = kwargs.get('patience_counter', 5)
78
+ self.over_sample_rate = kwargs.get('over_sample_rate', 5)
79
+ self.num_samples_boundary = kwargs.get('num_samples_boundary', 5)
80
+ self.dtime_max = kwargs.get('dtime_max', 5)
81
+ # we pad the sequence at the front only in multi-step generation
82
+ self.num_step_gen = kwargs.get('num_step_gen', 1)
83
+
84
+ def get_yaml_config(self):
85
+ """Return the config in dict (yaml compatible) format.
86
+
87
+ Returns:
88
+ dict: config of the thinning specs in dict format.
89
+ """
90
+ return {'num_seq': self.num_seq,
91
+ 'num_sample': self.num_sample,
92
+ 'num_exp': self.num_exp,
93
+ 'look_ahead_time': self.look_ahead_time,
94
+ 'patience_counter': self.patience_counter,
95
+ 'over_sample_rate': self.over_sample_rate,
96
+ 'num_samples_boundary': self.num_samples_boundary,
97
+ 'dtime_max': self.dtime_max,
98
+ 'num_step_gen': self.num_step_gen}
99
+
100
+ @staticmethod
101
+ def parse_from_yaml_config(yaml_config):
102
+ """Parse from the yaml to generate the config object.
103
+
104
+ Args:
105
+ yaml_config (dict): configs from yaml file.
106
+
107
+ Returns:
108
+ EasyTPP.ThinningConfig: Config class for thinning algorithms.
109
+ """
110
+ return ThinningConfig(**yaml_config) if yaml_config is not None else None
111
+
112
+ def copy(self):
113
+ """Copy the config.
114
+
115
+ Returns:
116
+ EasyTPP.ThinningConfig: a copy of current config.
117
+ """
118
+ return ThinningConfig(num_seq=self.num_seq,
119
+ num_sample=self.num_sample,
120
+ num_exp=self.num_exp,
121
+ look_ahead_time=self.look_ahead_time,
122
+ patience_counter=self.patience_counter,
123
+ over_sample_rate=self.over_sample_rate,
124
+ num_samples_boundary=self.num_samples_boundary,
125
+ dtime_max=self.dtime_max,
126
+ num_step_gen=self.num_step_gen)
127
+
128
+
129
+ class BaseConfig(Config):
130
+ def __init__(self, **kwargs):
131
+ """Initialize the Config class.
132
+ """
133
+ self.stage = kwargs.get('stage')
134
+ self.backend = kwargs.get('backend')
135
+ self.dataset_id = kwargs.get('dataset_id')
136
+ self.runner_id = kwargs.get('runner_id')
137
+ self.model_id = kwargs.get('model_id')
138
+ self.exp_id = kwargs.get('exp_id')
139
+ self.base_dir = kwargs.get('base_dir')
140
+ self.specs = kwargs.get('specs', {})
141
+ self.backend = self.set_backend(self.backend)
142
+
143
+ @staticmethod
144
+ def set_backend(backend):
145
+ if backend.lower() in ['torch', 'pytorch']:
146
+ return Backend.Torch
147
+ else:
148
+ raise ValueError(
149
+ f"Backend should be 'torch' or 'pytorch', current value: {backend}"
150
+ )
151
+
152
+ def get_yaml_config(self):
153
+ """Return the config in dict (yaml compatible) format.
154
+
155
+ Returns:
156
+ dict: config of the base config specs in dict format.
157
+ """
158
+ return {'stage': self.stage,
159
+ 'backend': str(self.backend),
160
+ 'dataset_id': self.dataset_id,
161
+ 'runner_id': self.runner_id,
162
+ 'model_id': self.model_id,
163
+ 'base_dir': self.base_dir,
164
+ 'specs': self.specs}
165
+
166
+ @staticmethod
167
+ def parse_from_yaml_config(yaml_config):
168
+ """Parse from the yaml to generate the config object.
169
+
170
+ Args:
171
+ yaml_config (dict): configs from yaml file.
172
+
173
+ Returns:
174
+ BaseConfig: Config class for trainer specs.
175
+ """
176
+ return BaseConfig(**yaml_config)
177
+
178
+ def copy(self):
179
+ """Copy the config.
180
+
181
+ Returns:
182
+ BaseConfig: a copy of current config.
183
+ """
184
+ return BaseConfig(stage=self.stage,
185
+ backend=self.backend,
186
+ dataset_id=self.dataset_id,
187
+ runner_id=self.runner_id,
188
+ model_id=self.model_id,
189
+ base_dir=self.base_dir,
190
+ specs=self.specs)
191
+
192
+
193
+ class ModelConfig(Config):
194
+ def __init__(self, **kwargs):
195
+ """Initialize the Config class.
196
+ """
197
+ self.rnn_type = kwargs.get('rnn_type', 'LSTM')
198
+ self.hidden_size = kwargs.get('hidden_size', 32)
199
+ self.time_emb_size = kwargs.get('time_emb_size', 16)
200
+ self.num_layers = kwargs.get('num_layers', 2)
201
+ self.num_heads = kwargs.get('num_heads', 2)
202
+ self.sharing_param_layer = kwargs.get('sharing_param_layer', False)
203
+ self.use_mc_samples = kwargs.get('use_mc_samples', True) # if using MC samples in computing log-likelihood
204
+ self.loss_integral_num_sample_per_step = kwargs.get('loss_integral_num_sample_per_step', 20) # mc_num_sample_per_step
205
+ self.dropout_rate = kwargs.get('dropout_rate', 0.0)
206
+ self.use_ln = kwargs.get('use_ln', False)
207
+ self.thinning = ThinningConfig.parse_from_yaml_config(kwargs.get('thinning'))
208
+ self.is_training = kwargs.get('training', False)
209
+ self.num_event_types_pad = kwargs.get('num_event_types_pad', None)
210
+ self.num_event_types = kwargs.get('num_event_types', None)
211
+ self.pad_token_id = kwargs.get('event_pad_index', None)
212
+ self.model_id = kwargs.get('model_id', None)
213
+ self.pretrained_model_dir = kwargs.get('pretrained_model_dir', None)
214
+ self.gpu = kwargs.get('gpu', -1)
215
+ self.model_specs = kwargs.get('model_specs', {})
216
+
217
+ def get_yaml_config(self):
218
+ """Return the config in dict (yaml compatible) format.
219
+
220
+ Returns:
221
+ dict: config of the model config specs in dict format.
222
+ """
223
+ return {'rnn_type': self.rnn_type,
224
+ 'hidden_size': self.hidden_size,
225
+ 'time_emb_size': self.time_emb_size,
226
+ 'num_layers': self.num_layers,
227
+ 'sharing_param_layer': self.sharing_param_layer,
228
+ 'loss_integral_num_sample_per_step': self.loss_integral_num_sample_per_step,
229
+ 'dropout_rate': self.dropout_rate,
230
+ 'use_ln': self.use_ln,
231
+ # for some models / cases we may not need to pass thinning config
232
+ # e.g., for intensity-free model
233
+ 'thinning': None if self.thinning is None else self.thinning.get_yaml_config(),
234
+ 'num_event_types_pad': self.num_event_types_pad,
235
+ 'num_event_types': self.num_event_types,
236
+ 'event_pad_index': self.pad_token_id,
237
+ 'model_id': self.model_id,
238
+ 'pretrained_model_dir': self.pretrained_model_dir,
239
+ 'gpu': self.gpu,
240
+ 'model_specs': self.model_specs}
241
+
242
+ @staticmethod
243
+ def parse_from_yaml_config(yaml_config):
244
+ """Parse from the yaml to generate the config object.
245
+
246
+ Args:
247
+ yaml_config (dict): configs from yaml file.
248
+
249
+ Returns:
250
+ ModelConfig: Config class for trainer specs.
251
+ """
252
+ return ModelConfig(**yaml_config)
253
+
254
+ def copy(self):
255
+ """Copy the config.
256
+
257
+ Returns:
258
+ ModelConfig: a copy of current config.
259
+ """
260
+ return ModelConfig(rnn_type=self.rnn_type,
261
+ hidden_size=self.hidden_size,
262
+ time_emb_size=self.time_emb_size,
263
+ num_layers=self.num_layers,
264
+ sharing_param_layer=self.sharing_param_layer,
265
+ loss_integral_num_sample_per_step=self.loss_integral_num_sample_per_step,
266
+ dropout_rate=self.dropout_rate,
267
+ use_ln=self.use_ln,
268
+ thinning=self.thinning,
269
+ num_event_types_pad=self.num_event_types_pad,
270
+ num_event_types=self.num_event_types,
271
+ event_pad_index=self.pad_token_id,
272
+ pretrained_model_dir=self.pretrained_model_dir,
273
+ gpu=self.gpu,
274
+ model_specs=self.model_specs)