BiTimeBERT / README.md
JasonWang1's picture
Update README.md
f9cc33e verified
|
raw
history blame
2.68 kB
---
license: mit
language:
- en
base_model:
- google-bert/bert-base-cased
tags:
- BERT
- temporal
- time-aware language modeling
---
# BiTimeBERT
BiTimeBERT is pretrained on the New York Times Annotated Corpus using two temporal objectives: TAMLM (Time-aware Masked Language Modeling) and DD (Document Dating). Note that the DD task employs monthly temporal granularity, classifying documents into 246 month labels spanning the corpus timeline, and thus, the seq_relationship_head outputs 246-class temporal predictions.
## 🎯 Model Details
| Property | Value |
|----------|-------|
| **Base Model** | `bert-base-cased` |
| **Pretraining Tasks** | TAMLM + DD |
| **Temporal Granularity** | Month-level |
| **DD Labels** | 246 month classes |
| **Training Corpus** | NYT Annotated Corpus |
| **Framework** | PyTorch / Transformers |
| **Language** | English |
## 🚀 How to Load This Model
### ⚠️ Important: Custom Loading Required
Due to the modified `seq_relationship` head (246-class vs. standard 2-class NSP), you **cannot** load this model with the default `from_pretrained()` alone. Follow one of the methods below:
---
### You can use Helper Function as below to load BiTimeBERT
```python
import torch
import torch.nn as nn
from transformers import BertForPreTraining, BertTokenizer, BertConfig
from huggingface_hub import hf_hub_download
import safetensors.torch as safetensors_lib
def load_bitembert(model_id="JasonWang1/BiTimeBERT", device=None, num_temporal_labels=246):
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load config and tokenizer
config = BertConfig.from_pretrained(model_id)
tokenizer = BertTokenizer.from_pretrained(model_id)
# Load model with mismatched sizes ignored
model = BertForPreTraining.from_pretrained(
model_id,
config=config,
ignore_mismatched_sizes=True
)
# Replace DD head with correct dimension
model.cls.seq_relationship = nn.Linear(config.hidden_size, num_temporal_labels)
# Download and load DD head weights from safetensors
weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors")
state_dict = safetensors_lib.load_file(weights_path, device='cpu')
if 'cls.seq_relationship.weight' in state_dict:
model.cls.seq_relationship.weight.data = state_dict['cls.seq_relationship.weight']
model.cls.seq_relationship.bias.data = state_dict['cls.seq_relationship.bias']
model.eval()
return model.to(device), tokenizer
# ================= Usage =================
model, tokenizer = load_bitembert("JasonWang1/BiTimeBERT")