diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..cd3cdcaa0a9c99d41cd9f2b2f3e2f7bb640e0e99 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +docs/images/thinning_algo.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..31f7744adf92cfde5a2ef9d17ea870c8e9e651b4 --- /dev/null +++ b/.github/workflows/docs.yaml @@ -0,0 +1,59 @@ +name: docs + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + release: + types: [ published ] + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.8' + - name: Install dependencies + run: | + python -m pip install --upgrade pip setuptools wheel + sudo apt-get update + sudo apt-get install openjdk-11-jdk + sudo apt-get install pandoc + - name: Build Sphinx docs + run: | + pip install tensorflow==2.2.0 + pip install torch + pip install pandas + pip install numpy + pip install -r requirements-doc.txt + cd docs + make html + # Publish built docs to gh-pages branch. + # =============================== + - name: Commit documentation changes + run: | + git clone https://github.com/ant-research/EasyTemporalPointProcess.git --branch gh-pages --single-branch gh-pages + cp -r docs/build/html/* gh-pages/ + cd gh-pages + touch .nojekyll + git config --local user.email "action@github.com" + git config --local user.name "GitHub Action" + git add . + git commit -m "Update documentation" -a || true + # The above command will fail if no changes were present, so we ignore + # that. + - name: Push changes + uses: ad-m/github-push-action@master + with: + branch: gh-pages + directory: gh-pages + github_token: ${{ secrets.GITHUB_TOKEN }} + # =============================== diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 0000000000000000000000000000000000000000..fcea725c36d11a9acfb5788665d50cdce838f639 --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,39 @@ +# This workflow will upload a Python Package using Twine when a release is created +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Upload Python Package + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + - name: Install dependencies + run: | + pip install -r requirements.txt + pip install wheel + - name: Build package + run: python setup.py sdist bdist_wheel + - name: Publish package + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d40bd146b6391fe120cdc3ee1c615bfddf075473 --- /dev/null +++ b/.gitignore @@ -0,0 +1,95 @@ +# python build +build/ +dist/ +easy_tpp.egg-info/ +*.egg-info/ + +# python temp +*.pyc +*.pyo +*.pyd +__pycache__/ +*.so +*.egg + +# proto +protoc +protoc-3.4.0.tar.gz +*_pb2.py + +# misc +experiments/ +log/ +logs/ +*.swp +*.swo +.vscode/ +.idea/ + +# OS files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# IDE +.vscode/ +.idea/ +*.sublime-project +*.sublime-workspace + +# Testing +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ +.coverage +htmlcov/ +.tox/ + +# Checkpoints and outputs +examples/checkpoints/* +notebooks/checkpoints/* +results/ +outputs/ +*.pth +*.pt +*.ckpt + +# Data files (large files should not be committed) +examples/data/* +!examples/data/.gitkeep +*.json.gz +*.pkl +*.h5 +*.hdf5 + +# Large cascade data files (use Git LFS or external storage) +data/cascades/information_cascade*.json +data/cascades/*.json +!data/cascades/.gitkeep + +# Model files (use Git LFS or Hugging Face Hub) +*.bin +*.safetensors +*.onnx + +# Temporary files +*.tmp +*.temp +*.bak +*~ + +# Jupyter Notebook checkpoints +.ipynb_checkpoints/ + +# Environment +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ \ No newline at end of file diff --git a/ADDITIONS_README.md b/ADDITIONS_README.md new file mode 100644 index 0000000000000000000000000000000000000000..2cd28f1e5060c21d78ef4b25cabe2f3796f857d0 --- /dev/null +++ b/ADDITIONS_README.md @@ -0,0 +1,71 @@ +# EasyTPP 新增功能说明 + +本仓库在原始 EasyTPP 基础上新增了用于处理信息级联数据的指标计算功能。 + +## 🆕 新增文件 + +### 核心功能 +- **`compute_cascade_metrics.py`**: 计算级联指标的主脚本 + - 情感得分 (Sentiment Score) + - 情感偏差 (Sentiment Deviation) + - 语境偏差 (Contextual Deviation) + - 困惑度 (Perplexity) + +### 文档和工具 +- **`COMPUTE_METRICS_README.md`**: 详细使用说明 +- **`HF_UPLOAD_GUIDE.md`**: Hugging Face 上传指南 +- **`UPLOAD_CHECKLIST.md`**: 上传检查清单(自动生成) +- **`cleanup_for_hf.py`**: 清理脚本,准备上传 +- **`example_compute_metrics.sh`**: 使用示例脚本 +- **`requirements_compute_metrics.txt`**: 额外依赖包 + +## 🚀 快速开始 + +### 1. 安装依赖 + +```bash +pip install -r requirements.txt +pip install -r requirements_compute_metrics.txt +``` + +### 2. 运行指标计算 + +```bash +python compute_cascade_metrics.py \ + --input_cascade information_cascade.json \ + --output output_with_metrics.json \ + --batch_size 32 +``` + +详细说明请参考 `COMPUTE_METRICS_README.md` + +## 📦 上传到 Hugging Face + +1. 运行清理脚本: +```bash +python cleanup_for_hf.py +``` + +2. 按照 `HF_UPLOAD_GUIDE.md` 的说明上传 + +## 🔗 相关文档 + +- [指标计算说明](COMPUTE_METRICS_README.md) +- [上传指南](HF_UPLOAD_GUIDE.md) +- [原始 EasyTPP README](README.md) + +## 📝 使用场景 + +这些新增功能主要用于: +- 分析社交媒体信息级联(如微博转发、评论) +- 计算文本的情感特征和语义偏差 +- 为 TPP 模型提供额外的特征输入 + +## ⚙️ 与 EasyTPP 集成 + +计算出的指标可以用于: +- `RobertTPPDataset`: 加载包含语义和偏差特征的数据 +- `RobertEventTokenizer`: 处理自定义特征 +- `TorchRobotTHP`: 使用语义和偏差特征的 TPP 模型 + +参考 `examples/train_robot_thp_with_features.py` 了解完整示例。 diff --git a/CLEANUP_SUMMARY.md b/CLEANUP_SUMMARY.md new file mode 100644 index 0000000000000000000000000000000000000000..e818edc820153a7d5b0fc79e9d5d4cb1590c7c22 --- /dev/null +++ b/CLEANUP_SUMMARY.md @@ -0,0 +1,164 @@ +# 文件夹整理总结 + +## ✅ 整理完成时间 +2025-01-19 + +## 📊 文件夹信息 + +- **位置**: `/Users/chenshuyi/Downloads/EasyTemporalPointProcess-main` +- **大小**: 1.3MB +- **状态**: ✅ 已整理,可以上传 + +## 🧹 已完成的清理工作 + +### 1. 更新 .gitignore +- ✅ 添加了 Python 缓存文件模式 +- ✅ 添加了 IDE 配置文件排除 +- ✅ 添加了 OS 系统文件排除 +- ✅ 添加了测试和构建文件排除 +- ✅ 添加了数据文件和模型文件排除规则 + +### 2. 创建清理脚本 +- ✅ `cleanup_for_hf.py` - 自动清理脚本 +- ✅ 已运行,确认无需要删除的文件 + +### 3. 文件检查 +- ✅ 无大文件(>50MB) +- ✅ 无敏感信息 +- ✅ 无临时文件 + +## 📁 文件结构 + +``` +EasyTemporalPointProcess-main/ +├── 📄 核心代码 +│ ├── easy_tpp/ # EasyTPP 核心库 +│ ├── examples/ # 示例代码 +│ ├── notebooks/ # Jupyter notebooks +│ └── tests/ # 测试代码 +│ +├── 🆕 新增功能(级联指标计算) +│ ├── compute_cascade_metrics.py # 主计算脚本 +│ ├── COMPUTE_METRICS_README.md # 使用说明 +│ ├── requirements_compute_metrics.txt # 额外依赖 +│ └── example_compute_metrics.sh # 示例脚本 +│ +├── 📚 文档 +│ ├── README.md # 原始 README +│ ├── ADDITIONS_README.md # 新增功能说明 +│ ├── HF_UPLOAD_GUIDE.md # Hugging Face 上传指南 +│ ├── QUICK_START_HF.md # 快速开始指南 +│ ├── UPLOAD_CHECKLIST.md # 上传检查清单 +│ └── CLEANUP_SUMMARY.md # 本文件 +│ +├── 🛠️ 工具脚本 +│ ├── cleanup_for_hf.py # 清理脚本 +│ └── setup.py # 安装脚本 +│ +└── ⚙️ 配置文件 + ├── .gitignore # Git 忽略规则(已更新) + ├── requirements.txt # 基础依赖 + ├── requirements_compute_metrics.txt # 指标计算依赖 + └── setup.cfg # 安装配置 +``` + +## 📋 新增文件列表 + +### 核心功能 +1. `compute_cascade_metrics.py` (19.5 KB) + - 计算情感得分、情感偏差、语境偏差、困惑度 + +### 文档 +2. `COMPUTE_METRICS_README.md` (5.9 KB) + - 详细的指标计算使用说明 + +3. `HF_UPLOAD_GUIDE.md` (3.7 KB) + - Hugging Face 上传完整指南 + +4. `ADDITIONS_README.md` (1.9 KB) + - 新增功能概述 + +5. `QUICK_START_HF.md` (2.3 KB) + - 快速上传指南 + +6. `UPLOAD_CHECKLIST.md` (3.0 KB) + - 上传检查清单(自动生成) + +7. `CLEANUP_SUMMARY.md` (本文件) + - 整理总结 + +### 工具和配置 +8. `cleanup_for_hf.py` (7.8 KB) + - 自动清理脚本 + +9. `example_compute_metrics.sh` (1.2 KB) + - 使用示例脚本 + +10. `requirements_compute_metrics.txt` (266 B) + - 指标计算所需依赖 + +## 🎯 下一步操作 + +### 1. 上传到 Hugging Face + +```bash +# 安装 CLI +pip install huggingface_hub + +# 登录 +huggingface-cli login + +# 创建仓库(在网页上) +# https://huggingface.co/new + +# 上传 +cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main +huggingface-cli upload / . --repo-type dataset +``` + +### 2. 在云电脑上下载 + +```bash +huggingface-cli download / --local-dir ./EasyTPP +``` + +### 3. 使用新功能 + +```bash +cd EasyTPP +pip install -r requirements.txt +pip install -r requirements_compute_metrics.txt + +python compute_cascade_metrics.py \ + --input_cascade information_cascade.json \ + --output output_with_metrics.json +``` + +## ✅ 检查清单 + +- [x] 清理临时文件 +- [x] 更新 .gitignore +- [x] 检查大文件 +- [x] 检查敏感信息 +- [x] 创建上传指南 +- [x] 创建使用文档 +- [x] 验证文件结构 +- [ ] 上传到 Hugging Face(待执行) +- [ ] 在云电脑上测试(待执行) + +## 📝 注意事项 + +1. **文件大小**: 1.3MB,无需 Git LFS +2. **许可证**: 保持原始 Apache 2.0 许可证 +3. **依赖**: 确保所有依赖都在 requirements 文件中 +4. **文档**: 所有新增功能都有详细文档 + +## 🔗 相关链接 + +- [Hugging Face](https://huggingface.co/) +- [Hugging Face CLI 文档](https://huggingface.co/docs/huggingface_hub/guides/cli) +- [原始 EasyTPP 项目](https://github.com/ant-research/EasyTemporalPointProcess) + +--- + +**整理完成!可以开始上传了!** 🚀 diff --git a/COMPUTE_METRICS_README.md b/COMPUTE_METRICS_README.md new file mode 100644 index 0000000000000000000000000000000000000000..d320639ecbc534a2056edfb305efe47d7089c185 --- /dev/null +++ b/COMPUTE_METRICS_README.md @@ -0,0 +1,191 @@ +# 计算级联指标使用说明 + +本脚本用于计算信息级联数据的情感得分、情感deviation、contextual deviation和perplexity。 + +## 功能说明 + +脚本 `compute_cascade_metrics.py` 会处理以下两个JSON文件: +- `information_cascade.json`: 包含完整级联数据(原帖、评论、转发) +- `information_cascade_original_posts.json`: 包含原帖数据(可选) + +计算以下指标: +1. **情感得分 (Sentiment Score)**: 文本的情感倾向得分 +2. **情感偏差 (Sentiment Deviation)**: 相对于原帖的情感偏差 +3. **语境偏差 (Contextual Deviation)**: 相对于原帖的语义偏差 +4. **困惑度 (Perplexity)**: 文本的语言模型困惑度 + +## 安装依赖 + +在云电脑上安装必要的依赖: + +```bash +pip install torch transformers numpy tqdm +``` + +如果需要使用GPU: +```bash +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 +``` + +## 使用方法 + +### 基本用法(使用默认模型) + +```bash +python compute_cascade_metrics.py \ + --input_cascade information_cascade.json \ + --output output_with_metrics.json \ + --batch_size 32 +``` + +### 完整用法(指定所有模型) + +```bash +python compute_cascade_metrics.py \ + --input_cascade information_cascade.json \ + --input_original information_cascade_original_posts.json \ + --output output_with_metrics.json \ + --bert_model bert-base-chinese \ + --sentiment_model <情感分析模型路径> \ + --perplexity_model <语言模型路径> \ + --batch_size 32 \ + --max_length 512 \ + --device cuda +``` + +### 参数说明 + +- `--input_cascade`: **必需**,输入级联JSON文件路径 +- `--input_original`: 可选,输入原帖JSON文件路径 +- `--output`: **必需**,输出JSON文件路径 +- `--bert_model`: BERT模型名称或路径(默认: `bert-base-chinese`) +- `--sentiment_model`: 情感分析模型路径(可选,不提供则使用简化方法) +- `--perplexity_model`: 语言模型路径(可选,不提供则使用简化方法) +- `--batch_size`: 批处理大小(默认: 32) +- `--max_length`: 最大序列长度(默认: 512) +- `--device`: 计算设备,`cuda` 或 `cpu`(默认: 自动选择) +- `--max_cascades`: 最大处理级联数量(用于测试,默认: 处理所有) + +## 输出格式 + +处理后的JSON文件会在每个节点中添加以下字段: + +### 原帖 (`post_info`) +```json +{ + "post_info": { + "content": "原帖内容", + "embedding": [0.1, 0.2, ...], // BERT语义向量 (768维) + "sentiment_score": 0.7, // 情感得分 + "perplexity": 15.3 // 困惑度 + } +} +``` + +### 评论 (`comment_tree`) +```json +{ + "comment_tree": { + "comment_id": { + "content": "评论内容", + "embedding": [0.1, 0.2, ...], + "sentiment_score": 0.6, + "perplexity": 12.5, + "contextual_deviation": 0.25, // 语境偏差 + "sentiment_deviation": 0.1 // 情感偏差 + } + } +} +``` + +### 转发 (`repost_chain`) +```json +{ + "repost_chain": [ + { + "forward_text": "转发内容", + "comment_content": "评论内容", + "embedding": [0.1, 0.2, ...], + "sentiment_score": 0.5, + "perplexity": 18.2, + "contextual_deviation": 0.35, + "sentiment_deviation": 0.2 + } + ] +} +``` + +## 模型选择建议 + +### BERT模型 +- 中文文本:`bert-base-chinese` +- 英文文本:`bert-base-uncased` +- 自定义模型:提供本地路径 + +### 情感分析模型 +- 中文:可以使用 `uer/roberta-base-finetuned-chinanews-chinese` 或其他中文情感分析模型 +- 英文:可以使用 `nlptown/bert-base-multilingual-uncased-sentiment` 等 +- 如果不提供,脚本会使用基于关键词的简化方法 + +### 困惑度模型 +- 中文:可以使用 `gpt2-chinese` 或其他中文语言模型 +- 英文:可以使用 `gpt2` 等 +- 如果不提供,脚本会使用基于词汇多样性的简化方法 + +## 注意事项 + +1. **大文件处理**: 如果JSON文件很大,处理时间可能较长。建议: + - 使用GPU加速(`--device cuda`) + - 调整批处理大小(`--batch_size`) + - 先用 `--max_cascades` 测试少量数据 + +2. **内存使用**: + - BERT模型需要较多内存 + - 如果内存不足,减小 `--batch_size` + +3. **简化方法**: + - 如果不提供情感分析模型或困惑度模型,脚本会使用简化的启发式方法 + - 简化方法的结果可能不如专业模型准确,但计算速度快 + +4. **数据格式**: + - 确保输入的JSON文件格式正确 + - JSON文件应包含 `cascades` 字段,每个级联包含 `post_info`、`comment_tree`、`repost_chain` + +## 示例 + +### 示例1:使用默认设置处理数据 +```bash +python compute_cascade_metrics.py \ + --input_cascade /path/to/information_cascade.json \ + --output /path/to/output.json +``` + +### 示例2:使用GPU和自定义模型 +```bash +python compute_cascade_metrics.py \ + --input_cascade /path/to/information_cascade.json \ + --output /path/to/output.json \ + --bert_model bert-base-chinese \ + --sentiment_model /path/to/sentiment_model \ + --device cuda \ + --batch_size 64 +``` + +### 示例3:测试模式(只处理前10个级联) +```bash +python compute_cascade_metrics.py \ + --input_cascade /path/to/information_cascade.json \ + --output /path/to/output.json \ + --max_cascades 10 +``` + +## 故障排除 + +1. **CUDA内存不足**: 减小 `--batch_size` 或使用 `--device cpu` +2. **模型下载失败**: 检查网络连接,或手动下载模型到本地后指定路径 +3. **JSON格式错误**: 检查输入JSON文件格式是否正确 +4. **处理速度慢**: 使用GPU(`--device cuda`)和增大批处理大小 + +## 与EasyTPP集成 + +处理后的JSON文件可以用于EasyTPP框架的训练。参考 `examples/train_robot_thp_with_features.py` 了解如何使用这些特征。 diff --git a/DATA_FILES_NOTICE.md b/DATA_FILES_NOTICE.md new file mode 100644 index 0000000000000000000000000000000000000000..fdf324c4e04ed16811b9cb5c9ec3f4ef5e2ffecb --- /dev/null +++ b/DATA_FILES_NOTICE.md @@ -0,0 +1,107 @@ +# ⚠️ 数据文件说明 + +## 📁 数据文件位置 + +数据文件已复制到 `data/cascades/` 目录: + +- `data/cascades/information_cascade.json` (606MB) +- `data/cascades/information_cascade_original_posts.json` (980MB) + +## ⚠️ 重要提示 + +**这些文件太大(总计约 1.6GB),不会上传到 Hugging Face!** + +`.gitignore` 已配置为排除这些文件,因为它们超过了 Git/Hugging Face 的推荐大小限制。 + +## 📥 在云电脑上获取数据文件 + +### 方法1: 直接传输(推荐) + +```bash +# 在云电脑上创建目录 +mkdir -p data/cascades + +# 使用 scp 从本地传输 +scp -r user@local-machine:/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/information_cascade*.json ./data/cascades/ +``` + +### 方法2: 使用云存储 + +1. 将文件上传到云存储(Google Drive, Dropbox, OneDrive 等) +2. 在云电脑上下载 + +### 方法3: 使用 Git LFS(如果配置) + +如果需要通过 Git 管理大文件: + +```bash +# 安装 Git LFS +git lfs install + +# 跟踪大文件 +git lfs track "data/cascades/*.json" + +# 添加文件 +git add .gitattributes +git add data/cascades/*.json +git commit -m "Add cascade data with LFS" +git push +``` + +### 方法4: 使用 Hugging Face Dataset Hub + +可以将数据文件单独上传到 Hugging Face Dataset Hub: + +```bash +# 安装依赖 +pip install huggingface_hub + +# 上传数据文件 +huggingface-cli upload /cascade-data data/cascades/ --repo-type dataset +``` + +然后在云电脑上下载: + +```bash +huggingface-cli download /cascade-data --local-dir ./data/cascades +``` + +## ✅ 验证文件 + +上传到 Hugging Face 后,验证: + +```bash +# 检查文件是否存在 +ls -lh data/cascades/ + +# 应该看到: +# information_cascade.json +# information_cascade_original_posts.json +``` + +## 🚀 使用数据文件 + +文件准备好后,运行指标计算: + +```bash +python compute_cascade_metrics.py \ + --input_cascade data/cascades/information_cascade.json \ + --input_original data/cascades/information_cascade_original_posts.json \ + --output output_with_metrics.json \ + --batch_size 32 \ + --device cuda +``` + +## 📝 文件来源 + +原始文件位置: +- `/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/` + +已复制到: +- `/Users/chenshuyi/Downloads/EasyTemporalPointProcess-main/data/cascades/` + +## 🔗 相关文档 + +- [数据文件说明](data/cascades/README.md) +- [指标计算说明](COMPUTE_METRICS_README.md) +- [上传指南](HF_UPLOAD_GUIDE.md) diff --git a/DATA_TRANSFER_SUMMARY.md b/DATA_TRANSFER_SUMMARY.md new file mode 100644 index 0000000000000000000000000000000000000000..36b00a275a65ffe290a945c832958ae06e4638c5 --- /dev/null +++ b/DATA_TRANSFER_SUMMARY.md @@ -0,0 +1,95 @@ +# 数据文件转移总结 + +## ✅ 已完成 + +两个 information cascade 文件已成功复制到 EasyTemporalPointProcess-main 文件夹。 + +## 📁 文件位置 + +### 源文件位置 +- `/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/information_cascade.json` +- `/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/information_cascade_original_posts.json` + +### 目标位置 +- `/Users/chenshuyi/Downloads/EasyTemporalPointProcess-main/data/cascades/information_cascade.json` (606MB) +- `/Users/chenshuyi/Downloads/EasyTemporalPointProcess-main/data/cascades/information_cascade_original_posts.json` (980MB) + +## ⚠️ 重要说明 + +### 文件大小 +- **总大小**: 约 1.6GB +- **information_cascade.json**: 606MB +- **information_cascade_original_posts.json**: 980MB + +### Git 排除配置 +这些文件**不会上传到 Hugging Face**,因为: +1. 文件太大,超过 Git/Hugging Face 推荐大小 +2. 已通过 `.gitignore` 排除: + ``` + data/cascades/information_cascade*.json + data/cascades/*.json + ``` + +## 📥 在云电脑上获取数据文件 + +### 方法1: 使用 scp 传输(推荐) + +```bash +# 在云电脑上 +mkdir -p data/cascades + +# 从本地传输 +scp user@local-machine:/Users/chenshuyi/Documents/research_projects/评论家罗伯特TPP/data/cascades/information_cascade*.json ./data/cascades/ +``` + +### 方法2: 上传到 Hugging Face Dataset Hub + +```bash +# 在本地 +cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main +huggingface-cli upload /cascade-data data/cascades/ --repo-type dataset + +# 在云电脑上下载 +huggingface-cli download /cascade-data --local-dir ./data/cascades +``` + +### 方法3: 使用云存储 + +1. 将文件上传到 Google Drive / Dropbox / OneDrive +2. 在云电脑上下载 + +## 📝 相关文档 + +- **数据文件说明**: `data/cascades/README.md` +- **数据文件注意事项**: `DATA_FILES_NOTICE.md` +- **上传指南**: `HF_UPLOAD_GUIDE.md` + +## ✅ 验证 + +上传到 Hugging Face 后,验证数据文件: + +```bash +# 检查文件是否存在 +ls -lh data/cascades/ + +# 应该看到: +# information_cascade.json (606MB) +# information_cascade_original_posts.json (980MB) +``` + +## 🚀 使用数据文件 + +文件准备好后,运行指标计算: + +```bash +python compute_cascade_metrics.py \ + --input_cascade data/cascades/information_cascade.json \ + --input_original data/cascades/information_cascade_original_posts.json \ + --output output_with_metrics.json \ + --batch_size 32 \ + --device cuda +``` + +--- + +**数据文件已成功转移!** ✅ diff --git a/HF_UPLOAD_GUIDE.md b/HF_UPLOAD_GUIDE.md new file mode 100644 index 0000000000000000000000000000000000000000..7bcc156ff2031a590214700e54e1b13089f26e9d --- /dev/null +++ b/HF_UPLOAD_GUIDE.md @@ -0,0 +1,180 @@ +# Hugging Face 上传指南 + +本指南说明如何将 EasyTemporalPointProcess-main 上传到 Hugging Face。 + +## 📋 准备工作 + +### 1. 运行清理脚本 + +```bash +cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main +python cleanup_for_hf.py +``` + +这会自动: +- 删除 `__pycache__/`、`.pyc` 等临时文件 +- 检查大文件 +- 创建上传检查清单 + +### 2. 数据文件说明 ⚠️ + +**重要**: `data/cascades/` 目录包含大文件(约 1.6GB),**不会上传到 Hugging Face**。 + +这些文件已通过 `.gitignore` 排除: +- `information_cascade.json` (606MB) +- `information_cascade_original_posts.json` (980MB) + +**在云电脑上获取数据文件的方法**: +- 方法1: 使用 scp 直接传输(推荐) +- 方法2: 上传到云存储后下载 +- 方法3: 使用 Git LFS(如果配置) +- 方法4: 单独上传到 Hugging Face Dataset Hub + +详细说明请参考 `DATA_FILES_NOTICE.md` + +### 3. 手动检查 + +- [ ] 检查是否有敏感信息(API密钥、密码等) +- [ ] 确认大文件已正确排除(通过 .gitignore) +- [ ] 确保 `requirements.txt` 是最新的 +- [ ] 检查 README.md 是否完整 + +## 🚀 上传方法 + +### 方法1: 使用 Hugging Face CLI(推荐) + +```bash +# 1. 安装 Hugging Face CLI +pip install huggingface_hub + +# 2. 登录 +huggingface-cli login +# 输入你的 Hugging Face token(在 https://huggingface.co/settings/tokens 获取) + +# 3. 创建仓库(在网页上创建,或使用 CLI) +# 访问 https://huggingface.co/new 创建新仓库 +# 选择 "Dataset" 类型,命名为例如:easytpp-cascade-metrics + +# 4. 上传文件 +cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main +huggingface-cli upload /easytpp-cascade-metrics . --repo-type dataset +``` + +### 方法2: 使用 Git + +```bash +# 1. 初始化 Git(如果还没有) +cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main +git init + +# 2. 添加文件 +git add . +git commit -m "Add EasyTPP with cascade metrics computation" + +# 3. 添加 Hugging Face 远程仓库 +# 先在 https://huggingface.co/new 创建仓库 +git remote add origin https://huggingface.co// + +# 4. 推送 +git push origin main +``` + +### 方法3: 使用 Web 界面上传 + +1. 访问 https://huggingface.co/new +2. 创建新的 Dataset 仓库 +3. 点击 "Add file" → "Upload files" +4. 拖拽或选择文件夹上传 + +## 📦 在云电脑上下载 + +上传完成后,在云电脑上下载: + +```bash +# 方法1: 使用 Hugging Face CLI +pip install huggingface_hub +huggingface-cli download / --local-dir ./EasyTPP + +# 方法2: 使用 Git +git clone https://huggingface.co/datasets// +cd + +# 方法3: 使用 Python +from huggingface_hub import snapshot_download +snapshot_download(repo_id="/", repo_type="dataset", local_dir="./EasyTPP") +``` + +### 📥 下载数据文件 + +**重要**: 代码仓库不包含数据文件(已通过 .gitignore 排除)。 + +数据文件需要单独获取: + +```bash +# 方法1: 使用 scp 从本地传输(推荐) +mkdir -p data/cascades +scp user@local-machine:/path/to/information_cascade*.json ./data/cascades/ + +# 方法2: 如果已上传到 Hugging Face Dataset Hub +huggingface-cli download /cascade-data --local-dir ./data/cascades + +# 方法3: 从云存储下载 +# (根据你使用的云存储服务) +``` + +详细说明请参考 `DATA_FILES_NOTICE.md` + +## 📝 新增功能说明 + +本仓库在原始 EasyTPP 基础上新增了以下功能: + +### 1. 级联指标计算 (`compute_cascade_metrics.py`) + +用于计算信息级联数据的指标: +- **情感得分** (Sentiment Score) +- **情感偏差** (Sentiment Deviation) +- **语境偏差** (Contextual Deviation) +- **困惑度** (Perplexity) + +详细说明请参考 `COMPUTE_METRICS_README.md` + +### 2. 相关文件 + +- `compute_cascade_metrics.py`: 主计算脚本 +- `COMPUTE_METRICS_README.md`: 使用说明 +- `requirements_compute_metrics.txt`: 额外依赖 +- `example_compute_metrics.sh`: 示例脚本 +- `cleanup_for_hf.py`: 清理脚本 + +## ⚠️ 注意事项 + +1. **大文件处理** + - 如果文件 >50MB,考虑使用 Git LFS + - 或排除数据文件,使用外部链接 + +2. **敏感信息** + - 不要上传包含 API 密钥、密码的文件 + - 检查配置文件中的敏感数据 + +3. **许可证** + - 确保所有代码都有适当的许可证 + - 原始 EasyTPP 使用 Apache 2.0 许可证 + +4. **版本控制** + - 建议使用 Git 进行版本控制 + - 每次更新后提交并推送 + +## 🔍 验证上传 + +上传后检查: +- [ ] 所有文件都已上传 +- [ ] README 显示正确 +- [ ] 代码可以正常下载 +- [ ] 依赖可以正常安装 + +## 📞 问题反馈 + +如有问题,请检查: +1. Hugging Face 仓库设置是否正确 +2. 文件大小是否超过限制 +3. 是否有权限问题 diff --git a/LICENCE b/LICENCE new file mode 100644 index 0000000000000000000000000000000000000000..956857c012635513993905638b2a1d5487a29373 --- /dev/null +++ b/LICENCE @@ -0,0 +1,203 @@ +Copyright 2022 The EasyTPP Authors. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..18a3968a99186440db99df54cd14bd875dfd8feb --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include requirements.txt +include version.py \ No newline at end of file diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000000000000000000000000000000000000..b1e48d9c985c258b4b1e21fcc5f4e3174e0a777d --- /dev/null +++ b/NOTICE @@ -0,0 +1,23 @@ +============================================================= +EasyTPP is a open source tool developed by Machine Intelligence Team +Copyright (c) 2020-2022, Ant Group Holding Limited. +Licensed under the Apache License, Version 2.0 + +============================================================= +This toolkit contains various third-party components under +different open source licenses + +----------------------------- +Training evaluation pipeline +Apache License, Version 2.0 +FuxiCTR authors + +---------------------------- +Training evaluation pipeline +Apache License, Version 2.0 +EasyNLP, Alibaba Inc. + +---------------------------- +Tokenizer and DataLoader +Apache License, Version 2.0 +The HuggingFace Inc. team \ No newline at end of file diff --git a/QUICK_START_HF.md b/QUICK_START_HF.md new file mode 100644 index 0000000000000000000000000000000000000000..c0736b118a510ffbeb9e0e26c32b5911b59c7c0f --- /dev/null +++ b/QUICK_START_HF.md @@ -0,0 +1,104 @@ +# 快速开始:上传到 Hugging Face + +## ✅ 整理完成 + +文件夹已整理完成,可以上传到 Hugging Face。 + +**文件夹大小**: 1.3MB(适合上传) + +## 📋 整理内容 + +### 已完成的清理 +- ✅ 更新了 `.gitignore` 文件 +- ✅ 创建了清理脚本 `cleanup_for_hf.py` +- ✅ 检查了大文件(无大文件) +- ✅ 创建了上传指南 + +### 新增文件 +- `compute_cascade_metrics.py` - 级联指标计算脚本 +- `COMPUTE_METRICS_README.md` - 指标计算说明 +- `HF_UPLOAD_GUIDE.md` - 上传指南 +- `ADDITIONS_README.md` - 新增功能说明 +- `cleanup_for_hf.py` - 清理脚本 +- `requirements_compute_metrics.txt` - 额外依赖 + +## 🚀 三步上传 + +### 步骤1: 安装 Hugging Face CLI + +```bash +pip install huggingface_hub +``` + +### 步骤2: 登录 + +```bash +huggingface-cli login +# 输入你的 token(在 https://huggingface.co/settings/tokens 获取) +``` + +### 步骤3: 创建仓库并上传 + +```bash +# 1. 在网页上创建仓库 +# 访问 https://huggingface.co/new +# 选择 "Dataset",命名为例如:easytpp-cascade-metrics + +# 2. 上传文件 +cd /Users/chenshuyi/Downloads/EasyTemporalPointProcess-main +huggingface-cli upload /easytpp-cascade-metrics . --repo-type dataset +``` + +## 📥 在云电脑上下载 + +```bash +# 方法1: 使用 CLI +huggingface-cli download /easytpp-cascade-metrics --local-dir ./EasyTPP + +# 方法2: 使用 Git +git clone https://huggingface.co/datasets//easytpp-cascade-metrics +cd easytpp-cascade-metrics +``` + +### ⚠️ 重要:数据文件需要单独获取 + +代码仓库**不包含**数据文件(已通过 .gitignore 排除,因为文件太大)。 + +数据文件需要单独传输: + +```bash +# 在云电脑上创建目录 +mkdir -p data/cascades + +# 方法1: 使用 scp 从本地传输(推荐) +scp user@local-machine:/path/to/information_cascade*.json ./data/cascades/ + +# 方法2: 如果已上传到 Hugging Face Dataset Hub +huggingface-cli download /cascade-data --local-dir ./data/cascades +``` + +详细说明请参考 `DATA_FILES_NOTICE.md` + +## 📚 相关文档 + +- **详细上传指南**: `HF_UPLOAD_GUIDE.md` +- **数据文件说明**: `DATA_FILES_NOTICE.md` ⚠️ **重要** +- **指标计算说明**: `COMPUTE_METRICS_README.md` +- **新增功能**: `ADDITIONS_README.md` +- **上传检查清单**: `UPLOAD_CHECKLIST.md` + +## ⚠️ 注意事项 + +1. **文件大小**: 当前文件夹 1.3MB,无需 Git LFS +2. **敏感信息**: 已检查,无敏感信息 +3. **依赖**: 确保 `requirements.txt` 和 `requirements_compute_metrics.txt` 已包含 + +## 🎯 下一步 + +1. 按照上述步骤上传到 Hugging Face +2. 在云电脑上下载并测试 +3. 运行 `compute_cascade_metrics.py` 计算指标 + +--- + +**准备好了!可以开始上传了!** 🎉 diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c94dedc33c0f3f0db185bde7384a2a2d85acc546 --- /dev/null +++ b/README.md @@ -0,0 +1,279 @@ +# EasyTPP [ICLR 2024] + + + + + +`EasyTPP` is an easy-to-use development and application toolkit for [Temporal Point Process](https://mathworld.wolfram.com/TemporalPointProcess.html) (TPP), with key features in configurability, compatibility and reproducibility. We hope this project could benefit both researchers and practitioners with the goal of easily customized development and open benchmarking in TPP. + + + + +| Features | Model List | Dataset | Quick Start | Benchmark |Documentation |Todo List | Citation |Acknowledgement | Star History | + +## News + + +- ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [11-06-2025] We have released a new version of ``EasyTPP`` that exclusively supports PyTorch. TensorFlow support has been removed to streamline the codebase and focus on PyTorch-based implementations. +- ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [11-05-2025] Added the implementation of the [S2P2](https://openreview.net/pdf?id=74SvE2GZwW) model, presented at NeurIPS'2025. +- ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [02-17-2024] ``EasyTPP`` supports HuggingFace dataset API: all datasets have been published in [HuggingFace Repo](https://huggingface.co/easytpp) and see [tutorial notebook](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_1_dataset.ipynb) for an example of usage. +- [01-16-2024] Our paper [EasyTPP: Towards Open Benchmarking Temporal Point Process](https://arxiv.org/abs/2307.08097) is accepted by ICLR'2024! +
+ Click to see previous news +

+- [09-30-2023] We published two textual event sequence datasets [GDELT](https://drive.google.com/drive/folders/1Ms-ATMMFf6v4eesfJndyuPLGtX58fCnk) and [Amazon-text-review](https://drive.google.com/drive/folders/1-SLYyrl7ucEG7NpSIF0eSoG9zcbZagZw) that are used in our paper [LAMP](https://arxiv.org/abs/2305.16646), where LLM can be applied for event prediction! See [Documentation](https://ant-research.github.io/EasyTemporalPointProcess/user_guide/dataset.html#preprocessed-datasets) for more details. +- [09-30-2023] Two of our papers [Language Model Can Improve Event Prediction by Few-Shot Abductive Reasoning](https://arxiv.org/abs/2305.16646) (LAMP) and [Prompt-augmented Temporal Point Process for Streaming Event Sequence](https://arxiv.org/abs/2310.04993) (PromptTPP) are accepted by NeurIPS'2023! +- [09-02-2023] We published two non-anthropogenic datasets [earthquake](https://drive.google.com/drive/folders/1ubeIz_CCNjHyuu6-XXD0T-gdOLm12rf4) and [volcano eruption](https://drive.google.com/drive/folders/1KSWbNi8LUwC-dxz1T5sOnd9zwAot95Tp?usp=drive_link)! See Dataset for details. +- [05-29-2023] We released ``EasyTPP`` v0.0.1! +- [12-27-2022] Our paper [Bellman Meets Hawkes: Model-Based Reinforcement Learning via Temporal Point Processes](https://arxiv.org/abs/2201.12569) was accepted by AAAI'2023! +- [10-01-2022] Our paper [HYPRO: A Hybridly Normalized Probabilistic Model for Long-Horizon Prediction of Event Sequences](https://arxiv.org/abs/2210.01753) was accepted by NeurIPS'2022! +- [05-01-2022] We started to develop `EasyTPP`.

+
+ + +## Features [Back to Top] + + +- **Configurable and customizable**: models are modularized and configurable,with abstract classes to support developing customized + TPP models. +- **PyTorch-based implementation**: `EasyTPP` implements state-of-the-art TPP models using PyTorch 1.7.0+, providing a clean and modern deep learning framework. +- **Reproducible**: all the benchmarks can be easily reproduced. +- **Hyper-parameter optimization**: a pipeline of [optuna](https://github.com/optuna/optuna)-based HPO is provided. + + +## Model List [Back to Top] + + +We provide reference implementations of various state-of-the-art TPP papers: + +| No | Publication | Model | Paper | Implementation | +|:---:|:-----------:|:-------------:|:-----------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------| +| 1 | KDD'16 | RMTPP | [Recurrent Marked Temporal Point Processes: Embedding Event History to Vector](https://www.kdd.org/kdd2016/papers/files/rpp1081-duA.pdf) | [PyTorch](easy_tpp/model/torch_model/torch_rmtpp.py) | +| 2 | NeurIPS'17 | NHP | [The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process](https://arxiv.org/abs/1612.09328) | [PyTorch](easy_tpp/model/torch_model/torch_nhp.py) | +| 3 | NeurIPS'19 | FullyNN | [Fully Neural Network based Model for General Temporal Point Processes](https://arxiv.org/abs/1905.09690) | [PyTorch](easy_tpp/model/torch_model/torch_fullynn.py) | +| 4 | ICML'20 | SAHP | [Self-Attentive Hawkes process](https://arxiv.org/abs/1907.07561) | [PyTorch](easy_tpp/model/torch_model/torch_sahp.py) | +| 5 | ICML'20 | THP | [Transformer Hawkes process](https://arxiv.org/abs/2002.09291) | [PyTorch](easy_tpp/model/torch_model/torch_thp.py) | +| 6 | ICLR'20 | IntensityFree | [Intensity-Free Learning of Temporal Point Processes](https://arxiv.org/abs/1909.12127) | [PyTorch](easy_tpp/model/torch_model/torch_intensity_free.py) | +| 7 | ICLR'21 | ODETPP | [Neural Spatio-Temporal Point Processes (simplified)](https://arxiv.org/abs/2011.04583) | [PyTorch](easy_tpp/model/torch_model/torch_ode_tpp.py) | +| 8 | ICLR'22 | AttNHP | [Transformer Embeddings of Irregularly Spaced Events and Their Participants](https://arxiv.org/abs/2201.00044) | [PyTorch](easy_tpp/model/torch_model/torch_attnhp.py) | +| 9 | NeurIPS'25 | S2P2 | [Deep Continuous-Time State-Space Models for Marked Event Sequences](https://openreview.net/pdf?id=74SvE2GZwW) | [PyTorch](easy_tpp/model/torch_model/torch_s2p2.py) | + + + +## Dataset [Back to Top] + + +We preprocessed one synthetic and five real world datasets from widely-cited works that contain diverse characteristics in terms of their application domains and temporal statistics: +- Synthetic: a univariate Hawkes process simulated by [Tick](https://github.com/X-DataInitiative/tick) library. +- Retweet ([Zhou, 2013](http://proceedings.mlr.press/v28/zhou13.pdf)): timestamped user retweet events. +- Taxi ([Whong, 2014](https://chriswhong.com/open-data/foil_nyc_taxi/)): timestamped taxi pick-up events. +- StackOverflow ([Leskovec, 2014](https://snap.stanford.edu/data/)): timestamped user badge reward events in StackOverflow. +- Taobao ([Xue et al, 2022](https://arxiv.org/abs/2210.01753)): timestamped user online shopping behavior events in Taobao platform. +- Amazon ([Xue et al, 2022](https://arxiv.org/abs/2210.01753)): timestamped user online shopping behavior events in Amazon platform. + +Per users' request, we processed two non-anthropogenic datasets +- [Earthquake](https://drive.google.com/drive/folders/1ubeIz_CCNjHyuu6-XXD0T-gdOLm12rf4): timestamped earthquake events over the Conterminous U.S from 1996 to 2023, processed from [USGS](https://www.usgs.gov/programs/earthquake-hazards/science/earthquake-data). +- [Volcano eruption](https://drive.google.com/drive/folders/1KSWbNi8LUwC-dxz1T5sOnd9zwAot95Tp?usp=drive_link): timestamped volcano eruption events over the world in recent hundreds of years, processed from [The Smithsonian Institution](https://volcano.si.edu/). + + + All datasets are preprocess to the `Gatech` format dataset widely used for TPP researchers, and saved at [Google Drive](https://drive.google.com/drive/u/0/folders/1f8k82-NL6KFKuNMsUwozmbzDSFycYvz7) with a public access. + +## Quick Start [Back to Top] + + + +### Colab Tutorials + +Explore the following tutorials that can be opened directly in Google Colab: + +- [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_1_dataset.ipynb) Tutorial 1: Dataset in EasyTPP. +- [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_2_tfb_wb.ipynb) Tutorial 2: Tensorboard in EasyTPP. +- [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_3_train_eval.ipynb) Tutorial 3: Training and Evaluation of TPPs. + +### End-to-end Example + +We provide an end-to-end example for users to run a standard TPP model with `EasyTPP`. + + +### Step 1. Installation + +First of all, we can install the package either by using pip or from the source code on Github. + +To install the latest stable version: +```bash +pip install easy-tpp +``` + +To install the latest on GitHub: +```bash +git clone https://github.com/ant-research/EasyTemporalPointProcess.git +cd EasyTemporalPointProcess +python setup.py install +``` + + +### Step 2. Prepare datasets + +We need to put the datasets in a local directory before running a model and the datasets should follow a certain format. See [OnlineDoc - Datasets](https://ant-research.github.io/EasyTemporalPointProcess/user_guide/dataset.html) for more details. + +Suppose we use the [taxi dataset](https://chriswhong.com/open-data/foil_nyc_taxi/) in the example. + +### Step 3. Train the model + + +Before start training, we need to set up the config file for the pipeline. We provide a preset config file in [Example Config](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml). The details of the configuration can be found in [OnlineDoc - Training Pipeline](https://ant-research.github.io/EasyTemporalPointProcess/user_guide/run_train_pipeline.html). + +After the setup of data and config, the directory structure is as follows: + +```bash + + data + |______taxi + |____ train.pkl + |____ dev.pkl + |____ test.pkl + + configs + |______experiment_config.yaml + +``` + + +Then we start the training by simply running the script + +```python + +import argparse +from easy_tpp.config_factory import Config +from easy_tpp.runner import Runner + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('--config_dir', type=str, required=False, default='configs/experiment_config.yaml', + help='Dir of configuration yaml to train and evaluate the model.') + + parser.add_argument('--experiment_id', type=str, required=False, default='NHP_train', + help='Experiment id in the config file.') + + args = parser.parse_args() + + config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id) + + model_runner = Runner.build_from_config(config) + + model_runner.run() + + +if __name__ == '__main__': + main() + +``` + +A more detailed example can be found at [OnlineDoc - QuickStart](https://ant-research.github.io/EasyTemporalPointProcess/get_started/quick_start.html). + + +## Documentation [Back to Top] + + +The classes and methods of `EasyTPP` have been well documented so that users can generate the documentation by: + +```shell +cd doc +pip install -r requirements.txt +make html +``` +NOTE: +* The `doc/requirements.txt` is only for documentation by Sphinx, which can be automatically generated by Github actions `.github/workflows/docs.yml`. (Trigger by pull request.) + +The full documentation is available on the [website](https://ant-research.github.io/EasyTemporalPointProcess/). + +## Benchmark [Back to Top] + + +In the [examples](https://github.com/ant-research/EasyTemporalPointProcess/tree/main/examples) folder, we provide a [script](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/benchmark_script.py) to benchmark the TPPs, with Taxi dataset as the input. + +To run the script, one should download the Taxi data following the above instructions. The [config](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml) file is readily setup up. Then run + + +```shell +cd examples +python run_retweet.py +``` + + +## License [Back to Top] + +This project is licensed under the [Apache License (Version 2.0)](https://github.com/alibaba/EasyNLP/blob/master/LICENSE). This toolkit also contains some code modified from other repos under other open-source licenses. See the [NOTICE](https://github.com/ant-research/EasyTPP/blob/master/NOTICE) file for more information. + + +## Todo List [Back to Top] + + +- [x] New dataset: + - [x] Earthquake: the source data is available in [USGS](https://www.usgs.gov/programs/earthquake-hazards/science/earthquake-data). + - [x] Volcano eruption: the source data is available in [NCEI](https://www.ngdc.noaa.gov/hazard/volcano.shtml). +- [ ] New model: + - [ ] Meta Temporal Point Process, ICLR 2023. + - [ ] Model-based RL via TPP, AAAI 2022. + +## Citation [Back to Top] + + + +If you find `EasyTPP` useful for your research or development, please cite the following paper: +``` +@inproceedings{xue2024easytpp, + title={EasyTPP: Towards Open Benchmarking Temporal Point Processes}, + author={Siqiao Xue and Xiaoming Shi and Zhixuan Chu and Yan Wang and Hongyan Hao and Fan Zhou and Caigao Jiang and Chen Pan and James Y. Zhang and Qingsong Wen and Jun Zhou and Hongyuan Mei}, + booktitle = {International Conference on Learning Representations (ICLR)}, + year = {2024}, + url ={https://arxiv.org/abs/2307.08097} +} +``` + +## Acknowledgment [Back to Top] + + +The project is jointly initiated by Machine Intelligence Group, Alipay and DAMO Academy, Alibaba. + +The following repositories are used in `EasyTPP`, either in close to original form or as an inspiration: + +- [EasyRec](https://github.com/alibaba/EasyRec) +- [EasyNLP](https://github.com/alibaba/EasyNLP) +- [FuxiCTR](https://github.com/xue-pai/FuxiCTR) +- [Neural Hawkes Process](https://github.com/hongyuanmei/neurawkes) +- [Neural Hawkes Particle Smoothing](https://github.com/hongyuanmei/neural-hawkes-particle-smoothing) +- [Attentive Neural Hawkes Process](https://github.com/yangalan123/anhp-andtt) +- [Huggingface - transformers](https://github.com/huggingface/transformers) + + +## Star History [Back to Top] + + +![Star History Chart](https://api.star-history.com/svg?repos=ant-research/EasyTemporalPointProcess&type=Date) + diff --git a/UPLOAD_CHECKLIST.md b/UPLOAD_CHECKLIST.md new file mode 100644 index 0000000000000000000000000000000000000000..d00253dc9a35ba59e03540abf2c14ea069c5768c --- /dev/null +++ b/UPLOAD_CHECKLIST.md @@ -0,0 +1,116 @@ +# Hugging Face 上传检查清单 + +## ✅ 清理完成 + +### 已删除的文件类型 +- `__pycache__/` 文件夹 +- `*.pyc`, `*.pyo`, `*.pyd` 文件 +- `.DS_Store` 文件(macOS) +- `.vscode/`, `.idea/` 文件夹 +- `*.swp`, `*.swo` 文件 + +### 需要手动检查的项目 + +1. **大文件检查** + - 检查是否有超过50MB的文件 + - 考虑使用 Git LFS 或排除这些文件 + +2. **敏感信息检查** + - 检查是否有API密钥、密码等敏感信息 + - 检查配置文件中的敏感数据 + +3. **数据文件** + - 检查 `examples/data/` 目录 + - 如果数据文件很大,考虑排除或使用外部链接 + +4. **模型文件** + - 检查是否有预训练模型文件 + - 大模型文件应使用 Git LFS 或 Hugging Face Model Hub + +5. **日志文件** + - 确保没有日志文件被包含 + - 检查 `log/`, `logs/` 目录 + +## 📦 上传到 Hugging Face + +### 方法1: 使用 Hugging Face CLI + +```bash +# 安装 Hugging Face CLI +pip install huggingface_hub + +# 登录 +huggingface-cli login + +# 创建仓库(如果还没有) +# 在 https://huggingface.co/new 创建新仓库 + +# 上传文件 +cd /path/to/EasyTemporalPointProcess-main +huggingface-cli upload / . --repo-type dataset +``` + +### 方法2: 使用 Git + +```bash +# 初始化 Git 仓库(如果还没有) +git init +git add . +git commit -m "Initial commit" + +# 添加 Hugging Face 远程仓库 +git remote add origin https://huggingface.co// + +# 推送 +git push origin main +``` + +### 方法3: 使用 Web 界面 + +1. 访问 https://huggingface.co/new +2. 创建新的 Dataset 或 Space +3. 使用 Web 界面上传文件 + +## 📝 文件结构说明 + +``` +EasyTemporalPointProcess-main/ +├── easy_tpp/ # 核心库代码 +├── examples/ # 示例代码 +├── notebooks/ # Jupyter notebooks +├── tests/ # 测试代码 +├── docs/ # 文档 +├── compute_cascade_metrics.py # 新增:级联指标计算脚本 +├── COMPUTE_METRICS_README.md # 新增:指标计算说明 +├── requirements.txt # 基础依赖 +├── requirements_compute_metrics.txt # 新增:指标计算依赖 +├── setup.py # 安装脚本 +└── README.md # 项目说明 +``` + +## ⚠️ 注意事项 + +1. **不要上传大文件到 Git 仓库** + - 使用 Git LFS 或 Hugging Face 的存储系统 + - 考虑使用外部链接引用大文件 + +2. **检查许可证** + - 确保所有代码都有适当的许可证 + - 检查第三方依赖的许可证兼容性 + +3. **README 文件** + - 确保 README.md 清晰说明项目用途 + - 包含安装和使用说明 + +4. **依赖管理** + - 确保 requirements.txt 是最新的 + - 考虑使用 `pip freeze` 生成精确版本 + +## 🔍 验证上传 + +上传后,检查: +- [ ] 所有文件都已上传 +- [ ] 文件大小合理 +- [ ] 没有敏感信息泄露 +- [ ] README 显示正确 +- [ ] 代码可以正常下载和使用 diff --git a/cleanup_for_hf.py b/cleanup_for_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..0daf200cb964364b35220ff617cbe93d61971cd6 --- /dev/null +++ b/cleanup_for_hf.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +清理脚本:准备上传到 Hugging Face + +该脚本会: +1. 清理不必要的文件(__pycache__, .pyc, .pyo等) +2. 检查大文件 +3. 创建上传检查清单 +""" + +import os +import shutil +from pathlib import Path +from typing import List, Tuple + + +def find_and_remove_patterns(root_dir: str, patterns: List[str]) -> List[str]: + """ + 查找并删除匹配模式的文件/文件夹 + + Args: + root_dir: 根目录 + patterns: 文件/文件夹模式列表 + + Returns: + 已删除的文件/文件夹列表 + """ + removed = [] + root_path = Path(root_dir) + + for pattern in patterns: + for item in root_path.rglob(pattern): + if item.exists(): + try: + if item.is_file(): + item.unlink() + removed.append(str(item)) + elif item.is_dir(): + shutil.rmtree(item) + removed.append(str(item)) + except Exception as e: + print(f"警告: 无法删除 {item}: {e}") + + return removed + + +def find_large_files(root_dir: str, size_mb: int = 50) -> List[Tuple[str, float]]: + """ + 查找大文件 + + Args: + root_dir: 根目录 + size_mb: 文件大小阈值(MB) + + Returns: + (文件路径, 大小MB) 列表 + """ + large_files = [] + root_path = Path(root_dir) + size_bytes = size_mb * 1024 * 1024 + + for item in root_path.rglob('*'): + if item.is_file(): + try: + size = item.stat().st_size + if size > size_bytes: + size_mb_actual = size / (1024 * 1024) + large_files.append((str(item), size_mb_actual)) + except Exception as e: + print(f"警告: 无法检查 {item}: {e}") + + return large_files + + +def check_gitignore(root_dir: str) -> bool: + """ + 检查是否存在 .gitignore 文件 + + Args: + root_dir: 根目录 + + Returns: + 是否存在 .gitignore + """ + gitignore_path = Path(root_dir) / '.gitignore' + return gitignore_path.exists() + + +def create_upload_checklist(root_dir: str) -> str: + """ + 创建上传检查清单 + + Args: + root_dir: 根目录 + + Returns: + 检查清单内容 + """ + checklist = """# Hugging Face 上传检查清单 + +## ✅ 清理完成 + +### 已删除的文件类型 +- `__pycache__/` 文件夹 +- `*.pyc`, `*.pyo`, `*.pyd` 文件 +- `.DS_Store` 文件(macOS) +- `.vscode/`, `.idea/` 文件夹 +- `*.swp`, `*.swo` 文件 + +### 需要手动检查的项目 + +1. **大文件检查** + - 检查是否有超过50MB的文件 + - 考虑使用 Git LFS 或排除这些文件 + +2. **敏感信息检查** + - 检查是否有API密钥、密码等敏感信息 + - 检查配置文件中的敏感数据 + +3. **数据文件** + - 检查 `examples/data/` 目录 + - 如果数据文件很大,考虑排除或使用外部链接 + +4. **模型文件** + - 检查是否有预训练模型文件 + - 大模型文件应使用 Git LFS 或 Hugging Face Model Hub + +5. **日志文件** + - 确保没有日志文件被包含 + - 检查 `log/`, `logs/` 目录 + +## 📦 上传到 Hugging Face + +### 方法1: 使用 Hugging Face CLI + +```bash +# 安装 Hugging Face CLI +pip install huggingface_hub + +# 登录 +huggingface-cli login + +# 创建仓库(如果还没有) +# 在 https://huggingface.co/new 创建新仓库 + +# 上传文件 +cd /path/to/EasyTemporalPointProcess-main +huggingface-cli upload / . --repo-type dataset +``` + +### 方法2: 使用 Git + +```bash +# 初始化 Git 仓库(如果还没有) +git init +git add . +git commit -m "Initial commit" + +# 添加 Hugging Face 远程仓库 +git remote add origin https://huggingface.co// + +# 推送 +git push origin main +``` + +### 方法3: 使用 Web 界面 + +1. 访问 https://huggingface.co/new +2. 创建新的 Dataset 或 Space +3. 使用 Web 界面上传文件 + +## 📝 文件结构说明 + +``` +EasyTemporalPointProcess-main/ +├── easy_tpp/ # 核心库代码 +├── examples/ # 示例代码 +├── notebooks/ # Jupyter notebooks +├── tests/ # 测试代码 +├── docs/ # 文档 +├── compute_cascade_metrics.py # 新增:级联指标计算脚本 +├── COMPUTE_METRICS_README.md # 新增:指标计算说明 +├── requirements.txt # 基础依赖 +├── requirements_compute_metrics.txt # 新增:指标计算依赖 +├── setup.py # 安装脚本 +└── README.md # 项目说明 +``` + +## ⚠️ 注意事项 + +1. **不要上传大文件到 Git 仓库** + - 使用 Git LFS 或 Hugging Face 的存储系统 + - 考虑使用外部链接引用大文件 + +2. **检查许可证** + - 确保所有代码都有适当的许可证 + - 检查第三方依赖的许可证兼容性 + +3. **README 文件** + - 确保 README.md 清晰说明项目用途 + - 包含安装和使用说明 + +4. **依赖管理** + - 确保 requirements.txt 是最新的 + - 考虑使用 `pip freeze` 生成精确版本 + +## 🔍 验证上传 + +上传后,检查: +- [ ] 所有文件都已上传 +- [ ] 文件大小合理 +- [ ] 没有敏感信息泄露 +- [ ] README 显示正确 +- [ ] 代码可以正常下载和使用 +""" + + return checklist + + +def main(): + """主函数""" + root_dir = os.path.dirname(os.path.abspath(__file__)) + + print("=" * 60) + print("清理脚本:准备上传到 Hugging Face") + print("=" * 60) + + # 要删除的模式 + patterns_to_remove = [ + '__pycache__', + '*.pyc', + '*.pyo', + '*.pyd', + '.DS_Store', + '.vscode', + '.idea', + '*.swp', + '*.swo', + '*.log', + '.pytest_cache', + '.mypy_cache', + '.ruff_cache', + ] + + print("\n1. 清理不必要的文件...") + removed = find_and_remove_patterns(root_dir, patterns_to_remove) + if removed: + print(f" 已删除 {len(removed)} 个文件/文件夹") + for item in removed[:10]: # 只显示前10个 + print(f" - {item}") + if len(removed) > 10: + print(f" ... 还有 {len(removed) - 10} 个文件/文件夹") + else: + print(" 没有找到需要删除的文件") + + # 检查大文件 + print("\n2. 检查大文件(>50MB)...") + large_files = find_large_files(root_dir, size_mb=50) + if large_files: + print(f" 找到 {len(large_files)} 个大文件:") + for file_path, size_mb in large_files: + print(f" - {file_path} ({size_mb:.2f} MB)") + print("\n ⚠️ 建议:大文件应使用 Git LFS 或排除在上传之外") + else: + print(" ✅ 没有找到大文件") + + # 检查 .gitignore + print("\n3. 检查 .gitignore...") + if check_gitignore(root_dir): + print(" ✅ .gitignore 文件存在") + else: + print(" ⚠️ 警告: .gitignore 文件不存在") + + # 创建检查清单 + print("\n4. 创建上传检查清单...") + checklist_content = create_upload_checklist(root_dir) + checklist_path = Path(root_dir) / 'UPLOAD_CHECKLIST.md' + with open(checklist_path, 'w', encoding='utf-8') as f: + f.write(checklist_content) + print(f" ✅ 已创建: {checklist_path}") + + print("\n" + "=" * 60) + print("清理完成!") + print("=" * 60) + print("\n下一步:") + print("1. 查看 UPLOAD_CHECKLIST.md 了解上传步骤") + print("2. 检查是否有敏感信息需要移除") + print("3. 按照检查清单上传到 Hugging Face") + + +if __name__ == '__main__': + main() diff --git a/compute_cascade_metrics.py b/compute_cascade_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..2827567c23c5e202fb9566ecaab4f0c5ddb55a62 --- /dev/null +++ b/compute_cascade_metrics.py @@ -0,0 +1,568 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +计算信息级联的指标:情感得分、情感deviation、contextual deviation、perplexity + +该脚本处理 information_cascade.json 和 information_cascade_original_posts.json, +计算以下指标: +1. 情感得分 (sentiment score) +2. 情感deviation (sentiment deviation) +3. Contextual deviation (语境偏差) +4. Perplexity (困惑度) + +使用方法(在云电脑上): + python compute_cascade_metrics.py \ + --input_cascade information_cascade.json \ + --input_original information_cascade_original_posts.json \ + --output output_with_metrics.json \ + --bert_model bert-base-chinese \ + --sentiment_model \ + --perplexity_model \ + --batch_size 32 +""" + +import argparse +import json +import numpy as np +import torch +from typing import Dict, List, Any, Optional, Tuple +from tqdm import tqdm +from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM +import os + + +class CascadeMetricsComputer: + """ + 计算级联数据的各种指标 + """ + + def __init__( + self, + bert_model_name: str = 'bert-base-chinese', + sentiment_model_name: Optional[str] = None, + perplexity_model_name: Optional[str] = None, + device: Optional[str] = None, + batch_size: int = 32, + max_length: int = 512 + ): + """ + 初始化指标计算器 + + Args: + bert_model_name: BERT模型名称(用于计算语义向量和contextual deviation) + sentiment_model_name: 情感分析模型名称(用于计算情感得分) + perplexity_model_name: 语言模型名称(用于计算困惑度) + device: 计算设备('cuda'或'cpu'),如果为None则自动选择 + batch_size: 批处理大小 + max_length: 最大序列长度 + """ + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + self.device = device + self.batch_size = batch_size + self.max_length = max_length + + print(f"正在加载BERT模型: {bert_model_name}") + self.bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) + self.bert_model = AutoModel.from_pretrained(bert_model_name) + self.bert_model.to(device) + self.bert_model.eval() + print(f"BERT模型已加载到设备: {device}") + + # 加载情感分析模型 + if sentiment_model_name: + print(f"正在加载情感分析模型: {sentiment_model_name}") + self.sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name) + self.sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name) + self.sentiment_model.to(device) + self.sentiment_model.eval() + print(f"情感分析模型已加载到设备: {device}") + else: + self.sentiment_tokenizer = None + self.sentiment_model = None + print("未提供情感分析模型,将使用简化的情感计算方法") + + # 加载困惑度模型(语言模型) + if perplexity_model_name: + print(f"正在加载困惑度模型: {perplexity_model_name}") + self.perplexity_tokenizer = AutoTokenizer.from_pretrained(perplexity_model_name) + self.perplexity_model = AutoModelForCausalLM.from_pretrained(perplexity_model_name) + self.perplexity_model.to(device) + self.perplexity_model.eval() + print(f"困惑度模型已加载到设备: {device}") + else: + self.perplexity_tokenizer = None + self.perplexity_model = None + print("未提供困惑度模型,将使用简化的困惑度计算方法") + + def compute_embeddings(self, texts: List[str]) -> np.ndarray: + """ + 计算BERT语义向量 + + Args: + texts: 文本列表 + + Returns: + 语义向量矩阵 [num_texts, hidden_size] + """ + embeddings = [] + + with torch.no_grad(): + for i in range(0, len(texts), self.batch_size): + batch_texts = texts[i:i + self.batch_size] + + # 处理空文本 + batch_texts = [text if text else "[PAD]" for text in batch_texts] + + # 分词和编码 + inputs = self.bert_tokenizer( + batch_texts, + return_tensors='pt', + padding=True, + truncation=True, + max_length=self.max_length + ).to(self.device) + + # 前向传播 + outputs = self.bert_model(**inputs) + + # 使用[CLS]标记的嵌入 + batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() + embeddings.append(batch_embeddings) + + return np.vstack(embeddings) + + def compute_sentiment_scores(self, texts: List[str]) -> List[float]: + """ + 计算情感得分 + + Args: + texts: 文本列表 + + Returns: + 情感得分列表(每个文本一个得分,范围通常在[-1, 1]或[0, 1]) + """ + if self.sentiment_model is None: + # 使用简化的情感计算方法 + return self._compute_sentiment_simple(texts) + + sentiment_scores = [] + + with torch.no_grad(): + for i in range(0, len(texts), self.batch_size): + batch_texts = texts[i:i + self.batch_size] + batch_texts = [text if text else "[PAD]" for text in batch_texts] + + inputs = self.sentiment_tokenizer( + batch_texts, + return_tensors='pt', + padding=True, + truncation=True, + max_length=self.max_length + ).to(self.device) + + outputs = self.sentiment_model(**inputs) + logits = outputs.logits + + # 假设是二分类(正面/负面),使用softmax获取概率 + probs = torch.softmax(logits, dim=-1) + + # 计算情感得分:正面概率 - 负面概率(或使用其他方法) + if probs.shape[1] == 2: + # 二分类:[负面概率, 正面概率] + batch_scores = (probs[:, 1] - probs[:, 0]).cpu().numpy().tolist() + else: + # 多分类或其他情况,使用第一个类别的概率作为得分 + batch_scores = probs[:, 0].cpu().numpy().tolist() + + sentiment_scores.extend(batch_scores) + + return sentiment_scores + + def _compute_sentiment_simple(self, texts: List[str]) -> List[float]: + """ + 简化的情感计算方法(基于启发式规则) + + Args: + texts: 文本列表 + + Returns: + 情感得分列表 + """ + scores = [] + for text in texts: + if not text: + scores.append(0.0) + continue + + # 简单的启发式方法 + positive_words = ['好', '棒', '赞', '喜欢', '支持', '👍', '❤️', '😊', '😄'] + negative_words = ['差', '坏', '讨厌', '反对', '👎', '😢', '😠', '😡'] + + positive_count = sum(1 for word in positive_words if word in text) + negative_count = sum(1 for word in negative_words if word in text) + + # 计算情感得分(归一化到[-1, 1]) + total_words = len(text) + if total_words > 0: + score = (positive_count - negative_count) / max(total_words, 1) + score = np.clip(score, -1.0, 1.0) + else: + score = 0.0 + + scores.append(score) + + return scores + + def compute_perplexity(self, texts: List[str]) -> List[float]: + """ + 计算困惑度 + + Args: + texts: 文本列表 + + Returns: + 困惑度列表 + """ + if self.perplexity_model is None: + # 使用简化的困惑度计算方法 + return self._compute_perplexity_simple(texts) + + perplexities = [] + + with torch.no_grad(): + for text in texts: + if not text: + perplexities.append(0.0) + continue + + # 分词 + inputs = self.perplexity_tokenizer( + text, + return_tensors='pt', + truncation=True, + max_length=self.max_length + ).to(self.device) + + # 计算困惑度 + outputs = self.perplexity_model(**inputs, labels=inputs['input_ids']) + loss = outputs.loss + + # 困惑度 = exp(loss) + perplexity = torch.exp(loss).item() + perplexities.append(perplexity) + + return perplexities + + def _compute_perplexity_simple(self, texts: List[str]) -> List[float]: + """ + 简化的困惑度计算方法(基于词汇多样性) + + Args: + texts: 文本列表 + + Returns: + 困惑度列表 + """ + perplexities = [] + + for text in texts: + if not text: + perplexities.append(0.0) + continue + + # 基于词汇多样性的简化方法 + words = text.split() + unique_words = len(set(words)) + total_words = len(words) + + if total_words > 0: + # 词汇多样性越低,困惑度越高(简化代理) + perplexity_proxy = 1.0 - (unique_words / total_words) + else: + perplexity_proxy = 0.0 + + perplexities.append(perplexity_proxy) + + return perplexities + + def compute_cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float: + """ + 计算余弦相似度 + + Args: + vec1: 向量1 + vec2: 向量2 + + Returns: + 余弦相似度 [0, 1] + """ + dot_product = np.dot(vec1, vec2) + norm1 = np.linalg.norm(vec1) + norm2 = np.linalg.norm(vec2) + + if norm1 == 0 or norm2 == 0: + return 0.0 + + similarity = dot_product / (norm1 * norm2) + return float(similarity) + + def compute_contextual_deviation(self, root_embedding: np.ndarray, current_embedding: np.ndarray) -> float: + """ + 计算语境偏差(Contextual Deviation) + + 定义为:1 - 语义相似度 + + Args: + root_embedding: 原帖的语义向量 + current_embedding: 当前文本的语义向量 + + Returns: + 语境偏差值 [0, 1],越高表示越偏离原帖语境 + """ + similarity = self.compute_cosine_similarity(root_embedding, current_embedding) + deviation = 1.0 - similarity + return deviation + + def compute_sentiment_deviation(self, root_sentiment: float, current_sentiment: float) -> float: + """ + 计算情感偏差(Sentiment Deviation) + + 定义为:|当前情感得分 - 原帖情感得分| + + Args: + root_sentiment: 原帖的情感得分 + current_sentiment: 当前文本的情感得分 + + Returns: + 情感偏差值 [0, 2](如果情感得分范围是[-1, 1]) + """ + deviation = abs(current_sentiment - root_sentiment) + return deviation + + def process_cascade(self, cascade: Dict[str, Any]) -> Dict[str, Any]: + """ + 处理单个级联,计算所有指标 + + Args: + cascade: 级联数据字典 + + Returns: + 添加了指标后的级联数据字典 + """ + # 1. 收集所有文本 + texts: List[str] = [] + indices: List[Tuple[str, Optional[str]]] = [] + + # 原帖 + post_info = cascade.get('post_info', {}) + post_content = post_info.get('content', '') + texts.append(post_content) + indices.append(('post', None)) + + # 评论 + comment_tree = cascade.get('comment_tree', {}) + comment_ids = list(comment_tree.keys()) + for comment_id in comment_ids: + node = comment_tree[comment_id] + texts.append(node.get('content', '')) + indices.append(('comment', comment_id)) + + # 转发 + repost_chain = cascade.get('repost_chain', []) + for node in repost_chain: + forward_text = node.get('forward_text', '') or '' + comment_content = node.get('comment_content', '') or '' + repost_text = forward_text + comment_content + texts.append(repost_text) + indices.append(('repost', node.get('repost_id'))) + + # 2. 批量计算特征 + if len(texts) == 0: + return cascade + + embeddings = self.compute_embeddings(texts) + sentiment_scores = self.compute_sentiment_scores(texts) + perplexities = self.compute_perplexity(texts) + + # 3. 获取原帖的特征(用于计算偏差) + root_embedding = embeddings[0] + root_sentiment = sentiment_scores[0] + + # 4. 将特征附加到级联数据中 + # 原帖 + post_info['embedding'] = root_embedding.tolist() + post_info['sentiment_score'] = root_sentiment + post_info['perplexity'] = perplexities[0] + + # 评论 + for i, comment_id in enumerate(comment_ids): + node = comment_tree[comment_id] + idx = 1 + i # 跳过原帖 + + node['embedding'] = embeddings[idx].tolist() + node['sentiment_score'] = sentiment_scores[idx] + node['perplexity'] = perplexities[idx] + + # 计算偏差 + node['contextual_deviation'] = self.compute_contextual_deviation( + root_embedding, embeddings[idx] + ) + node['sentiment_deviation'] = self.compute_sentiment_deviation( + root_sentiment, sentiment_scores[idx] + ) + + # 转发 + offset = 1 + len(comment_ids) + for j, node in enumerate(repost_chain): + idx = offset + j + + node['embedding'] = embeddings[idx].tolist() + node['sentiment_score'] = sentiment_scores[idx] + node['perplexity'] = perplexities[idx] + + # 计算偏差 + node['contextual_deviation'] = self.compute_contextual_deviation( + root_embedding, embeddings[idx] + ) + node['sentiment_deviation'] = self.compute_sentiment_deviation( + root_sentiment, sentiment_scores[idx] + ) + + return cascade + + +def load_json_file(file_path: str) -> Dict[str, Any]: + """ + 加载JSON文件(支持大文件) + + Args: + file_path: JSON文件路径 + + Returns: + 数据字典 + """ + print(f"正在加载JSON文件: {file_path}") + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + print(f"已加载 {len(data.get('cascades', []))} 个级联") + return data + + +def main(): + parser = argparse.ArgumentParser( + description='计算信息级联的指标:情感得分、情感deviation、contextual deviation、perplexity' + ) + parser.add_argument( + '--input_cascade', + type=str, + required=True, + help='输入级联JSON文件路径 (information_cascade.json)' + ) + parser.add_argument( + '--input_original', + type=str, + default=None, + help='输入原帖JSON文件路径 (information_cascade_original_posts.json),可选' + ) + parser.add_argument( + '--output', + type=str, + required=True, + help='输出JSON文件路径' + ) + parser.add_argument( + '--bert_model', + type=str, + default='bert-base-chinese', + help='BERT模型名称或路径(用于计算语义向量)' + ) + parser.add_argument( + '--sentiment_model', + type=str, + default=None, + help='情感分析模型名称或路径(可选)' + ) + parser.add_argument( + '--perplexity_model', + type=str, + default=None, + help='语言模型名称或路径(用于计算困惑度,可选)' + ) + parser.add_argument( + '--batch_size', + type=int, + default=32, + help='批处理大小' + ) + parser.add_argument( + '--max_length', + type=int, + default=512, + help='最大序列长度' + ) + parser.add_argument( + '--device', + type=str, + default=None, + help='计算设备(cuda/cpu),如果为None则自动选择' + ) + parser.add_argument( + '--max_cascades', + type=int, + default=None, + help='最大处理级联数量(用于测试,None表示处理所有)' + ) + + args = parser.parse_args() + + # 加载数据 + cascade_data = load_json_file(args.input_cascade) + + if args.input_original: + original_data = load_json_file(args.input_original) + # 如果需要合并数据,在这里处理 + # 目前先只处理cascade_data + + # 初始化指标计算器 + print("\n初始化指标计算器...") + computer = CascadeMetricsComputer( + bert_model_name=args.bert_model, + sentiment_model_name=args.sentiment_model, + perplexity_model_name=args.perplexity_model, + device=args.device, + batch_size=args.batch_size, + max_length=args.max_length + ) + + # 处理级联 + cascades = cascade_data.get('cascades', []) + total_cascades = len(cascades) + if args.max_cascades: + cascades = cascades[:args.max_cascades] + + print(f"\n开始处理 {len(cascades)}/{total_cascades} 个级联...") + processed_count = 0 + for idx, cascade in enumerate(tqdm(cascades, desc="处理级联")): + try: + cascade_data['cascades'][idx] = computer.process_cascade(cascade) + processed_count += 1 + except Exception as e: + print(f"\n处理级联 {idx} 时出错: {e}") + import traceback + traceback.print_exc() + continue + + print(f"\n成功处理 {processed_count}/{len(cascades)} 个级联") + + # 保存结果 + print(f"\n正在保存结果到: {args.output}") + with open(args.output, 'w', encoding='utf-8') as f: + json.dump(cascade_data, f, ensure_ascii=False, indent=2) + + print(f"✅ 完成!结果已保存到: {args.output}") + + +if __name__ == '__main__': + main() diff --git a/data/cascades/.gitkeep b/data/cascades/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..4f4ae210f199be735fbfa04565f1dcbf77eb92c3 --- /dev/null +++ b/data/cascades/.gitkeep @@ -0,0 +1,3 @@ +# 此目录用于存放级联数据文件 +# 数据文件太大,已通过 .gitignore 排除 +# 请参考 DATA_FILES_NOTICE.md 了解如何获取数据文件 diff --git a/data/cascades/README.md b/data/cascades/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0c1b728d52961266a891c8bbb5219a77fc1708b3 --- /dev/null +++ b/data/cascades/README.md @@ -0,0 +1,101 @@ +# Cascade Data Files + +本目录包含信息级联数据文件。 + +## 📁 文件说明 + +### 主要文件 + +1. **`information_cascade.json`** (606MB) + - 完整的级联数据,包含原帖、评论、转发等信息 + - 用于计算级联指标和训练模型 + +2. **`information_cascade_original_posts.json`** (980MB) + - 原帖数据 + - 包含原始微博帖子信息 + +## ⚠️ 文件大小说明 + +这些文件较大(总计约 1.6GB),**不会自动上传到 Git/Hugging Face**。 + +## 📥 如何获取数据文件 + +### 方法1: 手动下载 + +数据文件需要单独下载或传输到云电脑: + +```bash +# 在云电脑上创建目录 +mkdir -p data/cascades + +# 使用 scp 或其他方式传输文件 +scp user@local:/path/to/information_cascade.json ./data/cascades/ +scp user@local:/path/to/information_cascade_original_posts.json ./data/cascades/ +``` + +### 方法2: 使用 Git LFS(如果配置) + +如果使用 Git LFS: + +```bash +# 安装 Git LFS +git lfs install + +# 跟踪大文件 +git lfs track "data/cascades/*.json" + +# 添加并提交 +git add .gitattributes +git add data/cascades/*.json +git commit -m "Add cascade data files with LFS" +``` + +### 方法3: 使用外部存储 + +- 上传到云存储(如 Google Drive, Dropbox) +- 使用 Hugging Face Dataset Hub 的存储系统 +- 使用对象存储服务(如 AWS S3, 阿里云 OSS) + +## 🚀 使用数据文件 + +### 运行指标计算 + +```bash +python compute_cascade_metrics.py \ + --input_cascade data/cascades/information_cascade.json \ + --input_original data/cascades/information_cascade_original_posts.json \ + --output output_with_metrics.json \ + --batch_size 32 +``` + +### 数据格式 + +JSON 文件格式: +```json +{ + "cascades": [ + { + "post_info": { + "content": "...", + "timestamp": "..." + }, + "comment_tree": {...}, + "repost_chain": [...] + } + ] +} +``` + +详细格式说明请参考项目文档。 + +## 📝 注意事项 + +1. **文件大小**: 这些文件很大,确保有足够的磁盘空间 +2. **内存**: 加载完整文件可能需要大量内存 +3. **处理**: 建议使用批处理方式处理数据 +4. **备份**: 建议保留数据文件的备份 + +## 🔗 相关文档 + +- [指标计算说明](../COMPUTE_METRICS_README.md) +- [上传指南](../HF_UPLOAD_GUIDE.md) diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..d0c3cbf1020d5c292abdedf27627c6abe25e2293 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bd0cbdd74653ab558ff2b83d11bab245f8a07c3c --- /dev/null +++ b/docs/README.md @@ -0,0 +1,13 @@ +# Documentation for EasyTPP + +This contains the full documentation of EasyTPP, which is hosted at github and can be updated manually (for releases) +by pushing to the gh-pages branch. + + +To generate the documentation locally, type + +``` +pip install -r requirements-doc.txt +cd docs +make html +``` \ No newline at end of file diff --git a/docs/images/thinning_algo.jpg b/docs/images/thinning_algo.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b686b35eb08fe4e51fa2fa8dd9363b2556db9856 --- /dev/null +++ b/docs/images/thinning_algo.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f025ac2ec9c15033ba78ddbdc2b3cb90a842d3df1ac4c73d056fd61ee1304fec +size 235643 diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..dc1312ab09ca6fb0267dee6b28a38e69c253631a --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/advanced/implementation.rst b/docs/source/advanced/implementation.rst new file mode 100644 index 0000000000000000000000000000000000000000..8e7b4b536efd14e7a868db9151afb67fcc6434a0 --- /dev/null +++ b/docs/source/advanced/implementation.rst @@ -0,0 +1,143 @@ +=================================== +Model Implementation Details +=================================== + +Basic structure +=================================== + +In the model folder, `torch_basemodel` (**/model/torch_model/torch_basemodel.py**) / `tf_basemodel` (**/model/tf_model/tf_basemodel.py**) implements functionalities of computing loglikelihood and sampling procedures that are common +to all the TPP models. In the inherited class, models with specific structures are defined, explained in below sections. + + +Computing the loglikelihood of non-pad event sequence +------------------------------------------------------ + +The loglikelihood computation, following the definition in Equation 8 of `The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process `_, is shared by all the TPP models. + +it takes `time_delta_seqs`, `lambda_at_event`, `lambdas_loss_samples`, `seq_mask`, + `lambda_type_mask` as the input and output the loglikelihood items, please see `torch_basemodel` (**/model/torch_model/torch_basemodel.py**) / `tf_basemodel` (**/model/tf_model/tf_basemodel.py**) +for details. + +It is noted that: + +1. Sequential prediction: because we performance sequential prediction, i.e., predict next one given previous, we do not consider the last one as it has no labels. To implement the `forward` function, we take input of `time_seqs[:, :-1]` +and `type_seqs[:, :-1]`. For `time_delta_seqs` it is different; please see the next point. + + + +2. Continuous-time evolution: recall the definition in [dataset](./dataset.rst), assume we have a sequence of 4 events and 1 pad event +at the end, i.e., + +.. code-block:: bash + + index: 0, 1, 2, 3, 4 + dtimes: 0, t_1-t_0, t_2-t_1, t_3-t_2, pad + types: e_0, e_1, e_2, e_3, pad + non_pad_mask: True, True, True, True, False + +For the i-th event, i-th dtime denotes the time evolution (e.g., decay in NHP) to the current event and +(i+1)-th dtime denotes the time evolution to the next event. To compute the non-event loglikelihood, +we should consider the time evolution after the event happens. Therefore we should use `type_delta_seqs[:, 1:]` with masks specified in the below step. + +3. Masking: suppose we have predictions of 0,1,2,3-th event and their labels are 1,2,3,4-th events +where $4$-th event needed to be masked. So we should set the sequence mask as `True, True, True, False`, i.e., `seq_mask=batch_non_pad_mask[:, 1:]`. +The same logic applies to the attention mask and event type mask. + +Therefore the following code is a typical example of calling the loglikelihood computation: + + +.. code-block:: python + + event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, # seq_len = max_len - 1 + lambdas_loss_samples=lambda_t_sample, # seq_len = max_len - 1 + time_delta_seq=time_delta_seq[:, 1:], + seq_mask=batch_non_pad_mask[:, 1:], + lambda_type_mask=type_mask[:, 1:]) + + + +Computing the integral inside the loglikelihood +----------------------------------------------- + + +The loglikelihood of the parameters is the sum of the log-intensities of the events that happened, at the times they happened, +minus an integral of the total intensities over the observation interval over [0,T]: + +.. math:: + + \sum_{t_i}\log \lambda_{k_i}(t_i) - \int_0^T \lambda(t) dt + +The first term refers to event loglikelihood and the second term (including the negative sign) refers to the non-event loglikelihood. + + + + + + +Neural Hawkes Process (NHP) +=================================== + +We implement NHP based on author's official pytorch code `Github:nce-mpp `_. + +1. A continuous-time LSTM is introduced, with the code mainly come from `Github:nce-mpp `_. +2. A `forward` function in NHP class that recursively update the states: we compute the event embedding, pass to the LSTM cell and then decay afterwards. Noted that for i-th event, we should use (i+1)-th dt for the decay. So we do not consider the last event as it has no decay time. + +Attentive Neural Hawkes Process (AttNHP) +======================================== + + +We implement AttNHP based on the authors' official pytorch code `Github:anhp-andtt `_ +and similar to NHP, we factorize it into based model and inherited model. + +The forward functions is implemented faithfully to that of the author's repo. + + +Transformer Hawkes Process (THP) +======================================== + +We implement THP based on a fixed version of pytorch code `Github:anhp-andtt/thp `_ +and we factorize it into based model and inherited model. + + +Self-Attentive Hawkes Process (SAHP) +======================================== + +We implement SAHP based on a fixed version of pytorch code `Github:anhp-andtt/sahp `_ +and we factorize it into based model and inherited model. + +`SAHP` basically shares very similar structure to that of `THP`. + + + +Recurrent Marked Temporal Point Processes (RMTPP) +==================================================== + +We implement RMTPP faithfully to the author's paper. + + +Intensity Free Learning of Temporal Point Process (IntensityFree) +================================================================== + +We implement the model based on the author's torch code `Github:ifl-tpp `_. + +A small difference between our implementation and the author's is we ignore the `context_init` (the initial state of the RNN) because in our data setup, we do not need a learnable initial RNN state. This modification generally makes little impact on the learning process. + +It is worth noting that the thinning algorithm can not be applied to this model because it is intensity-free. When comparing the performance of the model, we only look at its log-likelihood learning curve. + + +Fully Neural Network based Model for General Temporal Point Processes (FullyNN) +=============================================================================== + +We implement the model based on the author's keras code `Github:NeuralNetworkPointProcess `_. + + +ODE-based Temporal Point Process (ODETPP) +========================================= + +We implement a TPP with Neural ODE state evolution, which is a simplified version of `Neural Spatio-Temporal Point Processes `_. The ODE implementation uses the code from the `blog `_ + + +Attentive Neural Hawkes Network (ANHN) +====================================== + +We implement the model based on the author's paper: the attentive model without the graph regularizer is named ANHN. diff --git a/docs/source/advanced/performance_valid.rst b/docs/source/advanced/performance_valid.rst new file mode 100644 index 0000000000000000000000000000000000000000..522e10e5cfcd7175611fb4c9e7bcb5a5577d2e62 --- /dev/null +++ b/docs/source/advanced/performance_valid.rst @@ -0,0 +1,41 @@ +========================================= +Performance validation of EasyTPP models +========================================= + +We run the experiments on various dataset to validate the implementations: each model is trained with a max number of epochs and +the best model is selected based on the performance on the valid set, then we report the results on the test set. + + +Simulated dataset +--------------------------- +Conttime +********************** + + + ++--------------+----------+----------+----------+--------------------+ +| Models | Loglike | RMSE | Acc | Num Training Epochs| ++==============+==========+==========+==========+====================+ +| Torch_NHP | -0.93504 | 0.34000 | 0.38656 | 200 | ++--------------+----------+----------+----------+--------------------+ +| Tf_NHP | -0.85774 | 0.34014 | 0.38806 | 200 | ++--------------+----------+----------+----------+--------------------+ +| Torch_AttNHP | -1.02001 | 0.33678 | 0.36782 | 200 | ++--------------+----------+----------+----------+--------------------+ +| Tf_AttNHP | -1.02315 | 0.33816 | 0.19456 | 200 | ++--------------+----------+----------+----------+--------------------+ +| Torch_AttNHP | -1.00593 | 0.33685 | 0.37723 | 500 | ++--------------+----------+----------+----------+--------------------+ +| Tf_AttNHP | -0.99827 | 0.33717 | 0.36498 | 500 | ++--------------+----------+----------+----------+--------------------+ +| Torch_THP | -0.99827 | 0.33717 | 0.36498 | 500 | ++--------------+----------+----------+----------+--------------------+ +| Tf_THP | -1.01898 | 0.33677 | 0.37875 | 500 | ++--------------+----------+----------+----------+--------------------+ + + + +## Real dataset +### Taxi + + diff --git a/docs/source/advanced/tensorboard.rst b/docs/source/advanced/tensorboard.rst new file mode 100644 index 0000000000000000000000000000000000000000..05349adcb665188d1a31c70bd4c2845840eeb5be --- /dev/null +++ b/docs/source/advanced/tensorboard.rst @@ -0,0 +1,75 @@ +=================================== +Launching the Tensorboard +=================================== + + +Here we present how to launch the tensorboard within the ``EasyTPP`` framework. + +Step 1: Activate the usage of tensorboard in Config file +======================================================== + + +As shown in `Training Pipeline <../get_started/run_train_pipeline.html>`_, we need to firstly initialize the 'model_config.yaml' file to setup the running config before training or evaluating the model. + +In the ``model config`` (`modeling` attribute of the config), one needs to set ``use_tfb`` to ``True`` in `trainer`. Then before the running process, summary writers tracking the performance on training and valid sets are both initialized. + +.. code-block:: yaml + + NHP_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: NHP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 200 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: True # Activate the tensorboard + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: -1 + model_config: + hidden_size: 64 + loss_integral_num_sample_per_step: 20 + # pretrained_model_dir: ./checkpoints/75518_4377527680_230530-132355/models/saved_model + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 1 + + + +Step 2: Launching the tensorboard +======================================================== + + +We simply go to the output file of the training runner (its directory is specified in `base_dir` of ``base_config``), find out the tensorboard file address and launch it. + +A complete example of using tensorboard can be seen at *examples/run_tensorboard.py*. + + +.. code-block:: python + + import os + + def main(): + # one can find this dir in the config out file + log_dir = './checkpoints/NHP_train_taxi_20220527-20:18:30/tfb_train' + os.system('tensorboard --logdir={}'.format(log_dir)) + return + + + if __name__ == '__main__': + main() \ No newline at end of file diff --git a/docs/source/advanced/thinning_algo.rst b/docs/source/advanced/thinning_algo.rst new file mode 100644 index 0000000000000000000000000000000000000000..79d5e8841df7469e29d958c026f7b584ecd65356 --- /dev/null +++ b/docs/source/advanced/thinning_algo.rst @@ -0,0 +1,56 @@ +============================================== +Thinning Algorithm for Sampling Event Sequence +============================================== + +In ``EasyTPP`` we use ``Thinning algorithm`` depicted in Algorithm 2 +in `The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process `_ +for event sampling. + +The implementation of the algorithm +==================================== + + +We implement the algorithm both in PyTorch and Tensorflow, as seen in *./model/torch_thinning.py* and +*./model/tf_thinning.py*, which basically follow the same procedure. + +The corresponding code is in function ``draw_next_time_one_step``, which consists of the following steps: + +1. Compute the upper bound of the intensity at each event timestamp in function ``compute_intensity_upper_bound``, where we sample some timestamps inside event intervals and output a upper bound intensity matrix [batch_size, seq_len], denoting the upper bound of prediced intensity (for next time interval) for each sequence at each timestamp. +2. Sample the exponential distribution with the intensity computed in Step 1 in function ``sample_exp_distribution``, where we simply divide the standard exponential number with the intensity, which is equivalent to sampling with exp(sample_rate), according to `the properties of Exponential Distribution `_. The exponential random variables have size [batch_size, seq_len, num_sample, num_exp], where num_sample refers to the number of event times sampled in every interval and num_exp refers to number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm. +3. Compute the intensities at the sample times proposed in Step 2, with final size `[batch_size, seq_len, num_sample, num_exp]`. +4. Sample the standard uniform distribution with size `[batch_size, seq_len, num_sample, num_exp]`. +5. Perform the acceptance sampling with certain probability in function ``sample_accept``. +6. The earliest sampling dtimes are accepted. For unaccepted sampling dtimes, use boundary/maxsampletime for that draw. +7. The final predicted dtimes has size `[batch_size, seq_len, num_sample]`, which refers to the sampling dtimes for each sequence at each timestamp, along with an equal weight vector. +8. The product of the predicted dtimes and the weight is the final predicted dtimes, with size `[batch_size, seq_len]`. + + +.. image:: ../../images/thinning_algo.jpg + :alt: thinning_algo + + + +One-step prediction +==================================== +By default, once given the parameters of thinning algo (defining a ``thinning`` config as part of ``model_config``), we perform the one-step prediction in model evaluation, i.e., predict the next event given the prefix. The implementation is in function ``prediction_event_one_step`` in BaseModel (i.e., TorchBaseModel or TfBaseModel). + + +Multi-step prediction +==================================== +The recursive multi-step prediction is activated by setting `num_step_gen` to a number bigger than 1 in the ``thinning`` config. + +Be noted that, we generate the multi-step events after the last non-pad event of each sequence. The implementation is in function `predict_multi_step_since_last_event` in BaseModel (i.e., TorchBaseModel or TfBaseModel). + + +.. code-block:: yaml + + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 5 # by default it is single step, i.e., 1 \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..201cf041e73119c0c103b498348e85d16440eedf --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,59 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + + +# -- Autodoc information ----------------------------------------------------- +# https://sphinx-rtd-tutorial.readthedocs.io/en/latest/sphinx-config.html + + +import os +import sys + +sys.path.insert(0, os.path.abspath('../../easy_tpp/')) + +sys.path.insert(0, os.path.abspath('../..')) + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'EasyTPP' +copyright = '2022, Machine Intelligence, Alipay' +author = 'Machine Intelligence, Alipay' +release = '0.0.2' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + "sphinx.ext.autodoc", + 'sphinx.ext.viewcode', + "sphinx.ext.todo", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + 'sphinx.ext.autosummary' +] + +napoleon_google_docstring = True +napoleon_numpy_docstring = False + +templates_path = ['_templates'] +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = 'sphinx_rtd_theme' +html_static_path = ['_static'] + +autodoc_member_order = "bysource" +autodoc_default_flags = ["members"] +autodoc_default_options = { + "members": True, + "member-order": "bysource", + "special-members": "__init__", +} diff --git a/docs/source/dev_guide/model_custom.rst b/docs/source/dev_guide/model_custom.rst new file mode 100644 index 0000000000000000000000000000000000000000..0f81445f4d27778145e9fe533c8a4a7aec0a0856 --- /dev/null +++ b/docs/source/dev_guide/model_custom.rst @@ -0,0 +1,78 @@ +================== +Customize a Model +================== + + +Here we introduce how to customize a TPP model with the support of ``EasyTPP``. + + + +Create a new TPP Model Class +============================= + +Assume we are building a PyTorch model. We need to initialize the model by inheriting class `EasyTPP.model.torch_model.TorchBaseModel <../ref/models.html>`_. + +.. code-block:: python + + from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel + + # Custom Torch TPP implementations need to + # inherit from the TorchBaseModel interface + class NewModel(TorchBaseModel): + def __init__(self, model_config): + super(NewModel, self).__init__(model_config) + + # Forward along the sequence, output the states / intensities at the event times + def forward(self, batch): + ... + return states + + # Compute the loglikelihood loss + def loglike_loss(self, batch): + .... + return loglike + + # Compute the intensities at given sampling times + # Used in the Thinning sampler + def compute_intensities_at_sample_times(self, batch, sample_times, **kwargs): + ... + return intensities + + +If we are building a Tensorflow model, we start with the following code + +.. code-block:: python + + from easy_tpp.model.torch_model.tf_basemodel import TfBaseModel + + # Custom Tf TPP implementations need to + # inherit from the TorchBaseModel interface + class NewModel(TfBaseModel): + def __init__(self, model_config): + super(NewModel, self).__init__(model_config) + + # Forward along the sequence, output the states / intensities at the event times + def forward(self, batch): + ... + return states + + + # Compute the loglikelihood loss + def loglike_loss(self, batch): + .... + return loglike + + # Compute the intensities at given sampling times + # Used in the Thinning sampler + def compute_intensities_at_sample_times(self, batch, sample_times, **kwargs): + ... + return intensities + +Rewrite Relevant Methods +============================== + +There are three important functions needed to be implemented: + +- `forward`: the input is the batch data and the output is states at each step. +- `loglike_loss`: it computes the loglikihood loss given the batch data. +- `compute_intensities_at_sample_times`: it computes the intensities at each sampling steps. diff --git a/docs/source/get_started/install.rst b/docs/source/get_started/install.rst new file mode 100644 index 0000000000000000000000000000000000000000..f392ce629bd5c6cd19168eecf45296c2c2c58ffd --- /dev/null +++ b/docs/source/get_started/install.rst @@ -0,0 +1,64 @@ +================== +Installation +================== + + +``EasyTPP`` provides an open-source library for `Neural TPP`, with a fully automated pipeline for model training and prediction. + + +Requirements +============= + +.. code-block:: bash + + PyTorch version >= 1.8.0 + Python version >= 3.7 + Tensorflow version >= 1.13.1 (only needed when using Tensorflow backend) + + + +First, we need a python environment whose version is at least greater than 3.7.0. If you don’t have one, please refer to the `Documentation `_ to install and configure the Anaconda environment. + +.. code-block:: bash + + conda create -n easytpp python=3.8 + conda activate easytpp + +Then, install Pytorch and keep the version at least greater than 1.8.0. + +.. code-block:: bash + + pip install torch + +By default, we assume to use PyTorch. If one wants to use Tensorflow backend, please install tensorflow additionally. Both Tensorflow 1.13.1 and 2.x are supported. + +.. code-block:: bash + + pip install tensorflow + + + +Install +===================== + + +Install with pip +-------------------------- + + +.. code-block:: bash + + pip install easy-tpp + + +Install with the source +-------------------------- + +Setup from the source: + +.. code-block:: bash + + git clone https://github.com/ant-research/EasyTemporalPointProcess.git + cd EasyTemporalPointProcess + python setup.py install + diff --git a/docs/source/get_started/introduction.rst b/docs/source/get_started/introduction.rst new file mode 100644 index 0000000000000000000000000000000000000000..c1eace20ebc883de8669e18d9d9b0324ec2fc152 --- /dev/null +++ b/docs/source/get_started/introduction.rst @@ -0,0 +1,60 @@ +================== +Introduction +================== + + +``EasyTPP`` provides an open-source library for `Neural TPP`, with a fully automated pipeline for model training and prediction. + + +Framework +========= + + +``EasyTPP`` supports both Tensorflow and PyTorch: each model has two equivalent versions implemented in Tensorflow 1.13 and Pytorch 1.8 respectively. The data processing and model training / prediction pipeline are compatible with both Tensorflow and Pytorch as well. + + +At the module level, ``EasyTPP`` is a package that consists of the following components, which are designed as loose-coupled modules that provide flexibility for users to develop customized functionalities. + + + +======================== ============================================================================== +Name Description +======================== ============================================================================== +`Preprocess` module Provides data batch-wise padding, inter-time processing and other related work for raw sequence. + +`Model` module Implements a list of SOTA TPP models. Please refer to `Model Validation <../advanced/performance_valid.html>`_ for more details. + +`Config` module Encapsulate the construction of the configuration needed to run the pipeline. + +`Runner` module Controls the training and prediction pipeline. +======================== ============================================================================== + + + +Install +========= + +``EasyTPP`` can be installed either by pip or the source. By default it is built based on PyTorch. If one wants to run with the Tensorflow backend, one needs to install Tensorflow additionally. + +Please see `Installation <./install.html>`_ for details of requirement and installation. + + +Prepare Data +============ + +By default, we use the data in Gatech format, i.e., each dataset is a dict containing the keys such as `time_since_last_event`, `time_since_start` and `type_event`. `Preprocess <../ref/preprocess.html>`_ module +will preprocess the data and feed it into the model. + + +An example of building a pseudo dataloader can be found at `examples `_. Please refer to `Datatset <../user_guide/dataset.html>`_ for more explanations of the `TPP` dataset iterator. + + +Model Training and Prediction +============================== + +The training and prediction pipeline consists of two steps: + +1. Setup the config file, which specifies the dataset dir, model params and pipeline settings. +2. Launch the python script to run the whole pipeline. + +Please see `Training Pipeline <../user_guide/run_train_pipeline.html>`_ and `Evaluation Pipeline <../user_guide/run_eval.html>`_ for more details. \ No newline at end of file diff --git a/docs/source/get_started/quick_start.rst b/docs/source/get_started/quick_start.rst new file mode 100644 index 0000000000000000000000000000000000000000..10778aa49b43bd6489fd029814824ed7ffd121a9 --- /dev/null +++ b/docs/source/get_started/quick_start.rst @@ -0,0 +1,106 @@ +==================== +Quick Start +==================== + + +We use the [Taxi]_ dataset as an example to show how to use ``EasyTPP`` to train a model. More details and results are provided in `Training Pipeline <../user_guide/run_train_pipeline.html>`_. + + +Download Dataset +=================== + + + +The Taxi dataset we used is preprocessed by `HYPRO `_ . You can either download the dataset (in pickle) from Google Drive `here `_ or the dataset (in json) from `HuggingFace `_. + + +Note that if the data sources are pickle files, we need to write the data config (in `Example Config `_) in the following way + +.. code-block:: yaml + + data: + taxi: + data_format: pickle + train_dir: ./data/taxi/train.pkl + valid_dir: ./data/taxi/dev.pkl + test_dir: ./data/taxi/test.pkl + +If we choose to directly load from HuggingFace, we can put it this way: + +.. code-block:: yaml + + data: + taxi: + data_format: json + train_dir: easytpp/taxi + valid_dir: easytpp/taxi + test_dir: easytpp/taxi + + +Meanwhile, it is also feasible to put the local directory of json files downloaded from HuggingFace in the config: + +.. code-block:: yaml + + data: + taxi: + data_format: json + train_dir: ./data/taxi/train.json + valid_dir: ./data/taxi/dev.json + test_dir: ./data/taxi/test.json + + + + +Setup the configuration file +============================== + +We provide a preset config file in `Example Config `_. The details of the configuration can be found in `Training Pipeline <../user_guide/run_train_pipeline.html>`_. + + + + +Train the Model +========================= + +At this stage we need to write a script to run the training pipeline. There is a preset script `train_nhp.py `_ and one can simply copy it. + +Taking the pickle data source for example, after the setup of data, config and running script, the directory structure is as follows: + +.. code-block:: bash + + data + |______taxi + |____ train.pkl + |____ dev.pkl + |____ test.pkl + + configs + |______experiment_config.yaml + + train_nhp.py + + + +The one can simply run the following command. + + +.. code-block:: bash + + python train_nhp.py + + + +Reference +---------- + +.. [Taxi] + +.. code-block:: bash + + @misc{whong-14-taxi, + title = {F{OIL}ing {NYC}’s Taxi Trip Data}, + author={Whong, Chris}, + year = {2014}, + url = {https://chriswhong.com/open-data/foil_nyc_taxi/} + } + diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..c4b2d6ac04a2225a74b68ba9f5c958168d7bf2ad --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,56 @@ +=================================== +``EasyTPP`` Documentation +=================================== + + +``EasyTPP`` is an easy-to-use development and application toolkit for `Neural Temporal Point Process `_ (*Neural TPP*), with key features in configurability, compatibility and reproducibility. We hope this project could benefit both researchers and practitioners with the goal of easily customized development and open benchmarking. + + + +.. toctree:: + :hidden: + +.. toctree:: + :maxdepth: 2 + :caption: GETTING STARTED + + Introduction + Installation + Quick Start + + +.. toctree:: + :maxdepth: 2 + :caption: USER GUIDE + + Dataset + Model Training + Model Prediction + +.. toctree:: + :maxdepth: 2 + :caption: DEVELOPER GUIDE + + Model Customization + + +.. toctree:: + :maxdepth: 2 + :caption: ADVANCED TOPICS + + Thinning Algorithm + Tensorboard + Performance Benchmarks + Implementation Details + +.. toctree:: + :maxdepth: 2 + :caption: API REFERENCE + + Config + Preprocess + Model + Runner + Hyper-parameter Optimization + Tf and Torch Wrapper + Utilities \ No newline at end of file diff --git a/docs/source/ref/config.rst b/docs/source/ref/config.rst new file mode 100644 index 0000000000000000000000000000000000000000..2ce82bba178bfaffe5ad2870b8bb959535b11149 --- /dev/null +++ b/docs/source/ref/config.rst @@ -0,0 +1,10 @@ +.. _api-config: + +EasyTPP Config Modules +============================ + + +.. automodule:: config_factory + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/ref/hpo.rst b/docs/source/ref/hpo.rst new file mode 100644 index 0000000000000000000000000000000000000000..c78288568fefd1bc108a78a9fc5f7d78ef1aba42 --- /dev/null +++ b/docs/source/ref/hpo.rst @@ -0,0 +1,10 @@ +.. _api-config: + +EasyTPP Config Modules +============================ + + +.. automodule:: hpo + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/ref/models.rst b/docs/source/ref/models.rst new file mode 100644 index 0000000000000000000000000000000000000000..1c2d3f26466583428e2f4cb96cedc41657cc8531 --- /dev/null +++ b/docs/source/ref/models.rst @@ -0,0 +1,50 @@ +.. _api-model: + +EasyTPP Models +==================== + + + +.. _api-tf_model: + +model.tf_model module +------------------------------ + +.. automodule:: easy_tpp.model.tf_model +.. autosummary:: + :toctree: ../generated/ + + tf_baselayer + tf_basemodel + tf_nhp + tf_fullynn + tf_intensity_free + tf_ode_tpp + tf_rmtpp + tf_sahp + tf_thp + tf_attnhp + tf_thinning + + +.. _api-torch_model: + +model.torch_model module +------------------------------ + +.. automodule:: easy_tpp.model.torch_model +.. autosummary:: + :toctree: ../generated/ + + torch_baselayer + torch_basemodel + torch_nhp + torch_fullynn + torch_intensity_free + torch_ode_tpp + torch_rmtpp + torch_sahp + torch_thp + torch_attnhp + torch_thinning + diff --git a/docs/source/ref/preprocess.rst b/docs/source/ref/preprocess.rst new file mode 100644 index 0000000000000000000000000000000000000000..c4d4347abb9bf10717fb10158da4fca510d48b4d --- /dev/null +++ b/docs/source/ref/preprocess.rst @@ -0,0 +1,10 @@ +.. _api-preprocess: + +EasyTPP Preprocess Modules +========================== + + +.. automodule:: preprocess + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/ref/runner.rst b/docs/source/ref/runner.rst new file mode 100644 index 0000000000000000000000000000000000000000..e9fec2ab020869d413ae8c7bd9773689605cc3ca --- /dev/null +++ b/docs/source/ref/runner.rst @@ -0,0 +1,10 @@ +.. _api-modelrunner: + +EasyTPP Model Runner Modules +============================ + + +.. automodule:: runner + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/ref/utils.rst b/docs/source/ref/utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..73f4d5b37c223db91cca5fa773cd08e373ca6f7a --- /dev/null +++ b/docs/source/ref/utils.rst @@ -0,0 +1,10 @@ +.. _api-util: + +EasyTPP Utilities Modules +========================== + + +.. automodule:: utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/ref/wrapper.rst b/docs/source/ref/wrapper.rst new file mode 100644 index 0000000000000000000000000000000000000000..bb54f44185df12121775e5edb2aab5ee360c62c4 --- /dev/null +++ b/docs/source/ref/wrapper.rst @@ -0,0 +1,17 @@ +.. _api-wrapper: + +EasyTPP Tf and Torch Wrapper Modules +==================================== + + +.. automodule:: tf_wrapper + :members: + :undoc-members: + :show-inheritance: + + + +.. automodule:: torch_wrapper + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/user_guide/dataset.rst b/docs/source/user_guide/dataset.rst new file mode 100644 index 0000000000000000000000000000000000000000..8a1aece91a8b46b42c702832d0c51278a535e134 --- /dev/null +++ b/docs/source/user_guide/dataset.rst @@ -0,0 +1,124 @@ +=========================================== +Expected Dataset Format and Data Processing +=========================================== + +Required format +=================================== + +In EasyTPP we use the data in Gatech format, i.e., each dataset is a dict containing the following keys as + +.. code-block:: bash + + dim_process: 5 # num of event types (no padding) + 'train': [[{'idx_event': 2, 'time_since_last_event': 1.0267814, 'time_since_last_same_event': 1.0267814, 'type_event': 3, 'time_since_start': 1.0267814}, {'idx_event': 3, 'time_since_last_event': 0.4029268, 'time_since_last_same_event': 1.4297082, 'type_event': 0, 'time_since_start': 1.4297082},...,],[{}...{}]] + +where `dim_process` refers to the number of event types (without padding) and +`train` (or `dev` / `test`) contains a list of list which corresponds to an event sequence each. + +Each pickle file generates a set of event sequences, each containing three sub sequences: + +1. `time_seqs`: absolute timestamps of the events, correspond to `time_since_last_event`. +2. `time_delta_seqs`: relative timestamps of the events, correspond to `time_since_last_same_event`. +3. `type_seqs`: types of the events, correspond to `type_event`. Be noted that the event type index `starts from 0`. + + +Data processing +=================================== + +The data processing follows the similar pipeline as in official code of `AttNHP `_. We name it the process of `event tokenize`. + + +Sequence padding +---------------- + + +time_seqs, time_delta_seqs and type_seqs are firstly padded to `the max length of the whole dataset` and then fed into the model in batch. + +.. code-block:: bash + + input: raw event sequence (e_0, e_1, e_2, e_3) and max_len=6 # the max length among all data seqs + + output: + + index: 0, 1, 2, 3, 4 5 + dtimes: 0, t_1-t_0, t_2-t_1, t_3-t_2, time_pad, time_pad + types: e_0, e_1, e_2, e_3, type_pad, type_pad + + +By default, we set the value of time_pad and type_pad to be the *num_event_types* (because we assume the event type index starts from 0, therefore the integer value of num_event_types is unused). + +Sequence masking +---------------- + + +After padding, we perform the masking for the event sequences and generate three more seqs: batch_non_pad_mask, attention_mask, type_mask: + +1. `batch_non_pad_mask`: it indicates the position of masks in the sequence. +2. `attention_mask`: it indicates the masks used in the attention calculation (one event can only attend to its past events). +3. `type_mask`: it uses one-hot vector to represent the event type. The padded event is a zero vector. + +Finally, each batch contains six elements: time_seqs, time_delta_seqs, event_seq, batch_non_pad_mask, attention_mask, type_mask. The implementation of padding mechanism can be found at `event_tokenizer `_. + + + +An example +---------------- + +We take a real event sequence for example. Assume we have an input sequence $[ 1, 9, 5, 0]$ with num_event_types=11 and max_len=6. + +Then the padded time_seqs, time_delta_seqs and type_seqs become + +.. code-block:: bash + + # time_seqs + [ 0.0000, 0.8252, 1.3806, 1.8349, 11.0000, 11.0000] + + # time_delta_seqs + [ 0.0000, 0.8252, 0.5554, 0.4542, 11.0000, 11.0000] + + # type_seqs + [ 1, 9, 5, 0, 11, 11] + + +The mask sequences are + +.. code-block:: bash + + # batch_non_pad_mask + [ True, True, True, True, False, False] + + # attention_mask + [[True, True, True, True, True, True], + [False, True, True, True, True, True], + [False, False, True, True, True, True], + [False, False, False, True, True, True], + [False, False, False, False, True, True], + [False, False, False, False, True, True]] + + # type_mask + [[False, True, False, False, False, False, False, False, False, False, False], + [False, False, False, False, False, False, False, False, False, True, False], + [False, False, False, False, False, True, False, False, False, False, False], + [True, False, False, False, False, False, False, False, False, False, False], + [False, False, False, False, False, False, False, False, False, False, False], + [False, False, False, False, False, False, False, False, False, False, False]], + + +The runnable examples of constructing and iterating the dataset object can be found at `examples/event_tokenizer.py `_ + + +Preprocessed Datasets +=================================== + +We have preprocessed some widely-used open source datasets in Gatech format, which can be found at `Google Drive `_. We use them for validating and benchmarking EasyTPP models. + +- Retweet (`Zhou, 2013 `_). This dataset contains time-stamped user retweet event sequences. The events are categorized into 3 types: retweets by “small,” “medium” and “large” users. Small users have fewer than 120 followers, medium users have fewer than 1363, and the rest are large users. We work on a subset of 5200 most active users with an average sequence length of 70. +- Taxi (`Whong, 2014 `_). This dataset tracks the time-stamped taxi pick-up and drop-off events across the five boroughs of the New York City; each (borough, pick-up or drop-off) combination defines an event type, so there are 10 event types in total. We work on a randomly sampled subset of 2000 drivers and each driver has a sequence. We randomly sampled disjoint train, dev and test sets with 1400, 200 and 400 sequences. +- StackOverflow ( `Leskovec, 2014 `_). This dataset has two years of user awards on a question-answering website: each user received a sequence of badges and there are 22 different kinds of badges in total. We randomly sampled disjoint train, dev and test sets with 1400,400 and 400 sequences from the dataset. +- Taobao (`Xue et al, 2022 `_). This dataset contains time-stamped user click behaviors on Taobao shopping pages from November 25 to December 03, 2017. Each user has a sequence of item click events with each event containing the timestamp and the category of the item. The categories of all items are first ranked by frequencies and the top 19 are kept while the rest are merged into one category, with each category corresponding to an event type. We work on a subset of 4800 most active users with an average sequence length of 150 and then end up with 20 event types. +- Amazon (`Xue et al, 2022 `_). This dataset includes time-stamped user product reviews behavior from January, 2008 to October, 2018. Each user has a sequence of produce review events with each event containing the timestamp and category of the reviewed product, with each category corresponding to an event type. We work on a subset of 5200 most active users with an average sequence length of 70 and then end up with 16 event types. + +Besides, we also published two textual event sequence datasets: + +- GDELT (`Shi et al, 2023 `_). The GDELT Project monitors events all over the world, with live datasets updated every 15 minutes. We only focused on the political events that happened in G20 countries from 2022-01-01 to 2022-07-31, ending up with a corpus of 109000 time-stamped event tokens. The event type of each token has a structured name of the format subject-predicate-object. Each {predicate} is one of the twenty CAMEO codes such as {CONSULT} and {INVESTIGATE}; each {subject} or {object} is one of the 2279 political entities (individuals, groups, and states) such as {Tesla} and {Australia}. We split the dataset into disjoint train, dev, and test sets based on their dates: the 83100 events that happened before 2022-07-05 are training data; the 16650 events after 2022-07-19 are test data; the 9250 events between these dates are development data. +- Amazon-text-review (`Shi et al, 2023 `_). This dataset contains user reviews on Amazon shopping website from 2014-01-04 to 2016-10-02. We focused on the most active 2500 users and each user has a sequence of product review events. The type is the category of the product: we selected the most frequently-reviewed 23 categories and grouped all the others into a special OTHER category, ending up with 24 categories in total. Each review event also has a mark which is the actual content of the review. Each of the 2500 sequences is cut into three segments: the events that happened before 2015-08-01 are training data; those after 2016-02-01 are test data; the events between these dates are dev data. Then we have 49,680 training tokens, 7,020 dev tokens, and 13,090 test tokens. diff --git a/docs/source/user_guide/run_eval.rst b/docs/source/user_guide/run_eval.rst new file mode 100644 index 0000000000000000000000000000000000000000..bd556e6502d796022f59d4fd287e7e0b83d81b02 --- /dev/null +++ b/docs/source/user_guide/run_eval.rst @@ -0,0 +1,97 @@ +================================ +Evaluate a Model +================================ + +Step 1: Setup the config file +=============================================== + +Same as in the training pipeline, firstly we need to initialize the task configuration in the config file. + +Similar to the setup in `Training Pipeline <./run_train_pipeline.html>`_, we set the `stage` to `eval` and pass the `pretrained_model_dir` to ``the model_config`` + +Note that the *pretrained_model_dir* can be found in the log of the training process. + +.. code-block:: yaml + + NHP_eval: + base_config: + stage: eval + backend: torch + dataset_id: taxi + runner_id: std_tpp + base_dir: './checkpoints/' + model_id: NHP + trainer_config: + batch_size: 256 + max_epoch: 1 + model_config: + hidden_size: 64 + use_ln: False + seed: 2019 + gpu: 0 + pretrained_model_dir: ./checkpoints/26507_4380788096_231111-101848/models/saved_model # must provide this dir + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + + + + +A complete example of these files can be seen at `examples/example_config.yaml `_ . + + +Step 2: Run the evaluation script +================================= + +Same as in the training pipeline, we need to initialize a ``ModelRunner`` object to do the evaluation. + +The following code is an example, which is a copy from `examples/train_nhp.py `_ . + + +.. code-block:: python + + import argparse + + from easy_tpp.config_factory import RunnerConfig + from easy_tpp.runner import Runner + + + def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('--config_dir', type=str, required=False, default='configs/experiment_config.yaml', + help='Dir of configuration yaml to train and evaluate the model.') + + parser.add_argument('--experiment_id', type=str, required=False, default='RMTPP_eval', + help='Experiment id in the config file.') + + args = parser.parse_args() + + config = RunnerConfig.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id) + + model_runner = Runner.build_from_config(config) + + model_runner.run() + + + if __name__ == '__main__': + main() + + + + +Checkout the output +==================== + +The evaluation result will be print in the console and saved in the logs whose directory is specified in the +out config file, i.e.: + +.. code-block:: bash + + 'output_config_dir': './checkpoints/NHP_test_conttime_20221002-13:19:23/NHP_test_output.yaml' diff --git a/docs/source/user_guide/run_train_pipeline.rst b/docs/source/user_guide/run_train_pipeline.rst new file mode 100644 index 0000000000000000000000000000000000000000..9b44876d17586aa9358595382a0a76005c312ff8 --- /dev/null +++ b/docs/source/user_guide/run_train_pipeline.rst @@ -0,0 +1,245 @@ +============================================ +Training a Model & Configuration Explanation +============================================ + +This tutorial shows how one can use ``EasyTPP`` to train the implemented models. + +In principle, firstly we need to initialize a config yaml file, containing all the input configuration to guide the training and eval process. The overall structure of a config file is shown as below: + +.. code-block:: yaml + + pipeline_config_id: .. # name of the config for guiding the pipeline + + data: + [Dataset ID]: # name of the dataset, e.g, taxi + .... + + [EXPERIMENT ID]: # name of the experiment to run + base_config: + .... + model_config: + ... + + +After the config file is setup, we can run the script, by specifying the `config directory` and `experiment id`, to start the pipeline. We currently provide a preset script at `examples/train_nhp.py`. + + +Step 1: Setup the config file containing data and model configs +================================================================ + + +To be specific, one needs to define the following entries in the config file: + +- **pipeline_config_id**: registered name of EasyTPP.Config objects, such as `runner_config` or `hpo_runner_config`. By reading this, the corresponding configuration class will be loaded for constructing the pipeline. + +.. code-block:: yaml + + pipeline_config_id: runner_config + + +- **data**: dataset specifics. One can put multiple dataset specifics in the config file, but only one will be used in one experiment. + + - *[DATASET ID]*: name of the dataset, e.g., taxi. + - *train_dir, valid_dir, test_dir*: directory of the datafile. For the moment we only accept pkl file (please see `Dataset <./dataset.html>`_ for details) + - *data_spec*: define the event type information. + +.. code-block:: yaml + + data: + taxi: + data_format: pkl + train_dir: ../data/taxi/train.pkl + valid_dir: ../data/taxi/dev.pkl + test_dir: ../data/taxi/test.pkl + data_spec: + num_event_types: 7 # num of types excluding pad events. + pad_token_id: 6 # event type index for pad events + padding_side: right # pad at the right end of the sequence + truncation_side: right # truncate at the right end of the sequence + max_len: 100 # max sequence length used as model input + +- **[EXPERIMENT ID]**: name of the experiment to run in the pipeline. It contains two blocks of configs: + +*base_config* contains the pipeline framework related specifications. + +.. code-block:: yaml + + base_config: + stage: train # train, eval and generate + backend: tensorflow # tensorflow and torch + dataset_id: conttime # name of the dataset + runner_id: std_tpp # registered name of the pipeline runner + model_id: RMTPP # model name # registered name of the implemented model + base_dir: './checkpoints/' # base dir to save the logs and models. + + + +*model_config* contains the model related specifications. + + +.. code-block:: yaml + + model_config: + hidden_size: 32 + time_emb_size: 16 + num_layers: 2 + num_heads: 2 + mc_num_sample_per_step: 20 + sharing_param_layer: False + loss_integral_num_sample_per_step: 20 + dropout: 0.0 + use_ln: False + thinning_params: # thinning algorithm for event sampling + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + + +*trainer_config* contains the training related specifications. + +.. code-block:: yaml + + trainer_config: # trainer arguments + seed: 2019 + gpu: 0 + batch_size: 256 + max_epoch: 10 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: False + metrics: ['acc', 'rmse'] + + + + +A complete example of these files can be seen at *examples/example_config*. + + +Step 2: Run the training script +=============================================== + +To run the training process, we simply need to call two functions: + +1. ``Config``: it reads the directory of the configs specified in Step 1 and do some processing to form a complete configuration. +2. ``Runner``: it reads the configuration and setups the whole pipeline for training, evaluation and generation. + + +The following code is an example, which is a copy from *examples/train_nhp.py*. + + +.. code-block:: python + + import argparse + from easy_tpp.config_factory import Config + from easy_tpp.runner import Runner + + + def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('--config_dir', type=str, required=False, default='configs/experiment_config.yaml', + help='Dir of configuration yaml to train and evaluate the model.') + + parser.add_argument('--experiment_id', type=str, required=False, default='RMTPP_train', + help='Experiment id in the config file.') + + args = parser.parse_args() + + config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id) + + model_runner = Runner.build_from_config(config) + + model_runner.run() + + + if __name__ == '__main__': + main() + + + + + +Checkout the output +======================== + + +During training, the log, the best model based on valid set performance, the complete configuration file are all saved. The directory of the saved files is specified in 'base' of ``model_config.yaml``, i.e., + + + +In the `./checkpoints/` folder, one find the correct subfolder by concatenating the 'experiment_id' and running timestamps. Inside that subfolder, there is a complete configuration file, e.g., ``NHP_train_output.yaml`` that records all the information used in the pipeline. The + +.. code-block:: yaml + + data_config: + train_dir: ../data/conttime/train.pkl + valid_dir: ../data/conttime/dev.pkl + test_dir: ../data/conttime/test.pkl + specs: + num_event_types_pad: 6 + num_event_types: 5 + event_pad_index: 5 + data_format: pkl + base_config: + stage: train + backend: tensorflow + dataset_id: conttime + runner_id: std_tpp + model_id: RMTPP + base_dir: ./checkpoints/ + exp_id: RMTPP_train + log_folder: ./checkpoints/98888_4299965824_221205-153425 + saved_model_dir: ./checkpoints/98888_4299965824_221205-153425/models/saved_model + saved_log_dir: ./checkpoints/98888_4299965824_221205-153425/log + output_config_dir: ./checkpoints/98888_4299965824_221205-153425/RMTPP_train_output.yaml + model_config: + hidden_size: 32 + time_emb_size: 16 + num_layers: 2 + num_heads: 2 + mc_num_sample_per_step: 20 + sharing_param_layer: false + loss_integral_num_sample_per_step: 20 + dropout: 0.0 + use_ln: false + seed: 2019 + gpu: 0 + thinning_params: + num_seq: 10 + num_sample: 1 + num_exp: 500 + look_ahead_time: 10 + patience_counter: 5 + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 1 + trainer: + batch_size: 256 + max_epoch: 10 + shuffle: false + optimizer: adam + learning_rate: 0.001 + valid_freq: 1 + use_tfb: false + metrics: + - acc + - rmse + seq_pad_end: true + is_training: true + num_event_types_pad: 6 + num_event_types: 5 + event_pad_index: 5 + model_id: RMTPP + + + +If we set ``use_tfb`` to ``true``, it means we can launch the tensorboard to track the training process, one +can see `Running Tensorboard <../advanced/tensorboard.html>`_ for details. \ No newline at end of file diff --git a/easy_tpp/__init__.py b/easy_tpp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..541f859dc9ea238cf66986b5b7ccd3189c4c74dc --- /dev/null +++ b/easy_tpp/__init__.py @@ -0,0 +1 @@ +__version__ = '0.1.0' \ No newline at end of file diff --git a/easy_tpp/config_factory/__init__.py b/easy_tpp/config_factory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15da373a2af9387108c2bfd7c4ff2e77782c1b5c --- /dev/null +++ b/easy_tpp/config_factory/__init__.py @@ -0,0 +1,13 @@ +from easy_tpp.config_factory.config import Config +from easy_tpp.config_factory.data_config import DataConfig, DataSpecConfig +from easy_tpp.config_factory.hpo_config import HPOConfig, HPORunnerConfig +from easy_tpp.config_factory.runner_config import RunnerConfig, ModelConfig, BaseConfig + +__all__ = ['Config', + 'DataConfig', + 'DataSpecConfig', + 'ModelConfig', + 'BaseConfig', + 'RunnerConfig', + 'HPOConfig', + 'HPORunnerConfig'] diff --git a/easy_tpp/config_factory/config.py b/easy_tpp/config_factory/config.py new file mode 100644 index 0000000000000000000000000000000000000000..de0cc9de0ae6e0fd3ca8ce9a54a3a345370894f2 --- /dev/null +++ b/easy_tpp/config_factory/config.py @@ -0,0 +1,120 @@ +from abc import abstractmethod +from typing import Any +from omegaconf import OmegaConf + +from easy_tpp.utils import save_yaml_config, Registrable, logger + + +class Config(Registrable): + + def save_to_yaml_file(self, config_dir): + """Save the config into the yaml file 'config_dir'. + + Args: + config_dir (str): Target filename. + + Returns: + """ + yaml_config = self.get_yaml_config() + OmegaConf.save(yaml_config, config_dir) + + @staticmethod + def build_from_yaml_file(yaml_dir, **kwargs): + """Load yaml config file from disk. + + Args: + yaml_dir (str): Path of the yaml config file. + + Returns: + EasyTPP.Config: Config object corresponding to cls. + """ + config = OmegaConf.load(yaml_dir) + pipeline_config = config.get('pipeline_config_id') + config_cls = Config.by_name(pipeline_config.lower()) + logger.critical(f'Load pipeline config class {config_cls.__name__}') + return config_cls.parse_from_yaml_config(config, **kwargs) + + @abstractmethod + def get_yaml_config(self): + """Get the yaml format config from self. + + Returns: + """ + pass + + @staticmethod + @abstractmethod + def parse_from_yaml_config(yaml_config): + """Parse from the yaml to generate the config object. + + Args: + yaml_config (dict): configs from yaml file. + + Returns: + EasyTPP.Config: Config class for data. + """ + pass + + @abstractmethod + def copy(self): + """Get a same and freely modifiable copy of self. + + Returns: + """ + pass + + def __str__(self): + """Str representation of the config. + + Returns: + str: str representation of the dict format of the config. + """ + return str(self.get_yaml_config()) + + def update(self, config): + """Update the config. + + Args: + config (dict): config dict. + + Returns: + EasyTPP.Config: Config class for data. + """ + logger.critical(f'Update config class {self.__class__.__name__}') + return self.parse_from_yaml_config(config) + + def pop(self, key: str, default_var: Any): + """pop out the key-value item from the config. + + Args: + key (str): key name. + default_var (Any): default value to pop. + + Returns: + Any: value to pop. + """ + return vars(self).pop(key) or default_var + + def get(self, key: str, default_var: Any): + """Retrieve the key-value item from the config. + + Args: + key (str): key name. + default_var (Any): default value to pop. + + Returns: + Any: value to get. + """ + return vars(self)[key] or default_var + + def set(self, key: str, var_to_set: Any): + """Set the key-value item from the config. + + Args: + key (str): key name. + var_to_set (Any): default value to pop. + + Returns: + Any: value to get. + """ + vars(self)[key] = var_to_set diff --git a/easy_tpp/config_factory/data_config.py b/easy_tpp/config_factory/data_config.py new file mode 100644 index 0000000000000000000000000000000000000000..9b3fd30e761f0e958d25c3c75350398c4036f767 --- /dev/null +++ b/easy_tpp/config_factory/data_config.py @@ -0,0 +1,147 @@ +from easy_tpp.config_factory.config import Config + + +class DataSpecConfig(Config): + def __init__(self, **kwargs): + """Initialize the Config class. + """ + self.num_event_types = kwargs.get('num_event_types') + self.pad_token_id = kwargs.get('pad_token_id') + self.padding_side = kwargs.get('padding_side') + self.truncation_side = kwargs.get('truncation_side') + self.padding_strategy = kwargs.get('padding_strategy') + self.max_len = kwargs.get('max_len') + self.truncation_strategy = kwargs.get('truncation_strategy') + self.num_event_types_pad = self.num_event_types + 1 + self.model_input_names = kwargs.get('model_input_names') + + if self.padding_side is not None and self.padding_side not in ["right", "left"]: + raise ValueError( + f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}" + ) + + if self.truncation_side is not None and self.truncation_side not in ["right", "left"]: + raise ValueError( + f"Truncation side should be selected between 'right' and 'left', current value: {self.truncation_side}" + ) + + def get_yaml_config(self): + """Return the config in dict (yaml compatible) format. + + Returns: + dict: config of the data specs in dict format. + """ + return { + 'num_event_types': self.num_event_types, + 'pad_token_id': self.pad_token_id, + 'padding_side': self.padding_side, + 'truncation_side': self.truncation_side, + 'padding_strategy': self.padding_strategy, + 'truncation_strategy': self.truncation_strategy, + 'max_len': self.max_len + } + + @staticmethod + def parse_from_yaml_config(yaml_config): + """Parse from the yaml to generate the config object. + + Args: + yaml_config (dict): configs from yaml file. + + Returns: + DataSpecConfig: Config class for data specs. + """ + return DataSpecConfig(**yaml_config) + + def copy(self): + """Copy the config. + + Returns: + DataSpecConfig: a copy of current config. + """ + return DataSpecConfig(num_event_types_pad=self.num_event_types_pad, + num_event_types=self.num_event_types, + event_pad_index=self.pad_token_id, + padding_side=self.padding_side, + truncation_side=self.truncation_side, + padding_strategy=self.padding_strategy, + truncation_strategy=self.truncation_strategy, + max_len=self.max_len) + + +@Config.register('data_config') +class DataConfig(Config): + def __init__(self, train_dir, valid_dir, test_dir, data_format, specs=None): + """Initialize the DataConfig object. + + Args: + train_dir (str): dir of tran set. + valid_dir (str): dir of valid set. + test_dir (str): dir of test set. + specs (dict, optional): specs of dataset. Defaults to None. + """ + self.train_dir = train_dir + self.valid_dir = valid_dir + self.test_dir = test_dir + self.data_specs = specs or DataSpecConfig() + self.data_format = train_dir.split('.')[-1] if data_format is None else data_format + + def get_yaml_config(self): + """Return the config in dict (yaml compatible) format. + + Returns: + dict: config of the data in dict format. + """ + return { + 'train_dir': self.train_dir, + 'valid_dir': self.valid_dir, + 'test_dir': self.test_dir, + 'data_format': self.data_format, + 'data_specs': self.data_specs.get_yaml_config(), + } + + @staticmethod + def parse_from_yaml_config(yaml_config): + """Parse from the yaml to generate the config object. + + Args: + yaml_config (dict): configs from yaml file. + + Returns: + EasyTPP.DataConfig: Config class for data. + """ + return DataConfig( + train_dir=yaml_config.get('train_dir'), + valid_dir=yaml_config.get('valid_dir'), + test_dir=yaml_config.get('test_dir'), + data_format=yaml_config.get('data_format'), + specs=DataSpecConfig.parse_from_yaml_config(yaml_config.get('data_specs')) + ) + + def copy(self): + """Copy the config. + + Returns: + EasyTPP.DataConfig: a copy of current config. + """ + return DataConfig(train_dir=self.train_dir, + valid_dir=self.valid_dir, + test_dir=self.test_dir, + specs=self.data_specs) + + def get_data_dir(self, split): + """Get the dir of the source raw data. + + Args: + split (str): dataset split notation, 'train', 'dev' or 'valid', 'test'. + + Returns: + str: dir of the source raw data file. + """ + split = split.lower() + if split == 'train': + return self.train_dir + elif split in ['dev', 'valid']: + return self.valid_dir + else: + return self.test_dir diff --git a/easy_tpp/config_factory/hpo_config.py b/easy_tpp/config_factory/hpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..49e282cdc59ef54daeb8b310cb3a207e397d014b --- /dev/null +++ b/easy_tpp/config_factory/hpo_config.py @@ -0,0 +1,132 @@ +from easy_tpp.config_factory.config import Config +from easy_tpp.config_factory.runner_config import RunnerConfig +from easy_tpp.utils import parse_uri_to_protocol_and_path, py_assert + + +class HPOConfig(Config): + def __init__(self, framework_id, storage_uri, is_continuous, num_trials, num_jobs): + """Initialize the HPO Config + + Args: + framework_id (str): hpo framework id. + storage_uri (str): result storage dir. + is_continuous (bool): whether to continuously do the optimization. + num_trials (int): num of trails used in optimization. + num_jobs (int): num of the jobs. + """ + self.framework_id = framework_id or 'optuna' + self.is_continuous = is_continuous if is_continuous is not None else True + self.num_trials = num_trials or 50 + self.storage_uri = storage_uri + self.num_jobs = num_jobs if num_jobs is not None else 1 + + @property + def storage_protocol(self): + """Get the storage protocol + + Returns: + str: the dir of the storage protocol. + """ + storage_protocol, _ = parse_uri_to_protocol_and_path(self.storage_uri) + return storage_protocol + + @property + def storage_path(self): + """Get the storage protocol + + Returns: + str: the dir of the hpo data storage. + """ + _, storage_path = parse_uri_to_protocol_and_path(self.storage_uri) + return storage_path + + def get_yaml_config(self): + """Return the config in dict (yaml compatible) format. + + Returns: + dict: config of the HPO specs in dict format. + """ + return { + 'framework_id': self.framework_id, + 'storage_uri': self.storage_uri, + 'is_continuous': self.is_continuous, + 'num_trials': self.num_trials, + 'num_jobs': self.num_jobs + } + + @staticmethod + def parse_from_yaml_config(yaml_config, **kwargs): + """Parse from the yaml to generate the config object. + + Args: + yaml_config (dict): configs from yaml file. + + Returns: + EasyTPP.HPOConfig: Config class for HPO specs. + """ + if yaml_config is None: + return None + else: + return HPOConfig( + framework_id=yaml_config.get('framework_id'), + storage_uri=yaml_config.get('storage_uri'), + is_continuous=yaml_config.get('is_continuous'), + num_trials=yaml_config.get('num_trials'), + num_jobs=yaml_config.get('num_jobs'), + ) + + def copy(self): + """Copy the config. + + Returns: + EasyTPP.HPOConfig: a copy of current config. + """ + return HPOConfig( + framework_id=self.framework_id, + storage_uri=self.storage_uri, + is_continuous=self.is_continuous, + num_trials=self.num_trials, + num_jobs=self.num_jobs + ) + + +@Config.register('hpo_runner_config') +class HPORunnerConfig(Config): + def __init__(self, hpo_config, runner_config): + """Initialize the config class + + Args: + hpo_config (EasyTPP.HPOConfig): hpo config class. + runner_config (EasyTPP.RunnerConfig): runner config class. + """ + self.hpo_config = hpo_config + self.runner_config = runner_config + + @staticmethod + def parse_from_yaml_config(yaml_config, **kwargs): + """Parse from the yaml to generate the config object. + + Args: + yaml_config (dict): configs from yaml file. + + Returns: + EasyTPP.HPORunnerConfig: Config class for HPO specs. + """ + runner_config = RunnerConfig.parse_from_yaml_config(yaml_config, **kwargs) + hpo_config = HPOConfig.parse_from_yaml_config(yaml_config.get('hpo'), **kwargs) + py_assert(hpo_config is not None, ValueError, 'No hpo configs is provided for HyperTuner') + return HPORunnerConfig( + hpo_config=hpo_config, + runner_config=runner_config + ) + + def copy(self): + """Copy the config. + + Returns: + EasyTPP.HPORunnerConfig: a copy of current config. + """ + return HPORunnerConfig( + hpo_config=self.hpo_config, + runner_config=self.runner_config + ) diff --git a/easy_tpp/config_factory/model_config.py b/easy_tpp/config_factory/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..26756e9e5ef2c3afc7e8d5c0edddb3b706305418 --- /dev/null +++ b/easy_tpp/config_factory/model_config.py @@ -0,0 +1,274 @@ +from easy_tpp.config_factory.config import Config + +from easy_tpp.utils.const import Backend + + +class TrainerConfig(Config): + + def __init__(self, **kwargs): + """Initialize the Config class. + """ + self.seed = kwargs.get('seed', 9899) + self.gpu = kwargs.get('gpu', -1) + self.batch_size = kwargs.get('batch_size', 256) + self.max_epoch = kwargs.get('max_epoch', 10) + self.shuffle = kwargs.get('shuffle', False) + self.optimizer = kwargs.get('optimizer', 'adam') + self.learning_rate = kwargs.get('learning_rate', 1.e-3) + self.valid_freq = kwargs.get('valid_freq', 1) + self.use_tfb = kwargs.get('use_tfb', False) + self.metrics = kwargs.get('metrics', ['acc', 'rmse']) + + def get_yaml_config(self): + """Return the config in dict (yaml compatible) format. + + Returns: + dict: config of the trainer specs in dict format. + """ + return {'seed': self.seed, + 'gpu': self.gpu, + 'batch_size': self.batch_size, + 'max_epoch': self.max_epoch, + 'shuffle': self.shuffle, + 'optimizer': self.optimizer, + 'learning_rate': self.learning_rate, + 'valid_freq': self.valid_freq, + 'use_tfb': self.use_tfb, + 'metrics': self.metrics + } + + @staticmethod + def parse_from_yaml_config(yaml_config): + """Parse from the yaml to generate the config object. + + Args: + yaml_config (dict): configs from yaml file. + + Returns: + EasyTPP.TrainerConfig: Config class for trainer specs. + """ + return TrainerConfig(**yaml_config) + + def copy(self): + """Copy the config. + + Returns: + EasyTPP.TrainerConfig: a copy of current config. + """ + return TrainerConfig(batch_size=self.batch_size, + max_epoch=self.max_epoch, + shuffle=self.shuffle, + optimizer=self.optimizer, + learning_rate=self.learning_rate, + valid_freq=self.valid_freq, + use_tfb=self.use_tfb, + metrics=self.metrics + ) + + +class ThinningConfig(Config): + def __init__(self, **kwargs): + """Initialize the Config class. + """ + self.num_seq = kwargs.get('num_seq', 10) + self.num_sample = kwargs.get('num_sample', 1) + self.num_exp = kwargs.get('num_exp', 500) + self.look_ahead_time = kwargs.get('look_ahead_time', 10) + self.patience_counter = kwargs.get('patience_counter', 5) + self.over_sample_rate = kwargs.get('over_sample_rate', 5) + self.num_samples_boundary = kwargs.get('num_samples_boundary', 5) + self.dtime_max = kwargs.get('dtime_max', 5) + # we pad the sequence at the front only in multi-step generation + self.num_step_gen = kwargs.get('num_step_gen', 1) + + def get_yaml_config(self): + """Return the config in dict (yaml compatible) format. + + Returns: + dict: config of the thinning specs in dict format. + """ + return {'num_seq': self.num_seq, + 'num_sample': self.num_sample, + 'num_exp': self.num_exp, + 'look_ahead_time': self.look_ahead_time, + 'patience_counter': self.patience_counter, + 'over_sample_rate': self.over_sample_rate, + 'num_samples_boundary': self.num_samples_boundary, + 'dtime_max': self.dtime_max, + 'num_step_gen': self.num_step_gen} + + @staticmethod + def parse_from_yaml_config(yaml_config): + """Parse from the yaml to generate the config object. + + Args: + yaml_config (dict): configs from yaml file. + + Returns: + EasyTPP.ThinningConfig: Config class for thinning algorithms. + """ + return ThinningConfig(**yaml_config) if yaml_config is not None else None + + def copy(self): + """Copy the config. + + Returns: + EasyTPP.ThinningConfig: a copy of current config. + """ + return ThinningConfig(num_seq=self.num_seq, + num_sample=self.num_sample, + num_exp=self.num_exp, + look_ahead_time=self.look_ahead_time, + patience_counter=self.patience_counter, + over_sample_rate=self.over_sample_rate, + num_samples_boundary=self.num_samples_boundary, + dtime_max=self.dtime_max, + num_step_gen=self.num_step_gen) + + +class BaseConfig(Config): + def __init__(self, **kwargs): + """Initialize the Config class. + """ + self.stage = kwargs.get('stage') + self.backend = kwargs.get('backend') + self.dataset_id = kwargs.get('dataset_id') + self.runner_id = kwargs.get('runner_id') + self.model_id = kwargs.get('model_id') + self.exp_id = kwargs.get('exp_id') + self.base_dir = kwargs.get('base_dir') + self.specs = kwargs.get('specs', {}) + self.backend = self.set_backend(self.backend) + + @staticmethod + def set_backend(backend): + if backend.lower() in ['torch', 'pytorch']: + return Backend.Torch + else: + raise ValueError( + f"Backend should be 'torch' or 'pytorch', current value: {backend}" + ) + + def get_yaml_config(self): + """Return the config in dict (yaml compatible) format. + + Returns: + dict: config of the base config specs in dict format. + """ + return {'stage': self.stage, + 'backend': str(self.backend), + 'dataset_id': self.dataset_id, + 'runner_id': self.runner_id, + 'model_id': self.model_id, + 'base_dir': self.base_dir, + 'specs': self.specs} + + @staticmethod + def parse_from_yaml_config(yaml_config): + """Parse from the yaml to generate the config object. + + Args: + yaml_config (dict): configs from yaml file. + + Returns: + BaseConfig: Config class for trainer specs. + """ + return BaseConfig(**yaml_config) + + def copy(self): + """Copy the config. + + Returns: + BaseConfig: a copy of current config. + """ + return BaseConfig(stage=self.stage, + backend=self.backend, + dataset_id=self.dataset_id, + runner_id=self.runner_id, + model_id=self.model_id, + base_dir=self.base_dir, + specs=self.specs) + + +class ModelConfig(Config): + def __init__(self, **kwargs): + """Initialize the Config class. + """ + self.rnn_type = kwargs.get('rnn_type', 'LSTM') + self.hidden_size = kwargs.get('hidden_size', 32) + self.time_emb_size = kwargs.get('time_emb_size', 16) + self.num_layers = kwargs.get('num_layers', 2) + self.num_heads = kwargs.get('num_heads', 2) + self.sharing_param_layer = kwargs.get('sharing_param_layer', False) + self.use_mc_samples = kwargs.get('use_mc_samples', True) # if using MC samples in computing log-likelihood + self.loss_integral_num_sample_per_step = kwargs.get('loss_integral_num_sample_per_step', 20) # mc_num_sample_per_step + self.dropout_rate = kwargs.get('dropout_rate', 0.0) + self.use_ln = kwargs.get('use_ln', False) + self.thinning = ThinningConfig.parse_from_yaml_config(kwargs.get('thinning')) + self.is_training = kwargs.get('training', False) + self.num_event_types_pad = kwargs.get('num_event_types_pad', None) + self.num_event_types = kwargs.get('num_event_types', None) + self.pad_token_id = kwargs.get('event_pad_index', None) + self.model_id = kwargs.get('model_id', None) + self.pretrained_model_dir = kwargs.get('pretrained_model_dir', None) + self.gpu = kwargs.get('gpu', -1) + self.model_specs = kwargs.get('model_specs', {}) + + def get_yaml_config(self): + """Return the config in dict (yaml compatible) format. + + Returns: + dict: config of the model config specs in dict format. + """ + return {'rnn_type': self.rnn_type, + 'hidden_size': self.hidden_size, + 'time_emb_size': self.time_emb_size, + 'num_layers': self.num_layers, + 'sharing_param_layer': self.sharing_param_layer, + 'loss_integral_num_sample_per_step': self.loss_integral_num_sample_per_step, + 'dropout_rate': self.dropout_rate, + 'use_ln': self.use_ln, + # for some models / cases we may not need to pass thinning config + # e.g., for intensity-free model + 'thinning': None if self.thinning is None else self.thinning.get_yaml_config(), + 'num_event_types_pad': self.num_event_types_pad, + 'num_event_types': self.num_event_types, + 'event_pad_index': self.pad_token_id, + 'model_id': self.model_id, + 'pretrained_model_dir': self.pretrained_model_dir, + 'gpu': self.gpu, + 'model_specs': self.model_specs} + + @staticmethod + def parse_from_yaml_config(yaml_config): + """Parse from the yaml to generate the config object. + + Args: + yaml_config (dict): configs from yaml file. + + Returns: + ModelConfig: Config class for trainer specs. + """ + return ModelConfig(**yaml_config) + + def copy(self): + """Copy the config. + + Returns: + ModelConfig: a copy of current config. + """ + return ModelConfig(rnn_type=self.rnn_type, + hidden_size=self.hidden_size, + time_emb_size=self.time_emb_size, + num_layers=self.num_layers, + sharing_param_layer=self.sharing_param_layer, + loss_integral_num_sample_per_step=self.loss_integral_num_sample_per_step, + dropout_rate=self.dropout_rate, + use_ln=self.use_ln, + thinning=self.thinning, + num_event_types_pad=self.num_event_types_pad, + num_event_types=self.num_event_types, + event_pad_index=self.pad_token_id, + pretrained_model_dir=self.pretrained_model_dir, + gpu=self.gpu, + model_specs=self.model_specs) diff --git a/easy_tpp/config_factory/runner_config.py b/easy_tpp/config_factory/runner_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ebb5bf3d06eaaae2de4a305806bfda337dd057e9 --- /dev/null +++ b/easy_tpp/config_factory/runner_config.py @@ -0,0 +1,161 @@ +import copy +import os + +from easy_tpp.config_factory.config import Config +from easy_tpp.config_factory.data_config import DataConfig +from easy_tpp.config_factory.model_config import TrainerConfig, ModelConfig, BaseConfig +from easy_tpp.utils import create_folder, logger, get_unique_id, get_stage, RunnerPhase, \ + MetricsHelper, DefaultRunnerConfig, py_assert, is_torch_available, \ + is_torch_gpu_available +from easy_tpp.utils.const import Backend + + +@Config.register('runner_config') +class RunnerConfig(Config): + def __init__(self, base_config, model_config, data_config, trainer_config): + """Initialize the Config class. + + Args: + base_config (EasyTPP.BaseConfig): BaseConfig object. + model_config (EasyTPP.ModelConfig): ModelConfig object. + data_config (EasyTPP.DataConfig): DataConfig object. + trainer_config (EasyTPP.TrainerConfig): TrainerConfig object + """ + self.data_config = data_config + self.model_config = model_config + self.base_config = base_config + self.trainer_config = trainer_config + + self.update_config() + + # save the complete config + save_dir = self.base_config.specs['output_config_dir'] + self.save_to_yaml_file(save_dir) + + logger.info(f'Save the config to {save_dir}') + + def get_yaml_config(self): + """Return the config in dict (yaml compatible) format. + + Returns: + dict: config of the runner config in dict format. + """ + return {'data_config': self.data_config.get_yaml_config(), + 'base_config': self.base_config.get_yaml_config(), + 'model_config': self.model_config.get_yaml_config(), + 'trainer_config': self.trainer_config.get_yaml_config()} + + @staticmethod + def parse_from_yaml_config(yaml_config, **kwargs): + """Parse from the yaml to generate the config object. + + Args: + yaml_config (dict): configs from yaml file. + + Returns: + RunnerConfig: Config class for trainer specs. + """ + direct_parse = kwargs.get('direct_parse', False) + if not direct_parse: + exp_id = kwargs.get('experiment_id') + yaml_exp_config = yaml_config[exp_id] + dataset_id = yaml_exp_config.get('base_config').get('dataset_id') + if dataset_id is None: + dataset_id = DefaultRunnerConfig.DEFAULT_DATASET_ID + try: + yaml_data_config = yaml_config['data'][dataset_id] + except KeyError: + raise RuntimeError('dataset_id={} is not found in config.'.format(dataset_id)) + + data_config = DataConfig.parse_from_yaml_config(yaml_data_config) + # add exp id to base config + yaml_exp_config.get('base_config').update(exp_id=exp_id) + + else: + yaml_exp_config = yaml_config + data_config = DataConfig.parse_from_yaml_config(yaml_exp_config.get('data_config')) + + base_config = BaseConfig.parse_from_yaml_config(yaml_exp_config.get('base_config')) + model_config = ModelConfig.parse_from_yaml_config(yaml_exp_config.get('model_config')) + trainer_config = TrainerConfig.parse_from_yaml_config(yaml_exp_config.get('trainer_config')) + + return RunnerConfig( + data_config=data_config, + base_config=base_config, + model_config=model_config, + trainer_config=trainer_config + ) + + def update_config(self): + """Updated config dict. + """ + model_folder_name = get_unique_id() + + log_folder = create_folder(self.base_config.base_dir, model_folder_name) + model_folder = create_folder(log_folder, 'models') + + self.base_config.specs['log_folder'] = log_folder + self.base_config.specs['saved_model_dir'] = os.path.join(model_folder, 'saved_model') + self.base_config.specs['saved_log_dir'] = os.path.join(log_folder, 'log') + self.base_config.specs['output_config_dir'] = os.path.join(log_folder, + f'{self.base_config.exp_id}_output.yaml') + + if self.trainer_config.use_tfb: + self.base_config.specs['tfb_train_dir'] = create_folder(log_folder, 'tfb_train') + self.base_config.specs['tfb_valid_dir'] = create_folder(log_folder, 'tfb_valid') + + current_stage = get_stage(self.base_config.stage) + is_training = current_stage == RunnerPhase.TRAIN + self.model_config.is_training = is_training + self.model_config.gpu = self.trainer_config.gpu + + # update the dataset config => model config + self.model_config.num_event_types_pad = self.data_config.data_specs.num_event_types_pad + self.model_config.num_event_types = self.data_config.data_specs.num_event_types + self.model_config.pad_token_id = self.data_config.data_specs.pad_token_id + self.model_config.max_len = self.data_config.data_specs.max_len + + # update base config => model config + model_id = self.base_config.model_id + self.model_config.model_id = model_id + + run = current_stage + use_torch = self.base_config.backend == Backend.Torch + device = 'GPU' if self.trainer_config.gpu >= 0 else 'CPU' + + py_assert(is_torch_available(), ValueError, + f'PyTorch is not available in the current environment!') + + if use_torch and device == 'GPU': + py_assert(is_torch_gpu_available(), + ValueError, + f'Torch cuda is not supported in the current environment yet!') + + critical_msg = '{run} model {model_name} using {device} ' \ + 'with {tf_torch} backend'.format(run=run, + model_name=model_id, + device=device, + tf_torch=self.base_config.backend) + + logger.critical(critical_msg) + + return + + def get_metric_functions(self): + return MetricsHelper.get_metrics_callback_from_names(self.trainer_config.metrics) + + def get_metric_direction(self, metric_name='rmse'): + return MetricsHelper.get_metric_direction(metric_name) + + def copy(self): + """Copy the config. + + Returns: + RunnerConfig: a copy of current config. + """ + return RunnerConfig( + base_config=copy.deepcopy(self.base_config), + model_config=copy.deepcopy(self.model_config), + data_config=copy.deepcopy(self.data_config), + trainer_config=copy.deepcopy(self.trainer_config) + ) diff --git a/easy_tpp/default_registers/__init__.py b/easy_tpp/default_registers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/easy_tpp/default_registers/register_metrics.py b/easy_tpp/default_registers/register_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..feeab1f46d7b2cb06aedb5bdf313eda87e464eb4 --- /dev/null +++ b/easy_tpp/default_registers/register_metrics.py @@ -0,0 +1,53 @@ +import numpy as np + +from easy_tpp.utils.const import PredOutputIndex +from easy_tpp.utils.metrics import MetricsHelper + + +@MetricsHelper.register(name='rmse', direction=MetricsHelper.MINIMIZE, overwrite=False) +def rmse_metric_function(predictions, labels, **kwargs): + """Compute rmse metrics of the time predictions. + + Args: + predictions (np.array): model predictions. + labels (np.array): ground truth. + + Returns: + float: average rmse of the time predictions. + """ + seq_mask = kwargs.get('seq_mask') + if seq_mask is None or len(seq_mask) == 0: + # If mask is empty or None, use all predictions + pred = predictions[PredOutputIndex.TimePredIndex] + label = labels[PredOutputIndex.TimePredIndex] + else: + pred = predictions[PredOutputIndex.TimePredIndex][seq_mask] + label = labels[PredOutputIndex.TimePredIndex][seq_mask] + + pred = np.reshape(pred, [-1]) + label = np.reshape(label, [-1]) + return np.sqrt(np.mean((pred - label) ** 2)) + + +@MetricsHelper.register(name='acc', direction=MetricsHelper.MAXIMIZE, overwrite=False) +def acc_metric_function(predictions, labels, **kwargs): + """Compute accuracy ratio metrics of the type predictions. + + Args: + predictions (np.array): model predictions. + labels (np.array): ground truth. + + Returns: + float: accuracy ratio of the type predictions. + """ + seq_mask = kwargs.get('seq_mask') + if seq_mask is None or len(seq_mask) == 0: + # If mask is empty or None, use all predictions + pred = predictions[PredOutputIndex.TypePredIndex] + label = labels[PredOutputIndex.TypePredIndex] + else: + pred = predictions[PredOutputIndex.TypePredIndex][seq_mask] + label = labels[PredOutputIndex.TypePredIndex][seq_mask] + pred = np.reshape(pred, [-1]) + label = np.reshape(label, [-1]) + return np.mean(pred == label) diff --git a/easy_tpp/default_registers/register_optuna_trials.py b/easy_tpp/default_registers/register_optuna_trials.py new file mode 100644 index 0000000000000000000000000000000000000000..479276d797d5eda747e5191aa96bc8c164f8390d --- /dev/null +++ b/easy_tpp/default_registers/register_optuna_trials.py @@ -0,0 +1,13 @@ +from easy_tpp.hpo.optuna_hpo import OptunaTuner + + +@OptunaTuner.register_trial_func(model_id='default', overwrite=False) +def default_trial(trial, **kwargs): + setting = { + "trainer_config": {"max_epoch": "suggest_int(40, 100, log=True)", + "batch_size": 256, + "optimizer": "adam", + "learning_rate": "suggest_float(5e-4, 1e-2, log=True)"}, + "model_config": {"hidden_size": "suggest_int(16, 32)"} + } + return setting diff --git a/easy_tpp/hpo/__init__.py b/easy_tpp/hpo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6efc8ac9106f0a25e041a957e3b6ddcc1df9637 --- /dev/null +++ b/easy_tpp/hpo/__init__.py @@ -0,0 +1,6 @@ +from easy_tpp.hpo.base_hpo import HyperTuner +from easy_tpp.hpo.optuna_hpo import OptunaTuner +from easy_tpp.default_registers.register_optuna_trials import * + +__all__ = ['HyperTuner', + 'OptunaTuner'] \ No newline at end of file diff --git a/easy_tpp/hpo/base_hpo.py b/easy_tpp/hpo/base_hpo.py new file mode 100644 index 0000000000000000000000000000000000000000..43f8a79ec280de3f1fa1bc7d249f8561aa957609 --- /dev/null +++ b/easy_tpp/hpo/base_hpo.py @@ -0,0 +1,131 @@ +from abc import abstractmethod +from collections import defaultdict +from typing import List + +from easy_tpp.utils import logger, Registrable + + +class HyperTuner(Registrable): + _trial_register_center = defaultdict(dict) + + def __init__(self, config, trial_end_callbacks: List[callable] = None): + """Initialize the tuner + + Args: + config (EasyTPP.Config): config class + trial_end_callbacks (List[callable]): List of callback functions to be executed after each trial. + """ + self.config = config + self.trial_end_callbacks = trial_end_callbacks or [] + logger.info(f'Storage of hpo framework: {self.config.hpo_config.storage_uri}') + + @abstractmethod + def get_all_best_runner_configs(self): + pass + + @abstractmethod + def get_best_runner_config_by_name(self, runner_id): + """ + + Args: + runner_id (str): + + Returns: + + """ + pass + + @abstractmethod + def get_num_remain_trials_by_name(self, runner_id): + pass + + @staticmethod + def build_from_config(config, trial_end_callbacks: List[callable] = None): + """Load yaml config file from disk. + + Args: + config (EasyTPP.Config): config class + trial_end_callbacks (List[callable]): List of callback functions to be executed after each trial. + + Returns: + EasyTPP.Config: Config object corresponding to cls. + """ + runner_cls = HyperTuner.by_name(config.hpo_config.framework_id) + return runner_cls(config, trial_end_callbacks) + + # ---------------------- Trail Register and Get Functions --------------------- + + @classmethod + def register_trial_func(cls, model_id, overwrite=True): + """Register the trial functions in HPO + + Args: + model_id (str): id of the models. + overwrite (bool, optional): whether to overwrite the trial function. Defaults to True. + + Returns: + dict: the registered trial function + """ + register_center = HyperTuner._trial_register_center + + def _register_trial(func): + if model_id in register_center[cls]: + if overwrite: + register_center[cls][model_id] = func + logger.info(f'The trial for {model_id} is already registered, but overwrite it.') + else: + logger.warn(f'The trial for {model_id} is already registered, and cannot be overwritten!') + else: + register_center[cls][model_id] = func + logger.info(f'Trial register: {cls.get_registered_name()} - {model_id}') + return func + + return _register_trial + + @classmethod + def retrieve_trial_func_by_model_name(cls, name): + """Retrieve the trail function by the model id + + Args: + name (str): model id. + + Raises: + RuntimeError: non registered error for the hpo framework. + + Returns: + dict: registered trial center + """ + cls_trial_rc = HyperTuner._trial_register_center[cls] + if name not in cls_trial_rc: + if 'default' in cls_trial_rc: + logger.warn( + f'Trial for {name} in {cls.get_registered_name()} is not existed, and use default trial!' + ) + name = 'default' + else: + raise RuntimeError(f'This HPO Framework is not registered!') + return cls_trial_rc[name] + + @classmethod + def get_registered_name(cls): + """Get the name of the registered hpo class. + + Returns: + str: the name of the registered hpo class. + """ + hpo_rc = HyperTuner.registry_dict() + for registered_name, hpo_cls in hpo_rc.items(): + if cls in hpo_cls: + return registered_name + + logger.warn(f'The hpo framework is not registered: {cls}') + return None + + @abstractmethod + def run(self): + """Run the process. + + Raises: + NotImplementedError: error raised in base class. + """ + raise NotImplementedError diff --git a/easy_tpp/hpo/optuna_hpo.py b/easy_tpp/hpo/optuna_hpo.py new file mode 100644 index 0000000000000000000000000000000000000000..99d51450016c4beec7c5862076149853eacb7ca9 --- /dev/null +++ b/easy_tpp/hpo/optuna_hpo.py @@ -0,0 +1,376 @@ +import logging +import os +import sys + +import optuna +from optuna.samplers import TPESampler +from optuna.trial import TrialState + +from easy_tpp.config_factory import RunnerConfig +from easy_tpp.hpo.base_hpo import HyperTuner +from easy_tpp.preprocess import TPPDataLoader +from easy_tpp.runner import Runner +from easy_tpp.utils import Timer, dict_deep_update +from easy_tpp.utils.log_utils import get_logger + +optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout)) +logger = get_logger('optuna_hpo') + + +@HyperTuner.register(name='optuna') +class OptunaTuner(HyperTuner): + def __init__(self, config, trial_end_callbacks): + """Initialize the Optuna Tuner class. + + Args: + config (EasyTPP.Config): config class. + trial_end_callbacks (list): list of trial callbacks. + """ + super(OptunaTuner, self).__init__(config, trial_end_callbacks) + + # fetch db storage from the given storage_uri + self.storage_fn = self._fetch_storage() + + # optuna db storage uri + self.storage = 'sqlite:///{}'.format(self.storage_fn) if self.storage_fn else None + + self.runner_config = self.config.runner_config + self.hpo_config = self.config.hpo_config + + # build data reader + data_config = self.runner_config.data_config + backend = self.runner_config.base_config.backend + kwargs = self.runner_config.trainer_config.get_yaml_config() + self._data_loader = TPPDataLoader( + data_config=data_config, + backend=backend, + **kwargs + ) + + def get_all_best_runner_configs(self): + """Get all best runner configs. Obtain from storage. + + Returns: + Dict[str, EasyTPP.RunnerConfig]: Dict of all best runner configs. + """ + runner_configs = {} + for study_summary in optuna.get_all_study_summaries(self.storage): + runner_configs[study_summary.study_name] = self._build_runner_config_from_storage( + study=study_summary, + trial=study_summary.best_trial + ) + return runner_configs + + def get_best_runner_config_by_name(self, exp_id): + """Get the best runner config by runner_id. Obtain it from storage. + + Args: + exp_id (str): experiment id. + + Returns: + EasyTPP.RunnerConfig: best runner config. + """ + for study_summary in optuna.get_all_study_summaries(self.storage): + if exp_id == study_summary.study_name: + return self._build_runner_config_from_storage(study_summary, study_summary.best_trial) + return None + + def get_num_remain_trials_by_name(self, exp_id): + """Get the num of remaining trails by experiment id. + + Args: + exp_id (str): experiment id. + + Returns: + int: num of remaining trails. + """ + for study_summary in optuna.get_all_study_summaries(self.storage): + if exp_id == study_summary.study_name: + study = optuna.load_study(study_name=exp_id, storage=self.storage) + num_completed_trials = len(study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,))) + num_remain_trials = self.hpo_config.num_trials - num_completed_trials + return num_remain_trials + return self.hpo_config.num_trials + + def optimize( + self, + base_runner_config, + train_loader, + valid_loader, + test_loader=None, + exp_id=None, + **kwargs + ): + """Run the optimization process. + + Args: + base_runner_config (EasyTPP.RunnerConfig): runner config. + train_loader (EasyTPP.DataLoader): train data loader. + valid_loader (EasyTPP.DataLoader): valid data loader + test_loader (EasyTPP.DataLoader, optional): test data loader. Defaults to None. + exp_id (str, optional): experiment id. Defaults to None. + + Raises: + RuntimeError: best trial is not found. + + Returns: + tuple: best_metric and best_runner_config + """ + # obtain parameters + storage = self.storage + load_if_exists = self.hpo_config.is_continuous + metric_direction = base_runner_config.get_metric_direction() + + # delete the study if it already existed when 'is_continue' is false + if not load_if_exists and exp_id is not None: + if exp_id in [std_summary.study_name for std_summary in optuna.get_all_study_summaries(storage)]: + optuna.delete_study( + study_name=exp_id, + storage=storage + ) + # create hpo study + study = optuna.create_study( + storage=storage, + direction=metric_direction, + load_if_exists=load_if_exists, + study_name=exp_id, + sampler=TPESampler(seed=9899), + ) + + # set user_attr to study + study.set_user_attr('data_config', base_runner_config.data_config.get_yaml_config()) + + # calculate the number of remaining trials + num_completed_trials = len(study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,))) + num_remain_trials = self.hpo_config.num_trials - num_completed_trials + if num_remain_trials > 0: + logger.info(f'Number of hpo trials completed for runner {exp_id}: ' + f'{num_completed_trials}/{self.hpo_config.num_trials}') + objective_func = self._get_objective_func( + base_runner_config=base_runner_config, + train_loader=train_loader, + valid_loader=valid_loader, + test_loader=test_loader, + **kwargs + ) + # hpo optimize + study.optimize( + objective_func, + n_trials=num_remain_trials, + callbacks=[self._optimize_trial_end], + gc_after_trial=True, + n_jobs=self.hpo_config.num_jobs, + ) + + # statistics of this hpo + pruned_trials = study.get_trials(deepcopy=False, states=(TrialState.PRUNED,)) + complete_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,)) + logger.info("HPO - Optuna statistics:") + logger.info(f"\tNumber of finished trials: {len(study.trials)}") + logger.info(f"\tNumber of pruned trials: {len(pruned_trials)}") + logger.info(f"\tNumber of complete trials: {len(complete_trials)}") + + if len(study.best_trials) == 0: + raise RuntimeError('Best trial is not found, please check the model or metric.') + trial = study.best_trial + logger.info(f"HPO - Best metric value ({metric_direction}): {trial.value}") + + logger.info(f"Best Parameters: ") + for key, value in trial.params.items(): + logger.info(f"\t{key}: {value}") + + best_metric = trial.value + best_runner_config = RunnerConfig.parse_from_yaml_config( + trial.user_attrs['runner_config'], + data_config=base_runner_config.data_config, + ) + return best_metric, best_runner_config + + def _get_objective_func( + self, + base_runner_config, + **kwargs, + ): + """Get the optimization objective function. + + Args: + base_runner_config (EasyTPP.Config): runner config. + + Raises: + e: RuntimeError + optuna.TrialPruned: message in trials. + + Returns: + _type_: _description_ + """ + trial_func = self.retrieve_trial_func_by_model_name(base_runner_config.base_config.exp_id) + + def objective(trial): + timer = Timer() + timer.start() + logger.info(f'Start the trial {trial.number} ...') + # get a copy of base runner config for isolation + # generate new runner runners + runner_config = base_runner_config.copy() + + trial_model_info = trial_func( + trial, + trainder_config=runner_config.trainer_config, + model_config=runner_config.model_config + ) + + # use predefined trial to update model_info + runner_config_dict = dict_deep_update( + target=runner_config.get_yaml_config(), + source=trial_model_info, + is_add_new_key=False) + + # eval the "suggest" in runner_config (actually run trial suggestion) + runner_config_dict = self._eval_str_trial_to_dict(trial, + runner_config_dict) + + runner_config = RunnerConfig.parse_from_yaml_config(runner_config_dict, direct_parse=True) + + runner = Runner.build_from_config( + runner_config=runner_config, + unique_model_dir=True, + skip_data_loader=True, + ) + try: + # train model + runner.train(**kwargs) + # evaluate model + metric = runner.evaluate(trial=trial, **kwargs) + # save the final model + runner.save() + except RuntimeError as e: + # add the error message into trial + err_msg = str(e) + trial.set_user_attr("error", err_msg) + + logger.error(f'Error in the trial {trial.number}: {err_msg}') + + # just prune the errors like 'out of memory' + if 'out of memory' not in err_msg: + raise e + raise optuna.TrialPruned() + finally: + # add model path into trial + trial.set_user_attr("model_dir", runner.runner_config.model_dir) + # trial.set_user_attr("model_config", runner.runner_config.model_config) + trial.set_user_attr("runner_config", runner.runner_config.get_yaml_config()) + + logger.info(f'End trial {trial.number} ! Cost time: {timer.end()}') + + return metric + + return objective + + def _optimize_trial_end(self, study, trial): + """End the process of trials. + + Args: + study (optuna.study.Study): an object of optuna :class:`~optuna.study` to be studied during optimization. + trial (optuna.trial.FrozenTrial): an object of optuna :class:`~optuna.trial` that stores trial information. + """ + # push storage to the specified uri + self._push_storage(trial) + + is_best_yet = (trial in study.best_trials) + + runner_id = study.study_name + runner_config = self._build_runner_config_from_storage(study, trial) + + # invoke callbacks + for callback in self.trial_end_callbacks: + callback(runner_id, runner_config, is_best_yet) + + # clean disk + if not is_best_yet and os.path.exists(trial.user_attrs['model_dir']): + os.system(f"rm -fr {trial.user_attrs['model_dir']}") + + def _eval_str_trial_to_dict(self, trial, a_dict): + for key, val in a_dict.items(): + if type(val) == str and val.startswith('suggest_'): + idx = val.find('(') + prefix = val[:idx + 1] + suffix = val[idx + 1:] + + # get trial variable name + trial_name = [k for k, v in locals().items() if v == trial][0] + + code = '''{0}.{1}"{2}",{3}'''.format(trial_name, prefix, key, suffix) + a_dict[key] = eval(code) + elif type(val) == dict: + self._eval_str_trial_to_dict(trial, a_dict[key]) + + return a_dict + + def _build_runner_config_from_storage( + self, + study, + trial + ): + """Initialize the RunnerConfig from the study and trial. + + Args: + study (optuna.study.Study): an object of optuna :class:`~optuna.study` to be studied during optimization. + trial (optuna.trial.FrozenTrial): an object of optuna :class:`~optuna.trial` that stores trial information. + + Returns: + EasyTPP.Config: RunnerConfig object. + """ + runner_config_dict = trial.user_attrs['runner_config'] + return RunnerConfig.parse_from_yaml_config(runner_config_dict, direct_parse=True) + + def run(self): + """Run the HPO process. + + Returns: + tuple: best_metric, best_runner_config + """ + # to avoid to load unused data + train_loader, valid_loader, test_loader = None, None, None + exp_id = self.runner_config.base_config.exp_id + + if self.get_num_remain_trials_by_name(exp_id) > 0: + train_loader = self._data_loader.get_loader(split='train') + valid_loader = self._data_loader.get_loader(split='dev') + if self.runner_config.data_config.test_dir is not None: + test_loader = self._data_loader.get_loader(split='test') + + best_metric, best_runner_config = self.optimize( + base_runner_config=self.runner_config, + train_loader=train_loader, + valid_loader=valid_loader, + test_loader=test_loader, + exp_id=exp_id + ) + + return best_metric, best_runner_config + + def _fetch_storage(self): + """Retrieve the stored model. + + Returns: + str: dir of the stored model. + """ + local_storage_fn = self.config.hpo_config.storage_path + + # return the local storage location + return local_storage_fn + + def _push_storage(self, trial): + """_summary_ + + Args: + trial (_type_): _description_ + + Raises: + NotImplementedError: _description_ + """ + # save hpo storage to remote if it's in remote + if self.config.hpo_config.storage_protocol == 'oss': + raise NotImplementedError + + return diff --git a/easy_tpp/model/__init__.py b/easy_tpp/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..915718a0655a3b5919aad213f1e51e1e8f8739c0 --- /dev/null +++ b/easy_tpp/model/__init__.py @@ -0,0 +1,25 @@ +from easy_tpp.model.torch_model.torch_anhn import ANHN as TorchANHN +from easy_tpp.model.torch_model.torch_attnhp import AttNHP as TorchAttNHP +from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel +from easy_tpp.model.torch_model.torch_fullynn import FullyNN as TorchFullyNN +from easy_tpp.model.torch_model.torch_intensity_free import IntensityFree as TorchIntensityFree +from easy_tpp.model.torch_model.torch_nhp import NHP as TorchNHP +from easy_tpp.model.torch_model.torch_ode_tpp import ODETPP as TorchODETPP +from easy_tpp.model.torch_model.torch_rmtpp import RMTPP as TorchRMTPP +from easy_tpp.model.torch_model.torch_s2p2 import S2P2 as TorchS2P2 +from easy_tpp.model.torch_model.torch_sahp import SAHP as TorchSAHP +from easy_tpp.model.torch_model.torch_thp import THP as TorchTHP +from easy_tpp.model.torch_model.torch_robot_thp import RobotTHP as TorchRobotTHP + +__all__ = ['TorchBaseModel', + 'TorchNHP', + 'TorchAttNHP', + 'TorchTHP', + 'TorchSAHP', + 'TorchFullyNN', + 'TorchIntensityFree', + 'TorchODETPP', + 'TorchRMTPP', + 'TorchANHN', + 'TorchS2P2', + 'TorchRobotTHP'] diff --git a/easy_tpp/model/torch_model/MODEL_FEATURES_SUPPORT.md b/easy_tpp/model/torch_model/MODEL_FEATURES_SUPPORT.md new file mode 100644 index 0000000000000000000000000000000000000000..55dee4b987d42074084bb5ea5bfa14f52f2bc531 --- /dev/null +++ b/easy_tpp/model/torch_model/MODEL_FEATURES_SUPPORT.md @@ -0,0 +1,206 @@ +# EasyTPP模型特征支持情况 + +## 概述 + +本文档说明EasyTPP框架中各个模型对评论罗伯特自定义特征的支持情况。 + +## 自定义特征列表 + +1. **语义向量(BERT)**: `semantic_vectors` - [batch_size, seq_len, semantic_dim] +2. **偏差特征**: `deviation_features` - [batch_size, seq_len, 3] (语境偏差、情感偏差、困惑度) +3. **自发/被@特征**: `is_spontaneous` - [batch_size, seq_len] (-1=不适用, 0=被@, 1=自发) +4. **结构感知掩码**: `structure_mask` - [batch_size, seq_len, seq_len] + +## 模型支持情况 + +| 模型 | 语义向量 | 偏差特征 | 自发/被@ | 结构掩码 | 说明 | +|------|---------|---------|---------|---------|------| +| **RobotTHP** | ✅ | ✅ | ✅ | ✅ | 专门设计支持所有特征 | +| THP | ❌ | ❌ | ❌ | ❌ | 仅支持标准特征 | +| NHP | ❌ | ❌ | ❌ | ❌ | 仅支持标准特征 | +| SAHP | ❌ | ❌ | ❌ | ❌ | 仅支持标准特征 | +| AttNHP | ❌ | ❌ | ❌ | ❌ | 仅支持标准特征 | +| RMTPP | ❌ | ❌ | ❌ | ❌ | 仅支持标准特征 | +| FullyNN | ❌ | ❌ | ❌ | ❌ | 仅支持标准特征 | +| ODETPP | ❌ | ❌ | ❌ | ❌ | 仅支持标准特征 | +| ANHN | ❌ | ❌ | ❌ | ❌ | 仅支持标准特征 | +| S2P2 | ❌ | ❌ | ❌ | ❌ | 仅支持标准特征 | + +## 标准特征(所有模型都支持) + +所有模型都支持以下标准特征: +- `time_seqs`: [batch_size, seq_len] - 时间序列 +- `time_delta_seqs`: [batch_size, seq_len] - 时间间隔序列 +- `type_seqs`: [batch_size, seq_len] - 事件类型序列 +- `batch_non_pad_mask`: [batch_size, seq_len] - 非padding掩码 +- `attention_mask`: [batch_size, seq_len, seq_len] - 注意力掩码 + +## 代码对比 + +### RobotTHP(支持自定义特征) + +```python +def forward(self, time_seqs, type_seqs, attention_mask, + semantic_vectors=None, deviation_features=None, + is_spontaneous=None, structure_mask=None): + # 支持所有自定义特征 + if self.use_semantic and semantic_vectors is not None: + semantic_emb = self.semantic_projection(semantic_vectors) + enc_output = enc_output + semantic_emb + # ... +``` + +### THP(不支持自定义特征) + +```python +def forward(self, time_seqs, type_seqs, attention_mask): + # 只接收标准参数 + tem_enc = self.layer_temporal_encoding(time_seqs) + enc_output = self.layer_type_emb(type_seqs) + # ... +``` + +### NHP(不支持自定义特征) + +```python +def forward(self, batch): + t_BN, dt_BN, marks_BN, _, _ = batch + # 只使用标准batch元素 + # ... +``` + +## 如何让其他模型支持自定义特征 + +如果需要让其他模型(如THP、NHP等)也支持这些特征,需要: + +### 1. 修改forward方法 + +添加自定义特征参数: + +```python +def forward(self, time_seqs, type_seqs, attention_mask, + semantic_vectors=None, deviation_features=None, + is_spontaneous=None, structure_mask=None): + # 原有代码 + enc_output = self.layer_type_emb(type_seqs) + + # 添加语义特征处理 + if semantic_vectors is not None: + semantic_projection = nn.Linear(semantic_dim, hidden_size) + semantic_emb = semantic_projection(semantic_vectors) + enc_output = enc_output + semantic_emb + + # 添加偏差特征处理 + if deviation_features is not None: + deviation_embedding = nn.Linear(3, hidden_size) + deviation_emb = deviation_embedding(deviation_features) + enc_output = enc_output + deviation_emb + + # ... +``` + +### 2. 修改loglike_loss方法 + +从batch中提取自定义特征: + +```python +def loglike_loss(self, batch): + time_seqs = batch[0] + time_delta_seqs = batch[1] + type_seqs = batch[2] + batch_non_pad_mask = batch[3] + attention_mask = batch[4] + + # 提取自定义特征(如果提供) + semantic_vectors = batch[5] if len(batch) > 5 else None + deviation_features = batch[6] if len(batch) > 6 else None + is_spontaneous = batch[7] if len(batch) > 7 else None + + # 调用forward时传入自定义特征 + enc_out = self.forward( + time_seqs[:, :-1], + type_seqs[:, :-1], + attention_mask[:, :-1, :-1], + semantic_vectors[:, :-1] if semantic_vectors is not None else None, + deviation_features[:, :-1] if deviation_features is not None else None, + is_spontaneous[:, :-1] if is_spontaneous is not None else None + ) + # ... +``` + +### 3. 添加必要的层 + +在`__init__`中添加特征处理层: + +```python +def __init__(self, model_config): + super().__init__(model_config) + + # 添加语义特征投影层 + if getattr(model_config, 'use_semantic', False): + self.semantic_projection = nn.Linear( + model_config.semantic_dim, + model_config.hidden_size + ) + + # 添加偏差特征嵌入层 + if getattr(model_config, 'use_deviation', False): + self.deviation_embedding = nn.Linear(3, model_config.hidden_size) + + # 添加自发/被@特征嵌入层 + self.spontaneous_embedding = nn.Embedding(3, model_config.hidden_size) +``` + +## 推荐方案 + +### 方案1:使用RobotTHP(推荐) + +如果需要在EasyTPP框架中使用自定义特征,**推荐直接使用RobotTHP**,因为: +- ✅ 已经完整支持所有自定义特征 +- ✅ 符合EasyTPP标准接口 +- ✅ 可以直接与其他模型比较 +- ✅ 使用方便,无需修改代码 + +### 方案2:扩展其他模型 + +如果需要让其他模型也支持这些特征: +1. 参考`RobotTHP`的实现 +2. 按照上述步骤修改目标模型 +3. 确保与EasyTPP标准接口兼容 + +## 示例:使用RobotTHP + +```python +from easy_tpp.model import TorchRobotTHP +from easy_tpp.preprocess.robert_dataset import RobertTPPDataset +from easy_tpp.preprocess.robert_tokenizer import RobertEventTokenizer + +# 准备数据(包含自定义特征) +data_dict = { + 'time_seqs': [...], + 'type_seqs': [...], + 'time_delta_seqs': [...], + 'semantic_vectors': [...], # 语义向量 + 'deviation_features': [...], # 偏差特征 + 'is_spontaneous': [...] # 自发/被@标记 +} + +# 创建数据集和分词器 +dataset = RobertTPPDataset(data_dict) +tokenizer = RobertEventTokenizer(config, use_semantic=True, use_deviation=True) + +# 创建模型(自动支持所有特征) +model = TorchRobotTHP(model_config) + +# 训练 +for batch in data_loader: + loss, num_events = model.loglike_loss(batch.values()) +``` + +## 总结 + +- **只有RobotTHP支持所有自定义特征** +- 其他模型仅支持标准特征(时间、类型、掩码) +- 如果需要使用自定义特征,建议使用RobotTHP +- 如果需要扩展其他模型,可以参考RobotTHP的实现方式 + diff --git a/easy_tpp/model/torch_model/ROBOT_THP_README.md b/easy_tpp/model/torch_model/ROBOT_THP_README.md new file mode 100644 index 0000000000000000000000000000000000000000..2852592b23f61cfbbb5bbe2ae31d8cd37e987e51 --- /dev/null +++ b/easy_tpp/model/torch_model/ROBOT_THP_README.md @@ -0,0 +1,149 @@ +# Robot-THP 模型说明 + +## 概述 + +`RobotTHP` 是专门为"评论罗伯特"场景设计的 Transformer Hawkes Process 模型,结合了: + +1. **EasyTPP THP的优点**: + - 可学习的 ScaledSoftplus(事件类型特定的beta参数) + - MC采样近似积分(精确计算survival term) + +2. **语义增强型THP的特点**: + - MLP强度函数(灵活表达) + - 多模态特征融合(语义、偏差、自发/被@) + - 结构感知注意力机制 + +## 核心特性 + +### 1. MLP强度函数 +```python +intensity_states = MLP(hidden_states) + factor_intensity_base +lambda = ScaledSoftplus(intensity_states) +``` + +### 2. ScaledSoftplus +- 每个事件类型有独立的可学习beta参数 +- 提供事件类型特定的激活曲线 +- 更精细的强度控制 + +### 3. MC采样积分 +- 使用 `make_dtime_loss_samples()` 采样时间点 +- 使用 `compute_states_at_sample_times()` 计算采样点强度 +- 使用 `compute_loglikelihood()` 计算精确的对数似然 + +### 4. 多模态特征支持 +- 语义向量(BERT嵌入) +- 偏差特征(语境偏差、情感偏差、困惑度) +- 自发/被@特征 +- 结构感知掩码 + +## 使用方法 + +### 基本使用 + +```python +from easy_tpp.model.torch_model.torch_robot_thp import RobotTHP + +# 模型配置 +model_config = ModelConfig( + hidden_size=128, + num_event_types=4, + num_event_types_pad=5, + num_layers=3, + num_heads=6, + dropout_rate=0.1, + semantic_dim=768, + use_semantic=True, + use_deviation=True, + use_structure_mask=True, + loss_integral_num_sample_per_step=20, + use_mc_samples=True +) + +# 创建模型 +model = RobotTHP(model_config) +``` + +### 训练 + +```python +# 批次数据格式(tuple/list) +batch = ( + time_seqs, # [batch_size, seq_len] + time_delta_seqs, # [batch_size, seq_len] + type_seqs, # [batch_size, seq_len] + batch_non_pad_mask, # [batch_size, seq_len] + attention_mask, # [batch_size, seq_len, seq_len] + semantic_vectors, # [batch_size, seq_len, semantic_dim] (可选) + deviation_features, # [batch_size, seq_len, 3] (可选) + is_spontaneous, # [batch_size, seq_len] (可选) + structure_mask # [batch_size, seq_len, seq_len] (可选) +) + +# 计算损失 +loss, num_events = model.loglike_loss(batch) +``` + +### 前向传播 + +```python +# 前向传播 +hidden_states = model.forward( + time_seqs=time_seqs, + type_seqs=type_seqs, + attention_mask=attention_mask, + semantic_vectors=semantic_vectors, # 可选 + deviation_features=deviation_features, # 可选 + is_spontaneous=is_spontaneous, # 可选 + structure_mask=structure_mask # 可选 +) +``` + +## 与EasyTPP THP的区别 + +| 特性 | EasyTPP THP | Robot-THP | +|------|-------------|-----------| +| 强度函数 | 线性组合 + ScaledSoftplus | MLP + ScaledSoftplus | +| 时间衰减 | 显式 `α*Δt` | 隐式(在Transformer中) | +| 特征输入 | 仅类型+时间 | 类型+时间+语义+偏差+自发/被@ | +| 积分计算 | MC采样 | MC采样(相同) | +| 适用场景 | 通用TPP | 社交网络级联 | + +## 配置参数 + +### 必需参数 +- `hidden_size`: 隐藏层维度 +- `num_event_types`: 事件类型数(不含padding) +- `num_event_types_pad`: 事件类型数(含padding) +- `num_layers`: Transformer层数 +- `num_heads`: 注意力头数 +- `dropout_rate`: Dropout率 + +### 可选参数 +- `semantic_dim`: 语义向量维度(默认768) +- `use_semantic`: 是否使用语义特征(默认False) +- `use_deviation`: 是否使用偏差特征(默认False) +- `use_structure_mask`: 是否使用结构感知掩码(默认False) +- `loss_integral_num_sample_per_step`: MC采样点数(默认20) +- `use_mc_samples`: 是否使用MC采样(默认True) + +## 优势 + +1. **理论完备性**:使用ScaledSoftplus和MC采样,符合TPP理论 +2. **表达能力**:MLP强度函数可以学习复杂模式 +3. **领域特定**:支持语义、偏差等社交网络特征 +4. **精确计算**:MC采样提供更准确的积分近似 + +## 注意事项 + +1. **数据格式**:支持tuple/list和dict两种输入格式 +2. **可选特征**:语义、偏差等特征是可选的,如果不提供会自动跳过 +3. **掩码处理**:会自动生成默认的因果掩码(如果未提供) +4. **设备管理**:自动使用配置中的device + +## 示例 + +完整的使用示例请参考: +- `easy_tpp/examples/` 目录下的示例代码 +- 或参考 `torch_thp.py` 的使用方式 + diff --git a/easy_tpp/model/torch_model/ROBOT_THP_USAGE.md b/easy_tpp/model/torch_model/ROBOT_THP_USAGE.md new file mode 100644 index 0000000000000000000000000000000000000000..debfd8e33f56877c9ba169e11b6353647de60f2f --- /dev/null +++ b/easy_tpp/model/torch_model/ROBOT_THP_USAGE.md @@ -0,0 +1,168 @@ +# RobotTHP 使用指南 + +## 概述 + +`RobotTHP` 已完全符合 EasyTPP 标准接口,可以直接在 EasyTPP 框架中使用并与其他模型进行比较。 + +## 接口规范 + +### 1. forward 方法 +```python +def forward(self, time_seqs, type_seqs, attention_mask, + semantic_vectors=None, deviation_features=None, + is_spontaneous=None, structure_mask=None): + """ + 前向传播 + + Args: + time_seqs: [batch_size, seq_len], 时间序列 + type_seqs: [batch_size, seq_len], 事件类型序列 + attention_mask: [batch_size, seq_len, seq_len], 注意力掩码 + semantic_vectors: [batch_size, seq_len, semantic_dim] (可选) + deviation_features: [batch_size, seq_len, 3] (可选) + is_spontaneous: [batch_size, seq_len] (可选) + structure_mask: [batch_size, seq_len, seq_len] (可选) + + Returns: + tensor: [batch_size, seq_len, hidden_size], 隐藏状态 + """ +``` + +### 2. loglike_loss 方法 +```python +def loglike_loss(self, batch): + """ + 计算对数似然损失(符合EasyTPP标准) + + Args: + batch (tuple, list): EasyTPP标准批次格式 + batch[0]: time_seqs [batch_size, seq_len] + batch[1]: time_delta_seqs [batch_size, seq_len] + batch[2]: type_seqs [batch_size, seq_len] + batch[3]: batch_non_pad_mask [batch_size, seq_len] + batch[4]: attention_mask [batch_size, seq_len, seq_len] + batch[5]: semantic_vectors (可选) + batch[6]: deviation_features (可选) + batch[7]: is_spontaneous (可选) + batch[8]: structure_mask (可选) + + Returns: + tuple: (loss, num_events) - 符合EasyTPP标准 + """ +``` + +### 3. compute_intensities_at_sample_times 方法 +```python +def compute_intensities_at_sample_times(self, + time_seqs, + time_delta_seqs, + type_seqs, + sample_dtimes, + **kwargs): + """ + 计算采样时间点的强度值(符合EasyTPP标准) + + Args: + time_seqs: [batch_size, seq_len] + time_delta_seqs: [batch_size, seq_len] + type_seqs: [batch_size, seq_len] + sample_dtimes: [batch_size, seq_len, num_samples] + **kwargs: 可选参数 + + Returns: + tensor: [batch_size, seq_len, num_samples, num_event_types] + """ +``` + +## 在EasyTPP中使用 + +### 1. 配置文件 + +创建配置文件 `configs/robot_thp.yaml`: + +```yaml +RobotTHP_train: + base_config: + stage: train + backend: torch + dataset_id: your_dataset + runner_id: std_tpp + model_id: RobotTHP # 模型ID + model_class_path: easy_tpp.model.torch_model.torch_robot_thp.RobotTHP + base_dir: './results/checkpoints/robot_thp/' + + trainer_config: + batch_size: 32 + max_epoch: 30 + optimizer: adam + learning_rate: 1e-4 + gpu: 0 + + model_config: + hidden_size: 128 + num_layers: 3 + num_heads: 6 + dropout_rate: 0.1 + num_event_types: 4 + num_event_types_pad: 5 + semantic_dim: 768 + use_semantic: true + use_deviation: true + use_structure_mask: true + loss_integral_num_sample_per_step: 20 + use_mc_samples: true +``` + +### 2. 运行训练 + +```bash +python -m easy_tpp.main --config_path configs/robot_thp.yaml +``` + +### 3. 与其他模型比较 + +RobotTHP 会自动与其他模型(THP, NHP, SAHP等)一起参与比较,因为: +- ✅ 继承自 `TorchBaseModel` +- ✅ 实现了标准接口 +- ✅ 已注册到模型列表 + +## 核心特性 + +1. **MLP强度函数**:灵活表达复杂模式 +2. **ScaledSoftplus**:事件类型特定的可学习beta参数 +3. **MC采样积分**:精确计算survival term +4. **多模态特征**:支持语义、偏差、自发/被@特征 + +## 与标准THP的区别 + +| 特性 | THP | RobotTHP | +|------|-----|----------| +| 强度函数 | 线性组合 | MLP | +| ScaledSoftplus | ✅ | ✅ | +| MC采样 | ✅ | ✅ | +| 语义特征 | ❌ | ✅ | +| 偏差特征 | ❌ | ✅ | +| 自发/被@ | ❌ | ✅ | + +## 注意事项 + +1. **可选特征**:如果不提供语义、偏差等特征,模型会自动跳过 +2. **批次格式**:必须使用EasyTPP的标准批次格式(tuple/list) +3. **模型注册**:已在 `easy_tpp/model/__init__.py` 中注册为 `TorchRobotTHP` + +## 示例 + +```python +from easy_tpp.model import TorchRobotTHP +from easy_tpp.config_factory import Config + +# 加载配置 +config = Config(config_path='configs/robot_thp.yaml') + +# 创建模型 +model = TorchRobotTHP(config.model_config) + +# 训练(EasyTPP会自动处理批次格式) +# 使用 EasyTPP 的训练流程即可 +``` + diff --git a/easy_tpp/model/torch_model/__init__.py b/easy_tpp/model/torch_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/easy_tpp/model/torch_model/torch_anhn.py b/easy_tpp/model/torch_model/torch_anhn.py new file mode 100644 index 0000000000000000000000000000000000000000..f487439796c07770fa6c1ee2f1eca300e838d471 --- /dev/null +++ b/easy_tpp/model/torch_model/torch_anhn.py @@ -0,0 +1,273 @@ +import torch +from torch import nn + +from easy_tpp.model.torch_model.torch_baselayer import MultiHeadAttention +from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel + + +class ANHN(TorchBaseModel): + """Torch implementation of Attentive Neural Hawkes Network, IJCNN 2021. + http://arxiv.org/abs/2211.11758 + """ + + def __init__(self, model_config): + """Initialize the model + + Args: + model_config (EasyTPP.ModelConfig): config of model specs. + """ + super(ANHN, self).__init__(model_config) + + self.d_time = model_config['time_emb_size'] + self.use_norm = model_config['use_ln'] + + self.n_layers = model_config['num_layers'] + self.n_head = model_config['num_heads'] + self.dropout = model_config['dropout'] + + self.layer_rnn = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, batch_first=True) + + self.lambda_w = torch.empty([self.hidden_size, self.num_event_types]) + self.lambda_b = torch.empty([self.num_event_types, 1]) + nn.init.xavier_normal_(self.lambda_w) + nn.init.xavier_normal_(self.lambda_b) + + self.layer_time_delta = nn.Sequential(nn.Linear(2 * self.hidden_size, self.hidden_size), nn.Softplus()) + + self.layer_base_intensity = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size), nn.Sigmoid()) + + self.layer_att = MultiHeadAttention(self.n_head, + self.hidden_size, + self.hidden_size, + self.dropout) + + self.layer_intensity = nn.Sequential(nn.Linear(self.hidden_size, self.num_event_types), nn.Softplus()) + + self.layer_temporal_emb = nn.Linear(1, self.hidden_size) + + def forward(self, dtime_seqs, type_seqs, attention_mask): + """Call the model. + + Args: + dtime_seqs (tensor): [batch_size, seq_len]. + type_seqs (tensor): [batch_size, seq_len]. + attention_mask (tensor): [batch_size, seq_len, hidden_size]. + + Returns: + list: hidden states, [batch_size, seq_len, hidden_size], states right before the event happens; + stacked decay states, [batch_size, max_seq_length, 4, hidden_dim], states right after + the event happens. + """ + + # [batch_size, seq_len, hidden_size] + event_emb = self.layer_type_emb(type_seqs) + + # [batch_size, seq_len, hidden_size] + rnn_output, _ = self.layer_rnn(event_emb) + + # [batch_size, seq_len, hidden_size] + # mu in Equation (3) + intensity_base = self.layer_base_intensity(rnn_output) + + # [batch_size, num_head, seq_len, seq_len] + _, att_weight = self.layer_att(rnn_output, + rnn_output, + rnn_output, + mask=attention_mask, + output_weight=True) + + # [batch_size, seq_len, seq_len, 1] + att_weight = torch.sum(att_weight, dim=1)[..., None] + + # At each step, alpha and delta reply on all previous event embeddings because there is a cumsum in Equation + # (3), therefore the alpha and beta have shape [batch_size, seq_len, seq_len, hidden_size] when performing + # matrix operations. + # [batch_size, seq_len, seq_len, hidden_dim] + # alpha in Equation (3) + intensity_alpha = att_weight * rnn_output[:, None, :, :] + + # compute delta + max_len = event_emb.size()[1] + + # [batch_size, seq_len, seq_len, hidden_dim] + left = rnn_output[:, None, :, :].repeat(1, max_len, 1, 1) + right = rnn_output[:, :, None, :].repeat(1, 1, max_len, 1) + # [batch_size, seq_len, seq_len, hidden_dim * 2] + cur_prev_concat = torch.concat([left, right], dim=-1) + # [batch_size, seq_len, seq_len, hidden_dim] + intensity_delta = self.layer_time_delta(cur_prev_concat) + + # compute time elapse + # [batch_size, seq_len, seq_len, 1] + base_dtime, target_cumsum_dtime = self.compute_cumsum_dtime(dtime_seqs) + + # [batch_size, max_len, hidden_size] + imply_lambdas = self.compute_states_at_event_times(intensity_base, + intensity_alpha, + intensity_delta, + target_cumsum_dtime) + + return imply_lambdas, (intensity_base, intensity_alpha, intensity_delta), (base_dtime, target_cumsum_dtime) + + def loglike_loss(self, batch): + """Compute the loglike loss. + + Args: + batch (list): batch input. + + Returns: + tuple: loglikelihood loss and num of events. + """ + time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask, type_mask = batch + + imply_lambdas, (intensity_base, intensity_alpha, intensity_delta), (base_dtime, target_cumsum_dtime) \ + = self.forward(time_delta_seqs[:, 1:], + type_seqs[:, :-1], + attention_mask[:, 1:, :-1]) + lambda_at_event = self.layer_intensity(imply_lambdas) + + # Num of samples in each batch and num of event time point in the sequence + batch_size, seq_len, _ = lambda_at_event.size() + + # Compute the big lambda integral in equation (8) + # 1 - take num_mc_sample rand points in each event interval + # 2 - compute its lambda value for every sample point + # 3 - take average of these sample points + # 4 - times the interval length + + # interval_t_sample - [batch_size, num_times=max_len-1, num_mc_sample] + # for every batch and every event point => do a sampling (num_mc_sampling) + # the first dtime is zero, so we use time_delta_seqs[:, 1:] + interval_t_sample = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) + + state_t_sample = self.compute_states_at_sample_times(intensity_base, intensity_alpha, intensity_delta, + base_dtime, interval_t_sample) + lambda_t_sample = self.layer_intensity(state_t_sample) + + event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, + lambdas_loss_samples=lambda_t_sample, + time_delta_seq=time_delta_seqs[:, 1:], + seq_mask=batch_non_pad_mask[:, 1:], + lambda_type_mask=type_mask[:, 1:]) + + # (num_samples, num_times) + loss = - (event_ll - non_event_ll).sum() + return loss, num_events + + def compute_cumsum_dtime(self, dtime_seqs): + """Compute cumulative delta times. + + Args: + dtime_seqs (tensor): [batch_size, seq_len]. + + Returns: + tensor: [batch_size, seq_len]. + """ + # try to replicate tf.cumsum() + # [batch_size, seq_len, num_sample] + # [0, dt_1, dt_2] => [dt_1 + dt_2, dt_2, 0] + cum_dtimes = torch.cumsum(torch.flip(dtime_seqs, dims=[-1]), dim=1) + cum_dtimes = torch.concat([torch.zeros_like(cum_dtimes[:, :1]), cum_dtimes[:, 1:]], dim=1) + + # [batch_size, seq_len, seq_len, 1] (lower triangular: positive, upper: negative, diagonal: zero) + base_elapses = torch.unsqueeze(cum_dtimes[:, None, :] - cum_dtimes[:, :, None], dim=-1) + + # [batch_size, seq_len, seq_len, 1] + target_cumsum = base_elapses + dtime_seqs[:, :, None, None] + + return base_elapses, target_cumsum + + def compute_states_at_event_times(self, intensity_base, intensity_alpha, intensity_delta, cumsum_dtimes): + """Compute implied lambda based on Equation (3). + + Args: + intensity_base (tensor): [batch_size, seq_len, (num_sample), hidden_size] + intensity_alpha (tensor): [batch_size, seq_len, seq_len, (num_sample), hidden_size] + intensity_delta (tensor): [batch_size, seq_len, seq_len, (num_sample), hidden_size] + cumsum_dtimes: [batch_size, seq_len, (num_sample), 1] + + Returns: + hidden states at all cumsum_dtimes: [batch_size, seq_len, num_samples, hidden_size] + + """ + # to avoid nan calculated by exp after (nan * 0 = nan) + elapse = torch.abs(cumsum_dtimes) + + # [batch_size, seq_len, hidden_dim] + cumsum_term = torch.sum(intensity_alpha * torch.exp(-intensity_delta * elapse), dim=-2) + # [batch_size, seq_len, hidden_dim] + imply_lambdas = intensity_base + cumsum_term + + return imply_lambdas + + def compute_states_at_sample_times(self, intensity_base, intensity_alpha, intensity_delta, base_dtime, + sample_dtimes): + """Compute the hidden states at sampled times. + + Args: + intensity_base (tensor): [batch_size, seq_len, hidden_size]. + intensity_alpha (tensor): [batch_size, seq_len, seq_len, hidden_size]. + intensity_delta (tensor): [batch_size, seq_len, seq_len, hidden_size]. + base_dtime (tensor): [batch_size, seq_len, seq_len, hidden_size]. + sample_dtimes (tensor): [batch_size, seq_len, num_samples]. + + Returns: + tensor: hidden state at each sampled time, [batch_size, seq_len, num_sample, hidden_size]. + """ + + # [batch_size, seq_len, 1, hidden_size] + mu = intensity_base[:, :, None] + + # [batch_size, seq_len, 1, seq_len, hidden_size] + alpha = intensity_alpha[:, :, None] + delta = intensity_delta[:, :, None] + base_elapses = base_dtime[:, :, None] + + # [batch_size, seq_len, num_samples, 1, 1] + sample_dtimes_ = sample_dtimes[:, :, :, None, None] + + states_samples = [] + seq_len = intensity_base.size()[1] + for _ in range(seq_len): + states_samples_ = self.compute_states_at_event_times(mu, alpha, delta, base_elapses + sample_dtimes_) + states_samples.append(states_samples_) + + # [batch_size, seq_len, num_sample, hidden_size] + states_samples = torch.stack(states_samples, dim=1) + return states_samples + + def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs): + """Compute the intensity at sampled times. + + Args: + time_seqs (tensor): [batch_size, seq_len], sequences of timestamps. + time_delta_seqs (tensor): [batch_size, seq_len], sequences of delta times. + type_seqs (tensor): [batch_size, seq_len], sequences of event types. + sampled_dtimes (tensor): [batch_size, seq_len, num_sample], sampled time delta sequence. + + Returns: + tensor: intensities as sampled_dtimes, [batch_size, seq_len, num_samples, event_num]. + """ + + attention_mask = kwargs.get('attention_mask', None) + compute_last_step_only = kwargs.get('compute_last_step_only', False) + + if attention_mask is None: + batch_size, seq_len = time_seqs.size() + attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).unsqueeze(0) + attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool) + + # [batch_size, seq_len, num_samples] + imply_lambdas, (intensity_base, intensity_alpha, intensity_delta), (base_dtime, target_cumsum_dtime) \ + = self.forward(time_delta_seqs, type_seqs, attention_mask) + + # [batch_size, seq_len, num_samples, hidden_size] + encoder_output = self.compute_states_at_sample_times(intensity_base, intensity_alpha, intensity_delta, + base_dtime, sample_dtimes) + + if compute_last_step_only: + lambdas = self.softplus(encoder_output[:, -1:, :, :]) + else: + # [batch_size, seq_len, num_samples, num_event_types] + lambdas = self.softplus(encoder_output) + return lambdas diff --git a/easy_tpp/model/torch_model/torch_attnhp.py b/easy_tpp/model/torch_model/torch_attnhp.py new file mode 100644 index 0000000000000000000000000000000000000000..2d7e70bcb89f0ff97a5c46e125d54ba9b9ba1bda --- /dev/null +++ b/easy_tpp/model/torch_model/torch_attnhp.py @@ -0,0 +1,321 @@ +import math + +import torch +from torch import nn + +from easy_tpp.model.torch_model.torch_baselayer import EncoderLayer, MultiHeadAttention, ScaledSoftplus +from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel + + +class AttNHP(TorchBaseModel): + """Torch implementation of Attentive Neural Hawkes Process, ICLR 2022. + https://arxiv.org/abs/2201.00044. + Source code: https://github.com/yangalan123/anhp-andtt/blob/master/anhp/model/xfmr_nhp_fast.py + """ + + def __init__(self, model_config): + """Initialize the model + + Args: + model_config (EasyTPP.ModelConfig): config of model specs. + """ + super(AttNHP, self).__init__(model_config) + self.d_model = model_config.hidden_size + self.use_norm = model_config.use_ln + self.d_time = model_config.time_emb_size + + self.div_term = torch.exp(torch.arange(0, self.d_time, 2) * -(math.log(10000.0) / self.d_time)).reshape(1, 1, + -1) + + self.n_layers = model_config.num_layers + self.n_head = model_config.num_heads + self.dropout = model_config.dropout_rate + + self.heads = [] + for i in range(self.n_head): + self.heads.append( + nn.ModuleList( + [EncoderLayer( + self.d_model + self.d_time, + MultiHeadAttention(1, self.d_model + self.d_time, self.d_model, self.dropout, + output_linear=False), + + use_residual=False, + dropout=self.dropout + ) + for _ in range(self.n_layers) + ] + ) + ) + self.heads = nn.ModuleList(self.heads) + + if self.use_norm: + self.norm = nn.LayerNorm(self.d_model) + self.inten_linear = nn.Linear(self.d_model * self.n_head, self.num_event_types) + self.softplus = ScaledSoftplus(self.num_event_types) # learnable mark-specific beta + self.layer_event_emb = nn.Linear(self.d_model + self.d_time, self.d_model) + self.layer_intensity = nn.Sequential(self.inten_linear, self.softplus) + self.eps = torch.finfo(torch.float32).eps + + def compute_temporal_embedding(self, time): + """Compute the temporal embedding. + + Args: + time (tensor): [batch_size, seq_len]. + + Returns: + tensor: [batch_size, seq_len, emb_size]. + """ + batch_size = time.size(0) + seq_len = time.size(1) + pe = torch.zeros(batch_size, seq_len, self.d_time).to(time) + _time = time.unsqueeze(-1) + div_term = self.div_term.to(time) + pe[..., 0::2] = torch.sin(_time * div_term) + pe[..., 1::2] = torch.cos(_time * div_term) + + return pe + + def forward_pass(self, init_cur_layer, time_emb, sample_time_emb, event_emb, combined_mask): + """update the structure sequentially. + + Args: + init_cur_layer (tensor): [batch_size, seq_len, hidden_size] + time_emb (tensor): [batch_size, seq_len, hidden_size] + sample_time_emb (tensor): [batch_size, seq_len, hidden_size] + event_emb (tensor): [batch_size, seq_len, hidden_size] + combined_mask (tensor): [batch_size, seq_len, hidden_size] + + Returns: + tensor: [batch_size, seq_len, hidden_size*2] + """ + cur_layers = [] + seq_len = event_emb.size(1) + for head_i in range(self.n_head): + # [batch_size, seq_len, hidden_size] + cur_layer_ = init_cur_layer + for layer_i in range(self.n_layers): + # each layer concats the temporal emb + # [batch_size, seq_len, hidden_size*2] + layer_ = torch.cat([cur_layer_, sample_time_emb], dim=-1) + # make combined input from event emb + layer emb + # [batch_size, seq_len*2, hidden_size*2] + _combined_input = torch.cat([event_emb, layer_], dim=1) + enc_layer = self.heads[head_i][layer_i] + # compute the output + enc_output = enc_layer(_combined_input, combined_mask) + + # the layer output + # [batch_size, seq_len, hidden_size] + _cur_layer_ = enc_output[:, seq_len:, :] + # add residual connection + cur_layer_ = torch.tanh(_cur_layer_) + cur_layer_ + + # event emb + event_emb = torch.cat([enc_output[:, :seq_len, :], time_emb], dim=-1) + + if self.use_norm: + cur_layer_ = self.norm(cur_layer_) + cur_layers.append(cur_layer_) + cur_layer_ = torch.cat(cur_layers, dim=-1) + + return cur_layer_ + + def seq_encoding(self, time_seqs, event_seqs): + """Encode the sequence. + + Args: + time_seqs (tensor): time seqs input, [batch_size, seq_len]. + event_seqs (_type_): event type seqs input, [batch_size, seq_len]. + + Returns: + tuple: event embedding, time embedding and type embedding. + """ + # [batch_size, seq_len, hidden_size] + time_emb = self.compute_temporal_embedding(time_seqs) + # [batch_size, seq_len, hidden_size] + type_emb = torch.tanh(self.layer_type_emb(event_seqs.long())) + # [batch_size, seq_len, hidden_size*2] + event_emb = torch.cat([type_emb, time_emb], dim=-1) + + return event_emb, time_emb, type_emb + + def make_layer_mask(self, attention_mask): + """Create a tensor to do masking on layers. + + Args: + attention_mask (tensor): mask for attention operation, [batch_size, seq_len, seq_len] + + Returns: + tensor: aim to keep the current layer, the same size of attention mask + a diagonal matrix, [batch_size, seq_len, seq_len] + """ + # [batch_size, seq_len, seq_len] + layer_mask = (torch.eye(attention_mask.size(1), device=self.device) < 1).unsqueeze(0).expand_as(attention_mask) + return layer_mask + + def make_combined_att_mask(self, attention_mask, layer_mask): + """Combined attention mask and layer mask. + + Args: + attention_mask (tensor): mask for attention operation, [batch_size, seq_len, seq_len] + layer_mask (tensor): mask for other layers, [batch_size, seq_len, seq_len] + + Returns: + tensor: [batch_size, seq_len * 2, seq_len * 2] + """ + # [batch_size, seq_len, seq_len * 2] + combined_mask = torch.cat([attention_mask, layer_mask], dim=-1) + # [batch_size, seq_len, seq_len * 2] + contextual_mask = torch.cat([attention_mask, torch.ones_like(layer_mask)], dim=-1) + # [batch_size, seq_len * 2, seq_len * 2] + combined_mask = torch.cat([contextual_mask, combined_mask], dim=1) + return combined_mask + + def forward(self, time_seqs, event_seqs, attention_mask, sample_times=None): + """Call the model. + + Args: + time_seqs (tensor): [batch_size, seq_len], sequences of timestamps. + event_seqs (tensor): [batch_size, seq_len], sequences of event types. + attention_mask (tensor): [batch_size, seq_len, seq_len], masks for event sequences. + sample_times (tensor, optional): [batch_size, seq_len, num_samples]. Defaults to None. + + Returns: + tensor: states at sampling times, [batch_size, seq_len, num_samples]. + """ + event_emb, time_emb, type_emb = self.seq_encoding(time_seqs, event_seqs) + init_cur_layer = torch.zeros_like(type_emb) + layer_mask = self.make_layer_mask(attention_mask) + if sample_times is None: + sample_time_emb = time_emb + else: + sample_time_emb = self.compute_temporal_embedding(sample_times) + combined_mask = self.make_combined_att_mask(attention_mask, layer_mask) + cur_layer_ = self.forward_pass(init_cur_layer, time_emb, sample_time_emb, event_emb, combined_mask) + + return cur_layer_ + + def loglike_loss(self, batch): + """Compute the loglike loss. + + Args: + batch (list): batch input. + + Returns: + list: loglike loss, num events. + """ + time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask = batch + # 1. compute event-loglik + # the prediction of last event has no label, so we proceed to the last but one + # att mask => diag is False, not mask. + enc_out = self.forward(time_seqs[:, :-1], type_seqs[:, :-1], attention_mask[:, :-1, :-1], time_seqs[:, 1:]) + # [batch_size, seq_len, num_event_types] + lambda_at_event = self.layer_intensity(enc_out) + + # 2. compute non-event-loglik (using MC sampling to compute integral) + # 2.1 sample times + # [batch_size, seq_len, num_sample] + temp_time = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) + + # [batch_size, seq_len, num_sample] + sample_times = temp_time + time_seqs[:, :-1].unsqueeze(-1) + + # 2.2 compute intensities at sampled times + # [batch_size, seq_len = max_len - 1, num_sample, event_num] + lambda_t_sample = self.compute_intensities_at_sample_times(time_seqs[:, :-1], + time_delta_seqs[:, :-1], # not used + type_seqs[:, :-1], + sample_times, + attention_mask=attention_mask[:, :-1, :-1]) + + event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, + lambdas_loss_samples=lambda_t_sample, + time_delta_seq=time_delta_seqs[:, 1:], + seq_mask=batch_non_pad_mask[:, 1:], + type_seq=type_seqs[:, 1:]) + + # compute loss to minimize + loss = - (event_ll - non_event_ll).sum() + return loss, num_events + + def compute_states_at_sample_times(self, + time_seqs, + type_seqs, + attention_mask, + sample_times): + """Compute the states at sampling times. + + Args: + time_seqs (tensor): [batch_size, seq_len], sequences of timestamps. + time_delta_seqs (tensor): [batch_size, seq_len], sequences of delta times. + type_seqs (tensor): [batch_size, seq_len], sequences of event types. + attention_mask (tensor): [batch_size, seq_len, seq_len], masks for event sequences. + sample_dtimes (tensor): delta times in sampling. + + Returns: + tensor: hiddens states at sampling times. + """ + batch_size = type_seqs.size(0) + seq_len = type_seqs.size(1) + num_samples = sample_times.size(-1) + + # [num_samples, batch_size, seq_len] + sample_times = sample_times.permute((2, 0, 1)) + # [num_samples * batch_size, seq_len] + _sample_time = sample_times.reshape(num_samples * batch_size, -1) + # [num_samples * batch_size, seq_len] + _types = type_seqs.expand(num_samples, -1, -1).reshape(num_samples * batch_size, -1) + # [num_samples * batch_size, seq_len] + _times = time_seqs.expand(num_samples, -1, -1).reshape(num_samples * batch_size, -1) + # [num_samples * batch_size, seq_len] + _attn_mask = attention_mask.unsqueeze(0).expand(num_samples, -1, -1, -1).reshape(num_samples * batch_size, + seq_len, + seq_len) + # [num_samples * batch_size, seq_len, hidden_size] + encoder_output = self.forward(_times, + _types, + _attn_mask, + _sample_time) + + # [num_samples, batch_size, seq_len, hidden_size] + encoder_output = encoder_output.reshape(num_samples, batch_size, seq_len, -1) + # [batch_size, seq_len, num_samples, hidden_size] + encoder_output = encoder_output.permute((1, 2, 0, 3)) + return encoder_output + + def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs): + """Compute the intensity at sampled times. + + Args: + time_seqs (tensor): [batch_size, seq_len], sequences of timestamps. + time_delta_seqs (tensor): [batch_size, seq_len], sequences of delta times. + type_seqs (tensor): [batch_size, seq_len], sequences of event types. + sampled_dtimes (tensor): [batch_size, seq_len, num_sample], sampled time delta sequence. + + Returns: + tensor: intensities as sampled_dtimes, [batch_size, seq_len, num_samples, event_num]. + """ + attention_mask = kwargs.get('attention_mask', None) + compute_last_step_only = kwargs.get('compute_last_step_only', False) + + if attention_mask is None: + batch_size, seq_len = time_seqs.size() + attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).unsqueeze(0).to(type_seqs.device) + attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool) + + if sample_dtimes.size()[1] < time_seqs.size()[1]: + # we pass sample_dtimes for last time step here + # we do a temp solution + # [batch_size, seq_len, num_samples] + sample_dtimes = time_seqs[:, :, None] + torch.tile(sample_dtimes, [1, time_seqs.size()[1], 1]) + + # [batch_size, seq_len, num_samples, hidden_size] + encoder_output = self.compute_states_at_sample_times(time_seqs, type_seqs, attention_mask, sample_dtimes) + + if compute_last_step_only: + lambdas = self.layer_intensity(encoder_output[:, -1:, :, :]) + else: + # [batch_size, seq_len, num_samples, num_event_types] + lambdas = self.layer_intensity(encoder_output) + return lambdas diff --git a/easy_tpp/model/torch_model/torch_baselayer.py b/easy_tpp/model/torch_model/torch_baselayer.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ade06a5e4ff4d8dd4d93ad177459d340e6504e --- /dev/null +++ b/easy_tpp/model/torch_model/torch_baselayer.py @@ -0,0 +1,289 @@ +import math + +import torch +from torch import nn + + +def attention(query, key, value, mask=None, dropout=None): + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) + if mask is not None: + # small change here -- we use "1" for masked element + scores = scores.masked_fill(mask > 0, -1e9) + p_attn = torch.softmax(scores, dim=-1) + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn, value), p_attn + + +class ScaledSoftplus(nn.Module): + ''' + Use different beta for mark-specific intensities + ''' + def __init__(self, num_marks, threshold=20.): + super(ScaledSoftplus, self).__init__() + self.threshold = threshold + self.log_beta = nn.Parameter(torch.zeros(num_marks), requires_grad=True) # [num_marks] + + def forward(self, x): + ''' + :param x: [..., num_marks] + ''' + beta = self.log_beta.exp() + beta_x = beta * x + return torch.where( + beta_x <= self.threshold, + torch.log1p(beta_x.clamp(max=math.log(1e5)).exp()) / beta, + x, # if above threshold, then the transform is effectively linear + ) + + +class MultiHeadAttention(nn.Module): + def __init__(self, n_head, d_input, d_model, dropout=0.1, output_linear=False): + super(MultiHeadAttention, self).__init__() + assert d_model % n_head == 0 + self.n_head = n_head + self.d_k = d_model // n_head + self.d_v = self.d_k + self.d_model = d_model + self.output_linear = output_linear + + if output_linear: + self.linears = nn.ModuleList( + [nn.Linear(d_input, d_model) for _ in range(3)] + [nn.Linear(d_model, d_model), ]) + else: + self.linears = nn.ModuleList([nn.Linear(d_input, d_model) for _ in range(3)]) + + self.dropout = nn.Dropout(p=dropout) + + def forward(self, query, key, value, mask, output_weight=False): + if mask is not None: + mask = mask.unsqueeze(1) + nbatches = query.size(0) + + query, key, value = [ + lin_layer(x).view(nbatches, -1, self.n_head, self.d_k).transpose(1, 2) + for lin_layer, x in zip(self.linears, (query, key, value)) + ] + x, attn_weight = attention(query, key, value, mask=mask, dropout=self.dropout) + + x = x.transpose(1, 2).contiguous() \ + .view(nbatches, -1, self.n_head * self.d_k) + + if self.output_linear: + if output_weight: + return self.linears[-1](x), attn_weight + else: + return self.linears[-1](x) + else: + if output_weight: + return x, attn_weight + else: + return x + + +class SublayerConnection(nn.Module): + # used for residual connection + def __init__(self, d_model, dropout): + super(SublayerConnection, self).__init__() + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer): + return x + self.dropout(sublayer(self.norm(x))) + + +class EncoderLayer(nn.Module): + def __init__(self, d_model, self_attn, feed_forward=None, use_residual=False, dropout=0.1): + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.use_residual = use_residual + if use_residual: + self.sublayer = nn.ModuleList([SublayerConnection(d_model, dropout) for _ in range(2)]) + self.d_model = d_model + + def forward(self, x, mask): + if self.use_residual: + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) + if self.feed_forward is not None: + return self.sublayer[1](x, self.feed_forward) + else: + return x + else: + x = self.self_attn(x, x, x, mask) + if self.feed_forward is not None: + return self.feed_forward(x) + else: + return x + + +class TimePositionalEncoding(nn.Module): + """Temporal encoding in THP, ICML 2020 + """ + + def __init__(self, d_model, max_len=5000, device='cpu'): + super().__init__() + i = torch.arange(0, d_model, 1, device=device) + div_term = (2 * (i // 2).float() * -(math.log(10000.0) / d_model)).exp() + self.register_buffer('div_term', div_term) + + def forward(self, x): + """Compute time positional encoding defined in Equation (2) in THP model. + + Args: + x (tensor): time_seqs, [batch_size, seq_len] + + Returns: + temporal encoding vector, [batch_size, seq_len, model_dim] + + """ + result = x.unsqueeze(-1) * self.div_term + result[:, :, 0::2] = torch.sin(result[:, :, 0::2]) + result[:, :, 1::2] = torch.cos(result[:, :, 1::2]) + return result + + +class TimeShiftedPositionalEncoding(nn.Module): + """Time shifted positional encoding in SAHP, ICML 2020 + """ + + def __init__(self, d_model, max_len=5000, device='cpu'): + super().__init__() + # [max_len, 1] + position = torch.arange(0, max_len, device=device).float().unsqueeze(1) + # [model_dim //2 ] + div_term = (torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model)).exp() + + self.layer_time_delta = nn.Linear(1, d_model // 2, bias=False) + + self.register_buffer('position', position) + self.register_buffer('div_term', div_term) + + def forward(self, x, interval): + """ + + Args: + x: time_seq, [batch_size, seq_len] + interval: time_delta_seq, [batch_size, seq_len] + + Returns: + Time shifted positional encoding defined in Equation (8) in SAHP model + + """ + phi = self.layer_time_delta(interval.unsqueeze(-1)) + aa = len(x.size()) + if aa > 1: + length = x.size(1) + else: + length = x.size(0) + + arc = (self.position[:length] * self.div_term).unsqueeze(0) + + pe_sin = torch.sin(arc + phi) + pe_cos = torch.cos(arc + phi) + pe = torch.cat([pe_sin, pe_cos], dim=-1) + + return pe + + +class GELU(nn.Module): + """GeLu activation function + """ + + def forward(self, x): + return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +class Identity(nn.Module): + + def forward(self, inputs): + return inputs + + +def activation_layer(act_name): + """Construct activation layers + Args: + act_name: str or nn.Module, name of activation function + Return: + act_layer: activation layer + + """ + if isinstance(act_name, str): + if act_name.lower() == 'sigmoid': + act_layer = nn.Sigmoid() + elif act_name.lower() == 'linear': + act_layer = Identity() + elif act_name.lower() == 'relu': + act_layer = nn.ReLU(inplace=True) + elif act_name.lower() == 'prelu': + act_layer = nn.PReLU() + elif act_name.lower() == 'gelu': + act_layer = GELU() + elif issubclass(act_name, nn.Module): + act_layer = act_name() + else: + raise NotImplementedError + + return act_layer + + +class DNN(nn.Module): + """The Multi Layer Percetron + Input shape + - nD tensor with shape: ``(batch_size, ..., input_dim)``. + The most common situation would be a 2D input with shape ``(batch_size, input_dim)``. + Output shape + - nD tensor with shape: ``(batch_size, ..., hidden_size[-1])``. + For instance, for a 2D input with shape ``(batch_size, input_dim)``, + the output would have shape ``(batch_size, hidden_size[-1])``. + Arguments + - **inputs_dim**: input feature dimension. + - **hidden_size**:list of positive integer, the layer number and units in each layer. + - **activation**: Activation function to use. + - **l2_reg**: float between 0 and 1. L2 regularizer strength applied to the kernel weights matrix. + - **dropout_rate**: float in [0,1). Fraction of the units to dropout. + - **use_bn**: bool. Whether use BatchNormalization before activation or not. + - **seed**: A Python integer to use as random seed. + """ + + def __init__(self, inputs_dim, hidden_size, activation='relu', l2_reg=0, dropout_rate=0, use_bn=False, + init_std=0.0001): + super(DNN, self).__init__() + self.dropout_rate = dropout_rate + self.dropout = nn.Dropout(dropout_rate) + self.l2_reg = l2_reg + self.use_bn = use_bn + if len(hidden_size) == 0: + raise ValueError("hidden_units is empty!!") + hidden_size = [inputs_dim] + list(hidden_size) + + self.linears = nn.ModuleList( + [nn.Linear(hidden_size[i], hidden_size[i + 1]) for i in range(len(hidden_size) - 1)]) + + if self.use_bn: + self.bn = nn.ModuleList( + [nn.BatchNorm1d(hidden_size[i + 1]) for i in range(len(hidden_size) - 1)]) + + self.activation_layers = nn.ModuleList( + [activation_layer(activation) for i in range(len(hidden_size) - 1)]) + + for name, tensor in self.linears.named_parameters(): + if 'weight' in name: + nn.init.normal_(tensor, mean=0, std=init_std) + + def forward(self, inputs): + deep_input = inputs + + for i in range(len(self.linears)): + + fc = self.linears[i](deep_input) + + if self.use_bn: + fc = self.bn[i](fc) + + fc = self.activation_layers[i](fc) + + fc = self.dropout(fc) + deep_input = fc + return deep_input diff --git a/easy_tpp/model/torch_model/torch_basemodel.py b/easy_tpp/model/torch_model/torch_basemodel.py new file mode 100644 index 0000000000000000000000000000000000000000..6b3e42c1a022257df55701e37f5ae81c17bee6ec --- /dev/null +++ b/easy_tpp/model/torch_model/torch_basemodel.py @@ -0,0 +1,270 @@ +""" Base model with common functionality """ + +import torch +from torch import nn +from torch.nn import functional as F + +from easy_tpp.model.torch_model.torch_thinning import EventSampler +from easy_tpp.utils import set_device + + +class TorchBaseModel(nn.Module): + def __init__(self, model_config): + """Initialize the BaseModel + + Args: + model_config (EasyTPP.ModelConfig): model spec of configs + """ + super(TorchBaseModel, self).__init__() + self.loss_integral_num_sample_per_step = model_config.loss_integral_num_sample_per_step + self.hidden_size = model_config.hidden_size + self.num_event_types = model_config.num_event_types # not include [PAD], [BOS], [EOS] + self.num_event_types_pad = model_config.num_event_types_pad # include [PAD], [BOS], [EOS] + self.pad_token_id = model_config.pad_token_id + self.eps = torch.finfo(torch.float32).eps + + self.layer_type_emb = nn.Embedding(self.num_event_types_pad, # have padding + self.hidden_size, + padding_idx=self.pad_token_id) + + self.gen_config = model_config.thinning + self.event_sampler = None + self.device = set_device(model_config.gpu) + self.use_mc_samples = model_config.use_mc_samples + + self.to(self.device) + + if self.gen_config: + self.event_sampler = EventSampler(num_sample=self.gen_config.num_sample, + num_exp=self.gen_config.num_exp, + over_sample_rate=self.gen_config.over_sample_rate, + patience_counter=self.gen_config.patience_counter, + num_samples_boundary=self.gen_config.num_samples_boundary, + dtime_max=self.gen_config.dtime_max, + device=self.device) + + @staticmethod + def generate_model_from_config(model_config): + """Generate the model in derived class based on model config. + + Args: + model_config (EasyTPP.ModelConfig): config of model specs. + """ + model_id = model_config.model_id + + for subclass in TorchBaseModel.__subclasses__(): + if subclass.__name__ == model_id: + return subclass(model_config) + + raise RuntimeError('No model named ' + model_id) + + @staticmethod + def get_logits_at_last_step(logits, batch_non_pad_mask, sample_len=None): + """Retrieve the hidden states of last non-pad events. + + Args: + logits (tensor): [batch_size, seq_len, hidden_dim], a sequence of logits + batch_non_pad_mask (tensor): [batch_size, seq_len], a sequence of masks + sample_len (tensor): default None, use batch_non_pad_mask to find out the last non-mask position + + ref: https://medium.com/analytics-vidhya/understanding-indexing-with-pytorch-gather-33717a84ebc4 + + Returns: + tensor: retrieve the logits of EOS event + """ + + seq_len = batch_non_pad_mask.sum(dim=1) + select_index = seq_len - 1 if sample_len is None else seq_len - 1 - sample_len + # [batch_size, hidden_dim] + select_index = select_index.unsqueeze(1).repeat(1, logits.size(-1)) + # [batch_size, 1, hidden_dim] + select_index = select_index.unsqueeze(1) + # [batch_size, hidden_dim] + last_logits = torch.gather(logits, dim=1, index=select_index).squeeze(1) + return last_logits + + def compute_loglikelihood(self, time_delta_seq, lambda_at_event, lambdas_loss_samples, seq_mask, type_seq): + """Compute the loglikelihood of the event sequence based on Equation (8) of NHP paper. + + Args: + time_delta_seq (tensor): [batch_size, seq_len], time_delta_seq from model input. + lambda_at_event (tensor): [batch_size, seq_len, num_event_types], unmasked intensity at + (right after) the event. + lambdas_loss_samples (tensor): [batch_size, seq_len, num_sample, num_event_types], + intensity at sampling times. + seq_mask (tensor): [batch_size, seq_len], sequence mask vector to mask the padded events. + type_seq (tensor): [batch_size, seq_len], sequence of mark ids, with padded events having a mark of self.pad_token_id + + Returns: + tuple: event loglike, non-event loglike, intensity at event with padding events masked + """ + + # First, add an epsilon to every marked intensity for stability + lambda_at_event = lambda_at_event + self.eps + lambdas_loss_samples = lambdas_loss_samples + self.eps + + log_marked_event_lambdas = lambda_at_event.log() + total_sampled_lambdas = lambdas_loss_samples.sum(dim=-1) + + # Compute event LL - [batch_size, seq_len] + event_ll = -F.nll_loss( + log_marked_event_lambdas.permute(0, 2, 1), # mark dimension needs to come second, not third to match nll_loss specs + target=type_seq, + ignore_index=self.pad_token_id, # Padded events have a pad_token_id as a value + reduction='none', # Does not aggregate, and replaces what would have been the log(marked intensity) with 0. + ) + + # Compute non-event LL [batch_size, seq_len] + # interval_integral = length_interval * average of sampled lambda(t) + if self.use_mc_samples: + non_event_ll = total_sampled_lambdas.mean(dim=-1) * time_delta_seq * seq_mask + else: # Use trapezoid rule + non_event_ll = 0.5 * (total_sampled_lambdas[..., 1:] + total_sampled_lambdas[..., :-1]).mean(dim=-1) * time_delta_seq * seq_mask + + num_events = torch.masked_select(event_ll, event_ll.ne(0.0)).size()[0] + return event_ll, non_event_ll, num_events + + def make_dtime_loss_samples(self, time_delta_seq): + """Generate the time point samples for every interval. + + Args: + time_delta_seq (tensor): [batch_size, seq_len]. + + Returns: + tensor: [batch_size, seq_len, n_samples] + """ + # [1, 1, n_samples] + dtimes_ratio_sampled = torch.linspace(start=0.0, + end=1.0, + steps=self.loss_integral_num_sample_per_step, + device=self.device)[None, None, :] + + # [batch_size, max_len, n_samples] + sampled_dtimes = time_delta_seq[:, :, None] * dtimes_ratio_sampled + + return sampled_dtimes + + def compute_states_at_sample_times(self, **kwargs): + raise NotImplementedError('This need to implemented in inherited class ! ') + + def predict_one_step_at_every_event(self, batch): + """One-step prediction for every event in the sequence. + + Args: + time_seqs (tensor): [batch_size, seq_len]. + time_delta_seqs (tensor): [batch_size, seq_len]. + type_seqs (tensor): [batch_size, seq_len]. + + Returns: + tuple: tensors of dtime and type prediction, [batch_size, seq_len]. + """ + time_seq, time_delta_seq, event_seq, batch_non_pad_mask, _ = batch + + # remove the last event, as the prediction based on the last event has no label + # note: the first dts is 0 + # [batch_size, seq_len] + time_seq, time_delta_seq, event_seq = time_seq[:, :-1], time_delta_seq[:, :-1], event_seq[:, :-1] + + # [batch_size, seq_len] + dtime_boundary = torch.max(time_delta_seq * self.event_sampler.dtime_max, + time_delta_seq + self.event_sampler.dtime_max) + + # [batch_size, seq_len, num_sample] + accepted_dtimes, weights = self.event_sampler.draw_next_time_one_step(time_seq, + time_delta_seq, + event_seq, + dtime_boundary, + self.compute_intensities_at_sample_times, + compute_last_step_only=False) # make it explicit + + # We should condition on each accepted time to sample event mark, but not conditioned on the expected event time. + # 1. Use all accepted_dtimes to get intensity. + # [batch_size, seq_len, num_sample, num_marks] + intensities_at_times = self.compute_intensities_at_sample_times(time_seq, + time_delta_seq, + event_seq, + accepted_dtimes) + + # 2. Normalize the intensity over last dim and then compute the weighted sum over the `num_sample` dimension. + # Each of the last dimension is a categorical distribution over all marks. + # [batch_size, seq_len, num_sample, num_marks] + intensities_normalized = intensities_at_times / intensities_at_times.sum(dim=-1, keepdim=True) + + # 3. Compute weighted sum of distributions and then take argmax. + # [batch_size, seq_len, num_marks] + intensities_weighted = torch.einsum('...s,...sm->...m', weights, intensities_normalized) + + # [batch_size, seq_len] + types_pred = torch.argmax(intensities_weighted, dim=-1) + + # [batch_size, seq_len] + dtimes_pred = torch.sum(accepted_dtimes * weights, dim=-1) # compute the expected next event time + return dtimes_pred, types_pred + + def predict_multi_step_since_last_event(self, batch, forward=False): + """Multi-step prediction since last event in the sequence. + + Args: + batch (tuple): A tuple containing: + - time_seq_label (tensor): Timestamps of events [batch_size, seq_len]. + - time_delta_seq_label (tensor): Time intervals between events [batch_size, seq_len]. + - event_seq_label (tensor): Event types [batch_size, seq_len]. + - batch_non_pad_mask_label (tensor): Mask for non-padding elements [batch_size, seq_len]. + - attention_mask (tensor): Mask for attention [batch_size, seq_len]. + forward (bool, optional): Whether to use the entire sequence for prediction. Defaults to False. + + Returns: + tuple: tensors of dtime and type prediction, [batch_size, seq_len]. + """ + time_seq_label, time_delta_seq_label, event_seq_label, _, _ = batch + + num_step = self.gen_config.num_step_gen + + if not forward: + time_seq = time_seq_label[:, :-num_step] + time_delta_seq = time_delta_seq_label[:, :-num_step] + event_seq = event_seq_label[:, :-num_step] + else: + time_seq, time_delta_seq, event_seq = time_seq_label, time_delta_seq_label, event_seq_label + + for i in range(num_step): + # [batch_size, seq_len] + dtime_boundary = time_delta_seq + self.event_sampler.dtime_max + + # [batch_size, 1, num_sample] + accepted_dtimes, weights = \ + self.event_sampler.draw_next_time_one_step(time_seq, + time_delta_seq, + event_seq, + dtime_boundary, + self.compute_intensities_at_sample_times, + compute_last_step_only=True) + + # [batch_size, 1] + dtimes_pred = torch.sum(accepted_dtimes * weights, dim=-1) + + # [batch_size, seq_len, 1, event_num] + intensities_at_times = self.compute_intensities_at_sample_times(time_seq, + time_delta_seq, + event_seq, + dtimes_pred[:, :, None], + max_steps=event_seq.size()[1]) + + # [batch_size, seq_len, event_num] + intensities_at_times = intensities_at_times.squeeze(dim=-2) + + # [batch_size, seq_len] + types_pred = torch.argmax(intensities_at_times, dim=-1) + + # [batch_size, 1] + types_pred_ = types_pred[:, -1:] + dtimes_pred_ = dtimes_pred[:, -1:] + time_pred_ = time_seq[:, -1:] + dtimes_pred_ + + # concat to the prefix sequence + time_seq = torch.cat([time_seq, time_pred_], dim=-1) + time_delta_seq = torch.cat([time_delta_seq, dtimes_pred_], dim=-1) + event_seq = torch.cat([event_seq, types_pred_], dim=-1) + + return time_delta_seq[:, -num_step - 1:], event_seq[:, -num_step - 1:], \ + time_delta_seq_label[:, -num_step - 1:], event_seq_label[:, -num_step - 1:] diff --git a/easy_tpp/model/torch_model/torch_fullynn.py b/easy_tpp/model/torch_model/torch_fullynn.py new file mode 100644 index 0000000000000000000000000000000000000000..52ae351c3439dcd0ad3fec73ad70252d1d8e9527 --- /dev/null +++ b/easy_tpp/model/torch_model/torch_fullynn.py @@ -0,0 +1,219 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import grad + +from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel + + +class CumulHazardFunctionNetwork(nn.Module): + """Cumulative Hazard Function Network + ref: https://github.com/wassname/torch-neuralpointprocess + """ + + def __init__(self, model_config): + super(CumulHazardFunctionNetwork, self).__init__() + self.hidden_size = model_config.hidden_size + self.num_mlp_layers = model_config.model_specs['num_mlp_layers'] + self.num_event_types = model_config.num_event_types + self.proper_marked_intensities = model_config.model_specs["proper_marked_intensities"] + + # transform inter-event time embedding + self.layer_dense_1 = nn.Linear(in_features=1, out_features=self.hidden_size) + + # concat rnn states and inter-event time embedding + self.layer_dense_2 = nn.Linear(in_features=self.hidden_size * 2, out_features=self.hidden_size) + + # mlp layers + self.module_list = nn.ModuleList( + [nn.Linear(in_features=self.hidden_size, out_features=self.hidden_size) for _ in + range(self.num_mlp_layers - 1)]) + + self.layer_dense_3 = nn.Sequential(nn.Linear(in_features=self.hidden_size, + out_features=self.num_event_types), + nn.Softplus()) + + self.params_eps = torch.finfo(torch.float32).eps # ensure positiveness of parameters + + self.init_weights_positive() + + def init_weights_positive(self): + for p in self.parameters(): + p.data = torch.abs(p.data) + p.data = torch.clamp(p.data, min=self.params_eps) + + def forward(self, hidden_states, time_delta_seqs): + for p in self.parameters(): + p.data = torch.clamp(p.data, min=self.params_eps) + + time_delta_seqs.requires_grad_(True) + + # [batch_size, seq_len, hidden_size] + t = self.layer_dense_1(time_delta_seqs.unsqueeze(dim=-1)) + + # [batch_size, seq_len, hidden_size] + out = torch.tanh(self.layer_dense_2(torch.cat([hidden_states, t], dim=-1))) + for layer in self.module_list: + out = torch.tanh(layer(out)) + + # [batch_size, seq_len, num_event_types] + integral_lambda = self.layer_dense_3(out) + + # [batch_size, seq_len, num_event_types] + if self.proper_marked_intensities: + derivative_integral_lambdas = [] + for i in range(integral_lambda.shape[-1]): # iterate over marks + derivative_integral_lambdas.append(grad( + integral_lambda[..., i].sum(), + time_delta_seqs, + create_graph=True, retain_graph=True)[0]) + derivative_integral_lambda = torch.stack(derivative_integral_lambdas, dim=-1) # TODO: Check that it is okay to iterate over marks like this + else: + derivative_integral_lambda = grad( + integral_lambda.sum(), + time_delta_seqs, + create_graph=True, retain_graph=True)[0] + derivative_integral_lambda = derivative_integral_lambda.unsqueeze(-1).expand(*derivative_integral_lambda.shape, self.num_event_types) / self.num_event_types + + return integral_lambda, derivative_integral_lambda + + +class FullyNN(TorchBaseModel): + """Torch implementation of + Fully Neural Network based Model for General Temporal Point Processes, NeurIPS 2019. + https://arxiv.org/abs/1905.09690 + + ref: https://github.com/KanghoonYoon/torch-neuralpointprocess/blob/master/module.py; + https://github.com/wassname/torch-neuralpointprocess + """ + + def __init__(self, model_config): + """Initialize the model + + Args: + model_config (EasyTPP.ModelConfig): config of model specs. + """ + super(FullyNN, self).__init__(model_config) + + self.rnn_type = model_config.rnn_type + self.rnn_list = [nn.LSTM, nn.RNN, nn.GRU] + self.n_layers = model_config.num_layers + self.dropout_rate = model_config.dropout_rate + for sub_rnn_class in self.rnn_list: + if sub_rnn_class.__name__ == self.rnn_type: + self.layer_rnn = sub_rnn_class(input_size=1 + self.hidden_size, + hidden_size=self.hidden_size, + num_layers=self.n_layers, + batch_first=True, + dropout=self.dropout_rate) + + self.layer_intensity = CumulHazardFunctionNetwork(model_config) + + def forward(self, time_seqs, time_delta_seqs, type_seqs): + """Call the model + + Args: + time_seqs (tensor): [batch_size, seq_len], timestamp seqs. + time_delta_seqs (tensor): [batch_size, seq_len], inter-event time seqs. + type_seqs (tensor): [batch_size, seq_len], event type seqs. + + Returns: + tensor: hidden states at event times. + """ + # [batch_size, seq_len, hidden_size] + type_embedding = self.layer_type_emb(type_seqs) + + # [batch_size, seq_len, hidden_size + 1] + rnn_input = torch.cat((type_embedding, time_delta_seqs.unsqueeze(-1)), dim=-1) + + # [batch_size, seq_len, hidden_size] + # states right after the event + hidden_states, _ = self.layer_rnn(rnn_input) + + return hidden_states + + def loglike_loss(self, batch): + """Compute the loglike loss. + + Args: + batch (tuple, list): batch input. + + Returns: + list: loglike loss, num events. + """ + # [batch_size, seq_len] + time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, _ = batch + + # [batch_size, seq_len, hidden_size] + hidden_states = self.forward( + time_seqs[:, :-1], + time_delta_seqs[:, :-1], + type_seqs[:, :-1], + ) + # [batch_size, seq_len, num_event_types] + integral_lambda, derivative_integral_lambda = self.layer_intensity(hidden_states, time_delta_seqs[:, 1:]) + + # First, add an epsilon to every marked intensity for stability + derivative_integral_lambda += self.eps + + # Compute components for each LL term + log_marked_event_lambdas = derivative_integral_lambda.log() + + # Compute event LL - [batch_size, seq_len] + event_ll = -F.nll_loss( + log_marked_event_lambdas.permute(0, 2, 1), # mark dimension needs to come second, not third to match nll_loss specs + target=type_seqs[:, 1:], + ignore_index=self.pad_token_id, # Padded events have a pad_token_id as a value + reduction='none', # Does not aggregate, and replaces what would have been the log(marked intensity) with 0. + ) + + # [batch_size, seq_len] + # multiplied by sequence mask + non_event_ll = integral_lambda.sum(-1) * batch_non_pad_mask[:, 1:] + num_events = torch.masked_select(event_ll, event_ll.ne(0.0)).size()[0] + loss = - (event_ll - non_event_ll).sum() + + return loss, num_events + + def compute_intensities_at_sample_times(self, + time_seqs, + time_delta_seqs, + type_seqs, + sample_dtimes, + **kwargs): + """Compute hidden states at sampled times. + + Args: + time_seqs (tensor): [batch_size, seq_len], times seqs. + time_delta_seqs (tensor): [batch_size, seq_len], time delta seqs. + type_seqs (tensor): [batch_size, seq_len], event type seqs. + sample_dtimes (tensor): [batch_size, seq_len, num_samples], sampled inter-event timestamps. + + Returns: + tensor: [batch_size, seq_len, num_samples, num_event_types], intensity at all sampled times. + """ + + compute_last_step_only = kwargs.get('compute_last_step_only', False) + + # [batch_size, seq_len, hidden_size] + hidden_states = self.forward( + time_seqs=time_seqs, + time_delta_seqs=time_delta_seqs, + type_seqs=type_seqs, + ) + + num_samples = sample_dtimes.size()[-1] + batch_size, seq_len, hidden_size = hidden_states.shape + + hidden_states_ = hidden_states[..., None, :].expand(batch_size, seq_len, num_samples, hidden_size) + _, derivative_integral_lambda = self.layer_intensity.forward( + hidden_states=hidden_states_, + time_delta_seqs=sample_dtimes, + ) + + if compute_last_step_only: + lambdas = derivative_integral_lambda[:, -1:, :, :] + else: + # [batch_size, seq_len, num_samples, num_event_types] + lambdas = derivative_integral_lambda + return lambdas diff --git a/easy_tpp/model/torch_model/torch_intensity_free.py b/easy_tpp/model/torch_model/torch_intensity_free.py new file mode 100644 index 0000000000000000000000000000000000000000..f30c4404e94afc2bf3830f1183b690e8cd5d17c3 --- /dev/null +++ b/easy_tpp/model/torch_model/torch_intensity_free.py @@ -0,0 +1,254 @@ +import torch +import torch.distributions as D +from torch import nn +from torch.distributions import Categorical, TransformedDistribution +from torch.distributions import MixtureSameFamily as TorchMixtureSameFamily +from torch.distributions import Normal as TorchNormal + +from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel + + +def clamp_preserve_gradients(x, min_val, max_val): + """Clamp the tensor while preserving gradients in the clamped region. + + Args: + x (tensor): tensor to be clamped. + min_val (float): minimum value. + max_val (float): maximum value. + """ + return x + (x.clamp(min_val, max_val) - x).detach() + + +class Normal(TorchNormal): + """Normal distribution, redefined `log_cdf` and `log_survival_function` due to + no numerically stable implementation of them is available for normal distribution. + """ + + def log_cdf(self, x): + cdf = clamp_preserve_gradients(self.cdf(x), 1e-7, 1 - 1e-7) + return cdf.log() + + def log_survival_function(self, x): + cdf = clamp_preserve_gradients(self.cdf(x), 1e-7, 1 - 1e-7) + return torch.log(1.0 - cdf) + + +class MixtureSameFamily(TorchMixtureSameFamily): + """Mixture (same-family) distribution, redefined `log_cdf` and `log_survival_function`. + """ + + def log_cdf(self, x): + x = self._pad(x) + log_cdf_x = self.component_distribution.log_cdf(x) + mix_logits = self.mixture_distribution.logits + return torch.logsumexp(log_cdf_x + mix_logits, dim=-1) + + def log_survival_function(self, x): + x = self._pad(x) + log_sf_x = self.component_distribution.log_survival_function(x) + mix_logits = self.mixture_distribution.logits + return torch.logsumexp(log_sf_x + mix_logits, dim=-1) + + +class LogNormalMixtureDistribution(TransformedDistribution): + """ + Mixture of log-normal distributions. + + Args: + locs (tensor): [batch_size, seq_len, num_mix_components]. + log_scales (tensor): [batch_size, seq_len, num_mix_components]. + log_weights (tensor): [batch_size, seq_len, num_mix_components]. + mean_log_inter_time (float): Average log-inter-event-time. + std_log_inter_time (float): Std of log-inter-event-times. + """ + + def __init__(self, locs, log_scales, log_weights, mean_log_inter_time, std_log_inter_time, validate_args=None): + mixture_dist = D.Categorical(logits=log_weights) + component_dist = Normal(loc=locs, scale=log_scales.exp()) + GMM = MixtureSameFamily(mixture_dist, component_dist) + if mean_log_inter_time == 0.0 and std_log_inter_time == 1.0: + transforms = [] + else: + transforms = [D.AffineTransform(loc=mean_log_inter_time, scale=std_log_inter_time)] + self.mean_log_inter_time = mean_log_inter_time + self.std_log_inter_time = std_log_inter_time + transforms.append(D.ExpTransform()) + + self.transforms = transforms + sign = 1 + for transform in self.transforms: + sign = sign * transform.sign + self.sign = int(sign) + super().__init__(GMM, transforms, validate_args=validate_args) + + def log_cdf(self, x): + for transform in self.transforms[::-1]: + x = transform.inv(x) + if self._validate_args: + self.base_dist._validate_sample(x) + + if self.sign == 1: + return self.base_dist.log_cdf(x) + else: + return self.base_dist.log_survival_function(x) + + def log_survival_function(self, x): + for transform in self.transforms[::-1]: + x = transform.inv(x) + if self._validate_args: + self.base_dist._validate_sample(x) + + if self.sign == 1: + return self.base_dist.log_survival_function(x) + else: + return self.base_dist.log_cdf(x) + + +class IntensityFree(TorchBaseModel): + """Torch implementation of Intensity-Free Learning of Temporal Point Processes, ICLR 2020. + https://openreview.net/pdf?id=HygOjhEYDH + + reference: https://github.com/shchur/ifl-tpp + """ + + def __init__(self, model_config): + """Initialize the model + + Args: + model_config (EasyTPP.ModelConfig): config of model specs. + + """ + super(IntensityFree, self).__init__(model_config) + + self.num_mix_components = model_config.model_specs['num_mix_components'] + self.mean_log_inter_time = model_config.get("mean_log_inter_time", 0.0) + self.std_log_inter_time = model_config.get("std_log_inter_time", 1.0) + + self.num_features = 1 + self.hidden_size + + self.layer_rnn = nn.GRU(input_size=self.num_features, + hidden_size=self.hidden_size, + num_layers=1, # used in original paper + batch_first=True) + + self.mark_linear = nn.Linear(self.hidden_size, self.num_event_types_pad) + self.linear = nn.Linear(self.hidden_size, 3 * self.num_mix_components) + + def forward(self, time_delta_seqs, type_seqs): + """Call the model. + + Args: + time_delta_seqs (tensor): [batch_size, seq_len], inter-event time seqs. + type_seqs (tensor): [batch_size, seq_len], event type seqs. + + Returns: + list: hidden states, [batch_size, seq_len, hidden_dim], states right before the event happens. + """ + # [batch_size, seq_len, hidden_size] + # We dont normalize inter-event time here + temporal_seqs = torch.log(time_delta_seqs + self.eps).unsqueeze(-1) + + # [batch_size, seq_len, hidden_size] + type_emb = self.layer_type_emb(type_seqs) + + # [batch_size, seq_len, hidden_size + 1] + rnn_input = torch.cat([temporal_seqs, type_emb], dim=-1) + + # [batch_size, seq_len, hidden_size] + context = self.layer_rnn(rnn_input)[0] + + return context + + def loglike_loss(self, batch): + """Compute the loglike loss. + + Args: + batch (list): batch input. + + Returns: + tuple: loglikelihood loss and num of events. + """ + time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, _ = batch + + # [batch_size, seq_len, hidden_size] + context = self.forward(time_delta_seqs[:, :-1], type_seqs[:, :-1]) + + # [batch_size, seq_len, 3 * num_mix_components] + raw_params = self.linear(context) + locs = raw_params[..., :self.num_mix_components] + log_scales = raw_params[..., self.num_mix_components: (2 * self.num_mix_components)] + log_weights = raw_params[..., (2 * self.num_mix_components):] + + log_scales = clamp_preserve_gradients(log_scales, -5.0, 3.0) + log_weights = torch.log_softmax(log_weights, dim=-1) + inter_time_dist = LogNormalMixtureDistribution( + locs=locs, + log_scales=log_scales, + log_weights=log_weights, + mean_log_inter_time=self.mean_log_inter_time, + std_log_inter_time=self.std_log_inter_time + ) + + inter_times = time_delta_seqs[:, 1:].clamp(min=1e-5) + # [batch_size, seq_len] + event_mask = torch.logical_and(batch_non_pad_mask[:, 1:], type_seqs[:, 1:] != self.pad_token_id) + time_ll = inter_time_dist.log_prob(inter_times) * event_mask + + # [batch_size, seq_len, num_marks] + mark_logits = torch.log_softmax(self.mark_linear(context), dim=-1) + mark_dist = Categorical(logits=mark_logits) + mark_ll = mark_dist.log_prob(type_seqs[:, 1:]) * event_mask + + log_p = time_ll + mark_ll + + # [batch_size,] + loss = -log_p.sum() + + num_events = event_mask.sum().item() + + return loss, num_events + + def predict_one_step_at_every_event(self, batch): + """One-step prediction for every event in the sequence. + + Args: + time_seqs (tensor): [batch_size, seq_len]. + time_delta_seqs (tensor): [batch_size, seq_len]. + type_seqs (tensor): [batch_size, seq_len]. + + Returns: + tuple: tensors of dtime and type prediction, [batch_size, seq_len]. + """ + time_seq, time_delta_seq, event_seq, batch_non_pad_mask, _ = batch + + # remove the last event, as the prediction based on the last event has no label + # time_delta_seq should start from 1, because the first one is zero + time_seq, time_delta_seq, event_seq = time_seq[:, :-1], time_delta_seq[:, :-1], event_seq[:, :-1] + + # [batch_size, seq_len, hidden_size] + context = self.forward(time_delta_seq, event_seq) + + # [batch_size, seq_len, 3 * num_mix_components] + raw_params = self.linear(context) + locs = raw_params[..., :self.num_mix_components] + log_scales = raw_params[..., self.num_mix_components: (2 * self.num_mix_components)] + log_weights = raw_params[..., (2 * self.num_mix_components):] + + log_scales = clamp_preserve_gradients(log_scales, -5.0, 3.0) + log_weights = torch.log_softmax(log_weights, dim=-1) + inter_time_dist = LogNormalMixtureDistribution( + locs=locs, + log_scales=log_scales, + log_weights=log_weights, + mean_log_inter_time=self.mean_log_inter_time, + std_log_inter_time=self.std_log_inter_time + ) + + # [num_samples, batch_size, seq_len] + accepted_dtimes = inter_time_dist.sample((self.event_sampler.num_sample,)) + dtimes_pred = accepted_dtimes.mean(dim=0) + + # [batch_size, seq_len, num_marks] + mark_logits = torch.log_softmax(self.mark_linear(context), dim=-1) # Marks are modeled conditionally independently from times + types_pred = torch.argmax(mark_logits, dim=-1) + return dtimes_pred, types_pred diff --git a/easy_tpp/model/torch_model/torch_nhp.py b/easy_tpp/model/torch_model/torch_nhp.py new file mode 100644 index 0000000000000000000000000000000000000000..b38117d2a02387d9713b98e97e27936f36c585a6 --- /dev/null +++ b/easy_tpp/model/torch_model/torch_nhp.py @@ -0,0 +1,257 @@ +import torch +from torch import nn + +from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel +from easy_tpp.model.torch_model.torch_baselayer import ScaledSoftplus + + +class ContTimeLSTMCell(nn.Module): + """LSTM Cell in Neural Hawkes Process, NeurIPS'17. + """ + + def __init__(self, hidden_dim): + """Initialize the continuous LSTM cell. + + Args: + hidden_dim (int): dim of hidden state. + """ + super(ContTimeLSTMCell, self).__init__() + self.hidden_dim = hidden_dim + self.init_dense_layer(hidden_dim, bias=True) + + def init_dense_layer(self, hidden_dim, bias): + """Initialize linear layers given Equations (5a-6c) in the paper. + + Args: + hidden_dim (int): dim of hidden state. + """ + + self.linear_layer = nn.Linear(2 * hidden_dim, 7 * hidden_dim, bias=bias) + self.softplus = nn.Softplus() + + def forward(self, x_i, hidden_ti_minus, ct_ti_minus, c_bar_im1): + """Update the continuous-time LSTM cell. + + Args: + x_i (tensor): event embedding vector at t_i. + hidden_ti_minus (tensor): hidden state at t_i- + ct_ti_minus (tensor): cell state c(t) at t_i- + c_bar_im1 (tensor): cell state c_bar at t_{i-1} (c_bar_{i-1}) + + Returns: + list: cell state, cell bar state, decay and output at t_i + """ + + x_i_ = torch.cat((x_i, hidden_ti_minus), dim=1) + + i_i, i_bar_i, f_i, f_bar_i, z_i, o_i, delta_i = self.linear_layer(x_i_).chunk(7, dim=-1) + + i_i, i_bar_i, f_i, f_bar_i, z_i, o_i, delta_i = ( + torch.sigmoid(i_i), # Eq (5a) + torch.sigmoid(i_bar_i), # Eq (5a) - Bar version + torch.sigmoid(f_i), # Eq (5b) + torch.sigmoid(f_bar_i), # Eq (5b) - Bar version + torch.tanh(z_i), # Eq (5c) + torch.sigmoid(o_i), # Eq (5d) + self.softplus(delta_i) # Eq (6c) + ) + + # Eq (6a) + c_i = f_i * ct_ti_minus + i_i * z_i + + # Eq (6b) + c_bar_i = f_bar_i * c_bar_im1 + i_bar_i * z_i + + return c_i, c_bar_i, delta_i, o_i + + def decay(self, c_i, c_bar_i, delta_i, o_i, dtime): + """Cell and hidden state decay according to Equation (7). + + Args: + c_i (tensor): cell state c(t) at t_i. + c_bar_i (tensor): cell state c_bar at t_i (c_bar_i). + delta_i (tensor): gate decay state at t_i. + o_i (tensor): gate output state at t_i. + dtime (tensor): delta time to decay. + + Returns: + list: list of cell and hidden state tensors after the decay. + """ + + c_t = c_bar_i + (c_i - c_bar_i) * torch.exp(-delta_i * dtime) + h_t = o_i * torch.tanh(c_t) + return c_t, h_t + + +class NHP(TorchBaseModel): + """Torch implementation of The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process, + NeurIPS 2017, https://arxiv.org/abs/1612.09328. + """ + + def __init__(self, model_config): + """Initialize the NHP model. + + Args: + model_config (EasyTPP.ModelConfig): config of model specs. + """ + super(NHP, self).__init__(model_config) + self.beta = model_config.model_specs.get('beta', 1.0) + self.bias = model_config.model_specs.get('bias', True) + self.rnn_cell = ContTimeLSTMCell(self.hidden_size) + + self.layer_intensity = nn.Sequential( # eq. 4a, + nn.Linear(self.hidden_size, self.num_event_types, self.bias), + ScaledSoftplus(self.num_event_types)) # learnable mark-specific beta + + def get_init_state(self, batch_size): + c_t, c_bar_t, delta_t, o_t = torch.zeros( + batch_size, + 4 * self.hidden_size, + device=self.device).chunk(4, dim=1) + return c_t, c_bar_t, delta_t, o_t # Okay to initialize delta to be zero because c==c_bar at the beginning + + def forward(self, batch): + ''' + Suppose we have inputs with original sequence length N+1 + ts: [t0, t1, ..., t_N] + dts: [0, t1 - t0, t2 - t1, ..., t_N - t_{N-1}] + marks: [k0, k1, ..., k_N] (k0 and kN could be padded marks if t0 and tN correspond to left and right windows) + + Return: + Left limits of [t_1, ..., t_N] of shape: (batch_size, seq_len - 1, hidden_dim) + Right limits of [t_0, ..., t_{N-1}, t_N] of shape: (batch_size, seq_len, 4 * hidden_dim) + We need the right limit of t_N to sample continuation. + + > rnn_cell.recurrence(event_emb_t, h_tm1, c_tm1, c_bar_tm1) -> c_t, c_bar_t, gate_delta, gate_o + > rnn_cell.decay(c_t, c_bar_t, delta_t, o_t, dt) -> c_d_t, h_d_t + ''' + t_BN, dt_BN, marks_BN, _, _ = batch + B, N = dt_BN.shape + left_hs = [] + right_states = [] + + all_event_emb_BNP = self.layer_type_emb(marks_BN) + c_t, c_bar_t, delta_t, o_t = self.get_init_state(B) # initialize the right limits + for i in range(N): + # Take last right limit and evolve into left limit; we will discard this value for t0 because dt=0 + ct_d_t, h_d_t = self.rnn_cell.decay(c_t, c_bar_t, delta_t, o_t, dt_BN[..., i][..., None]) + + # Take left limit and update to be right limit + event_emb_t = all_event_emb_BNP[..., i, :] + c_t, c_bar_t, delta_t, o_t = self.rnn_cell( + x_i=event_emb_t, + hidden_ti_minus=h_d_t, + ct_ti_minus=ct_d_t, + c_bar_im1=c_bar_t, + ) + + left_hs.append(h_d_t) + right_states.append(torch.cat((c_t, c_bar_t, delta_t, o_t), dim=-1)) + + left_hiddens = torch.stack(left_hs[1:], dim=-2) # (batch_size, seq_len - 1, hidden_dim) + right_hiddens = torch.stack(right_states, dim=-2) # (batch_size, seq_len, 4 * hidden_dim) + return left_hiddens, right_hiddens + + def get_states(self, right_hiddens, sample_dts): + """ + right_hiddens: (batch_size, seq_len, 4 * hidden_dim): (c_t, c_bar_t, delta_t, o_t) + sample_dts: (batch_size, seq_len, MC_points) + + > rnn_cell.decay(c_t, c_bar_t, delta_t, o_t, dt) -> c_d_t, h_d_t + """ + c_t, c_bar_t, delta_t, o_t = torch.chunk(right_hiddens, 4, dim=-1) + _, h_ts = self.rnn_cell.decay(c_t[:, :, None, :], + c_bar_t[:, :, None, :], + delta_t[:, :, None, :], + o_t[:, :, None, :], + sample_dts[..., None]) + return h_ts + + def loglike_loss(self, batch): + """Compute the log-likelihood loss. + + Args: + batch (list): batch input. + + Returns: + tuple: loglikelihood loss and num of events. + """ + ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch + + # 1. compute hidden states at event time + # left limits of [t_1, ..., t_N] + # right limits of [t_0, ..., t_{N-1}, t_N] + left_hiddens, right_hiddens = self.forward((ts_BN, dts_BN, marks_BN, None, None)) + right_hiddens = right_hiddens[..., :-1, :] # discard right limit at t_N for logL + + # 2. evaluate intensity values at each event *from the left limit* + intensity_B_Nm1_M = self.layer_intensity(left_hiddens) + + # 3. sample dts in each interval for estimating the integral + dts_sample_B_Nm1_G = self.make_dtime_loss_samples(dts_BN[:, 1:]) + + # 4. evaluate intensity at dt_samples for MC *from the left limit* after decay -> shape (B, N-1, G, M) + intensity_dts_B_Nm1_G_M = self.layer_intensity(self.get_states(right_hiddens, dts_sample_B_Nm1_G)) + + event_ll, non_event_ll, num_events = self.compute_loglikelihood( + lambda_at_event=intensity_B_Nm1_M, + lambdas_loss_samples=intensity_dts_B_Nm1_G_M, + time_delta_seq=dts_BN[:, 1:], + seq_mask=batch_non_pad_mask[:, 1:], + type_seq=marks_BN[:, 1:]) + + # compute loss to minimize + loss = - (event_ll - non_event_ll).sum() + return loss, num_events + + def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs): + """Compute the intensity at sampled times, not only event times. + + Args: + time_seqs (tensor): [batch_size, seq_len], times seqs. + time_delta_seqs (tensor): [batch_size, seq_len], time delta seqs. + type_seqs (tensor): [batch_size, seq_len], event type seqs. + sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps. + + Returns: + tensor: [batch_size, num_times, num_mc_sample, num_event_types], + intensity at each timestamp for each event type. + """ + + compute_last_step_only = kwargs.get('compute_last_step_only', False) + + _input = time_seqs, time_delta_seqs, type_seqs, None, None + + # We will need the right limit at the last given event to decay from and get the left limits for sampling + _, right_hiddens = self.forward(_input) + + c_i, c_bar_i, delta_i, o_i = torch.chunk(right_hiddens, 4, dim=-1) + + if compute_last_step_only: + interval_t_sample = sample_dtimes[:, -1:, :, None] + _, h_ts = self.rnn_cell.decay(c_i[:, -1:, None, :], + c_bar_i[:, -1:, None, :], + delta_i[:, -1:, None, :], + o_i[:, -1:, None, :], + interval_t_sample) + + # [batch_size, 1, num_mc_sample, num_marks] + sampled_intensities = self.layer_intensity(h_ts) + + else: + # interval_t_sample - [batch_size, seq_len, num_mc_sample, 1] + interval_t_sample = sample_dtimes[..., None] + # Use broadcasting to compute the decays at all time steps + # at all sample points + # h_ts shape (batch_size, seq_len, num_mc_sample, hidden_dim) + # cells[:, :, None, :] (batch_size, seq_len, 1, hidden_dim) + _, h_ts = self.rnn_cell.decay(c_i[:, :, None, :], + c_bar_i[:, :, None, :], + delta_i[:, :, None, :], + o_i[:, :, None, :], + interval_t_sample) + + # [batch_size, seq_len, num_mc_sample, num_marks] + sampled_intensities = self.layer_intensity(h_ts) + + return sampled_intensities diff --git a/easy_tpp/model/torch_model/torch_ode_tpp.py b/easy_tpp/model/torch_model/torch_ode_tpp.py new file mode 100644 index 0000000000000000000000000000000000000000..da5dbe270fc9687f57993aca51683d3e0b770606 --- /dev/null +++ b/easy_tpp/model/torch_model/torch_ode_tpp.py @@ -0,0 +1,298 @@ +import torch +from torch import nn + +from easy_tpp.model.torch_model.torch_baselayer import DNN +from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel +from easy_tpp.utils import rk4_step_method + + +def flatten_parameters(model): + p_shapes = [] + flat_parameters = [] + for p in model.parameters(): + p_shapes.append(p.size()) + flat_parameters.append(p.flatten()) + return torch.cat(flat_parameters) + + +class NeuralODEAdjoint(torch.autograd.Function): + + def __init__(self, device): + super(NeuralODEAdjoint, self).__init__() + self.device = device + + @staticmethod + def forward(ctx, z_init, delta_t, ode_fn, solver, num_sample_times, *model_parameters): + """ + + Args: + ctx: + input: (tensor): [batch_size] + model: + solver: + delta_t (tensor): [batch_size, num_sample_times] + + Returns: + + """ + + ctx.ode_fn = ode_fn + ctx.solver = solver + ctx.delta_t = delta_t + ctx.model_parameters = model_parameters + ctx.num_sample_times = num_sample_times + + total_state = [] + dt_ratio = 1.0 / num_sample_times + delta_t = delta_t * dt_ratio + with torch.no_grad(): + state = z_init + for i in range(num_sample_times): + # [batch_size, hidden_size] + state = solver(diff_func=ode_fn, dt=delta_t, z0=state) + total_state.append(state) + + # [batch_size, num_samples, hidden_size] + ctx.save_for_backward(state) + + return state + + @staticmethod + def backward(ctx, grad_z): + output_state = ctx.saved_tensors[0] # return a tuple + ode_fn = ctx.ode_fn + solver = ctx.solver + delta_t = ctx.delta_t + model_parameters = ctx.model_parameters + num_sample_times = ctx.num_sample_times + + # Dynamics of augmented system to be calculated backwards in time + def aug_dynamics(aug_states): + tmp_z = aug_states[0] + tmp_neg_a = -aug_states[1] + + with torch.set_grad_enabled(True): + tmp_z = tmp_z.detach().requires_grad_(True) + func_eval = ode_fn(tmp_z) + tmp_ds = torch.autograd.grad( + (func_eval,), (tmp_z, *model_parameters), + grad_outputs=tmp_neg_a, + allow_unused=True, + retain_graph=True) + + neg_adfdz = tmp_ds[0] + neg_adfdtheta = [torch.flatten(var) for var in tmp_ds[1:]] + + return [func_eval, neg_adfdz, *neg_adfdtheta] + + dt_ratio = 1.0 / num_sample_times + delta_t = delta_t * dt_ratio + + with torch.no_grad(): + # Construct back-state for ode solver + # reshape variable \theta for batch solving + init_var_grad = [torch.zeros_like(torch.flatten(var)) for var in model_parameters] + + # [z(t_1), a(t_1), \theta] + z1 = output_state + a1 = grad_z + states = [z1, a1, *init_var_grad] + + for i in range(num_sample_times): + states = solver(aug_dynamics, -delta_t, states) + + grad_z0 = states[1] + + grad_theta = [torch.reshape(torch.mean(var_grad, dim=0), var.shape) for var, var_grad in + zip(model_parameters, states[2:])] + + return (grad_z0, None, None, None, None, *grad_theta) + + +class NeuralODE(nn.Module): + def __init__(self, model, solver, num_sample_times, device): + super().__init__() + self.model = model + self.solver = solver + self.params = [w for w in model.parameters()] + self.num_sample_times = num_sample_times + self.device = device + + def forward(self, input_state, delta_time): + """ + + Args: + input_state: [batch_size, hidden_size] + return_state: + + Returns: + + """ + output_state = NeuralODEAdjoint.apply(input_state, + delta_time, + self.model, + self.solver, + self.num_sample_times, + *self.params) + + # [batch_size, num_sample_times, hidden_size] + return output_state + + +class ODETPP(TorchBaseModel): + """Torch implementation of a TPP with Neural ODE state evolution, which is a simplified version of TPP in + https://arxiv.org/abs/2011.04583, ICLR 2021 + + code reference: https://msurtsukov.github.io/Neural-ODE/; + https://github.com/liruilong940607/NeuralODE/blob/master/NeuralODE.py + + """ + + def __init__(self, model_config): + """Initialize the model + + Args: + model_config (EasyTPP.ModelConfig): config of model specs. + """ + super(ODETPP, self).__init__(model_config) + + self.layer_intensity = nn.Sequential( + nn.Linear(self.hidden_size, self.num_event_types), + nn.Softplus()) + + self.event_model = DNN(inputs_dim=self.hidden_size, + hidden_size=[self.hidden_size]) + + self.ode_num_sample_per_step = model_config.model_specs['ode_num_sample_per_step'] + + self.solver = rk4_step_method + + self.layer_neural_ode = NeuralODE(model=self.event_model, + solver=self.solver, + num_sample_times=self.ode_num_sample_per_step, + device=self.device) + + def forward(self, time_delta_seqs, type_seqs): + """Call the model. + + Args: + time_delta_seqs (tensor): [batch_size, seq_len], inter-event time seqs. + type_seqs (tensor): [batch_size, seq_len], event type seqs. + + Returns: + tensor: hidden states at event times. + + """ + # [batch_size, seq_len, hidden_size] + type_seq_emb = self.layer_type_emb(type_seqs) + time_delta_seqs_ = time_delta_seqs[..., None] + + left_limits, right_limits = [], [] + right_limit = torch.zeros_like(type_seq_emb[:, 0, :], device=self.device) + for type_emb, dt in zip(torch.unbind(type_seq_emb, dim=-2), + torch.unbind(time_delta_seqs_, dim=-2)): + left_limit = self.layer_neural_ode(right_limit, dt) + right_limit = left_limit + type_emb + + left_limits.append(left_limit) + right_limits.append(right_limit) + + # [batch_size, seq_len-1, hidden_size] + left_limits = torch.stack(left_limits[1:], dim=1) + # [batch_size, seq_len, hidden_size] + right_limits = torch.stack(right_limits, dim=1) + + return left_limits, right_limits + + def loglike_loss(self, batch): + """Compute the loglike loss. + + Args: + batch (list): batch input. + + Returns: + list: loglike loss, num events. + """ + time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, _ = batch + + # compute hidden states at event time + # left limits of [t_1, ..., t_N] + # right limits of [t_0, ..., t_{N-1}] + left_limits, right_limits = self.forward(time_delta_seqs, type_seqs) + right_limits = right_limits[..., :-1, :] + + # Lambda(t) right before each event time point + # lambda_at_event - [batch_size, num_times=max_len-1, num_event_types] + # Here we drop the last event because it has no delta_time label (can not decay) + lambda_at_event = self.layer_intensity(left_limits) + + # interval_t_sample - [batch_size, num_times=max_len-1, num_mc_sample] + # for every batch and every event point => do a sampling (num_mc_sampling) + # the first dtime is zero, so we use time_delta_seq[:, 1:] + interval_t_sample = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) + + # [batch_size, num_times = max_len - 1, num_mc_sample, hidden_size] + sample_state_ti = self.compute_states_at_sample_times(right_limits, interval_t_sample) + + # [batch_size, num_times = max_len - 1, num_mc_sample, event_num] + lambda_t_sample = self.layer_intensity(sample_state_ti) + + event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, + lambdas_loss_samples=lambda_t_sample, + time_delta_seq=time_delta_seqs[:, 1:], + seq_mask=batch_non_pad_mask[:, 1:], + type_seq=type_seqs[:, 1:]) + + # compute loss to optimize + loss = - (event_ll - non_event_ll).sum() + return loss, num_events + + def compute_states_at_sample_times(self, state_ti_plus, sample_dtimes): + """Compute the states at sampling times. + + Args: + state_ti_plus (tensor): [batch_size, seq_len, hidden_size], states right after the events. + sample_dtimes (tensor): [batch_size, seq_len, num_samples], delta times in sampling. + + Returns: + tensor: hiddens states at sampling times. + """ + + # Use broadcasting to compute the decays at all time steps + # at all sample points + # h_ts shape (batch_size, seq_len, num_samples, hidden_dim) + state = self.solver(diff_func=self.event_model, + dt=sample_dtimes[..., None], # [batch_size, seq_len, num_samples, 1] + z0=state_ti_plus[..., None, :]) # [batch_size, seq_len, 1, hidden_size] + + return state + + def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs): + """Compute the intensity at sampled times, not only event times. + + Args: + time_seqs (tensor): [batch_size, seq_len], times seqs. + time_delta_seqs (tensor): [batch_size, seq_len], time delta seqs. + type_seqs (tensor): [batch_size, seq_len], event type seqs. + sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps. + + Returns: + tensor: [batch_size, num_times, num_mc_sample, num_event_types], + intensity at each timestamp for each event type. + """ + + compute_last_step_only = kwargs.get('compute_last_step_only', False) + + _, right_limits = self.forward(time_delta_seqs, type_seqs) + + # [batch_size, num_sample_times, num_mc_sample, hidden_size] + sample_state_ti = self.compute_states_at_sample_times(right_limits, sample_dtimes) + + if compute_last_step_only: + # [batch_size, 1, num_mc_sample, num_event_types] + sampled_intensities = self.layer_intensity(sample_state_ti[:, -1:, :, :]) + else: + # [batch_size, num_sample_times, num_mc_sample, num_event_types] + sampled_intensities = self.layer_intensity(sample_state_ti) + + return sampled_intensities diff --git a/easy_tpp/model/torch_model/torch_rmtpp.py b/easy_tpp/model/torch_model/torch_rmtpp.py new file mode 100644 index 0000000000000000000000000000000000000000..4d87f9de178a19fef9dab3e15131b9e2926ab5d5 --- /dev/null +++ b/easy_tpp/model/torch_model/torch_rmtpp.py @@ -0,0 +1,117 @@ +import torch +from torch import nn +import math + +from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel + +class RMTPP(TorchBaseModel): + """Torch implementation of Recurrent Marked Temporal Point Processes, KDD 2016. + https://www.kdd.org/kdd2016/papers/files/rpp1081-duA.pdf + """ + + def __init__(self, model_config): + """Initialize the model + + Args: + model_config (EasyTPP.ModelConfig): config of model specs. + """ + super(RMTPP, self).__init__(model_config) + + self.layer_temporal_emb = nn.Linear(1, self.hidden_size) + self.layer_rnn = nn.RNN(input_size=self.hidden_size, hidden_size=self.hidden_size, + num_layers=1, batch_first=True) + + self.hidden_to_intensity_logits = nn.Linear(self.hidden_size, self.num_event_types) + self.b_t = nn.Parameter(torch.zeros(1, self.num_event_types)) + self.w_t = nn.Parameter(torch.zeros(1, self.num_event_types)) + nn.init.xavier_normal_(self.b_t) + nn.init.xavier_normal_(self.w_t) + + def evolve_and_get_intentsity(self, right_hiddens_BNH, dts_BNG): + """ + Eq.11 that computes intensity. + """ + past_influence_BNGM = self.hidden_to_intensity_logits(right_hiddens_BNH[..., None, :]) + intensity_BNGM = (past_influence_BNGM + self.w_t[None, None, :] * dts_BNG[..., None] + + self.b_t[None, None, :]).clamp(max=math.log(1e5)).exp() + return intensity_BNGM + + def forward(self, batch): + """ + Suppose we have inputs with original sequence length N+1 + ts: [t0, t1, ..., t_N] + dts: [0, t1 - t0, t2 - t1, ..., t_N - t_{N-1}] + marks: [k0, k1, ..., k_N] (k0 and kN could be padded marks if t0 and tN correspond to left and right windows) + + Return: + left limits of *intensity* at [t_1, ..., t_N] of shape: (batch_size, seq_len - 1, hidden_dim) + right limits of *hidden states* [t_0, ..., t_{N-1}, t_N] of shape: (batch_size, seq_len, hidden_dim) + We need the right limit of t_N to sample continuation. + """ + + t_BN, dt_BN, marks_BN, _, _ = batch + mark_emb_BNH = self.layer_type_emb(marks_BN) + time_emb_BNH = self.layer_temporal_emb(t_BN[..., None]) + right_hiddens_BNH, _ = self.layer_rnn(mark_emb_BNH + time_emb_BNH) + left_intensity_B_Nm1_M = self.evolve_and_get_intentsity(right_hiddens_BNH[:, :-1, :], dt_BN[:, 1:][...,None]).squeeze(-2) + return left_intensity_B_Nm1_M, right_hiddens_BNH + + + def loglike_loss(self, batch): + """Compute the log-likelihood loss. + + Args: + batch (list): batch input. + + Returns: + tuple: loglikelihood loss and num of events. + """ + ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch + + # compute left intensity and hidden states at event time + # left limits of intensity at [t_1, ..., t_N] + # right limits of hidden states at [t_0, ..., t_{N-1}, t_N] + left_intensity_B_Nm1_M, right_hiddens_BNH = self.forward((ts_BN, dts_BN, marks_BN, None, None)) + right_hiddens_B_Nm1_H = right_hiddens_BNH[..., :-1, :] # discard right limit at t_N for logL + + dts_sample_B_Nm1_G = self.make_dtime_loss_samples(dts_BN[:, 1:]) + intensity_dts_B_Nm1_G_M = self.evolve_and_get_intentsity(right_hiddens_B_Nm1_H, dts_sample_B_Nm1_G) + + event_ll, non_event_ll, num_events = self.compute_loglikelihood( + lambda_at_event=left_intensity_B_Nm1_M, + lambdas_loss_samples=intensity_dts_B_Nm1_G_M, + time_delta_seq=dts_BN[:, 1:], + seq_mask=batch_non_pad_mask[:, 1:], + type_seq=marks_BN[:, 1:] + ) + + # compute loss to minimize + loss = - (event_ll - non_event_ll).sum() + return loss, num_events + + + + def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs): + """Compute the intensity at sampled times, not only event times. + + Args: + time_seq (tensor): [batch_size, seq_len], times seqs. + time_delta_seq (tensor): [batch_size, seq_len], time delta seqs. + event_seq (tensor): [batch_size, seq_len], event type seqs. + sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps. + + Returns: + tensor: [batch_size, num_times, num_mc_sample, num_event_types], + intensity at each timestamp for each event type. + """ + + compute_last_step_only = kwargs.get('compute_last_step_only', False) + + _input = time_seqs, time_delta_seqs, type_seqs, None, None + _, right_hiddens_BNH = self.forward(_input) + + if compute_last_step_only: + sampled_intensities = self.evolve_and_get_intentsity(right_hiddens_BNH[:, -1:, :], sample_dtimes[:, -1:, :]) + else: + sampled_intensities = self.evolve_and_get_intentsity(right_hiddens_BNH, sample_dtimes) # shape: [B, N, G, M] + return sampled_intensities diff --git a/easy_tpp/model/torch_model/torch_robot_thp.py b/easy_tpp/model/torch_model/torch_robot_thp.py new file mode 100644 index 0000000000000000000000000000000000000000..5507f40d5fe2cf037cba27a7a21b60426599934f --- /dev/null +++ b/easy_tpp/model/torch_model/torch_robot_thp.py @@ -0,0 +1,411 @@ +""" +评论罗伯特 Transformer Hawkes Process (Robot-THP) + +结合EasyTPP THP的优点和语义增强型THP的特点: +1. 保持MLP强度函数(灵活表达) +2. 使用可学习的ScaledSoftplus(事件类型特定性) +3. 使用MC采样近似积分(精确计算) +4. 支持多模态特征(语义、偏差、自发/被@) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from typing import Dict, Optional + +from easy_tpp.model.torch_model.torch_baselayer import ScaledSoftplus +from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel + + +class RobotTHP(TorchBaseModel): + """ + 评论罗伯特 Transformer Hawkes Process + + 结合了: + - EasyTPP THP的ScaledSoftplus和MC采样 + - 语义增强型THP的多模态特征融合 + """ + + def __init__(self, model_config): + """ + 初始化 Robot-THP 模型 + + Args: + model_config: 模型配置 + - hidden_size: 隐藏层维度 + - num_event_types: 事件类型数(不含padding) + - num_event_types_pad: 事件类型数(含padding) + - num_layers: Transformer层数 + - num_heads: 注意力头数 + - dropout: Dropout率 + - semantic_dim: 语义向量维度(可选) + - use_semantic: 是否使用语义特征 + - use_deviation: 是否使用偏差特征 + - use_structure_mask: 是否使用结构感知掩码 + """ + super(RobotTHP, self).__init__(model_config) + + # 基础参数 + self.d_model = model_config.hidden_size + self.num_layers = model_config.num_layers + self.n_head = model_config.num_heads + self.dropout = model_config.dropout_rate + + # 语义相关参数(可选) + self.semantic_dim = getattr(model_config, 'semantic_dim', 768) + self.use_semantic = getattr(model_config, 'use_semantic', False) + self.use_deviation = getattr(model_config, 'use_deviation', False) + self.use_structure_mask = getattr(model_config, 'use_structure_mask', False) + + # 事件类型嵌入 + # 使用基类的 layer_type_emb,但也可以自定义 + if not hasattr(self, 'layer_type_emb'): + self.layer_type_emb = nn.Embedding( + self.num_event_types_pad, + self.d_model, + padding_idx=self.pad_token_id + ) + + # 时间编码(正弦位置编码) + self.layer_temporal_encoding = TimePositionalEncoding(self.d_model, device=self.device) + + # 语义向量投影层(如果使用语义特征) + if self.use_semantic: + self.semantic_projection = nn.Linear(self.semantic_dim, self.d_model) + + # 偏差特征嵌入(如果使用偏差特征) + if self.use_deviation: + self.deviation_embedding = nn.Linear(3, self.d_model) # 3个偏差特征 + + # 自发/被@特征嵌入 + self.spontaneous_embedding = nn.Embedding( + num_embeddings=3, # 0=不适用, 1=自发, 2=被@ + embedding_dim=self.d_model + ) + + # Transformer编码器层 + from easy_tpp.model.torch_model.torch_baselayer import EncoderLayer, MultiHeadAttention + + # Feed-forward网络 + self.feed_forward = nn.Sequential( + nn.Linear(self.d_model, self.d_model * 2), + nn.ReLU(), + nn.Dropout(self.dropout), + nn.Linear(self.d_model * 2, self.d_model) + ) + + self.stack_layers = nn.ModuleList( + [EncoderLayer( + self.d_model, + MultiHeadAttention(self.n_head, self.d_model, self.d_model, self.dropout, + output_linear=False), + use_residual=False, + feed_forward=self.feed_forward, + dropout=self.dropout + ) for _ in range(self.num_layers)]) + + # MLP强度函数层(保持你的设计) + self.layer_intensity_hidden = nn.Sequential( + nn.Linear(self.d_model, self.d_model), + nn.ReLU(), + nn.Dropout(self.dropout), + nn.Linear(self.d_model, self.num_event_types) + ) + + # 可学习的ScaledSoftplus(事件类型特定的beta参数) + self.softplus = ScaledSoftplus(self.num_event_types) + + # 基础强度参数(可选,用于额外的灵活性) + self.factor_intensity_base = nn.Parameter( + torch.empty([1, self.num_event_types], device=self.device) + ) + nn.init.xavier_normal_(self.factor_intensity_base) + + def forward(self, time_seqs, type_seqs, attention_mask, + semantic_vectors=None, deviation_features=None, + is_spontaneous=None, structure_mask=None): + """ + 前向传播(符合EasyTPP标准接口) + + Args: + time_seqs: [batch_size, seq_len], 时间序列 + type_seqs: [batch_size, seq_len], 事件类型序列 + attention_mask: [batch_size, seq_len, seq_len], 注意力掩码 + semantic_vectors: [batch_size, seq_len, semantic_dim], 语义向量(可选) + deviation_features: [batch_size, seq_len, 3], 偏差特征(可选) + is_spontaneous: [batch_size, seq_len], 自发/被@标记(可选) + structure_mask: [batch_size, seq_len, seq_len], 结构掩码(可选) + + Returns: + tensor: [batch_size, seq_len, hidden_size], 隐藏状态 + """ + batch_size, seq_len = time_seqs.shape + + # 1. 时间编码 + tem_enc = self.layer_temporal_encoding(time_seqs) # [batch_size, seq_len, d_model] + + # 2. 事件类型嵌入 + enc_output = self.layer_type_emb(type_seqs) # [batch_size, seq_len, d_model] + + # 3. 语义向量(如果提供) + if self.use_semantic and semantic_vectors is not None: + semantic_emb = self.semantic_projection(semantic_vectors) # [batch_size, seq_len, d_model] + enc_output = enc_output + semantic_emb + + # 4. 偏差特征(如果提供) + if self.use_deviation and deviation_features is not None: + deviation_emb = self.deviation_embedding(deviation_features) # [batch_size, seq_len, d_model] + enc_output = enc_output + deviation_emb + + # 5. 自发/被@特征(如果提供) + if is_spontaneous is not None: + # 转换:-1->0(不适用), 0->2(被@), 1->1(自发) + spontaneous_indices = (is_spontaneous + 1).long().clamp(0, 2) + spontaneous_emb = self.spontaneous_embedding(spontaneous_indices) + enc_output = enc_output + spontaneous_emb + + # 6. Transformer编码器(参考EasyTPP THP的设计) + for enc_layer in self.stack_layers: + enc_output = enc_output + tem_enc # 在每层都添加时间编码 + enc_output = enc_layer(enc_output, mask=attention_mask) + + return enc_output + + def loglike_loss(self, batch): + """ + 计算对数似然损失(使用MC采样,符合EasyTPP标准接口) + + Args: + batch (tuple, list): EasyTPP标准批次格式 + batch[0]: time_seqs [batch_size, seq_len] + batch[1]: time_delta_seqs [batch_size, seq_len] + batch[2]: type_seqs [batch_size, seq_len] + batch[3]: batch_non_pad_mask [batch_size, seq_len] + batch[4]: attention_mask [batch_size, seq_len, seq_len] + batch[5]: semantic_vectors [batch_size, seq_len, semantic_dim] (可选) + batch[6]: deviation_features [batch_size, seq_len, 3] (可选) + batch[7]: is_spontaneous [batch_size, seq_len] (可选) + batch[8]: structure_mask [batch_size, seq_len, seq_len] (可选) + + Returns: + tuple: (loss, num_events) - 符合EasyTPP标准 + """ + # EasyTPP标准批次格式:tuple/list + time_seqs = batch[0] + time_delta_seqs = batch[1] + type_seqs = batch[2] + batch_non_pad_mask = batch[3] + attention_mask = batch[4] + + # 提取可选特征(如果提供) + semantic_vectors = batch[5] if len(batch) > 5 and batch[5] is not None else None + deviation_features = batch[6] if len(batch) > 6 and batch[6] is not None else None + is_spontaneous = batch[7] if len(batch) > 7 and batch[7] is not None else None + structure_mask = batch[8] if len(batch) > 8 and batch[8] is not None else None + + # 合并结构掩码和注意力掩码 + if self.use_structure_mask and structure_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask | structure_mask.bool() + + # 1. 计算事件发生时的强度 + # [batch_size, seq_len-1, hidden_size] + # 处理可选特征的切片 + semantic_vec_slice = None + if semantic_vectors is not None and semantic_vectors.dim() > 1: + semantic_vec_slice = semantic_vectors[:, :-1] + + deviation_feat_slice = None + if deviation_features is not None and deviation_features.dim() > 1: + deviation_feat_slice = deviation_features[:, :-1] + + is_spon_slice = None + if is_spontaneous is not None and is_spontaneous.dim() > 1: + is_spon_slice = is_spontaneous[:, :-1] + + struct_mask_slice = None + if structure_mask is not None and structure_mask.dim() > 2: + struct_mask_slice = structure_mask[:, :-1, :-1] + + attn_mask_slice = attention_mask[:, :-1, :-1] if attention_mask.dim() > 2 else attention_mask + + enc_out = self.forward( + time_seqs[:, :-1], + type_seqs[:, :-1], + attn_mask_slice, + semantic_vec_slice, + deviation_feat_slice, + is_spon_slice, + struct_mask_slice + ) + + # MLP计算强度状态 + # [batch_size, seq_len-1, num_event_types] + intensity_states = self.layer_intensity_hidden(enc_out) + self.factor_intensity_base + + # 使用ScaledSoftplus + lambda_at_event = self.softplus(intensity_states) + + # 2. 使用MC采样计算非事件时间的强度积分 + # 2.1 采样时间点 + # [batch_size, seq_len-1, num_sample] + sample_dtimes = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) + + # 2.2 计算采样时间点的强度 + # [batch_size, seq_len-1, num_sample, num_event_types] + state_t_sample = self.compute_states_at_sample_times( + event_states=enc_out, + sample_dtimes=sample_dtimes + ) + lambda_t_sample = self.softplus(state_t_sample) + + # 3. 计算对数似然 + event_ll, non_event_ll, num_events = self.compute_loglikelihood( + lambda_at_event=lambda_at_event, + lambdas_loss_samples=lambda_t_sample, + time_delta_seq=time_delta_seqs[:, 1:], + seq_mask=batch_non_pad_mask[:, 1:], + type_seq=type_seqs[:, 1:] + ) + + # 4. 计算损失 + loss = - (event_ll - non_event_ll).sum() + return loss, num_events + + def compute_states_at_sample_times(self, event_states, sample_dtimes): + """ + 计算采样时间点的强度状态 + + Args: + event_states: [batch_size, seq_len, hidden_size], 事件发生时的隐藏状态 + sample_dtimes: [batch_size, seq_len, num_samples], 采样的时间间隔 + + Returns: + tensor: [batch_size, seq_len, num_samples, num_event_types], 采样时间点的强度状态 + """ + # [batch_size, seq_len, 1, hidden_size] + event_states = event_states[:, :, None, :] + + # [batch_size, seq_len, num_samples, hidden_size] + # 扩展event_states到所有采样时间点 + event_states_expanded = event_states.expand(-1, -1, sample_dtimes.size(-1), -1) + + # MLP计算强度状态 + # [batch_size, seq_len, num_samples, num_event_types] + intensity_states = self.layer_intensity_hidden(event_states_expanded) + + # 添加基础强度 + # [1, 1, 1, num_event_types] + factor_intensity_base = self.factor_intensity_base[None, None, ...] + intensity_states = intensity_states + factor_intensity_base + + return intensity_states + + def compute_intensities_at_sample_times(self, + time_seqs, + time_delta_seqs, + type_seqs, + sample_dtimes, + **kwargs): + """ + 计算采样时间点的强度值(用于生成,符合EasyTPP标准接口) + + Args: + time_seqs: [batch_size, seq_len], 时间序列 + time_delta_seqs: [batch_size, seq_len], 时间间隔序列 + type_seqs: [batch_size, seq_len], 事件类型序列 + sample_dtimes: [batch_size, seq_len, num_samples], 采样的时间间隔 + **kwargs: 其他参数 + - attention_mask: [batch_size, seq_len, seq_len] (可选) + - compute_last_step_only: bool (可选) + - semantic_vectors: [batch_size, seq_len, semantic_dim] (可选) + - deviation_features: [batch_size, seq_len, 3] (可选) + - is_spontaneous: [batch_size, seq_len] (可选) + - structure_mask: [batch_size, seq_len, seq_len] (可选) + + Returns: + tensor: [batch_size, seq_len, num_samples, num_event_types], 采样时间点的强度 + """ + attention_mask = kwargs.get('attention_mask', None) + compute_last_step_only = kwargs.get('compute_last_step_only', False) + + # 提取可选特征 + semantic_vectors = kwargs.get('semantic_vectors', None) + deviation_features = kwargs.get('deviation_features', None) + is_spontaneous = kwargs.get('is_spontaneous', None) + structure_mask = kwargs.get('structure_mask', None) + + # 生成默认的因果掩码(如果未提供) + if attention_mask is None: + batch_size, seq_len = time_seqs.size() + attention_mask = torch.triu( + torch.ones(seq_len, seq_len, device=self.device), + diagonal=1 + ).unsqueeze(0) + attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool) + + # 合并结构掩码 + if self.use_structure_mask and structure_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask | structure_mask.bool() + + # [batch_size, seq_len, hidden_size] + enc_out = self.forward( + time_seqs, + type_seqs, + attention_mask, + semantic_vectors, + deviation_features, + is_spontaneous, + structure_mask + ) + + # [batch_size, seq_len, num_samples, num_event_types] + encoder_output = self.compute_states_at_sample_times(enc_out, sample_dtimes) + + if compute_last_step_only: + lambdas = self.softplus(encoder_output[:, -1:, :, :]) + else: + # [batch_size, seq_len, num_samples, num_event_types] + lambdas = self.softplus(encoder_output) + + return lambdas + + +class TimePositionalEncoding(nn.Module): + """时间位置编码(参考EasyTPP的实现)""" + + def __init__(self, d_model, device=None): + super(TimePositionalEncoding, self).__init__() + self.d_model = d_model + self.device = device + + def forward(self, time_seqs): + """ + Args: + time_seqs: [batch_size, seq_len], 时间序列 + + Returns: + tensor: [batch_size, seq_len, d_model], 时间编码 + """ + batch_size, seq_len = time_seqs.shape + device = time_seqs.device + + # 使用正弦/余弦编码 + div_term = torch.exp( + torch.arange(0, self.d_model, 2, device=device).float() * + (-math.log(10000.0) / self.d_model) + ) + + # [batch_size, seq_len, d_model] + time_emb = torch.zeros(batch_size, seq_len, self.d_model, device=device) + + # 对时间进行编码 + time_expanded = time_seqs.unsqueeze(-1) # [batch_size, seq_len, 1] + + time_emb[:, :, 0::2] = torch.sin(time_expanded * div_term) + time_emb[:, :, 1::2] = torch.cos(time_expanded * div_term) + + return time_emb + diff --git a/easy_tpp/model/torch_model/torch_s2p2.py b/easy_tpp/model/torch_model/torch_s2p2.py new file mode 100644 index 0000000000000000000000000000000000000000..6f609a35bd0cf65a3124cc3ec9d65475cbdcf0a6 --- /dev/null +++ b/easy_tpp/model/torch_model/torch_s2p2.py @@ -0,0 +1,322 @@ +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from easy_tpp.model.torch_model.torch_baselayer import ScaledSoftplus +from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel +from easy_tpp.ssm.models import LLH, Int_Backward_LLH, Int_Forward_LLH + + +class ComplexEmbedding(nn.Module): + def __init__(self, *args, **kwargs): + super(ComplexEmbedding, self).__init__() + self.real_embedding = nn.Embedding(*args, **kwargs) + self.imag_embedding = nn.Embedding(*args, **kwargs) + + self.real_embedding.weight.data *= 1e-3 + self.imag_embedding.weight.data *= 1e-3 + + def forward(self, x): + return torch.complex( + self.real_embedding(x), + self.imag_embedding(x), + ) + + +class IntensityNet(nn.Module): + def __init__(self, input_dim, bias, num_event_types): + super().__init__() + self.intensity_net = nn.Linear(input_dim, num_event_types, bias=bias) + self.softplus = ScaledSoftplus(num_event_types) + + def forward(self, x): + return self.softplus(self.intensity_net(x)) + + +class S2P2(TorchBaseModel): + def __init__(self, model_config): + """Initialize the model + + Args: + model_config (EasyTPP.ModelConfig): config of model specs. + """ + super(S2P2, self).__init__(model_config) + self.n_layers = model_config.num_layers + self.P = model_config.model_specs["P"] # Hidden state dimension + self.H = model_config.hidden_size # Residual stream dimension + self.beta = model_config.model_specs.get("beta", 1.0) + self.bias = model_config.model_specs.get("bias", True) + self.simple_mark = model_config.model_specs.get("simple_mark", True) + + layer_kwargs = dict( + P=self.P, + H=self.H, + dt_init_min=model_config.model_specs.get("dt_init_min", 1e-4), + dt_init_max=model_config.model_specs.get("dt_init_max", 0.1), + act_func=model_config.model_specs.get("act_func", "full_glu"), + dropout_rate=model_config.model_specs.get("dropout_rate", 0.0), + for_loop=model_config.model_specs.get("for_loop", False), + pre_norm=model_config.model_specs.get("pre_norm", True), + post_norm=model_config.model_specs.get("post_norm", False), + simple_mark=self.simple_mark, + relative_time=model_config.model_specs.get("relative_time", False), + complex_values=model_config.model_specs.get("complex_values", True), + ) + + int_forward_variant = model_config.model_specs.get("int_forward_variant", False) + int_backward_variant = model_config.model_specs.get( + "int_backward_variant", False + ) + assert ( + int_forward_variant + int_backward_variant + ) <= 1 # Only one at most is allowed to be specified + + if int_forward_variant: + llh_layer = Int_Forward_LLH + elif int_backward_variant: + llh_layer = Int_Backward_LLH + else: + llh_layer = LLH + + self.backward_variant = int_backward_variant + + self.layers = nn.ModuleList( + [ + llh_layer(**layer_kwargs, is_first_layer=i == 0) + for i in range(self.n_layers) + ] + ) + self.layers_mark_emb = nn.Embedding( + self.num_event_types_pad, + self.H, + ) # One embedding to share amongst layers to be used as input into a layer-specific and input-aware impulse + self.layer_type_emb = None # Remove old embeddings from EasyTPP + self.intensity_net = IntensityNet( + input_dim=self.H, + bias=self.bias, + num_event_types=self.num_event_types, + ) + + def _get_intensity( + self, x_LP: Union[torch.tensor, List[torch.tensor]], right_us_BNH + ) -> torch.Tensor: + """ + Assume time has already been evolved, take a vertical stack of hidden states and produce intensity. + """ + left_u_H = None + for i, layer in enumerate(self.layers): + if isinstance( + x_LP, list + ): # Sometimes it is convenient to pass as a list over the layers rather than a single tensor + left_u_H = layer.depth_pass( + x_LP[i], current_left_u_H=left_u_H, prev_right_u_H=right_us_BNH[i] + ) + else: + left_u_H = layer.depth_pass( + x_LP[..., i, :], + current_left_u_H=left_u_H, + prev_right_u_H=right_us_BNH[i], + ) + + return self.intensity_net(left_u_H) # self.ScaledSoftplus(self.linear(left_u_H)) + + def _evolve_and_get_intensity_at_sampled_dts(self, x_LP, dt_G, right_us_H): + left_u_GH = None + for i, layer in enumerate(self.layers): + x_GP = layer.get_left_limit( + right_limit_P=x_LP[..., i, :], + dt_G=dt_G, + next_left_u_GH=left_u_GH, + current_right_u_H=right_us_H[i], + ) + left_u_GH = layer.depth_pass( + current_left_x_P=x_GP, + current_left_u_H=left_u_GH, + prev_right_u_H=right_us_H[i], + ) + return self.intensity_net(left_u_GH) # self.ScaledSoftplus(self.linear(left_u_GH)) + + def forward( + self, batch, initial_state_BLP: Optional[torch.Tensor] = None, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Batch operations of self._forward + """ + t_BN, dt_BN, marks_BN, batch_non_pad_mask, _ = batch + + right_xs_BNP = [] # including both t_0 and t_N + left_xs_BNm1P = [] + right_us_BNH = [ + None + ] # Start with None as this is the 'input' to the first layer + left_u_BNH, right_u_BNH = None, None + alpha_BNP = self.layers_mark_emb(marks_BN) + + for l_i, layer in enumerate(self.layers): + # for each event, compute the fixed impulse via alpha_m for event i of type m + init_state = ( + initial_state_BLP[:, l_i] if initial_state_BLP is not None else None + ) + + # Returns right limit of xs and us for [t0, t1, ..., tN] + # "layer" returns the right limit of xs at current layer, and us for the next layer (as transformations of ys) + # x_BNP: at time [t_0, t_1, ..., t_{N-1}, t_N] + # next_left_u_BNH: at time [t_0, t_1, ..., t_{N-1}, t_N] -- only available for backward variant + # next_right_u_BNH: at time [t_0, t_1, ..., t_{N-1}, t_N] -- always returned but only used for RT + x_BNP, next_layer_left_u_BNH, next_layer_right_u_BNH = layer.forward( + left_u_BNH, right_u_BNH, alpha_BNP, dt_BN, init_state + ) + assert next_layer_right_u_BNH is not None + + right_xs_BNP.append(x_BNP) + if next_layer_left_u_BNH is None: # NOT backward variant + left_xs_BNm1P.append( + layer.get_left_limit( # current and next at event level + x_BNP[..., :-1, :], # at time [t_0, t_1, ..., t_{N-1}] + dt_BN[..., 1:].unsqueeze( + -1 + ), # with dts [t1-t0, t2-t1, ..., t_N-t_{N-1}] + current_right_u_H=right_u_BNH + if right_u_BNH is None + else right_u_BNH[ + ..., :-1, : + ], # at time [t_0, t_1, ..., t_{N-1}] + next_left_u_GH=left_u_BNH + if left_u_BNH is None + else left_u_BNH[..., 1:, :].unsqueeze( + -2 + ), # at time [t_1, t_2 ..., t_N] + ).squeeze(-2) + ) + right_us_BNH.append(next_layer_right_u_BNH) + + left_u_BNH, right_u_BNH = next_layer_left_u_BNH, next_layer_right_u_BNH + + right_xs_BNLP = torch.stack(right_xs_BNP, dim=-2) + + ret_val = { + "right_xs_BNLP": right_xs_BNLP, # [t_0, ..., t_N] + "right_us_BNH": right_us_BNH, # [t_0, ..., t_N]; list starting with None + } + + if left_u_BNH is not None: # backward variant + ret_val["left_u_BNm1H"] = left_u_BNH[ + ..., 1:, : + ] # The next inputs after last layer -> transformation of ys + else: # NOT backward variant + ret_val["left_xs_BNm1LP"] = torch.stack(left_xs_BNm1P, dim=-2) + + # 'seq_len - 1' left limit for [t_1, ..., t_N] for events (u if available, x if not) + # 'seq_len' right limit for [t_0, t_1, ..., t_{N-1}, t_N] for events xs or us + return ret_val + + def loglike_loss(self, batch, **kwargs): + # hidden states at the left and right limits around event time; note for the shift by 1 in indices: + # consider a sequence [t0, t1, ..., tN] + # Produces the following: + # left_x: x0, x1, x2, ... <-> x_{t_1-}, x_{t_2-}, x_{t_3-}, ..., x_{t_N-} (note the shift in indices) for all layers + # OR ==> <-> u_{t_1-}, u_{t_2-}, u_{t_3-}, ..., u_{t_N-} for last layer + # + # right_x: x0, x1, x2, ... <-> x_{t_0+}, x_{t_1+}, ..., x_{t_N+} for all layers + # right_u: u0, u1, u2, ... <-> u_{t_0+}, u_{t_1+}, ..., u_{t_N+} for all layers + forward_results = self.forward( + batch + ) # N minus 1 comparing with sequence lengths + right_xs_BNLP, right_us_BNH = ( + forward_results["right_xs_BNLP"], + forward_results["right_us_BNH"], + ) + right_us_BNm1H = [ + None if right_u_BNH is None else right_u_BNH[:, :-1, :] + for right_u_BNH in right_us_BNH + ] + + ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch + + # evaluate intensity values at each event *from the left limit*, _get_intensity: [LP] -> [M] + # left_xs_B_Nm1_LP = left_xs_BNm1LP[:, :-1, ...] # discard the left limit of t_N + # Note: no need to discard the left limit of t_N because "marks_mask" will deal with it + if "left_u_BNm1H" in forward_results: # ONLY backward variant + intensity_B_Nm1_M = self.intensity_net( + forward_results["left_u_BNm1H"] + ) # self.ScaledSoftplus(self.linear(forward_results["left_u_BNm1H"])) + else: # NOT backward variant + intensity_B_Nm1_M = self._get_intensity( + forward_results["left_xs_BNm1LP"], right_us_BNm1H + ) + + # sample dt in each interval for MC: [batch_size, num_times=N-1, num_mc_sample] + # N-1 because we only consider the intervals between N events + # G for grid points + dts_sample_B_Nm1_G = self.make_dtime_loss_samples(dts_BN[:, 1:]) + + # evaluate intensity at dt_samples for MC *from the left limit* after decay -> shape (B, N-1, MC, M) + intensity_dts_B_Nm1_G_M = self._evolve_and_get_intensity_at_sampled_dts( + right_xs_BNLP[ + :, :-1 + ], # x_{t_i+} will evolve up to x_{t_{i+1}-} and many times between for i=0,...,N-1 + dts_sample_B_Nm1_G, + right_us_BNm1H, + ) + + event_ll, non_event_ll, num_events = self.compute_loglikelihood( + lambda_at_event=intensity_B_Nm1_M, + lambdas_loss_samples=intensity_dts_B_Nm1_G_M, + time_delta_seq=dts_BN[:, 1:], + seq_mask=batch_non_pad_mask[:, 1:], + type_seq=marks_BN[:, 1:], + ) + + # compute loss to optimize + loss = -(event_ll - non_event_ll).sum() + + return loss, num_events + + def compute_intensities_at_sample_times( + self, event_times_BN, inter_event_times_BN, marks_BN, sample_dtimes, **kwargs + ): + """Compute the intensity at sampled times, not only event times. *from the left limit* + + Args: + time_seq (tensor): [batch_size, seq_len], times seqs. + time_delta_seq (tensor): [batch_size, seq_len], time delta seqs. + event_seq (tensor): [batch_size, seq_len], event type seqs. + sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps. + + Returns: + tensor: [batch_size, num_times, num_mc_sample, num_event_types], + intensity at each timestamp for each event type. + """ + + compute_last_step_only = kwargs.get("compute_last_step_only", False) + + # assume inter_event_times_BN always starts from 0 + _input = event_times_BN, inter_event_times_BN, marks_BN, None, None + + # 'seq_len - 1' left limit for [t_1, ..., t_N] + # 'seq_len' right limit for [t_0, t_1, ..., t_{N-1}, t_N] + + forward_results = self.forward( + _input + ) # N minus 1 comparing with sequence lengths + right_xs_BNLP, right_us_BNH = ( + forward_results["right_xs_BNLP"], + forward_results["right_us_BNH"], + ) + + if ( + compute_last_step_only + ): # fix indices for right_us_BNH: list [None, tensor([BNH]), ...] + right_us_B1H = [ + None if right_u_BNH is None else right_u_BNH[:, -1:, :] + for right_u_BNH in right_us_BNH + ] + sampled_intensity = self._evolve_and_get_intensity_at_sampled_dts( + right_xs_BNLP[:, -1:, :, :], sample_dtimes[:, -1:, :], right_us_B1H + ) # equiv. to right_xs_BNLP[:, -1, :, :][:, None, ...] + else: + sampled_intensity = self._evolve_and_get_intensity_at_sampled_dts( + right_xs_BNLP, sample_dtimes, right_us_BNH + ) + return sampled_intensity # [B, N, MC, M] diff --git a/easy_tpp/model/torch_model/torch_sahp.py b/easy_tpp/model/torch_model/torch_sahp.py new file mode 100644 index 0000000000000000000000000000000000000000..224cde87b9e6152cc6f75b61e270448b6107f410 --- /dev/null +++ b/easy_tpp/model/torch_model/torch_sahp.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn + +from easy_tpp.model.torch_model.torch_baselayer import EncoderLayer, MultiHeadAttention, \ + TimeShiftedPositionalEncoding, ScaledSoftplus +from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel + + +class SAHP(TorchBaseModel): + """Torch implementation of Self-Attentive Hawkes Process, ICML 2020. + Part of the code is collected from https://github.com/yangalan123/anhp-andtt/blob/master/sahp + + I slightly modify the original code because it is not stable. + + """ + + def __init__(self, model_config): + """Initialize the model + + Args: + model_config (EasyTPP.ModelConfig): config of model specs. + """ + super(SAHP, self).__init__(model_config) + self.d_model = model_config.hidden_size + self.d_time = model_config.time_emb_size + + self.use_norm = model_config.use_ln + + # position vector, used for temporal encoding + self.layer_position_emb = TimeShiftedPositionalEncoding(d_model=self.d_model, + device=self.device) + + self.n_layers = model_config.num_layers + self.n_head = model_config.num_heads + self.dropout = model_config.dropout_rate + + # convert hidden vectors into a scalar + self.layer_intensity_hidden = nn.Linear(self.d_model, self.num_event_types) + self.softplus = ScaledSoftplus(self.num_event_types) # learnable mark-specific beta + + self.stack_layers = nn.ModuleList( + [EncoderLayer( + self.d_model, + MultiHeadAttention(self.n_head, self.d_model, self.d_model, self.dropout, + output_linear=False), + + use_residual=False, + dropout=self.dropout + ) for _ in range(self.n_layers)]) + + if self.use_norm: + self.norm = nn.LayerNorm(self.d_model) + + # Equation (12): mu = GELU(h*W_mu) + self.mu = nn.Sequential( + nn.Linear(self.d_model, self.num_event_types, bias=False), + nn.GELU(), + ) + + # Equation (13): eta = GELU(h*W_eta) + self.eta = nn.Sequential( + nn.Linear(self.d_model, self.num_event_types, bias=False), + nn.GELU(), + ) + + # Equation (14): gamma = Softplus(h*W_gamma) + self.gamma = nn.Sequential( + nn.Linear(self.d_model, self.num_event_types, bias=False), + nn.Softplus(), + ) + + def state_decay(self, encode_state, duration_t): + """Equation (15), which computes the pre-intensity states + + Args: + encode_state (tensor): [batch_size, seq_len, hidden_size]. + duration_t (tensor): [batch_size, seq_len, num_sample]. + + Returns: + tensor: hidden states at event times. + """ + mu, eta, gamma = self.mu(encode_state), self.eta(encode_state), self.gamma(encode_state) + + # [batch_size, hidden_dim] + states = mu + (eta - mu) * torch.exp(-gamma * duration_t) + return states + + def forward(self, time_seqs, time_delta_seqs, event_seqs, attention_mask): + """Call the model + + Args: + time_seqs (tensor): [batch_size, seq_len], timestamp seqs. + time_delta_seqs (tensor): [batch_size, seq_len], inter-event time seqs. + event_seqs (tensor): [batch_size, seq_len], event type seqs. + attention_mask (tensor): [batch_size, seq_len, hidden_size], attention masks. + + Returns: + tensor: hidden states at event times. + """ + type_embedding = self.layer_type_emb(event_seqs) + position_embedding = self.layer_position_emb(time_seqs, time_delta_seqs) + + enc_output = type_embedding + position_embedding + + for enc_layer in self.stack_layers: + enc_output = enc_layer( + enc_output, + mask=attention_mask) + if self.use_norm: + enc_output = self.norm(enc_output) + # [batch_size, seq_len, hidden_dim] + return enc_output + + def loglike_loss(self, batch): + """Compute the log-likelihood loss. + + Args: + batch (tuple, list): batch input. + + Returns: + list: loglike loss, num events. + """ + time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask = batch + + enc_out = self.forward(time_seqs[:, :-1], time_delta_seqs[:, :-1], type_seqs[:, :-1], attention_mask[:, :-1, :-1]) + + cell_t = self.state_decay(encode_state=enc_out, + duration_t=time_delta_seqs[:, 1:, None]) + + # [batch_size, seq_len, num_event_types] + lambda_at_event = self.softplus(cell_t) + + # 2. compute non-event-loglik (using MC sampling to compute integral) + # 2.1 sample times + # [batch_size, seq_len, num_sample] + sample_dtimes = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) + + # 2.2 compute intensities at sampled times + # [batch_size, num_times = max_len - 1, num_sample, event_num] + state_t_sample = self.compute_states_at_sample_times(encode_state=enc_out, + sample_dtimes=sample_dtimes) + lambda_t_sample = self.softplus(state_t_sample) + + event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, + lambdas_loss_samples=lambda_t_sample, + time_delta_seq=time_delta_seqs[:, 1:], + seq_mask=batch_non_pad_mask[:, 1:], + type_seq=type_seqs[:, 1:]) + + # compute loss to minimize + loss = - (event_ll - non_event_ll).sum() + return loss, num_events + + def compute_states_at_sample_times(self, + encode_state, + sample_dtimes): + """Compute the hidden states at sampled times. + + Args: + encode_state (tensor): three tensors with each shape [batch_size, seq_len, hidden_size]. + sample_dtimes (tensor): [batch_size, seq_len, num_samples]. + + Returns: + tensor: [batch_size, seq_len, num_samples, hidden_size], hidden state at each sampled time. + """ + + cell_states = self.state_decay(encode_state[:, :, None, :], + sample_dtimes[:, :, :, None]) + + return cell_states + + def compute_intensities_at_sample_times(self, + time_seqs, + time_delta_seqs, + type_seqs, + sample_dtimes, + **kwargs): + """Compute hidden states at sampled times. + + Args: + time_seqs (tensor): [batch_size, seq_len], times seqs. + time_delta_seqs (tensor): [batch_size, seq_len], time delta seqs. + type_seqs (tensor): [batch_size, seq_len], event type seqs. + sample_dtimes (tensor): [batch_size, seq_len, num_samples], sampled inter-event timestamps. + + Returns: + tensor: [batch_size, seq_len, num_samples, num_event_types], intensity at all sampled times. + """ + + attention_mask = kwargs.get('attention_mask', None) + compute_last_step_only = kwargs.get('compute_last_step_only', False) + + if attention_mask is None: + batch_size, seq_len = time_seqs.size() + attention_mask = torch.triu(torch.ones(seq_len, seq_len, device=self.device), diagonal=1).unsqueeze(0) + attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool) + + # [batch_size, seq_len, num_samples] + enc_out = self.forward(time_seqs, time_delta_seqs, type_seqs, attention_mask) + + # [batch_size, seq_len, num_samples, hidden_size] + encoder_output = self.compute_states_at_sample_times(enc_out, sample_dtimes) + + if compute_last_step_only: + lambdas = self.softplus(encoder_output[:, -1:, :, :]) + else: + # [batch_size, seq_len, num_samples, num_event_types] + lambdas = self.softplus(encoder_output) + return lambdas diff --git a/easy_tpp/model/torch_model/torch_thinning.py b/easy_tpp/model/torch_model/torch_thinning.py new file mode 100644 index 0000000000000000000000000000000000000000..74721957059844c1b9278b57b2013df85679e30a --- /dev/null +++ b/easy_tpp/model/torch_model/torch_thinning.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +from easy_tpp.utils import logger + + +class EventSampler(nn.Module): + """Event Sequence Sampler based on thinning algorithm, which corresponds to Algorithm 2 of + The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process, + https://arxiv.org/abs/1612.09328. + + The implementation uses code from https://github.com/yangalan123/anhp-andtt/blob/master/anhp/esm/thinning.py. + """ + + def __init__(self, num_sample, num_exp, over_sample_rate, num_samples_boundary, dtime_max, patience_counter, + device): + """Initialize the event sampler. + + Args: + num_sample (int): number of sampled next event times via thinning algo for computing predictions. + num_exp (int): number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + over_sample_rate (float): multiplier for the intensity up bound. + num_samples_boundary (int): number of sampled event times to compute the boundary of the intensity. + dtime_max (float): max value of delta times in sampling + patience_counter (int): the maximum iteration used in adaptive thinning. + device (torch.device): torch device index to select. + """ + super(EventSampler, self).__init__() + self.num_sample = num_sample + self.num_exp = num_exp + self.over_sample_rate = over_sample_rate + self.num_samples_boundary = num_samples_boundary + self.dtime_max = dtime_max + self.patience_counter = patience_counter + self.device = device + + def compute_intensity_upper_bound(self, time_seq, time_delta_seq, event_seq, intensity_fn, + compute_last_step_only): + # logger.critical(f'time_seq: {time_seq}') + # logger.critical(f'time_delta_seq: {time_delta_seq}') + # logger.critical(f'event_seq: {event_seq}') + # logger.critical(f'intensity_fn: {intensity_fn}') + # logger.critical(f'compute_last_step_only: {compute_last_step_only}') + """Compute the upper bound of intensity at each event timestamp. + + Args: + time_seq (tensor): [batch_size, seq_len], timestamp seqs. + time_delta_seq (tensor): [batch_size, seq_len], time delta seqs. + event_seq (tensor): [batch_size, seq_len], event type seqs. + intensity_fn (fn): a function that computes the intensity. + compute_last_step_only (bool): wheter to compute the last time step pnly. + + Returns: + tensor: [batch_size, seq_len] + """ + batch_size, seq_len = time_seq.size() + + # [1, 1, num_samples_boundary] + time_for_bound_sampled = torch.linspace(start=0.0, + end=1.0, + steps=self.num_samples_boundary, + device=self.device)[None, None, :] + + # [batch_size, seq_len, num_samples_boundary] + dtime_for_bound_sampled = time_delta_seq[:, :, None] * time_for_bound_sampled + + # [batch_size, seq_len, num_samples_boundary, event_num] + intensities_for_bound = intensity_fn(time_seq, + time_delta_seq, + event_seq, + dtime_for_bound_sampled, + max_steps=seq_len, + compute_last_step_only=compute_last_step_only) + + # [batch_size, seq_len] + bounds = intensities_for_bound.sum(dim=-1).max(dim=-1)[0] * self.over_sample_rate + + return bounds + + def sample_exp_distribution(self, sample_rate): + """Sample an exponential distribution. + + Args: + sample_rate (tensor): [batch_size, seq_len], intensity rate. + + Returns: + tensor: [batch_size, seq_len, num_exp], exp numbers at each event timestamp. + """ + + batch_size, seq_len = sample_rate.size() + + # For fast approximation, we reuse the rnd for all samples + # [batch_size, seq_len, num_exp] + exp_numbers = torch.empty(size=[batch_size, seq_len, self.num_exp], + dtype=torch.float32, + device=self.device) + + # [batch_size, seq_len, num_exp] + # exp_numbers.exponential_(1.0) + exp_numbers.exponential_(1.0) + + # [batch_size, seq_len, num_exp] + # exp_numbers = torch.tile(exp_numbers, [1, 1, self.num_sample, 1]) + + # [batch_size, seq_len, num_exp] + # div by sample_rate is equivalent to exp(sample_rate), + # see https://en.wikipedia.org/wiki/Exponential_distribution + exp_numbers = exp_numbers / sample_rate[:, :, None] + + return exp_numbers + + def sample_uniform_distribution(self, intensity_upper_bound): + """Sample an uniform distribution + + Args: + intensity_upper_bound (tensor): upper bound intensity computed in the previous step. + + Returns: + tensor: [batch_size, seq_len, num_sample, num_exp] + """ + batch_size, seq_len = intensity_upper_bound.size() + + unif_numbers = torch.empty(size=[batch_size, seq_len, self.num_sample, self.num_exp], + dtype=torch.float32, + device=self.device) + unif_numbers.uniform_(0.0, 1.0) + + return unif_numbers + + def sample_accept(self, unif_numbers, sample_rate, total_intensities, exp_numbers): + """Do the sample-accept process. + + For the accumulated exp (delta) samples drawn for each event timestamp, find (from left to right) the first + that makes the criterion < 1 and accept it as the sampled next-event time. If all exp samples are rejected + (criterion >= 1), then we set the sampled next-event time dtime_max. + + Args: + unif_numbers (tensor): [batch_size, max_len, num_sample, num_exp], sampled uniform random number. + sample_rate (tensor): [batch_size, max_len], sample rate (intensity). + total_intensities (tensor): [batch_size, seq_len, num_sample, num_exp] + exp_numbers (tensor): [batch_size, seq_len, num_sample, num_exp]: sampled exp numbers (delta in Algorithm 2). + + Returns: + result (tensor): [batch_size, seq_len, num_sample], sampled next-event times. + """ + + # [batch_size, max_len, num_sample, num_exp] + criterion = unif_numbers * sample_rate[:, :, None, None] / total_intensities + + # [batch_size, max_len, num_sample, num_exp] + masked_crit_less_than_1 = torch.where(criterion<1,1,0) + + # [batch_size, max_len, num_sample] + non_accepted_filter = (1-masked_crit_less_than_1).all(dim=3) + + # [batch_size, max_len, num_sample] + first_accepted_indexer = masked_crit_less_than_1.argmax(dim=3) + + # [batch_size, max_len, num_sample,1] + # indexer must be unsqueezed to 4D to match the number of dimensions of exp_numbers + result_non_accepted_unfiltered = torch.gather(exp_numbers, 3, first_accepted_indexer.unsqueeze(3)) + + # [batch_size, max_len, num_sample,1] + result = torch.where(non_accepted_filter.unsqueeze(3), torch.tensor(self.dtime_max), result_non_accepted_unfiltered) + + # [batch_size, max_len, num_sample] + result = result.squeeze(dim=-1) + + return result + + def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_boundary, + intensity_fn, compute_last_step_only=False): + """Compute next event time based on Thinning algorithm. + + Args: + time_seq (tensor): [batch_size, seq_len], timestamp seqs. + time_delta_seq (tensor): [batch_size, seq_len], time delta seqs. + event_seq (tensor): [batch_size, seq_len], event type seqs. + dtime_boundary (tensor): [batch_size, seq_len], dtime upper bound. + intensity_fn (fn): a function to compute the intensity. + compute_last_step_only (bool, optional): whether to compute last event timestep only. Defaults to False. + + Returns: + tuple: next event time prediction and weight. + """ + # 1. compute the upper bound of the intensity at each timestamp + # the last event has no label (no next event), so we drop it + # [batch_size, seq_len=max_len - 1] + intensity_upper_bound = self.compute_intensity_upper_bound(time_seq, + time_delta_seq, + event_seq, + intensity_fn, + compute_last_step_only) + + # 2. draw exp distribution with intensity = intensity_upper_bound + # we apply fast approximation, i.e., re-use exp sample times for computation + # [batch_size, seq_len, num_exp] + exp_numbers = self.sample_exp_distribution(intensity_upper_bound) + exp_numbers = torch.cumsum(exp_numbers, dim=-1) + + # 3. compute intensity at sampled times from exp distribution + # [batch_size, seq_len, num_exp, event_num] + intensities_at_sampled_times = intensity_fn(time_seq, + time_delta_seq, + event_seq, + exp_numbers, + max_steps=time_seq.size(1), + compute_last_step_only=compute_last_step_only) + + # [batch_size, seq_len, num_exp] + total_intensities = intensities_at_sampled_times.sum(dim=-1) + + # add one dim of num_sample: re-use the intensity for samples for prediction + # [batch_size, seq_len, num_sample, num_exp] + total_intensities = torch.tile(total_intensities[:, :, None, :], [1, 1, self.num_sample, 1]) + + # [batch_size, seq_len, num_sample, num_exp] + exp_numbers = torch.tile(exp_numbers[:, :, None, :], [1, 1, self.num_sample, 1]) + + # 4. draw uniform distribution + # [batch_size, seq_len, num_sample, num_exp] + unif_numbers = self.sample_uniform_distribution(intensity_upper_bound) + + # 5. find out accepted intensities + # [batch_size, seq_len, num_sample] + res = self.sample_accept(unif_numbers, intensity_upper_bound, total_intensities, exp_numbers) + + # [batch_size, seq_len, num_sample] + weights = torch.ones_like(res)/res.shape[2] + + # add a upper bound here in case it explodes, e.g., in ODE models + return res.clamp(max=1e5), weights diff --git a/easy_tpp/model/torch_model/torch_thp.py b/easy_tpp/model/torch_model/torch_thp.py new file mode 100644 index 0000000000000000000000000000000000000000..a5effc23ba9070f6b45ef67ea107c0cbe0b8a193 --- /dev/null +++ b/easy_tpp/model/torch_model/torch_thp.py @@ -0,0 +1,194 @@ +import torch +import torch.nn as nn + +from easy_tpp.model.torch_model.torch_baselayer import EncoderLayer, MultiHeadAttention, TimePositionalEncoding, ScaledSoftplus +from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel + + +class THP(TorchBaseModel): + """Torch implementation of Transformer Hawkes Process, ICML 2020, https://arxiv.org/abs/2002.09291. + Note: Part of the code is collected from https://github.com/yangalan123/anhp-andtt/tree/master/thp. + """ + + def __init__(self, model_config): + """Initialize the model + + Args: + model_config (EasyTPP.ModelConfig): config of model specs. + """ + super(THP, self).__init__(model_config) + self.d_model = model_config.hidden_size + self.d_time = model_config.time_emb_size + self.use_norm = model_config.use_ln + + self.n_layers = model_config.num_layers + self.n_head = model_config.num_heads + self.dropout = model_config.dropout_rate + + self.layer_temporal_encoding = TimePositionalEncoding(self.d_model, device=self.device) + + self.factor_intensity_base = nn.Parameter(torch.empty([1, self.num_event_types], device=self.device)) + self.factor_intensity_decay = nn.Parameter(torch.empty([1, self.num_event_types], device=self.device)) + nn.init.xavier_normal_(self.factor_intensity_base) + nn.init.xavier_normal_(self.factor_intensity_decay) + + # convert hidden vectors into event-type-sized vector + self.layer_intensity_hidden = nn.Linear(self.d_model, self.num_event_types) + self.softplus = ScaledSoftplus(self.num_event_types) # learnable mark-specific beta + + # Add MLP layer + # Equation (5) + self.feed_forward = nn.Sequential( + nn.Linear(self.d_model, self.d_model * 2), + nn.ReLU(), + nn.Linear(self.d_model * 2, self.d_model) + ) + + self.stack_layers = nn.ModuleList( + [EncoderLayer( + self.d_model, + MultiHeadAttention(self.n_head, self.d_model, self.d_model, self.dropout, + output_linear=False), + use_residual=False, + feed_forward=self.feed_forward, + dropout=self.dropout + ) for _ in range(self.n_layers)]) + + def forward(self, time_seqs, type_seqs, attention_mask): + """Call the model + + Args: + time_seqs (tensor): [batch_size, seq_len], timestamp seqs. + type_seqs (tensor): [batch_size, seq_len], event type seqs. + attention_mask (tensor): [batch_size, seq_len, hidden_size], attention masks. + + Returns: + tensor: hidden states at event times. + """ + # [batch_size, seq_len, hidden_size] + tem_enc = self.layer_temporal_encoding(time_seqs) + enc_output = self.layer_type_emb(type_seqs) + + # [batch_size, seq_len, hidden_size] + for enc_layer in self.stack_layers: + enc_output += tem_enc + enc_output = enc_layer( + enc_output, + mask=attention_mask) + + return enc_output + + def loglike_loss(self, batch): + """Compute the loglike loss. + + Args: + batch (tuple, list): batch input. + + Returns: + tuple: loglike loss, num events. + """ + time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask = batch + + # 1. compute event-loglik + # [batch_size, seq_len, hidden_size] + enc_out = self.forward(time_seqs[:, :-1], type_seqs[:, :-1], attention_mask[:, :-1, :-1]) + + # [batch_size, seq_len, num_event_types] + # update time decay based on Equation (6) + # [1, 1, num_event_types] + factor_intensity_decay = self.factor_intensity_decay[None, ...] + factor_intensity_base = self.factor_intensity_base[None, ...] + + # update time decay based on Equation (6) + # [batch_size, seq_len, num_event_types] + intensity_states = factor_intensity_decay * time_delta_seqs[:, 1:, None] + self.layer_intensity_hidden( + enc_out) + factor_intensity_base + + lambda_at_event = self.softplus(intensity_states) + + # 2. compute non-event-loglik (using MC sampling to compute integral) + # 2.1 sample dtimes + # [batch_size, seq_len, num_sample] + sample_dtimes = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) + + # 2.2 compute intensities at sampled times + # [batch_size, num_times = max_len - 1, num_sample, event_num] + state_t_sample = self.compute_states_at_sample_times(event_states=enc_out, + sample_dtimes=sample_dtimes) + lambda_t_sample = self.softplus(state_t_sample) + + event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, + lambdas_loss_samples=lambda_t_sample, + time_delta_seq=time_delta_seqs[:, 1:], + seq_mask=batch_non_pad_mask[:, 1:], + type_seq=type_seqs[:, 1:]) + + # compute loss to minimize + loss = - (event_ll - non_event_ll).sum() + return loss, num_events + + def compute_states_at_sample_times(self, event_states, sample_dtimes): + """Compute the hidden states at sampled times. + + Args: + event_states (tensor): [batch_size, seq_len, hidden_size]. + sample_dtimes (tensor): [batch_size, seq_len, num_samples]. + + Returns: + tensor: hidden state at each sampled time. + """ + # [batch_size, seq_len, 1, hidden_size] + event_states = event_states[:, :, None, :] + + # [batch_size, seq_len, num_samples, 1] + sample_dtimes = sample_dtimes[..., None] + + # [1, 1, 1, num_event_types] + factor_intensity_decay = self.factor_intensity_decay[None, None, ...] + factor_intensity_base = self.factor_intensity_base[None, None, ...] + + # update time decay based on Equation (6) + # [batch_size, seq_len, num_samples, num_event_types] + intensity_states = factor_intensity_decay * sample_dtimes + self.layer_intensity_hidden( + event_states) + factor_intensity_base + + return intensity_states + + def compute_intensities_at_sample_times(self, + time_seqs, + time_delta_seqs, + type_seqs, + sample_dtimes, + **kwargs): + """Compute hidden states at sampled times. + + Args: + time_seqs (tensor): [batch_size, seq_len], times seqs. + time_delta_seqs (tensor): [batch_size, seq_len], time delta seqs. + type_seqs (tensor): [batch_size, seq_len], event type seqs. + sample_dtimes (tensor): [batch_size, seq_len, num_samples], sampled inter-event timestamps. + + Returns: + tensor: [batch_size, seq_len, num_samples, num_event_types], intensity at all sampled times. + """ + + attention_mask = kwargs.get('attention_mask', None) + compute_last_step_only = kwargs.get('compute_last_step_only', False) + + if attention_mask is None: + batch_size, seq_len = time_seqs.size() + attention_mask = torch.triu(torch.ones(seq_len, seq_len, device=self.device), diagonal=1).unsqueeze(0) + attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool) + + # [batch_size, seq_len, num_samples] + enc_out = self.forward(time_seqs, type_seqs, attention_mask) + + # [batch_size, seq_len, num_samples, hidden_size] + encoder_output = self.compute_states_at_sample_times(enc_out, sample_dtimes) + + if compute_last_step_only: + lambdas = self.softplus(encoder_output[:, -1:, :, :]) + else: + # [batch_size, seq_len, num_samples, num_event_types] + lambdas = self.softplus(encoder_output) + return lambdas diff --git a/easy_tpp/preprocess/ROBERT_FEATURES_GUIDE.md b/easy_tpp/preprocess/ROBERT_FEATURES_GUIDE.md new file mode 100644 index 0000000000000000000000000000000000000000..0f6c1d1f19391b84c9c0991ad70b72abd7fecf2f --- /dev/null +++ b/easy_tpp/preprocess/ROBERT_FEATURES_GUIDE.md @@ -0,0 +1,183 @@ +# 评论罗伯特特征使用指南 + +## 概述 + +本指南说明如何在EasyTPP框架中使用RobotTHP模型,并加载语义特征、偏差特征等自定义特征。 + +## 文件说明 + +1. **robert_dataset.py**: 扩展的TPPDataset,支持加载语义特征、偏差特征等 +2. **robert_tokenizer.py**: 扩展的EventTokenizer,支持自定义特征的padding和批处理 +3. **train_robot_thp_with_features.py**: 完整的使用示例 + +## 数据格式 + +### 输入数据字典格式 + +```python +data_dict = { + 'time_seqs': [[0.0, 10.5, 25.3], ...], # 时间序列列表 + 'type_seqs': [[0, 1, 2], ...], # 事件类型序列列表 + 'time_delta_seqs': [[0.0, 10.5, 14.8], ...], # 时间间隔序列列表 + 'semantic_vectors': [[[0.1]*768, [0.2]*768], ...], # 语义向量(可选) + 'deviation_features': [[[0.0, 0.0, 0.0], [0.7, 0.5, 0.3]], ...], # 偏差特征(可选) + 'is_spontaneous': [[-1.0, 1.0, -1.0], ...] # 自发/被@标记(可选) +} +``` + +### 特征说明 + +- **semantic_vectors**: `[num_seqs, seq_len, semantic_dim]`,BERT语义向量 +- **deviation_features**: `[num_seqs, seq_len, 3]`,偏差特征 [语境偏差, 情感偏差, 困惑度] +- **is_spontaneous**: `[num_seqs, seq_len]`,标记值: + - `-1.0`: 不适用(非罗伯特评论) + - `0.0`: 被@(罗伯特被原帖作者@) + - `1.0`: 自发(罗伯特自发评论) + +## 使用方法 + +### 1. 准备数据 + +```python +from easy_tpp.preprocess.robert_dataset import RobertTPPDataset + +# 准备数据字典 +data_dict = { + 'time_seqs': [...], + 'type_seqs': [...], + 'time_delta_seqs': [...], + 'semantic_vectors': [...], # 可选 + 'deviation_features': [...], # 可选 + 'is_spontaneous': [...] # 可选 +} + +# 创建数据集 +dataset = RobertTPPDataset(data_dict) +``` + +### 2. 创建分词器 + +```python +from easy_tpp.preprocess.robert_tokenizer import RobertEventTokenizer +from easy_tpp.config_factory import DataSpecConfig + +config = DataSpecConfig.parse_from_yaml_config({ + 'num_event_types': 4, + 'batch_size': 32, + 'pad_token_id': 4 +}) + +tokenizer = RobertEventTokenizer( + config, + use_semantic=True, # 是否使用语义特征 + use_deviation=True, # 是否使用偏差特征 + semantic_dim=768 # 语义向量维度 +) +``` + +### 3. 创建数据加载器 + +```python +from easy_tpp.preprocess.data_collator import TPPDataCollator +from torch.utils.data import DataLoader + +data_collator = TPPDataCollator( + tokenizer=tokenizer, + return_tensors='pt', + max_length=tokenizer.model_max_length, + padding=True, + truncation=False +) + +data_loader = DataLoader( + dataset, + collate_fn=data_collator, + batch_size=32, + shuffle=True +) +``` + +### 4. 在模型中使用 + +RobotTHP模型的`loglike_loss`方法会自动从batch中提取这些特征: + +```python +from easy_tpp.model import TorchRobotTHP + +model = TorchRobotTHP(model_config) + +for batch in data_loader: + batch_values = batch.values() # 转换为tuple/list + loss, num_events = model.loglike_loss(batch_values) +``` + +## 批次格式 + +批次数据格式(tuple/list): +```python +batch = ( + time_seqs, # [0] [batch_size, seq_len] + time_delta_seqs, # [1] [batch_size, seq_len] + type_seqs, # [2] [batch_size, seq_len] + batch_non_pad_mask, # [3] [batch_size, seq_len] + attention_mask, # [4] [batch_size, seq_len, seq_len] + semantic_vectors, # [5] [batch_size, seq_len, semantic_dim] (可选) + deviation_features, # [6] [batch_size, seq_len, 3] (可选) + is_spontaneous, # [7] [batch_size, seq_len] (可选) + structure_mask # [8] [batch_size, seq_len, seq_len] (可选) +) +``` + +## 完整示例 + +参考 `examples/train_robot_thp_with_features.py` 获取完整的使用示例。 + +## 注意事项 + +1. **特征对齐**: 确保所有特征序列的长度与时间序列一致 +2. **Padding值**: + - 语义向量和偏差特征:padding使用0.0 + - is_spontaneous:padding使用-1.0(不适用) +3. **可选特征**: 如果某个特征未提供,模型会自动跳过该特征的处理 +4. **配置一致性**: 确保模型配置中的`use_semantic`和`use_deviation`与tokenizer设置一致 + +## 与标准EasyTPP的集成 + +要完全集成到EasyTPP框架中,需要: + +1. **自定义数据加载器**: 继承`TPPDataLoader`并重写`_build_input_from_json`方法 +2. **配置文件**: 在配置文件中指定使用自定义数据集和分词器 +3. **模型配置**: 设置`use_semantic=True`和`use_deviation=True` + +## 从JSON文件加载 + +如果你的数据是JSON格式,可以参考以下方式加载: + +```python +import json +import numpy as np + +# 加载JSON数据 +with open('your_data.json', 'r') as f: + json_data = json.load(f) + +# 提取特征 +time_seqs = [[event['time_since_start'] for event in seq] for seq in json_data] +type_seqs = [[event['type_event'] for event in seq] for seq in json_data] +time_delta_seqs = [[event['time_since_last_event'] for event in seq] for seq in json_data] + +# 提取语义特征(如果存在) +semantic_vectors = None +if 'semantic_vectors' in json_data[0][0]: + semantic_vectors = [[event['semantic_vectors'] for event in seq] for seq in json_data] + +# 创建数据字典 +data_dict = { + 'time_seqs': time_seqs, + 'type_seqs': type_seqs, + 'time_delta_seqs': time_delta_seqs, + 'semantic_vectors': semantic_vectors, + # ... 其他特征 +} +``` + diff --git a/easy_tpp/preprocess/__init__.py b/easy_tpp/preprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..681f022bf725f2322d0e751c78e226bd44189272 --- /dev/null +++ b/easy_tpp/preprocess/__init__.py @@ -0,0 +1,6 @@ +from easy_tpp.preprocess.data_loader import TPPDataLoader, EventTokenizer, TPPDataset, get_data_loader + +__all__ = ['TPPDataLoader', + 'EventTokenizer', + 'TPPDataset', + 'get_data_loader'] diff --git a/easy_tpp/preprocess/data_collator.py b/easy_tpp/preprocess/data_collator.py new file mode 100644 index 0000000000000000000000000000000000000000..d1515b2d082c59321adedf7497624a84f57cff39 --- /dev/null +++ b/easy_tpp/preprocess/data_collator.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass +from typing import Union, Optional + +from easy_tpp.preprocess.event_tokenizer import EventTokenizer +from easy_tpp.utils import PaddingStrategy, TruncationStrategy + + +@dataclass +class TPPDataCollator: + """ + Data collator that will dynamically pad the inputs of event sequences. + + Args: + tokenizer ([`EventTokenizer`]): + The tokenizer used for encoding the data. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + + - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single + sequence is provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + return_tensors (`str`): + The type of Tensor to return. Allowable values are "np", "pt" and "tf". + """ + + tokenizer: EventTokenizer + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + truncation: Union[bool, str, TruncationStrategy] = False + return_tensors: str = "pt" + + def __call__(self, features, return_tensors=None): + if return_tensors is None: + return_tensors = self.return_tensors + + batch = self.tokenizer.pad( + features, + padding=self.padding, + max_length=self.max_length, + truncation=self.truncation, + return_tensors=return_tensors, + ) + + return batch diff --git a/easy_tpp/preprocess/data_loader.py b/easy_tpp/preprocess/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..09d99df53031f4cdf4a4002ccd57ae6388f0814b --- /dev/null +++ b/easy_tpp/preprocess/data_loader.py @@ -0,0 +1,235 @@ +import matplotlib.pyplot as plt +import numpy as np +from collections import Counter +from easy_tpp.preprocess.dataset import TPPDataset +from easy_tpp.preprocess.dataset import get_data_loader +from easy_tpp.preprocess.event_tokenizer import EventTokenizer +from easy_tpp.utils import load_pickle, py_assert + + +class TPPDataLoader: + def __init__(self, data_config, **kwargs): + """Initialize the dataloader + + Args: + data_config (EasyTPP.DataConfig): data config. + backend (str): backend engine, defaults to 'torch'. + """ + self.data_config = data_config + self.num_event_types = data_config.data_specs.num_event_types + self.backend = kwargs.get('backend', 'torch') + self.kwargs = kwargs + + def build_input(self, source_dir, data_format, split): + """Helper function to load and process dataset based on file format. + + Args: + source_dir (str): Path to dataset directory. + split (str): Dataset split, e.g., 'train', 'dev', 'test'. + + Returns: + dict: Dictionary containing sequences of event times, types, and intervals. + """ + + if data_format == 'pkl': + return self._build_input_from_pkl(source_dir, split) + elif data_format == 'json': + return self._build_input_from_json(source_dir, split) + else: + raise ValueError(f"Unsupported file format: {data_format}") + + def _build_input_from_pkl(self, source_dir, split): + """Load and process data from a pickle file. + + Args: + source_dir (str): Path to the pickle file. + split (str): Dataset split, e.g., 'train', 'dev', 'test'. + + Returns: + dict: Dictionary with processed event sequences. + """ + data = load_pickle(source_dir) + py_assert(data["dim_process"] == self.num_event_types, + ValueError, "Inconsistent dim_process in different splits.") + + source_data = data[split] + return { + 'time_seqs': [[x["time_since_start"] for x in seq] for seq in source_data], + 'type_seqs': [[x["type_event"] for x in seq] for seq in source_data], + 'time_delta_seqs': [[x["time_since_last_event"] for x in seq] for seq in source_data] + } + + def _build_input_from_json(self, source_dir, split): + """Load and process data from a JSON file. + + Args: + source_dir (str): Path to the JSON file or Hugging Face dataset name. + split (str): Dataset split, e.g., 'train', 'dev', 'test'. + + Returns: + dict: Dictionary with processed event sequences. + """ + from datasets import load_dataset + split_mapped = 'validation' if split == 'dev' else split + if source_dir.endswith('.json'): + data = load_dataset('json', data_files={split_mapped: source_dir}, split=split_mapped) + elif source_dir.startswith('easytpp'): + data = load_dataset(source_dir, split=split_mapped) + else: + raise ValueError("Unsupported source directory format for JSON.") + + py_assert(data['dim_process'][0] == self.num_event_types, + ValueError, "Inconsistent dim_process in different splits.") + + return { + 'time_seqs': data['time_since_start'], + 'type_seqs': data['type_event'], + 'time_delta_seqs': data['time_since_last_event'] + } + + def get_loader(self, split='train', **kwargs): + """Get the corresponding data loader. + + Args: + split (str, optional): denote the train, valid and test set. Defaults to 'train'. + num_event_types (int, optional): num of event types in the data. Defaults to None. + + Raises: + NotImplementedError: the input of 'num_event_types' is inconsistent with the data. + + Returns: + EasyTPP.DataLoader: the data loader for tpp data. + """ + data_dir = self.data_config.get_data_dir(split) + data = self.build_input(data_dir, self.data_config.data_format, split) + + dataset = TPPDataset(data) + tokenizer = EventTokenizer(self.data_config.data_specs) + + # Remove 'shuffle' from kwargs if it exists to avoid conflict + shuffle = kwargs.pop('shuffle', self.kwargs.get('shuffle', False)) + + loader = get_data_loader(dataset, + self.backend, + tokenizer, + batch_size=self.kwargs['batch_size'], + shuffle=shuffle, + **kwargs) + + return loader + + def train_loader(self, **kwargs): + """Return the train loader + + Returns: + EasyTPP.DataLoader: data loader for train set. + """ + return self.get_loader('train', **kwargs) + + def valid_loader(self, **kwargs): + """Return the valid loader + + Returns: + EasyTPP.DataLoader: data loader for valid set. + """ + return self.get_loader('dev', **kwargs) + + def test_loader(self, **kwargs): + """Return the test loader + + Returns: + EasyTPP.DataLoader: data loader for test set. + """ + # for test set, we do not shuffle + kwargs['shuffle'] = False + return self.get_loader('test', **kwargs) + + def get_statistics(self, split='train'): + """Get basic statistics about the dataset. + + Args: + split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'. + + Returns: + dict: Dictionary containing statistics about the dataset. + """ + data_dir = self.data_config.get_data_dir(split) + data = self.build_input(data_dir, self.data_config.data_format, split) + + num_sequences = len(data['time_seqs']) + sequence_lengths = [len(seq) for seq in data['time_seqs']] + avg_sequence_length = sum(sequence_lengths) / num_sequences + all_event_types = [event for seq in data['type_seqs'] for event in seq] + event_type_counts = Counter(all_event_types) + + # Calculate time_delta_seqs statistics + all_time_deltas = [delta for seq in data['time_delta_seqs'] for delta in seq] + mean_time_delta = np.mean(all_time_deltas) if all_time_deltas else 0 + min_time_delta = np.min(all_time_deltas) if all_time_deltas else 0 + max_time_delta = np.max(all_time_deltas) if all_time_deltas else 0 + + stats = { + "num_sequences": num_sequences, + "avg_sequence_length": avg_sequence_length, + "event_type_distribution": dict(event_type_counts), + "max_sequence_length": max(sequence_lengths), + "min_sequence_length": min(sequence_lengths), + "mean_time_delta": mean_time_delta, + "min_time_delta": min_time_delta, + "max_time_delta": max_time_delta + } + + return stats + + def plot_event_type_distribution(self, split='train'): + """Plot the distribution of event types in the dataset. + + Args: + split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'. + """ + stats = self.get_statistics(split) + event_type_distribution = stats['event_type_distribution'] + + plt.figure(figsize=(8, 6)) + plt.bar(event_type_distribution.keys(), event_type_distribution.values(), color='skyblue') + plt.xlabel('Event Types') + plt.ylabel('Frequency') + plt.title(f'Event Type Distribution ({split} set)') + plt.show() + + def plot_event_delta_times_distribution(self, split='train'): + """Plot the distribution of event delta times in the dataset. + + Args: + split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'. + """ + data_dir = self.data_config.get_data_dir(split) + data = self.build_input(data_dir, self.data_config.data_format, split) + + # Flatten the time_delta_seqs to get all delta times + all_time_deltas = [delta for seq in data['time_delta_seqs'] for delta in seq] + + plt.figure(figsize=(10, 6)) + plt.hist(all_time_deltas, bins=30, color='skyblue', edgecolor='black') + plt.xlabel('Event Delta Times') + plt.ylabel('Frequency') + plt.title(f'Event Delta Times Distribution ({split} set)') + plt.grid(axis='y', alpha=0.75) + plt.show() + + def plot_sequence_length_distribution(self, split='train'): + """Plot the distribution of sequence lengths in the dataset. + + Args: + split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'. + """ + data_dir = self.data_config.get_data_dir(split) + data = self.build_input(data_dir, self.data_config.data_format, split) + sequence_lengths = [len(seq) for seq in data['time_seqs']] + + plt.figure(figsize=(8, 6)) + plt.hist(sequence_lengths, bins=10, color='salmon', edgecolor='black') + plt.xlabel('Sequence Length') + plt.ylabel('Frequency') + plt.title(f'Sequence Length Distribution ({split} set)') + plt.show() diff --git a/easy_tpp/preprocess/dataset.py b/easy_tpp/preprocess/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9c47c4b066e5888060975796f42e7a50e75d2a54 --- /dev/null +++ b/easy_tpp/preprocess/dataset.py @@ -0,0 +1,80 @@ +import math +from typing import Dict + +import numpy as np +from torch.utils.data import Dataset, DataLoader + +from easy_tpp.preprocess.data_collator import TPPDataCollator +from easy_tpp.preprocess.event_tokenizer import EventTokenizer +from easy_tpp.utils import py_assert + + +class TPPDataset(Dataset): + def __init__(self, data: Dict): + self.data_dict = data + self.time_seqs = self.data_dict['time_seqs'] + self.time_delta_seqs = self.data_dict['time_delta_seqs'] + self.type_seqs = self.data_dict['type_seqs'] + + def __len__(self): + """ + + Returns: length of the dataset + + """ + + py_assert(len(self.time_seqs) == len(self.type_seqs) and len(self.time_delta_seqs) == len(self.type_seqs), + ValueError, + f"Inconsistent lengths for data! time_seq_len:{len(self.time_seqs)}, event_len: " + f"{len(self.type_seqs)}, time_delta_seq_len: {len(self.time_delta_seqs)}") + + return len(self.time_seqs) + + def __getitem__(self, idx): + """ + + Args: + idx: iteration index + + Returns: + dict: a dict of time_seqs, time_delta_seqs and type_seqs element + + """ + return dict({'time_seqs': self.time_seqs[idx], 'time_delta_seqs': self.time_delta_seqs[idx], + 'type_seqs': self.type_seqs[idx]}) + + def get_dt_stats(self): + x_bar, s_2_x, n = 0., 0., 0 + min_dt, max_dt = np.inf, -np.inf + + for dts, marks in zip(self.time_delta_seqs, self.type_seqs): + dts = np.array(dts[1:-1 if marks[-1] == -1 else None]) + min_dt = min(min_dt, dts.min()) + max_dt = max(max_dt, dts.max()) + y_bar = dts.mean() + s_2_y = dts.var() + m = dts.shape[0] + n += m + # Formula taken from https://math.stackexchange.com/questions/3604607/can-i-work-out-the-variance-in-batches + s_2_x = (((n - 1) * s_2_x + (m - 1) * s_2_y) / (n + m - 1)) + ( + (n * m * ((x_bar - y_bar) ** 2)) / ((n + m) * (n + m - 1))) + x_bar = (n * x_bar + m * y_bar) / (n + m) + + print(x_bar, (s_2_x ** 0.5)) + print(f'min_dt: {min_dt}') + print(f'max_dt: {max_dt}') + return x_bar, (s_2_x ** 0.5), min_dt, max_dt + + +def get_data_loader(dataset: TPPDataset, backend: str, tokenizer: EventTokenizer, **kwargs): + assert backend == 'torch', 'Only torch backend is supported.' + padding = True if tokenizer.padding_strategy is None else tokenizer.padding_strategy + truncation = False if tokenizer.truncation_strategy is None else tokenizer.truncation_strategy + data_collator = TPPDataCollator(tokenizer=tokenizer, + return_tensors='pt', + max_length=tokenizer.model_max_length, + padding=padding, + truncation=truncation) + return DataLoader(dataset, + collate_fn=data_collator, + **kwargs) diff --git a/easy_tpp/preprocess/event_tokenizer.py b/easy_tpp/preprocess/event_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..acaedb487803569b2f7855b6918cfe15b7ea4b69 --- /dev/null +++ b/easy_tpp/preprocess/event_tokenizer.py @@ -0,0 +1,543 @@ +import copy +from collections import UserDict +from typing import Optional, Union, Dict, Any, List, Mapping + +import numpy as np + +from easy_tpp.utils import logger, TruncationStrategy, PaddingStrategy, \ + TensorType, is_torch_device, requires_backends, is_numpy_array, py_assert + + +class BatchEncoding(UserDict): + """ + Holds the output of the [`~event_tokenizer.EventTokenizer.__call__`], + [`~event_tokenizer.EventTokenizer.encode_plus`] methods (tokens, attention_masks, etc). + + This class is derived from a python dictionary and can be used as a dictionary. + + Args: + data (`dict`): + Dictionary of lists/arrays/tensors returned by the `__call__`/`encode_plus`/`batch_encode_plus` methods + ('input_ids', 'attention_mask', etc.). + tensor_type (`Union[None, str, TensorType]`, *optional*): + You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + initialization. + prepend_batch_axis (`bool`, *optional*, defaults to `False`): + Whether or not to add a batch axis when converting to tensors (see `tensor_type` above). + n_sequences (`Optional[int]`, *optional*): + You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + initialization. + """ + + def __init__( + self, + data: Optional[Dict[str, Any]] = None, + tensor_type: Union[None, str, TensorType] = None, + prepend_batch_axis: bool = False + ): + super().__init__(data) + + self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis) + + def keys(self): + return self.data.keys() + + def values(self): + return list(self.data.values()) + + def items(self): + return self.data.items() + + def convert_to_tensors( + self, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False + ): + """ + Convert the inner content to tensors. + + Args: + tensor_type (`str` or [`~utils.TensorType`], *optional*): + The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If + `None`, no modification is done. + prepend_batch_axis (`int`, *optional*, defaults to `False`): + Whether or not to add the batch dimension during the conversion. + """ + if tensor_type is None: + return self + + # Convert to TensorType + if not isinstance(tensor_type, TensorType): + tensor_type = TensorType(tensor_type) + + # Get a function reference for the correct framework + if tensor_type == TensorType.PYTORCH: + import torch + + as_tensor = torch.tensor + is_tensor = torch.is_tensor + else: + as_tensor = np.asarray + is_tensor = is_numpy_array + + # Do the tensor conversion in batch + for key, value in self.items(): + try: + if prepend_batch_axis: + value = [value] + + if not is_tensor(value): + tensor = as_tensor(value) + + self[key] = tensor + except Exception as e: + if key == "overflowing_tokens": + raise ValueError( + "Unable to create tensor returning overflowing tokens of different lengths. " + "Please see if a fast version of this tokenizer is available to have this feature available." + ) from e + raise ValueError( + "Unable to create tensor, you should probably activate truncation and/or padding with" + " 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your" + f" features (`{key}` in this case) have excessive nesting (inputs type `list` where type `int` is" + " expected)." + ) from e + + return self + + def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": + """ + Send all values to device by calling `v.to(device)` (PyTorch only). + + Args: + device (`str` or `torch.device`): The device to put the tensors on. + + Returns: + [`BatchEncoding`]: The same instance after modification. + """ + requires_backends(self, ["torch"]) + + # This check catches things like APEX blindly calling "to" on all inputs to a module + # Otherwise it passes the casts down and casts the LongTensor containing the token idxs + # into a HalfTensor + if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): + self.data = {k: v.to(device=device) for k, v in self.data.items()} + else: + logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") + return self + + +class EventTokenizer: + """ + Base class for tokenizer event sequences, vendored from huggingface/transformer + """ + padding_side: str = "right" + truncation_side: str = "right" + model_input_names: List[str] = ["time_seqs", "time_delta_seqs", "type_seqs", "seq_non_pad_mask", "attention_mask"] + + def __init__(self, config): + config = copy.deepcopy(config) + self.num_event_types = config.num_event_types + self.pad_token_id = config.pad_token_id + + self.model_max_length = config.max_len + + self.padding_strategy = config.padding_strategy + self.truncation_strategy = config.truncation_strategy + + # Padding and truncation side are right by default and overridden in subclasses. If specified in the kwargs, it + # is changed. + self.padding_side = config.pop("padding_side", self.padding_side) + self.truncation_side = config.pop("truncation_side", self.truncation_side) + self.model_input_names = config.pop("model_input_names", self.model_input_names) + + def _get_padding_truncation_strategies( + self, padding=False, truncation=None, max_length=None, verbose=False, **kwargs + ): + padding_strategy, truncation_strategy = None, None + # If you only set max_length, it activates truncation for max_length + if max_length is not None and padding is False and truncation is None: + if verbose: + logger.warning( + "Truncation was not explicitly activated but `max_length` is provided a specific value, please" + " use `truncation=True` to explicitly truncate examples to max length. Defaulting to" + " 'longest_first' truncation strategy" + ) + truncation = "longest_first" + + # Get padding strategy + if padding is False: + if max_length is None: + padding_strategy = PaddingStrategy.LONGEST + else: + padding_strategy = PaddingStrategy.MAX_LENGTH + elif padding is not False: + if padding is True: + if verbose: + if max_length is not None and ( + truncation is None or truncation is False or truncation == "do_not_truncate" + ): + logger.warn( + "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. " + "To pad to max length, use `padding='max_length'`." + ) + padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch + elif not isinstance(padding, PaddingStrategy): + padding_strategy = PaddingStrategy(padding) + elif isinstance(padding, PaddingStrategy): + padding_strategy = padding + else: + padding_strategy = PaddingStrategy.DO_NOT_PAD + + # Get truncation strategy + if truncation is not None and truncation is not False: + if truncation is True: + truncation_strategy = ( + TruncationStrategy.LONGEST_FIRST + ) # Default to truncate the longest sequences in pairs of inputs + elif not isinstance(truncation, TruncationStrategy): + truncation_strategy = TruncationStrategy(truncation) + elif isinstance(truncation, TruncationStrategy): + truncation_strategy = truncation + else: + truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE + + # Set max length if needed + if max_length is None: + if padding_strategy == PaddingStrategy.MAX_LENGTH: + max_length = self.model_max_length + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE: + max_length = self.model_max_length + + # Test if we have a padding token + if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token_id): + raise ValueError( + "Asking to pad but the tokenizer does not have a padding token. " + "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` " + "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`." + ) + + return padding_strategy, truncation_strategy, max_length, kwargs + + def _truncate(self, + encoded_inputs: Union[Dict[str, Any], + Dict[str, List]], + truncation_strategy: TruncationStrategy, + truncation_side: str, + max_length: Optional[int] = None): + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE: + py_assert(max_length is not None, ValueError, 'must pass max_length when truncation is activated!') + for k, v in encoded_inputs.items(): + seq_ = [seq[:max_length] for seq in v] if truncation_side == 'right' \ + else [seq[-max_length:] for seq in v] + encoded_inputs[k] = seq_ + + return encoded_inputs + + def pad( + self, + encoded_inputs: Union[ + Dict[str, Any], + Dict[str, List], + ], + padding: Union[bool, str, PaddingStrategy] = True, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + verbose: bool = False, + ) -> BatchEncoding: + """ + Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length + in the batch. + + Padding side (left/right) padding token ids are defined at the tokenizer level (with `self.padding_side`, + `self.pad_token_id` and `self.pad_token_type_id`). + + Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the + text followed by a call to the `pad` method to get a padded encoding. + + + + If the `encoded_inputs` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the + result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of + PyTorch tensors, you will lose the specific device of your tensors however. + + + + Args: + encoded_inputs ([`BatchEncoding`], list of [`BatchEncoding`]: + Tokenized inputs. Can represent one input ([`BatchEncoding`] or `Dict[str, List[int]]`) or a batch of + tokenized inputs (list of [`BatchEncoding`], *Dict[str, List[List[int]]]* or *List[Dict[str, + List[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader + collate function. + + Instead of `List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see + the note above for the return type. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + """ + + # If we have a list of dicts, let's convert it in a dict of lists + # We do this to allow using this method as a collate_fn function in PyTorch Dataloader + if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping): + encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} + + # The model's main input name, usually `time_seqs`, has be passed for padding + if self.model_input_names[0] not in encoded_inputs: + raise ValueError( + "You should supply an encoding or a list of encodings to this method " + f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" + ) + + required_input = encoded_inputs[self.model_input_names[0]] + + padding_strategy, truncation_strategy, max_length, _ = self._get_padding_truncation_strategies( + padding=padding, max_length=max_length, truncation=truncation, verbose=verbose + ) + + encoded_inputs = self._truncate(encoded_inputs, + truncation_strategy=truncation_strategy, + max_length=max_length, + truncation_side=self.truncation_side) + + batch_size = len(required_input) + assert all( + len(v) == batch_size for v in encoded_inputs.values() + ), "Some items in the output dictionary have a different batch size than others." + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = max(len(inputs) for inputs in required_input) + padding_strategy = PaddingStrategy.MAX_LENGTH + + batch_output = self._pad( + encoded_inputs, + max_length=max_length, + padding_strategy=padding_strategy, + return_attention_mask=return_attention_mask, + ) + + return BatchEncoding(batch_output, tensor_type=return_tensors) + + def _pad( + self, + encoded_inputs: Union[Dict[str, Any], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + # check whether we need to pad it + seq_lens = np.array([len(seq) for seq in required_input]) + is_all_seq_equal_max_length = np.all(seq_lens == max_length) + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and ~is_all_seq_equal_max_length + + batch_output = dict() + + if needs_to_be_padded: + # time seqs + batch_output[self.model_input_names[0]] = self.make_pad_sequence(encoded_inputs[self.model_input_names[0]], + self.pad_token_id, + padding_side=self.padding_side, + max_len=max_length) + # time_delta seqs + batch_output[self.model_input_names[1]] = self.make_pad_sequence(encoded_inputs[self.model_input_names[1]], + self.pad_token_id, + padding_side=self.padding_side, + max_len=max_length) + # type_seqs + batch_output[self.model_input_names[2]] = self.make_pad_sequence(encoded_inputs[self.model_input_names[2]], + self.pad_token_id, + padding_side=self.padding_side, + max_len=max_length, + dtype=np.int64) + else: + batch_output[self.model_input_names[0]] = np.array(encoded_inputs[self.model_input_names[0]], dtype=np.float32) + batch_output[self.model_input_names[1]] = np.array(encoded_inputs[self.model_input_names[1]], dtype=np.float32) + batch_output[self.model_input_names[2]] = np.array(encoded_inputs[self.model_input_names[2]], dtype=np.int64) + + # non_pad_mask; replaced the use of event types by using the original sequence length + seq_pad_mask = np.full_like(batch_output[self.model_input_names[2]], fill_value=True, dtype=bool) + for i, seq_len in enumerate(seq_lens): + seq_pad_mask[i, seq_len:] = False + batch_output[self.model_input_names[3]] = seq_pad_mask + + if return_attention_mask: + # attention_mask + batch_output[self.model_input_names[4]] = self.make_attn_mask_for_pad_sequence( + batch_output[self.model_input_names[2]], + self.pad_token_id) + else: + batch_output[self.model_input_names[4]] = [] + + return batch_output + + @staticmethod + def make_pad_sequence(seqs, + pad_token_id, + padding_side, + max_len, + dtype=np.float32, + group_by_event_types=False): + """Pad the sequence batch-wise. + + Args: + seqs (list): list of sequences with variational length + pad_token_id (int, float): optional, a value that used to pad the sequences. If None, then the pad index + is set to be the event_num_with_pad + max_len (int): optional, the maximum length of the sequence after padding. If None, then the + length is set to be the max length of all input sequences. + pad_at_end (bool): optional, whether to pad the sequnce at the end. If False, + the sequence is pad at the beginning + + Returns: + a numpy array of padded sequence + + + Example: + ```python + seqs = [[0, 1], [3, 4, 5]] + pad_sequence(seqs, 100) + >>> [[0, 1, 100], [3, 4, 5]] + + pad_sequence(seqs, 100, max_len=5) + >>> [[0, 1, 100, 100, 100], [3, 4, 5, 100, 100]] + ``` + + """ + if not group_by_event_types: + if padding_side == "right": + pad_seq = np.array([seq + [pad_token_id] * (max_len - len(seq)) for seq in seqs], dtype=dtype) + else: + pad_seq = np.array([[pad_token_id] * (max_len - len(seq)) + seq for seq in seqs], dtype=dtype) + else: + pad_seq = [] + for seq in seqs: + if padding_side == "right": + pad_seq.append(np.array([s + [pad_token_id] * (max_len - len(s)) for s in seq], dtype=dtype)) + else: + pad_seq.append(np.array([[pad_token_id] * (max_len - len(s)) + s for s in seqs], dtype=dtype)) + + pad_seq = np.array(pad_seq) + return pad_seq + + def make_attn_mask_for_pad_sequence(self, pad_seqs, pad_token_id): + """Make the attention masks for the sequence. + + Args: + pad_seqs (tensor): list of sequences that have been padded with fixed length + pad_token_id (int): optional, a value that used to pad the sequences. If None, then the pad index + is set to be the event_num_with_pad + + Returns: + np.array: a bool matrix of the same size of input, denoting the masks of the + sequence (True: non mask, False: mask) + + + Example: + ```python + seqs = [[ 1, 6, 0, 7, 12, 12], + [ 1, 0, 5, 1, 10, 9]] + make_attn_mask_for_pad_sequence(seqs, pad_index=12) + >>> + batch_non_pad_mask + ([[ True, True, True, True, False, False], + [ True, True, True, True, True, True]]) + attention_mask + [[[ False True True True True True] + [False False True True True True] + [False False False True True True] + [False False False False True True] + [False False False False True True] + [False False False False True True]] + + [[False True True True True True] + [False False True True True True] + [False False False True True True] + [False False False False True True] + [False False False False False True] + [False False False False False False]]] + ``` + + + """ + + seq_num, seq_len = pad_seqs.shape + + # [batch_size, seq_len] + seq_pad_mask = pad_seqs == pad_token_id + + # [batch_size, seq_len, seq_len] + attention_key_pad_mask = np.tile(seq_pad_mask[:, None, :], (1, seq_len, 1)) + subsequent_mask = np.tile(np.triu(np.ones((seq_len, seq_len), dtype=bool), k=1)[None, :, :], (seq_num, 1, 1)) + + attention_mask = subsequent_mask | attention_key_pad_mask + + return attention_mask + + def make_type_mask_for_pad_sequence(self, pad_seqs): + """Make the type mask. + + Args: + pad_seqs (tensor): a list of sequence events with equal length (i.e., padded sequence) + + Returns: + np.array: a 3-dim matrix, where the last dim (one-hot vector) indicates the type of event + + """ + type_mask = np.zeros([*pad_seqs.shape, self.num_event_types], dtype=np.int32) + for i in range(self.num_event_types): + type_mask[:, :, i] = pad_seqs == i + + return type_mask diff --git a/easy_tpp/preprocess/robert_dataset.py b/easy_tpp/preprocess/robert_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2294dbaf47cfe216d0c3a911fdc2812a77b5f0 --- /dev/null +++ b/easy_tpp/preprocess/robert_dataset.py @@ -0,0 +1,95 @@ +""" +评论罗伯特数据集 - 支持语义特征、偏差特征等 + +扩展TPPDataset以支持多模态特征 +""" + +from typing import Dict, Optional +import numpy as np +from easy_tpp.preprocess.dataset import TPPDataset +from easy_tpp.utils import py_assert + + +class RobertTPPDataset(TPPDataset): + """ + 支持语义特征、偏差特征等的TPP数据集 + + 扩展标准TPPDataset以支持: + - semantic_vectors: 语义向量列表 + - deviation_features: 偏差特征列表 + - is_spontaneous: 自发/被@标记列表 + """ + + def __init__(self, data: Dict): + """ + 初始化数据集 + + Args: + data: 数据字典,包含: + - time_seqs: 时间序列列表 + - type_seqs: 事件类型序列列表 + - time_delta_seqs: 时间间隔序列列表 + - semantic_vectors: 语义向量列表(可选)[num_seqs, seq_len, semantic_dim] + - deviation_features: 偏差特征列表(可选)[num_seqs, seq_len, 3] + - is_spontaneous: 自发/被@标记列表(可选)[num_seqs, seq_len] + """ + super(RobertTPPDataset, self).__init__(data) + + # 可选特征 + self.semantic_vectors = self.data_dict.get('semantic_vectors', None) + self.deviation_features = self.data_dict.get('deviation_features', None) + self.is_spontaneous = self.data_dict.get('is_spontaneous', None) + + # 验证数据一致性 + if self.semantic_vectors is not None: + py_assert( + len(self.semantic_vectors) == len(self.time_seqs), + ValueError, + f"Inconsistent lengths: semantic_vectors={len(self.semantic_vectors)}, " + f"time_seqs={len(self.time_seqs)}" + ) + + if self.deviation_features is not None: + py_assert( + len(self.deviation_features) == len(self.time_seqs), + ValueError, + f"Inconsistent lengths: deviation_features={len(self.deviation_features)}, " + f"time_seqs={len(self.time_seqs)}" + ) + + if self.is_spontaneous is not None: + py_assert( + len(self.is_spontaneous) == len(self.time_seqs), + ValueError, + f"Inconsistent lengths: is_spontaneous={len(self.is_spontaneous)}, " + f"time_seqs={len(self.time_seqs)}" + ) + + def __getitem__(self, idx): + """ + 获取单个样本 + + Args: + idx: 样本索引 + + Returns: + dict: 包含时间、类型、可选特征的字典 + """ + item = { + 'time_seqs': self.time_seqs[idx], + 'time_delta_seqs': self.time_delta_seqs[idx], + 'type_seqs': self.type_seqs[idx] + } + + # 添加可选特征 + if self.semantic_vectors is not None: + item['semantic_vectors'] = self.semantic_vectors[idx] + + if self.deviation_features is not None: + item['deviation_features'] = self.deviation_features[idx] + + if self.is_spontaneous is not None: + item['is_spontaneous'] = self.is_spontaneous[idx] + + return item + diff --git a/easy_tpp/preprocess/robert_tokenizer.py b/easy_tpp/preprocess/robert_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7edaf7de82b92989be54ac401cc5d68f99869426 --- /dev/null +++ b/easy_tpp/preprocess/robert_tokenizer.py @@ -0,0 +1,261 @@ +""" +评论罗伯特事件分词器 - 支持语义特征、偏差特征等 + +扩展EventTokenizer以支持多模态特征的padding和批处理 +""" + +from typing import Optional, Union, Dict, Any +import numpy as np +import torch +from easy_tpp.preprocess.event_tokenizer import EventTokenizer, BatchEncoding +from easy_tpp.utils import PaddingStrategy + + +class RobertEventTokenizer(EventTokenizer): + """ + 支持语义特征、偏差特征等的事件分词器 + + 扩展EventTokenizer以支持: + - semantic_vectors: 语义向量padding + - deviation_features: 偏差特征padding + - is_spontaneous: 自发/被@标记padding + """ + + def __init__(self, config, use_semantic=False, use_deviation=False, semantic_dim=768): + """ + 初始化分词器 + + Args: + config: 配置对象 + use_semantic: 是否使用语义特征 + use_deviation: 是否使用偏差特征 + semantic_dim: 语义向量维度 + """ + super(RobertEventTokenizer, self).__init__(config) + + self.use_semantic = use_semantic + self.use_deviation = use_deviation + self.semantic_dim = semantic_dim + + # 添加自定义特征到model_input_names + # 标准顺序:time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask + # 自定义特征添加在后面 + if self.use_semantic: + self.model_input_names.append('semantic_vectors') + if self.use_deviation: + self.model_input_names.append('deviation_features') + + # is_spontaneous总是添加(如果使用) + self.model_input_names.append('is_spontaneous') + + def _pad( + self, + encoded_inputs: Union[Dict[str, Any], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + 填充编码输入(包括自定义特征) + + Args: + encoded_inputs: 编码后的输入 + max_length: 最大长度 + padding_strategy: 填充策略 + return_attention_mask: 是否返回注意力掩码 + + Returns: + dict: 填充后的批次数据 + """ + # 先处理标准字段(调用父类方法) + # 但我们需要重写以添加自定义特征的处理 + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = max(len(seq) for seq in required_input) + padding_strategy = PaddingStrategy.MAX_LENGTH + + # 获取序列长度 + seq_lens = np.array([len(seq) for seq in required_input]) + is_all_seq_equal_max_length = np.all(seq_lens == max_length) + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and ~is_all_seq_equal_max_length + + batch_output = dict() + + # 处理标准字段(time_seqs, time_delta_seqs, type_seqs) + if needs_to_be_padded: + batch_output[self.model_input_names[0]] = self.make_pad_sequence( + encoded_inputs[self.model_input_names[0]], + self.pad_token_id, + padding_side=self.padding_side, + max_len=max_length + ) + batch_output[self.model_input_names[1]] = self.make_pad_sequence( + encoded_inputs[self.model_input_names[1]], + self.pad_token_id, + padding_side=self.padding_side, + max_len=max_length + ) + batch_output[self.model_input_names[2]] = self.make_pad_sequence( + encoded_inputs[self.model_input_names[2]], + self.pad_token_id, + padding_side=self.padding_side, + max_len=max_length, + dtype=np.int64 + ) + else: + batch_output[self.model_input_names[0]] = np.array( + encoded_inputs[self.model_input_names[0]], dtype=np.float32 + ) + batch_output[self.model_input_names[1]] = np.array( + encoded_inputs[self.model_input_names[1]], dtype=np.float32 + ) + batch_output[self.model_input_names[2]] = np.array( + encoded_inputs[self.model_input_names[2]], dtype=np.int64 + ) + + # non_pad_mask + seq_pad_mask = np.full_like(batch_output[self.model_input_names[2]], fill_value=True, dtype=bool) + for i, seq_len in enumerate(seq_lens): + seq_pad_mask[i, seq_len:] = False + batch_output[self.model_input_names[3]] = seq_pad_mask + + # attention_mask + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if return_attention_mask: + batch_output[self.model_input_names[4]] = self.make_attn_mask_for_pad_sequence( + batch_output[self.model_input_names[2]], + self.pad_token_id + ) + else: + batch_output[self.model_input_names[4]] = [] + + # 处理自定义特征 + # 处理语义向量 + if self.use_semantic and 'semantic_vectors' in encoded_inputs: + semantic_vectors = encoded_inputs['semantic_vectors'] + if needs_to_be_padded: + batch_output['semantic_vectors'] = self.make_pad_sequence_for_features( + semantic_vectors, + pad_value=0.0, + max_len=max_length, + feature_dim=self.semantic_dim + ) + else: + batch_output['semantic_vectors'] = np.array(semantic_vectors, dtype=np.float32) + elif self.use_semantic: + # 如果没有提供但需要,创建零向量 + batch_size = len(required_input) + if needs_to_be_padded: + batch_output['semantic_vectors'] = np.zeros( + (batch_size, max_length, self.semantic_dim), dtype=np.float32 + ) + else: + # 使用最大长度 + max_seq_len = int(seq_lens.max()) + batch_output['semantic_vectors'] = np.zeros( + (batch_size, max_seq_len, self.semantic_dim), dtype=np.float32 + ) + + # 处理偏差特征 + if self.use_deviation and 'deviation_features' in encoded_inputs: + deviation_features = encoded_inputs['deviation_features'] + if needs_to_be_padded: + batch_output['deviation_features'] = self.make_pad_sequence_for_features( + deviation_features, + pad_value=0.0, + max_len=max_length, + feature_dim=3 + ) + else: + batch_output['deviation_features'] = np.array(deviation_features, dtype=np.float32) + elif self.use_deviation: + # 如果没有提供但需要,创建零向量 + batch_size = len(required_input) + if needs_to_be_padded: + batch_output['deviation_features'] = np.zeros( + (batch_size, max_length, 3), dtype=np.float32 + ) + else: + max_seq_len = int(seq_lens.max()) + batch_output['deviation_features'] = np.zeros( + (batch_size, max_seq_len, 3), dtype=np.float32 + ) + + # 处理is_spontaneous + if 'is_spontaneous' in encoded_inputs: + is_spontaneous = encoded_inputs['is_spontaneous'] + if needs_to_be_padded: + batch_output['is_spontaneous'] = self.make_pad_sequence_for_features( + is_spontaneous, + pad_value=-1.0, # -1表示不适用 + max_len=max_length, + feature_dim=1 # 标量 + ) + else: + batch_output['is_spontaneous'] = np.array(is_spontaneous, dtype=np.float32) + else: + # 如果没有提供,创建-1向量(不适用) + batch_size = len(required_input) + if needs_to_be_padded: + batch_output['is_spontaneous'] = np.full( + (batch_size, max_length), -1.0, dtype=np.float32 + ) + else: + max_seq_len = int(seq_lens.max()) + batch_output['is_spontaneous'] = np.full( + (batch_size, max_seq_len), -1.0, dtype=np.float32 + ) + + return batch_output + + def make_pad_sequence_for_features(self, seqs, pad_value, max_len, feature_dim=None, dtype=np.float32): + """ + 为特征序列创建padding(辅助方法) + + Args: + seqs: 序列列表 + pad_value: padding值 + max_len: 最大长度 + feature_dim: 特征维度(如果是多维特征) + dtype: 数据类型 + + Returns: + np.ndarray: 填充后的数组 + """ + padded_seqs = [] + for seq in seqs: + seq_len = len(seq) + if seq_len < max_len: + pad_len = max_len - seq_len + if isinstance(seq, np.ndarray): + if seq.ndim == 1: + # 一维数组 + pad = np.full(pad_len, pad_value, dtype=dtype) + padded_seq = np.concatenate([seq, pad], axis=0) + else: + # 多维数组 + pad_shape = (pad_len,) + seq.shape[1:] + pad = np.full(pad_shape, pad_value, dtype=dtype) + padded_seq = np.concatenate([seq, pad], axis=0) + else: + # 列表 + if isinstance(seq[0], (list, np.ndarray, tuple)): + # 嵌套列表(多维) + if feature_dim: + pad = [[pad_value] * feature_dim] * pad_len + else: + pad = [[pad_value] * len(seq[0])] * pad_len + padded_seq = seq + pad + else: + # 一维列表 + pad = [pad_value] * pad_len + padded_seq = seq + pad + padded_seqs.append(padded_seq) + else: + padded_seqs.append(seq[:max_len]) + + return np.array(padded_seqs, dtype=dtype) + diff --git a/easy_tpp/runner/__init__.py b/easy_tpp/runner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a71d121f5860ab3c7a4d6967b76b31cfc43f2860 --- /dev/null +++ b/easy_tpp/runner/__init__.py @@ -0,0 +1,7 @@ +from easy_tpp.runner.base_runner import Runner +from easy_tpp.runner.tpp_runner import TPPRunner +# for register all necessary contents +from easy_tpp.default_registers.register_metrics import * + +__all__ = ['Runner', + 'TPPRunner'] \ No newline at end of file diff --git a/easy_tpp/runner/base_runner.py b/easy_tpp/runner/base_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..eb094fc46b60677e6a6235458984f1157ebdc345 --- /dev/null +++ b/easy_tpp/runner/base_runner.py @@ -0,0 +1,201 @@ +import logging +from abc import abstractmethod + +from easy_tpp.preprocess import TPPDataLoader +from easy_tpp.utils import Registrable, Timer, logger, get_unique_id, LogConst, get_stage, RunnerPhase + + +class Runner(Registrable): + """Registrable Base Runner class. + """ + + def __init__( + self, + runner_config, + unique_model_dir=False, + **kwargs): + """Initialize the base runner. + + Args: + runner_config (RunnerConfig): config for the runner. + unique_model_dir (bool, optional): whether to give unique dir to save the model. Defaults to False. + """ + self.runner_config = runner_config + # re-assign the model_dir + if unique_model_dir: + runner_config.model_dir = runner_config.base_config.specs['saved_model_dir'] + '_' + get_unique_id() + + self.save_log() + + skip_data_loader = kwargs.get('skip_data_loader', False) + if not skip_data_loader: + # build data reader + data_config = self.runner_config.data_config + backend = self.runner_config.base_config.backend + kwargs = self.runner_config.trainer_config.get_yaml_config() + self._data_loader = TPPDataLoader( + data_config=data_config, + backend=backend, + **kwargs + ) + + # Needed for Intensity Free model + mean_log_inter_time, std_log_inter_time, min_dt, max_dt = ( + self._data_loader.train_loader().dataset.get_dt_stats()) + runner_config.model_config.set("mean_log_inter_time", mean_log_inter_time) + runner_config.model_config.set("std_log_inter_time", std_log_inter_time) + self.timer = Timer() + + @staticmethod + def build_from_config(runner_config, unique_model_dir=False, **kwargs): + """Build up the runner from runner config. + + Args: + runner_config (RunnerConfig): config for the runner. + unique_model_dir (bool, optional): whether to give unique dir to save the model. Defaults to False. + + Returns: + Runner: the corresponding runner class. + """ + runner_cls = Runner.by_name(runner_config.base_config.runner_id) + return runner_cls(runner_config, unique_model_dir=unique_model_dir, **kwargs) + + def get_config(self): + return self.runner_config + + def set_model_dir(self, model_dir): + self.runner_config.base_config.specs['saved_model_dir'] = model_dir + + def get_model_dir(self): + return self.runner_config.base_config.specs['saved_model_dir'] + + def train( + self, + train_loader=None, + valid_loader=None, + test_loader=None, + **kwargs + ): + """Train the model. + + Args: + train_loader (EasyTPP.DataLoader, optional): data loader for train set. Defaults to None. + valid_loader (EasyTPP.DataLoader, optional): data loader for valid set. Defaults to None. + test_loader (EasyTPP.DataLoader, optional): data loader for test set. Defaults to None. + + Returns: + model: _description_ + """ + # no train and valid loader from outside + if train_loader is None and valid_loader is None: + train_loader = self._data_loader.train_loader() + valid_loader = self._data_loader.valid_loader() + + # no test loader from outside and there indeed exits test data in config + if test_loader is None and self.runner_config.data_config.test_dir is not None: + test_loader = self._data_loader.test_loader() + + logger.info(f'Data \'{self.runner_config.base_config.dataset_id}\' loaded...') + + timer = self.timer + timer.start() + model_id = self.runner_config.base_config.model_id + logger.info(f'Start {model_id} training...') + model = self._train_model( + train_loader, + valid_loader, + test_loader=test_loader, + **kwargs + ) + logger.info(f'End {model_id} train! Cost time: {timer.end()}') + return model + + def evaluate(self, valid_loader=None, **kwargs): + if valid_loader is None: + valid_loader = self._data_loader.valid_loader() + + logger.info(f'Data \'{self.runner_config.base_config.dataset_id}\' loaded...') + + timer = self.timer + timer.start() + model_id = self.runner_config.base_config.model_id + logger.info(f'Start {model_id} evaluation...') + + metric = self._evaluate_model( + valid_loader, + **kwargs + ) + logger.info(f'End {model_id} evaluation! Cost time: {timer.end()}') + return metric['rmse'] # return a list of scalr for HPO to use + + def gen(self, gen_loader=None, **kwargs): + if gen_loader is None: + gen_loader = self._data_loader.test_loader() + + logger.info(f'Data \'{self.runner_config.base_config.dataset_id}\' loaded...') + + timer = self.timer + timer.start() + model_name = self.runner_config.base_config.model_id + logger.info(f'Start {model_name} evaluation...') + + model = self._gen_model( + gen_loader, + **kwargs + ) + logger.info(f'End {model_name} generation! Cost time: {timer.end()}') + return model + + @abstractmethod + def _train_model(self, train_loader, valid_loader, **kwargs): + pass + + @abstractmethod + def _evaluate_model(self, data_loader, **kwargs): + pass + + @abstractmethod + def _gen_model(self, data_loader, **kwargs): + pass + + @abstractmethod + def _save_model(self, model_dir, **kwargs): + pass + + @abstractmethod + def _load_model(self, model_dir, **kwargs): + pass + + def save_log(self): + """Save log to local files + """ + log_dir = self.runner_config.base_config.specs['saved_log_dir'] + fh = logging.FileHandler(log_dir) + fh.setFormatter(logging.Formatter(LogConst.DEFAULT_FORMAT_LONG)) + logger.addHandler(fh) + logger.info(f'Save the log to {log_dir}') + return + + def save( + self, + model_dir=None, + **kwargs + ): + return self._save_model(model_dir, **kwargs) + + def run(self, **kwargs): + """Start the runner. + + Args: + **kwargs (dict): optional params. + + Returns: + EasyTPP.BaseModel, dict: the results of the process. + """ + current_stage = get_stage(self.runner_config.base_config.stage) + if current_stage == RunnerPhase.TRAIN: + return self.train(**kwargs) + elif current_stage == RunnerPhase.VALIDATE: + return self.evaluate(**kwargs) + else: + return self.gen(**kwargs) diff --git a/easy_tpp/runner/tpp_runner.py b/easy_tpp/runner/tpp_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..c8535d4658ecb7d723db9750cc906b8e230141dd --- /dev/null +++ b/easy_tpp/runner/tpp_runner.py @@ -0,0 +1,225 @@ +from collections import OrderedDict + +from easy_tpp.runner.base_runner import Runner +from easy_tpp.utils import RunnerPhase, logger, MetricsHelper, MetricsTracker, concat_element, save_pickle +from easy_tpp.utils.const import Backend + + +@Runner.register(name='std_tpp') +class TPPRunner(Runner): + """Standard TPP runner + """ + + def __init__(self, runner_config, unique_model_dir=False, **kwargs): + super(TPPRunner, self).__init__(runner_config, unique_model_dir, **kwargs) + + self.metrics_tracker = MetricsTracker() + if self.runner_config.trainer_config.metrics is not None: + self.metric_functions = self.runner_config.get_metric_functions() + + self._init_model() + + pretrain_dir = self.runner_config.model_config.pretrained_model_dir + if pretrain_dir is not None: + self._load_model(pretrain_dir) + + def _init_model(self): + """Initialize the model. + """ + self.use_torch = self.runner_config.base_config.backend == Backend.Torch + + if self.use_torch: + from easy_tpp.utils import set_seed + from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel + from easy_tpp.torch_wrapper import TorchModelWrapper + from easy_tpp.utils import count_model_params + set_seed(self.runner_config.trainer_config.seed) + + self.model = TorchBaseModel.generate_model_from_config(model_config=self.runner_config.model_config) + self.model_wrapper = TorchModelWrapper(self.model, + self.runner_config.base_config, + self.runner_config.model_config, + self.runner_config.trainer_config) + num_params = count_model_params(self.model) + + else: + from easy_tpp.utils.tf_utils import set_seed + from easy_tpp.model.tf_model.tf_basemodel import TfBaseModel + from easy_tpp.tf_wrapper import TfModelWrapper + from easy_tpp.utils.tf_utils import count_model_params + set_seed(self.runner_config.trainer_config.seed) + + self.model = TfBaseModel.generate_model_from_config(model_config=self.runner_config.model_config) + self.model_wrapper = TfModelWrapper(self.model, + self.runner_config.base_config, + self.runner_config.model_config, + self.runner_config.trainer_config) + num_params = count_model_params() + + info_msg = f'Num of model parameters {num_params}' + logger.info(info_msg) + + def _save_model(self, model_dir, **kwargs): + """Save the model. + + Args: + model_dir (str): the dir for model to save. + """ + if model_dir is None: + model_dir = self.runner_config.base_config.specs['saved_model_dir'] + self.model_wrapper.save(model_dir) + logger.critical(f'Save model to {model_dir}') + return + + def _load_model(self, model_dir, **kwargs): + """Load the model from the dir. + + Args: + model_dir (str): the dir for model to load. + """ + self.model_wrapper.restore(model_dir) + logger.critical(f'Load model from {model_dir}') + return + + def _train_model(self, train_loader, valid_loader, **kwargs): + """Train the model. + + Args: + train_loader (EasyTPP.DataLoader): data loader for the train set. + valid_loader (EasyTPP.DataLoader): data loader for the valid set. + """ + test_loader = kwargs.get('test_loader') + for i in range(self.runner_config.trainer_config.max_epoch): + train_metrics = self.run_one_epoch(train_loader, RunnerPhase.TRAIN) + + message = f"[ Epoch {i} (train) ]: train " + MetricsHelper.metrics_dict_to_str(train_metrics) + logger.info(message) + + self.model_wrapper.write_summary(i, train_metrics, RunnerPhase.TRAIN) + + # evaluate model + if i % self.runner_config.trainer_config.valid_freq == 0: + valid_metrics = self.run_one_epoch(valid_loader, RunnerPhase.VALIDATE) + + self.model_wrapper.write_summary(i, valid_metrics, RunnerPhase.VALIDATE) + + message = f"[ Epoch {i} (valid) ]: valid " + MetricsHelper.metrics_dict_to_str(valid_metrics) + logger.info(message) + + updated = self.metrics_tracker.update_best("loglike", valid_metrics['loglike'], i) + + message_valid = "current best loglike on valid set is {:.4f} (updated at epoch-{})".format( + self.metrics_tracker.current_best['loglike'], self.metrics_tracker.episode_best) + + if updated: + message_valid += f", best updated at this epoch" + self.model_wrapper.save(self.runner_config.base_config.specs['saved_model_dir']) + + if test_loader is not None: + test_metrics = self.run_one_epoch(test_loader, RunnerPhase.VALIDATE) + + message = f"[ Epoch {i} (test) ]: test " + MetricsHelper.metrics_dict_to_str(test_metrics) + logger.info(message) + + logger.critical(message_valid) + + self.model_wrapper.close_summary() + + return + + def _evaluate_model(self, data_loader, **kwargs): + """Evaluate the model on the valid dataset. + + Args: + data_loader (EasyTPP.DataLoader): data loader for the valid set + + Returns: + dict: metrics dict. + """ + + eval_metrics = self.run_one_epoch(data_loader, RunnerPhase.VALIDATE) + + self.model_wrapper.write_summary(0, eval_metrics, RunnerPhase.VALIDATE) + + self.model_wrapper.close_summary() + + message = f"Evaluation result: " + MetricsHelper.metrics_dict_to_str(eval_metrics) + + logger.critical(message) + + return eval_metrics + + def _gen_model(self, data_loader, **kwargs): + """Generation of the TPP, one-step and multi-step are both supported. + """ + + test_result = self.run_one_epoch(data_loader, RunnerPhase.PREDICT) + + # For the moment we save it to a pkl + + message = f'Save the prediction to pickle file pred.pkl' + + logger.critical(message) + + save_pickle('pred.pkl', test_result) + + return + + def run_one_epoch(self, data_loader, phase): + """Run one complete epoch. + + Args: + data_loader: data loader object defined in model runner + phase: enum, [train, dev, test] + + Returns: + a dict of metrics + """ + total_loss = 0 + total_num_event = 0 + epoch_label = [] + epoch_pred = [] + epoch_mask = [] + pad_index = self.runner_config.data_config.data_specs.pad_token_id + metrics_dict = OrderedDict() + if phase in [RunnerPhase.TRAIN, RunnerPhase.VALIDATE]: + for batch in data_loader: + batch_loss, batch_num_event, batch_pred, batch_label, batch_mask = \ + self.model_wrapper.run_batch(batch, phase=phase) + + total_loss += batch_loss + total_num_event += batch_num_event + epoch_pred.append(batch_pred) + epoch_label.append(batch_label) + epoch_mask.append(batch_mask) + + avg_loss = total_loss / total_num_event + + metrics_dict.update({'loglike': -avg_loss, 'num_events': total_num_event}) + + else: + for batch in data_loader: + batch_pred, batch_label = self.model_wrapper.run_batch(batch, phase=phase) + epoch_pred.append(batch_pred) + epoch_label.append(batch_label) + + # we need to improve the code here + # classify batch_output to list + pred_exists, label_exists = False, False + if epoch_pred[0][0] is not None: + epoch_pred = concat_element(epoch_pred, pad_index) + pred_exists = True + if len(epoch_label) > 0 and epoch_label[0][0] is not None: + epoch_label = concat_element(epoch_label, pad_index) + label_exists = True + if len(epoch_mask): + epoch_mask = concat_element(epoch_mask, False)[0] # retrieve the first element of concat array + epoch_mask = epoch_mask.astype(bool) + + if pred_exists and label_exists: + metrics_dict.update(self.metric_functions(epoch_pred, epoch_label, seq_mask=epoch_mask)) + + if phase == RunnerPhase.PREDICT: + metrics_dict.update({'pred': epoch_pred, 'label': epoch_label}) + + return metrics_dict diff --git a/easy_tpp/ssm/__init__.py b/easy_tpp/ssm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/easy_tpp/ssm/initializers.py b/easy_tpp/ssm/initializers.py new file mode 100644 index 0000000000000000000000000000000000000000..3027c85fba18fc3ef080d8d2b90187fe11fbf363 --- /dev/null +++ b/easy_tpp/ssm/initializers.py @@ -0,0 +1,117 @@ +import math + +import numpy as np +import numpy as onp +import torch as th +from numpy.linalg import eigh + + +def make_HiPPO(P): + """Create a HiPPO-LegS matrix. + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Args: + P (int32): state size + Returns: + P x P HiPPO LegS matrix + """ + M = np.sqrt(1 + 2 * np.arange(P)) + A = M[:, np.newaxis] * M[np.newaxis, :] + A = np.tril(A) - np.diag(np.arange(P)) + return -A + + +def make_NPLR_HiPPO(P): + """ + Makes components needed for NPLR representation of HiPPO-LegS + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Args: + P (int32): state size + + Returns: + P x P HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B + + """ + # Make -HiPPO + hippo = make_HiPPO(P) + + # Add in a rank 1 term. Makes it Normal. + R1 = np.sqrt(np.arange(P) + 0.5) + + # HiPPO also specifies the B matrix + B = np.sqrt(2 * np.arange(P) + 1.0) + return hippo, R1, B + + +def make_DPLR_HiPPO(P): + """ + Makes components needed for DPLR representation of HiPPO-LegS + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Note, we will only use the diagonal part + Args: + P: + + Returns: + eigenvalues Lambda, low-rank term R1, conjugated HiPPO input matrix B, + eigenvectors V, HiPPO B pre-conjugation + + """ + A, R1, B = make_NPLR_HiPPO(P) + + S = A + R1[:, np.newaxis] * R1[np.newaxis, :] + + S_diag = np.diagonal(S) + Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) + + # Diagonalize S to V \Lambda V^* + Lambda_imag, V = eigh(S * -1j) + + R1 = V.conj().T @ R1 + B_orig = B + B = V.conj().T @ B + return ( + th.tensor(onp.asarray(Lambda_real + 1j * Lambda_imag), dtype=th.complex64), + th.tensor(onp.asarray(R1)), + th.tensor(onp.asarray(B)), + th.tensor(onp.asarray(V), dtype=th.complex64), + th.tensor(onp.asarray(B_orig)), + ) + + +def init_log_steps(P, dt_min, dt_max): + """Initialize an array of learnable timescale parameters. + initialized uniformly in log space. + Args: + input: + Returns: + initialized array of timescales (float32): (P,) + """ + unlog = th.rand(size=(P,)) + log = unlog * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) + return log + + +def lecun_normal_(tensor: th.Tensor) -> th.Tensor: + input_size = tensor.shape[ + -1 + ] # Assuming that the weights' input dimension is the last. + std = math.sqrt(1 / input_size) + with th.no_grad(): + return tensor.normal_(0, std) # or torch.nn.init.xavier_normal_ + + +def init_VinvB(shape, Vinv): + """Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. + Note we will parameterize this with two different matrices for complex + + Modified from https://github.com/lindermanlab/S5/blob/52cc7e22d6963459ad99a8674e4d3cfb0a480008/s5/ssm.py#L165 + + numbers. + Args: + shape (tuple): desired shape (P,H) + Vinv: (complex64) the inverse eigenvectors used for initialization + Returns: + B_tilde (complex64) of shape (P,H) + """ + B = lecun_normal_(th.zeros(shape)) + VinvB = Vinv @ B.type(th.complex64) + return VinvB diff --git a/easy_tpp/ssm/models.py b/easy_tpp/ssm/models.py new file mode 100644 index 0000000000000000000000000000000000000000..db695751e86e353115f1ff83b18e20241eaa7ac0 --- /dev/null +++ b/easy_tpp/ssm/models.py @@ -0,0 +1,820 @@ +from typing import Optional, Tuple + +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .initializers import ( + make_DPLR_HiPPO, # , lecun_normal_ # init_VinvB, init_log_steps, +) + +MATRIX_SCALING_FACTOR = 1 + + +class LLH(nn.Module): + """ + This is canon: + L -- number of layers + N -- number of events. + P -- Hidden dimension. Dimensionality of x. + H -- output dimension. Dimensionality of y/u. + """ + + def __init__( + self, + P: int, + H: int, + dt_init_min: float = 1e-4, + dt_init_max: float = 0.1, + dropout_rate: float = 0.0, + act_func: str = "gelu", # F.gelu, + for_loop: bool = False, + pre_norm: bool = True, + post_norm: bool = False, + simple_mark: bool = True, + is_first_layer: bool = False, + relative_time: bool = False, + complex_values: bool = True, + ): + """ + + :param P: + :param H: + :param dt_init_min: + :param dt_init_max: + :param act_func: + """ + + super(LLH, self).__init__() + + # Inscribe the args. + self.P = P + self.H = H + self.dt_init_min = dt_init_min + self.dt_init_max = dt_init_max + self.dropout_rate = dropout_rate + self.complex_values = complex_values + + # select the activation function. + if act_func == "gelu": + self.act_func = nn.Sequential(nn.GELU(), nn.Dropout(p=self.dropout_rate)) + elif act_func == "full_glu": + self.act_func = nn.Sequential( + nn.Linear(self.H, 2 * self.H), + nn.Dropout(p=self.dropout_rate), + nn.GLU(), + nn.Dropout(p=self.dropout_rate), + ) + + elif ( + act_func == "half_glu" + ): # ref: https://github.com/lindermanlab/S5/blob/main/s5/layers.py#L76 + self.act_func1 = nn.Sequential( + nn.Dropout(p=self.dropout_rate), + nn.GELU(), + nn.Linear(self.H, self.H), + ) + self.act_func = lambda x: nn.Dropout(p=self.dropout_rate)( + x * nn.Sigmoid()(self.act_func1(x)) + ) + else: + raise NotImplementedError( + "Unrecognized activation function {}".format(act_func) + ) + + # Assume we always use conjugate symmetry. + self.conj_sym = True + + # Allow a learnable initial state. + # Needs to be =/= 0 since we take the log to compute + if self.complex_values: + self.initial_state_P = nn.Parameter( + th.complex( + th.randn( + self.P, + ), + th.randn( + self.P, + ), + ) + * 1e-3, + requires_grad=True, + ) + else: + self.initial_state_P = nn.Parameter( + th.randn( + self.P, + ), + requires_grad=True, + ) + + self.norm = nn.LayerNorm(self.H) + self.for_loop = for_loop + self.pre_norm = pre_norm + self.post_norm = post_norm + + self.is_first_layer = is_first_layer + self.relative_time = relative_time + + self._init_ssm_params() + + self.simple_mark = simple_mark + if not simple_mark: + self.mark_a_net = nn.Linear(self.H, self.P, bias=True) + self.mark_u_net = nn.Linear( + self.H, self.P, bias=False + ) # Only need one bias + self.mark_a_net.weight.data = th.complex( + nn.init.xavier_normal_(self.mark_a_net.weight.data) * 1e-3, + nn.init.xavier_normal_(self.mark_a_net.weight.data) * 1e-3, + ) + self.mark_a_net.bias.data = th.complex( + nn.init.xavier_normal_(self.mark_a_net.bias.data) * 1e-3, + nn.init.xavier_normal_(self.mark_a_net.bias.data) * 1e-3, + ) + self.mark_u_net.weight.data = th.complex( + nn.init.xavier_normal_(self.mark_u_net.weight.data) * 1e-3, + nn.init.xavier_normal_(self.mark_u_net.weight.data) * 1e-3, + ) + if not self.complex_values: + self.mark_a_net.weight.data = self.mark_a_net.weight.data.real + self.mark_a_net.bias.data = self.mark_a_net.bias.data.real + self.mark_u_net.weight.data = self.self.mark_u_net.weight.data.real + + def _init_ssm_params(self): + self._init_A() + if not self.is_first_layer: + self._init_B() + self._init_C() + if ( + not self.is_first_layer + ): # Could group, but left in same order to not mess with initialization + self._init_D() + self._init_E() + + def _init_A(self): + # Define the initial diagonal HiPPO matrix. + # Te throw the HiPPO B away. + Lambda_P, _, _, V_PP, _ = make_DPLR_HiPPO(self.P) + self.Lambda_P_log_neg_real = th.nn.Parameter((-Lambda_P.real).log()) + self.Lambda_P_imag = th.nn.Parameter(Lambda_P.imag) + + # Store these for use later. + self._V_PP = V_PP + self._Vc_PP = V_PP.conj().T + + # We also initialize the step size. + if self.relative_time: + self.delta_net = nn.Linear( + self.H, self.P, bias=True + ) # nn.Parameter(init_log_steps(self.P, self.dt_init_min, self.dt_init_max)) + with th.no_grad(): + self.delta_net.weight.copy_( + nn.init.xavier_normal_(self.delta_net.weight) + ) + bias = th.ones( + self.P, + ) + bias += th.log(-th.expm1(-bias)) + self.delta_net.bias.copy_(bias) + else: + self.log_step_size_P = nn.Parameter( + th.zeros(size=(self.P,)), requires_grad=False + ) + + @property + def Lambda_P(self): + if self.complex_values: + return th.complex( + -self.Lambda_P_log_neg_real.exp(), + self.Lambda_P_imag, + ) + else: + return -self.Lambda_P_log_neg_real.exp() + + def _init_B(self): + # Initialize the B outside the eigenbasis and then transform. + B = nn.init.xavier_normal_(th.zeros((self.P, self.H))) * MATRIX_SCALING_FACTOR + B_tilde_PH = self._Vc_PP @ B.type(th.complex64) + self.B_tilde_PH = ( + th.nn.Parameter(B_tilde_PH) + if self.complex_values + else th.nn.Parameter(B_tilde_PH.real) + ) + + def _init_C(self): + # Use the "complex_normal" initialization. + # See ~https://github.com/lindermanlab/S5/blob/52cc7e22d6963459ad99a8674e4d3cfb0a480008/s5/ssm.py#L183 + C = nn.init.xavier_normal_(th.zeros((self.H, self.P))) * MATRIX_SCALING_FACTOR + C_tilde_HP = C.type(th.complex64) @ self._V_PP + self.C_tilde_HP = ( + th.nn.Parameter(C_tilde_HP) + if self.complex_values + else th.nn.Parameter(C_tilde_HP.real) + ) + # self.C_tilde_HP.data *= 1e-3 + + def _init_D(self): + # Initialize feedthrough (D) matrix. Note the intensity depends on all layers. + D_HH = th.zeros(self.H) + nn.init.normal_(D_HH, std=1.0) + self.D_HH = nn.Parameter(D_HH, requires_grad=True) + + def _init_E(self): + E = ( + th.nn.init.xavier_normal_(th.zeros((self.P, self.H))) + * MATRIX_SCALING_FACTOR + ) + E_tilde_PH = self._Vc_PP @ E.type(th.complex64) + self.E_tilde_PH = ( + th.nn.Parameter(E_tilde_PH) + if self.complex_values + else th.nn.Parameter(E_tilde_PH.real) + ) + + def compute_impulse(self, right_u_H, mark_embedding_H): + # Compute impulse to add to left limit of x to make right limit. + alpha_P = th.einsum( + "ph,...h->...p", + self.E_tilde_PH, + mark_embedding_H.type(th.complex64) + if self.complex_values + else mark_embedding_H, + ) + return alpha_P + + def get_lambda(self, right_u_NH, shift_u=True): + if self.relative_time and (right_u_NH is not None): + if shift_u: # during "forward" when dts = [0, t1-t0, ..., t_N-t_{N-1}] + right_u_NH = F.pad( + right_u_NH[..., :-1, :], (0, 0, 1, 0) + ) # pad default 0 at beginning of second to last dim + lambda_rescaled_NP = ( + F.softplus(self.delta_net(right_u_NH)) * self.Lambda_P + ) # predict delta_i from right_u_i + return {"lambda_rescaled_NP": lambda_rescaled_NP} + else: + if self.relative_time: + lambda_rescaled_P = F.softplus(self.delta_net.bias) * self.Lambda_P + else: + lambda_rescaled_P = th.exp(self.log_step_size_P) * self.Lambda_P + return {"lambda_rescaled_P": lambda_rescaled_P} + + def forward( + self, + left_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + right_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + mark_embedding_NH: th.Tensor, + dt_N: th.Tensor, + initial_state_P: Optional[th.Tensor] = None, + ) -> Tuple[th.Tensor, th.Tensor]: + """ + Apply the linear SSM to the inputs. + + In the context of TPPs, this returns the right limit of the "intensity function". + This intensity will have been passed through a non-linearity, though, and so there is no + guarantee for it is positive. + + :param u_NH: [..., seq_len, input_dim] + :param alpha_NP: [..., seq_len, hidden_dim] + :param dt_N: [..., seq_len] + :param initial_state_P: [..., hidden_dim] + :return: + """ + # Pull out the dimensions. + *leading_dims, _, _ = mark_embedding_NH.shape + num_leading_dims = len(leading_dims) + + if initial_state_P is None: + # Pad and expand to match leading dimensions of input + initial_state_P = self.initial_state_P.view( + *[1 for _ in range(num_leading_dims)], -1 + ).expand(*leading_dims, -1) + + # Add layer norm + prime_left_u_NH = left_u_NH + prime_right_u_NH = right_u_NH + if prime_left_u_NH is not None: # ONLY for backward variant + assert all( + u_d == a_d + for u_d, a_d in zip(prime_left_u_NH.shape, mark_embedding_NH.shape) + ) # All but last dimensions should match + if self.pre_norm: + prime_left_u_NH = self.norm(prime_left_u_NH) + if prime_right_u_NH is not None: + assert all( + u_d == a_d + for u_d, a_d in zip(prime_right_u_NH.shape, mark_embedding_NH.shape) + ) # All but last dimensions should match + if self.pre_norm: + prime_right_u_NH = self.norm(prime_right_u_NH) + + right_x_NP, left_y_NH, right_y_NH = self._ssm( + left_u_NH=prime_left_u_NH, + right_u_NH=prime_right_u_NH, + impulse_NP=self.compute_impulse(prime_right_u_NH, mark_embedding_NH), + dt_N=dt_N, + initial_state_P=initial_state_P, + ) + + # Given the following: + # right_u: u0, u1, u2, ... <-> u_{t_0}, u_{t_1}, u_{t_2}, ... + # left_u: u0, u1, u2, ... <-> u_{t_0-}, u_{t_1-}, u_{t_2-}, ... + # a: a0, a1, a2, ... <-> mark embeddings for m_0, m_1, m_2, ... at times t_0, t_1, t_2 + # dt: dt0, dt1, dt2, ... <-> 0, t_1-t_0, t_2-t_1, ... + # initial_state_p: hidden state to evolve to to compute x_{0} + + # Returns the following: + # right_x: x0, x1, x2, ... <-> x_{t_0}, x_{t_1}, x_{t_2}, ... + # right_y: y0, y1, y2, ... <-> y_{t_0}, y_{t_1}, y_{t_2}, ... + # left_y: y0, y1, y2, ... <-> y_{t_0-}, y_{t_1-}, y_{t_2-}, ... + + next_layer_left_u_NH = next_layer_right_u_NH = None + if left_y_NH is not None: + next_layer_left_u_NH = self.act_func(left_y_NH) + ( + left_u_NH if left_u_NH is not None else 0.0 + ) + if self.post_norm: + next_layer_left_u_NH = self.norm(next_layer_left_u_NH) + if right_y_NH is not None: + next_layer_right_u_NH = self.act_func(right_y_NH) + ( + right_u_NH if right_u_NH is not None else 0.0 + ) + if self.post_norm: + next_layer_right_u_NH = self.norm(next_layer_right_u_NH) + return right_x_NP, next_layer_left_u_NH, next_layer_right_u_NH + + def _ssm( + self, + left_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + right_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + impulse_NP: th.Tensor, + dt_N: th.Tensor, # [0, t_1 - t_0, ..., t_N - t_{N-1}] + initial_state_P: th.Tensor, + ): + *leading_dims, N, P = impulse_NP.shape + u_NH = right_u_NH # This implementation does not use left_u, nor does it compute left_y + if u_NH is not None: + impulse_NP = impulse_NP + th.einsum( + "ph,...nh->...np", + self.B_tilde_PH, + u_NH.type(th.complex64) if self.complex_values else u_NH, + ) + y_u_res_NH = th.einsum( + "...nh,h->...nh", u_NH, self.D_HH + ) # D_HH should really be D_H + else: + assert self.is_first_layer + y_u_res_NH = 0.0 + + lambda_res = self.get_lambda(right_u_NH=right_u_NH, shift_u=True) + if "lambda_rescaled_P" in lambda_res: # original formulation + lambda_dt_NP = th.einsum( + "...n,p->...np", dt_N, lambda_res["lambda_rescaled_P"] + ) + else: # relative time + lambda_dt_NP = th.einsum( + "...n,...np->...np", dt_N, lambda_res["lambda_rescaled_NP"] + ) + + if self.for_loop: + right_x_P = initial_state_P + right_x_NP = [] + for i in range(N): + right_x_P = ( + lambda_dt_NP[..., i, :].exp() * right_x_P + impulse_NP[..., i, :] + ) + right_x_NP.append(right_x_P) + right_x_NP = th.stack(right_x_NP, dim=-2) + else: + # Trick inspired by: https://github.com/PeaBrane/mamba-tiny/blob/master/scans.py + # .unsqueeze(-2) to add sequence dimension to initial state + log_impulse_Np1_P = th.concat( + (initial_state_P.unsqueeze(-2), impulse_NP), dim=-2 + ).log() + lamdba_dt_star = F.pad(lambda_dt_NP.cumsum(-2), (0, 0, 1, 0)) + right_x_log_NP = ( + th.logcumsumexp(log_impulse_Np1_P - lamdba_dt_star, -2) + lamdba_dt_star + ) + right_x_NP = right_x_log_NP.exp()[..., 1:, :] + + conj_sym_mult = 2 if self.conj_sym else 1 + y_NH = ( + conj_sym_mult + * th.einsum("...np,hp->...nh", right_x_NP, self.C_tilde_HP).real + + y_u_res_NH + ) + + return right_x_NP, None, y_NH + + def get_left_limit( + self, + right_limit_P: th.Tensor, # Along with dt, can have any number of leading dimensions, produces a tensor of dim ...MP + dt_G: th.Tensor, + current_right_u_H: th.Tensor, + next_left_u_GH: th.Tensor, + ) -> th.Tensor: + """ + To get the left limit, we roll on the layer for the right dt. + Computed for a single point (vmap for multiple). + + :param right_limit_P: at [t_0, ..., t_{N-1}] + :param dt: Length of time to roll the layer on for. at [t_1 - t_0, ..., t_N - t_{N-1}] + :param current_right_u_H: at [t_0, ..., t_{N-1}] -- for relative-time variant + :param next_left_u_GH: at [t_1, ..., t_N] -- for backward variant + + :return: + """ + + if current_right_u_H is not None and self.pre_norm: + current_right_u_H = self.norm(current_right_u_H) + + lambda_res = self.get_lambda( + current_right_u_H, shift_u=False + ) # U should already be shifted + if "lambda_rescaled_P" in lambda_res: + lambda_bar_GP = th.exp( + th.einsum("...g,p->...gp", dt_G, lambda_res["lambda_rescaled_P"]) + ) + else: + lambda_bar_GP = th.exp( + th.einsum("...g,...p->...gp", dt_G, lambda_res["lambda_rescaled_NP"]) + ) + + return th.einsum("...p,...gp->...gp", right_limit_P, lambda_bar_GP) + + def depth_pass( + self, + current_left_x_P: th.Tensor, # No leading dimensions (seq, batch, etc.) here because we accommodate any of them + current_left_u_H: Optional[ + th.Tensor + ], # Just assume that x and u match in the leading dimensions. Produces y_H with equivalent leading dimensions + prev_right_u_H: Optional[ + th.Tensor + ], # Just assume that x and u match in the leading dimensions. Produces y_H with equivalent leading dimensions + ) -> th.Tensor: + if current_left_u_H is not None: + if self.pre_norm: + prime_u_H = self.norm(current_left_u_H) + else: + prime_u_H = current_left_u_H + y_u_res_H = th.einsum( + "...h,h->...h", prime_u_H, self.D_HH + ) # D_HH should really be D_H + else: + assert self.is_first_layer + y_u_res_H = 0.0 + + conj_sym_mult = 2 if self.conj_sym else 1 + y_H = ( + conj_sym_mult + * th.einsum("...p,hp->...h", current_left_x_P, self.C_tilde_HP).real + + y_u_res_H + ) + + # Apply an activation function. + if self.post_norm: + new_u_H = self.norm( + self.act_func(y_H) + + (current_left_u_H if current_left_u_H is not None else 0.0) + ) + else: + new_u_H = self.act_func(y_H) + ( + current_left_u_H if current_left_u_H is not None else 0.0 + ) + + return new_u_H + + +class Int_Forward_LLH(LLH): + # LLH but Bu_t is integrated w.r.t dt instead of dN_t + # After discretization, when evolving x_t to x_t', applies ZOH on u_t over [t,t'] forward in time + # (as opposed to u_{t'} backwards over [t,t']) + + def _ssm( + self, + left_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + right_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + impulse_NP: th.Tensor, + dt_N: th.Tensor, + initial_state_P: th.Tensor, + ) -> Tuple[th.Tensor, th.Tensor]: + """ + Apply the linear SSM to the inputs. + + In the context of TPPs, this returns the right limit of the "intensity function". + This intensity will have been passed through a non-linearity, though, and so there is no + guarantee for it is positive. + + :param u_NH: [..., seq_len, input_dim] + :param alpha_NP: [..., seq_len, hidden_dim] + :param dt_N: [..., seq_len] + :param initial_state_P: [..., hidden_dim] + + :return: + """ + # Pull out the dimensions. + *leading_dims, N, P = impulse_NP.shape + + lambda_res = self.get_lambda(right_u_NH=right_u_NH, shift_u=True) + if "lambda_rescaled_P" in lambda_res: + lambda_rescaled = lambda_res["lambda_rescaled_P"] + lambda_dt_NP = th.einsum( + "...n,p->...np", dt_N, lambda_res["lambda_rescaled_P"] + ) + else: + lambda_rescaled = lambda_res["lambda_rescaled_NP"] + lambda_dt_NP = th.einsum( + "...n,...np->...np", dt_N, lambda_res["lambda_rescaled_NP"] + ) + + if left_u_NH is not None: + left_Du_NH = th.einsum( + "...nh,h->...nh", + left_u_NH, + self.D_HH, + ) + else: + assert self.is_first_layer + left_Du_NH = 0.0 + + if right_u_NH is not None: + right_u_NH = F.pad(right_u_NH[..., :-1, :], (0, 0, 1, 0)) + right_Bu_NP = th.einsum( + "...np,ph,...nh->...np", + lambda_dt_NP.exp() - 1.0, # dts: [0, t1-t0, t2-t1, ...] + self.B_tilde_PH, + right_u_NH.type(th.complex64) if self.complex_values else right_u_NH, + ) + right_Du_NH = th.einsum( + "...nh,h->...nh", + right_u_NH, + self.D_HH, + ) + else: + assert self.is_first_layer + right_Bu_NP = right_Du_NH = 0.0 + + if self.for_loop: + right_x_P = initial_state_P + left_x_NP, right_x_NP = [], [] + for i in range(N): + left_x_P = lambda_dt_NP[..., i, :].exp() * right_x_P + ( + right_Bu_NP[..., i, :] if left_u_NH is not None else 0.0 + ) + right_x_P = left_x_P + impulse_NP[..., i, :] + left_x_NP.append(left_x_P) + right_x_NP.append(right_x_P) + right_x_NP = th.stack( + right_x_NP, dim=-2 + ) # discard initial_hidden_states, right_limit of xs for [t0, t1, ...] + left_x_NP = th.stack( + left_x_NP, dim=-2 + ) # discard initial_hidden_states, left_limit of xs for [t0, t1, ...] + else: + # Trick inspired by: https://github.com/PeaBrane/mamba-tiny/blob/master/scans.py + # .unsqueeze(-2) to add sequence dimension to initial state + log_impulse_Np1_P = th.concat( + (initial_state_P.unsqueeze(-2), right_Bu_NP + impulse_NP), dim=-2 + ).log() + lamdba_dt_star = F.pad(lambda_dt_NP.cumsum(-2), (0, 0, 1, 0)) + right_x_log_NP = ( + th.logcumsumexp(log_impulse_Np1_P - lamdba_dt_star, -2) + lamdba_dt_star + ) + right_x_NP = right_x_log_NP.exp() # Contains initial_state_P in index 0 + left_x_NP = ( + lambda_dt_NP.exp() * right_x_NP[..., :-1, :] + right_Bu_NP + ) # Evolves previous hidden state forward to compute left limit + right_x_NP = right_x_NP[..., 1:, :] + + conj_sym_mult = 2 if self.conj_sym else 1 + left_y_NH = ( + conj_sym_mult + * th.einsum("hp,...np->...nh", self.C_tilde_HP, left_x_NP).real + + left_Du_NH + ) # ys for [t0, t1, ...] + right_y_NH = ( + conj_sym_mult + * th.einsum("hp,...np->...nh", self.C_tilde_HP, right_x_NP).real + + right_Du_NH + ) # ys for [t0, t1, ...] + + return right_x_NP, left_y_NH, right_y_NH + + def get_left_limit( + self, + right_limit_P: th.Tensor, # Along with dt, can have any number of leading dimensions, produces a tensor of dim ...MP + dt_G: th.Tensor, + current_right_u_H: Optional[th.Tensor], + next_left_u_GH: Optional[th.Tensor], + ) -> th.Tensor: + """ + To get the left limit, we roll on the layer for the right dt. + Computed for a single point (vmap for multiple). + + :param right_limit_P: + :param dt: Length of time to roll the layer on for. + :return: + """ + if current_right_u_H is not None and self.pre_norm: + current_right_u_H = self.norm(current_right_u_H) + + lambda_res = self.get_lambda( + current_right_u_H, shift_u=False + ) # U should already be shifted + if "lambda_rescaled_P" in lambda_res: + lambda_bar_GP = th.exp( + th.einsum("...g,p->...gp", dt_G, lambda_res["lambda_rescaled_P"]) + ) + else: + lambda_bar_GP = th.exp( + th.einsum("...g,...p->...gp", dt_G, lambda_res["lambda_rescaled_NP"]) + ) + + # lambda_rescaled_P = th.exp(self.log_step_size_P) * self.Lambda_P + # lambda_bar_GP = th.exp(th.einsum('...g,p->...gp', dt_G, lambda_rescaled_P)) + int_hidden_GP = th.einsum("...p,...gp->...gp", right_limit_P, lambda_bar_GP) + + if current_right_u_H is None: # no Bu term + assert self.is_first_layer + return int_hidden_GP + else: # add Bu to impulse + if self.pre_norm: + current_right_u_H = self.norm(current_right_u_H) + + impulse_GP = th.einsum( + "...gp,ph,...h->...gp", + lambda_bar_GP - 1.0, + self.B_tilde_PH, + current_right_u_H.type(th.complex64) + if self.complex_values + else current_right_u_H, + ) + + return int_hidden_GP + impulse_GP + + +class Int_Backward_LLH(Int_Forward_LLH): + # LLH but Bu_t is integrated w.r.t dt instead of dN_t + # After discretization, when evolving x_t to x_t', applies ZOH on u_t' over [t,t'] backwards in time + # (as opposed to u_{t} forwards over [t,t']) + + def _ssm( + self, + left_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + right_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + impulse_NP: th.Tensor, + dt_N: th.Tensor, + initial_state_P: th.Tensor, + ) -> Tuple[th.Tensor, th.Tensor]: + """ + Apply the linear SSM to the inputs. + + In the context of TPPs, this returns the right limit of the "intensity function". + This intensity will have been passed through a non-linearity, though, and so there is no + guarantee for it is positive. + + :param u_NH: [..., seq_len, input_dim] + :param alpha_NP: [..., seq_len, hidden_dim] + :param dt_N: [..., seq_len] + :param initial_state_P: [..., hidden_dim] + + :return: + """ + # Pull out the dimensions. + *leading_dims, N, P = impulse_NP.shape + + # lambda_rescaled_P = th.exp(self.log_step_size_P) * self.Lambda_P + # lambda_dt_NP = th.einsum('...n,p->...np', dt_N, lambda_rescaled_P) + lambda_res = self.get_lambda(right_u_NH=right_u_NH, shift_u=True) + if "lambda_rescaled_P" in lambda_res: + lambda_dt_NP = th.einsum( + "...n,p->...np", dt_N, lambda_res["lambda_rescaled_P"] + ) + else: + lambda_dt_NP = th.einsum( + "...n,...np->...np", dt_N, lambda_res["lambda_rescaled_NP"] + ) + + if left_u_NH is not None: + left_Bu_NP = th.einsum( + "...np,ph,...nh->...np", + lambda_dt_NP.exp() - 1.0, # dts: [0, t1-t0, t2-t1, ...] + self.B_tilde_PH, + left_u_NH.type(th.complex64) if self.complex_values else left_u_NH, + ) + left_Du_NH = th.einsum( + "...nh,h->...nh", + left_u_NH, + self.D_HH, + ) + else: + assert self.is_first_layer + left_Bu_NP = left_Du_NH = 0.0 + + if right_u_NH is not None: + right_Du_NH = th.einsum( + "...nh,h->...nh", + right_u_NH, + self.D_HH, + ) + else: + assert self.is_first_layer + right_Du_NH = 0.0 + + if self.for_loop: + right_x_P = initial_state_P + left_x_NP, right_x_NP = [], [] + for i in range(N): + left_x_P = lambda_dt_NP[..., i, :].exp() * right_x_P + ( + left_Bu_NP[..., i, :] if left_u_NH is not None else 0.0 + ) + right_x_P = left_x_P + impulse_NP[..., i, :] + left_x_NP.append(left_x_P) + right_x_NP.append(right_x_P) + right_x_NP = th.stack( + right_x_NP, dim=-2 + ) # discard initial_hidden_states, right_limit of xs for [t0, t1, ...] + left_x_NP = th.stack( + left_x_NP, dim=-2 + ) # discard initial_hidden_states, left_limit of xs for [t0, t1, ...] + else: + # Trick inspired by: https://github.com/PeaBrane/mamba-tiny/blob/master/scans.py + # .unsqueeze(-2) to add sequence dimension to initial state + log_impulse_Np1_P = th.concat( + (initial_state_P.unsqueeze(-2), left_Bu_NP + impulse_NP), dim=-2 + ).log() + lamdba_dt_star = F.pad(lambda_dt_NP.cumsum(-2), (0, 0, 1, 0)) + right_x_log_NP = ( + th.logcumsumexp(log_impulse_Np1_P - lamdba_dt_star, -2) + lamdba_dt_star + ) + right_x_NP = right_x_log_NP.exp() # Contains initial_state_P in index 0 + left_x_NP = ( + lambda_dt_NP.exp() * right_x_NP[..., :-1, :] + left_Bu_NP + ) # Evolves previous hidden state forward to compute left limit + right_x_NP = right_x_NP[..., 1:, :] + + conj_sym_mult = 2 if self.conj_sym else 1 + left_y_NH = ( + conj_sym_mult + * th.einsum("hp,...np->...nh", self.C_tilde_HP, left_x_NP).real + + left_Du_NH + ) # ys for [t0, t1, ...] + right_y_NH = ( + conj_sym_mult + * th.einsum("hp,...np->...nh", self.C_tilde_HP, right_x_NP).real + + right_Du_NH + ) # ys for [t0, t1, ...] + + return right_x_NP, left_y_NH, right_y_NH + + def get_left_limit( + self, + right_limit_P: th.Tensor, # Along with dt, can have any number of leading dimensions, produces a tensor of dim ...MP + dt_G: th.Tensor, + current_right_u_H: th.Tensor, + next_left_u_GH: th.Tensor, + ) -> th.Tensor: + """ + To get the left limit, we roll on the layer for the right dt. + Computed for a single point (vmap for multiple). + + :param right_limit_P: + :param dt: Length of time to roll the layer on for. + :return: + """ + + if current_right_u_H is not None and self.pre_norm: + current_right_u_H = self.norm(current_right_u_H) + + lambda_res = self.get_lambda( + current_right_u_H, shift_u=False + ) # U should already be shifted + if "lambda_rescaled_P" in lambda_res: + lambda_bar_GP = th.exp( + th.einsum("...g,p->...gp", dt_G, lambda_res["lambda_rescaled_P"]) + ) + else: + lambda_bar_GP = th.exp( + th.einsum("...g,...p->...gp", dt_G, lambda_res["lambda_rescaled_NP"]) + ) + + int_hidden_GP = th.einsum("...p,...gp->...gp", right_limit_P, lambda_bar_GP) + + if next_left_u_GH is None: # no Bu term + assert self.is_first_layer + return int_hidden_GP + else: # add Bu to impulse + if self.pre_norm: + next_left_u_GH = self.norm(next_left_u_GH) + + impulse_GP = th.einsum( + "...gp,ph,...gh->...gp", + lambda_bar_GP - 1.0, + self.B_tilde_PH, + next_left_u_GH.type(th.complex64) + if self.complex_values + else next_left_u_GH, + ) + + return int_hidden_GP + impulse_GP diff --git a/easy_tpp/ssm/ssm_util.py b/easy_tpp/ssm/ssm_util.py new file mode 100644 index 0000000000000000000000000000000000000000..72d9b5ce668b501f36864e9b5eb984710c2049f8 --- /dev/null +++ b/easy_tpp/ssm/ssm_util.py @@ -0,0 +1,80 @@ +# @title Imports and environment +import torch as th + + +def discretize_zoh(Lambda, B_tilde, Delta): + """Discretize a diagonalized, continuous-time linear SSM + using zero-order hold method. + + modified from: https://github.com/lindermanlab/S5/blob/3c18fdb6b06414da35e77b94b9cd855f6a95ef17/s5/ssm.py#L29 + + Args: + Lambda (complex64): diagonal state matrix (P,) + B_tilde (complex64): input matrix (P, H) + Delta (float32): discretization step sizes (P,) + Returns: + discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) + """ + Identity = th.ones(Lambda.shape[0]) + Lambda_bar = th.exp(Lambda * Delta) + B_bar = (1 / Lambda * (Lambda_bar - Identity))[..., None] * B_tilde + return Lambda_bar, B_bar + + +def apply_ssm( + Lambda_bar_NP, + B_bar_NPH, + C_tilde_HP, + input_sequence_NH, + alpha_NP, + conj_sym, + initial_state_P=None, +): + """Compute the NxH output of discretized SSM given an NxH input. + + modified from: https://github.com/lindermanlab/S5/blob/3c18fdb6b06414da35e77b94b9cd855f6a95ef17/s5/ssm.py#L60 + - removed bidirectionality. + - assume Lambda_bar is N-length. + + Args: + Lambda_bar_NP (complex64): discretized diagonal state matrix for each interval (N, P) + B_bar_NPH (complex64): "discretized" input matrix. Note: may be outside ZOH (N, P, H) + C_tilde_HP (complex64): output matrix (H, P) + input_sequence_NH (float32): input sequence of features (N, H) + alpha_NP (complex64): mark-specific biases (N, P) + conj_sym (bool): Whether conjugate symmetry is enforced + initial_state_P (): Allow passing in a specific initial state (otherwise zero is assumed.) + Returns: + ys_NH (float32): the SSM outputs (S5 layer preactivations) (N, H) + """ + N, P, H = B_bar_NPH.shape + + # Compute effective inputs. + Bu_elements_NP = th.vmap(lambda b, u, alpha: b @ u.type(th.complex64) + alpha)( + B_bar_NPH, input_sequence_NH, alpha_NP + ) + + # # Torch doesn't roll an associative scan... yet... + # _, xs = jax.lax.associative_scan(binary_operator, (Lambda_elements, Bu_elements)) + + # Set the initial state if we haven't already. + if initial_state_P is None: + state = th.zeros((P,)) + else: + state = initial_state_P + + # Accumulate the hidden states here. Note the initial state shouldn't be returned. + # xs = th.zeros((L, P)).type(th.complex64) + xs = [state] + + for i, (lam_P, bu_P) in enumerate(zip(Lambda_bar_NP, Bu_elements_NP)): + # state = lam_P * state + bu_P + # xs[i] = state + xs.append(lam_P * xs[-1] + bu_P) + xs = th.stack(xs)[1:] + + # Output the xs and ys after projecting. + if conj_sym: + return xs, th.vmap(lambda x: 2 * (C_tilde_HP @ x).real)(xs) + else: + return xs, th.vmap(lambda x: (C_tilde_HP @ x).real)(xs) diff --git a/easy_tpp/torch_wrapper.py b/easy_tpp/torch_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..870d2f3b3efdb401e1f49a56a1d1a4526e2c411f --- /dev/null +++ b/easy_tpp/torch_wrapper.py @@ -0,0 +1,147 @@ +""" Initialize a Pytorch model wrapper that feed into Model Runner """ + +import torch +from torch.utils.tensorboard import SummaryWriter + +from easy_tpp.utils import RunnerPhase, set_optimizer, set_device + + +class TorchModelWrapper: + def __init__(self, model, base_config, model_config, trainer_config): + """A wrapper class for Torch backends. + + Args: + model (BaseModel): a TPP model. + base_config (EasyTPP.Config): basic configs. + model_config (EasyTPP.ModelConfig): model spec configs. + trainer_config (EasyTPP.TrainerConfig): trainer spec configs. + """ + self.model = model + self.base_config = base_config + self.model_config = model_config + self.trainer_config = trainer_config + + self.model_id = self.base_config.model_id + # Sometimes PyTorch may not switch the active device context for all operations + # This causes illegal memory access error + if self.trainer_config.gpu!=-1: + torch.cuda.set_device(self.trainer_config.gpu) + self.device = set_device(self.trainer_config.gpu) + + self.model.to(self.device) + + if self.model_config.is_training: + # set up optimizer + optimizer = self.trainer_config.optimizer + self.learning_rate = self.trainer_config.learning_rate + self.opt = set_optimizer(optimizer, self.model.parameters(), self.learning_rate) + + # set up tensorboard + self.train_summary_writer, self.valid_summary_writer = None, None + if self.trainer_config.use_tfb: + self.train_summary_writer = SummaryWriter(log_dir=self.base_config.specs['tfb_train_dir']) + self.valid_summary_writer = SummaryWriter(log_dir=self.base_config.specs['tfb_valid_dir']) + + def restore(self, ckpt_dir): + """Load the checkpoint to restore the model. + + Args: + ckpt_dir (str): path for the checkpoint. + """ + + self.model.load_state_dict(torch.load(ckpt_dir), strict=False) + + def save(self, ckpt_dir): + """Save the checkpoint for the model. + + Args: + ckpt_dir (str): path for the checkpoint. + """ + torch.save(self.model.state_dict(), ckpt_dir) + + def write_summary(self, epoch, kv_pairs, phase): + """Write the kv_paris into the tensorboard + + Args: + epoch (int): epoch index in the training. + kv_pairs (dict): metrics dict. + phase (RunnerPhase): a const that defines the stage of model runner. + """ + if self.trainer_config.use_tfb: + summary_writer = None + if phase == RunnerPhase.TRAIN: + summary_writer = self.train_summary_writer + elif phase == RunnerPhase.VALIDATE: + summary_writer = self.valid_summary_writer + elif phase == RunnerPhase.PREDICT: + pass + + if summary_writer is not None: + for k, v in kv_pairs.items(): + if k != 'num_events': + summary_writer.add_scalar(k, v, epoch) + + summary_writer.flush() + return + + def close_summary(self): + """Close the tensorboard summary writer. + """ + if self.train_summary_writer is not None: + self.train_summary_writer.close() + + if self.valid_summary_writer is not None: + self.valid_summary_writer.close() + return + + def run_batch(self, batch, phase): + """Run one batch. + + Args: + batch (EasyTPP.BatchEncoding): preprocessed batch data that go into the model. + phase (RunnerPhase): a const that defines the stage of model runner. + + Returns: + tuple: for training and validation we return loss, prediction and labels; + for prediction we return prediction. + """ + + batch = batch.to(self.device).values() + if phase in (RunnerPhase.TRAIN, RunnerPhase.VALIDATE): + # set mode to train + is_training = (phase == RunnerPhase.TRAIN) + self.model.train(is_training) + + # FullyRNN needs grad event in validation stage + grad_flag = is_training if not self.model_id == 'FullyNN' else True + # run model + with torch.set_grad_enabled(grad_flag): + loss, num_event = self.model.loglike_loss(batch) + + # Assume we dont do prediction on train set + pred_dtime, pred_type, label_dtime, label_type, mask = None, None, None, None, None + + # update grad + if is_training: + self.opt.zero_grad() + (loss / num_event).backward() + self.opt.step() + else: # by default we do not do evaluation on train set which may take a long time + if self.model.event_sampler: + self.model.eval() + with torch.no_grad(): + if batch[1] is not None and batch[2] is not None: + label_dtime, label_type = batch[1][:, 1:].cpu().numpy(), batch[2][:, 1:].cpu().numpy() + if batch[3] is not None: + mask = batch[3][:, 1:].cpu().numpy() + pred_dtime, pred_type = self.model.predict_one_step_at_every_event(batch=batch) + pred_dtime = pred_dtime.detach().cpu().numpy() + pred_type = pred_type.detach().cpu().numpy() + return loss.item(), num_event, (pred_dtime, pred_type), (label_dtime, label_type), (mask,) + else: + pred_dtime, pred_type, label_dtime, label_type = self.model.predict_multi_step_since_last_event(batch=batch) + pred_dtime = pred_dtime.detach().cpu().numpy() + pred_type = pred_type.detach().cpu().numpy() + label_dtime = label_dtime.detach().cpu().numpy() + label_type = label_type.detach().cpu().numpy() + return (pred_dtime, pred_type), (label_dtime, label_type) diff --git a/easy_tpp/utils/__init__.py b/easy_tpp/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec65684d5888b8b790e71ea53cdab6199669a3e9 --- /dev/null +++ b/easy_tpp/utils/__init__.py @@ -0,0 +1,62 @@ +from easy_tpp.utils.const import RunnerPhase, LogConst, DefaultRunnerConfig, PaddingStrategy, TensorType, ExplicitEnum, \ + TruncationStrategy +from easy_tpp.utils.import_utils import is_torchvision_available, \ + is_torch_cuda_available, is_torch_available, requires_backends, is_torch_gpu_available +from easy_tpp.utils.log_utils import default_logger as logger, DEFAULT_FORMATTER +from easy_tpp.utils.metrics import MetricsHelper, MetricsTracker +from easy_tpp.utils.misc import py_assert, make_config_string, create_folder, save_yaml_config, load_yaml_config, \ + load_pickle, has_key, array_pad_cols, save_pickle, concat_element, get_stage, to_dict, \ + dict_deep_update, save_json, load_json +from easy_tpp.utils.multiprocess_utils import get_unique_id, Timer, parse_uri_to_protocol_and_path, is_master_process, \ + is_local_master_process +from easy_tpp.utils.ode_utils import rk4_step_method +from easy_tpp.utils.registrable import Registrable +from easy_tpp.utils.torch_utils import set_device, set_optimizer, set_seed, count_model_params +from easy_tpp.utils.generic import is_torch_device, is_numpy_array +from easy_tpp.utils.gen_utils import generate_and_save_json + +__all__ = ['py_assert', + 'make_config_string', + 'create_folder', + 'save_yaml_config', + 'load_yaml_config', + 'RunnerPhase', + 'LogConst', + 'load_pickle', + 'has_key', + 'array_pad_cols', + 'MetricsHelper', + 'MetricsTracker', + 'set_device', + 'set_optimizer', + 'set_seed', + 'save_pickle', + 'count_model_params', + 'Registrable', + 'logger', + 'get_unique_id', + 'Timer', + 'concat_element', + 'get_stage', + 'to_dict', + 'DEFAULT_FORMATTER', + 'parse_uri_to_protocol_and_path', + 'is_master_process', + 'is_local_master_process', + 'dict_deep_update', + 'DefaultRunnerConfig', + 'rk4_step_method', + 'is_torchvision_available', + 'is_torch_cuda_available', + 'is_torch_gpu_available', + 'is_torch_available', + 'requires_backends', + 'PaddingStrategy', + 'ExplicitEnum', + 'TruncationStrategy', + 'TensorType', + 'is_torch_device', + 'is_numpy_array', + 'save_json', + 'load_json', + 'generate_and_save_json'] diff --git a/easy_tpp/utils/const.py b/easy_tpp/utils/const.py new file mode 100644 index 0000000000000000000000000000000000000000..ae4de19e8b69dffe91ba513cd71f553e70ffc9bf --- /dev/null +++ b/easy_tpp/utils/const.py @@ -0,0 +1,90 @@ +from enum import Enum + + +class ExplicitEnum(str, Enum): + """ + Enum with more explicit error message for missing values. + """ + + def __str__(self): + return str(self.value) + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" + ) + + +class PaddingStrategy(ExplicitEnum): + """ + Possible values for the `padding` argument in [`EventTokenizer.__call__`]. Useful for tab-completion in an + IDE. + """ + + LONGEST = "longest" + MAX_LENGTH = "max_length" + DO_NOT_PAD = "do_not_pad" + + +class TensorType(ExplicitEnum): + """ + Possible values for the `return_tensors` argument in [`EventTokenizerBase.__call__`]. Useful for + tab-completion in an IDE. + """ + + PYTORCH = "pt" + NUMPY = "np" + + +class RunnerPhase(ExplicitEnum): + """Model runner phase enum. + """ + TRAIN = 'train' + VALIDATE = 'validate' + PREDICT = 'predict' + + +class LossFunction(ExplicitEnum): + """Loss function for neural TPP model. + """ + LOGLIKE = 'loglike' + PARTIAL_TIME_LOSS = 'rmse' + PARTIAL_EVENT_LOSS = 'accuracy' + + +class LogConst: + """Format for log handler. + """ + DEFAULT_FORMAT = '[%(asctime)s] [%(levelname)s] %(message)s' + DEFAULT_FORMAT_LONG = '%(asctime)s - %(filename)s[pid:%(process)d;line:%(lineno)d:%(funcName)s]' \ + ' - %(levelname)s: %(message)s' + + +class PredOutputIndex: + """Positional index for the output tuple in ModelRunner. + """ + TimePredIndex = 0 + TypePredIndex = 1 + + +class DefaultRunnerConfig: + DEFAULT_DATASET_ID = 'conttime' + + +class TruncationStrategy(ExplicitEnum): + """ + Possible values for the `truncation` argument in [`EventTokenizer.__call__`]. Useful for tab-completion in + an IDE. + """ + + LONGEST_FIRST = "longest_first" + DO_NOT_TRUNCATE = "do_not_truncate" + + +class Backend(ExplicitEnum): + """ + Possible values for the `backend` argument in configuration. + """ + + Torch = 'torch' diff --git a/easy_tpp/utils/gen_utils.py b/easy_tpp/utils/gen_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d0068764f4b2c4aab6199620bc65910137a689d9 --- /dev/null +++ b/easy_tpp/utils/gen_utils.py @@ -0,0 +1,120 @@ +import numpy as np +from easy_tpp.utils.misc import save_json + +def generate_synthetic_data(n_nodes=3, end_time=1000, baseline=0.1, adjacency=0.5, decay=1.0): + """ + Generates synthetic data using a multivariate Hawkes process with exponential kernels. + + Args: + n_nodes (int): Number of nodes (or dimensions) in the Hawkes process. + end_time (float): The time until which the process is simulated. + baseline (float): Baseline intensity for each node. + adjacency (float): Adjacency matrix value for the influence between nodes. + decay (float): Decay parameter for the exponential kernel. + + Returns: + list: A list of lists, where each sublist contains dictionaries representing events for a node. + """ + baseline_vector = np.full(n_nodes, baseline) + adjacency_matrix = np.full((n_nodes, n_nodes), adjacency) + events = [[] for _ in range(n_nodes)] + current_time = 0 + + while current_time < end_time: + # Calculate the intensity for each node + intensities = baseline_vector.copy() + for i in range(n_nodes): + for j in range(n_nodes): + if events[j]: + last_event_time = events[j][-1]['time_since_start'] + intensities[i] += adjacency_matrix[i, j] * np.exp(-decay * (current_time - last_event_time)) + + # Determine the next event time + total_intensity = np.sum(intensities) + if total_intensity == 0: + break + time_to_next_event = np.random.exponential(1 / total_intensity) + current_time += time_to_next_event + + if current_time >= end_time: + break + + # Determine which node the event occurs in + probabilities = intensities / total_intensity + node = np.random.choice(n_nodes, p=probabilities) + + # Record the event as a dictionary + if events[node]: + last_event_time = events[node][-1]['time_since_start'] + else: + last_event_time = 0 + + event = { + 'time_since_start': current_time, + 'time_since_last_event': current_time - last_event_time, + 'type_event': node + } + events[node].append(event) + + return events + +def format_tick_data_to_hf(events, dim_process, max_seq_len): + """ + Formats the synthetic data from a multivariate Hawkes process to the Hugging Face dataset format. + + Args: + events (list): A list of lists, where each sublist contains dictionaries representing events for a node. + dim_process (int): Number of nodes (or dimensions) in the Hawkes process. + max_seq_len (int): Maximum sequence length. + + Returns: + list: A list of dictionaries, where each dictionary represents a sequence. + """ + # Flatten all events into a single list + all_events = [event for node_events in events for event in node_events] + + # Sort events by time_since_start + all_events.sort(key=lambda x: x['time_since_start']) + + # Split into multiple sequences based on max_seq_len + formatted_data = [] + for seq_idx in range(0, len(all_events), max_seq_len): + seq_events = all_events[seq_idx:seq_idx + max_seq_len] + + # Adjust time_since_start to have zero start timestamps + start_time = seq_events[0]['time_since_start'] + time_since_start = [event['time_since_start'] - start_time for event in seq_events] + time_since_last_event = [event['time_since_last_event'] for event in seq_events] + type_event = [event['type_event'] for event in seq_events] + + temp_dict = { + 'dim_process': dim_process, + 'seq_idx': seq_idx // max_seq_len, + 'seq_len': len(seq_events), + 'time_since_start': time_since_start, + 'time_since_last_event': time_since_last_event, + 'type_event': type_event + } + formatted_data.append(temp_dict) + + return formatted_data + +def generate_and_save_json(n_nodes, end_time, baseline, adjacency, decay, max_seq_len, target_file): + """ + Generates synthetic data, formats it, and saves it to a file in Hugging Face format. + + Args: + n_nodes (int): Number of nodes (or dimensions) in the Hawkes process. + end_time (float): The time until which the process is simulated. + baseline (float): Baseline intensity for each node. + adjacency (float): Adjacency matrix value for the influence between nodes. + decay (float): Decay parameter for the exponential kernel. + max_seq_len (int): Maximum sequence length. + target_file (str): Path to the file where the formatted data will be saved. + + Raises: + IOError: If the file cannot be opened or written to. + """ + events = generate_synthetic_data(n_nodes, end_time, baseline, adjacency, decay) + formatted_data = format_tick_data_to_hf(events, dim_process=n_nodes, max_seq_len=max_seq_len) + save_json(formatted_data, target_file) \ No newline at end of file diff --git a/easy_tpp/utils/generic.py b/easy_tpp/utils/generic.py new file mode 100644 index 0000000000000000000000000000000000000000..95534354741fd5424c3f96df61d2d3173abcd71a --- /dev/null +++ b/easy_tpp/utils/generic.py @@ -0,0 +1,71 @@ +import numpy as np + +from easy_tpp.utils import is_torch_available + + +def is_tensor(x): + """ + Tests if `x` is a `torch.Tensor` or `np.ndarray`. + """ + if is_torch_available(): + import torch + + if isinstance(x, torch.Tensor): + return True + + return isinstance(x, np.ndarray) + + +def _is_numpy(x): + return isinstance(x, np.ndarray) + + +def is_numpy_array(x): + """ + Tests if `x` is a numpy array or not. + """ + return _is_numpy(x) + + +def _is_torch(x): + import torch + + return isinstance(x, torch.Tensor) + + +def is_torch_tensor(x): + """ + Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed. + """ + return False if not is_torch_available() else _is_torch(x) + + +def _is_torch_device(x): + import torch + + return isinstance(x, torch.device) + + +def is_torch_device(x): + """ + Tests if `x` is a torch device or not. Safe to call even if torch is not installed. + """ + return False if not is_torch_available() else _is_torch_device(x) + + +def _is_torch_dtype(x): + import torch + + if isinstance(x, str): + if hasattr(torch, x): + x = getattr(torch, x) + else: + return False + return isinstance(x, torch.dtype) + + +def is_torch_dtype(x): + """ + Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed. + """ + return False if not is_torch_available() else _is_torch_dtype(x) diff --git a/easy_tpp/utils/import_utils.py b/easy_tpp/utils/import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd23536b877dca491d7ed88eab3a83e2748511e --- /dev/null +++ b/easy_tpp/utils/import_utils.py @@ -0,0 +1,120 @@ +import importlib.util +import sys +from collections import OrderedDict +from typing import Union, Tuple + +from easy_tpp.utils.log_utils import default_logger as logger + +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: + # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + package_version = importlib_metadata.version(pkg_name) + except importlib_metadata.PackageNotFoundError: + pass + logger.debug(f"Detected {pkg_name} version {package_version}") + if return_version: + return package_exists, package_version + else: + return package_exists + + +_torchdistx_available = _is_package_available("torchdistx") +_torchvision_available = _is_package_available("torchvision") + +_torch_available, _torch_version = _is_package_available("torch", return_version=True) + + +def is_torch_available(): + return _torch_available + + +def get_torch_version(): + return _torch_version + + +def is_torchvision_available(): + return _torchvision_available + + +def is_torch_cuda_available(): + if is_torch_available(): + import torch + + return torch.cuda.is_available() + else: + return False + + +def is_torch_mps_available(): + if is_torch_available(): + try: + import torch + torch.device('mps') + return True + except RuntimeError: + return False + else: + return False + + +def is_torch_gpu_available(): + is_cuda_available = is_torch_cuda_available() + + is_mps_available = is_torch_mps_available() + + return is_cuda_available | is_mps_available + + +def torch_only_method(fn): + def wrapper(*args, **kwargs): + if not _torch_available: + raise ImportError( + "You need to install pytorch to use this method or class." + ) + else: + return fn(*args, **kwargs) + + return wrapper + + +# docstyle-ignore +PYTORCH_IMPORT_ERROR = """ +{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +TORCHVISION_IMPORT_ERROR = """ +{0} requires the Torchvision library but it was not found in your environment. Checkout the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. +Please note that you may need to restart your runtime after installation. +""" + +BACKENDS_MAPPING = OrderedDict( + [ + ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)) + ] +) + + +def requires_backends(obj, backends): + if not isinstance(backends, (list, tuple)): + backends = [backends] + + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + + checks = (BACKENDS_MAPPING[backend] for backend in backends) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError("".join(failed)) diff --git a/easy_tpp/utils/log_utils.py b/easy_tpp/utils/log_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..adfe0d371a93c2e2925023bfa98e33d904252aa8 --- /dev/null +++ b/easy_tpp/utils/log_utils.py @@ -0,0 +1,54 @@ +import logging +import sys +import typing + +from easy_tpp.utils.const import LogConst + +# -------- log setting --------- +DEFAULT_LOGGER = "easytpp.logger" + + +class CustomFormatter(logging.Formatter): + grey = "\x1b[38;20m" + yellow = "\x1b[33;20m" + red = "\x1b[31;20m" + bold_red = "\x1b[31;1m" + reset = "\x1b[0m" + format = LogConst.DEFAULT_FORMAT_LONG + + FORMATS = { + logging.DEBUG: grey + format + reset, + logging.INFO: grey + format + reset, + logging.WARNING: yellow + format + reset, + logging.ERROR: red + format + reset, + logging.CRITICAL: bold_red + format + reset + } + + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt) + return formatter.format(record) + + +DEFAULT_FORMATTER = CustomFormatter() + +_ch = logging.StreamHandler(stream=sys.stdout) +_ch.setFormatter(DEFAULT_FORMATTER) + +_DEFAULT_HANDLERS = [_ch] + +_LOGGER_CACHE = {} # type: typing.Dict[str, logging.Logger] + + +def get_logger(name, level="INFO", handlers=None, update=False): + if name in _LOGGER_CACHE and not update: + return _LOGGER_CACHE[name] + logger = logging.getLogger(name) + logger.setLevel(level) + logger.handlers = handlers or _DEFAULT_HANDLERS + logger.propagate = False + return logger + + +# -------------------------- Singleton Object -------------------------- +default_logger = get_logger(DEFAULT_LOGGER) diff --git a/easy_tpp/utils/metrics.py b/easy_tpp/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..09378d48813684ea64091bba8fa4a468d255102d --- /dev/null +++ b/easy_tpp/utils/metrics.py @@ -0,0 +1,112 @@ +from collections import defaultdict + +import numpy as np + +from easy_tpp.utils.log_utils import default_logger as logger + + +class MetricsHelper: + MAXIMIZE = 'maximize' + MINIMIZE = 'minimize' + _registry_center = defaultdict(tuple) + + @staticmethod + def get_metric_function(name): + if name in MetricsHelper._registry_center: + return MetricsHelper._registry_center[name][0] + else: + logger.warn(f'Metric is not found: {name}') + return None + + @staticmethod + def get_metric_direction(name): + if name in MetricsHelper._registry_center: + return MetricsHelper._registry_center[name][1] + else: + return None + + @staticmethod + def get_all_registered_metric(): + return MetricsHelper._registry_center.values + + @staticmethod + def register(name, direction, overwrite=True): + registry_center = MetricsHelper._registry_center + + def _add_metric_to_registry(func): + if name in registry_center: + if overwrite: + registry_center[name] = (func, direction) + else: + logger.warn(f'The metric {name} is already registered, and cannot be overwritten!') + else: + registry_center[name] = (func, direction) + return func + + return _add_metric_to_registry + + @staticmethod + def metrics_dict_to_str(metrics_dict): + """ Convert metrics to a string to show in console """ + eval_info = '' + for k, v in metrics_dict.items(): + eval_info += '{0} is {1}, '.format(k, v) + + return eval_info[:-2] + + @staticmethod + def get_metrics_callback_from_names(metric_names): + """ Metrics function callbacks """ + metric_functions = [] + metric_names_ = [] + for name in metric_names: + metric = MetricsHelper.get_metric_function(name) + if metric is not None: + metric_functions.append(metric) + metric_names_.append(name) + + def metrics(preds, labels, **kwargs): + """ call metrics functions """ + res = dict() + for metric_name, metric_func in zip(metric_names_, metric_functions): + res[metric_name.lower()] = metric_func(preds, labels, **kwargs) + return res + + return metrics + + +class MetricsTracker: + """Track and record the metrics. + """ + + def __init__(self): + self.current_best = { + 'loglike': np.finfo(float).min, + 'distance': np.finfo(float).max + } + self.episode_best = 'NeverUpdated' + + def update_best(self, key, value, epoch): + """Update the recorder for the best metrics. + + Args: + key (str): metrics key. + value (float): metrics value. + epoch (int): num of epoch. + + Raises: + NotImplementedError: for keys other than 'loglike'. + + Returns: + bool: whether the recorder has been updated. + """ + updated = False + if key == 'loglike': + if value > self.current_best[key]: + updated = True + self.current_best[key] = value + self.episode_best = epoch + else: + raise NotImplementedError + + return updated diff --git a/easy_tpp/utils/misc.py b/easy_tpp/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7819b328f9279c806c87e2975ed0364dfbefff91 --- /dev/null +++ b/easy_tpp/utils/misc.py @@ -0,0 +1,286 @@ +import copy +import os +import pickle + +import numpy as np +import yaml +import json +from easy_tpp.utils.const import RunnerPhase + + +def py_assert(condition, exception_type, msg): + """An assert function that ensures the condition holds, otherwise throws a message. + + Args: + condition (bool): a formula to ensure validity. + exception_type (_StandardError): Error type, such as ValueError. + msg (str): a message to throw out. + + Raises: + exception_type: throw an error when the condition does not hold. + """ + if not condition: + raise exception_type(msg) + + +def make_config_string(config, max_num_key=4): + """Generate a name for config files. + + Args: + config (dict): configuration dict. + max_num_key (int, optional): max number of keys to concat in the output. Defaults to 4. + + Returns: + dict: a concatenated string from config dict. + """ + str_config = '' + num_key = 0 + for k, v in config.items(): + if num_key < max_num_key: # for the moment we only record model name + if k == 'name': + str_config += str(v) + '_' + num_key += 1 + return str_config[:-1] + + +def save_yaml_config(save_dir, config): + """A function that saves a dict of config to yaml format file. + + Args: + save_dir (str): the path to save config file. + config (dict): the target config object. + """ + prt_dir = os.path.dirname(save_dir) + + from collections import OrderedDict + # add yaml representer for different type + yaml.add_representer( + OrderedDict, + lambda dumper, data: dumper.represent_mapping('tag:yaml.org,2002:map', data.items()) + ) + + if prt_dir != '' and not os.path.exists(prt_dir): + os.makedirs(prt_dir) + + with open(save_dir, 'w') as f: + yaml.dump(config, stream=f, default_flow_style=False, sort_keys=False) + + return + + +def load_yaml_config(config_dir): + """ Load yaml config file from disk. + + Args: + config_dir: str or Path + The path of the config file. + + Returns: + Config: dict. + """ + with open(config_dir) as config_file: + # load configs + config = yaml.load(config_file, Loader=yaml.FullLoader) + + return config + + +def get_stage(stage): + stage = stage.lower() + if stage in ['train', 'training']: + return RunnerPhase.TRAIN + elif stage in ['valid', 'dev', 'eval']: + return RunnerPhase.VALIDATE + else: + return RunnerPhase.PREDICT + + +def create_folder(*args): + """Create path if the folder doesn't exist. + + Returns: + str: the created folder's path. + """ + path = os.path.join(*args) + if not os.path.exists(path): + os.makedirs(path) + return path + + +def load_pickle(file_dir): + """Load from pickle file. + + Args: + file_dir (BinaryIO): dir of the pickle file. + + Returns: + any type: the loaded data. + """ + with open(file_dir, 'rb') as file: + try: + data = pickle.load(file, encoding='latin-1') + except Exception: + data = pickle.load(file) + + return data + + +def save_pickle(file_dir, object_to_save): + """Save the object to a pickle file. + + Args: + file_dir (str): dir of the pickle file. + object_to_save (any): the target data to be saved. + """ + + with open(file_dir, "wb") as f_out: + pickle.dump(object_to_save, f_out) + + return + + +def save_json(data, file_dir): + """ + Save data to a JSON file. + + Args: + data: The data to be saved. It should be JSON serializable (e.g., a dictionary or list). + file_dir (str): The path to the file where the data will be saved. + + Raises: + IOError: If the file cannot be opened or written to. + """ + with open(file_dir, 'w') as outfile: + json.dump(data, outfile, indent=4) + print(f"Data successfully saved to {file_dir}") + + +def load_json(file_dir): + """ + Reads data from a JSON file. + + Args: + file_dir (str): The path to the JSON file to be read. + + Returns: + The data read from the JSON file. + + Raises: + IOError: If the file cannot be opened or read. + json.JSONDecodeError: If the file is not a valid JSON. + """ + with open(file_dir, 'r') as infile: + data = json.load(infile) + return data + + +def has_key(target_dict, target_keys): + """Check if the keys exist in the target dict. + + Args: + target_dict (dict): a dict. + target_keys (str, list): list of keys. + + Returns: + bool: True if all the key exist in the dict; False otherwise. + """ + if not isinstance(target_keys, list): + target_keys = [target_keys] + for k in target_keys: + if k not in target_dict: + return False + return True + + +def array_pad_cols(arr, max_num_cols, pad_index): + """Pad the array by columns. + + Args: + arr (np.array): target array to be padded. + max_num_cols (int): target num cols for padded array. + pad_index (int): pad index to fill out the padded elements + + Returns: + np.array: the padded array. + """ + res = np.ones((arr.shape[0], max_num_cols)) * pad_index + + res[:, :arr.shape[1]] = arr + + return res + + +def concat_element(arrs, pad_index): + """ Concat element from each batch output """ + + n_lens = len(arrs) + n_elements = len(arrs[0]) + + # found out the max seq len (num cols) in output arrays + max_len = max([x[0].shape[1] for x in arrs]) + + concated_outputs = [] + for j in range(n_elements): + a_output = [] + for i in range(n_lens): + arrs_ = array_pad_cols(arrs[i][j], max_num_cols=max_len, pad_index=pad_index) + a_output.append(arrs_) + + concated_outputs.append(np.concatenate(a_output, axis=0)) + + # n_elements * [ [n_lens, dim_of_element] ] + return concated_outputs + + +def to_dict(obj, classkey=None): + if isinstance(obj, dict): + data = {} + for (k, v) in obj.items(): + data[k] = to_dict(v, classkey) + return data + elif hasattr(obj, "_ast"): + return to_dict(obj._ast()) + elif hasattr(obj, "__iter__"): + return [to_dict(v, classkey) for v in obj] + elif hasattr(obj, "__dict__"): + data = dict([(key, to_dict(value, classkey)) + for key, value in obj.__dict__.iteritems() + if not callable(value) and not key.startswith('_') and key not in ['name']]) + if classkey is not None and hasattr(obj, "__class__"): + data[classkey] = obj.__class__.__name__ + return data + else: + return obj + + +def dict_deep_update(target, source, is_add_new_key=True): + """ Update 'target' dict by 'source' dict deeply, and return a new dict copied from target and source deeply. + + Args: + target: dict + source: dict + is_add_new_key: bool, default True. + Identify if add a key that in source but not in target into target. + + Returns: + New target: dict. It contains the both target and source values, but keeps the values from source when the key + is duplicated. + """ + # deep copy for avoiding to modify the original dict + result = copy.deepcopy(target) if target is not None else {} + + if source is None: + return result + + for key, value in source.items(): + if key not in result: + if is_add_new_key: + result[key] = value + continue + # both target and source have the same key + base_type_list = [int, float, str, tuple, bool] + if type(result[key]) in base_type_list or type(source[key]) in base_type_list: + result[key] = value + else: + result[key] = dict_deep_update(result[key], source[key], is_add_new_key=is_add_new_key) + return result diff --git a/easy_tpp/utils/multiprocess_utils.py b/easy_tpp/utils/multiprocess_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..418a6223206339d27711eb9bb8d4eb2aa4da5bb7 --- /dev/null +++ b/easy_tpp/utils/multiprocess_utils.py @@ -0,0 +1,115 @@ +import os +import time + + +def is_master_process(): + """ Check if the process is the master process in all machines. + + Returns: + bool + """ + rank = 0 if os.getenv('RANK') is None else int(os.getenv('RANK')) + if rank == 0: + return True + else: + return False + + +def is_local_master_process(): + """ Check if the process is the master process in the local machine. + + Returns: + bool + """ + rank = 0 if os.getenv('RANK') is None else int(os.getenv('RANK')) + local_world_size = 1 if os.getenv('LOCAL_WORLD_SIZE') is None else int(os.getenv('LOCAL_WORLD_SIZE')) + if local_world_size == 0 or rank % local_world_size == 0: + return True + else: + return False + + +def get_now_timestamp_id(): + """ Get the current timestamp string. + + Returns: + A string like yyMMdd_hhmmss + """ + import datetime + return datetime.datetime.now().strftime('%y%m%d-%H%M%S') + + +def get_unique_id(): + """ Generate a unique id string based on process id (pid), thread id and timestamp. + + Returns: + Unique id: str + """ + import os + import threading + pid = os.getpid() + tid = threading.currentThread().ident + ts_id = get_now_timestamp_id() + + return '{}_{}_{}'.format(pid, tid, ts_id) + + +def parse_uri_to_protocol_and_path(uri): + """ Parse a uri into two parts, protocol and path. Set 'file' as default protocol when lack protocol. + + Args: + uri: str + The uri to identify a resource, whose format is like 'protocol://uri'. + + Returns: + Protocol: str. The method to access the resource. + URI: str. The location of the resource. + """ + + if uri is None: + return None, None + tokens = uri.split('://') + if len(tokens) == 2: + protocol = tokens[0] + path = tokens[1] + elif len(tokens) == 1: + protocol = 'file' + path = tokens[0] + else: + raise RuntimeError(f'Wrong url format: {uri}') + + return protocol, path + + +class Timer: + """Count the elapsing time between start and end. + """ + + def __init__(self, unit='m'): + unit = unit.lower() + if unit == 's': + self._unit = 1 + elif unit == 'm': + self._unit = 60 + elif unit == 'h': + self._unit = 1440 + else: + raise RuntimeError('Unknown unit:', unit) + + self.unit = unit + # default start time is set to the time the object initialized + self._start_time = time.time() + + def start(self): + self._start_time = time.time() + + def end(self): + end_time = time.time() + cost = (end_time - self._start_time) / self._unit + # reset the start time using the end time + self._start_time = end_time + return '%.3f%s' % (cost, self.unit) + + +# -------------------------- Singleton Object -------------------------- +default_timer = Timer() diff --git a/easy_tpp/utils/ode_utils.py b/easy_tpp/utils/ode_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..66a5e3cb5f42948c444a306c9103708b8fc5c54c --- /dev/null +++ b/easy_tpp/utils/ode_utils.py @@ -0,0 +1,91 @@ +def ode_update_op(z0, dz, dt): + """ + General update operation for solving ODEs. + + Args: + z0: Tensor or a list for Tensor whose shape is [..., dim] + State at t0. + dz: Tensor or a list for Tensor whose shape is [..., dim] + Differentiation of state. + dt: Tensor with shape [..., 1] + Equal to t1 - t0. + + Returns: + + """ + if isinstance(z0, list) or isinstance(z0, tuple): + return [item_z + dt * item_dz for item_z, item_dz in zip(z0, dz)] + else: + return z0 + dt * dz + + +def euler_step_method(diff_func, dt, z0): + """ + Euler method for solving ODEs. + + Args: + diff_func: function(state) + Differential equation. + dt: Tensor with shape [..., 1] + Equal to t1 - t0. + z0: Tensor or a list for Tensor whose shape is [..., dim] + State at t0. + + Returns: + Tensor or a list for Tensor whose shape is [..., dim], which is updated state. + """ + dz = diff_func(z0) + return ode_update_op(z0, dz, dt) + + +def rk2_step_method(diff_func, dt, z0): + """ + Second order Runge-Kutta method for solving ODEs. + + Args: + diff_func: function(dt, state) + Differential equation. + dt: Tensor with shape [..., 1] + Equal to t1 - t0. + z0: Tensor or a list for Tensor whose shape is [..., dim] + State at t0. + + Returns: + Tensor or a list for Tensor whose shape is [..., dim] + """ + # shape -> [..., dim] + k1 = diff_func(z0) + k2 = diff_func(ode_update_op(z0, k1, dt)) + + if isinstance(z0, list) or isinstance(z0, tuple): + return [item_z + (item_k1 + item_k2) * dt * 0.5 for item_z, item_k1, item_k2 in zip(z0, k1, k2)] + else: + return z0 + dt * (k1 + k2) * 0.5 + + +def rk4_step_method(diff_func, dt, z0): + """ + Fourth order Runge-Kutta method for solving ODEs. + + Args: + diff_func: function(dt, state) + Differential equation. + dt: Tensor with shape [..., 1] + Equal to t1 - t0. + z0: Tensor with shape [..., dim] + State at t0. + + Returns: + Tensor with shape [..., dim], which is updated state. + """ + # shape -> [..., dim] + k1 = diff_func(z0) + k2 = diff_func(ode_update_op(z0, k1, dt / 2.0)) + k3 = diff_func(ode_update_op(z0, k2, dt / 2.0)) + k4 = diff_func(ode_update_op(z0, k3, dt)) + + if isinstance(z0, list) or isinstance(z0, tuple): + return [item_z + (item_k1 + 2.0 * item_k2 + 2.0 * item_k3 + item_k4) * dt / 6.0 + for item_z, item_k1, item_k2, item_k3, item_k4 in zip(z0, k1, k2, k3, k4)] + else: + return z0 + dt * (k1 + k2 * 2.0 + k3 * 2.0 + k4) / 6.0 diff --git a/easy_tpp/utils/registrable.py b/easy_tpp/utils/registrable.py new file mode 100644 index 0000000000000000000000000000000000000000..01b1142a3a8eedf7502d73b5dadddd1fead1f813 --- /dev/null +++ b/easy_tpp/utils/registrable.py @@ -0,0 +1,156 @@ +from collections import defaultdict + +from .log_utils import default_logger as logger + + +class Registrable: + """Any class that inherits from ``Registrable`` gains access to a named registry for its subclasses. To register them, just decorate them with the classmethod ``@BaseClass.register(name)``. + + After which you can call ``BaseClass.list_available()`` to get the keys for the registered subclasses, and ``BaseClass.by_name(name)`` to get the corresponding subclass. + + Note that the registry stores the subclasses themselves; not class instances. In most cases you would then call ``from_params(params)`` on the returned subclass. + """ + + _registry = defaultdict(dict) + _default_impl = None + + @classmethod + def register(cls, name, constructor=None, overwrite=False): + """Register a class under a particular name. + Args: + name (str): The name to register the class under. + constructor (str): optional (default=None) + The name of the method to use on the class to construct the object. If this is given, + we will use this method (which must be a ``classmethod``) instead of the default + constructor. + overwrite (bool) : optional (default=False) + If True, overwrites any existing models registered under ``name``. Else, + throws an error if a model is already registered under ``name``. + + # Examples + To use this class, you would typically have a base class that inherits from ``Registrable``: + ```python + class Transform(Registrable): + ... + ``` + Then, if you want to register a subclass, you decorate it like this: + ```python + @Transform.register("shift-transform") + class ShiftTransform(Transform): + def __init__(self, param1: int, param2: str): + ... + ``` + Registering a class like this will let you instantiate a class from a config file, where you + give ``"type": "shift-transform"``, and keys corresponding to the parameters of the ``__init__`` + method (note that for this to work, those parameters must have type annotations). + If you want to have the instantiation from a config file call a method other than the + constructor, either because you have several different construction paths that could be + taken for the same object (as we do in ``Transform``) or because you have logic you want to + happen before you get to the constructor, you can register a specific ``@classmethod`` as the constructor to use. + """ + registry = Registrable._registry[cls] + + def add_subclass_to_registry(subclass): + # Add to registry, raise an error if key has already been used. + if name in registry: + if overwrite: + message = ( + f"{name} has already been registered as {registry[name][0].__name__}, but " + f"overwrite=True, so overwriting with {cls.__name__}" + ) + logger.info(message) + else: + message = ( + f"Cannot register {name} as {cls.__name__}; " + f"name already in use for {registry[name][0].__name__}" + ) + raise RuntimeError(message) + registry[name] = (subclass, constructor) + return subclass + + return add_subclass_to_registry + + @classmethod + def by_name(cls, name): + """ + Returns a callable function that constructs an argument of the registered class. Because + you can register particular functions as constructors for specific names, this isn't + necessarily the ``__init__`` method of some class. + """ + logger.debug(f"instantiating registered subclass {name} of {cls}") + subclass, constructor = cls.resolve_class_name(name) + if not constructor: + return subclass + else: + return getattr(subclass, constructor) + + @classmethod + def resolve_class_name(cls, name): + """ + Returns the subclass that corresponds to the given ``name``, along with the name of the + method that was registered as a constructor for that ``name``, if any. + This method also allows ``name`` to be a fully-specified module name, instead of a name that + was already added to the ``Registry``. In that case, you cannot use a separate function as + a constructor (as you need to call ``cls.register()`` in order to tell us what separate + function to use). + """ + if name in Registrable._registry[cls]: + subclass, constructor = Registrable._registry[cls].get(name) + return subclass, constructor + else: + for base_cls, v in Registrable._registry.items(): + if name in v: + subclass, constructor = Registrable._registry[base_cls].get(name) + return subclass, constructor + + if "." in name: + # This might be a fully qualified class name, so we'll try importing its "module" + # and finding it there. + parts = name.split(".") + submodule = ".".join(parts[:-1]) + class_name = parts[-1] + import importlib + try: + module = importlib.import_module(submodule) + except ModuleNotFoundError: + raise RuntimeError( + f"tried to interpret {name} as a path to a class " + f"but unable to import module {submodule}" + ) + + try: + subclass = getattr(module, class_name) + constructor = None + return subclass, constructor + except AttributeError: + raise RuntimeError( + f"tried to interpret {name} as a path to a class " + f"but unable to find class {class_name} in {submodule}" + ) + + else: + # is not a qualified class name + raise RuntimeError( + f"{name} is not a registered name for {cls.__name__}. " + "You probably need to use the --include-package flag " + "to load your custom code. Alternatively, you can specify your choices " + """using fully-qualified paths, e.g. {"model": "my_module.models.MyModel"} """ + "in which case they will be automatically imported correctly." + ) + + @classmethod + def list_available(cls): + """List default first if it exists""" + keys = list(Registrable._registry[cls].keys()) + default = cls._default_impl + + if default is None: + return keys + elif default not in keys: + raise RuntimeError(f"Default implementation {default} is not registered") + else: + return [default] + [k for k in keys if k != default] + + @classmethod + def registry_dict(cls): + return Registrable._registry[cls] diff --git a/easy_tpp/utils/torch_utils.py b/easy_tpp/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dae7406a0cd1556a5ebbd339237946f2ff0e1a33 --- /dev/null +++ b/easy_tpp/utils/torch_utils.py @@ -0,0 +1,74 @@ +import os +import random + +import numpy as np +import torch + +from easy_tpp.utils.import_utils import is_torch_mps_available + + +def set_seed(seed=1029): + """Setup random seed. + + Args: + seed (int, optional): random seed. Defaults to 1029. + """ + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def set_device(gpu=-1): + """Setup the device. + + Args: + gpu (int, optional): num of GPU to use. Defaults to -1 (not use GPU, i.e., use CPU). + """ + if gpu >= 0: + if torch.cuda.is_available(): + device = torch.device("cuda:" + str(gpu)) + elif is_torch_mps_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + return device + + +def set_optimizer(optimizer, params, lr): + """Setup the optimizer. + + Args: + optimizer (str): name of the optimizer. + params (dict): dict of params for the optimizer. + lr (float): learning rate. + + Raises: + NotImplementedError: if the optimizer's name is wrong or the optimizer is not supported, + we raise error. + + Returns: + torch.optim: torch optimizer. + """ + if isinstance(optimizer, str): + if optimizer.lower() == "adam": + optimizer = "Adam" + try: + optimizer = getattr(torch.optim, optimizer)(params, lr=lr) + except Exception: + raise NotImplementedError("optimizer={} is not supported.".format(optimizer)) + return optimizer + + +def count_model_params(model): + """Count the number of params of the model. + + Args: + model (torch.nn.Moduel): a torch model. + + Returns: + int: total num of the parameters. + """ + return sum(p.numel() for p in model.parameters()) diff --git a/example_compute_metrics.sh b/example_compute_metrics.sh new file mode 100644 index 0000000000000000000000000000000000000000..672b162eb45adf91ee4d537850c761ea875f840f --- /dev/null +++ b/example_compute_metrics.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# 计算级联指标的示例脚本 + +# 设置路径(请根据实际情况修改) +INPUT_CASCADE="information_cascade.json" +INPUT_ORIGINAL="information_cascade_original_posts.json" +OUTPUT="output_with_metrics.json" + +# 基本用法(使用默认模型和简化方法) +echo "=== 方法1: 基本用法(使用默认模型)===" +python compute_cascade_metrics.py \ + --input_cascade "$INPUT_CASCADE" \ + --output "$OUTPUT" \ + --batch_size 32 + +# 使用GPU加速 +echo "=== 方法2: 使用GPU加速 ===" +python compute_cascade_metrics.py \ + --input_cascade "$INPUT_CASCADE" \ + --output "${OUTPUT%.json}_gpu.json" \ + --device cuda \ + --batch_size 64 + +# 使用自定义BERT模型 +echo "=== 方法3: 使用自定义BERT模型 ===" +python compute_cascade_metrics.py \ + --input_cascade "$INPUT_CASCADE" \ + --output "${OUTPUT%.json}_custom.json" \ + --bert_model bert-base-chinese \ + --batch_size 32 + +# 测试模式(只处理前10个级联) +echo "=== 方法4: 测试模式(处理前10个级联)===" +python compute_cascade_metrics.py \ + --input_cascade "$INPUT_CASCADE" \ + --output "${OUTPUT%.json}_test.json" \ + --max_cascades 10 \ + --batch_size 16 + +echo "完成!" diff --git a/examples/configs/experiment_config.yaml b/examples/configs/experiment_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..27ece55ae84c56fb1e1983e24c49471f8c61bbc4 --- /dev/null +++ b/examples/configs/experiment_config.yaml @@ -0,0 +1,625 @@ +pipeline_config_id: runner_config + +data: + taxi: + data_format: json + train_dir: easytpp/taxi # ./data/taxi/train.json + valid_dir: easytpp/taxi # ./data/taxi/dev.json + test_dir: easytpp/taxi # ./data/taxi/test.json + data_specs: + num_event_types: 10 + pad_token_id: 10 + padding_side: right +# padding_strategy: max_length +# truncation_strategy: longest_first # or Truncate to a maximum length specified with the argument `max_length` +# max_len: 20 + conttime: + data_format: pkl + train_dir: ../data/conttime/train.pkl + valid_dir: ../data/conttime/dev.pkl + test_dir: ../data/conttime/test.pkl + data_specs: + num_event_types: 5 + pad_token_id: 5 + padding_side: right + truncation_side: right +# padding_strategy: max_length # for ode tpp we have to set this to max_length +# max_len: 20 + hawkes_1d: + data_format: pkl + train_dir: ../data/hawkes/train.pkl + valid_dir: ../data/hawkes/dev.pkl + test_dir: ../data/hawkes/test.pkl + data_specs: + num_event_types: 1 + pad_token_id: 1 + padding_side: right + truncation_side: right + retweet: + data_format: pkl + train_dir: ../data/retweet/train.pkl + valid_dir: ../data/retweet/dev.pkl + test_dir: ../data/retweet/test.pkl + data_specs: + num_event_types: 3 + pad_token_id: 3 + padding_side: right + truncation_side: right + + + +RMTPP_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: RMTPP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 20 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: -1 + model_config: + hidden_size: 32 + time_emb_size: 16 + num_layers: 2 + num_heads: 2 + mc_num_sample_per_step: 20 + sharing_param_layer: False + loss_integral_num_sample_per_step: 20 + dropout: 0.0 + use_ln: False + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + + +RMTPP_gen: + base_config: + stage: gen + backend: torch + dataset_id: retweet + runner_id: std_tpp + base_dir: './checkpoints/' + model_id: RMTPP + model_config: + hidden_size: 32 + time_emb_size: 16 + mc_num_sample_per_step: 20 + sharing_param_layer: False + loss_integral_num_sample_per_step: 20 + dropout: 0.0 + use_ln: False + seed: 2019 + gpu: 0 + pretrained_model_dir: ./checkpoints/2555_4348724608_230603-155841/models/saved_model + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 10 + +NHP_eval: + base_config: + stage: eval + backend: torch + dataset_id: taxi + runner_id: std_tpp + base_dir: './checkpoints/' + model_id: NHP + trainer_config: + batch_size: 256 + max_epoch: 1 + model_config: + hidden_size: 64 + use_ln: False + seed: 2019 + gpu: 0 + pretrained_model_dir: ./checkpoints/26507_4380788096_231111-101848/models/saved_model + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + +NHP_gen: + base_config: + stage: eval + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: NHP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 20 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: -1 + model_config: + hidden_size: 64 + loss_integral_num_sample_per_step: 20 + pretrained_model_dir: ./checkpoints/75518_4377527680_230530-132355/models/saved_model + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 1 + +FullyNN_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: FullyNN # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 200 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: 0 + model_config: + rnn_type: LSTM + hidden_size: 32 + time_emb_size: 4 + num_layers: 2 + num_heads: 2 + mc_num_sample_per_step: 20 + sharing_param_layer: False + loss_integral_num_sample_per_step: 20 + dropout: 0.0 + use_ln: False + model_specs: + num_mlp_layers: 3 +# thinning: +# num_seq: 10 +# num_sample: 1 +# num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm +# look_ahead_time: 10 +# patience_counter: 5 # the maximum iteration used in adaptive thinning +# over_sample_rate: 5 +# num_samples_boundary: 5 +# dtime_max: 5 +# num_step_gen: 1 + + + +IntensityFree_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: IntensityFree # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 200 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: 0 + model_config: + hidden_size: 32 + time_emb_size: 16 + num_layers: 2 + num_heads: 2 + mc_num_sample_per_step: 20 + sharing_param_layer: False + loss_integral_num_sample_per_step: 20 + dropout: 0.0 + use_ln: False + model_specs: + num_mix_components: 3 + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 1 + + +ODETPP_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: ODETPP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 32 + max_epoch: 200 + shuffle: False + optimizer: adam + learning_rate: 1.e-1 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: -1 + model_config: + hidden_size: 4 + time_emb_size: 4 + num_layers: 1 + sharing_param_layer: False + loss_integral_num_sample_per_step: 20 + dropout: 0.0 + use_ln: False + model_specs: + ode_num_sample_per_step: 2 + time_factor: 100 + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 50 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 1 + +ODETPP_gen: + base_config: + stage: gen + backend: torch + dataset_id: retweet + runner_id: std_tpp + base_dir: './checkpoints/' + model_id: ODETPP + trainer_config: + batch_size: 256 + max_epoch: 1 + model_config: + hidden_size: 32 + time_emb_size: 16 + num_layers: 1 + sharing_param_layer: False + loss_integral_num_sample_per_step: 20 + dropout: 0.0 + use_ln: False + seed: 2019 + gpu: 0 + pretrained_model_dir: ./checkpoints/3538_4310828416_230603-165911/models/saved_model + model_specs: + ode_num_sample_per_step: 2 + time_factor: 100 + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 10 + +NHP_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: NHP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 2 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: -1 + model_config: + hidden_size: 64 + loss_integral_num_sample_per_step: 20 +# pretrained_model_dir: ./checkpoints/75518_4377527680_230530-132355/models/saved_model + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 1 + + + +SAHP_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: SAHP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 20 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: 0 + model_config: + hidden_size: 32 + time_emb_size: 16 + num_layers: 2 + num_heads: 2 + loss_integral_num_sample_per_step: 20 + use_ln: False + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 1 + + + +SAHP_gen: + base_config: + stage: gen + backend: torch + dataset_id: retweet + runner_id: std_tpp + model_id: SAHP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 1 + model_config: + hidden_size: 16 + time_emb_size: 4 + num_layers: 2 + num_heads: 2 + loss_integral_num_sample_per_step: 20 + use_ln: False + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 10 + +THP_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: THP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 30 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: -1 + model_config: + hidden_size: 32 + time_emb_size: 16 + num_layers: 2 + num_heads: 2 + mc_num_sample_per_step: 20 + loss_integral_num_sample_per_step: 20 + use_ln: False + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 1 + + +THP_gen: + base_config: + stage: gen + backend: torch + dataset_id: retweet + runner_id: std_tpp + model_id: THP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 1 + model_config: + hidden_size: 32 + time_emb_size: 16 + num_layers: 2 + num_heads: 2 + mc_num_sample_per_step: 20 + loss_integral_num_sample_per_step: 20 + use_ln: False +# pretrained_model_dir: ./checkpoints/2694_4384867712_230603-160544/models/saved_model + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 10 + +AttNHP_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: AttNHP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 200 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: -1 + model_config: + hidden_size: 16 + time_emb_size: 4 + num_layers: 2 + num_heads: 2 + loss_integral_num_sample_per_step: 10 + use_ln: False + thinning: + num_seq: 2 + num_sample: 1 + num_exp: 50 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 1 + + +AttNHP_gen: + base_config: + stage: gen + backend: torch + dataset_id: retweet + runner_id: std_tpp + model_id: AttNHP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 1 + model_config: + hidden_size: 16 + time_emb_size: 4 + num_layers: 2 + num_heads: 2 + mc_num_sample_per_step: 20 + loss_integral_num_sample_per_step: 20 + use_ln: False +# pretrained_model_dir: ./checkpoints/6934_4375315840_230603-222826/models/saved_model + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 50 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 10 + + +# Example configuration for training State-Space Point Process (S2P2) model. +S2P2_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: S2P2 + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 300 + shuffle: True + optimizer: adam + learning_rate: 1.e-2 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: -1 # ID of GPU to use. Set to -1 to use CPU instead. `mps` backend could lead to incorrect results, please use CPU or CUDA. + model_config: + hidden_size: 128 # Number of dimensions for u_t and y_t, labeled as H in the paper. + loss_integral_num_sample_per_step: 10 # How many time points to use to estimate the integrated intensity between each pair of subsequent events for the log-likelihood. + use_mc_samples: True # Use Monte-Carlo sampling for the integral estimation. If False, uses a quadrature with a grid of evenly spaced points. + num_layers: 4 # Number of LLH layers. + model_specs: + P: 16 # Number of dimensions for the hidden state x_t, labeled as P in the paper. + dropout_rate: 0.1 # Dropout rate, used immediately after the activation function between layers but before the normalization. Formally, we set u^{(l+1)}_t = LayerNorm(dropout(\sigma(y^{(l)}_t)) + u^{(l)}_t). + act_func: gelu # gelu | half_glu | full_glu # Activation function to use between layers. + for_loop: True # If enabled, uses for-loop for computing the recurrence in the LLH layers. If disabled, uses a parallel scan. + pre_norm: False # Should be set to False. If True, uses a LayerNorm on the inputs to a LLH layer. + post_norm: True # Should be set to True. If True, uses a LayerNorm on the outputs of a LLH layer (after transforming and adding the residual). + int_forward_variant: False # Should be set to False. If True, uses u_{t_i} as the ZOH constant for u_t with t \in (t_i, t_{i+1}]. + int_backward_variant: True # Should be set to True. If True, uses u_{t_{i+1}-} as the ZOH constant for u_t with t \in (t_i, t_{i+1}]. + relative_time: True # If True, predicts the scaling factor to be applied to the dynamics between each pair of subsequent events. See Sec. 3.3 of the paper. diff --git a/examples/configs/hpo_config.yaml b/examples/configs/hpo_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9302adebe347677509a0fe493eb44de5e9e46ced --- /dev/null +++ b/examples/configs/hpo_config.yaml @@ -0,0 +1,55 @@ +pipeline_config_id: hpo_runner_config + +data: + taxi: + data_format: pkl + train_dir: ./data/taxi/train.pkl + valid_dir: ./data/taxi/dev.pkl + test_dir: ./data/taxi/test.pkl + data_specs: + num_event_types: 10 + pad_token_id: 10 + padding_side: right + truncation_side: right + +hpo: + storage_uri: sqlite://hpo_test.db + is_continuous: False + framework_id: optuna # the framework of hpo + n_trials: 10 + + +NHP_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: NHP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 200 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: -1 + model_config: + hidden_size: 64 + loss_integral_num_sample_per_step: 20 +# pretrained_model_dir: ./checkpoints/75518_4377527680_230530-132355/models/saved_model + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 1 + diff --git a/examples/data/.gitkeep b/examples/data/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/data_inspection/config.yaml b/examples/data_inspection/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91effc9aac86f1b61e0ac574aad387d4ed154efc --- /dev/null +++ b/examples/data_inspection/config.yaml @@ -0,0 +1,10 @@ +pipeline_config_id: data_config + +data_format: json +train_dir: easytpp/taxi # ./data/taxi/train.json +valid_dir: easytpp/taxi # ./data/taxi/dev.json +test_dir: easytpp/taxi # ./data/taxi/test.json +data_specs: + num_event_types: 10 + pad_token_id: 10 + padding_side: right \ No newline at end of file diff --git a/examples/data_inspection/data_inspection.py b/examples/data_inspection/data_inspection.py new file mode 100644 index 0000000000000000000000000000000000000000..8955c481bdc54972e8511c3d9809bbe57e296807 --- /dev/null +++ b/examples/data_inspection/data_inspection.py @@ -0,0 +1,20 @@ +import os +import sys +# Get the directory of the current file +current_file_path = os.path.abspath(__file__) +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))) + +from easy_tpp.config_factory import Config +from easy_tpp.preprocess.data_loader import TPPDataLoader + + +def main(): + config = Config.build_from_yaml_file('./config.yaml') + tpp_loader = TPPDataLoader(config) + stats = tpp_loader.get_statistics(split='train') + print(stats) + tpp_loader.plot_event_type_distribution() + tpp_loader.plot_event_delta_times_distribution() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/data_loader.py b/examples/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..4f003563b0d3e80c6a357aa8a9618736ba59092e --- /dev/null +++ b/examples/data_loader.py @@ -0,0 +1,52 @@ +import random + +from easy_tpp.config_factory import DataSpecConfig +from easy_tpp.preprocess import EventTokenizer +from easy_tpp.preprocess.dataset import TPPDataset, get_data_loader + + +def make_raw_data(): + data = [ + [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 0}], + [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1}], + [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1}], + ] + for i, j in enumerate([2, 5, 3]): + start_time = 0 + for k in range(j): + delta_t = random.random() + start_time += delta_t + data[i].append({"time_since_last_event": delta_t, + "time_since_start": start_time, + "type_event": random.randint(0, 10) + }) + + return data + + +def main(): + source_data = make_raw_data() + + time_seqs = [[x["time_since_start"] for x in seq] for seq in source_data] + type_seqs = [[x["type_event"] for x in seq] for seq in source_data] + time_delta_seqs = [[x["time_since_last_event"] for x in seq] for seq in source_data] + + input_data = {'time_seqs': time_seqs, + 'type_seqs': type_seqs, + 'time_delta_seqs': time_delta_seqs} + + config = DataSpecConfig.parse_from_yaml_config({'num_event_types': 11, 'batch_size': 1, + 'pad_token_id': 11}) + + dataset = TPPDataset(input_data) + + tokenizer = EventTokenizer(config) + + loader = get_data_loader(dataset, 'torch', tokenizer) + + for batch in loader: + print(batch) + + +if __name__ == '__main__': + main() diff --git a/examples/event_tokenizer.py b/examples/event_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..363c84be79f491ecffbfbd8fa95e815f67273d37 --- /dev/null +++ b/examples/event_tokenizer.py @@ -0,0 +1,46 @@ +import random + +from easy_tpp.preprocess.event_tokenizer import EventTokenizer +from easy_tpp.config_factory import DataSpecConfig + +def make_raw_data(): + data = [ + [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 0}], + [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1}], + [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1}], + ] + for i, j in enumerate([2, 5, 3]): + start_time = 0 + for k in range(j): + delta_t = random.random() + start_time += delta_t + data[i].append({"time_since_last_event": delta_t, + "time_since_start": start_time, + "type_event": random.randint(0, 10) + }) + + return data + + +def main(): + source_data = make_raw_data() + + time_seqs = [[x["time_since_start"] for x in seq] for seq in source_data] + type_seqs = [[x["type_event"] for x in seq] for seq in source_data] + time_delta_seqs = [[x["time_since_last_event"] for x in seq] for seq in source_data] + + input_data = {'time_seqs': time_seqs, + 'type_seqs': type_seqs, + 'time_delta_seqs': time_delta_seqs} + + config = DataSpecConfig.parse_from_yaml_config({'num_event_types': 11, 'pad_token_id': 11}) + + tokenizer = EventTokenizer(config) + + output = tokenizer.pad(input_data, return_tensors='pt') + + print(output) + + +if __name__ == '__main__': + main() diff --git a/examples/gen_synthetic_data.py b/examples/gen_synthetic_data.py new file mode 100644 index 0000000000000000000000000000000000000000..5cdb15b36f5171c4b477c318f49215fcd8c0d21d --- /dev/null +++ b/examples/gen_synthetic_data.py @@ -0,0 +1,10 @@ +from easy_tpp.utils.gen_utils import generate_and_save_json + +if __name__ == "__main__": + generate_and_save_json(n_nodes=3, + end_time=100, + baseline=1, + adjacency=0.5, + decay=0.1, + max_seq_len=40, + target_file='synthetic_data.json') diff --git a/examples/hf_data_loader.py b/examples/hf_data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..117b81de56a7c246a5888a5c327ad473ca316e20 --- /dev/null +++ b/examples/hf_data_loader.py @@ -0,0 +1,62 @@ +from datasets import load_dataset + +def load_data_from_hf(hf_dir=None, local_dir=None): + if hf_dir: + ds = load_dataset(hf_dir) + else: + ds = load_dataset('json', data_files=local_dir) + + print("Dataset structure:") + print(ds) + + # Print available features for validation split + print("\nValidation split features:") + print(ds['validation'].features) + + # Try to access metadata fields if they exist, otherwise show available data + try: + print('\ndim process: ' + str(ds['validation'].data['dim_process'][0].as_py())) + except (KeyError, IndexError): + print("dim_process field not found in dataset") + + try: + print('num seqs: ' + str(ds['validation'].data['num_seqs'][0].as_py())) + except (KeyError, IndexError): + print("num_seqs field not found in dataset") + + try: + print('avg seq len: ' + str(ds['validation'].data['avg_seq_len'][0].as_py())) + except (KeyError, IndexError): + print("avg_seq_len field not found in dataset") + + try: + print('min seq len: ' + str(ds['validation'].data['min_seq_len'][0].as_py())) + except (KeyError, IndexError): + print("min_seq_len field not found in dataset") + + try: + print('max seq len: ' + str(ds['validation'].data['max_seq_len'][0].as_py())) + except (KeyError, IndexError): + print("max_seq_len field not found in dataset") + + # Show actual data structure + print("\nFirst few examples from validation split:") + for i, example in enumerate(ds['validation']): + if i < 3: # Show first 3 examples + print(f"Example {i}:") + for key, value in example.items(): + if isinstance(value, list) and len(value) > 10: + print(f" {key}: {value[:5]}... (length: {len(value)})") + else: + print(f" {key}: {value}") + else: + break + + return ds + + +if __name__ == '__main__': + # in case one fails to load from hf directly + # one can load the json data file locally + # load_data_from_hf(hf_dir=None, local_dir={'validation':'dev.json'}) + load_data_from_hf(hf_dir='easytpp/taxi') \ No newline at end of file diff --git a/examples/script_data_processing/earthquake.py b/examples/script_data_processing/earthquake.py new file mode 100644 index 0000000000000000000000000000000000000000..0efedeaf0cd5f21ef4c6dfe3661ed811014883f4 --- /dev/null +++ b/examples/script_data_processing/earthquake.py @@ -0,0 +1,110 @@ +import pickle +import warnings +from datetime import datetime + +import numpy as np +import pandas as pd + +warnings.filterwarnings('ignore') + + +# data source: https://earthquake.usgs.gov/earthquakes/search/ + +def event_type_map(mag): + if mag < 2.75: + return 0 + elif mag < 3.0: + return 1 + elif mag < 3.5: + return 2 + elif mag < 4.0: + return 3 + elif mag < 4.5: + return 4 + elif mag < 5.0: + return 5 + else: + return 6 + + +def clean_csv(source_dir): + df = pd.read_csv(source_dir, header=0) + + df.drop_duplicates(inplace=True) + + df.sort_values(by=['time'], inplace=True) + print(len(df)) + df = df[['time', 'mag']] + df['event_type'] = df['mag'].apply(lambda x: event_type_map(x)) + + df.to_csv('earthquake.csv', index=False, header=True) + return + + +def make_seq(df): + seq = [] + df['time_diff'] = df['event_time'].diff() + df.index = np.arange(len(df)) + for index, row in df.iterrows(): + if index == 0: + event_dict = {"time_since_last_event": 0.0, + "time_since_start": 0.0, + "type_event": row['event_type'] + } + start_event_time = row['event_time'] + else: + event_dict = {"time_since_last_event": row['time_diff'], + "time_since_start": row['event_time'] - start_event_time, + "type_event": row['event_type'] + } + seq.append(event_dict) + + return seq + + +def make_pkl(target_dir, dim_process, split, seqs): + with open(target_dir, "wb") as f_out: + pickle.dump( + { + "dim_process": dim_process, + split: seqs + }, f_out + ) + return + + +def make_dataset(source_dir): + df = pd.read_csv(source_dir, header=0) + df['time'] = pd.to_datetime(df['time']) + + norm_const = 10000 + df['event_time'] = df['time'].apply(lambda x: datetime.timestamp(x)) / norm_const + seq_len = np.random.randint(15, 19, 4300) + print(np.sum(seq_len)) + + seq_start_idx = [0] + list(np.cumsum(seq_len)[:-1] - 1) + seq_end_idx = np.cumsum(seq_len) - 1 + + total_seq = [make_seq(df.iloc[start_idx:end_idx, :]) for (start_idx, end_idx) in + zip(seq_start_idx, seq_end_idx)] + + print(len(total_seq)) + make_pkl('train.pkl', 7, 'train', total_seq[:3000]) + print(np.sum(seq_len[:3000])) + make_pkl('dev.pkl', 7, 'dev', total_seq[3000:3400]) + print(np.sum(seq_len[3000:3400])) + make_pkl('test.pkl', 7, 'test', total_seq[3400:]) + print(np.sum(seq_len[3400:])) + + # 70794 + # 4300 + # 49364 + # 6612 + # 14818 + + return + + +if __name__ == '__main__': + # clean_csv() + make_dataset('earthquake.csv') diff --git a/examples/script_data_processing/make_hf_dataset.py b/examples/script_data_processing/make_hf_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e7369cdc09c4a169a68a261171dfec4b308b3009 --- /dev/null +++ b/examples/script_data_processing/make_hf_dataset.py @@ -0,0 +1,66 @@ +import json + +import numpy as np + +from easy_tpp.utils import load_pickle + + +def make_json_serializable(input_dict): + for k, v in input_dict.items(): + if isinstance(v, np.float32): + input_dict[k] = float(v) + elif isinstance(v, np.int32): + input_dict[k] = int(v) + + return input_dict + + +def make_hf_dataset(source_dir, target_dir, split='test'): + data_pkl = load_pickle(source_dir) + + dim_process = int(data_pkl['dim_process']) + + data_json = [] + for idx, seq in enumerate(data_pkl[split]): + seq_len = len(seq) + time_since_start, time_since_last_event, type_event = [], [], [] + for idx_event, event in enumerate(data_pkl[split][idx]): + # if idx_event == 0 and event['time_since_start'] > 0: + # start_timestamp = event['time_since_start'] + # else: + # start_timestamp = 0 + if idx_event == 0 and event['time_since_last_event'] > 0: + event['time_since_last_event'] = 0 + + # event['time_since_start'] -= start_timestamp + + event = make_json_serializable(event) + time_since_start.append(time_since_start) + time_since_last_event.append(event['time_since_last_event']) + type_event.append(event['type_event']) + + # re-calculate the time_since start + from itertools import accumulate + time_since_start = list(accumulate(time_since_last_event)) + + temp_dict = {'dim_process': dim_process, + 'seq_idx': idx, + 'seq_len': seq_len, + 'time_since_start': time_since_start, + 'time_since_last_event': time_since_last_event, + 'type_event': type_event} + data_json.append(temp_dict) + + with open(target_dir, "w") as outfile: + json.dump(data_json, outfile) + + return + + +if __name__ == '__main__': + test_data_dir = ['amazon/test.pkl', 'amazon/test.json'] + dev_data_dir = ['amazon/dev.pkl', 'amazon/dev.json'] + train_data_dir = ['amazon/train.pkl', 'amazon/train.json'] + make_hf_dataset(source_dir=test_data_dir[0], target_dir=test_data_dir[1]) + make_hf_dataset(source_dir=dev_data_dir[0], target_dir=dev_data_dir[1], split='dev') + make_hf_dataset(source_dir=train_data_dir[0], target_dir=train_data_dir[1], split='train') diff --git a/examples/script_data_processing/taobao.py b/examples/script_data_processing/taobao.py new file mode 100644 index 0000000000000000000000000000000000000000..64b44bbbdab8df190437d5e50c515058c1e45b82 --- /dev/null +++ b/examples/script_data_processing/taobao.py @@ -0,0 +1,156 @@ +import pickle +import warnings + +import numpy as np +import pandas as pd + +warnings.filterwarnings('ignore') + + +# source data: https://tianchi.aliyun.com/dataset/dataDetail?dataId=649 + +def check_dominate_event_type(event_type_seq, threshold=0.7): + event_type = np.unique(event_type_seq) + total_len = len(event_type_seq) + type_ratio = [len(event_type_seq[event_type_seq == event_type_i]) / total_len for event_type_i in event_type] + + return True if max(type_ratio) > threshold else False + + +def cate_map(cate_id, cate_event_map_df): + res = cate_event_map_df[cate_event_map_df['cate'] == cate_id]['event_id'].to_list()[0] + return res + + +def read_data_step_3(source_dir, cate_dir, target_dir): + train_df = pd.read_csv(source_dir, header=0) + + cate_event_map_df = pd.read_csv(cate_dir, header=0) + + train_df['event_type'] = train_df['cate_id'].apply(lambda x: cate_map(x, cate_event_map_df)) + print(train_df['event_type'].value_counts(normalize=True)) + unique_user_id = np.unique(train_df['user_id']) + + for idx, user_id in enumerate(unique_user_id): + user_df = train_df[train_df['user_id'] == user_id] + prev_time = user_df.iloc[0, 4] + event_dtime = user_df['event_dtime'].values + event_time = user_df['event_time'].values + event_dtime[0] = 0.0 + + for i in range(1, len(event_time)): + if event_dtime[i] > 50.0: # too large interval + rand_dt = np.random.random() + 0.1 + event_time[i] = prev_time + rand_dt + event_dtime[i] = rand_dt + else: + event_time[i] = event_time[i - 1] + event_dtime[i] + prev_time = event_time[i] + + user_df['event_dtime'] = event_dtime + user_df['event_time'] = event_time + + print(min(event_dtime[1:]), max(event_dtime)) + + assert abs(np.mean(user_df['event_time'].diff().values[1:]) - np.mean(event_dtime[1:])) < 0.0001 + + train_df.to_csv(target_dir) + return + + +def read_data_step_2(source_dir): + train_df = pd.read_csv(source_dir, header=None) + train_df.columns = ['user_id', 'item_id', 'cate_id', 'event_type_raw', 'event_time'] + count = train_df['cate_id'].value_counts(normalize=True) + pd.DataFrame(count).to_csv('taobao_map.csv', header=True) + + return + + +def read_data_step_1(source_dir, target_dir): + train_df = pd.read_csv(source_dir, header=None) + train_df.columns = ['user_id', 'item_id', 'cate_id', 'event_type_raw', 'event_time'] + train_df['event_time'] /= 10000 + unique_user_id = np.unique(train_df['user_id']) + + train_df = train_df[train_df['event_type_raw'] == 'pv'] + + res = pd.DataFrame() + total_seq = 0 + + for idx, user_id in enumerate(unique_user_id): + print(f'user {idx}') + user_df = train_df[train_df['user_id'] == user_id] + + # drop consecutive duplicate on pv + user_df = user_df.loc[user_df['cate_id'].shift() != user_df['cate_id']] + user_df.fillna(0.0, inplace=True) + + user_df.sort_values(by=['event_time'], inplace=True) + user_df['event_dtime'] = user_df['event_time'].diff() + + user_df.fillna(0.0, inplace=True) + + # drop dtime < 0.05 + user_df = user_df[user_df['event_dtime'] > 0.1] + + if len(user_df) < 40: + print('user seq is too short, skip it') + continue + + total_seq += 1 + print(f'{total_seq} users have been recorded') + res = pd.concat([res, user_df]) + if total_seq > 2000: + break + + res.to_csv(target_dir, header=True, index=False) + + return + + +def save_data(source_dir): + df = pd.read_csv(source_dir, header=0) + unique_user_id = np.unique(df['user_id']) + res = [] + print(np.unique(df['event_type'])) + for idx, user_id in enumerate(unique_user_id): + print(f'user {idx}') + user_seq = [] + user_df = df[df['user_id'] == user_id] + length = 0 + for idx_row, row in user_df.iterrows(): + event_dtime = 0 if length == 0 else row['event_dtime'] + user_seq.append({"time_since_last_event": event_dtime, + "time_since_start": row['event_time'], + "type_event": row['event_type'] + }) + length += 1 + + res.append(user_seq) + + with open('../data/taobao/train.pkl', "wb") as f_out: + pickle.dump( + { + "dim_process": 17, + 'train': res[:1300] + }, f_out + ) + + with open('../data/taobao/dev.pkl', "wb") as f_out: + pickle.dump( + { + "dim_process": 17, + 'dev': res[1300:1500] + }, f_out + ) + + with open('../data/taobao/test.pkl', "wb") as f_out: + pickle.dump( + { + "dim_process": 17, + 'test': res[1500:] + }, f_out + ) + + return diff --git a/examples/script_data_processing/taxi.py b/examples/script_data_processing/taxi.py new file mode 100644 index 0000000000000000000000000000000000000000..8eae048d44c9bc13b5f8922fa675b9fa84d1dc76 --- /dev/null +++ b/examples/script_data_processing/taxi.py @@ -0,0 +1,52 @@ +import pickle +import warnings + +warnings.filterwarnings('ignore') + +def read_data_step_1(): + + def read_pkl(file_dir): + res = [] + taxi = pickle.load(open(file_dir, "rb" )) + count = 0 + for seq in taxi['seqs']: + if len(seq) > 34: + count += 1 + res.append(seq) + # print(np.max(seq['time_since_last_event'])) + print(count) + return res + + # from Mei et al 's paper on event imputation + train_res = read_pkl('pilottaxi/big/train.pkl') + dev_res = read_pkl('pilottaxi/big/dev.pkl') + test_res = read_pkl('pilottaxi/big/test1.pkl') + + with open('../data/taxi/train.pkl', "wb") as f_out: + pickle.dump( + { + "dim_process": 10, + 'train': train_res[:1500] + }, f_out + ) + + with open('../data/taxi/dev.pkl', "wb") as f_out: + pickle.dump( + { + "dim_process": 10, + 'dev': dev_res[:200] + }, f_out + ) + + with open('../data/taxi/test.pkl', "wb") as f_out: + pickle.dump( + { + "dim_process": 10, + 'test': test_res[:400] + }, f_out + ) + + return + +if __name__ == '__main__': + read_data_step_1() \ No newline at end of file diff --git a/examples/script_data_processing/volcano.py b/examples/script_data_processing/volcano.py new file mode 100644 index 0000000000000000000000000000000000000000..c0736313df01ed52ac3e1f14cb5b0d111e494332 --- /dev/null +++ b/examples/script_data_processing/volcano.py @@ -0,0 +1,105 @@ +import datetime +import pickle +import warnings + +import numpy as np +import pandas as pd + +warnings.filterwarnings('ignore') + + +def make_datetime(year, month, day): + try: + date = datetime.datetime(int(year), int(month), int(day)) + except ValueError as e: + if e.args[0] == 'day is out of range for month': + date = datetime.datetime(int(year), int(month), int(day)-1) + return datetime.datetime.timestamp(date) + 61851630000 # make sure the timestamp is positive + + +def clean_csv(): + source_dir = 'events.csv' + + df = pd.read_csv(source_dir, header=0) + + df = df[~df['event_date_year'].isna()] + df = df[df['event_date_year'] > 0] + df['event_date_month'].fillna(1, inplace=True) + df['event_date_day'].fillna(1, inplace=True) + df.drop_duplicates(inplace=True) + norm_const = 1000000 + df['event_timestamp'] = df.apply( + lambda x: make_datetime(x['event_date_year'], x['event_date_month'], x['event_date_day']), + axis=1)/norm_const + df.sort_values(by=['event_date_year', 'event_date_month', 'event_date_day'], inplace=True) + df['event_type'] = [0] * len(df) + + df.to_csv('volcano.csv', index=False, header=True) + return + + +def make_seq(df): + seq = [] + df['time_diff'] = df['event_timestamp'].diff() + df.index = np.arange(len(df)) + for index, row in df.iterrows(): + if index == 0: + event_dict = {"time_since_last_event": 0.0, + "time_since_start": 0.0, + "type_event": row['event_type'] + } + start_event_time = row['event_timestamp'] + else: + event_dict = {"time_since_last_event": row['time_diff'], + "time_since_start": row['event_timestamp'] - start_event_time, + "type_event": row['event_type'] + } + seq.append(event_dict) + + return seq + + +def make_pkl(target_dir, dim_process, split, seqs): + with open(target_dir, "wb") as f_out: + pickle.dump( + { + "dim_process": dim_process, + split: seqs + }, f_out + ) + return + + +def make_dataset(source_dir): + df = pd.read_csv(source_dir, header=0) + + vols = np.unique(df['volcano_name']) + total_seq = [] + for vol in vols: + df_ = df[df['volcano_name'] == vol] + df_.sort_values('event_timestamp', inplace=True) + total_seq.append(make_seq(df_)) + + + print(len(total_seq)) + make_pkl('train.pkl', 1, 'train', total_seq[:400]) + count_seq(total_seq[:400]) + make_pkl('dev.pkl', 1, 'dev', total_seq[400:450]) + count_seq(total_seq[400:450]) + make_pkl('test.pkl', 1, 'test', total_seq[450:]) + count_seq(total_seq[450:]) + + + return + + +def count_seq(seqs): + total_len = [len(seq) for seq in seqs] + print(np.mean(total_len)) + print(np.sum(total_len)) + + return + +if __name__ == '__main__': + # clean_csv() + make_dataset('volcano.csv') diff --git a/examples/synthetic_data.json b/examples/synthetic_data.json new file mode 100644 index 0000000000000000000000000000000000000000..eabc132dc87c450b4a7182dffe347263247ad5b8 --- /dev/null +++ b/examples/synthetic_data.json @@ -0,0 +1,987 @@ +[ + { + "dim_process": 3, + "seq_idx": 0, + "seq_len": 40, + "time_since_start": [ + 0.0, + 0.42086957652699325, + 1.385602268596856, + 1.424026474575376, + 1.7178400031639305, + 2.2605841408802134, + 2.301889005026473, + 2.4109621864239537, + 3.41213823098126, + 3.5691357672510318, + 3.636241821169924, + 3.867667527351244, + 4.043225313731539, + 4.918625817456199, + 4.979610747665842, + 5.323877683595276, + 5.737097451671871, + 5.8768209886411595, + 6.077383561983641, + 6.851224675407036, + 6.902069805040336, + 7.122899591785588, + 7.739227419142382, + 8.809729248458874, + 8.84679204431348, + 9.047153721465744, + 9.717751695477991, + 10.582909241737948, + 11.027245898896854, + 11.27394997658292, + 11.295775013544421, + 11.538774116917368, + 11.541904431831128, + 11.819460297042518, + 12.376159230211185, + 12.525532026571218, + 12.526613083044994, + 12.823684319510495, + 13.191426558062, + 13.331223097362722 + ], + "time_since_last_event": [ + 0.10781775661111342, + 0.5286873331381067, + 1.4934200252079695, + 0.03842420597851981, + 0.2938135285885546, + 2.2605841408802134, + 1.8810194284994801, + 0.6931221832600234, + 1.1515540901010466, + 0.15699753626977175, + 0.06710605391889235, + 0.23142570618131986, + 0.17555778638029462, + 0.8754005037246602, + 2.6777217426393687, + 0.3442669359294337, + 0.413219768076595, + 0.9581951711849603, + 0.20056257334248162, + 1.1141272237351654, + 4.491107618616382, + 1.0455160298019468, + 0.837157614102046, + 1.6868296566732868, + 1.1075646251710989, + 0.20036167715226405, + 0.6705979740122476, + 3.7316845663309124, + 2.2175166504379806, + 1.556198281104928, + 0.021825036961502065, + 0.9558648751794205, + 0.5146585329342734, + 0.2775558652113901, + 0.5566989331686667, + 0.9867579096538499, + 1.2308380695005727, + 0.2970712364655004, + 0.36774223855150545, + 0.1397965393007219 + ], + "type_event": [ + 0, + 2, + 1, + 1, + 1, + 0, + 2, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 2, + 2, + 2, + 0, + 0, + 2, + 1, + 0, + 1, + 0, + 1, + 1, + 1, + 2, + 0, + 1, + 1, + 2, + 0, + 0, + 0, + 2, + 1, + 1, + 1, + 1 + ] + }, + { + "dim_process": 3, + "seq_idx": 1, + "seq_len": 40, + "time_since_start": [ + 0.0, + 0.3471017403361518, + 0.7342550839632178, + 0.8471490022195951, + 1.1929590865712605, + 1.3317610899432033, + 1.505477006203563, + 3.115466803274767, + 3.1398290904188038, + 3.179831044316419, + 3.276885387142677, + 3.605949476683932, + 4.8703080011971025, + 4.888492291627866, + 6.0231232143034426, + 6.237864227799694, + 6.304928546372096, + 6.436213379510246, + 6.7270405273626075, + 7.249370401657135, + 7.39294684307702, + 7.5220412848241445, + 8.488473520943874, + 8.598979385467668, + 8.935339636286027, + 9.129493033115097, + 9.130827787293956, + 9.209265243142237, + 9.484788834784812, + 9.68614725822453, + 10.18856016102221, + 10.252664643152702, + 10.51952095459652, + 10.641348929403481, + 10.768016833763447, + 10.939742131904282, + 11.07780990674102, + 11.092078882076088, + 11.21680379710266, + 11.863278613149799 + ], + "time_since_last_event": [ + 0.23486586175346957, + 1.5370314692411586, + 1.774812016508191, + 0.11289391825637729, + 0.3458100843516654, + 1.3317610899432033, + 0.17371591626035965, + 1.9225077167035067, + 1.6343520842152408, + 0.06436424104165184, + 0.13705629672387332, + 3.2588477363477804, + 1.2643585245131703, + 1.611606904485189, + 1.15281521310634, + 1.3493719361718277, + 0.06706431857240247, + 3.256382335193827, + 0.7039173130591649, + 0.9444418552850387, + 0.6659063157144125, + 0.12909444174712448, + 2.052260141433628, + 1.3496089838105334, + 0.4468661153421536, + 0.5305136476474281, + 1.6087865024698118, + 0.07977221002714074, + 0.35396104749085566, + 0.47688201508229255, + 1.2532205247361823, + 0.76787580836789, + 0.26685631144381716, + 0.12182797480696195, + 0.5794566727412374, + 1.2535948736797522, + 0.13806777483673827, + 0.014268975335067324, + 0.12472491502657235, + 1.0952617793863517 + ], + "type_event": [ + 1, + 0, + 2, + 2, + 2, + 1, + 1, + 2, + 1, + 2, + 1, + 0, + 0, + 1, + 0, + 1, + 1, + 2, + 0, + 1, + 0, + 0, + 2, + 1, + 2, + 1, + 0, + 1, + 0, + 1, + 2, + 0, + 0, + 0, + 2, + 1, + 1, + 1, + 1, + 2 + ] + }, + { + "dim_process": 3, + "seq_idx": 2, + "seq_len": 40, + "time_since_start": [ + 0.0, + 0.0835389271618503, + 0.3415719690968011, + 0.5340965525389834, + 0.5932160750694919, + 0.6488294414801139, + 0.7792970960726642, + 1.326404746996122, + 1.3909145800310014, + 1.4680285691802872, + 1.9527988326428947, + 2.3358990201982195, + 2.407883832583252, + 2.413700503802339, + 2.7479362404094587, + 2.8412946057245527, + 3.105645109951716, + 3.661666124317133, + 3.6916858006822615, + 3.9092412231509037, + 3.987031812875376, + 4.161523819553075, + 5.2622920642277045, + 5.348513737580717, + 5.531624013402865, + 5.71939355882191, + 5.870036266064751, + 5.885231905738095, + 6.048098133193701, + 6.225220295525897, + 6.325630879652014, + 6.413536185349383, + 6.443055510538155, + 7.45185576298282, + 7.547509271668591, + 8.638304376144621, + 8.883646722258437, + 10.086409132175074, + 10.205060209985664, + 10.850965858064288 + ], + "time_since_last_event": [ + 1.3644790789042887, + 0.8725631383669601, + 0.48412136425477215, + 0.19252458344218226, + 0.5096771479076416, + 0.05561336641062198, + 0.24520054353368081, + 0.6775753055160081, + 1.3909145800310014, + 0.0771139891492858, + 0.48477026346260743, + 1.5566019241255553, + 0.07198481238503263, + 1.087295756806217, + 0.795137407766564, + 0.093358365315094, + 0.26435050422716344, + 0.5560210143654167, + 1.2779852968799226, + 1.5013573905676516, + 0.325365688558243, + 0.2522825964021713, + 1.570606263545443, + 0.08622167335301256, + 1.3701001938497903, + 0.18776954541904445, + 0.5215225284840344, + 1.8982000928627194, + 0.17806186712894956, + 0.339988389787802, + 0.606237320830104, + 0.18831588982348535, + 0.029519325188772427, + 1.008800252444665, + 0.09565350868577127, + 2.59020624295092, + 1.3361374505898453, + 1.4481047560304532, + 3.8794293303336502, + 0.6459056480786245 + ], + "type_event": [ + 0, + 1, + 2, + 2, + 1, + 1, + 2, + 1, + 0, + 0, + 0, + 2, + 2, + 1, + 0, + 0, + 0, + 0, + 1, + 2, + 0, + 2, + 1, + 1, + 2, + 2, + 1, + 0, + 1, + 0, + 2, + 0, + 0, + 0, + 0, + 1, + 0, + 1, + 2, + 2 + ] + }, + { + "dim_process": 3, + "seq_idx": 3, + "seq_len": 40, + "time_since_start": [ + 0.0, + 0.5057659306827347, + 1.257214062528, + 1.273849921713257, + 1.5588413451190135, + 1.7112787496534878, + 2.1908723937283696, + 2.9082913162192625, + 3.78851236518269, + 3.8146499591013168, + 4.0146586878427115, + 4.362935003689699, + 4.7318297459598355, + 5.011139726020957, + 5.607932457683674, + 5.912184748343371, + 6.017378567476413, + 6.2218398767317, + 6.648478446305454, + 6.914285506354226, + 7.557606650154526, + 7.65347655033348, + 7.888113400672381, + 8.090802954874917, + 8.15087393876626, + 9.677474789259527, + 10.252572426271463, + 10.293284058302213, + 10.441628087776024, + 10.531458512190596, + 10.838059842108557, + 10.935007289252205, + 11.435856722511687, + 11.803055404293474, + 12.468247863031294, + 12.686787410265389, + 13.01122917818713, + 13.368248042027481, + 13.388774344555742, + 13.738590987286365 + ], + "time_since_last_event": [ + 1.0695594700363813, + 2.7780878106357534, + 1.562216806675167, + 0.7680839910305224, + 0.28499142340575645, + 0.45406468712548786, + 0.6320310486093561, + 0.7174189224908929, + 2.0772336155292024, + 0.9063586428820543, + 0.20000872874139475, + 0.5744226385070093, + 4.7318297459598355, + 0.9964810381782456, + 1.2449974539939745, + 1.1803550023835356, + 1.0062388414554562, + 0.30965512838832865, + 0.426638569573754, + 0.26580706004877186, + 1.5402280826781123, + 0.739191043979254, + 0.3305067505178556, + 2.482870497191243, + 0.2627605380938789, + 2.0239982389260476, + 0.5750976370119361, + 2.2024811034272957, + 0.1890556615045611, + 2.3805845734243363, + 0.3066013299179602, + 0.09694744714364845, + 1.1425726642094745, + 0.8680481150412689, + 1.0323911405196071, + 0.8837320059719147, + 0.3244417679217406, + 0.3570188638403522, + 0.020526302528260487, + 1.2703431242550707 + ], + "type_event": [ + 1, + 0, + 2, + 0, + 0, + 2, + 0, + 0, + 2, + 0, + 0, + 2, + 1, + 0, + 2, + 1, + 0, + 1, + 1, + 1, + 0, + 1, + 0, + 2, + 0, + 1, + 1, + 2, + 1, + 0, + 0, + 0, + 2, + 0, + 2, + 0, + 0, + 0, + 0, + 2 + ] + }, + { + "dim_process": 3, + "seq_idx": 4, + "seq_len": 40, + "time_since_start": [ + 0.0, + 0.2629993444184251, + 0.6355563358983503, + 0.726971599714048, + 1.005208401176482, + 1.3336808113572545, + 1.4741932415242829, + 1.9410304274461936, + 2.013981457716163, + 2.9756912668971296, + 3.191094637623358, + 3.2456713841570135, + 3.2606916231387117, + 3.4541972683797226, + 3.478125295741677, + 3.838202479244302, + 3.8455883128337973, + 4.179801630236142, + 4.426297043989173, + 4.784822972518647, + 5.0686735193473, + 5.473549012416349, + 5.846572103038568, + 6.043745098665958, + 7.0637836298308585, + 7.322571308350703, + 7.643979841763326, + 7.914473878288696, + 8.124260506025514, + 8.486921974094997, + 8.707803642043778, + 9.237127055446344, + 9.72379895700012, + 10.055575306451871, + 10.223175655492206, + 10.575180893162141, + 10.578481306265871, + 11.197304735444511, + 11.622040518419958, + 11.88742475354072 + ], + "time_since_last_event": [ + 3.7621831500428797, + 0.7282195949509642, + 1.4505932291615125, + 0.09141526381569776, + 1.005208401176482, + 0.32847241018077256, + 0.1405124301670284, + 1.2140588277321456, + 0.5397882161918801, + 0.9617098091809666, + 2.928095293204933, + 0.2699801172598839, + 0.015020238981698242, + 1.513166840933529, + 0.2870306581183186, + 0.5775108561055902, + 0.007385833589495405, + 0.7256043618564192, + 0.9481717482474963, + 0.9392346596848498, + 0.888871889111158, + 1.047251968427176, + 1.061749130519921, + 0.5701960862496094, + 1.0200385311649, + 2.2538977890034033, + 0.32140853341262243, + 2.067901775250128, + 0.48028066426218885, + 0.36266146806948285, + 0.7933297637550822, + 0.7502050813513463, + 1.015995314956342, + 2.9917916766210126, + 0.1676003490403346, + 1.3380538377157976, + 0.3553056507736656, + 1.473505778444391, + 1.0468596252578166, + 0.2653842351207629 + ], + "type_event": [ + 1, + 2, + 0, + 0, + 1, + 1, + 1, + 0, + 1, + 1, + 2, + 1, + 1, + 0, + 2, + 1, + 1, + 0, + 2, + 1, + 0, + 2, + 1, + 2, + 2, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 1, + 2, + 2, + 0, + 2, + 1, + 0, + 0 + ] + }, + { + "dim_process": 3, + "seq_idx": 5, + "seq_len": 40, + "time_since_start": [ + 0.0, + 0.04116392793800827, + 0.9276586291640072, + 1.672186093253707, + 1.920373369464734, + 2.7347273307850557, + 3.2151736548244187, + 3.458535559053594, + 3.5179103002773147, + 4.2957361391840365, + 4.9364158471454616, + 5.907947990465509, + 5.943179885493741, + 6.1374826734423, + 6.344631696320242, + 6.676118300230449, + 6.827669659199678, + 8.308219476007444, + 8.409452858802418, + 8.521113264196451, + 9.123452739734951, + 9.253560404627784, + 9.293960661531138, + 9.647926308998386, + 9.729658490669642, + 9.848942066911029, + 10.091932395040082, + 10.609605725083703, + 11.304002487835497, + 11.943041867562805, + 12.62325682842566, + 12.79544148586826, + 13.55235362023604, + 13.5578562398615, + 13.677386969951094, + 13.892990860353827, + 13.956000499772934, + 14.628992444223421, + 14.895006909798433, + 15.210640812727512 + ], + "time_since_last_event": [ + 1.3139629953555811, + 0.7363034941149493, + 0.8864947012259989, + 1.672186093253707, + 1.9253929175454658, + 1.0625412375313488, + 0.48044632403936305, + 1.53816218958886, + 0.3027366454528959, + 0.8372005801304425, + 1.418505546868147, + 1.6122118512814723, + 5.015521256329734, + 1.201066826296838, + 0.43668370585473326, + 0.3314866039102071, + 0.6901869857573786, + 2.365039590513703, + 1.7333345585719684, + 0.11166040539403355, + 0.6023394755384999, + 0.9453409286203396, + 0.17050792179618668, + 0.35396564746724835, + 2.9019888314699642, + 0.5953816622832449, + 0.36227390437043994, + 0.7606636581726747, + 0.6943967627517935, + 1.8511094725227224, + 0.6802149608628554, + 0.1721846574425996, + 3.9044273112376544, + 2.2538537520260036, + 0.8819454840828342, + 0.335134620492326, + 0.40364687953689327, + 0.7360015838695944, + 0.9390064100254989, + 0.5816483685040907 + ], + "type_event": [ + 2, + 1, + 1, + 2, + 0, + 2, + 2, + 0, + 2, + 0, + 2, + 0, + 1, + 2, + 0, + 0, + 2, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 2, + 1, + 2, + 1, + 1, + 2, + 2, + 2, + 0, + 1, + 2, + 1, + 0, + 1, + 0, + 1 + ] + }, + { + "dim_process": 3, + "seq_idx": 6, + "seq_len": 40, + "time_since_start": [ + 0.0, + 0.17821624587439544, + 0.5176117554715205, + 0.5380318618628763, + 1.0700165471748306, + 3.2432032966767395, + 5.38267578063288, + 5.385707041094108, + 5.441018622294536, + 5.534433889783443, + 5.9640115001181755, + 6.313304481408068, + 6.483870351189552, + 6.523443054905513, + 6.760423744952689, + 7.1748194059020705, + 8.127342485401925, + 8.264984193554668, + 8.483245633755828, + 9.801732632176993, + 9.838784402134678, + 10.209904957308012, + 10.281687733211697, + 10.352358405215114, + 10.814676511876883, + 10.852440091195561, + 10.86686670545862, + 10.892637421930658, + 10.97301343179609, + 11.0275101354691, + 12.665242854871025, + 12.972589286887569, + 13.304964719420553, + 13.478552033423767, + 13.62430668167579, + 13.857570434117463, + 13.86212558819284, + 14.122798591907639, + 14.205862660103989, + 14.433571226643664 + ], + "time_since_last_event": [ + 0.49857254491863046, + 0.3611548878639468, + 0.5176117554715205, + 2.2542243466288454, + 0.5524047917033101, + 2.705171434813863, + 4.312659233458049, + 5.207490795219712, + 0.055311581200427895, + 2.291230593106704, + 0.5813357194852955, + 0.8722858591135321, + 0.9494364614061084, + 0.039572703715961666, + 0.2369806900471758, + 0.4143956609493813, + 1.8140380039938577, + 2.300972693436492, + 0.2182614402011609, + 1.6743901467750675, + 0.03705176995768511, + 1.7266593235521839, + 0.4429033310770194, + 0.14245344790710135, + 0.462318106661769, + 3.6776206852934905, + 0.5851789722469221, + 0.07796091005377548, + 0.12057334060052938, + 0.1348727135384422, + 1.6377327194019244, + 0.3073464320165442, + 0.3323754325329844, + 2.505538601627677, + 0.3193419622552369, + 0.3790184006936954, + 2.9952588827342197, + 0.4984919102318486, + 0.3482922259865262, + 0.22770856653967542 + ], + "type_event": [ + 0, + 1, + 0, + 2, + 0, + 2, + 0, + 1, + 1, + 2, + 0, + 1, + 2, + 2, + 2, + 2, + 1, + 0, + 0, + 1, + 1, + 0, + 1, + 0, + 0, + 2, + 1, + 0, + 2, + 0, + 0, + 0, + 0, + 2, + 0, + 2, + 1, + 0, + 2, + 2 + ] + }, + { + "dim_process": 3, + "seq_idx": 7, + "seq_len": 19, + "time_since_start": [ + 0.0, + 0.6745929461191906, + 0.7825554915484929, + 1.1916248138883248, + 2.316605230987207, + 2.407098385912306, + 2.4164424201281776, + 2.5830697884510414, + 2.611880337031451, + 2.9321211492924704, + 3.1138617869040814, + 3.2517989694562175, + 3.5424316488160486, + 3.606652029268858, + 4.434686880159873, + 5.341441928703389, + 5.800945557225404, + 7.031560030404819, + 7.0472581931252165 + ], + "time_since_last_event": [ + 0.7581466412882065, + 0.8612939489565719, + 1.2800291291218997, + 0.40906932233983184, + 1.1249804170988824, + 1.7325054397931154, + 2.4164424201281776, + 0.26646455746383424, + 0.1954379169032734, + 0.3202408122610194, + 0.18174063761161108, + 0.6687291810051761, + 1.1353332629037425, + 0.3548530598126405, + 0.8922552313438246, + 1.7347898994345314, + 0.45950362852201465, + 2.596873150244946, + 0.01569816272039759 + ], + "type_event": [ + 1, + 2, + 0, + 0, + 0, + 2, + 1, + 0, + 1, + 1, + 1, + 0, + 2, + 0, + 2, + 0, + 0, + 2, + 2 + ] + } +] \ No newline at end of file diff --git a/examples/train_experiment/retweet_config.yaml b/examples/train_experiment/retweet_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ac89912d546fb84d7056e200300f1b065bdbd199 --- /dev/null +++ b/examples/train_experiment/retweet_config.yaml @@ -0,0 +1,255 @@ +pipeline_config_id: runner_config + +data: + retweet: + data_format: json + train_dir: easytpp/retweet + valid_dir: easytpp/retweet + test_dir: easytpp/retweet + data_specs: + num_event_types: 3 + pad_token_id: 3 + padding_side: right + truncation_side: right + +NHP_train: + base_config: + stage: train + backend: torch + dataset_id: retweet + runner_id: std_tpp + model_id: NHP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 20 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: -1 + model_config: + hidden_size: 64 + loss_integral_num_sample_per_step: 20 + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 1 + + + +SAHP_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: SAHP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 20 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: 0 + model_config: + hidden_size: 32 + time_emb_size: 16 + num_layers: 2 + num_heads: 2 + loss_integral_num_sample_per_step: 20 + use_ln: False + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 1 + + + +SAHP_gen: + base_config: + stage: gen + backend: torch + dataset_id: retweet + runner_id: std_tpp + model_id: SAHP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 1 + model_config: + hidden_size: 16 + time_emb_size: 4 + num_layers: 2 + num_heads: 2 + loss_integral_num_sample_per_step: 20 + use_ln: False + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 10 + +THP_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: THP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 30 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: -1 + model_config: + hidden_size: 32 + time_emb_size: 16 + num_layers: 2 + num_heads: 2 + mc_num_sample_per_step: 20 + loss_integral_num_sample_per_step: 20 + use_ln: False + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 1 + + +THP_gen: + base_config: + stage: gen + backend: torch + dataset_id: retweet + runner_id: std_tpp + model_id: THP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 1 + model_config: + hidden_size: 32 + time_emb_size: 16 + num_layers: 2 + num_heads: 2 + mc_num_sample_per_step: 20 + loss_integral_num_sample_per_step: 20 + use_ln: False +# pretrained_model_dir: ./checkpoints/2694_4384867712_230603-160544/models/saved_model + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 10 + +AttNHP_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: AttNHP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 200 + shuffle: False + optimizer: adam + learning_rate: 1.e-3 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: -1 + model_config: + hidden_size: 16 + time_emb_size: 4 + num_layers: 2 + num_heads: 2 + loss_integral_num_sample_per_step: 10 + use_ln: False + thinning: + num_seq: 2 + num_sample: 1 + num_exp: 50 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 1 + + +AttNHP_gen: + base_config: + stage: gen + backend: torch + dataset_id: retweet + runner_id: std_tpp + model_id: AttNHP # model name + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 1 + model_config: + hidden_size: 16 + time_emb_size: 4 + num_layers: 2 + num_heads: 2 + mc_num_sample_per_step: 20 + loss_integral_num_sample_per_step: 20 + use_ln: False +# pretrained_model_dir: ./checkpoints/6934_4375315840_230603-222826/models/saved_model + thinning: + num_seq: 10 + num_sample: 1 + num_exp: 50 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm + look_ahead_time: 10 + patience_counter: 5 # the maximum iteration used in adaptive thinning + over_sample_rate: 5 + num_samples_boundary: 5 + dtime_max: 5 + num_step_gen: 10 \ No newline at end of file diff --git a/examples/train_experiment/run_retweet.py b/examples/train_experiment/run_retweet.py new file mode 100644 index 0000000000000000000000000000000000000000..0c04c03db7773e99b0bb9741f00fb78ba5a317bd --- /dev/null +++ b/examples/train_experiment/run_retweet.py @@ -0,0 +1,26 @@ +import argparse + +from easy_tpp.config_factory import Config +from easy_tpp.runner import Runner + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('--config_dir', type=str, required=False, default='retweet_config.yaml', + help='Dir of configuration yaml to train and evaluate the model.') + + parser.add_argument('--experiment_id', type=str, required=False, default='NHP_train', + help='Experiment id in the config file.') + + args = parser.parse_args() + + config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id) + + model_runner = Runner.build_from_config(config) + + model_runner.run() + + +if __name__ == '__main__': + main() diff --git a/examples/train_nhp.py b/examples/train_nhp.py new file mode 100644 index 0000000000000000000000000000000000000000..708c07f50c368cbb0f9daf09eaf1e73fc3cdb0c0 --- /dev/null +++ b/examples/train_nhp.py @@ -0,0 +1,26 @@ +import argparse + +from easy_tpp.config_factory import Config +from easy_tpp.runner import Runner + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('--config_dir', type=str, required=False, default='configs/experiment_config.yaml', + help='Dir of configuration yaml to train and evaluate the model.') + + parser.add_argument('--experiment_id', type=str, required=False, default='NHP_train', + help='Experiment id in the config file.') + + args = parser.parse_args() + + config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id) + + model_runner = Runner.build_from_config(config) + + model_runner.run() + + +if __name__ == '__main__': + main() diff --git a/examples/train_nhp_hpo.py b/examples/train_nhp_hpo.py new file mode 100644 index 0000000000000000000000000000000000000000..124ec08276fc34b168f3254e03264d03071847ca --- /dev/null +++ b/examples/train_nhp_hpo.py @@ -0,0 +1,26 @@ +import argparse + +from easy_tpp.config_factory import Config +from easy_tpp.hpo import HyperTuner + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('--config_dir', type=str, required=False, default='configs/hpo_config.yaml', + help='Dir of configuration yaml to train and evaluate the model.') + + parser.add_argument('--experiment_id', type=str, required=False, default='NHP_train', + help='Experiment id in the config file.') + + args = parser.parse_args() + + config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id) + + tuner = HyperTuner.build_from_config(config) + + tuner.run() + + +if __name__ == '__main__': + main() diff --git a/examples/train_nhp_omegaconf.py b/examples/train_nhp_omegaconf.py new file mode 100644 index 0000000000000000000000000000000000000000..3045befda6bd787c297b91bbc3d687aeab6380a3 --- /dev/null +++ b/examples/train_nhp_omegaconf.py @@ -0,0 +1,29 @@ +from omegaconf import OmegaConf + +from easy_tpp.config_factory import ModelConfig +from easy_tpp.model.torch_model.torch_nhp import NHP + + +def main(): + config_omegaconf = OmegaConf.load('configs/experiment_config.yaml') + + model_config_dict = config_omegaconf.get('NHP_train').get('model_config') + model_config_dict['num_event_types'] = 10 + model_config_dict['num_event_types_pad'] = 11 + model_config_dict['event_pad_index'] = 10 + + model_config = ModelConfig.parse_from_yaml_config(model_config_dict) + + nhp_model = NHP(model_config) + + print(nhp_model.__dict__) + + # config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id) + # + # model_runner = Runner.build_from_config(config) + # + # model_runner.run() + + +if __name__ == '__main__': + main() diff --git a/examples/train_nhp_with_features.py b/examples/train_nhp_with_features.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d61e707fe7c2eb14f5f9d1d47a79284a780b73 --- /dev/null +++ b/examples/train_nhp_with_features.py @@ -0,0 +1,411 @@ +import random +from typing import Optional, Union, Dict, Any + +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader + +from easy_tpp.config_factory import DataSpecConfig, Config +from easy_tpp.model import TorchNHP as NHP +from easy_tpp.preprocess import TPPDataset, EventTokenizer +from easy_tpp.preprocess.data_collator import TPPDataCollator +from easy_tpp.preprocess.event_tokenizer import BatchEncoding +from easy_tpp.utils import PaddingStrategy + + +def make_raw_data(): + data = [ + [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 0, 'loan_amt': 10}], + [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1, 'loan_amt': 10}], + [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1, 'loan_amt': 20}], + [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1, 'loan_amt': 20}], + [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1, 'loan_amt': 20}], + [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1, 'loan_amt': 30}], + ] + for i, j in enumerate([2, 5, 3, 2, 4, 2]): + start_time = 0 + for k in range(j): + delta_t = random.random() + start_time += delta_t + data[i].append({"time_since_last_event": delta_t, + "time_since_start": start_time, + "type_event": random.randint(0, 10), + 'loan_amt': random.randint(10, 30)}) + + return data + + +class TPPDatasetV2(TPPDataset): + def __init__(self, data): + super(TPPDatasetV2, self).__init__(data) + self.loan_amt_seqs = self.data_dict['loan_amt_seqs'] + + def __getitem__(self, idx): + """ + + Args: + idx: iteration index + + Returns: + dict: a dict of time_seqs, time_delta_seqs and type_seqs element + + """ + return dict({'time_seqs': self.time_seqs[idx], 'time_delta_seqs': self.time_delta_seqs[idx], + 'type_seqs': self.type_seqs[idx], 'loan_amt_seqs': self.loan_amt_seqs[idx]}) + + +class EventTokenizerV2(EventTokenizer): + def __init__(self, config): + super(EventTokenizerV2, self).__init__(config) + self.model_input_names.append('loan_amt_seqs') + self.model_input_names.append('type_mask') + + def _pad( + self, + encoded_inputs: Union[Dict[str, Any], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + # check whether we need to pad it + is_all_seq_equal_max_length = [len(seq) == max_length for seq in required_input] + is_all_seq_equal_max_length = np.prod(is_all_seq_equal_max_length) + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and ~is_all_seq_equal_max_length + + batch_output = dict() + + if needs_to_be_padded: + # time seqs + batch_output[self.model_input_names[0]] = self.make_pad_sequence(encoded_inputs[self.model_input_names[0]], + self.pad_token_id, + padding_side=self.padding_side, + max_len=max_length) + # time_delta seqs + batch_output[self.model_input_names[1]] = self.make_pad_sequence(encoded_inputs[self.model_input_names[1]], + self.pad_token_id, + padding_side=self.padding_side, + max_len=max_length) + # type_seqs + batch_output[self.model_input_names[2]] = self.make_pad_sequence(encoded_inputs[self.model_input_names[2]], + self.pad_token_id, + padding_side=self.padding_side, + max_len=max_length, + dtype=np.int32) + + else: + batch_output = encoded_inputs + + # non_pad_mask + # we must use type seqs to check the mask, because the pad_token_id maybe one of valid values in + # time seqs + seq_pad_mask = batch_output[self.model_input_names[2]] == self.pad_token_id + batch_output[self.model_input_names[3]] = ~ seq_pad_mask + + if return_attention_mask: + # attention_mask + batch_output[self.model_input_names[4]] = self.make_attn_mask_for_pad_sequence( + batch_output[self.model_input_names[2]], + self.pad_token_id) + else: + batch_output[self.model_input_names[4]] = [] + + # type_mask + batch_output[self.model_input_names[6]] = self.make_type_mask_for_pad_sequence( + batch_output[self.model_input_names[2]]) + + # loan_amt_seqs + batch_output[self.model_input_names[5]] = self.make_pad_sequence(encoded_inputs[self.model_input_names[-2]], + self.pad_token_id, + padding_side=self.padding_side, + max_len=max_length) + + return batch_output + + +def make_data_loader(): + source_data = make_raw_data() + + time_seqs = [[x["time_since_start"] for x in seq] for seq in source_data] + type_seqs = [[x["type_event"] for x in seq] for seq in source_data] + time_delta_seqs = [[x["time_since_last_event"] for x in seq] for seq in source_data] + loan_amt_seqs = [[x["loan_amt"] for x in seq] for seq in source_data] + + input_data = {'time_seqs': time_seqs, + 'type_seqs': type_seqs, + 'time_delta_seqs': time_delta_seqs, + 'loan_amt_seqs': loan_amt_seqs} + + config = DataSpecConfig.parse_from_yaml_config({'num_event_types': 11, 'batch_size': 1, + 'pad_token_id': 11}) + + dataset = TPPDatasetV2(input_data) + + tokenizer = EventTokenizerV2(config) + + padding = True if tokenizer.padding_strategy is None else tokenizer.padding_strategy + truncation = False if tokenizer.truncation_strategy is None else tokenizer.truncation_strategy + + data_collator = TPPDataCollator(tokenizer=tokenizer, + return_tensors='pt', + max_length=tokenizer.model_max_length, + padding=padding, + truncation=truncation) + + data_loader = DataLoader(dataset, collate_fn=data_collator, batch_size=1) + + return data_loader + + +class NHPV2(NHP): + def __init__(self, model_config): + super(NHPV2, self).__init__(model_config) + + self.layer_loan_amt = nn.Linear(1, model_config.hidden_size) + + self.layer_merge = nn.Linear(model_config.hidden_size * 2, model_config.hidden_size) + + def forward(self, batch, **kwargs): + """Call the model. + + Args: + batch (tuple, list): batch input. + + Returns: + list: hidden states, [batch_size, seq_len, hidden_dim], states right before the event happens; + stacked decay states, [batch_size, max_seq_length, 4, hidden_dim], states right after + the event happens. + """ + time_seq, time_delta_seq, event_seq, batch_non_pad_mask, _, type_mask, loan_amt_seq = batch + + all_hiddens = [] + all_outputs = [] + all_cells = [] + all_cell_bars = [] + all_decays = [] + + max_steps = kwargs.get('max_steps', None) + + max_decay_time = kwargs.get('max_decay_time', 5.0) + + # last event has no time label + max_seq_length = max_steps if max_steps is not None else event_seq.size(1) - 1 + + batch_size = len(event_seq) + c_t, c_bar_t, delta_t, o_t = self.get_init_state(batch_size) + h_t = o_t # Use o_t as the initial hidden state, as in the base NHP + c_t = c_t + c_bar_i = c_bar_t + + # if only one event, then we dont decay + if max_seq_length == 1: + types_sub_batch = event_seq[:, 0] + x_t = self.layer_type_emb(types_sub_batch) + + # i add loan emb here + loan_t = self.layer_loan_amt(loan_amt_seq[:, 0]) + x_t = self.layer_merge(torch.cat(x_t, loan_t)) + + cell_i, c_bar_i, decay_i, output_i = \ + self.rnn_cell(x_t, h_t, c_t, c_bar_i) + + # Append all output + all_outputs.append(output_i) + all_decays.append(decay_i) + all_cells.append(cell_i) + all_cell_bars.append(c_bar_i) + all_hiddens.append(h_t) + else: + # Loop over all events + for i in range(max_seq_length): + if i == event_seq.size(1) - 1: + dt = torch.ones_like(time_delta_seq[:, i]) * max_decay_time + else: + dt = time_delta_seq[:, i + 1] # need to carefully check here + types_sub_batch = event_seq[:, i] + x_t = self.layer_type_emb(types_sub_batch) + + # i add loan emb here + loan_t = self.layer_loan_amt(loan_amt_seq[:, i:i+1]) + x_t = self.layer_merge(torch.cat([x_t, loan_t], dim=-1)) + + # cell_i (batch_size, process_dim) + cell_i, c_bar_i, decay_i, output_i = \ + self.rnn_cell(x_t, h_t, c_t, c_bar_i) + + # States decay - Equation (7) in the paper + c_t, h_t = self.rnn_cell.decay(cell_i, + c_bar_i, + decay_i, + output_i, + dt[:, None]) + + # Append all output + all_outputs.append(output_i) + all_decays.append(decay_i) + all_cells.append(cell_i) + all_cell_bars.append(c_bar_i) + all_hiddens.append(h_t) + + # (batch_size, max_seq_length, hidden_dim) + cell_stack = torch.stack(all_cells, dim=1) + cell_bar_stack = torch.stack(all_cell_bars, dim=1) + decay_stack = torch.stack(all_decays, dim=1) + output_stack = torch.stack(all_outputs, dim=1) + + # [batch_size, max_seq_length, hidden_dim] + hiddens_stack = torch.stack(all_hiddens, dim=1) + + # [batch_size, max_seq_length, 4, hidden_dim] + decay_states_stack = torch.stack((cell_stack, + cell_bar_stack, + decay_stack, + output_stack), + dim=2) + + return hiddens_stack, decay_states_stack + + def loglike_loss(self, batch): + """Compute the loglike loss. + + Args: + batch (list): batch input. + + Returns: + list: loglike loss, num events. + """ + time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, _, type_mask, loan_amt_seq = batch + + hiddens_ti, decay_states = self.forward(batch) + + # Num of samples in each batch and num of event time point in the sequence + batch_size, seq_len, _ = hiddens_ti.size() + + # Lambda(t) right before each event time point + # lambda_at_event - [batch_size, num_times=max_len-1, num_event_types] + # Here we drop the last event because it has no delta_time label (can not decay) + lambda_at_event = self.layer_intensity(hiddens_ti) + + # Compute the big lambda integral in Equation (8) + # 1 - take num_mc_sample rand points in each event interval + # 2 - compute its lambda value for every sample point + # 3 - take average of these sample points + # 4 - times the interval length + + # interval_t_sample - [batch_size, num_times=max_len-1, num_mc_sample] + # for every batch and every event point => do a sampling (num_mc_sampling) + # the first dtime is zero, so we use time_delta_seq[:, 1:] + interval_t_sample = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) + + # [batch_size, num_times = max_len - 1, num_mc_sample, hidden_size] + state_t_sample = self.compute_states_at_sample_times(decay_states, interval_t_sample) + + # [batch_size, num_times = max_len - 1, num_mc_sample, event_num] + lambda_t_sample = self.layer_intensity(state_t_sample) + + type_seqs = type_seqs.long() + event_ll, non_event_ll, num_events = self.compute_loglikelihood( + time_delta_seq=time_delta_seqs[:, 1:], + lambda_at_event=lambda_at_event, + lambdas_loss_samples=lambda_t_sample, + seq_mask=batch_non_pad_mask[:, 1:], + type_seq=type_seqs[:, 1:] + ) + + # (num_samples, num_times) + loss = - (event_ll - non_event_ll).sum() + return loss, num_events + + def compute_states_at_sample_times(self, decay_states, sample_dtimes): + """ + decay_states: (batch_size, seq_len, 4, hidden_dim) + sample_dtimes: (batch_size, seq_len, num_mc_sample) + """ + cell_stack, cell_bar_stack, decay_stack, output_stack = torch.unbind(decay_states, dim=2) + # Add a new axis for samples + _, h_ts = self.rnn_cell.decay( + cell_stack[:, :, None, :], + cell_bar_stack[:, :, None, :], + decay_stack[:, :, None, :], + output_stack[:, :, None, :], + sample_dtimes[..., None] + ) + return h_ts + +def make_model(): + config = Config.build_from_yaml_file('examples/configs/experiment_config.yaml', experiment_id='NHP_train') + model_config = config.model_config + + # hack this + model_config.num_event_types = 11 + model_config.num_event_types_pad = 12 + model_config.pad_token_id = 11 + + model = NHPV2(model_config) + + return model + + +def main(): + data_loader = make_data_loader() + + model = make_model() + + num_epochs = 10 + + opt = torch.optim.Adam(model.parameters(), lr=0.001) + + for i in range(num_epochs): + total_loss = 0 + total_num_event = 0 + for batch in data_loader: + with torch.set_grad_enabled(True): + batch_loss, batch_num_event = model.loglike_loss(batch = batch.values()) + + opt.zero_grad() + batch_loss.backward() + opt.step() + + total_loss += batch_loss + total_num_event += batch_num_event + + avg_loss = total_loss / total_num_event + print(f'epochs {i}: loss {avg_loss}') + + return + + +if __name__ == '__main__': + main() diff --git a/examples/train_robot_thp_with_features.py b/examples/train_robot_thp_with_features.py new file mode 100644 index 0000000000000000000000000000000000000000..9056c60c13ab0e739b4b03f6d67a85becf17a4c0 --- /dev/null +++ b/examples/train_robot_thp_with_features.py @@ -0,0 +1,202 @@ +""" +训练RobotTHP模型(带语义特征) + +展示如何在EasyTPP框架中使用RobotTHP模型,并加载语义特征、偏差特征等 +""" + +import torch +from torch.utils.data import DataLoader + +from easy_tpp.config_factory import DataSpecConfig +from easy_tpp.model import TorchRobotTHP +from easy_tpp.preprocess.robert_dataset import RobertTPPDataset +from easy_tpp.preprocess.robert_tokenizer import RobertEventTokenizer +from easy_tpp.preprocess.data_collator import TPPDataCollator + + +def prepare_robert_data(): + """ + 准备评论罗伯特数据(示例) + + 实际使用时,应该从JSON文件加载并处理 + """ + # 示例数据 + time_seqs = [ + [0.0, 10.5, 25.3, 45.2], + [0.0, 5.2, 12.8] + ] + type_seqs = [ + [0, 1, 2, 1], # post, bot_comment, user_comment, user_comment + [0, 1, 2] + ] + time_delta_seqs = [ + [0.0, 10.5, 14.8, 19.9], + [0.0, 5.2, 7.6] + ] + + # 语义向量(示例:768维BERT向量) + semantic_vectors = [ + [[0.1] * 768, [0.2] * 768, [0.3] * 768, [0.4] * 768], + [[0.1] * 768, [0.2] * 768, [0.3] * 768] + ] + + # 偏差特征(示例:3维 [语境偏差, 情感偏差, 困惑度]) + deviation_features = [ + [[0.0, 0.0, 0.0], [0.7, 0.5, 0.3], [0.2, 0.1, 0.1], [0.3, 0.2, 0.1]], + [[0.0, 0.0, 0.0], [0.6, 0.4, 0.2], [0.1, 0.1, 0.1]] + ] + + # 自发/被@标记(-1=不适用, 0=被@, 1=自发) + is_spontaneous = [ + [-1.0, 1.0, -1.0, -1.0], # 原帖不适用, 罗伯特自发, 用户评论不适用 + [-1.0, 0.0, -1.0] # 原帖不适用, 罗伯特被@, 用户评论不适用 + ] + + return { + 'time_seqs': time_seqs, + 'type_seqs': type_seqs, + 'time_delta_seqs': time_delta_seqs, + 'semantic_vectors': semantic_vectors, + 'deviation_features': deviation_features, + 'is_spontaneous': is_spontaneous + } + + +def create_data_loader(data_dict, config, use_semantic=True, use_deviation=True): + """ + 创建数据加载器 + + Args: + data_dict: 数据字典 + config: 数据配置 + use_semantic: 是否使用语义特征 + use_deviation: 是否使用偏差特征 + + Returns: + DataLoader: 数据加载器 + """ + # 创建数据集 + dataset = RobertTPPDataset(data_dict) + + # 创建分词器 + tokenizer = RobertEventTokenizer( + config, + use_semantic=use_semantic, + use_deviation=use_deviation, + semantic_dim=768 + ) + + # 创建数据整理器 + padding = True if tokenizer.padding_strategy is None else tokenizer.padding_strategy + truncation = False if tokenizer.truncation_strategy is None else tokenizer.truncation_strategy + + data_collator = TPPDataCollator( + tokenizer=tokenizer, + return_tensors='pt', + max_length=tokenizer.model_max_length, + padding=padding, + truncation=truncation + ) + + # 创建数据加载器 + data_loader = DataLoader( + dataset, + collate_fn=data_collator, + batch_size=config.batch_size, + shuffle=True + ) + + return data_loader + + +def main(): + """主函数""" + print("=" * 60) + print("训练RobotTHP模型(带语义特征)") + print("=" * 60) + + # 1. 准备数据 + print("\n1. 准备数据...") + data_dict = prepare_robert_data() + print(f" 序列数: {len(data_dict['time_seqs'])}") + + # 2. 创建配置 + print("\n2. 创建配置...") + config = DataSpecConfig.parse_from_yaml_config({ + 'num_event_types': 4, + 'batch_size': 2, + 'pad_token_id': 4 + }) + + # 3. 创建数据加载器 + print("\n3. 创建数据加载器...") + data_loader = create_data_loader( + data_dict, + config, + use_semantic=True, + use_deviation=True + ) + + # 4. 创建模型配置 + print("\n4. 创建模型...") + from easy_tpp.config_factory import ModelConfig + + model_config = ModelConfig.parse_from_yaml_config({ + 'hidden_size': 128, + 'num_layers': 3, + 'num_heads': 6, + 'dropout_rate': 0.1, + 'num_event_types': 4, + 'num_event_types_pad': 5, + 'pad_token_id': 4, + 'semantic_dim': 768, + 'use_semantic': True, + 'use_deviation': True, + 'use_structure_mask': False, + 'loss_integral_num_sample_per_step': 20, + 'use_mc_samples': True, + 'gpu': -1 + }) + + model = TorchRobotTHP(model_config) + print(f" 模型参数数量: {sum(p.numel() for p in model.parameters()):,}") + + # 5. 测试一个批次 + print("\n5. 测试数据加载...") + for batch in data_loader: + # batch是BatchEncoding对象,需要转换为tuple/list + batch_values = batch.values() + + print(f" 批次大小: {len(batch_values[0])}") + print(f" 序列长度: {batch_values[0].shape[1]}") + print(f" 时间序列形状: {batch_values[0].shape}") + print(f" 事件类型形状: {batch_values[2].shape}") + + if len(batch_values) > 5: + print(f" 语义向量形状: {batch_values[5].shape if batch_values[5] is not None else 'None'}") + if len(batch_values) > 6: + print(f" 偏差特征形状: {batch_values[6].shape if batch_values[6] is not None else 'None'}") + if len(batch_values) > 7: + print(f" 自发标记形状: {batch_values[7].shape if batch_values[7] is not None else 'None'}") + + # 6. 测试前向传播 + print("\n6. 测试前向传播...") + model.eval() + with torch.no_grad(): + loss, num_events = model.loglike_loss(batch_values) + print(f" 损失值: {loss.item():.4f}") + print(f" 事件数: {num_events}") + + break + + print("\n✅ 测试完成!") + print("\n使用说明:") + print("1. 将你的JSON数据转换为上述格式") + print("2. 使用RobertTPPDataset和RobertEventTokenizer加载数据") + print("3. 在EasyTPP配置文件中设置model_id为RobotTHP") + print("4. 运行训练即可") + + +if __name__ == '__main__': + main() + diff --git a/notebooks/easytpp_1_dataset.ipynb b/notebooks/easytpp_1_dataset.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..a4470210a6e55df29de521914dfe29573d6d5ab8 --- /dev/null +++ b/notebooks/easytpp_1_dataset.ipynb @@ -0,0 +1,2857 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "TmnzuOArbQk-" + }, + "source": [ + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_1_dataset.ipynb)\n", + "\n", + "# Tutorial 1: Dataset in EasyTPP" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "26Wvh9rZbTcg" + }, + "source": [ + "In this tutorial, we’ll explore the dataset-related functionalities in **EasyTPP**, an advanced library designed for temporal point process modeling. We will guide you through the installation process, data loading options, and configurations to set up a training pipeline effectively.\n", + "\n", + "\n", + "## Step 1: Install EasyTPP\n", + "First, let’s install the EasyTPP package. Run the following command to install the library in your Colab environment:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "U-gIiMZqMPFy", + "outputId": "40a622f2-57c2-4742-f913-9d58302d4e0e" + }, + "outputs": [], + "source": [ + "# ues the latest release\n", + "# !pip install easy-tpp\n", + "\n", + "# or use the git main branch\n", + "!pip install git+https://github.com/ant-research/EasyTemporalPointProcess.git" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I5YUvAc7bngQ" + }, + "source": [ + "## Step 2: Loading Preprocessed Datasets\n", + "\n", + "EasyTPP provides two methods to load preprocessed datasets:\n", + "- [Google Drive](https://drive.google.com/drive/folders/1f8k82-NL6KFKuNMsUwozmbzDSFycYvz7): Download the dataset in pickle format.\n", + "- [HuggingFace](https://huggingface.co/easytpp): Load the dataset in JSON format from the HuggingFace repository.\n", + "\n", + "> Note: The pickle format from Google Drive will be deprecated in future releases, and we recommend using the JSON files from HuggingFace for better compatibility and performance.\n", + "\n", + "\n", + "### Option 1: Load Pickle Data Files (Deprecated Soon)\n", + "If you choose to use the pickle files, muanlly download the data files fromt he Google Drive mentioned above, place them under a data directory in your workspace, and specify the directory path in the configuration file.\n", + "\n", + "Here is an example configuration for loading a Taxi dataset in pickle format:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6zfSHKhDmFSS" + }, + "source": [ + "\n", + "\n", + "```\n", + "data:\n", + " taxi:\n", + " data_format: pickle\n", + " train_dir: ./data/taxi/train.pkl\n", + " valid_dir: ./data/taxi/dev.pkl\n", + " test_dir: ./data/taxi/test.pkl\n", + "```\n", + "\n", + "Then we can launch the train/evaluation pipeline process. See [experiment_config](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml) for the full example.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HHPDzqud2wJf" + }, + "source": [ + "\n", + "### Option 2: Load JSON Data Files (Recommended)\n", + "\n", + "To use JSON data files from HuggingFace - [EasyTPP Repo](https://huggingface.co/easytpp), simply replace `data_format: pickle` with `data_format: json` in the config file, and update the directory paths accordingly. This setup is recommended for newer versions of EasyTPP and provides better compatibility with various processing functions in the library.\n", + "\n", + "To activate this loading process in the train/evaluation pipeline, similarly, we put the directory of huggingface repo in the config file, e.g.,\n", + "\n", + "```\n", + "data:\n", + " taxi:\n", + " data_format: json\n", + " train_dir: easytpp/taxi\n", + " valid_dir: easytpp/taxi\n", + " test_dir: easytpp/taxi\n", + "```\n", + "\n", + "Note that we can also manually put the locally directory of json files in the config:\n", + "\n", + "```\n", + "data:\n", + " taxi:\n", + " data_format: json\n", + " train_dir: ./data/taxi/train.json\n", + " valid_dir: ./data/taxi/dev.json\n", + " test_dir: ./data/taxi/test.json\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6HJd1lZB33mP" + }, + "source": [ + "## Step 3: Exploring Datasets\n", + "\n", + "The EasyTPP library offers several functions to streamline dataset loading and preprocessing. Let’s go over a few key functionalities:\n", + "\n", + "### Dataset Properties\n", + "\n", + "We firstly use the official HuggingFace APIs to directly download and inspect the dataset.\n", + "\n", + "In this example, the `load_dataset` function is used to load the \"taxi\" dataset, which is relatively small and suited for quick testing. The dataset is automatically split into three parts: train, validation, and test, with each split containing structured information on the events." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 368, + "referenced_widgets": [ + "41522af32da04925bc328c4a9d7e3a82", + "375ac02a2d6b43f4a70feaf921c5377d", + "ae9d5f81adc343edb3f7de9dccb15f47", + "03b8f2a927ff406b88d8ff844f620403", + "a4ac525ccb134ccbab57aaaafc37045b", + "8cfe483c4a924e6992a81fc44274f240", + "5d9a76f49322447d9ff1c71ab3f9bee7", + "8525620ea0104a05b5ba633671a95873", + "65d7f89d96cb4e9188da3db1300076aa", + "36ca2862fb10458abb4ba13a9e75ecf2", + "722c07901b4c454bb68055971a445778", + "5848c90c9f2144b588aa88290b529151", + "925eca71c7c64811ab88b672364f440d", + "3b29ed8d36864982a47b3510a83f00af", + "2b100553e6fe4bc5ad83e17ae523a42e", + "28db861ff5f845baac00370fbcd6521b", + "e6a5ce28f9974ba3a12d5044aa2e07d4", + "9c6f4364de0f4cfe90e189916df4a040", + "cd76860acc1b41b997d6df7427596106", + "500c3e320fef4e9e906bccc844c4bf4e", + "cf4cf5ecb14f45539e14971c4e8a7b07", + "5ff53796dc0046f8b077a84836d4fc61", + "1d41b7f7b722466f89c74067a3ed783d", + "7f82d7dbe42e445198a1f2e0bc512ea6", + "e0ad14cb4708466584a0ff46ebf70539", + "0b921bafccf148dfaa76e94580800dfb", + "cccb52e989424990a2f6fcd881a4137c", + "f3ad7c86e9604c27837ea509677cda39", + "599d263b0edc42c48b732f22b6f4bb52", + "b9cf1a3cdf9544578a1a1ac73dd182b9", + "8522fc20649e40aeb1fc9aabbd3501b0", + "c986defd720346949d9dfb4a4822bb68", + "e2ed3c0ee1bf4beeb1f564f43031648d", + "3cdce1231011425e913007b372d1354a", + "cfa4572b597345f7a5c69780eb2e19ab", + "494232ab57db40e391e34a90f0120dd7", + "df95b675f0504e34ab3b2f17d09109c2", + "02f0c7450af24e849e4c573ff7e0f9d8", + "a8eee5c478d84a4b828970f5f9c55016", + "faaf164d7dc142c3bfe55aac535efbf4", + "78bfe063ac114e4398007a26053956eb", + "0829c53e563b48c0b84dbad6baef16e0", + "86b40e77e7a34386b0876859f151b49a", + "59ebcd3733c14e8b87a42dc7399cc5e4", + "6d1e4b19483343bfa1dbd713c2efbd1f", + "ba666b3bd47e4568afeaec47a7acb912", + "38c9c5e1bd814f5ea92047d6f5833506", + "76579ac830bd404ba5008c5315c323c3", + "63da81f729204fb8b2dc5cc12074e09b", + "0e2d3e97ceed4b44b55ea7f70047ec7e", + "9adf81c38d8d4cce873ebed60c5b208c", + "3d24235ee98a4daeaaaf85a7b5763f0b", + "c4312014f1a54feab2dc7c2b58a837c5", + "75a31d51a0f940aa835b47236b4761b7", + "9bc1dbb8d296432f8c79f67a3cd5ef80", + "441615ebaf0f4db699d5fd569ee1655a", + "21ebc80b5d3c4eca8f8935ffd1a3590c", + "f622d1072c63472db30cb6ec6f6e72ae", + "879c8911eb3b41febff0d2a33cde55b4", + "22e1b3631d324a119e35e53088209296", + "93fbb1bbc4304d758998c904955ece3e", + "6bb36084b56f4277ad35da260978d9bf", + "fb6be1690b3247ee96124fd59c4c001d", + "1da7223d069345d88624d9a5485e1752", + "dc0b2a2b34fc483abe00409859000d02", + "5cdeca479dec41e9ba78fef207360283", + "3d4156a02f7a436f9d678531bdcad532", + "04ac17403fdd480d985d1f7885196c46", + "0b0bc44093404ad8b396dbc7f86a6967", + "f9d8cff1f4b24cd4b1aea8693746a81e", + "9b8a0445912840c194aed0846228972f", + "db91bc99d2964fc69fb4a71abeb0c254", + "0d554c589be24338bdeda06958b77140", + "33336a9fd5aa42acb54a06486781fb8b", + "47d50654d080473b9360a49e246d70df", + "f12c68ae19424e8eb8cad1e105674f4c", + "1d5ef40cf0004c5caa90c04564c8facb" + ] + }, + "id": "8sM6riIxQClw", + "outputId": "a5742b18-2ea9-408e-8ef4-dac9c6448cc5" + }, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "# we choose taxi dataset as it is relatively small\n", + "dataset = load_dataset('easytpp/taxi')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DM6yw1u0E6kL" + }, + "source": [ + "Each dataset split is a Dataset object with multiple features such as `seq_len `(sequence length), `time_since_start`, `seq_idx` (sequence index), `time_since_last_event`, `type_event` (event type), and `dim_process` (dimension of the process). This structured format provides essential information about each event's timing and type, which is crucial for modeling temporal point processes. Additionally, the package simplifies data access, allowing users to select specific splits and features for further analysis or model input with minimal setup." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "BZYUTFsDRHmL", + "outputId": "dfc38373-6079-4ca6-9cbc-ceba5b81e2d3" + }, + "outputs": [], + "source": [ + "dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-CpGsgnMFa6Z" + }, + "source": [ + "In the easytpp dataset, the `type_event` feature represents the event type codes within each sequence. In the example shown, `dataset['train']['type_event'][0]` reveals a list of integer codes, such as [8, 3, 8, 3, 8, 3, ...], corresponding to different types of events in the first sequence of the training set. These codes are likely categorical identifiers used to differentiate various types of events in the temporal point process, which can be useful in understanding event dynamics and patterns over time. This feature enables the model to learn and predict not only the timing but also the type of future events within the sequence." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NJKP0ATnv4_l", + "outputId": "7bd0df3f-14c1-4fda-bc99-d5ff7c9d9fbe" + }, + "outputs": [], + "source": [ + "dataset['train']['type_event'][0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "int17B0zKRtR" + }, + "source": [ + "### Dataset Distributions\n", + "\n", + "In the following code snippet, several functions from the EasyTPP package are used to configure and analyze a TPP dataset, providing insights into event distribution and timing characteristics.\n", + "\n", + "#### Dataset Configuration\n", + "\n", + "The `Config.build_from_yaml_file` function loads configurations from a specified YAML file (`config.yaml`). This file contains settings for data preprocessing, model parameters, and other configurations needed by the `TPPDataLoader` to manage and process the TPP data. By centralizing settings in a configuration file, this function allows for easier parameter management and adjustments without altering the code." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "G6a74A43MLQN" + }, + "outputs": [], + "source": [ + "# As an illustrative example, we write the YAML content to a file\n", + "yaml_content = \"\"\"\n", + "pipeline_config_id: data_config\n", + "\n", + "data_format: json\n", + "train_dir: easytpp/taxi # ./data/taxi/train.json\n", + "valid_dir: easytpp/taxi # ./data/taxi/dev.json\n", + "test_dir: easytpp/taxi # ./data/taxi/test.json\n", + "data_specs:\n", + " num_event_types: 10\n", + " pad_token_id: 10\n", + " padding_side: right\n", + "\"\"\"\n", + "\n", + "# Save the content to a file named config.yaml\n", + "with open(\"config.yaml\", \"w\") as file:\n", + " file.write(yaml_content)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rUBkm8JULMmP", + "outputId": "c1657049-7fd0-483f-ef24-ad5197a83714" + }, + "outputs": [], + "source": [ + "from easy_tpp.config_factory import Config\n", + "from easy_tpp.preprocess.data_loader import TPPDataLoader\n", + "\n", + "\n", + "config = Config.build_from_yaml_file('./config.yaml')\n", + "tpp_loader = TPPDataLoader(config)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wFBh_0sBM3EF" + }, + "source": [ + "#### Dataset Statistics\n", + "\n", + "\n", + "The `get_statistics` function retrieves statistical information about the dataset,\n", + "such as the distribution of event types, sequence lengths, and timing intervals. By specifying `split='train'`, this function targets only the training subset of the dataset. The resulting stats variable is printed to provide an overview of the dataset's\n", + "structure and characteristics, which can be helpful for understanding the data before model training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mz95yYH4NL3T", + "outputId": "40e28c82-c57f-4dd5-8687-e2d324ee9097" + }, + "outputs": [], + "source": [ + "stats = tpp_loader.get_statistics(split='train')\n", + "stats" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yA61VBnhLLEw" + }, + "source": [ + "#### Event Type Distribution Plot\n", + "\n", + "The following function generates a plot of the distribution of event types within the dataset. This visualization helps identify the frequency of different event types, which can be useful for analyzing class imbalance or the prevalence of certain types of events. Understanding event type distribution is essential for TPP models, as it informs the model about the likelihood and variety of event types." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 564 + }, + "id": "M3pheYYPNfYP", + "outputId": "3964b087-678f-4a69-ca34-50be977639cf" + }, + "outputs": [], + "source": [ + "tpp_loader.plot_event_type_distribution()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "F4Z-YvxkN66c" + }, + "source": [ + "#### Event Delta Time Distribution Plot\n", + "\n", + "\n", + "The `plot_event_delta_times_distribution` function visualizes the distribution of time intervals between consecutive events (delta times) in the dataset. This plot provides insights into the temporal patterns of events, such as whether they occur at regular intervals or vary widely. Understanding delta time distribution is crucial for TPP models since these patterns directly affect how the model learns to predict the timing of future events." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 564 + }, + "id": "ZLB47wR1ODAH", + "outputId": "bf0340e0-d259-42ea-c718-209f2bd2ee53" + }, + "outputs": [], + "source": [ + "tpp_loader.plot_event_delta_times_distribution()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "70Qundm8OFwM" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "26Wvh9rZbTcg", + "I5YUvAc7bngQ" + ], + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "02f0c7450af24e849e4c573ff7e0f9d8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "03b8f2a927ff406b88d8ff844f620403": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_36ca2862fb10458abb4ba13a9e75ecf2", + "placeholder": "​", + "style": "IPY_MODEL_722c07901b4c454bb68055971a445778", + "value": " 28.0/28.0 [00:00<00:00, 1.69kB/s]" + } + }, + "04ac17403fdd480d985d1f7885196c46": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_db91bc99d2964fc69fb4a71abeb0c254", + "placeholder": "​", + "style": "IPY_MODEL_0d554c589be24338bdeda06958b77140", + "value": "Generating test split: 100%" + } + }, + "0829c53e563b48c0b84dbad6baef16e0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "0b0bc44093404ad8b396dbc7f86a6967": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_33336a9fd5aa42acb54a06486781fb8b", + "max": 400, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_47d50654d080473b9360a49e246d70df", + "value": 400 + } + }, + "0b921bafccf148dfaa76e94580800dfb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c986defd720346949d9dfb4a4822bb68", + "placeholder": "​", + "style": "IPY_MODEL_e2ed3c0ee1bf4beeb1f564f43031648d", + "value": " 327k/327k [00:00<00:00, 11.6MB/s]" + } + }, + "0d554c589be24338bdeda06958b77140": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "0e2d3e97ceed4b44b55ea7f70047ec7e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1d41b7f7b722466f89c74067a3ed783d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7f82d7dbe42e445198a1f2e0bc512ea6", + "IPY_MODEL_e0ad14cb4708466584a0ff46ebf70539", + "IPY_MODEL_0b921bafccf148dfaa76e94580800dfb" + ], + "layout": "IPY_MODEL_cccb52e989424990a2f6fcd881a4137c" + } + }, + "1d5ef40cf0004c5caa90c04564c8facb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "1da7223d069345d88624d9a5485e1752": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "21ebc80b5d3c4eca8f8935ffd1a3590c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_93fbb1bbc4304d758998c904955ece3e", + "placeholder": "​", + "style": "IPY_MODEL_6bb36084b56f4277ad35da260978d9bf", + "value": "Generating validation split: 100%" + } + }, + "22e1b3631d324a119e35e53088209296": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "28db861ff5f845baac00370fbcd6521b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2b100553e6fe4bc5ad83e17ae523a42e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_cf4cf5ecb14f45539e14971c4e8a7b07", + "placeholder": "​", + "style": "IPY_MODEL_5ff53796dc0046f8b077a84836d4fc61", + "value": " 2.29M/2.29M [00:00<00:00, 6.88MB/s]" + } + }, + "33336a9fd5aa42acb54a06486781fb8b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "36ca2862fb10458abb4ba13a9e75ecf2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "375ac02a2d6b43f4a70feaf921c5377d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8cfe483c4a924e6992a81fc44274f240", + "placeholder": "​", + "style": "IPY_MODEL_5d9a76f49322447d9ff1c71ab3f9bee7", + "value": "README.md: 100%" + } + }, + "38c9c5e1bd814f5ea92047d6f5833506": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3d24235ee98a4daeaaaf85a7b5763f0b", + "max": 1400, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c4312014f1a54feab2dc7c2b58a837c5", + "value": 1400 + } + }, + "3b29ed8d36864982a47b3510a83f00af": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_cd76860acc1b41b997d6df7427596106", + "max": 2287957, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_500c3e320fef4e9e906bccc844c4bf4e", + "value": 2287957 + } + }, + "3cdce1231011425e913007b372d1354a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_cfa4572b597345f7a5c69780eb2e19ab", + "IPY_MODEL_494232ab57db40e391e34a90f0120dd7", + "IPY_MODEL_df95b675f0504e34ab3b2f17d09109c2" + ], + "layout": "IPY_MODEL_02f0c7450af24e849e4c573ff7e0f9d8" + } + }, + "3d24235ee98a4daeaaaf85a7b5763f0b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3d4156a02f7a436f9d678531bdcad532": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_04ac17403fdd480d985d1f7885196c46", + "IPY_MODEL_0b0bc44093404ad8b396dbc7f86a6967", + "IPY_MODEL_f9d8cff1f4b24cd4b1aea8693746a81e" + ], + "layout": "IPY_MODEL_9b8a0445912840c194aed0846228972f" + } + }, + "41522af32da04925bc328c4a9d7e3a82": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_375ac02a2d6b43f4a70feaf921c5377d", + "IPY_MODEL_ae9d5f81adc343edb3f7de9dccb15f47", + "IPY_MODEL_03b8f2a927ff406b88d8ff844f620403" + ], + "layout": "IPY_MODEL_a4ac525ccb134ccbab57aaaafc37045b" + } + }, + "441615ebaf0f4db699d5fd569ee1655a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_21ebc80b5d3c4eca8f8935ffd1a3590c", + "IPY_MODEL_f622d1072c63472db30cb6ec6f6e72ae", + "IPY_MODEL_879c8911eb3b41febff0d2a33cde55b4" + ], + "layout": "IPY_MODEL_22e1b3631d324a119e35e53088209296" + } + }, + "47d50654d080473b9360a49e246d70df": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "494232ab57db40e391e34a90f0120dd7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_78bfe063ac114e4398007a26053956eb", + "max": 653866, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_0829c53e563b48c0b84dbad6baef16e0", + "value": 653866 + } + }, + "500c3e320fef4e9e906bccc844c4bf4e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "5848c90c9f2144b588aa88290b529151": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_925eca71c7c64811ab88b672364f440d", + "IPY_MODEL_3b29ed8d36864982a47b3510a83f00af", + "IPY_MODEL_2b100553e6fe4bc5ad83e17ae523a42e" + ], + "layout": "IPY_MODEL_28db861ff5f845baac00370fbcd6521b" + } + }, + "599d263b0edc42c48b732f22b6f4bb52": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "59ebcd3733c14e8b87a42dc7399cc5e4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5cdeca479dec41e9ba78fef207360283": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5d9a76f49322447d9ff1c71ab3f9bee7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5ff53796dc0046f8b077a84836d4fc61": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "63da81f729204fb8b2dc5cc12074e09b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "65d7f89d96cb4e9188da3db1300076aa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "6bb36084b56f4277ad35da260978d9bf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6d1e4b19483343bfa1dbd713c2efbd1f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_ba666b3bd47e4568afeaec47a7acb912", + "IPY_MODEL_38c9c5e1bd814f5ea92047d6f5833506", + "IPY_MODEL_76579ac830bd404ba5008c5315c323c3" + ], + "layout": "IPY_MODEL_63da81f729204fb8b2dc5cc12074e09b" + } + }, + "722c07901b4c454bb68055971a445778": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "75a31d51a0f940aa835b47236b4761b7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "76579ac830bd404ba5008c5315c323c3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_75a31d51a0f940aa835b47236b4761b7", + "placeholder": "​", + "style": "IPY_MODEL_9bc1dbb8d296432f8c79f67a3cd5ef80", + "value": " 1400/1400 [00:00<00:00, 6217.86 examples/s]" + } + }, + "78bfe063ac114e4398007a26053956eb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7f82d7dbe42e445198a1f2e0bc512ea6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f3ad7c86e9604c27837ea509677cda39", + "placeholder": "​", + "style": "IPY_MODEL_599d263b0edc42c48b732f22b6f4bb52", + "value": "dev.json: 100%" + } + }, + "8522fc20649e40aeb1fc9aabbd3501b0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "8525620ea0104a05b5ba633671a95873": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "86b40e77e7a34386b0876859f151b49a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "879c8911eb3b41febff0d2a33cde55b4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_dc0b2a2b34fc483abe00409859000d02", + "placeholder": "​", + "style": "IPY_MODEL_5cdeca479dec41e9ba78fef207360283", + "value": " 200/200 [00:00<00:00, 3410.31 examples/s]" + } + }, + "8cfe483c4a924e6992a81fc44274f240": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "925eca71c7c64811ab88b672364f440d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e6a5ce28f9974ba3a12d5044aa2e07d4", + "placeholder": "​", + "style": "IPY_MODEL_9c6f4364de0f4cfe90e189916df4a040", + "value": "train.json: 100%" + } + }, + "93fbb1bbc4304d758998c904955ece3e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9adf81c38d8d4cce873ebed60c5b208c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9b8a0445912840c194aed0846228972f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9bc1dbb8d296432f8c79f67a3cd5ef80": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9c6f4364de0f4cfe90e189916df4a040": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a4ac525ccb134ccbab57aaaafc37045b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a8eee5c478d84a4b828970f5f9c55016": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ae9d5f81adc343edb3f7de9dccb15f47": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8525620ea0104a05b5ba633671a95873", + "max": 28, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_65d7f89d96cb4e9188da3db1300076aa", + "value": 28 + } + }, + "b9cf1a3cdf9544578a1a1ac73dd182b9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ba666b3bd47e4568afeaec47a7acb912": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0e2d3e97ceed4b44b55ea7f70047ec7e", + "placeholder": "​", + "style": "IPY_MODEL_9adf81c38d8d4cce873ebed60c5b208c", + "value": "Generating train split: 100%" + } + }, + "c4312014f1a54feab2dc7c2b58a837c5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "c986defd720346949d9dfb4a4822bb68": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cccb52e989424990a2f6fcd881a4137c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cd76860acc1b41b997d6df7427596106": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cf4cf5ecb14f45539e14971c4e8a7b07": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cfa4572b597345f7a5c69780eb2e19ab": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a8eee5c478d84a4b828970f5f9c55016", + "placeholder": "​", + "style": "IPY_MODEL_faaf164d7dc142c3bfe55aac535efbf4", + "value": "test.json: 100%" + } + }, + "db91bc99d2964fc69fb4a71abeb0c254": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "dc0b2a2b34fc483abe00409859000d02": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "df95b675f0504e34ab3b2f17d09109c2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_86b40e77e7a34386b0876859f151b49a", + "placeholder": "​", + "style": "IPY_MODEL_59ebcd3733c14e8b87a42dc7399cc5e4", + "value": " 654k/654k [00:00<00:00, 24.4MB/s]" + } + }, + "e0ad14cb4708466584a0ff46ebf70539": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b9cf1a3cdf9544578a1a1ac73dd182b9", + "max": 327062, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_8522fc20649e40aeb1fc9aabbd3501b0", + "value": 327062 + } + }, + "e2ed3c0ee1bf4beeb1f564f43031648d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e6a5ce28f9974ba3a12d5044aa2e07d4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f12c68ae19424e8eb8cad1e105674f4c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f3ad7c86e9604c27837ea509677cda39": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f622d1072c63472db30cb6ec6f6e72ae": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fb6be1690b3247ee96124fd59c4c001d", + "max": 200, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_1da7223d069345d88624d9a5485e1752", + "value": 200 + } + }, + "f9d8cff1f4b24cd4b1aea8693746a81e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f12c68ae19424e8eb8cad1e105674f4c", + "placeholder": "​", + "style": "IPY_MODEL_1d5ef40cf0004c5caa90c04564c8facb", + "value": " 400/400 [00:00<00:00, 4474.87 examples/s]" + } + }, + "faaf164d7dc142c3bfe55aac535efbf4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "fb6be1690b3247ee96124fd59c4c001d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/notebooks/easytpp_2_tfb_wb.ipynb b/notebooks/easytpp_2_tfb_wb.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..8acdd7658371a0571e8cf5e646e2e0126aa36053 --- /dev/null +++ b/notebooks/easytpp_2_tfb_wb.ipynb @@ -0,0 +1,248 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_2_tfb_wb.ipynb)\n", + "\n", + "\n", + "# Tutorial 2: Tensorboard and Weights & Biases in EasyTPP\n", + "\n", + "EasyTPP provides built-in support for both Tensorboard and Weights & Biases (W&B) to help you track and visualize your model training. These tools allow you to monitor metrics, compare experiments, and debug your models effectively.\n", + "\n", + "\n", + "## Example of using Tensorboard" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2025-02-03T02:24:56.584850Z", + "start_time": "2025-02-03T02:24:56.580600Z" + } + }, + "outputs": [], + "source": [ + "# As an illustrative example, we write the YAML content to a file\n", + "yaml_content = \"\"\"\n", + "pipeline_config_id: runner_config\n", + "\n", + "data:\n", + " taxi:\n", + " data_format: json\n", + " train_dir: easytpp/taxi # ./data/taxi/train.json\n", + " valid_dir: easytpp/taxi # ./data/taxi/dev.json\n", + " test_dir: easytpp/taxi # ./data/taxi/test.json\n", + " data_specs:\n", + " num_event_types: 10\n", + " pad_token_id: 10\n", + " padding_side: right\n", + "\n", + "\n", + "NHP_train:\n", + " base_config:\n", + " stage: train\n", + " backend: torch\n", + " dataset_id: taxi\n", + " runner_id: std_tpp\n", + " model_id: NHP # model name\n", + " base_dir: './checkpoints/'\n", + " trainer_config:\n", + " batch_size: 256\n", + " max_epoch: 2\n", + " shuffle: False\n", + " optimizer: adam\n", + " learning_rate: 1.e-3\n", + " valid_freq: 1\n", + " use_tfb: True\n", + " metrics: [ 'acc', 'rmse' ]\n", + " seed: 2019\n", + " gpu: -1\n", + " model_config:\n", + " hidden_size: 32\n", + " loss_integral_num_sample_per_step: 20\n", + " thinning:\n", + " num_seq: 10\n", + " num_sample: 1\n", + " num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm\n", + " look_ahead_time: 10\n", + " patience_counter: 5 # the maximum iteration used in adaptive thinning\n", + " over_sample_rate: 5\n", + " num_samples_boundary: 5\n", + " dtime_max: 5\n", + " num_step_gen: 1\n", + "\"\"\"\n", + "\n", + "# Save the content to a file named config.yaml\n", + "with open(\"config.yaml\", \"w\") as file:\n", + " file.write(yaml_content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we run the following command to train the model:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[31;1m2025-02-03 10:32:32,085 - config.py[pid:91053;line:34:build_from_yaml_file] - CRITICAL: Load pipeline config class RunnerConfig\u001b[0m\n", + "\u001b[31;1m2025-02-03 10:32:32,089 - runner_config.py[pid:91053;line:161:update_config] - CRITICAL: train model NHP using CPU with torch backend\u001b[0m\n", + "\u001b[38;20m2025-02-03 10:32:32,098 - runner_config.py[pid:91053;line:36:__init__] - INFO: Save the config to ./checkpoints/91053_8345177088_250203-103232/NHP_train_output.yaml\u001b[0m\n", + "\u001b[38;20m2025-02-03 10:32:32,099 - base_runner.py[pid:91053;line:176:save_log] - INFO: Save the log to ./checkpoints/91053_8345177088_250203-103232/log\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/miniconda3/envs/llm/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Downloading readme: 100%|██████████| 28.0/28.0 [00:00<00:00, 119B/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.2244252199397379 0.29228809611195583\n", + "min_dt: 0.000277777777777\n", + "max_dt: 5.721388888888889\n", + "\u001b[38;20m2025-02-03 10:32:38,267 - tpp_runner.py[pid:91053;line:60:_init_model] - INFO: Num of model parameters 15252\u001b[0m\n", + "\u001b[38;20m2025-02-03 10:32:45,909 - base_runner.py[pid:91053;line:98:train] - INFO: Data 'taxi' loaded...\u001b[0m\n", + "\u001b[38;20m2025-02-03 10:32:45,910 - base_runner.py[pid:91053;line:103:train] - INFO: Start NHP training...\u001b[0m\n", + "\u001b[38;20m2025-02-03 10:32:46,425 - tpp_runner.py[pid:91053;line:96:_train_model] - INFO: [ Epoch 0 (train) ]: train loglike is -1.7553733776992408, num_events is 50454\u001b[0m\n", + "\u001b[38;20m2025-02-03 10:32:47,128 - tpp_runner.py[pid:91053;line:107:_train_model] - INFO: [ Epoch 0 (valid) ]: valid loglike is -1.6691416010202664, num_events is 7204, acc is 0.4414214325374792, rmse is 0.3327808472052436\u001b[0m\n", + "\u001b[38;20m2025-02-03 10:32:48,150 - tpp_runner.py[pid:91053;line:122:_train_model] - INFO: [ Epoch 0 (test) ]: test loglike is -1.6577474861303745, num_events is 14420, acc is 0.44667128987517335, rmse is 0.3408341129976238\u001b[0m\n", + "\u001b[31;1m2025-02-03 10:32:48,150 - tpp_runner.py[pid:91053;line:124:_train_model] - CRITICAL: current best loglike on valid set is -1.6691 (updated at epoch-0), best updated at this epoch\u001b[0m\n", + "\u001b[38;20m2025-02-03 10:32:48,487 - tpp_runner.py[pid:91053;line:96:_train_model] - INFO: [ Epoch 1 (train) ]: train loglike is -1.6284447180538213, num_events is 50454\u001b[0m\n", + "\u001b[38;20m2025-02-03 10:32:48,995 - tpp_runner.py[pid:91053;line:107:_train_model] - INFO: [ Epoch 1 (valid) ]: valid loglike is -1.5259201159945863, num_events is 7204, acc is 0.4582176568573015, rmse is 0.33537458414488913\u001b[0m\n", + "\u001b[38;20m2025-02-03 10:32:49,999 - tpp_runner.py[pid:91053;line:122:_train_model] - INFO: [ Epoch 1 (test) ]: test loglike is -1.5121817706527392, num_events is 14420, acc is 0.45977808599167824, rmse is 0.34166548827945314\u001b[0m\n", + "\u001b[31;1m2025-02-03 10:32:50,000 - tpp_runner.py[pid:91053;line:124:_train_model] - CRITICAL: current best loglike on valid set is -1.5259 (updated at epoch-1), best updated at this epoch\u001b[0m\n", + "\u001b[38;20m2025-02-03 10:32:50,000 - base_runner.py[pid:91053;line:110:train] - INFO: End NHP train! Cost time: 0.068m\u001b[0m\n" + ] + } + ], + "source": [ + "from easy_tpp.config_factory import Config\n", + "from easy_tpp.runner import Runner\n", + "\n", + "config = Config.build_from_yaml_file('./config.yaml', experiment_id='NHP_train')\n", + "\n", + "model_runner = Runner.build_from_config(config)\n", + "\n", + "model_runner.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "source": [ + "After the training is done, we can see the tensorboard files in the `./checkpoints/` directory. " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34mcheckpoints\u001b[m\u001b[m easytpp_1_dataset.ipynb\n", + "config.yaml easytpp_2_tfb_wb.ipynb\n", + "\n", + "./checkpoints:\n", + "\u001b[34m91053_8345177088_250203-103232\u001b[m\u001b[m\n", + "\n", + "./checkpoints/91053_8345177088_250203-103232:\n", + "NHP_train_output.yaml \u001b[34mmodels\u001b[m\u001b[m \u001b[34mtfb_valid\u001b[m\u001b[m\n", + "log \u001b[34mtfb_train\u001b[m\u001b[m\n", + "\n", + "./checkpoints/91053_8345177088_250203-103232/models:\n", + "saved_model\n", + "\n", + "./checkpoints/91053_8345177088_250203-103232/tfb_train:\n", + "events.out.tfevents.1738549958.siqiaodeMacBook-Pro.local.91053.0\n", + "\n", + "./checkpoints/91053_8345177088_250203-103232/tfb_valid:\n", + "events.out.tfevents.1738549958.siqiaodeMacBook-Pro.local.91053.1\n" + ] + } + ], + "source": [ + "!ls -R" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we can use the following script to visualize the training process:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorFlow installation not found - running with reduced feature set.\n", + "Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all\n", + "TensorBoard 2.17.1 at http://localhost:6006/ (Press CTRL+C to quit)\n" + ] + } + ], + "source": [ + "! tensorboard --logdir \"./checkpoints/91053_8345177088_250203-103232/tfb_train/\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/easytpp_3_train_eval.ipynb b/notebooks/easytpp_3_train_eval.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..e731a2e5dee07a94347fd12364a78af0ba027175 --- /dev/null +++ b/notebooks/easytpp_3_train_eval.ipynb @@ -0,0 +1,667 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_3_train_eval.ipynb)\n", + "\n", + "# Tutorial 3:Training and Evaluation Pipeline in EasyTPP\n", + "\n", + "In this tutorial, we'll walk through the complete process of training and evaluating temporal point process (TPP) models using the **EasyTPP** framework.\n", + "\n", + "In this notebook, we will cover the following key aspects:\n", + "- **Data Preparation**: Loading and preprocessing event sequence data.\n", + "- **Model Training**: Configuring and training a Neural Hawkes Process (NHP) model.\n", + "- **Model Evaluation**: Assessing model performance using various metrics.\n", + "- **Visualization**: Analyzing and visualizing model predictions and results.\n", + "\n", + "We begin by installing the package" + ], + "metadata": { + "id": "mprLutjnft_a" + } + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "collapsed": true, + "id": "aH28ufHMa-QU", + "outputId": "2e4ca8e1-41a4-4c44-deb5-9fc0a904a2fc" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: easy-tpp in /usr/local/lib/python3.11/dist-packages (0.1.0)\n", + "Collecting easy-tpp\n", + " Downloading easy_tpp-0.1.2-py3-none-any.whl.metadata (533 bytes)\n", + "Requirement already satisfied: PyYAML>=5.1 in /usr/local/lib/python3.11/dist-packages (from easy-tpp) (6.0.2)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from easy-tpp) (2.0.2)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from easy-tpp) (2.2.2)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (from easy-tpp) (2.6.0+cu124)\n", + "Requirement already satisfied: tensorboard in /usr/local/lib/python3.11/dist-packages (from easy-tpp) (2.18.0)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from easy-tpp) (24.2)\n", + "Requirement already satisfied: datasets in /usr/local/lib/python3.11/dist-packages (from easy-tpp) (3.5.0)\n", + "Requirement already satisfied: omegaconf in /usr/local/lib/python3.11/dist-packages (from easy-tpp) (2.3.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from datasets->easy-tpp) (3.18.0)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets->easy-tpp) (18.1.0)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.11/dist-packages (from datasets->easy-tpp) (0.3.8)\n", + "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.11/dist-packages (from datasets->easy-tpp) (2.32.3)\n", + "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.11/dist-packages (from datasets->easy-tpp) (4.67.1)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from datasets->easy-tpp) (3.5.0)\n", + "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.11/dist-packages (from datasets->easy-tpp) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.12.0,>=2023.1.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets->easy-tpp) (2024.12.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from datasets->easy-tpp) (3.11.15)\n", + "Requirement already satisfied: huggingface-hub>=0.24.0 in /usr/local/lib/python3.11/dist-packages (from datasets->easy-tpp) (0.30.1)\n", + "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf->easy-tpp) (4.9.3)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->easy-tpp) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->easy-tpp) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->easy-tpp) (2025.2)\n", + "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.11/dist-packages (from tensorboard->easy-tpp) (1.4.0)\n", + "Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.11/dist-packages (from tensorboard->easy-tpp) (1.71.0)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.11/dist-packages (from tensorboard->easy-tpp) (3.7)\n", + "Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in /usr/local/lib/python3.11/dist-packages (from tensorboard->easy-tpp) (5.29.4)\n", + "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.11/dist-packages (from tensorboard->easy-tpp) (75.2.0)\n", + "Requirement already satisfied: six>1.9 in /usr/local/lib/python3.11/dist-packages (from tensorboard->easy-tpp) (1.17.0)\n", + "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from tensorboard->easy-tpp) (0.7.2)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from tensorboard->easy-tpp) (3.1.3)\n", + "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (4.13.0)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (3.1.6)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (12.4.127)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (12.4.127)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (12.4.127)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (9.1.0.70)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (12.4.5.8)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (11.2.1.3)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (10.3.5.147)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (11.6.1.9)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (12.3.1.170)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (0.6.2)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (2.21.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (12.4.127)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (12.4.127)\n", + "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (3.2.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch->easy-tpp) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch->easy-tpp) (1.3.0)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets->easy-tpp) (2.6.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets->easy-tpp) (1.3.2)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets->easy-tpp) (25.3.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets->easy-tpp) (1.5.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets->easy-tpp) (6.3.1)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets->easy-tpp) (0.3.1)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets->easy-tpp) (1.18.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets->easy-tpp) (3.4.1)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets->easy-tpp) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets->easy-tpp) (2.3.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets->easy-tpp) (2025.1.31)\n", + "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.11/dist-packages (from werkzeug>=1.0.1->tensorboard->easy-tpp) (3.0.2)\n", + "Downloading easy_tpp-0.1.2-py3-none-any.whl (126 kB)\n", + "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m126.5/126.5 kB\u001B[0m \u001B[31m2.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", + "\u001B[?25hInstalling collected packages: easy-tpp\n", + " Attempting uninstall: easy-tpp\n", + " Found existing installation: easy-tpp 0.1.0\n", + " Uninstalling easy-tpp-0.1.0:\n", + " Successfully uninstalled easy-tpp-0.1.0\n", + "Successfully installed easy-tpp-0.1.2\n" + ] + }, + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "easy_tpp" + ] + }, + "id": "a33b6aa166cb43cfa25793118ff1aafc" + } + }, + "metadata": {} + } + ], + "source": [ + "! pip install --upgrade easy-tpp" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Step 1: setup the config file\n", + "\n", + "The EasyTPP framework uses a YAML-based configuration system that consists of two main components:\n", + "- **data configuration**, which specifies the data sources and their formats. It defines where to find the training, validation, and test datasets (see below as an example). The full explanation of the dataset can be found in the previous tutorial [EasyTPP-Dataset](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_1_dataset.ipynb).\n", + "- **model configuration**, which defines the model architecture, hyperparameters, and training settings.\n", + "\n", + "Let's first look at the data configuration:" + ], + "metadata": { + "id": "iz34tMDNg80K" + } + }, + { + "cell_type": "code", + "source": [ + "data_config = \"\"\"\n", + "data:\n", + " taxi:\n", + " data_format: json\n", + " train_dir: ./data/taxi/train.json\n", + " valid_dir: ./data/taxi/dev.json\n", + " test_dir: ./data/taxi/test.json\n", + "\n", + "\"\"\"" + ], + "metadata": { + "id": "3Sap8WlqgNQz" + }, + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "The model configuration specifies the neural network architecture and training hyperparameters. It is structured into two main sections: base configuration for general settings and trainer-specific configuration for detailed training parameters:\n" + ], + "metadata": { + "id": "ZJTYO2u02q6h" + } + }, + { + "cell_type": "code", + "source": [ + "model_config = \"\"\"\n", + "NHP_train:\n", + " base_config:\n", + " stage: train\n", + " backend: torch\n", + " dataset_id: taxi\n", + " runner_id: std_tpp\n", + " model_id: NHP # model name\n", + " base_dir: './checkpoints/'\n", + " trainer_config:\n", + " batch_size: 256\n", + " max_epoch: 2\n", + " shuffle: True\n", + " optimizer: adam\n", + " learning_rate: 1.e-3\n", + " valid_freq: 1\n", + " use_tfb: False\n", + " metrics: [ 'acc', 'rmse' ]\n", + " seed: 2019\n", + " gpu: -1\n", + "\"\"\"" + ], + "metadata": { + "id": "Yq9hnfrAzXxI" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "As an illustrative example, we write the YAML content to a file yaml_content as below. Note that `runner_config` is the pipeline configuration ID that tells EasyTPP which configuration to use. We combine the data configuration\n", + "and model configuration into a single YAML file. This file will be used to initialize the model runner, which will handle the training process based on the specified parameters." + ], + "metadata": { + "id": "ZhppAAZ-2xMN" + } + }, + { + "cell_type": "code", + "source": [ + "yaml_content = \"\"\"\n", + "pipeline_config_id: runner_config\n", + "\n", + "data:\n", + " taxi:\n", + " data_format: json\n", + " train_dir: easytpp/taxi # ./data/taxi/train.json\n", + " valid_dir: easytpp/taxi # ./data/taxi/dev.json\n", + " test_dir: easytpp/taxi # ./data/taxi/test.json\n", + " data_specs:\n", + " num_event_types: 10\n", + " pad_token_id: 10\n", + " padding_side: right\n", + "\n", + "\n", + "NHP_train:\n", + " base_config:\n", + " stage: train\n", + " backend: torch\n", + " dataset_id: taxi\n", + " runner_id: std_tpp\n", + " model_id: NHP # model name\n", + " base_dir: './checkpoints/'\n", + " trainer_config:\n", + " batch_size: 256\n", + " max_epoch: 2\n", + " shuffle: False\n", + " optimizer: adam\n", + " learning_rate: 1.e-3\n", + " valid_freq: 1\n", + " use_tfb: True\n", + " metrics: [ 'acc', 'rmse' ]\n", + " seed: 2019\n", + " gpu: -1\n", + " model_config:\n", + " hidden_size: 32\n", + " loss_integral_num_sample_per_step: 20\n", + " thinning:\n", + " num_seq: 10\n", + " num_sample: 1\n", + " num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm\n", + " look_ahead_time: 10\n", + " patience_counter: 5 # the maximum iteration used in adaptive thinning\n", + " over_sample_rate: 5\n", + " num_samples_boundary: 5\n", + " dtime_max: 5\n", + " num_step_gen: 1\n", + "\"\"\"\n", + "\n", + "# Save the content to a file named config.yaml\n", + "with open(\"config.yaml\", \"w\") as file:\n", + " file.write(yaml_content)" + ], + "metadata": { + "id": "BZk6_2092xrD" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Step 2: Train the model\n", + "\n", + "Then we can initialize the model runner using the configuration file we just created. The model runner will handle the training process based on the specified parameters in our configuration.\n", + "\n", + "We'll use the `Config` class to build the configuration from the YAML file and specify an experiment ID (key of the `model_config`). Then, we'll create a `Runner` instance from this configuration and run the training process." + ], + "metadata": { + "id": "tWzAAQrk25zZ" + } + }, + { + "cell_type": "code", + "source": [ + "from easy_tpp.config_factory import Config\n", + "from easy_tpp.runner import Runner\n", + "\n", + "config = Config.build_from_yaml_file('./config.yaml', experiment_id='NHP_train')\n", + "\n", + "model_runner = Runner.build_from_config(config)\n", + "\n", + "model_runner.run()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mkwjV7hL21Q6", + "outputId": "b5002b2f-ccf3-4267-c009-2848d679d311" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001B[31;1m2025-04-06 05:42:18,953 - config.py[pid:60211;line:34:build_from_yaml_file] - CRITICAL: Load pipeline config class RunnerConfig\u001B[0m\n", + "\u001B[31;1m2025-04-06 05:42:18,957 - runner_config.py[pid:60211;line:151:update_config] - CRITICAL: train model NHP using CPU with torch backend\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:42:18,970 - runner_config.py[pid:60211;line:35:__init__] - INFO: Save the config to ./checkpoints/60211_133800770625536_250406-054218/NHP_train_output.yaml\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:42:18,972 - base_runner.py[pid:60211;line:176:save_log] - INFO: Save the log to ./checkpoints/60211_133800770625536_250406-054218/log\u001B[0m\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.2244252199397379 0.29228809611195583\n", + "min_dt: 0.000277777777777\n", + "max_dt: 5.721388888888889\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:tensorflow:From /usr/local/lib/python3.11/dist-packages/tensorflow/python/compat/v2_compat.py:98: disable_resource_variables (from tensorflow.python.ops.resource_variables_toggle) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "non-resource variables are not supported in the long term\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001B[38;20m2025-04-06 05:42:34,617 - tpp_runner.py[pid:60211;line:60:_init_model] - INFO: Num of model parameters 15252\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:42:36,901 - base_runner.py[pid:60211;line:98:train] - INFO: Data 'taxi' loaded...\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:42:36,902 - base_runner.py[pid:60211;line:103:train] - INFO: Start NHP training...\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:42:39,574 - tpp_runner.py[pid:60211;line:96:_train_model] - INFO: [ Epoch 0 (train) ]: train loglike is -1.755373397054743, num_events is 50454\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:42:45,689 - tpp_runner.py[pid:60211;line:107:_train_model] - INFO: [ Epoch 0 (valid) ]: valid loglike is -1.6691416010202664, num_events is 7204, acc is 0.4415602443087174, rmse is 0.33315836060539783\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:42:54,655 - tpp_runner.py[pid:60211;line:122:_train_model] - INFO: [ Epoch 0 (test) ]: test loglike is -1.6577474861303745, num_events is 14420, acc is 0.4467406380027739, rmse is 0.34015134195006963\u001B[0m\n", + "\u001B[31;1m2025-04-06 05:42:54,657 - tpp_runner.py[pid:60211;line:124:_train_model] - CRITICAL: current best loglike on valid set is -1.6691 (updated at epoch-0), best updated at this epoch\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:42:57,202 - tpp_runner.py[pid:60211;line:96:_train_model] - INFO: [ Epoch 1 (train) ]: train loglike is -1.6284447567648255, num_events is 50454\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:43:01,056 - tpp_runner.py[pid:60211;line:107:_train_model] - INFO: [ Epoch 1 (valid) ]: valid loglike is -1.5259201159945863, num_events is 7204, acc is 0.4582176568573015, rmse is 0.3376860494138715\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:43:10,024 - tpp_runner.py[pid:60211;line:122:_train_model] - INFO: [ Epoch 1 (test) ]: test loglike is -1.5121817029299585, num_events is 14420, acc is 0.4597087378640777, rmse is 0.34172900829909414\u001B[0m\n", + "\u001B[31;1m2025-04-06 05:43:10,025 - tpp_runner.py[pid:60211;line:124:_train_model] - CRITICAL: current best loglike on valid set is -1.5259 (updated at epoch-1), best updated at this epoch\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:43:10,028 - base_runner.py[pid:60211;line:110:train] - INFO: End NHP train! Cost time: 0.552m\u001B[0m\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Step 2: Evaluate a model\n", + "\n", + "After training completes, we can evaluate the model's performance on the test dataset. EasyTPP offers streamlined methods for loading trained models and conducting evaluations.\n", + "\n", + "First, we need to locate the saved model directory within the `checkpoints` folder specified in our configuration. For example, the model might be stored at `./checkpoints/60211_133800770625536_250406-054218/models/saved_model`.\n", + "\n", + "Next, we'll create a new configuration file specifically for evaluation. This evaluation process performs one-step prediction on the validation set defined in the `valid_dir` parameter.\n", + "\n", + "For demonstration purposes, we'll use the same dataset for our evaluation as we used during training.\n", + "\n", + "We write the YAML config for evaluation as below. This configuration specifies:\n", + "- The data sources for training, validation, and testing (only validation will be used, the other two are optional).\n", + "- The evaluation stage and backend framework\n", + "- Model parameters and batch size\n", + "- The path to our pretrained model from the previous training step\n", + "\n", + "Note: Make sure the `thinning_config` is included in the configuration. Also, ensure that 'rmse' and 'acc' are specified in the `metrics` list under `trainer_config` to properly compute and return these evaluation metrics. The error in the previous run was due to missing 'rmse' in the metrics configuration." + ], + "metadata": { + "id": "66aUobRvJktc" + } + }, + { + "cell_type": "code", + "source": [ + "eval_yaml =\"\"\"\n", + "pipeline_config_id: runner_config\n", + "\n", + "data:\n", + " taxi:\n", + " data_format: json\n", + " train_dir: easytpp/taxi # ./data/taxi/train.json\n", + " valid_dir: easytpp/taxi # ./data/taxi/dev.json\n", + " test_dir: easytpp/taxi # ./data/taxi/test.json\n", + " data_specs:\n", + " num_event_types: 10\n", + " pad_token_id: 10\n", + " padding_side: right\n", + "\n", + "NHP_eval:\n", + " base_config:\n", + " stage: eval\n", + " backend: torch\n", + " dataset_id: taxi\n", + " runner_id: std_tpp\n", + " base_dir: './checkpoints/'\n", + " model_id: NHP\n", + " trainer_config:\n", + " batch_size: 256\n", + " max_epoch: 1\n", + " metrics: [ 'acc', 'rmse' ]\n", + " model_config:\n", + " hidden_size: 32\n", + " use_ln: False\n", + " seed: 2019\n", + " gpu: -1\n", + " pretrained_model_dir: ./checkpoints/60211_133800770625536_250406-054218/models/saved_model\n", + " thinning:\n", + " num_seq: 10\n", + " num_sample: 1\n", + " num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm\n", + " look_ahead_time: 10\n", + " patience_counter: 5 # the maximum iteration used in adaptive thinning\n", + " over_sample_rate: 5\n", + " num_samples_boundary: 5\n", + " dtime_max: 5\n", + "\"\"\"\n", + "\n", + "# Save the content to a file named config.yaml\n", + "with open(\"eval_config.yaml\", \"w\") as file:\n", + " file.write(eval_yaml)" + ], + "metadata": { + "id": "8ADVVKg1JlG1" + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "We use the following script to evaluate the trained model. This will load the pretrained model from the specified directory and evaluate it on the test dataset\n", + "" + ], + "metadata": { + "id": "32WAmhBrClUD" + } + }, + { + "cell_type": "code", + "source": [ + "from easy_tpp.config_factory import Config\n", + "from easy_tpp.runner import Runner\n", + "\n", + "config = Config.build_from_yaml_file('./eval_config.yaml', experiment_id='NHP_eval')\n", + "\n", + "model_runner = Runner.build_from_config(config)\n", + "\n", + "model_runner.run()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bwPC6DI9J2LT", + "outputId": "08030865-15eb-49f8-c2a7-f53363cac43d" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001B[31;1m2025-04-06 05:43:28,529 - config.py[pid:60211;line:34:build_from_yaml_file] - CRITICAL: Load pipeline config class RunnerConfig\u001B[0m\n", + "\u001B[31;1m2025-04-06 05:43:28,532 - runner_config.py[pid:60211;line:151:update_config] - CRITICAL: validate model NHP using CPU with torch backend\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:43:28,543 - runner_config.py[pid:60211;line:35:__init__] - INFO: Save the config to ./checkpoints/60211_133800770625536_250406-054328/NHP_eval_output.yaml\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:43:28,545 - base_runner.py[pid:60211;line:176:save_log] - INFO: Save the log to ./checkpoints/60211_133800770625536_250406-054328/log\u001B[0m\n", + "0.2244252199397379 0.29228809611195583\n", + "min_dt: 0.000277777777777\n", + "max_dt: 5.721388888888889\n", + "\u001B[38;20m2025-04-06 05:43:29,769 - tpp_runner.py[pid:60211;line:60:_init_model] - INFO: Num of model parameters 15252\u001B[0m\n", + "\u001B[31;1m2025-04-06 05:43:29,773 - tpp_runner.py[pid:60211;line:81:_load_model] - CRITICAL: Load model from ./checkpoints/60211_133800770625536_250406-054218/models/saved_model\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:43:30,305 - base_runner.py[pid:60211;line:117:evaluate] - INFO: Data 'taxi' loaded...\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:43:30,306 - base_runner.py[pid:60211;line:122:evaluate] - INFO: Start NHP evaluation...\u001B[0m\n", + "\u001B[31;1m2025-04-06 05:43:35,132 - tpp_runner.py[pid:60211;line:148:_evaluate_model] - CRITICAL: Evaluation result: loglike is -1.5259201159945863, num_events is 7204, acc is 0.4583564686285397, rmse is 0.3325426308842327\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:43:35,133 - base_runner.py[pid:60211;line:128:evaluate] - INFO: End NHP evaluation! Cost time: 0.080m\u001B[0m\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "np.float64(0.3325426308842327)" + ] + }, + "metadata": {}, + "execution_count": 7 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Step 3: Generate predictions\n", + "\n", + "After training and evaluation, we can generate predictions for new events given a sequence as `context`.\n", + "\n", + "The generation process uses the trained model to perform the one-step (by default) or multi-step prediction.\n", + "\n", + "In order to evaluate the prediction accuracy, we automatically mask the last-n events (where n depends on the prediction step count) and use them as ground truth (golden events) to compare against our model's predictions.\n", + "\n", + "We'll use the same configuration file we used for evaluation, but we'll specify a different experiment ID for the prediction task.\n", + "\n" + ], + "metadata": { + "id": "ewCrMYQ0HxLU" + } + }, + { + "cell_type": "code", + "source": [ + "gen_yaml =\"\"\"\n", + "pipeline_config_id: runner_config\n", + "\n", + "data:\n", + " taxi:\n", + " data_format: json\n", + " train_dir: easytpp/taxi # ./data/taxi/train.json\n", + " valid_dir: easytpp/taxi # ./data/taxi/dev.json\n", + " test_dir: easytpp/taxi # ./data/taxi/test.json\n", + " data_specs:\n", + " num_event_types: 10\n", + " pad_token_id: 10\n", + " padding_side: right\n", + "\n", + "NHP_gen:\n", + " base_config:\n", + " stage: gen\n", + " backend: torch\n", + " dataset_id: taxi\n", + " runner_id: std_tpp\n", + " base_dir: './checkpoints/'\n", + " model_id: NHP\n", + " trainer_config:\n", + " batch_size: 256\n", + " max_epoch: 1\n", + " metrics: [ 'acc', 'rmse' ]\n", + " model_config:\n", + " hidden_size: 32\n", + " use_ln: False\n", + " seed: 2019\n", + " gpu: -1\n", + " pretrained_model_dir: ./checkpoints/60211_133800770625536_250406-054218/models/saved_model\n", + " thinning:\n", + " num_seq: 10\n", + " num_sample: 1\n", + " num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm\n", + " look_ahead_time: 10\n", + " patience_counter: 5 # the maximum iteration used in adaptive thinning\n", + " over_sample_rate: 5\n", + " num_samples_boundary: 5\n", + " dtime_max: 5\n", + " num_step_gen: 1\n", + "\"\"\"\n", + "\n", + "# Save the content to a file named config.yaml\n", + "with open(\"gen_config.yaml\", \"w\") as file:\n", + " file.write(gen_yaml)" + ], + "metadata": { + "id": "FIGDkivPLO3N" + }, + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "We continue to use a similar configuration file to launch the generation process. Note that it is necessary to explicitly specify `num_gen_step` in the thinning config. The results will be saved in a pickle file.\n", + "\n", + "We acknowledge that the generation pipeline is still a work in progress. For instance, the multi-step sampling process can be further improved (ongoing work), and the output format could be enhanced to facilitate easier evaluation." + ], + "metadata": { + "id": "ymXwak6WoOGB" + } + }, + { + "cell_type": "code", + "source": [ + "from easy_tpp.config_factory import Config\n", + "from easy_tpp.runner import Runner\n", + "\n", + "config = Config.build_from_yaml_file('./gen_config.yaml', experiment_id='NHP_gen')\n", + "\n", + "model_runner = Runner.build_from_config(config)\n", + "\n", + "model_runner.run()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "kgJVjrh9HisJ", + "outputId": "f62fe3bc-8484-49f5-f814-166971d2a27a" + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001B[31;1m2025-04-06 05:43:48,495 - config.py[pid:60211;line:34:build_from_yaml_file] - CRITICAL: Load pipeline config class RunnerConfig\u001B[0m\n", + "\u001B[31;1m2025-04-06 05:43:48,498 - runner_config.py[pid:60211;line:151:update_config] - CRITICAL: predict model NHP using CPU with torch backend\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:43:48,508 - runner_config.py[pid:60211;line:35:__init__] - INFO: Save the config to ./checkpoints/60211_133800770625536_250406-054348/NHP_gen_output.yaml\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:43:48,510 - base_runner.py[pid:60211;line:176:save_log] - INFO: Save the log to ./checkpoints/60211_133800770625536_250406-054348/log\u001B[0m\n", + "0.2244252199397379 0.29228809611195583\n", + "min_dt: 0.000277777777777\n", + "max_dt: 5.721388888888889\n", + "\u001B[38;20m2025-04-06 05:43:49,245 - tpp_runner.py[pid:60211;line:60:_init_model] - INFO: Num of model parameters 15252\u001B[0m\n", + "\u001B[31;1m2025-04-06 05:43:49,249 - tpp_runner.py[pid:60211;line:81:_load_model] - CRITICAL: Load model from ./checkpoints/60211_133800770625536_250406-054218/models/saved_model\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:43:49,777 - base_runner.py[pid:60211;line:135:gen] - INFO: Data 'taxi' loaded...\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:43:49,779 - base_runner.py[pid:60211;line:140:gen] - INFO: Start NHP evaluation...\u001B[0m\n", + "\u001B[31;1m2025-04-06 05:43:50,159 - tpp_runner.py[pid:60211;line:162:_gen_model] - CRITICAL: Save the prediction to pickle file pred.pkl\u001B[0m\n", + "\u001B[38;20m2025-04-06 05:43:50,161 - base_runner.py[pid:60211;line:146:gen] - INFO: End NHP generation! Cost time: 0.006m\u001B[0m\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "id": "pTreVIxWof_R" + } + } + ] +} diff --git a/notebooks/s2p2_preprocess_ehrshot_cpt4.ipynb b/notebooks/s2p2_preprocess_ehrshot_cpt4.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..1886d0939ab0c8c5a20cc397cd73af430e151815 --- /dev/null +++ b/notebooks/s2p2_preprocess_ehrshot_cpt4.ipynb @@ -0,0 +1,410 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f0c61845", + "metadata": {}, + "source": [ + "# EHRSHOT Dataset Preprocessing (S2P2 Paper)\n", + "This notebook includes code for preparing the EHRSHOT event sequence dataset from the raw [EHRSHOT dataset](https://som-shahlab.github.io/ehrshot-website/), where medical services and procedures are treated as marks, as identified by _Current Procedural Terminology_ (CPT-4) codes.\n", + "\n", + "This version of dataset was originally used in evaluating the [State-Space Point Process (S2P2)](https://openreview.net/pdf?id=74SvE2GZwW) model. Note that we cannot distribute the raw data (or derivative dataset) under the terms of the original EHRSHOT dataset. The access to data can be applied [here](https://stanford.redivis.com/datasets/53gc-8rhx41kgt)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e3fb9cd3", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from collections import defaultdict\n", + "import heapq\n", + "from tqdm import tqdm\n", + "from easy_tpp.utils import set_seed\n", + "import random\n", + "import json" + ] + }, + { + "cell_type": "markdown", + "id": "83a40695ae9eeed7", + "metadata": {}, + "source": [ + "### 0. Load data and check if it's complete" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d2d2d13a79f2268c", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T21:52:03.248612Z", + "start_time": "2025-05-22T21:51:13.046926Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/fz/4gzrrkvs2_77xs43jry5yp9w0000gn/T/ipykernel_26867/1025842674.py:3: DtypeWarning: Columns (3,5) have mixed types. Specify dtype option on import or set low_memory=False.\n", + " df_dataset = pd.read_csv(path_to_data_csv)\n" + ] + } + ], + "source": [ + "path_to_data_csv = '../data/EHRSHOT/EHRSHOT_ASSETS/data/ehrshot.csv'\n", + "path_to_splits_csv = '../data/EHRSHOT/EHRSHOT_ASSETS/splits/person_id_map.csv'\n", + "df_dataset = pd.read_csv(path_to_data_csv)\n", + "df_split = pd.read_csv(path_to_splits_csv)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "18ef70b0ec958120", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T21:52:04.952046Z", + "start_time": "2025-05-22T21:52:04.031668Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# of events: 41661637\n", + "# of patients: 6739\n", + "# of visits: 921499\n", + "# of train patients 2295\n", + "# of val patients 2232\n", + "# of test patients 2212\n" + ] + } + ], + "source": [ + "# check if the same data as the original repo: https://github.com/som-shahlab/ehrshot-benchmark/blob/main/ehrshot/stats.ipynb\n", + "print(\"# of events:\", df_dataset.shape[0])\n", + "print(\"# of patients:\", df_dataset['patient_id'].nunique())\n", + "print(\"# of visits:\", df_dataset['visit_id'].nunique())\n", + "print(\"# of train patients\", df_split[df_split['split'] == 'train']['omop_person_id'].nunique())\n", + "print(\"# of val patients\", df_split[df_split['split'] == 'val']['omop_person_id'].nunique())\n", + "print(\"# of test patients\", df_split[df_split['split'] == 'test']['omop_person_id'].nunique())" + ] + }, + { + "cell_type": "markdown", + "id": "bd3aa280bbe0d3e", + "metadata": {}, + "source": [ + "### 1. Get event times for visit occurrence" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "55e9a82ed0910e84", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T21:55:43.462174Z", + "start_time": "2025-05-22T21:55:40.272675Z" + } + }, + "outputs": [], + "source": [ + "df_visit = df_dataset[df_dataset['omop_table'] == 'visit_occurrence']\n", + "df_visit.loc[:, 'start'] = pd.to_datetime(df_visit['start']).apply(lambda x: int(round(x.timestamp())))\n", + "df_visit_time = df_visit[['patient_id', 'start']].drop_duplicates(keep=False)\n", + "df_visit_time = df_visit_time.groupby(['patient_id'])['start'].apply(lambda x: sorted(list(set(x)))).reset_index(name='timestamp')\n", + "visit_dict = pd.Series(df_visit_time.timestamp.values, index=df_visit_time.patient_id).to_dict()\n", + "patient_visit = df_visit_time['patient_id'].to_numpy()" + ] + }, + { + "cell_type": "markdown", + "id": "7eddf9ac67745f1a", + "metadata": {}, + "source": [ + "### 2. Get CPT4 codes that have at least 100 frequencies" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "23d14c81fec5f549", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T21:59:11.017201Z", + "start_time": "2025-05-22T21:59:00.594692Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of marks after filtering: 668\n" + ] + } + ], + "source": [ + "df_cpt4 = df_dataset[df_dataset['code'].str.contains('CPT4', case=False, na=False)]\n", + "mark_val, mark_count = np.unique(df_cpt4.loc[:,'code'].to_numpy(), return_counts=True)\n", + "\n", + "mark_mask = (mark_count >= 100)\n", + "print(f'Number of marks after filtering: {sum(mark_mask)}')" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "370a230ac0d5213", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T21:59:12.675372Z", + "start_time": "2025-05-22T21:59:11.025996Z" + } + }, + "outputs": [], + "source": [ + "mark_val = mark_val[mark_mask]\n", + "mark_val_set = set(mark_val)\n", + "df_cpt4_subset = df_cpt4[df_cpt4['code'].isin(mark_val_set)][['patient_id', 'start', 'code']]\n", + "df_cpt4_subset['start'] = pd.to_datetime(df_cpt4_subset.loc[:,'start']).apply(lambda x: int(round(x.timestamp())))\n", + "df_cpt4_subset['code'] = df_cpt4_subset['code'].astype('category').cat.codes\n", + "mark_val_subset, mark_count_subset = np.unique(df_cpt4_subset.loc[:,'code'].to_numpy(), return_counts=True)\n", + "mark_count_dict = dict(zip(mark_val_subset, mark_count_subset))\n", + "patient_cpt4 = df_cpt4_subset['patient_id'].unique()" + ] + }, + { + "cell_type": "markdown", + "id": "10020d5895b3afc9", + "metadata": {}, + "source": [ + "### 3. Generate event sequences" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a744ef372bd6dd72", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T21:59:37.489196Z", + "start_time": "2025-05-22T21:59:37.485658Z" + } + }, + "outputs": [], + "source": [ + "def sample_event_times(real_event_time, std, size):\n", + " sampled_times = np.random.normal(real_event_time, scale=std, size=size)\n", + " # resample if not all non-negative, might be updated\n", + " while not np.all(sampled_times > 0):\n", + " sampled_times = np.random.normal(real_event_time, scale=std, size=size)\n", + " return sampled_times" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "5ff4de761c2847c6", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T22:01:05.315820Z", + "start_time": "2025-05-22T22:00:56.872143Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 6634/6634 [00:08<00:00, 786.48it/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6183\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "time_norm = 60 * 60 # in seconds\n", + "min_events = 5\n", + "max_marks_per_time = 10\n", + "padding_events = 668\n", + "all_sequences = []\n", + "idx = 0\n", + "set_seed(123)\n", + "\n", + "\n", + "for patient in tqdm(patient_cpt4):\n", + " patient = int(patient)\n", + " data = df_cpt4_subset[df_cpt4_subset.patient_id == patient]\n", + " if len(data) < 5 or len(data.start.unique()) < 2:\n", + " continue\n", + " events = list(zip(data['start'], data['code']))\n", + " sorted_unique_times = sorted(data.start.unique())\n", + " if not len(np.diff(sorted_unique_times)):\n", + " print(len(data))\n", + " print(len(events))\n", + " min_diff = min(np.diff(sorted_unique_times)) # minimum time between two consecutive events\n", + "\n", + " event_dict = defaultdict(list)\n", + " base_time = int(sorted_unique_times[0])\n", + " for t, m in events:\n", + " event_dict[(t - base_time)/time_norm].append(m)\n", + "\n", + " std = min(min_diff/time_norm, 1) / 10 # std. for Normal distribution to jitter event times\n", + " event_times = []\n", + " event_marks = []\n", + " for t in sorted_unique_times:\n", + " t = (t - base_time)/time_norm\n", + " v = event_dict[t]\n", + " if len(v) > max_marks_per_time: # choose mark by frequency\n", + " h = []\n", + " for mark in v:\n", + " if len(h) < max_marks_per_time:\n", + " heapq.heappush(h, (mark_count_dict[mark], mark))\n", + " else:\n", + " heapq.heappushpop(h, (mark_count_dict[mark], mark))\n", + " v = [x[1] for x in h]\n", + " else:\n", + " np.random.shuffle(v)\n", + "\n", + " sampled_times = sample_event_times(t, std, min(max_marks_per_time, len(v)) - 1)\n", + " times = sorted([t] + list(sampled_times))\n", + " times = [float(t) for t in times]\n", + " event_times.extend(times)\n", + " event_marks.extend(v)\n", + " assert len(v) <= max_marks_per_time\n", + " assert(len(event_times) == len(event_marks))\n", + " assert(min_events <= len(event_times))\n", + "\n", + " # padding the start and end of sequences to have padding events\n", + " event_marks[0] = padding_events\n", + " event_marks[-1] = padding_events\n", + "\n", + " all_sequences.append(\n", + " {\n", + " 'dim_process': padding_events,\n", + " 'seq_idx': idx,\n", + " 'seq_len': len(event_times),\n", + " 'time_since_start': event_times,\n", + " 'time_since_last_event': [0] + [event_times[i+1] - event_times[i] for i in range(len(event_times) - 1)],\n", + " 'type_event': event_marks,\n", + " }\n", + " )\n", + " idx += 1\n", + "print(len(all_sequences)) # 6183" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "61b6d934ea56003b", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T22:01:15.774095Z", + "start_time": "2025-05-22T22:01:15.766903Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "test: 927\n", + "valid: 927\n", + "train: 4329\n" + ] + } + ], + "source": [ + "test_pct, valid_pct, train_pct = 0.15, 0.15, 0.7\n", + "test_seqs, valid_seqs, train_seqs = [], [], []\n", + "\n", + "random.shuffle(all_sequences)\n", + "for i, seq in enumerate(all_sequences):\n", + " progress = (i + 1) / len(all_sequences)\n", + " if progress <= test_pct:\n", + " test_seqs.append(seq)\n", + " elif progress <= test_pct + valid_pct:\n", + " valid_seqs.append(seq)\n", + " else:\n", + " train_seqs.append(seq)\n", + "\n", + "print(f'test: {len(test_seqs)}')\n", + "print(f'valid: {len(valid_seqs)}')\n", + "print(f'train: {len(train_seqs)}')" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "f9c1d7894a5792bb", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T22:02:34.872771Z", + "start_time": "2025-05-22T22:02:32.829332Z" + } + }, + "outputs": [], + "source": [ + "# # Save results\n", + "# with open('./ehrshot_cpt4/train.json', 'w') as f:\n", + "# json.dump(train_seqs, f)\n", + "#\n", + "# with open('./ehrshot_cpt4/dev.json', 'w') as f:\n", + "# json.dump(valid_seqs, f)\n", + "#\n", + "# with open('./ehrshot_cpt4/test.json', 'w') as f:\n", + "# json.dump(test_seqs, f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60c3455349b9fb3d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "easytpp", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/requirements-doc.txt b/requirements-doc.txt new file mode 100644 index 0000000000000000000000000000000000000000..1c9831bc16e6730a366dac10b8f6f33ff9389618 --- /dev/null +++ b/requirements-doc.txt @@ -0,0 +1,9 @@ +sphinx +sphinx_rtd_theme +myst_parser +nbsphinx +nbsphinx_link +sphinx_fontawesome +sphinx-autobuild +recommonmark +sphinx_markdown_tables \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..25978a9290260c6a52d90c55ccb837aed1d0004a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +PyYAML>=5.1 +numpy +pandas +torch +tensorboard +packaging +datasets +omegaconf diff --git a/requirements_compute_metrics.txt b/requirements_compute_metrics.txt new file mode 100644 index 0000000000000000000000000000000000000000..7af93c87702c5209e2a16bcc4c87ad4d5826cc76 --- /dev/null +++ b/requirements_compute_metrics.txt @@ -0,0 +1,15 @@ +# 计算级联指标所需的依赖包 + +# 深度学习框架 +torch>=1.9.0 + +# Transformers库(用于BERT、情感分析、语言模型) +transformers>=4.20.0 + +# 数值计算 +numpy>=1.21.0 + +# 进度条 +tqdm>=4.62.0 + +# JSON处理(Python标准库,无需安装) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..224a77957f5db48dfa25c8bb4a35f535202da203 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[metadata] +description-file = README.md \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..7439d09cac431c38fc62becd3156ce1c0da1a683 --- /dev/null +++ b/setup.py @@ -0,0 +1,61 @@ +import codecs +import os +import re +from setuptools import find_packages +from setuptools import setup + + +def readme(): + with codecs.open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +def get_version(): + version_file = os.path.join(os.path.dirname(__file__), "version.py") + version_regex = r"__version__ = ['\"]([^'\"]*)['\"]" + with open(version_file, "r") as f: + version = re.search(version_regex, f.read(), re.M).group(1) + return version + + +def parse_requirements(fname='requirements.txt'): + """Parse the package dependencies listed in a requirements file.""" + + def parse_line(line): + """Parse information from a line in a requirements text file.""" + if line.startswith('-r '): + # Allow specifying requirements in other files + target = line.split(' ')[1] + for line in parse_require_file(target): + yield line + else: + yield line + + def parse_require_file(fpath): + with codecs.open(fpath, 'r') as f: + for line in f.readlines(): + line = line.strip() + if line and not line.startswith('#'): + for ll in parse_line(line): + yield ll + + packages = list(parse_require_file(fname)) + return packages + + +setup( + name='easy_tpp', + version=get_version(), + description='An easy and flexible tool for neural temporal point process', + url = 'https://github.com/ant-research/EasyTemporalPointProcess/', + # long_description = 'Our EasyTPP makes several unique contributions to this area: a unified interface of using existing datasets and adding new datasets; a wide range of evaluation programs that are easy to use and extend as well as facilitate reproducible research; implementations of popular neural TPPs, together with a rich library of modules by composing which one could quickly build complex models. ', + # long_description=open('README.md').read(), + # long_description_content_type='text/markdown', + author='Alipay', + packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), + include_package_data=True, + classifiers=[ + 'Programming Language :: Python :: 3' + ], + install_requires=parse_requirements('requirements.txt')) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/synthetic_data.json b/tests/synthetic_data.json new file mode 100644 index 0000000000000000000000000000000000000000..cf243b0fd26bbe1a58335bc15301f3382e6084ed --- /dev/null +++ b/tests/synthetic_data.json @@ -0,0 +1,984 @@ +[ + { + "dim_process": 3, + "seq_idx": 0, + "seq_len": 40, + "time_since_start": [ + 0.0, + 0.07334207193560538, + 0.14867635742857066, + 0.1726252021389103, + 0.25901494747659404, + 0.47459203000652384, + 0.6681905106050308, + 0.9983091339994778, + 1.0812915957716949, + 1.1675091987931212, + 2.161986940110504, + 2.345042474004808, + 2.772762086199987, + 3.2001245612506404, + 3.518062428718216, + 3.74812753207482, + 4.841234092227733, + 5.071240817272003, + 5.1528419872758, + 6.474922871578715, + 6.487923821154683, + 6.934037727910714, + 7.143573377945178, + 7.24373338614723, + 8.434572536866543, + 9.261220466026185, + 9.337874271297713, + 9.665559785402946, + 9.92721401517502, + 10.15462043817297, + 10.179297078155905, + 10.901075779494658, + 11.719869024869833, + 12.897963679523235, + 13.51229749935315, + 13.63295973707896, + 13.754497835991206, + 14.064371388582165, + 14.137326170860174, + 14.282006093027979 + ], + "time_since_last_event": [ + 0.21483377973696283, + 0.07334207193560538, + 0.3635101371655335, + 0.023948844710339645, + 0.47384872721355686, + 0.30196682786761353, + 0.4091755631284367, + 0.523717103992954, + 1.0079495238360896, + 0.49931868818809044, + 1.1636778061110262, + 1.1775332752116867, + 0.42771961219517873, + 0.42736247505065394, + 1.3560754886077122, + 0.23006510335660346, + 1.6411095309770922, + 0.23000672504427033, + 0.08160117000379685, + 1.322080884302915, + 2.7397962890798633, + 5.852746132139019, + 0.668650506366463, + 0.10016000820205218, + 1.5005348089558304, + 2.017487079878955, + 0.07665380527152799, + 3.177635964248264, + 1.4926414783084763, + 0.4890606527700232, + 0.25208306298088523, + 0.7464553413216883, + 0.8187932453751756, + 2.71866660136733, + 0.6143338198299144, + 4.295085465781247, + 0.12153809891224654, + 0.3098735525909593, + 0.07295478227800878, + 2.5621370681581457 + ], + "type_event": [ + 2, + 2, + 0, + 0, + 1, + 0, + 1, + 0, + 2, + 1, + 0, + 1, + 1, + 1, + 0, + 0, + 1, + 1, + 1, + 1, + 0, + 2, + 1, + 1, + 2, + 1, + 1, + 0, + 2, + 0, + 2, + 0, + 0, + 2, + 2, + 1, + 1, + 1, + 1, + 0 + ] + }, + { + "dim_process": 3, + "seq_idx": 1, + "seq_len": 40, + "time_since_start": [ + 0.0, + 0.028181241924562173, + 0.5438498619372929, + 0.5569154982531774, + 0.953577493500255, + 1.1895510505354885, + 1.35639094666106, + 1.6092436402993702, + 1.958192751836167, + 1.980686087510831, + 2.889246627279709, + 3.0616695980970317, + 3.2227777644702744, + 3.476072199928547, + 3.518627533754609, + 3.7535732893134632, + 3.8559502708921674, + 3.9753134964929924, + 4.673629329089364, + 4.968085817816862, + 4.988967859806754, + 5.3650310274164905, + 6.151042050344506, + 6.264091463216575, + 6.515758199672076, + 6.73935587212671, + 6.7443365890479345, + 7.812801643375826, + 8.198863245563638, + 8.540383641658318, + 9.193627341970883, + 9.323682844729449, + 9.328530517286238, + 9.65555621469707, + 9.800406242796422, + 9.831660603326418, + 11.757155356034657, + 12.272708667350456, + 12.294194069669851, + 13.588682400715355 + ], + "time_since_last_event": [ + 1.0133819387017393, + 0.4165345091192769, + 0.5156686200127307, + 0.5569154982531774, + 1.1972508385271645, + 0.23597355703523348, + 0.7994754484078825, + 0.2528526936383102, + 0.34894911153679686, + 0.7911350369753425, + 2.345396765342416, + 0.1724229708173226, + 1.2420916769594434, + 1.5178794480923798, + 0.45695793565757725, + 0.5307955248431888, + 0.33732273713755845, + 0.4992412965644455, + 0.6983158325963714, + 0.29445648872749786, + 0.02088204198989274, + 0.3760631676097361, + 2.3974687610310426, + 0.8990604358000844, + 0.2516667364555012, + 0.5883138217822044, + 0.004980716921224371, + 1.2970434437037497, + 4.342912974671471, + 0.34152039609467977, + 1.380825698595057, + 2.5793462556815143, + 0.7881468756279197, + 0.3270256974108321, + 0.6067789008255389, + 0.17610438862934785, + 1.9254947527082393, + 0.5155533113157986, + 2.4937878268734295, + 4.264999555985906 + ], + "type_event": [ + 2, + 1, + 1, + 2, + 0, + 0, + 2, + 2, + 2, + 0, + 1, + 1, + 0, + 2, + 1, + 0, + 1, + 2, + 2, + 2, + 2, + 2, + 0, + 2, + 2, + 0, + 0, + 2, + 1, + 1, + 2, + 0, + 1, + 1, + 2, + 1, + 1, + 1, + 2, + 0 + ] + }, + { + "dim_process": 3, + "seq_idx": 2, + "seq_len": 40, + "time_since_start": [ + 0.0, + 0.8626648050686825, + 0.8680882740053697, + 1.05109163932989, + 1.2276095523526713, + 1.2489564646835873, + 1.4314374955158407, + 1.4697436092017355, + 1.7414063661812058, + 2.0690025658846203, + 2.0970914134891956, + 2.337709567442122, + 2.5472764721362786, + 2.5788766429322436, + 3.2173591674211615, + 3.4577025709438907, + 4.127082361135013, + 4.7681694887836095, + 4.88009727937165, + 5.2925139530938985, + 5.676594094182082, + 6.314692083311858, + 6.422565102661483, + 6.635636608750982, + 7.329775215567427, + 7.442698542772288, + 7.575824809661523, + 7.731426036543983, + 7.942799063490742, + 8.98730625826526, + 8.991457451699521, + 9.025829600136863, + 9.32689817112168, + 9.436400472649101, + 9.689023321112789, + 10.175170884402366, + 10.195467685952504, + 10.92788619283061, + 11.073509660718738, + 11.254506421753153 + ], + "time_since_last_event": [ + 0.5984554465524745, + 2.777093984986056, + 0.8680882740053697, + 2.9440354169278677, + 0.35952127834730163, + 0.19786482535369743, + 0.5687726904471582, + 0.2421340568490642, + 0.2716627569794703, + 0.32759619970341447, + 0.6656539179733549, + 0.2406181539529264, + 1.2983200074526913, + 0.24116707549012162, + 0.6700826952848828, + 1.3887000050592704, + 0.9097231937138517, + 1.3104669178397188, + 0.7530149182366372, + 2.713637310161655, + 0.9084246053984728, + 0.6380979891297756, + 1.542467823289833, + 0.3209445254391241, + 0.6941386068164448, + 1.0201334401108042, + 0.133126266889235, + 0.15560122688246025, + 0.6130238479233157, + 3.6947923051713616, + 0.004151193434260847, + 1.29440356359288, + 0.30106857098481754, + 0.44494302094958016, + 0.25262284846368743, + 2.2323718209116237, + 0.8685695148308241, + 1.2388628717178207, + 0.8780419747662336, + 1.0793355373507865 + ], + "type_event": [ + 0, + 1, + 0, + 2, + 0, + 2, + 1, + 0, + 0, + 0, + 1, + 1, + 2, + 1, + 2, + 0, + 2, + 0, + 2, + 1, + 0, + 0, + 2, + 0, + 0, + 2, + 2, + 2, + 0, + 1, + 1, + 2, + 2, + 1, + 1, + 0, + 2, + 1, + 2, + 0 + ] + }, + { + "dim_process": 3, + "seq_idx": 3, + "seq_len": 40, + "time_since_start": [ + 0.0, + 0.0830056964788497, + 0.36352744832758077, + 0.47496998614037267, + 0.9352622719863746, + 1.2990351084673577, + 1.3008666638451203, + 1.486446505062382, + 1.4947542677616852, + 2.2066585838066004, + 2.7166919283912776, + 2.809258335521214, + 3.129178005728015, + 3.364228899973199, + 3.3642795912552046, + 3.514973469202154, + 3.847116542588836, + 4.208627659659371, + 4.410776839770513, + 4.694164599573796, + 5.710079046207959, + 6.222272520135903, + 6.320797756527597, + 6.537466510620732, + 6.732998601183482, + 6.876774259114889, + 9.137481643364438, + 9.444531206258219, + 9.711834532816695, + 10.43723153229039, + 10.49278885193226, + 10.636585861871026, + 10.768756382782179, + 11.118291742836007, + 11.935622938190178, + 12.005206066140914, + 12.555326314586907, + 12.573655472439867, + 12.637171410925298, + 13.11589137456081 + ], + "time_since_last_event": [ + 0.2301202051892588, + 0.0830056964788497, + 0.28052175184873107, + 0.8507136592177602, + 0.9843857161412188, + 0.824065122326985, + 0.3656043918587457, + 0.1855798412172618, + 0.19571915929432748, + 1.8431311354790196, + 0.5100333445846772, + 1.314504067759529, + 0.3199196702068008, + 1.8777823949108168, + 5.069128200574369e-05, + 0.38579546347413896, + 1.1304246141975582, + 0.6936541904572167, + 0.5636602971816771, + 1.3298850083185911, + 1.015914446634163, + 0.5121934739279439, + 0.0985252363916942, + 0.2166687540931349, + 2.5243709415241113, + 0.1437756579314069, + 4.726704803593925, + 2.9070646956374873, + 0.26730332655847633, + 0.7253969994736948, + 3.616014592817372, + 0.14379700993876554, + 0.3315248504917889, + 0.349535360053828, + 0.8173311953541713, + 1.368620204269888, + 0.6197033763967283, + 0.01832915785296052, + 0.06351593848543047, + 3.978409731196372 + ], + "type_event": [ + 2, + 2, + 2, + 1, + 0, + 1, + 0, + 0, + 1, + 2, + 2, + 1, + 1, + 0, + 0, + 1, + 2, + 1, + 2, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 2, + 0, + 0, + 0, + 1, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 2 + ] + }, + { + "dim_process": 3, + "seq_idx": 4, + "seq_len": 40, + "time_since_start": [ + 0.0, + 0.4416247807831368, + 0.48750433084403255, + 0.7020416092936372, + 0.992536327687894, + 1.577621836266509, + 1.8225788436351849, + 3.0315639457271146, + 3.1397235399295127, + 3.1723165325911964, + 3.230405841676678, + 3.4828647213245247, + 3.6990155247822756, + 3.895838687216745, + 4.388293180042794, + 4.5129490815578635, + 4.8821207309319945, + 4.980566122341848, + 5.193309557546044, + 5.804945229813747, + 6.054265530966305, + 6.39113447111437, + 6.409096551563955, + 6.5577082509565585, + 7.149028919174562, + 7.196210315191166, + 8.941353283351432, + 9.210090032643876, + 9.271970353435108, + 10.007142810342941, + 10.379932186383911, + 10.544939925600602, + 10.67153459481024, + 10.977147993846728, + 11.253388072124537, + 11.51166277174984, + 12.75290712514959, + 12.914811290660694, + 13.157590494415402, + 13.176441193910719 + ], + "time_since_last_event": [ + 1.4584237878499877, + 0.4416247807831368, + 0.8352428102741243, + 0.21453727844960468, + 1.8189947707534984, + 1.1359970554833723, + 0.24495700736867576, + 1.2089851020919298, + 0.10815959420239807, + 2.470274923297559, + 0.09068230174716518, + 0.31054818873332835, + 0.21615080345775084, + 0.6654328455400673, + 0.4924544928260488, + 0.12465590151506944, + 3.8895844032441005, + 0.09844539140985376, + 1.4942940327637686, + 1.2919961482558833, + 1.0736994086244565, + 1.1978249135683257, + 0.3548310205976506, + 0.1486116993926032, + 0.7578944480601919, + 1.3912650853774196, + 2.383645032394874, + 2.0610611134693144, + 0.3306170700836759, + 0.7351724569078328, + 1.169842153740035, + 0.5377971152576606, + 3.4753242796190733, + 0.5972158074628169, + 0.5818534773142972, + 0.9667228461492385, + 1.499519053025054, + 1.4031485189108537, + 0.40468336926581117, + 0.2616299032500251 + ], + "type_event": [ + 1, + 1, + 2, + 2, + 0, + 1, + 1, + 1, + 1, + 2, + 1, + 2, + 2, + 1, + 1, + 1, + 0, + 0, + 2, + 1, + 0, + 2, + 0, + 0, + 2, + 1, + 0, + 2, + 0, + 0, + 2, + 0, + 1, + 2, + 1, + 0, + 1, + 0, + 1, + 0 + ] + }, + { + "dim_process": 3, + "seq_idx": 5, + "seq_len": 40, + "time_since_start": [ + 0.0, + 0.28851592414360994, + 0.3872771053751194, + 0.4804979888514396, + 0.49576145909792046, + 0.8412555565300295, + 1.0680214776406558, + 1.6515256583037683, + 1.7372038970994765, + 1.9311621459163888, + 1.9898274085689422, + 2.3588925681228545, + 2.646953270722747, + 2.6668282350050134, + 4.0453592279707635, + 4.197044976109865, + 4.394663470861971, + 5.218155920714338, + 5.311340407103003, + 5.467370844180138, + 5.831070696484645, + 6.27364337895952, + 6.665141954470641, + 6.68620245450127, + 7.0374232809360535, + 7.2486112505714715, + 7.3489895225118715, + 7.453810103012415, + 7.683278054372536, + 8.030016071647097, + 8.051464550628396, + 8.130961216766352, + 8.19664348117027, + 9.342313724511527, + 9.451599619821906, + 9.678103273638172, + 9.791248393136385, + 9.884237006418275, + 10.039387949180323, + 10.195879316902918 + ], + "time_since_last_event": [ + 2.821179735233528, + 0.910402459313147, + 0.3872771053751194, + 1.121235223516294, + 0.20724553495431053, + 0.345494097432109, + 0.5875234887892162, + 1.264248552928649, + 0.08567823879570824, + 0.863140668275733, + 0.2526235114694657, + 1.517637011592825, + 0.7157911248063584, + 0.019874964282266205, + 1.37853099296575, + 1.8381524079870104, + 0.34930424289120765, + 3.2283285121453957, + 0.9166769362410321, + 0.15603043707713482, + 0.3636998523045065, + 2.0765984028496547, + 0.39149857551112177, + 0.8551317580166256, + 0.35122082643478336, + 2.0304553298571335, + 0.6838475680412301, + 0.20519885244094382, + 0.229467951360121, + 0.9925927907110434, + 0.3681864962558592, + 0.10094514511925468, + 0.14517893054187425, + 1.1456702433412573, + 0.10928589531037858, + 2.3291137511263003, + 1.6602871763700335, + 0.20613373278010272, + 0.5877883293584176, + 0.40463092376653265 + ], + "type_event": [ + 2, + 0, + 2, + 1, + 0, + 0, + 1, + 2, + 2, + 1, + 2, + 0, + 1, + 1, + 1, + 0, + 1, + 2, + 1, + 1, + 1, + 0, + 0, + 1, + 1, + 2, + 0, + 2, + 2, + 1, + 2, + 1, + 2, + 2, + 2, + 0, + 1, + 0, + 2, + 1 + ] + }, + { + "dim_process": 3, + "seq_idx": 6, + "seq_len": 40, + "time_since_start": [ + 0.0, + 0.08290359490719368, + 0.39676029935586143, + 0.8784427497039502, + 1.0007089397488613, + 1.2092446316704155, + 1.2499874080330642, + 1.9752092196807922, + 2.328769317535304, + 2.488829515036727, + 2.525031396760383, + 2.744489328327461, + 2.996445980208236, + 4.721865302870469, + 4.773414271674682, + 5.47082934708385, + 5.881268833595357, + 6.025480050991604, + 6.332449501308346, + 6.449904244275885, + 6.809820983639597, + 7.176603788208595, + 7.639879246227835, + 8.398663249617698, + 8.4926464335567, + 8.729336668682564, + 9.102722294108716, + 9.600785980259559, + 10.125398398793791, + 10.198443178017982, + 10.214317753170377, + 10.389669534405414, + 10.807247841881917, + 11.237678145377458, + 11.437027124103906, + 11.635355323319146, + 11.902844731418185, + 12.054633594812103, + 12.171647205719722, + 12.351725674306309 + ], + "time_since_last_event": [ + 0.2923544530340365, + 0.21876668021863566, + 0.39676029935586143, + 0.7955391547967565, + 0.12226619004491113, + 0.2085356919215542, + 0.04074277636264867, + 1.5784489203249308, + 1.07878190950224, + 0.5136202953559348, + 0.19626207922507888, + 0.25565981329073395, + 0.4714145834478529, + 1.7254193226622334, + 0.051548968804212336, + 2.726340018756389, + 6.328774229391442, + 0.554650703907754, + 1.5590352296336647, + 0.42442419328428116, + 0.92855215004424, + 0.7266995439327104, + 1.3074297449194887, + 0.758784003389863, + 1.3160426453481051, + 1.9195156850429669, + 0.704059044491018, + 0.49806368615084295, + 1.396061730111228, + 0.5976571977584229, + 0.01587457515239521, + 0.2642711356116223, + 0.4175783074765036, + 2.745031711820758, + 0.629779282221989, + 0.3976771779416879, + 1.6885269782478076, + 0.15178886339391795, + 0.11701361090761964, + 0.7163703509871624 + ], + "type_event": [ + 2, + 1, + 2, + 1, + 1, + 1, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 1, + 1, + 2, + 0, + 2, + 1, + 2, + 0, + 2, + 1, + 1, + 2, + 0, + 1, + 1, + 0, + 1, + 1, + 0, + 0, + 2, + 0, + 2, + 1, + 1, + 1, + 2 + ] + }, + { + "dim_process": 3, + "seq_idx": 7, + "seq_len": 18, + "time_since_start": [ + 0.0, + 0.36287007802970095, + 0.7632289843324003, + 0.8783532318219756, + 1.388673155434276, + 1.5621573264109827, + 2.143576918420962, + 3.0850936715304016, + 3.1020250362732043, + 3.2750748827400997, + 3.957865751965926, + 5.751451655400771, + 6.118380301766848, + 6.239705211505068, + 6.320001547224251, + 6.39454692091077, + 6.737816962725873, + 7.211118745063402 + ], + "time_since_last_event": [ + 0.7163475979636473, + 1.9939162261957506, + 0.7632289843324003, + 0.11512424748957528, + 0.5103199236123004, + 2.4585833929612164, + 0.5814195920099792, + 1.6964205160961257, + 0.016931364742802657, + 0.17304984646689547, + 0.6827908692258262, + 3.6078747369798094, + 0.3669286463660768, + 0.1213249097382203, + 2.3621357952583253, + 0.0745453736865187, + 6.374946884696172, + 0.4733017823375292 + ], + "type_event": [ + 2, + 0, + 2, + 2, + 2, + 1, + 1, + 2, + 2, + 2, + 2, + 1, + 1, + 1, + 2, + 2, + 0, + 0 + ] + } +] \ No newline at end of file diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..3efc4a81460f32a0596f2f88fdd289f22946d3ee --- /dev/null +++ b/tests/test_data_loader.py @@ -0,0 +1,55 @@ +import unittest + +from easy_tpp.config_factory import DataSpecConfig +from easy_tpp.utils import load_json +from easy_tpp.preprocess.dataset import TPPDataset, EventTokenizer, get_data_loader + + +class TestDataLoader(unittest.TestCase): + def setUp(self): + # Assuming the data is already generated and saved in 'synthetic_hf_data.json' + self.data_file = 'synthetic_data.json' + self.batch_size = 4 + self.input_data = self._make_json_2_dict(self.data_file) + self.dataset = TPPDataset(self.input_data) + + config = DataSpecConfig.parse_from_yaml_config({'num_event_types': 3, + 'batch_size': self.batch_size, + 'pad_token_id': 3}) + + self.tokenizer = EventTokenizer(config) + + self.data_loader = get_data_loader(self.dataset, 'torch', self.tokenizer, batch_size=self.batch_size) + + def _make_json_2_dict(self, json_dir): + json_data = load_json(json_dir) + res = dict() + res['time_seqs'] = [x['time_since_start'] for x in json_data] + res['time_delta_seqs'] = [x['time_since_last_event'] for x in json_data] + res['type_seqs'] = [x['type_event'] for x in json_data] + return res + + def test_data_loading(self): + """Test if data is loaded correctly.""" + self.assertIsNotNone(self.input_data) + self.assertIsInstance(self.input_data, dict) + self.assertGreater(len(self.input_data), 0) + + def test_batch_generation(self): + """Test if batches are generated correctly.""" + self.assertGreater(len(self.data_loader), 0) + for batch in self.data_loader: + self.assertLessEqual(batch['time_seqs'].shape[0], self.batch_size) + self.assertIn('time_seqs', batch) + self.assertIn('time_delta_seqs', batch) + self.assertIn('type_seqs', batch) + + def test_batch_content(self): + """Test if batch content is correct.""" + for batch in self.data_loader: + self.assertEqual(len(batch['time_seqs']), len(batch['time_delta_seqs'])) + self.assertEqual(len(batch['time_seqs']), len(batch['type_seqs'])) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_nhp.py b/tests/test_nhp.py new file mode 100644 index 0000000000000000000000000000000000000000..f946a111aa0e8bbbea4d744faf0299cfbcc60ca3 --- /dev/null +++ b/tests/test_nhp.py @@ -0,0 +1,91 @@ +import unittest + +import numpy as np +import torch +import os +import sys + +# Get the directory of the current file +current_file_path = os.path.abspath(__file__) +sys.path.append(os.path.dirname(os.path.dirname(current_file_path))) + +from easy_tpp.model import TorchNHP as NHP +from easy_tpp.preprocess.dataset import get_data_loader +from easy_tpp.config_factory import DataSpecConfig, ModelConfig +from easy_tpp.utils import load_json +from easy_tpp.preprocess.dataset import TPPDataset, EventTokenizer + + +class TestNeuralHawkesProcess(unittest.TestCase): + def setUp(self): + """Set up the test environment.""" + # Assuming the data is already generated and saved in 'synthetic_hf_data.json' + self.data_file = 'synthetic_data.json' + self.batch_size = 4 + self.input_data = self._make_json_2_dict(self.data_file) + self.dataset = TPPDataset(self.input_data) + + config = DataSpecConfig.parse_from_yaml_config({'num_event_types': 3, + 'batch_size': self.batch_size, + 'pad_token_id': 3}) + + self.tokenizer = EventTokenizer(config) + + self.data_loader = get_data_loader(self.dataset, 'torch', self.tokenizer, batch_size=self.batch_size) + + model_config = ModelConfig.parse_from_yaml_config({'hidden_size': 32, + 'loss_integral_num_sample_per_step': 20, + 'num_event_types': 3, + 'num_event_types_pad': 4, + 'event_pad_index': 3}) + self.model = NHP(model_config) + + def _make_json_2_dict(self, json_dir): + json_data = load_json(json_dir) + res = dict() + res['time_seqs'] = [x['time_since_start'] for x in json_data] + res['time_delta_seqs'] = ([np.array(x['time_since_last_event'], dtype=np.float32) for x in json_data]) + res['type_seqs'] = [x['type_event'] for x in json_data] + return res + + def test_model_initialization(self): + """Test if the model is initialized correctly.""" + self.assertIsInstance(self.model, NHP) + self.assertEqual(self.model.hidden_size, 32) + + def test_forward_pass(self): + """Test the forward pass of the model.""" + batch = next(iter(self.data_loader)).values() + output = self.model(batch) + self.assertIsInstance(output[0], torch.Tensor) + self.assertIsInstance(output[1], torch.Tensor) + + def test_loss_computation(self): + """Test if the model computes loss correctly.""" + batch = next(iter(self.data_loader)).values() + loss = self.model.loglike_loss(batch) + self.assertGreater(loss[0].item(), 0) # Loss should be positive + + def test_backward_pass(self): + """Test if the model can perform a backward pass.""" + batch = next(iter(self.data_loader)).values() + loss = self.model.loglike_loss(batch) + loss[0].backward() + for param in self.model.parameters(): + self.assertIsNotNone(param.grad) # Ensure gradients are computed + + def test_training_step(self): + """Test a single training step.""" + optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) + self.model.train() + for batch in self.data_loader: + optimizer.zero_grad() + loss = self.model.loglike_loss(batch.values()) + loss[0].backward() + optimizer.step() + self.assertIsNotNone(loss[0]) # Ensure loss is computed + break # Only run one step for the test + + +if __name__ == '__main__': + unittest.main() diff --git a/version.py b/version.py new file mode 100644 index 0000000000000000000000000000000000000000..fc79d63d5430b972ac6ec1c4bfea9af80922da4d --- /dev/null +++ b/version.py @@ -0,0 +1 @@ +__version__ = '0.2.1'