File size: 1,927 Bytes
af72211
f793749
 
 
af72211
f793749
 
 
 
 
 
 
 
 
 
 
 
af72211
 
f793749
af72211
f793749
 
af72211
f793749
af72211
f793749
 
 
 
 
 
af72211
f793749
af72211
f793749
af72211
f793749
af72211
f793749
 
 
af72211
f793749
 
 
af72211
f793749
 
 
 
af72211
f793749
 
 
af72211
f793749
 
 
 
 
 
 
 
af72211
f793749
af72211
f793749
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
---
datasets:
- CausalNewsCorpus
language: en
library_name: transformers
license: mit
metrics:
- f1
- precision
- recall
tags:
- token-classification
- roberta
- causal-narrative
- cause-effect-extraction
- span-extraction
- ner
---

# RoBERTa Causal Span Extractor

This model is a fine-tuned version of `roberta-base` for **causal span extraction**
(token classification). It identifies **cause** and **effect** text spans in sentences.

## Model Description

- **Base Model**: roberta-base
- **Task**: Token classification (BIO tagging)
- **Labels**: O, B-CAUSE, I-CAUSE, B-EFFECT, I-EFFECT
- **Training Data**: CausalNewsCorpus V2 (sentences with exactly 1 causal pair)
- **Training Samples**: 1105
- **Dev Samples**: 133

## Training Results

See the training notebook for detailed metrics.

## Usage

```python
from transformers import RobertaTokenizerFast, RobertaForTokenClassification
import torch

model_name = "causal-narrative/roberta-causal-span-extractor"
tokenizer = RobertaTokenizerFast.from_pretrained(model_name, add_prefix_space=True)
model = RobertaForTokenClassification.from_pretrained(model_name)

text = "The heavy rain caused flooding in the city."
words = text.split()
inputs = tokenizer(words, is_split_into_words=True, return_tensors="pt",
                   truncation=True, padding=True)

with torch.no_grad():
    outputs = model(**inputs)
    preds = torch.argmax(outputs.logits, dim=2)[0]

id2label = model.config.id2label
word_ids = tokenizer(words, is_split_into_words=True).word_ids()
prev = None
for wid in word_ids:
    if wid is not None and wid != prev:
        print(f"{words[wid]:20s} {id2label[preds[word_ids.index(wid)].item()]}")
        prev = wid
```

## Labels

| Label | Description |
|-------|-------------|
| O | Non-causal token |
| B-CAUSE | Beginning of cause span |
| I-CAUSE | Inside cause span |
| B-EFFECT | Beginning of effect span |
| I-EFFECT | Inside effect span |