srl_bert_model / SRL_preprocessing.py
yeomtong's picture
Update SRL_preprocessing.py
fe216b5 verified
raw
history blame
9.45 kB
import argparse, json, torch
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple
from transformers import AutoTokenizer
from torch.utils.data import Dataset
# ==============================================================
# 1. Data structure
# ==============================================================
@dataclass
class SRLSample:
words: List[str]
predicate_word_idx: int
labels: List[str]
predicate_form: Optional[str] = None
# ==============================================================
# 2. Bracket → BIO conversion (unchanged)
# ==============================================================
def _bio_from_brackets(tags):
bio, stack = [], []
for t in tags:
if "(V*" in t:
bio.append("B-V")
continue
opens = []
i = 0
while True:
s = t.find("(", i)
if s == -1: break
e = t.find("*", s)
if e == -1: break
role = t[s+1:e]
opens.append(role)
i = e + 1
closes = t.count(")")
if opens:
bio.append(f"B-{opens[0]}")
for r in opens: stack.append(r)
elif stack:
bio.append(f"I-{stack[-1]}")
else:
bio.append("O")
for _ in range(closes):
if stack: stack.pop()
return bio
def _read_sentences(path):
sent = []
with open(path, "r", encoding="utf-8", errors="replace") as f:
for line in f:
line = line.rstrip("\n")
if not line:
if sent: yield sent; sent = []
continue
cols = line.split()
if cols: sent.append(cols)
if sent: yield sent
# ==============================================================
# 3. CoNLL → SRLSample objects (in-memory)
# ==============================================================
def load_conll_samples(in_path, word_col_idx=3, srl_first_col_idx=11):
"""
Reads .gold_conll file and returns list[SRLSample],
one per predicate column.
"""
samples = []
for sent in _read_sentences(in_path):
words = [row[word_col_idx] for row in sent]
max_cols = max(len(row) for row in sent)
for srl_col in range(srl_first_col_idx, max_cols):
tags = [row[srl_col] if srl_col < len(row) else "*" for row in sent]
try:
pred_idx = next(i for i, t in enumerate(tags) if "(V*" in t)
except StopIteration:
continue
labels = _bio_from_brackets(tags)
predicate_form = words[pred_idx]
samples.append(SRLSample(words, pred_idx, labels, predicate_form))
print(f"[SRL_preprocessing] Loaded {len(samples)} predicate instances from {in_path}")
return samples
# ==============================================================
# 4. Dataset + Collate (same as yours, lightly cleaned)
# ==============================================================
class SRLDataset(Dataset):
def __init__(self, samples: List[SRLSample], tokenizer: AutoTokenizer,
label2id: Dict[str, int], max_length: int = 256, debug_print=False):
self.samples = samples
self.tokenizer = tokenizer
self.label2id = label2id
self.id2label = {v: k for k, v in label2id.items()}
self.max_length = max_length
self.debug_print = debug_print
def __len__(self): return len(self.samples)
def _tokenize_sentence(self, words):
return self.tokenizer(words, is_split_into_words=True,
add_special_tokens=False, return_attention_mask=False,
return_token_type_ids=False)
def _tokenize_predicate(self, form):
return self.tokenizer(form, add_special_tokens=False,
return_attention_mask=False,
return_token_type_ids=False)
def __getitem__(self, idx):
instance = self.samples[idx]
words = instance.words
n_words = len(words)
pred_form = instance.predicate_form or words[instance.predicate_word_idx]
enc_sent = self._tokenize_sentence(words)
enc_pred = self._tokenize_predicate(pred_form)
sent_wp_ids = enc_sent["input_ids"]
pred_wp_ids = enc_pred["input_ids"]
input_ids = [self.tokenizer.cls_token_id] + sent_wp_ids + [self.tokenizer.sep_token_id] \
+ pred_wp_ids + [self.tokenizer.sep_token_id]
ttids = [0] * (1 + len(sent_wp_ids) + 1) + [1] * (len(pred_wp_ids) + 1)
tmp = self.tokenizer(words, is_split_into_words=True)
word_ids = tmp.word_ids()
first_pos_by_wid = {}
for pos, wid in enumerate(word_ids):
if wid is not None and wid not in first_pos_by_wid:
first_pos_by_wid[wid] = pos
word_first_wp_fullidx = [first_pos_by_wid[w] for w in range(n_words)]
label_ids = [self.label2id[l] for l in instance.labels]
indicator = [0]*n_words; indicator[instance.predicate_word_idx] = 1
attention_mask = [1]*len(input_ids)
if len(input_ids) > self.max_length:
max_pos = self.max_length-1
input_ids = input_ids[:self.max_length]
ttids = ttids[:self.max_length]
attention_mask = attention_mask[:self.max_length]
word_first_wp_fullidx = [min(p, max_pos) for p in word_first_wp_fullidx]
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"token_type_ids": torch.tensor(ttids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"word_first_wp_fullidx": torch.tensor(word_first_wp_fullidx, dtype=torch.long),
"labels": torch.tensor(label_ids, dtype=torch.long),
"indicator": torch.tensor(indicator, dtype=torch.long),
"sent_len": torch.tensor(len(words), dtype=torch.long),
"pred_word_idx": torch.tensor(instance.predicate_word_idx, dtype=torch.long)
}
def srl_collate(batch: List[Dict], pad_token_id: int, pad_label_id: int = -100):
B = len(batch)
max_L = max(item["input_ids"].size(0) for item in batch)
input_ids = torch.full((B, max_L), pad_token_id, dtype=torch.long)
token_type_ids = torch.zeros((B, max_L), dtype=torch.long)
attention_mask = torch.zeros((B, max_L), dtype=torch.long)
max_n = max(int(item["sent_len"]) for item in batch)
word_first_wp_fullidx = torch.full((B, max_n), -1, dtype=torch.long)
labels = torch.full((B, max_n), pad_label_id, dtype=torch.long)
indicator = torch.zeros((B, max_n), dtype=torch.long)
sent_lens = torch.zeros((B,), dtype=torch.long)
pred_word_idx = torch.zeros((B,), dtype=torch.long)
sentence_mask = torch.zeros((B, max_n), dtype=torch.bool)
for i, item in enumerate(batch):
L = item["input_ids"].size(0)
input_ids[i, :L] = item["input_ids"]
token_type_ids[i, :L] = item["token_type_ids"]
attention_mask[i, :L] = item["attention_mask"]
n = int(item["sent_len"])
word_first_wp_fullidx[i, :n] = item["word_first_wp_fullidx"]
labels[i, :n] = item["labels"]
indicator[i, :n] = item["indicator"]
sent_lens[i] = n
pred_word_idx[i] = item["pred_word_idx"]
sentence_mask[i, :n] = True
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
"word_first_wp_fullidx": word_first_wp_fullidx,
"sentence_mask": sentence_mask,
"labels": labels,
"indicator": indicator,
"sent_lens": sent_lens,
"pred_word_idx": pred_word_idx,
}
# ==============================================================
# 5. Helper for trainer
# ==============================================================
def data_processing_for_loader_conll(
train_conll: str,
dev_conll: Optional[str],
# test_conll: Optional[str],
tokenizer,
word_col_idx: int = 3,
srl_first_col_idx: int = 11,
max_length: int = 256
) -> Tuple[SRLDataset, Optional[SRLDataset], Dict[str, int], Dict[int, str]]:
"""
Reads train/dev/test .gold_conll files and returns:
train_dataset, dev_dataset, test_dataset, label2id, id2label
* label set is computed from the UNION of train/dev/test labels
* dev/test can be None
"""
# Load samples
train_samples = load_conll_samples(train_conll, word_col_idx, srl_first_col_idx)
dev_samples = load_conll_samples(dev_conll, word_col_idx, srl_first_col_idx) if dev_conll else []
# test_samples = load_conll_samples(test_conll, word_col_idx, srl_first_col_idx) if test_conll else []
# Build label maps from ALL splits
all_samples = train_samples + dev_samples
label2id = {}
for s in all_samples:
for lab in s.labels:
if lab not in label2id:
label2id[lab] = len(label2id)
id2label = {v: k for k, v in label2id.items()}
# Datasets
train_ds = SRLDataset(train_samples, tokenizer, label2id, max_length=max_length)
dev_ds = SRLDataset(dev_samples, tokenizer, label2id, max_length=max_length) if dev_samples else None
# test_ds = SRLDataset(test_samples, tokenizer, label2id, max_length=max_length) if test_samples else None
return train_ds, dev_ds, label2id, id2label