File size: 1,927 Bytes
af72211 f793749 af72211 f793749 af72211 f793749 af72211 f793749 af72211 f793749 af72211 f793749 af72211 f793749 af72211 f793749 af72211 f793749 af72211 f793749 af72211 f793749 af72211 f793749 af72211 f793749 af72211 f793749 af72211 f793749 af72211 f793749 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
---
datasets:
- CausalNewsCorpus
language: en
library_name: transformers
license: mit
metrics:
- f1
- precision
- recall
tags:
- token-classification
- roberta
- causal-narrative
- cause-effect-extraction
- span-extraction
- ner
---
# RoBERTa Causal Span Extractor
This model is a fine-tuned version of `roberta-base` for **causal span extraction**
(token classification). It identifies **cause** and **effect** text spans in sentences.
## Model Description
- **Base Model**: roberta-base
- **Task**: Token classification (BIO tagging)
- **Labels**: O, B-CAUSE, I-CAUSE, B-EFFECT, I-EFFECT
- **Training Data**: CausalNewsCorpus V2 (sentences with exactly 1 causal pair)
- **Training Samples**: 1105
- **Dev Samples**: 133
## Training Results
See the training notebook for detailed metrics.
## Usage
```python
from transformers import RobertaTokenizerFast, RobertaForTokenClassification
import torch
model_name = "causal-narrative/roberta-causal-span-extractor"
tokenizer = RobertaTokenizerFast.from_pretrained(model_name, add_prefix_space=True)
model = RobertaForTokenClassification.from_pretrained(model_name)
text = "The heavy rain caused flooding in the city."
words = text.split()
inputs = tokenizer(words, is_split_into_words=True, return_tensors="pt",
truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
preds = torch.argmax(outputs.logits, dim=2)[0]
id2label = model.config.id2label
word_ids = tokenizer(words, is_split_into_words=True).word_ids()
prev = None
for wid in word_ids:
if wid is not None and wid != prev:
print(f"{words[wid]:20s} {id2label[preds[word_ids.index(wid)].item()]}")
prev = wid
```
## Labels
| Label | Description |
|-------|-------------|
| O | Non-causal token |
| B-CAUSE | Beginning of cause span |
| I-CAUSE | Inside cause span |
| B-EFFECT | Beginning of effect span |
| I-EFFECT | Inside effect span |
|