rasoultilburg commited on
Commit
d2ca2ac
·
verified ·
1 Parent(s): c3b7cef

Uploading joint causal model files into the Hugging Face Hub

Browse files
config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "JointCausalModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_joint_causal.JointCausalConfig",
7
+ "AutoModel": "modeling_joint_causal.JointCausalModel"
8
+ },
9
+ "dropout": 0.2,
10
+ "encoder_name": "bert-base-uncased",
11
+ "model_type": "joint_causal",
12
+ "num_bio_labels": 7,
13
+ "num_cls_labels": 2,
14
+ "num_rel_labels": 2,
15
+ "torch_dtype": "float32",
16
+ "transformers_version": "4.51.3"
17
+ }
config.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration settings for the joint causal learning model.
3
+ """
4
+
5
+ import torch
6
+
7
+ # Device configuration
8
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # Seed for reproducibility
11
+ SEED = 8642
12
+
13
+ # Model configuration
14
+ MODEL_CONFIG = {
15
+ "encoder_name": "bert-base-uncased", # Default encoder
16
+ "num_cls_labels": 2, # Binary classification for causal/non-causal
17
+ "num_bio_labels": 7, # BIO labels for span detection
18
+ "num_rel_labels": 2, # Relation labels (updated from 3 to 2)
19
+ "dropout": 0.2, # Dropout rate
20
+ }
21
+
22
+ # Training configuration
23
+ TRAINING_CONFIG = {
24
+ "batch_size": 16,
25
+ "num_epochs": 20,
26
+ "learning_rate": 1e-5,
27
+ "weight_decay": 0.1,
28
+ "gradient_clip_val": 1.0,
29
+ "patience_epochs": 10,
30
+ "model_save_path": "best_joint_causal_model.pt",
31
+ }
32
+
33
+ # Dataset configuration
34
+ DATASET_CONFIG = {
35
+ "max_length": 512, # Maximum sequence length for tokenization
36
+ "negative_relation_rate": 2.0, # Rate of negative relation samples to generate
37
+ "max_random_span_len": 5, # Maximum length for random negative spans
38
+ "ignore_id": -100, # ID to ignore in loss computation
39
+ }
40
+
41
+ # Label mappings
42
+ # BIO labels for span detection
43
+ id2label_bio = {
44
+ 0: "B-C", # Beginning of Cause
45
+ 1: "I-C", # Inside of Cause
46
+ 2: "B-E", # Beginning of Effect
47
+ 3: "I-E", # Inside of Effect
48
+ 4: "B-CE", # Beginning of Cause-Effect
49
+ 5: "I-CE", # Inside of Cause-Effect
50
+ 6: "O" # Outside
51
+ }
52
+ label2id_bio = {v: k for k, v in id2label_bio.items()}
53
+
54
+ # Entity label to BIO prefix mapping
55
+ entity_label_to_bio_prefix = {
56
+ "cause": "C",
57
+ "effect": "E",
58
+ "internal_CE": "CE",
59
+ "non-causal": "O"
60
+ }
61
+
62
+ # Relation labels
63
+ id2label_rel = {
64
+ 0: "Rel_None",
65
+ 1: "Rel_CE"
66
+ }
67
+ label2id_rel = {
68
+ "Rel_None": 0,
69
+ "Rel_CE": 1
70
+ }
71
+
72
+ # Classification labels
73
+ id2label_cls = {
74
+ 0: "non-causal",
75
+ 1: "causal"
76
+ }
77
+ label2id_cls = {v: k for k, v in id2label_cls.items()}
78
+
79
+ # Relation type mappings
80
+ POSITIVE_RELATION_TYPE_TO_ID = {
81
+ "Rel_CE": 1
82
+ }
83
+ NEGATIVE_SAMPLE_REL_ID = label2id_rel["Rel_None"]
84
+
85
+ # Inference configuration
86
+ INFERENCE_CONFIG = {
87
+ "cls_threshold": 0.5, # Threshold for causal/non-causal classification
88
+ }
configuration_joint_causal.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # configuration_joint_causal.py
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+ class JointCausalConfig(PretrainedConfig):
6
+ """
7
+ This is the configuration class for JointCausalModel, following the tutorial's guidelines.
8
+ """
9
+ # The 'model_type' is crucial for AutoClass support, as mentioned in the tutorial.
10
+ model_type = "joint_causal"
11
+
12
+ def __init__(
13
+ self,
14
+ encoder_name="bert-base-uncased",
15
+ num_cls_labels=2,
16
+ num_bio_labels=7,
17
+ num_rel_labels=2,
18
+ dropout=0.2,
19
+ **kwargs,
20
+ ):
21
+ self.encoder_name = encoder_name
22
+ self.num_cls_labels = num_cls_labels
23
+ self.num_bio_labels = num_bio_labels
24
+ self.num_rel_labels = num_rel_labels
25
+ self.dropout = dropout
26
+ # As per the tutorial, we must pass any extra kwargs to the superclass.
27
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce74796d1e54f99ca8e791aebed3817b12c1aacbe0e3f060f1de599f77c1ac62
3
+ size 448604340
modeling_joint_causal.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Dict, List, Optional
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoModel, PreTrainedModel
6
+ from dataclasses import dataclass
7
+ try:
8
+ from .config import id2label_bio, id2label_rel, id2label_cls
9
+ except ImportError:
10
+ from config import id2label_bio, id2label_rel, id2label_cls
11
+
12
+ try:
13
+ from .configuration_joint_causal import JointCausalConfig
14
+ except ImportError:
15
+ from configuration_joint_causal import JointCausalConfig
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Type aliases & label maps
19
+ # ---------------------------------------------------------------------------
20
+ label2id_bio = {v: k for k, v in id2label_bio.items()}
21
+ label2id_rel = {v: k for k, v in id2label_rel.items()}
22
+ label2id_cls = {v: k for k, v in id2label_cls.items()}
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Main module
26
+ # ---------------------------------------------------------------------------
27
+ """Joint Causal Extraction Model (softmax)
28
+ ============================================================================
29
+
30
+ A PyTorch module for joint causal extraction using softmax decoding for BIO tagging.
31
+ The model supports class weights for handling imbalanced data.
32
+
33
+ ```python
34
+ >>> model = JointCausalModel() # softmax-based model
35
+ """
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # Span dataclass
40
+ # ---------------------------------------------------------------------------
41
+ @dataclass
42
+ class Span:
43
+ role: str
44
+ start_tok: int
45
+ end_tok: int
46
+ text: str
47
+ is_virtual: bool = False
48
+
49
+
50
+ # ---------------------------------------------------------------------------
51
+ # Main module
52
+ # ---------------------------------------------------------------------------
53
+
54
+ class JointCausalModel(PreTrainedModel):
55
+
56
+ """Encoder + three heads with **optional CRF** BIO decoder.
57
+
58
+ This model integrates a pre-trained transformer encoder with three distinct
59
+ heads for:
60
+ 1. Classification (cls_head): Predicts a global label for the input.
61
+ 2. BIO tagging (bio_head): Performs sequence tagging using BIO scheme.
62
+ Can operate with a CRF layer or standard softmax.
63
+ 3. Relation extraction (rel_head): Identifies relations between entities
64
+ detected by the BIO tagging head.
65
+ """
66
+ # Link the model to its config class, as shown in the tutorial.
67
+ config_class = JointCausalConfig
68
+
69
+ # ------------------------------------------------------------------
70
+ # constructor
71
+ # -----------------------------------------------------------
72
+ def __init__(self, config: JointCausalConfig):
73
+
74
+ """Initializes the JointCausalModel.
75
+
76
+ Args:
77
+ encoder_name: Name of the pre-trained transformer model to use
78
+ (e.g., "bert-base-uncased").
79
+ num_cls_labels: Number of labels for the classification task.
80
+ num_bio_labels: Number of labels for the BIO tagging task.
81
+ num_rel_labels: Number of labels for the relation extraction task.
82
+ dropout: Dropout rate for regularization.
83
+ """
84
+
85
+ super().__init__(config)
86
+ self.config = config
87
+
88
+ self.enc = AutoModel.from_pretrained(config.encoder_name)
89
+ self.hidden_size = self.enc.config.hidden_size
90
+ self.dropout = nn.Dropout(config.dropout)
91
+ self.layer_norm = nn.LayerNorm(self.hidden_size)
92
+
93
+
94
+
95
+ self.cls_head = nn.Sequential(
96
+ nn.Linear(self.hidden_size, self.hidden_size // 2),
97
+ nn.ReLU(),
98
+ nn.Dropout(config.dropout),
99
+ nn.Linear(self.hidden_size // 2, config.num_cls_labels),
100
+ )
101
+ self.bio_head = nn.Sequential(
102
+ nn.Linear(self.hidden_size, self.hidden_size),
103
+ nn.ReLU(),
104
+ nn.Dropout(config.dropout),
105
+ nn.Linear(self.hidden_size, self.hidden_size // 2),
106
+ nn.ReLU(),
107
+ nn.Dropout(config.dropout),
108
+ nn.Linear(self.hidden_size // 2, config.num_bio_labels),
109
+ )
110
+ self.rel_head = nn.Sequential(
111
+ nn.Linear(self.hidden_size * 2, self.hidden_size),
112
+ nn.ReLU(),
113
+ nn.Dropout(config.dropout),
114
+ nn.Linear(self.hidden_size, self.hidden_size // 2),
115
+ nn.ReLU(),
116
+ nn.Dropout(config.dropout),
117
+ nn.Linear(self.hidden_size // 2, config.num_rel_labels),
118
+ )
119
+ self._init_new_layer_weights()
120
+
121
+ def get_config_dict(self) -> Dict:
122
+ """Returns the model's configuration as a dictionary."""
123
+ return {
124
+ "encoder_name": self.encoder_name,
125
+ "num_cls_labels": self.num_cls_labels,
126
+ "num_bio_labels": self.num_bio_labels,
127
+ "num_rel_labels": self.num_rel_labels,
128
+ "dropout": self.dropout_rate,
129
+ }
130
+
131
+ @classmethod
132
+ def from_config_dict(cls, config: Dict) -> "JointCausalModel":
133
+ """Creates a JointCausalModel instance from a configuration dictionary."""
134
+ return cls(**config)
135
+
136
+ def _init_new_layer_weights(self):
137
+ """Initializes the weights of the newly added linear layers.
138
+
139
+ Uses Xavier uniform initialization for weights and zeros for biases.
140
+ """
141
+ for mod in [self.cls_head, self.bio_head, self.rel_head]:
142
+ for sub_module in mod.modules():
143
+ if isinstance(sub_module, nn.Linear):
144
+ nn.init.xavier_uniform_(sub_module.weight)
145
+ if sub_module.bias is not None:
146
+ nn.init.zeros_(sub_module.bias)
147
+
148
+ def encode(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
149
+ """Encodes the input using the transformer model.
150
+
151
+ Args:
152
+ input_ids: Tensor of input token IDs.
153
+ attention_mask: Tensor indicating which tokens to attend to.
154
+
155
+ Returns:
156
+ Tensor of hidden states from the encoder, passed through dropout
157
+ and layer normalization.
158
+ """
159
+ hidden_states = self.enc(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
160
+ return self.layer_norm(self.dropout(hidden_states))
161
+
162
+ def forward(
163
+ self,
164
+ input_ids: torch.Tensor,
165
+ attention_mask: torch.Tensor,
166
+ *,
167
+ bio_labels: torch.Tensor | None = None,
168
+ pair_batch: torch.Tensor | None = None,
169
+ cause_starts: torch.Tensor | None = None,
170
+ cause_ends: torch.Tensor | None = None,
171
+ effect_starts: torch.Tensor | None = None,
172
+ effect_ends: torch.Tensor | None = None,
173
+ ) -> Dict[str, torch.Tensor | None]:
174
+ """Performs a forward pass through the model.
175
+
176
+ Args:
177
+ input_ids: Tensor of input token IDs.
178
+ attention_mask: Tensor indicating which tokens to attend to.
179
+ bio_labels: Optional tensor of BIO labels for training.
180
+ pair_batch: Optional tensor indicating which hidden states to use
181
+ for relation extraction.
182
+ cause_starts: Optional tensor of start indices for cause spans.
183
+ cause_ends: Optional tensor of end indices for cause spans.
184
+ effect_starts: Optional tensor of start indices for effect spans.
185
+ effect_ends: Optional tensor of end indices for effect spans.
186
+
187
+ Returns:
188
+ A dictionary containing:
189
+ - "cls_logits": Logits for the classification task.
190
+ - "bio_emissions": Emissions from the BIO tagging head.
191
+ - "tag_loss": Loss for the BIO tagging task (if bio_labels provided).
192
+ - "rel_logits": Logits for the relation extraction task (if
193
+ relation extraction inputs provided).
194
+ """
195
+ # Encode input
196
+ hidden = self.encode(input_ids, attention_mask)
197
+
198
+ # Classification head
199
+ cls_logits = self.cls_head(hidden[:, 0]) # Use [CLS] token representation
200
+
201
+ # BIO tagging head
202
+ emissions = self.bio_head(hidden)
203
+ tag_loss: Optional[torch.Tensor] = None
204
+
205
+ # Calculate BIO tagging loss if labels are provided
206
+ if bio_labels is not None:
207
+ # Softmax loss (typically handled by the training loop's loss function, e.g., CrossEntropyLoss)
208
+ # Here, we initialize it to 0.0 as a placeholder.
209
+ # The actual loss calculation for softmax would compare emissions with bio_labels.
210
+ tag_loss = torch.tensor(0.0, device=emissions.device)
211
+
212
+ # Relation extraction head
213
+ rel_logits: torch.Tensor | None = None
214
+ if pair_batch is not None and cause_starts is not None and cause_ends is not None \
215
+ and effect_starts is not None and effect_ends is not None:
216
+ # Select hidden states corresponding to the pairs for relation extraction
217
+ bio_states_for_rel = hidden[pair_batch]
218
+ seq_len_rel = bio_states_for_rel.size(1)
219
+ pos_rel = torch.arange(seq_len_rel, device=bio_states_for_rel.device).unsqueeze(0)
220
+
221
+ # Create masks for cause and effect spans
222
+ c_mask = ((cause_starts.unsqueeze(1) <= pos_rel) & (pos_rel <= cause_ends.unsqueeze(1))).unsqueeze(2)
223
+ e_mask = ((effect_starts.unsqueeze(1) <= pos_rel) & (pos_rel <= effect_ends.unsqueeze(1))).unsqueeze(2)
224
+
225
+ # Compute mean-pooled representations for cause and effect spans
226
+ c_vec = (bio_states_for_rel * c_mask).sum(1) / c_mask.sum(1).clamp(min=1) # Average pooling, clamp to avoid div by zero
227
+ e_vec = (bio_states_for_rel * e_mask).sum(1) / e_mask.sum(1).clamp(min=1) # Average pooling, clamp to avoid div by zero
228
+
229
+ # Concatenate cause and effect vectors and pass through relation head
230
+ rel_logits = self.rel_head(torch.cat([c_vec, e_vec], dim=1))
231
+
232
+ return {
233
+ "cls_logits": cls_logits,
234
+ "bio_emissions": emissions,
235
+ "tag_loss": tag_loss,
236
+ "rel_logits": rel_logits,
237
+ }
238
+
239
+ def predict(self, sents: List[str], tokenizer=None, rel_mode="auto", rel_threshold=0.4, cause_decision="cls+span") -> list:
240
+ """
241
+ HuggingFace-compatible prediction method for causal extraction.
242
+ Args:
243
+ sents (List[str]): List of input sentences.
244
+ tokenizer: Optional HuggingFace tokenizer. If None, uses self.encoder_name.
245
+ rel_mode (str): 'auto' or 'head'.
246
+ rel_threshold (float): Probability threshold for relation extraction.
247
+ cause_decision (str): 'cls_only', 'span_only', or 'cls+span'.
248
+ Returns:
249
+ List of dicts with 'text', 'causal', and 'relations' fields for each sentence.
250
+ """
251
+ # Use id2label_bio from the module-level import instead of importing here
252
+ if tokenizer is None:
253
+ from transformers import AutoTokenizer
254
+ tokenizer = AutoTokenizer.from_pretrained(self.encoder_name)
255
+ device = next(self.parameters()).device
256
+ outs = []
257
+ for txt in sents:
258
+ enc = tokenizer([txt], return_tensors="pt", truncation=True, max_length=512)
259
+ enc = {k: v.to(device) for k, v in enc.items()}
260
+ with torch.no_grad():
261
+ rel_args = {}
262
+ rel_pair_spans = []
263
+ # Always prepare relation extraction arguments if needed (for head mode or auto mode with multi C/E)
264
+ if rel_mode == "head":
265
+ res_tmp = self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"])
266
+ bio_tmp = res_tmp["bio_emissions"].squeeze(0).argmax(-1).tolist()
267
+ tok_tmp = tokenizer.convert_ids_to_tokens(enc["input_ids"].squeeze(0))
268
+ lab_tmp = [id2label_bio[i] for i in bio_tmp]
269
+ fixed_tmp = JointCausalModel._apply_bio_rules(tok_tmp, lab_tmp)
270
+ spans_tmp = JointCausalModel._merge_spans(tok_tmp, fixed_tmp)
271
+ c_spans = [s for s in spans_tmp if s.role in ("C", "CE")]
272
+ e_spans = [s for s in spans_tmp if s.role in ("E", "CE")]
273
+ pair_batch = []
274
+ cause_starts = []
275
+ cause_ends = []
276
+ effect_starts = []
277
+ effect_ends = []
278
+ for c in c_spans:
279
+ for e in e_spans:
280
+ if c.start_tok == e.start_tok and c.end_tok == e.end_tok:
281
+ continue
282
+ pair_batch.append(0)
283
+ cause_starts.append(c.start_tok)
284
+ cause_ends.append(c.end_tok)
285
+ effect_starts.append(e.start_tok)
286
+ effect_ends.append(e.end_tok)
287
+ rel_pair_spans.append((c, e))
288
+ if pair_batch:
289
+ rel_args = {
290
+ "pair_batch": torch.tensor(pair_batch, device=device),
291
+ "cause_starts": torch.tensor(cause_starts, device=device),
292
+ "cause_ends": torch.tensor(cause_ends, device=device),
293
+ "effect_starts": torch.tensor(effect_starts, device=device),
294
+ "effect_ends": torch.tensor(effect_ends, device=device),
295
+ }
296
+ res = self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], **rel_args)
297
+ cls = res["cls_logits"].squeeze(0)
298
+ bio = res["bio_emissions"].squeeze(0).argmax(-1).tolist()
299
+ tok = tokenizer.convert_ids_to_tokens(enc["input_ids"].squeeze(0))
300
+ lab = [id2label_bio[i] for i in bio]
301
+ fixed = JointCausalModel._apply_bio_rules(tok, lab)
302
+ spans = JointCausalModel._merge_spans(tok, fixed)
303
+ causal = JointCausalModel._decide_causal(cls, spans, cause_decision)
304
+ if not causal:
305
+ outs.append({"text": txt, "causal": False, "relations": []})
306
+ continue
307
+ rels = []
308
+ rel_logits = res.get("rel_logits")
309
+ rel_probs = None
310
+ if rel_logits is not None:
311
+ rel_probs = torch.softmax(rel_logits, dim=-1)
312
+ if rel_mode == "head":
313
+ for idx, (csp, esp) in enumerate(rel_pair_spans):
314
+ if rel_probs[idx, 1].item() > rel_threshold:
315
+ rels.append({"cause": csp.text, "effect": esp.text, "type": "Rel_CE"})
316
+ elif rel_mode == "auto":
317
+ c_spans = [s for s in spans if s.role in ("C", "CE")]
318
+ e_spans = [s for s in spans if s.role in ("E", "CE")]
319
+ if not c_spans or not e_spans:
320
+ rels = []
321
+ elif len(c_spans) == 1 and len(e_spans) >= 1:
322
+ for e in e_spans:
323
+ rels.append({"cause": c_spans[0].text, "effect": e.text, "type": "Rel_CE"})
324
+ elif len(e_spans) == 1 and len(c_spans) >= 1:
325
+ for c in c_spans:
326
+ rels.append({"cause": c.text, "effect": e_spans[0].text, "type": "Rel_CE"})
327
+ elif len(c_spans) > 1 and len(e_spans) > 1:
328
+ pair_batch = []
329
+ cause_starts = []
330
+ cause_ends = []
331
+ effect_starts = []
332
+ effect_ends = []
333
+ rel_pair_spans = []
334
+ for c in c_spans:
335
+ for e in e_spans:
336
+ if (c.start_tok == e.start_tok and c.end_tok == e.end_tok):
337
+ continue
338
+ pair_batch.append(0)
339
+ cause_starts.append(c.start_tok)
340
+ cause_ends.append(c.end_tok)
341
+ effect_starts.append(e.start_tok)
342
+ effect_ends.append(e.end_tok)
343
+ rel_pair_spans.append((c, e))
344
+ if pair_batch:
345
+ rel_args = {
346
+ "pair_batch": torch.tensor(pair_batch, device=device),
347
+ "cause_starts": torch.tensor(cause_starts, device=device),
348
+ "cause_ends": torch.tensor(cause_ends, device=device),
349
+ "effect_starts": torch.tensor(effect_starts, device=device),
350
+ "effect_ends": torch.tensor(effect_ends, device=device),
351
+ }
352
+ res_rel = self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], **rel_args)
353
+ rel_logits = res_rel.get("rel_logits")
354
+ if rel_logits is not None:
355
+ rel_probs = torch.softmax(rel_logits, dim=-1)
356
+ for idx, (csp, esp) in enumerate(rel_pair_spans):
357
+ if rel_probs[idx, 1].item() > rel_threshold:
358
+ rels.append({"cause": csp.text, "effect": esp.text, "type": "Rel_CE"})
359
+ if cause_decision == "cls_only":
360
+ causal = cls.argmax(-1).item() == 1
361
+ elif cause_decision == "span_only":
362
+ causal = any(x.role == "C" for x in spans) and any(x.role == "E" for x in spans)
363
+ elif cause_decision == "cls+span":
364
+ causal = (cls.argmax(-1).item() == 1) and (any(x.role == "C" for x in spans) and any(x.role == "E" for x in spans))
365
+ else:
366
+ raise ValueError(cause_decision)
367
+ if not rels:
368
+ outs.append({"text": txt, "causal": False, "relations": []})
369
+ else:
370
+ outs.append({"text": txt, "causal": causal, "relations": rels})
371
+ return outs
372
+
373
+ @staticmethod
374
+ def _apply_bio_rules(tok, lab):
375
+ """
376
+ Apply post-processing rules to BIO tags to fix inconsistencies and clean up spans.
377
+ - Fixes mixed-role spans, punctuation, short tokens, and CE disambiguation.
378
+ """
379
+ # Constants for punctuation, stopwords, and connectors
380
+ _PUNCT = {".",",",";",":","?","!","(",")","[","]","{","}"}
381
+ _STOPWORD_KEEP = {"this","that","these","those","it","they"}
382
+
383
+ rep, n = lab.copy(), len(tok)
384
+ def blocks():
385
+ i=0
386
+ while i<n:
387
+ if rep[i]=="O": i+=1; continue
388
+ s=i
389
+ while i+1<n and rep[i+1]!="O": i+=1
390
+ yield s,i; i+=1
391
+ # B‑1′: Fix mixed-role spans
392
+ for s,e in list(blocks()):
393
+ roles=[rep[j].split("-")[-1] for j in range(s,e+1)]
394
+ if len(set(roles))>1:
395
+ split=next((j for j in range(s+1,e+1) if roles[j-s]!=roles[j-s-1]),None)
396
+ if split:
397
+ if 1 in {split-s,e-split+1}:
398
+ maj=roles[0] if split-s>e-split+1 else roles[-1]
399
+ for j in range(s,e+1): rep[j]=f"B-{maj}" if j==s else f"I-{maj}"
400
+ # B‑2: Remove labels from punctuation
401
+ for i,t in enumerate(tok):
402
+ if rep[i]!="O" and t in _PUNCT: rep[i]="O"
403
+ # helper: extract labeled blocks
404
+ def labeled(v):
405
+ i=0; out=[]
406
+ while i<n:
407
+ if v[i]=="O": i+=1; continue
408
+ s=i; role=v[i].split("-")[-1]
409
+ while i+1<n and v[i+1]!="O": i+=1
410
+ out.append((s,i,role)); i+=1
411
+ return out
412
+ bl=labeled(rep)
413
+ # B‑4: Disambiguate CE to C or E if only one present
414
+ if any(r=="CE" for *_,r in bl):
415
+ cntc=sum(1 for *_,r in bl if r=="C"); cnte=sum(1 for *_,r in bl if r=="E")
416
+ if cntc==0 or cnte==0:
417
+ tr="C" if cntc==0 else "E"
418
+ for s,e,r in bl:
419
+ if r=="CE":
420
+ for idx in range(s,e+1): rep[idx]=f"B-{tr}" if idx==s else f"I-{tr}"
421
+ bl=labeled(rep)
422
+ # B‑5/6: Remove labels from short/stopword tokens and trailing punctuation
423
+ for s,e,_ in bl:
424
+ if tok[e] in _PUNCT: rep[e]="O"
425
+ if e==s and len(tok[s])<=2 and tok[s].lower() not in _STOPWORD_KEEP: rep[s]="O"
426
+ return rep
427
+
428
+ @staticmethod
429
+ def _merge_spans(tok, lab):
430
+ """
431
+ Merge contiguous labeled tokens into Span objects, gluing across connectors.
432
+ """
433
+ from transformers import AutoTokenizer
434
+ try:
435
+ from .config import MODEL_CONFIG
436
+ except ImportError:
437
+ from config import MODEL_CONFIG
438
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_CONFIG["encoder_name"])
439
+ _CONNECTORS = {"of", "to", "with", "for", "the", "and"}
440
+ _STOPWORDS = {"this", "that", "these", "those", "it", "they"}
441
+ spans = []
442
+ i = 0
443
+ while i < len(tok):
444
+ if lab[i] == "O":
445
+ i += 1
446
+ continue
447
+ role = lab[i].split("-")[-1]
448
+ s = i
449
+ while i + 1 < len(tok) and lab[i + 1] != "O":
450
+ i += 1
451
+ spans.append(Span(role, s, i, tokenizer.convert_tokens_to_string(tok[s:i + 1])))
452
+ i += 1
453
+
454
+ merged = [spans[0]] if spans else []
455
+ for sp in spans[1:]:
456
+ prv = merged[-1]
457
+ if sp.role == prv.role and sp.start_tok == prv.end_tok + 2 and tok[prv.end_tok + 1].lower() in _CONNECTORS:
458
+ # Check if the current span starts with a B tag and a connector is present
459
+ if lab[sp.start_tok].startswith("B") and tok[prv.end_tok + 1].lower() == "and":
460
+ merged.append(sp)
461
+ else:
462
+ merged[-1] = Span(
463
+ prv.role,
464
+ prv.start_tok,
465
+ sp.end_tok,
466
+ tokenizer.convert_tokens_to_string(tok[prv.start_tok:sp.end_tok + 1]),
467
+ prv.is_virtual
468
+ )
469
+ else:
470
+ merged.append(sp)
471
+
472
+ # Ensure spans are split when a new span starts with a B tag and a connector is present
473
+ final_spans = []
474
+ for span in merged:
475
+ tokens = tokenizer.tokenize(span.text)
476
+ if "and" in tokens:
477
+ split_idx = tokens.index("and")
478
+ first_part = tokenizer.convert_tokens_to_string(tokens[:split_idx])
479
+ second_part = tokenizer.convert_tokens_to_string(tokens[split_idx + 1:])
480
+ final_spans.append(Span(span.role, span.start_tok, span.start_tok + len(first_part.split()), first_part))
481
+ final_spans.append(Span(span.role, span.start_tok + len(first_part.split()) + 1, span.end_tok, second_part))
482
+ else:
483
+ # Trim stopwords from the start and end of the span only if the span length is greater than 1
484
+ if len(tokens) > 1:
485
+ trimmed_tokens = [t for t in tokens if t.lower() not in _STOPWORDS]
486
+ else:
487
+ trimmed_tokens = tokens
488
+ trimmed_text = tokenizer.convert_tokens_to_string(trimmed_tokens)
489
+ final_spans.append(Span(span.role, span.start_tok, span.end_tok, trimmed_text))
490
+
491
+ return final_spans
492
+
493
+ @staticmethod
494
+ def _decide_causal(cls, spans, mode):
495
+ if mode == "cls_only":
496
+ return cls.argmax(-1).item() == 1
497
+ elif mode == "span_only":
498
+ return any(x.role == "C" for x in spans) and any(x.role == "E" for x in spans)
499
+ elif mode == "cls+span":
500
+ return (cls.argmax(-1).item() == 1) and (any(x.role == "C" for x in spans) and any(x.role == "E" for x in spans))
501
+ else:
502
+ raise ValueError(mode)
503
+
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": false,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "extra_special_tokens": {},
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "pad_token": "[PAD]",
51
+ "sep_token": "[SEP]",
52
+ "strip_accents": null,
53
+ "tokenize_chinese_chars": true,
54
+ "tokenizer_class": "BertTokenizer",
55
+ "unk_token": "[UNK]"
56
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff