CausaLMiner / modeling_joint_causal.py
rasoultilburg's picture
optimizing prediction method without loading double AutoTokenizer!
376e8e9 verified
from __future__ import annotations
from typing import Dict, List
import torch
import torch.nn as nn
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
from dataclasses import dataclass
try:
from .config import id2label_bio, id2label_rel, id2label_cls
except ImportError:
from config import id2label_bio, id2label_rel, id2label_cls
try:
from .configuration_joint_causal import JointCausalConfig
except ImportError:
from configuration_joint_causal import JointCausalConfig
# ---------------------------------------------------------------------------
# Type aliases & label maps
# ---------------------------------------------------------------------------
label2id_bio = {v: k for k, v in id2label_bio.items()}
label2id_rel = {v: k for k, v in id2label_rel.items()}
label2id_cls = {v: k for k, v in id2label_cls.items()}
# ---------------------------------------------------------------------------
# Main module
# ---------------------------------------------------------------------------
"""Joint Causal Extraction Model (softmax)
============================================================================
A PyTorch module for joint causal extraction using softmax decoding for BIO tagging.
The model supports class weights for handling imbalanced data.
```python
>>> model = JointCausalModel() # softmax-based model
"""
# ---------------------------------------------------------------------------
# Span dataclass
# ---------------------------------------------------------------------------
@dataclass
class Span:
role: str
start_tok: int
end_tok: int
text: str
is_virtual: bool = False
# ---------------------------------------------------------------------------
# Main module
# ---------------------------------------------------------------------------
class JointCausalConfig(PretrainedConfig):
model_type = "joint_causal" # A unique name for your model type
def __init__(
self,
encoder_name="bert-base-uncased",
num_cls_labels=2,
num_bio_labels=7,
num_rel_labels=2,
dropout=0.2,
**kwargs,
):
self.encoder_name = encoder_name
self.num_cls_labels = num_cls_labels
self.num_bio_labels = num_bio_labels
self.num_rel_labels = num_rel_labels
self.dropout = dropout
super().__init__(**kwargs)
class JointCausalModel(PreTrainedModel):
"""
The updated JointCausalModel, inheriting from PreTrainedModel.
"""
# Link the model to its config class, as shown in the tutorial.
config_class = JointCausalConfig
# The __init__ method now accepts a single `config` object.
def __init__(self, config: JointCausalConfig):
super().__init__(config)
self.config = config
self.enc = AutoModel.from_pretrained(config.encoder_name)
self.hidden_size = self.enc.config.hidden_size
self.dropout = nn.Dropout(config.dropout)
self.layer_norm = nn.LayerNorm(self.hidden_size)
# The model layers are now built using parameters from the config object.
self.cls_head = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size // 2),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(self.hidden_size // 2, config.num_cls_labels),
)
self.bio_head = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(self.hidden_size, self.hidden_size // 2),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(self.hidden_size // 2, config.num_bio_labels),
)
self.rel_head = nn.Sequential(
nn.Linear(self.hidden_size * 2, self.hidden_size),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(self.hidden_size, self.hidden_size // 2),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(self.hidden_size // 2, config.num_rel_labels),
)
# No need to manually initialize weights, PreTrainedModel handles it.
# The encode and forward methods are mostly the same
def encode(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
hidden_states = self.enc(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
return self.layer_norm(self.dropout(hidden_states))
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
bio_labels: torch.Tensor | None = None,
pair_batch: torch.Tensor | None = None,
cause_starts: torch.Tensor | None = None,
cause_ends: torch.Tensor | None = None,
effect_starts: torch.Tensor | None = None,
effect_ends: torch.Tensor | None = None,
**kwargs, # Accept extra arguments
) -> Dict[str, torch.Tensor | None]:
hidden = self.encode(input_ids, attention_mask)
cls_logits = self.cls_head(hidden[:, 0])
emissions = self.bio_head(hidden)
tag_loss = torch.tensor(0.0, device=emissions.device) if bio_labels is not None else None
rel_logits: torch.Tensor | None = None
if pair_batch is not None and cause_starts is not None and cause_ends is not None \
and effect_starts is not None and effect_ends is not None:
bio_states_for_rel = hidden[pair_batch]
seq_len_rel = bio_states_for_rel.size(1)
pos_rel = torch.arange(seq_len_rel, device=bio_states_for_rel.device).unsqueeze(0)
c_mask = ((cause_starts.unsqueeze(1) <= pos_rel) & (pos_rel <= cause_ends.unsqueeze(1))).unsqueeze(2)
e_mask = ((effect_starts.unsqueeze(1) <= pos_rel) & (pos_rel <= effect_ends.unsqueeze(1))).unsqueeze(2)
c_vec = (bio_states_for_rel * c_mask).sum(1) / c_mask.sum(1).clamp(min=1)
e_vec = (bio_states_for_rel * e_mask).sum(1) / e_mask.sum(1).clamp(min=1)
rel_logits = self.rel_head(torch.cat([c_vec, e_vec], dim=1))
return {
"cls_logits": cls_logits,
"bio_emissions": emissions,
"tag_loss": tag_loss,
"rel_logits": rel_logits,
}
def predict(self, sents: List[str], tokenizer, rel_mode="auto", rel_threshold=0.4, cause_decision="cls+span") -> list:
# Removed:
# if tokenizer is None:
# tokenizer = AutoTokenizer.from_pretrained(self.config.encoder_name) # Use self.config
device = next(self.parameters()).device # or self.device
outs = []
for txt in sents:
enc = tokenizer([txt], return_tensors="pt", truncation=True, max_length=512)
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
rel_args = {}
rel_pair_spans = []
if rel_mode == "head":
res_tmp = self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"])
bio_tmp = res_tmp["bio_emissions"].squeeze(0).argmax(-1).tolist()
tok_tmp = tokenizer.convert_ids_to_tokens(enc["input_ids"].squeeze(0))
lab_tmp = [id2label_bio[i] for i in bio_tmp]
fixed_tmp = JointCausalModel._apply_bio_rules(tok_tmp, lab_tmp)
# Pass tokenizer to _merge_spans
spans_tmp = JointCausalModel._merge_spans(tok_tmp, fixed_tmp, tokenizer)
# ... (rest of the head mode logic) ...
c_spans = [s for s in spans_tmp if s.role in ("C", "CE")]
e_spans = [s for s in spans_tmp if s.role in ("E", "CE")]
pair_batch = []
cause_starts = []
cause_ends = []
effect_starts = []
effect_ends = []
for c in c_spans:
for e in e_spans:
if c.start_tok == e.start_tok and c.end_tok == e.end_tok:
continue
pair_batch.append(0)
cause_starts.append(c.start_tok)
cause_ends.append(c.end_tok)
effect_starts.append(e.start_tok)
effect_ends.append(e.end_tok)
rel_pair_spans.append((c, e))
if pair_batch:
rel_args = {
"pair_batch": torch.tensor(pair_batch, device=device),
"cause_starts": torch.tensor(cause_starts, device=device),
"cause_ends": torch.tensor(cause_ends, device=device),
"effect_starts": torch.tensor(effect_starts, device=device),
"effect_ends": torch.tensor(effect_ends, device=device),
}
res = self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], **rel_args)
cls = res["cls_logits"].squeeze(0)
bio = res["bio_emissions"].squeeze(0).argmax(-1).tolist()
tok = tokenizer.convert_ids_to_tokens(enc["input_ids"].squeeze(0))
lab = [id2label_bio[i] for i in bio]
fixed = JointCausalModel._apply_bio_rules(tok, lab)
# Pass tokenizer to _merge_spans
spans = JointCausalModel._merge_spans(tok, fixed, tokenizer)
# ... (rest of the predict method logic remains the same) ...
causal = JointCausalModel._decide_causal(cls, spans, cause_decision)
if not causal:
outs.append({"text": txt, "causal": False, "relations": []})
continue
rels = []
rel_logits = res.get("rel_logits")
rel_probs = None
if rel_logits is not None:
rel_probs = torch.softmax(rel_logits, dim=-1)
if rel_mode == "head":
for idx, (csp, esp) in enumerate(rel_pair_spans):
if rel_probs[idx, 1].item() > rel_threshold: # Ensure rel_probs is not None
rels.append({"cause": csp.text, "effect": esp.text, "type": "Rel_CE"})
elif rel_mode == "auto":
c_spans = [s for s in spans if s.role in ("C", "CE")]
e_spans = [s for s in spans if s.role in ("E", "CE")]
if not c_spans or not e_spans:
rels = []
elif len(c_spans) == 1 and len(e_spans) >= 1:
for e_val in e_spans: # Renamed e to e_val to avoid conflict
rels.append({"cause": c_spans[0].text, "effect": e_val.text, "type": "Rel_CE"})
elif len(e_spans) == 1 and len(c_spans) >= 1:
for c_val in c_spans: # Renamed c to c_val
rels.append({"cause": c_val.text, "effect": e_spans[0].text, "type": "Rel_CE"})
elif len(c_spans) > 1 and len(e_spans) > 1:
pair_batch_auto = [] # Use different variable name
cause_starts_auto = []
cause_ends_auto = []
effect_starts_auto = []
effect_ends_auto = []
rel_pair_spans_auto = []
for c_val in c_spans:
for e_val in e_spans:
if (c_val.start_tok == e_val.start_tok and c_val.end_tok == e_val.end_tok):
continue
pair_batch_auto.append(0)
cause_starts_auto.append(c_val.start_tok)
cause_ends_auto.append(c_val.end_tok)
effect_starts_auto.append(e_val.start_tok)
effect_ends_auto.append(e_val.end_tok)
rel_pair_spans_auto.append((c_val, e_val))
if pair_batch_auto:
rel_args_auto = { # Use different variable name
"pair_batch": torch.tensor(pair_batch_auto, device=device),
"cause_starts": torch.tensor(cause_starts_auto, device=device),
"cause_ends": torch.tensor(cause_ends_auto, device=device),
"effect_starts": torch.tensor(effect_starts_auto, device=device),
"effect_ends": torch.tensor(effect_ends_auto, device=device),
}
res_rel = self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], **rel_args_auto)
rel_logits_auto = res_rel.get("rel_logits") # Use different variable name
if rel_logits_auto is not None:
rel_probs_auto = torch.softmax(rel_logits_auto, dim=-1) # Use different variable name
for idx, (csp, esp) in enumerate(rel_pair_spans_auto):
if rel_probs_auto[idx, 1].item() > rel_threshold:
rels.append({"cause": csp.text, "effect": esp.text, "type": "Rel_CE"})
# Final causality decision
is_causal_final = JointCausalModel._decide_causal(cls, spans, cause_decision)
if not rels and is_causal_final : # If relations list is empty but sentence was deemed causal by cls/span
is_causal_final = False # Override to non-causal if no concrete relations found
if not is_causal_final or not rels: # if not causal or no relations
outs.append({"text": txt, "causal": False, "relations": []})
else:
outs.append({"text": txt, "causal": True, "relations": rels})
return outs
@staticmethod
def _apply_bio_rules(tok, lab):
# ... (This method seems okay as it doesn't load tokenizers) ...
_PUNCT = {".",",",";",":","?","!","(",")","[","]","{","}"}
_STOPWORD_KEEP = {"this","that","these","those","it","they"}
rep, n = lab.copy(), len(tok)
def blocks():
i=0
while i<n:
if rep[i]=="O": i+=1; continue
s_idx=i # Renamed s to s_idx
while i+1<n and rep[i+1]!="O": i+=1
yield s_idx,i; i+=1 # Renamed s to s_idx
for s_idx,e_idx in list(blocks()): # Renamed s,e to s_idx,e_idx
roles=[rep[j].split("-")[-1] for j in range(s_idx,e_idx+1)]
if len(set(roles))>1:
split=next((j for j in range(s_idx+1,e_idx+1) if roles[j-s_idx]!=roles[j-s_idx-1]),None)
if split:
if 1 in {split-s_idx,e_idx-split+1}: # Renamed s to s_idx
maj=roles[0] if split-s_idx>e_idx-split+1 else roles[-1] # Renamed s to s_idx
for j in range(s_idx,e_idx+1): rep[j]=f"B-{maj}" if j==s_idx else f"I-{maj}" # Renamed s to s_idx
for i,t in enumerate(tok):
if rep[i]!="O" and t in _PUNCT: rep[i]="O"
def labeled(v):
i=0; out=[]
while i<n:
if v[i]=="O": i+=1; continue
s_idx=i; role=v[i].split("-")[-1] # Renamed s to s_idx
while i+1<n and v[i+1]!="O": i+=1
out.append((s_idx,i,role)); i+=1 # Renamed s to s_idx
return out
bl=labeled(rep)
if any(r=="CE" for *_,r in bl):
cntc=sum(1 for *_,r in bl if r=="C"); cnte=sum(1 for *_,r in bl if r=="E")
if cntc==0 or cnte==0:
tr="C" if cntc==0 else "E"
for s_idx,e_idx,r in bl: # Renamed s,e to s_idx,e_idx
if r=="CE":
for idx in range(s_idx,e_idx+1): rep[idx]=f"B-{tr}" if idx==s_idx else f"I-{tr}" # Renamed s to s_idx
bl=labeled(rep)
for s_idx,e_idx,_ in bl: # Renamed s,e to s_idx,e_idx
if tok[e_idx] in _PUNCT: rep[e_idx]="O"
if e_idx==s_idx and len(tok[s_idx])<=2 and tok[s_idx].lower() not in _STOPWORD_KEEP: rep[s_idx]="O" # Renamed s to s_idx
return rep
# MODIFIED _merge_spans: tokenizer is now a required argument
@staticmethod
def _merge_spans(tok: List[str], lab: List[str], tokenizer): # Added tokenizer argument
# Removed:
# from .config import MODEL_CONFIG # This should not be needed here
# tokenizer = AutoTokenizer.from_pretrained(MODEL_CONFIG["encoder_name"])
_CONNECTORS = {"of","to","with","for","the"}
spans_list = [] # Renamed spans to spans_list
i=0
while i<len(tok):
if lab[i]=="O": i+=1; continue
role=lab[i].split("-")[-1]; s_idx=i # Renamed s to s_idx
while i+1<len(tok) and lab[i+1]!="O": i+=1
# Use the passed tokenizer instance
spans_list.append(Span(role,s_idx,i,tokenizer.convert_tokens_to_string(tok[s_idx:i+1]))) # Renamed s to s_idx
i+=1
merged=[spans_list[0]] if spans_list else []
for sp_item in spans_list[1:]: # Renamed sp to sp_item
prv=merged[-1]
if sp_item.role==prv.role and sp_item.start_tok==prv.end_tok+2 and tok[prv.end_tok+1].lower() in _CONNECTORS:
merged[-1]=Span(prv.role,prv.start_tok,sp_item.end_tok,tokenizer.convert_tokens_to_string(tok[prv.start_tok:sp_item.end_tok+1]),prv.is_virtual)
else: merged.append(sp_item)
return merged
@staticmethod
def _decide_causal(cls, spans, mode): # spans is the list of Span objects
# ... (This method seems okay) ...
span_ok = any(s.role == "C" or s.role == "CE" for s in spans) and \
any(s.role == "E" or s.role == "CE" for s in spans)
# If CE is present, it can act as both C and E for span_ok check
if any(s.role == "CE" for s in spans):
# Check if there's at least one CE or (one C and one E)
has_cause_element = any(s.role == "C" or s.role == "CE" for s in spans)
has_effect_element = any(s.role == "E" or s.role == "CE" for s in spans)
span_ok = has_cause_element and has_effect_element
cls_c = cls.argmax(-1).item() == 1 # Assuming 1 is 'causal' for cls_logits
if mode == "cls_only":
return cls_c
if mode == "span_only":
return span_ok
if mode == "cls+span":
return cls_c and span_ok
raise ValueError(f"Unknown cause_decision mode: {mode}")