| import argparse, json, torch |
| from dataclasses import dataclass |
| from typing import List, Dict, Optional, Tuple |
| from transformers import AutoTokenizer |
| from torch.utils.data import Dataset |
|
|
| |
| |
| |
| @dataclass |
| class SRLSample: |
| words: List[str] |
| predicate_word_idx: int |
| labels: List[str] |
| predicate_form: Optional[str] = None |
|
|
|
|
| |
| |
| |
| def _bio_from_brackets(tags): |
| bio, stack = [], [] |
| for t in tags: |
| if "(V*" in t: |
| bio.append("B-V") |
| continue |
| opens = [] |
| i = 0 |
| while True: |
| s = t.find("(", i) |
| if s == -1: break |
| e = t.find("*", s) |
| if e == -1: break |
| role = t[s+1:e] |
| opens.append(role) |
| i = e + 1 |
| closes = t.count(")") |
| if opens: |
| bio.append(f"B-{opens[0]}") |
| for r in opens: stack.append(r) |
| elif stack: |
| bio.append(f"I-{stack[-1]}") |
| else: |
| bio.append("O") |
| for _ in range(closes): |
| if stack: stack.pop() |
| return bio |
|
|
|
|
| def _read_sentences(path): |
| sent = [] |
| with open(path, "r", encoding="utf-8", errors="replace") as f: |
| for line in f: |
| line = line.rstrip("\n") |
| if not line: |
| if sent: yield sent; sent = [] |
| continue |
| cols = line.split() |
| if cols: sent.append(cols) |
| if sent: yield sent |
|
|
|
|
| |
| |
| |
| def load_conll_samples(in_path, word_col_idx=3, srl_first_col_idx=11): |
| """ |
| Reads .gold_conll file and returns list[SRLSample], |
| one per predicate column. |
| """ |
| samples = [] |
| for sent in _read_sentences(in_path): |
| words = [row[word_col_idx] for row in sent] |
| max_cols = max(len(row) for row in sent) |
| for srl_col in range(srl_first_col_idx, max_cols): |
| tags = [row[srl_col] if srl_col < len(row) else "*" for row in sent] |
| try: |
| pred_idx = next(i for i, t in enumerate(tags) if "(V*" in t) |
| except StopIteration: |
| continue |
| labels = _bio_from_brackets(tags) |
| predicate_form = words[pred_idx] |
| samples.append(SRLSample(words, pred_idx, labels, predicate_form)) |
| print(f"[SRL_preprocessing] Loaded {len(samples)} predicate instances from {in_path}") |
| return samples |
|
|
|
|
| |
| |
| |
| class SRLDataset(Dataset): |
| def __init__(self, samples: List[SRLSample], tokenizer: AutoTokenizer, |
| label2id: Dict[str, int], max_length: int = 256, debug_print=False): |
| self.samples = samples |
| self.tokenizer = tokenizer |
| self.label2id = label2id |
| self.id2label = {v: k for k, v in label2id.items()} |
| self.max_length = max_length |
| self.debug_print = debug_print |
|
|
| def __len__(self): return len(self.samples) |
|
|
| def _tokenize_sentence(self, words): |
| return self.tokenizer(words, is_split_into_words=True, |
| add_special_tokens=False, return_attention_mask=False, |
| return_token_type_ids=False) |
|
|
| def _tokenize_predicate(self, form): |
| return self.tokenizer(form, add_special_tokens=False, |
| return_attention_mask=False, |
| return_token_type_ids=False) |
|
|
| def __getitem__(self, idx): |
| instance = self.samples[idx] |
| words = instance.words |
| n_words = len(words) |
| pred_form = instance.predicate_form or words[instance.predicate_word_idx] |
|
|
| enc_sent = self._tokenize_sentence(words) |
| enc_pred = self._tokenize_predicate(pred_form) |
| sent_wp_ids = enc_sent["input_ids"] |
| pred_wp_ids = enc_pred["input_ids"] |
|
|
| input_ids = [self.tokenizer.cls_token_id] + sent_wp_ids + [self.tokenizer.sep_token_id] \ |
| + pred_wp_ids + [self.tokenizer.sep_token_id] |
| ttids = [0] * (1 + len(sent_wp_ids) + 1) + [1] * (len(pred_wp_ids) + 1) |
|
|
| tmp = self.tokenizer(words, is_split_into_words=True) |
| word_ids = tmp.word_ids() |
| first_pos_by_wid = {} |
| for pos, wid in enumerate(word_ids): |
| if wid is not None and wid not in first_pos_by_wid: |
| first_pos_by_wid[wid] = pos |
| word_first_wp_fullidx = [first_pos_by_wid[w] for w in range(n_words)] |
|
|
| label_ids = [self.label2id[l] for l in instance.labels] |
| indicator = [0]*n_words; indicator[instance.predicate_word_idx] = 1 |
| attention_mask = [1]*len(input_ids) |
|
|
| if len(input_ids) > self.max_length: |
| max_pos = self.max_length-1 |
| input_ids = input_ids[:self.max_length] |
| ttids = ttids[:self.max_length] |
| attention_mask = attention_mask[:self.max_length] |
| word_first_wp_fullidx = [min(p, max_pos) for p in word_first_wp_fullidx] |
|
|
| return { |
| "input_ids": torch.tensor(input_ids, dtype=torch.long), |
| "token_type_ids": torch.tensor(ttids, dtype=torch.long), |
| "attention_mask": torch.tensor(attention_mask, dtype=torch.long), |
| "word_first_wp_fullidx": torch.tensor(word_first_wp_fullidx, dtype=torch.long), |
| "labels": torch.tensor(label_ids, dtype=torch.long), |
| "indicator": torch.tensor(indicator, dtype=torch.long), |
| "sent_len": torch.tensor(len(words), dtype=torch.long), |
| "pred_word_idx": torch.tensor(instance.predicate_word_idx, dtype=torch.long) |
| } |
|
|
|
|
| def srl_collate(batch: List[Dict], pad_token_id: int, pad_label_id: int = -100): |
| B = len(batch) |
| max_L = max(item["input_ids"].size(0) for item in batch) |
| input_ids = torch.full((B, max_L), pad_token_id, dtype=torch.long) |
| token_type_ids = torch.zeros((B, max_L), dtype=torch.long) |
| attention_mask = torch.zeros((B, max_L), dtype=torch.long) |
| max_n = max(int(item["sent_len"]) for item in batch) |
| word_first_wp_fullidx = torch.full((B, max_n), -1, dtype=torch.long) |
| labels = torch.full((B, max_n), pad_label_id, dtype=torch.long) |
| indicator = torch.zeros((B, max_n), dtype=torch.long) |
| sent_lens = torch.zeros((B,), dtype=torch.long) |
| pred_word_idx = torch.zeros((B,), dtype=torch.long) |
| sentence_mask = torch.zeros((B, max_n), dtype=torch.bool) |
|
|
| for i, item in enumerate(batch): |
| L = item["input_ids"].size(0) |
| input_ids[i, :L] = item["input_ids"] |
| token_type_ids[i, :L] = item["token_type_ids"] |
| attention_mask[i, :L] = item["attention_mask"] |
| n = int(item["sent_len"]) |
| word_first_wp_fullidx[i, :n] = item["word_first_wp_fullidx"] |
| labels[i, :n] = item["labels"] |
| indicator[i, :n] = item["indicator"] |
| sent_lens[i] = n |
| pred_word_idx[i] = item["pred_word_idx"] |
| sentence_mask[i, :n] = True |
|
|
| return { |
| "input_ids": input_ids, |
| "token_type_ids": token_type_ids, |
| "attention_mask": attention_mask, |
| "word_first_wp_fullidx": word_first_wp_fullidx, |
| "sentence_mask": sentence_mask, |
| "labels": labels, |
| "indicator": indicator, |
| "sent_lens": sent_lens, |
| "pred_word_idx": pred_word_idx, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def data_processing_for_loader_conll( |
| train_conll: str, |
| dev_conll: Optional[str], |
| |
| tokenizer, |
| word_col_idx: int = 3, |
| srl_first_col_idx: int = 11, |
| max_length: int = 256 |
| ) -> Tuple[SRLDataset, Optional[SRLDataset], Dict[str, int], Dict[int, str]]: |
| """ |
| Reads train/dev/test .gold_conll files and returns: |
| train_dataset, dev_dataset, test_dataset, label2id, id2label |
| |
| * label set is computed from the UNION of train/dev/test labels |
| * dev/test can be None |
| """ |
| |
| train_samples = load_conll_samples(train_conll, word_col_idx, srl_first_col_idx) |
| dev_samples = load_conll_samples(dev_conll, word_col_idx, srl_first_col_idx) if dev_conll else [] |
| |
|
|
| |
| all_samples = train_samples + dev_samples |
| label2id = {} |
| for s in all_samples: |
| for lab in s.labels: |
| if lab not in label2id: |
| label2id[lab] = len(label2id) |
| id2label = {v: k for k, v in label2id.items()} |
|
|
| |
| train_ds = SRLDataset(train_samples, tokenizer, label2id, max_length=max_length) |
| dev_ds = SRLDataset(dev_samples, tokenizer, label2id, max_length=max_length) if dev_samples else None |
| |
|
|
| return train_ds, dev_ds, label2id, id2label |
|
|