| from __future__ import annotations
|
| from typing import Dict, List
|
| import torch
|
| import torch.nn as nn
|
| from transformers import AutoModel, PretrainedConfig, PreTrainedModel
|
| from dataclasses import dataclass
|
| try:
|
| from .config import id2label_bio, id2label_rel, id2label_cls
|
| except ImportError:
|
| from config import id2label_bio, id2label_rel, id2label_cls
|
|
|
| try:
|
| from .configuration_joint_causal import JointCausalConfig
|
| except ImportError:
|
| from configuration_joint_causal import JointCausalConfig
|
|
|
|
|
|
|
|
|
| label2id_bio = {v: k for k, v in id2label_bio.items()}
|
| label2id_rel = {v: k for k, v in id2label_rel.items()}
|
| label2id_cls = {v: k for k, v in id2label_cls.items()}
|
|
|
|
|
|
|
|
|
| """Joint Causal Extraction Model (softmax)
|
| ============================================================================
|
|
|
| A PyTorch module for joint causal extraction using softmax decoding for BIO tagging.
|
| The model supports class weights for handling imbalanced data.
|
|
|
| ```python
|
| >>> model = JointCausalModel() # softmax-based model
|
| """
|
|
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class Span:
|
| role: str
|
| start_tok: int
|
| end_tok: int
|
| text: str
|
| is_virtual: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| class JointCausalConfig(PretrainedConfig):
|
| model_type = "joint_causal"
|
|
|
| def __init__(
|
| self,
|
| encoder_name="bert-base-uncased",
|
| num_cls_labels=2,
|
| num_bio_labels=7,
|
| num_rel_labels=2,
|
| dropout=0.2,
|
| **kwargs,
|
| ):
|
| self.encoder_name = encoder_name
|
| self.num_cls_labels = num_cls_labels
|
| self.num_bio_labels = num_bio_labels
|
| self.num_rel_labels = num_rel_labels
|
| self.dropout = dropout
|
| super().__init__(**kwargs)
|
|
|
|
|
|
|
| class JointCausalModel(PreTrainedModel):
|
| """
|
| The updated JointCausalModel, inheriting from PreTrainedModel.
|
| """
|
|
|
| config_class = JointCausalConfig
|
|
|
|
|
| def __init__(self, config: JointCausalConfig):
|
| super().__init__(config)
|
| self.config = config
|
|
|
| self.enc = AutoModel.from_pretrained(config.encoder_name)
|
| self.hidden_size = self.enc.config.hidden_size
|
| self.dropout = nn.Dropout(config.dropout)
|
| self.layer_norm = nn.LayerNorm(self.hidden_size)
|
|
|
|
|
| self.cls_head = nn.Sequential(
|
| nn.Linear(self.hidden_size, self.hidden_size // 2),
|
| nn.ReLU(),
|
| nn.Dropout(config.dropout),
|
| nn.Linear(self.hidden_size // 2, config.num_cls_labels),
|
| )
|
| self.bio_head = nn.Sequential(
|
| nn.Linear(self.hidden_size, self.hidden_size),
|
| nn.ReLU(),
|
| nn.Dropout(config.dropout),
|
| nn.Linear(self.hidden_size, self.hidden_size // 2),
|
| nn.ReLU(),
|
| nn.Dropout(config.dropout),
|
| nn.Linear(self.hidden_size // 2, config.num_bio_labels),
|
| )
|
| self.rel_head = nn.Sequential(
|
| nn.Linear(self.hidden_size * 2, self.hidden_size),
|
| nn.ReLU(),
|
| nn.Dropout(config.dropout),
|
| nn.Linear(self.hidden_size, self.hidden_size // 2),
|
| nn.ReLU(),
|
| nn.Dropout(config.dropout),
|
| nn.Linear(self.hidden_size // 2, config.num_rel_labels),
|
| )
|
|
|
|
|
|
|
| def encode(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| hidden_states = self.enc(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
|
| return self.layer_norm(self.dropout(hidden_states))
|
|
|
| def forward(
|
| self,
|
| input_ids: torch.Tensor,
|
| attention_mask: torch.Tensor,
|
| bio_labels: torch.Tensor | None = None,
|
| pair_batch: torch.Tensor | None = None,
|
| cause_starts: torch.Tensor | None = None,
|
| cause_ends: torch.Tensor | None = None,
|
| effect_starts: torch.Tensor | None = None,
|
| effect_ends: torch.Tensor | None = None,
|
| **kwargs,
|
| ) -> Dict[str, torch.Tensor | None]:
|
| hidden = self.encode(input_ids, attention_mask)
|
| cls_logits = self.cls_head(hidden[:, 0])
|
| emissions = self.bio_head(hidden)
|
| tag_loss = torch.tensor(0.0, device=emissions.device) if bio_labels is not None else None
|
|
|
| rel_logits: torch.Tensor | None = None
|
| if pair_batch is not None and cause_starts is not None and cause_ends is not None \
|
| and effect_starts is not None and effect_ends is not None:
|
| bio_states_for_rel = hidden[pair_batch]
|
| seq_len_rel = bio_states_for_rel.size(1)
|
| pos_rel = torch.arange(seq_len_rel, device=bio_states_for_rel.device).unsqueeze(0)
|
| c_mask = ((cause_starts.unsqueeze(1) <= pos_rel) & (pos_rel <= cause_ends.unsqueeze(1))).unsqueeze(2)
|
| e_mask = ((effect_starts.unsqueeze(1) <= pos_rel) & (pos_rel <= effect_ends.unsqueeze(1))).unsqueeze(2)
|
| c_vec = (bio_states_for_rel * c_mask).sum(1) / c_mask.sum(1).clamp(min=1)
|
| e_vec = (bio_states_for_rel * e_mask).sum(1) / e_mask.sum(1).clamp(min=1)
|
| rel_logits = self.rel_head(torch.cat([c_vec, e_vec], dim=1))
|
|
|
| return {
|
| "cls_logits": cls_logits,
|
| "bio_emissions": emissions,
|
| "tag_loss": tag_loss,
|
| "rel_logits": rel_logits,
|
| }
|
|
|
| def predict(self, sents: List[str], tokenizer, rel_mode="auto", rel_threshold=0.4, cause_decision="cls+span") -> list:
|
|
|
|
|
|
|
|
|
| device = next(self.parameters()).device
|
| outs = []
|
| for txt in sents:
|
| enc = tokenizer([txt], return_tensors="pt", truncation=True, max_length=512)
|
| enc = {k: v.to(device) for k, v in enc.items()}
|
| with torch.no_grad():
|
| rel_args = {}
|
| rel_pair_spans = []
|
| if rel_mode == "head":
|
| res_tmp = self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"])
|
| bio_tmp = res_tmp["bio_emissions"].squeeze(0).argmax(-1).tolist()
|
| tok_tmp = tokenizer.convert_ids_to_tokens(enc["input_ids"].squeeze(0))
|
| lab_tmp = [id2label_bio[i] for i in bio_tmp]
|
| fixed_tmp = JointCausalModel._apply_bio_rules(tok_tmp, lab_tmp)
|
|
|
| spans_tmp = JointCausalModel._merge_spans(tok_tmp, fixed_tmp, tokenizer)
|
|
|
| c_spans = [s for s in spans_tmp if s.role in ("C", "CE")]
|
| e_spans = [s for s in spans_tmp if s.role in ("E", "CE")]
|
| pair_batch = []
|
| cause_starts = []
|
| cause_ends = []
|
| effect_starts = []
|
| effect_ends = []
|
| for c in c_spans:
|
| for e in e_spans:
|
| if c.start_tok == e.start_tok and c.end_tok == e.end_tok:
|
| continue
|
| pair_batch.append(0)
|
| cause_starts.append(c.start_tok)
|
| cause_ends.append(c.end_tok)
|
| effect_starts.append(e.start_tok)
|
| effect_ends.append(e.end_tok)
|
| rel_pair_spans.append((c, e))
|
| if pair_batch:
|
| rel_args = {
|
| "pair_batch": torch.tensor(pair_batch, device=device),
|
| "cause_starts": torch.tensor(cause_starts, device=device),
|
| "cause_ends": torch.tensor(cause_ends, device=device),
|
| "effect_starts": torch.tensor(effect_starts, device=device),
|
| "effect_ends": torch.tensor(effect_ends, device=device),
|
| }
|
|
|
| res = self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], **rel_args)
|
| cls = res["cls_logits"].squeeze(0)
|
| bio = res["bio_emissions"].squeeze(0).argmax(-1).tolist()
|
| tok = tokenizer.convert_ids_to_tokens(enc["input_ids"].squeeze(0))
|
| lab = [id2label_bio[i] for i in bio]
|
| fixed = JointCausalModel._apply_bio_rules(tok, lab)
|
|
|
| spans = JointCausalModel._merge_spans(tok, fixed, tokenizer)
|
|
|
| causal = JointCausalModel._decide_causal(cls, spans, cause_decision)
|
| if not causal:
|
| outs.append({"text": txt, "causal": False, "relations": []})
|
| continue
|
| rels = []
|
| rel_logits = res.get("rel_logits")
|
| rel_probs = None
|
| if rel_logits is not None:
|
| rel_probs = torch.softmax(rel_logits, dim=-1)
|
| if rel_mode == "head":
|
| for idx, (csp, esp) in enumerate(rel_pair_spans):
|
| if rel_probs[idx, 1].item() > rel_threshold:
|
| rels.append({"cause": csp.text, "effect": esp.text, "type": "Rel_CE"})
|
| elif rel_mode == "auto":
|
| c_spans = [s for s in spans if s.role in ("C", "CE")]
|
| e_spans = [s for s in spans if s.role in ("E", "CE")]
|
| if not c_spans or not e_spans:
|
| rels = []
|
| elif len(c_spans) == 1 and len(e_spans) >= 1:
|
| for e_val in e_spans:
|
| rels.append({"cause": c_spans[0].text, "effect": e_val.text, "type": "Rel_CE"})
|
| elif len(e_spans) == 1 and len(c_spans) >= 1:
|
| for c_val in c_spans:
|
| rels.append({"cause": c_val.text, "effect": e_spans[0].text, "type": "Rel_CE"})
|
| elif len(c_spans) > 1 and len(e_spans) > 1:
|
| pair_batch_auto = []
|
| cause_starts_auto = []
|
| cause_ends_auto = []
|
| effect_starts_auto = []
|
| effect_ends_auto = []
|
| rel_pair_spans_auto = []
|
| for c_val in c_spans:
|
| for e_val in e_spans:
|
| if (c_val.start_tok == e_val.start_tok and c_val.end_tok == e_val.end_tok):
|
| continue
|
| pair_batch_auto.append(0)
|
| cause_starts_auto.append(c_val.start_tok)
|
| cause_ends_auto.append(c_val.end_tok)
|
| effect_starts_auto.append(e_val.start_tok)
|
| effect_ends_auto.append(e_val.end_tok)
|
| rel_pair_spans_auto.append((c_val, e_val))
|
| if pair_batch_auto:
|
| rel_args_auto = {
|
| "pair_batch": torch.tensor(pair_batch_auto, device=device),
|
| "cause_starts": torch.tensor(cause_starts_auto, device=device),
|
| "cause_ends": torch.tensor(cause_ends_auto, device=device),
|
| "effect_starts": torch.tensor(effect_starts_auto, device=device),
|
| "effect_ends": torch.tensor(effect_ends_auto, device=device),
|
| }
|
| res_rel = self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], **rel_args_auto)
|
| rel_logits_auto = res_rel.get("rel_logits")
|
| if rel_logits_auto is not None:
|
| rel_probs_auto = torch.softmax(rel_logits_auto, dim=-1)
|
| for idx, (csp, esp) in enumerate(rel_pair_spans_auto):
|
| if rel_probs_auto[idx, 1].item() > rel_threshold:
|
| rels.append({"cause": csp.text, "effect": esp.text, "type": "Rel_CE"})
|
|
|
| is_causal_final = JointCausalModel._decide_causal(cls, spans, cause_decision)
|
| if not rels and is_causal_final :
|
| is_causal_final = False
|
|
|
| if not is_causal_final or not rels:
|
| outs.append({"text": txt, "causal": False, "relations": []})
|
| else:
|
| outs.append({"text": txt, "causal": True, "relations": rels})
|
| return outs
|
|
|
| @staticmethod
|
| def _apply_bio_rules(tok, lab):
|
|
|
| _PUNCT = {".",",",";",":","?","!","(",")","[","]","{","}"}
|
| _STOPWORD_KEEP = {"this","that","these","those","it","they"}
|
|
|
| rep, n = lab.copy(), len(tok)
|
| def blocks():
|
| i=0
|
| while i<n:
|
| if rep[i]=="O": i+=1; continue
|
| s_idx=i
|
| while i+1<n and rep[i+1]!="O": i+=1
|
| yield s_idx,i; i+=1
|
| for s_idx,e_idx in list(blocks()):
|
| roles=[rep[j].split("-")[-1] for j in range(s_idx,e_idx+1)]
|
| if len(set(roles))>1:
|
| split=next((j for j in range(s_idx+1,e_idx+1) if roles[j-s_idx]!=roles[j-s_idx-1]),None)
|
| if split:
|
| if 1 in {split-s_idx,e_idx-split+1}:
|
| maj=roles[0] if split-s_idx>e_idx-split+1 else roles[-1]
|
| for j in range(s_idx,e_idx+1): rep[j]=f"B-{maj}" if j==s_idx else f"I-{maj}"
|
| for i,t in enumerate(tok):
|
| if rep[i]!="O" and t in _PUNCT: rep[i]="O"
|
| def labeled(v):
|
| i=0; out=[]
|
| while i<n:
|
| if v[i]=="O": i+=1; continue
|
| s_idx=i; role=v[i].split("-")[-1]
|
| while i+1<n and v[i+1]!="O": i+=1
|
| out.append((s_idx,i,role)); i+=1
|
| return out
|
| bl=labeled(rep)
|
| if any(r=="CE" for *_,r in bl):
|
| cntc=sum(1 for *_,r in bl if r=="C"); cnte=sum(1 for *_,r in bl if r=="E")
|
| if cntc==0 or cnte==0:
|
| tr="C" if cntc==0 else "E"
|
| for s_idx,e_idx,r in bl:
|
| if r=="CE":
|
| for idx in range(s_idx,e_idx+1): rep[idx]=f"B-{tr}" if idx==s_idx else f"I-{tr}"
|
| bl=labeled(rep)
|
| for s_idx,e_idx,_ in bl:
|
| if tok[e_idx] in _PUNCT: rep[e_idx]="O"
|
| if e_idx==s_idx and len(tok[s_idx])<=2 and tok[s_idx].lower() not in _STOPWORD_KEEP: rep[s_idx]="O"
|
| return rep
|
|
|
|
|
| @staticmethod
|
| def _merge_spans(tok: List[str], lab: List[str], tokenizer):
|
|
|
|
|
|
|
|
|
| _CONNECTORS = {"of","to","with","for","the"}
|
| spans_list = []
|
| i=0
|
| while i<len(tok):
|
| if lab[i]=="O": i+=1; continue
|
| role=lab[i].split("-")[-1]; s_idx=i
|
| while i+1<len(tok) and lab[i+1]!="O": i+=1
|
|
|
| spans_list.append(Span(role,s_idx,i,tokenizer.convert_tokens_to_string(tok[s_idx:i+1])))
|
| i+=1
|
|
|
| merged=[spans_list[0]] if spans_list else []
|
| for sp_item in spans_list[1:]:
|
| prv=merged[-1]
|
| if sp_item.role==prv.role and sp_item.start_tok==prv.end_tok+2 and tok[prv.end_tok+1].lower() in _CONNECTORS:
|
| merged[-1]=Span(prv.role,prv.start_tok,sp_item.end_tok,tokenizer.convert_tokens_to_string(tok[prv.start_tok:sp_item.end_tok+1]),prv.is_virtual)
|
| else: merged.append(sp_item)
|
| return merged
|
|
|
| @staticmethod
|
| def _decide_causal(cls, spans, mode):
|
|
|
| span_ok = any(s.role == "C" or s.role == "CE" for s in spans) and \
|
| any(s.role == "E" or s.role == "CE" for s in spans)
|
|
|
| if any(s.role == "CE" for s in spans):
|
|
|
| has_cause_element = any(s.role == "C" or s.role == "CE" for s in spans)
|
| has_effect_element = any(s.role == "E" or s.role == "CE" for s in spans)
|
| span_ok = has_cause_element and has_effect_element
|
|
|
|
|
| cls_c = cls.argmax(-1).item() == 1
|
| if mode == "cls_only":
|
| return cls_c
|
| if mode == "span_only":
|
| return span_ok
|
| if mode == "cls+span":
|
| return cls_c and span_ok
|
| raise ValueError(f"Unknown cause_decision mode: {mode}")
|
|
|
|
|