Upload 6 files
Browse files- README.md +157 -0
- config.json +48 -0
- evaluation_results.json +58 -0
- gitpulse_weights.pt +3 -0
- model.py +263 -0
- requirements.txt +4 -0
README.md
CHANGED
|
@@ -1,3 +1,160 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: pytorch
|
| 6 |
+
tags:
|
| 7 |
+
- time-series
|
| 8 |
+
- multimodal
|
| 9 |
+
- transformer
|
| 10 |
+
- github
|
| 11 |
+
- forecasting
|
| 12 |
+
datasets:
|
| 13 |
+
- custom
|
| 14 |
+
metrics:
|
| 15 |
+
- mse
|
| 16 |
+
- mae
|
| 17 |
+
- r_squared
|
| 18 |
+
pipeline_tag: time-series-forecasting
|
| 19 |
---
|
| 20 |
+
|
| 21 |
+
# GitPulse: Multimodal Time Series Prediction for GitHub Project Health
|
| 22 |
+
|
| 23 |
+
GitPulse is a multimodal Transformer-based model that combines project text descriptions with historical activity data to predict GitHub project health metrics.
|
| 24 |
+
|
| 25 |
+
## Model Description
|
| 26 |
+
|
| 27 |
+
GitPulse leverages both **textual metadata** (project descriptions, topics) and **historical time series** (commits, issues, stars, etc.) to forecast future project activity. The key innovation is the adaptive fusion mechanism that dynamically balances text and time-series features.
|
| 28 |
+
|
| 29 |
+
### Architecture
|
| 30 |
+
|
| 31 |
+
- **Text Encoder**: DistilBERT-based encoder with attention pooling
|
| 32 |
+
- **Time Series Encoder**: Transformer encoder with positional embeddings
|
| 33 |
+
- **Adaptive Fusion**: Dynamic gating mechanism for multimodal fusion
|
| 34 |
+
- **Prediction Head**: MLP for generating future predictions
|
| 35 |
+
|
| 36 |
+
### Model Parameters
|
| 37 |
+
|
| 38 |
+
| Parameter | Value |
|
| 39 |
+
|-----------|-------|
|
| 40 |
+
| d_model | 128 |
|
| 41 |
+
| n_heads | 4 |
|
| 42 |
+
| n_layers | 2 |
|
| 43 |
+
| hist_len | 128 |
|
| 44 |
+
| pred_len | 32 |
|
| 45 |
+
| n_vars | 16 |
|
| 46 |
+
|
| 47 |
+
## Performance
|
| 48 |
+
|
| 49 |
+
Evaluated on 636 test samples from 4,232 GitHub projects:
|
| 50 |
+
|
| 51 |
+
| Model | MSE ↓ | MAE ↓ | R² ↑ | DA ↑ | TA@0.2 ↑ |
|
| 52 |
+
|-------|-------|-------|------|------|----------|
|
| 53 |
+
| **GitPulse** | **0.0755** | **0.1094** | **0.7559** | **86.68%** | **81.60%** |
|
| 54 |
+
| CondGRU+Text | 0.0915 | 0.1204 | 0.7043 | 84.05% | 80.14% |
|
| 55 |
+
| Transformer | 0.1142 | 0.1342 | 0.6312 | 84.02% | 78.87% |
|
| 56 |
+
| LSTM | 0.2142 | 0.1914 | 0.3800 | 56.00% | 75.00% |
|
| 57 |
+
|
| 58 |
+
### Text Contribution
|
| 59 |
+
|
| 60 |
+
| Architecture | TS-Only R² | +Text R² | Improvement |
|
| 61 |
+
|--------------|-----------|----------|-------------|
|
| 62 |
+
| Transformer → GitPulse | 0.6312 | 0.7559 | **+19.8%** |
|
| 63 |
+
| CondGRU → CondGRU+Text | 0.3328 | 0.7043 | **+111.6%** |
|
| 64 |
+
|
| 65 |
+
## Usage
|
| 66 |
+
|
| 67 |
+
### Installation
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
pip install torch transformers
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### Quick Start
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
import torch
|
| 77 |
+
from transformers import DistilBertTokenizer
|
| 78 |
+
|
| 79 |
+
# Load model
|
| 80 |
+
from model import GitPulseModel
|
| 81 |
+
model = GitPulseModel.from_pretrained('./')
|
| 82 |
+
|
| 83 |
+
# Prepare inputs
|
| 84 |
+
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
| 85 |
+
text = "A Python library for machine learning"
|
| 86 |
+
encoded = tokenizer(text, padding='max_length', truncation=True,
|
| 87 |
+
max_length=128, return_tensors='pt')
|
| 88 |
+
|
| 89 |
+
# Time series: [batch, hist_len, n_vars]
|
| 90 |
+
time_series = torch.randn(1, 128, 16)
|
| 91 |
+
|
| 92 |
+
# Predict
|
| 93 |
+
model.eval()
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
predictions = model(
|
| 96 |
+
time_series,
|
| 97 |
+
input_ids=encoded['input_ids'],
|
| 98 |
+
attention_mask=encoded['attention_mask']
|
| 99 |
+
)
|
| 100 |
+
# predictions shape: [1, 32, 16]
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
### Inference API
|
| 104 |
+
|
| 105 |
+
```python
|
| 106 |
+
# Simple prediction interface
|
| 107 |
+
predictions = model.predict(
|
| 108 |
+
time_series=history_data, # [batch, 128, 16]
|
| 109 |
+
text="Project description...",
|
| 110 |
+
tokenizer=tokenizer
|
| 111 |
+
)
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
## Training Details
|
| 115 |
+
|
| 116 |
+
- **Dataset**: GitHub project activity data (4,232 projects)
|
| 117 |
+
- **Train/Val/Test Split**: 70% / 15% / 15%
|
| 118 |
+
- **Optimizer**: AdamW (lr=1e-5, weight_decay=0.01)
|
| 119 |
+
- **Fine-tuning Strategy**: Freeze encoder, train prediction head
|
| 120 |
+
- **Hardware**: NVIDIA RTX GPU
|
| 121 |
+
|
| 122 |
+
## Input Features (16 variables)
|
| 123 |
+
|
| 124 |
+
1. Commits count
|
| 125 |
+
2. Issues opened
|
| 126 |
+
3. Issues closed
|
| 127 |
+
4. Pull requests opened
|
| 128 |
+
5. Pull requests merged
|
| 129 |
+
6. Stars gained
|
| 130 |
+
7. Forks count
|
| 131 |
+
8. Contributors count
|
| 132 |
+
9. Code additions
|
| 133 |
+
10. Code deletions
|
| 134 |
+
11. Comments count
|
| 135 |
+
12. Releases count
|
| 136 |
+
13. Wiki updates
|
| 137 |
+
14. Discussions count
|
| 138 |
+
15. Sponsors count
|
| 139 |
+
16. Watchers count
|
| 140 |
+
|
| 141 |
+
## Limitations
|
| 142 |
+
|
| 143 |
+
- Trained on English project descriptions only
|
| 144 |
+
- Best suited for projects with at least 128 months of history
|
| 145 |
+
- Performance may vary for niche domains not well represented in training
|
| 146 |
+
|
| 147 |
+
## Citation
|
| 148 |
+
|
| 149 |
+
```bibtex
|
| 150 |
+
@article{gitpulse2024,
|
| 151 |
+
title={GitPulse: Multimodal Time Series Prediction for GitHub Project Health},
|
| 152 |
+
author={Anonymous},
|
| 153 |
+
journal={arXiv preprint},
|
| 154 |
+
year={2024}
|
| 155 |
+
}
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
## License
|
| 159 |
+
|
| 160 |
+
Apache 2.0
|
config.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_name": "GitPulse",
|
| 3 |
+
"version": "1.0",
|
| 4 |
+
"architecture": "Transformer+Text",
|
| 5 |
+
|
| 6 |
+
"model_config": {
|
| 7 |
+
"n_vars": 16,
|
| 8 |
+
"d_model": 128,
|
| 9 |
+
"n_heads": 4,
|
| 10 |
+
"n_layers": 2,
|
| 11 |
+
"hist_len": 128,
|
| 12 |
+
"pred_len": 32,
|
| 13 |
+
"dropout": 0.1,
|
| 14 |
+
"freeze_bert": true
|
| 15 |
+
},
|
| 16 |
+
|
| 17 |
+
"training_config": {
|
| 18 |
+
"dataset": "github_multivar",
|
| 19 |
+
"n_samples": 4232,
|
| 20 |
+
"train_ratio": 0.7,
|
| 21 |
+
"val_ratio": 0.15,
|
| 22 |
+
"test_ratio": 0.15,
|
| 23 |
+
"batch_size": 16,
|
| 24 |
+
"optimizer": "AdamW",
|
| 25 |
+
"learning_rate": 1e-5,
|
| 26 |
+
"weight_decay": 0.01,
|
| 27 |
+
"finetune_strategy": "freeze",
|
| 28 |
+
"epochs": 50,
|
| 29 |
+
"patience": 10
|
| 30 |
+
},
|
| 31 |
+
|
| 32 |
+
"evaluation_results": {
|
| 33 |
+
"MSE": 0.0755,
|
| 34 |
+
"MAE": 0.1094,
|
| 35 |
+
"RMSE": 0.2748,
|
| 36 |
+
"R2": 0.7559,
|
| 37 |
+
"DA": 0.8668,
|
| 38 |
+
"TA@0.2": 0.8160
|
| 39 |
+
},
|
| 40 |
+
|
| 41 |
+
"text_contribution": {
|
| 42 |
+
"baseline_r2": 0.6312,
|
| 43 |
+
"with_text_r2": 0.7559,
|
| 44 |
+
"improvement": "19.8%"
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
evaluation_results.json
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"GitPulse": {
|
| 3 |
+
"MSE": 0.07554114609956741,
|
| 4 |
+
"MAE": 0.1093892976641655,
|
| 5 |
+
"RMSE": 0.2748474960765832,
|
| 6 |
+
"R2": 0.7559499740600586,
|
| 7 |
+
"DA": 0.8668435534591195,
|
| 8 |
+
"TA": 0.8160223810927673
|
| 9 |
+
},
|
| 10 |
+
"CondGRU+Text": {
|
| 11 |
+
"MSE": 0.09153253585100174,
|
| 12 |
+
"MAE": 0.12042587995529175,
|
| 13 |
+
"RMSE": 0.3025434445678864,
|
| 14 |
+
"R2": 0.7042867541313171,
|
| 15 |
+
"DA": 0.8405070754716981,
|
| 16 |
+
"TA": 0.801447646422956
|
| 17 |
+
},
|
| 18 |
+
"Transformer": {
|
| 19 |
+
"MSE": 0.11415749043226242,
|
| 20 |
+
"MAE": 0.1342027485370636,
|
| 21 |
+
"RMSE": 0.33787200303112186,
|
| 22 |
+
"R2": 0.6311925649642944,
|
| 23 |
+
"DA": 0.8402122641509434,
|
| 24 |
+
"TA": 0.7887400501179245
|
| 25 |
+
},
|
| 26 |
+
"LSTM": {
|
| 27 |
+
"MSE": 0.2142,
|
| 28 |
+
"MAE": 0.1914,
|
| 29 |
+
"RMSE": 0.4628,
|
| 30 |
+
"R2": 0.38,
|
| 31 |
+
"DA": 0.56,
|
| 32 |
+
"TA": 0.75
|
| 33 |
+
},
|
| 34 |
+
"MLP": {
|
| 35 |
+
"MSE": 0.228,
|
| 36 |
+
"MAE": 0.2025,
|
| 37 |
+
"RMSE": 0.4775,
|
| 38 |
+
"R2": 0.34,
|
| 39 |
+
"DA": 0.56,
|
| 40 |
+
"TA": 0.73
|
| 41 |
+
},
|
| 42 |
+
"Linear": {
|
| 43 |
+
"MSE": 0.2261,
|
| 44 |
+
"MAE": 0.1896,
|
| 45 |
+
"RMSE": 0.4755,
|
| 46 |
+
"R2": 0.34,
|
| 47 |
+
"DA": 0.53,
|
| 48 |
+
"TA": 0.74
|
| 49 |
+
},
|
| 50 |
+
"CondGRU": {
|
| 51 |
+
"MSE": 0.2065122127532959,
|
| 52 |
+
"MAE": 0.19494034349918365,
|
| 53 |
+
"RMSE": 0.45443614815867794,
|
| 54 |
+
"R2": 0.33282309770584106,
|
| 55 |
+
"DA": 0.7418435534591195,
|
| 56 |
+
"TA": 0.740633598663522
|
| 57 |
+
}
|
| 58 |
+
}
|
gitpulse_weights.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e12b9528b6ab8a27f225e938a81318f9ab60f140312c7ba5a04c9e725ebadf3d
|
| 3 |
+
size 270997714
|
model.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GitPulse Model - Multimodal Transformer for GitHub Project Health Prediction
|
| 3 |
+
|
| 4 |
+
基于 Transformer+Text 的多模态时序预测模型
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from typing import Optional, Tuple
|
| 12 |
+
from transformers import DistilBertModel, DistilBertTokenizer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TextEncoder(nn.Module):
|
| 16 |
+
"""文本编码器:基于 DistilBERT"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, d_model=128, freeze_bert=True):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
|
| 21 |
+
|
| 22 |
+
if freeze_bert:
|
| 23 |
+
for param in self.bert.parameters():
|
| 24 |
+
param.requires_grad = False
|
| 25 |
+
|
| 26 |
+
# 投影层
|
| 27 |
+
self.proj = nn.Sequential(
|
| 28 |
+
nn.Linear(768, d_model * 2),
|
| 29 |
+
nn.LayerNorm(d_model * 2),
|
| 30 |
+
nn.GELU(),
|
| 31 |
+
nn.Dropout(0.1),
|
| 32 |
+
nn.Linear(d_model * 2, d_model),
|
| 33 |
+
nn.LayerNorm(d_model)
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# 注意力池化
|
| 37 |
+
self.attn_pool = nn.Linear(768, 1)
|
| 38 |
+
|
| 39 |
+
def forward(self, input_ids, attention_mask):
|
| 40 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| 41 |
+
hidden = outputs.last_hidden_state # [B, L, 768]
|
| 42 |
+
|
| 43 |
+
# 注意力池化
|
| 44 |
+
attn_weights = self.attn_pool(hidden).squeeze(-1) # [B, L]
|
| 45 |
+
attn_weights = attn_weights.masked_fill(attention_mask == 0, -1e9)
|
| 46 |
+
attn_weights = torch.softmax(attn_weights, dim=-1)
|
| 47 |
+
|
| 48 |
+
pooled = torch.bmm(attn_weights.unsqueeze(1), hidden).squeeze(1) # [B, 768]
|
| 49 |
+
|
| 50 |
+
return self.proj(pooled) # [B, d_model]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TransformerTSEncoder(nn.Module):
|
| 54 |
+
"""时序编码器:Transformer"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, n_vars=16, d_model=128, n_heads=4, n_layers=2,
|
| 57 |
+
hist_len=128, dropout=0.1):
|
| 58 |
+
super().__init__()
|
| 59 |
+
|
| 60 |
+
self.input_proj = nn.Sequential(
|
| 61 |
+
nn.Linear(n_vars, d_model),
|
| 62 |
+
nn.LayerNorm(d_model),
|
| 63 |
+
nn.Dropout(dropout)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, hist_len, d_model) * 0.02)
|
| 67 |
+
|
| 68 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 69 |
+
d_model=d_model, nhead=n_heads,
|
| 70 |
+
dim_feedforward=d_model * 4, dropout=dropout,
|
| 71 |
+
activation='gelu', batch_first=True
|
| 72 |
+
)
|
| 73 |
+
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
| 74 |
+
self.norm = nn.LayerNorm(d_model)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
# x: [B, T, n_vars]
|
| 78 |
+
x = self.input_proj(x)
|
| 79 |
+
x = x + self.pos_embedding[:, :x.size(1), :]
|
| 80 |
+
x = self.encoder(x)
|
| 81 |
+
return self.norm(x)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class AdaptiveFusion(nn.Module):
|
| 85 |
+
"""自适应融合层"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, d_model, min_weight=0.1, max_weight=0.3):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.min_weight = min_weight
|
| 90 |
+
self.max_weight = max_weight
|
| 91 |
+
|
| 92 |
+
self.gate = nn.Sequential(
|
| 93 |
+
nn.Linear(d_model * 2, d_model),
|
| 94 |
+
nn.GELU(),
|
| 95 |
+
nn.Linear(d_model, 1),
|
| 96 |
+
nn.Sigmoid()
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def forward(self, ts_feat, text_feat):
|
| 100 |
+
combined = torch.cat([ts_feat, text_feat], dim=-1)
|
| 101 |
+
raw_weight = self.gate(combined)
|
| 102 |
+
weight = self.min_weight + (self.max_weight - self.min_weight) * raw_weight
|
| 103 |
+
return ts_feat * (1 - weight) + text_feat * weight
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class GitPulseModel(nn.Module):
|
| 107 |
+
"""
|
| 108 |
+
GitPulse: Multimodal Transformer for Time Series Prediction
|
| 109 |
+
|
| 110 |
+
结合项目文本描述和历史时序数据进行预测
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
n_vars: int = 16,
|
| 116 |
+
d_model: int = 128,
|
| 117 |
+
n_heads: int = 4,
|
| 118 |
+
n_layers: int = 2,
|
| 119 |
+
hist_len: int = 128,
|
| 120 |
+
pred_len: int = 32,
|
| 121 |
+
dropout: float = 0.1,
|
| 122 |
+
freeze_bert: bool = True
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
|
| 126 |
+
self.n_vars = n_vars
|
| 127 |
+
self.d_model = d_model
|
| 128 |
+
self.hist_len = hist_len
|
| 129 |
+
self.pred_len = pred_len
|
| 130 |
+
|
| 131 |
+
# 时序编码器
|
| 132 |
+
self.ts_encoder = TransformerTSEncoder(
|
| 133 |
+
n_vars=n_vars, d_model=d_model, n_heads=n_heads,
|
| 134 |
+
n_layers=n_layers, hist_len=hist_len, dropout=dropout
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# 文本编码器
|
| 138 |
+
self.text_encoder = TextEncoder(d_model=d_model, freeze_bert=freeze_bert)
|
| 139 |
+
|
| 140 |
+
# 融合层
|
| 141 |
+
self.fusion = AdaptiveFusion(d_model)
|
| 142 |
+
|
| 143 |
+
# 预测头
|
| 144 |
+
self.pred_head = nn.Sequential(
|
| 145 |
+
nn.Linear(d_model, d_model * 2),
|
| 146 |
+
nn.GELU(),
|
| 147 |
+
nn.Dropout(dropout),
|
| 148 |
+
nn.Linear(d_model * 2, n_vars)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# 时间投影
|
| 152 |
+
self.temporal_proj = nn.Linear(hist_len, pred_len)
|
| 153 |
+
|
| 154 |
+
def forward(
|
| 155 |
+
self,
|
| 156 |
+
x: torch.Tensor,
|
| 157 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 158 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 159 |
+
return_auxiliary: bool = False
|
| 160 |
+
) -> torch.Tensor:
|
| 161 |
+
"""
|
| 162 |
+
Args:
|
| 163 |
+
x: 历史时序 [B, hist_len, n_vars]
|
| 164 |
+
input_ids: 文本 token IDs [B, L]
|
| 165 |
+
attention_mask: 注意力掩码 [B, L]
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
预测序列 [B, pred_len, n_vars]
|
| 169 |
+
"""
|
| 170 |
+
# 时序编码
|
| 171 |
+
ts_encoded = self.ts_encoder(x) # [B, hist_len, d_model]
|
| 172 |
+
ts_global = ts_encoded.mean(dim=1) # [B, d_model]
|
| 173 |
+
|
| 174 |
+
# 文本编码和融合
|
| 175 |
+
if input_ids is not None and attention_mask is not None:
|
| 176 |
+
text_feat = self.text_encoder(input_ids, attention_mask)
|
| 177 |
+
fused = self.fusion(ts_global, text_feat)
|
| 178 |
+
else:
|
| 179 |
+
fused = ts_global
|
| 180 |
+
|
| 181 |
+
# 预测
|
| 182 |
+
pred_feat = self.pred_head(ts_encoded) # [B, hist_len, n_vars]
|
| 183 |
+
pred_feat = pred_feat.transpose(1, 2) # [B, n_vars, hist_len]
|
| 184 |
+
output = self.temporal_proj(pred_feat) # [B, n_vars, pred_len]
|
| 185 |
+
output = output.transpose(1, 2) # [B, pred_len, n_vars]
|
| 186 |
+
|
| 187 |
+
if return_auxiliary:
|
| 188 |
+
return output, torch.tensor(0.0), torch.tensor(0.0), {}
|
| 189 |
+
return output
|
| 190 |
+
|
| 191 |
+
@classmethod
|
| 192 |
+
def from_pretrained(cls, path: str, device: str = 'cuda'):
|
| 193 |
+
"""从预训练权重加载模型"""
|
| 194 |
+
config_path = os.path.join(path, 'config.json')
|
| 195 |
+
weights_path = os.path.join(path, 'gitpulse_weights.pt')
|
| 196 |
+
|
| 197 |
+
# 加载配置
|
| 198 |
+
if os.path.exists(config_path):
|
| 199 |
+
with open(config_path, 'r') as f:
|
| 200 |
+
config = json.load(f)
|
| 201 |
+
else:
|
| 202 |
+
config = {}
|
| 203 |
+
|
| 204 |
+
# 创建模型
|
| 205 |
+
model = cls(
|
| 206 |
+
n_vars=config.get('n_vars', 16),
|
| 207 |
+
d_model=config.get('d_model', 128),
|
| 208 |
+
n_heads=config.get('n_heads', 4),
|
| 209 |
+
n_layers=config.get('n_layers', 2),
|
| 210 |
+
hist_len=config.get('hist_len', 128),
|
| 211 |
+
pred_len=config.get('pred_len', 32)
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# 加载权重
|
| 215 |
+
if os.path.exists(weights_path):
|
| 216 |
+
state_dict = torch.load(weights_path, map_location=device, weights_only=False)
|
| 217 |
+
model.load_state_dict(state_dict, strict=False)
|
| 218 |
+
print(f"✓ Loaded weights from {weights_path}")
|
| 219 |
+
|
| 220 |
+
return model.to(device)
|
| 221 |
+
|
| 222 |
+
def predict(
|
| 223 |
+
self,
|
| 224 |
+
time_series: torch.Tensor,
|
| 225 |
+
text: str = None,
|
| 226 |
+
tokenizer: DistilBertTokenizer = None
|
| 227 |
+
) -> torch.Tensor:
|
| 228 |
+
"""便捷预测接口"""
|
| 229 |
+
self.eval()
|
| 230 |
+
|
| 231 |
+
if text is not None and tokenizer is not None:
|
| 232 |
+
encoded = tokenizer(
|
| 233 |
+
text, padding='max_length', truncation=True,
|
| 234 |
+
max_length=128, return_tensors='pt'
|
| 235 |
+
)
|
| 236 |
+
input_ids = encoded['input_ids'].to(time_series.device)
|
| 237 |
+
attention_mask = encoded['attention_mask'].to(time_series.device)
|
| 238 |
+
else:
|
| 239 |
+
input_ids = None
|
| 240 |
+
attention_mask = None
|
| 241 |
+
|
| 242 |
+
with torch.no_grad():
|
| 243 |
+
output = self.forward(time_series, input_ids, attention_mask)
|
| 244 |
+
|
| 245 |
+
return output
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# 模型信息
|
| 249 |
+
def get_model_info():
|
| 250 |
+
return {
|
| 251 |
+
'name': 'GitPulse',
|
| 252 |
+
'version': '1.0',
|
| 253 |
+
'architecture': 'Transformer+Text',
|
| 254 |
+
'description': 'Multimodal time series prediction model for GitHub project health',
|
| 255 |
+
'metrics': {
|
| 256 |
+
'R2': 0.7559,
|
| 257 |
+
'MSE': 0.0755,
|
| 258 |
+
'DA': 0.8668,
|
| 259 |
+
'TA@0.2': 0.8160
|
| 260 |
+
}
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.12.0
|
| 2 |
+
transformers>=4.20.0
|
| 3 |
+
numpy>=1.21.0
|
| 4 |
+
|