import argparse, json, torch from dataclasses import dataclass from typing import List, Dict, Optional, Tuple from transformers import AutoTokenizer from torch.utils.data import Dataset # ============================================================== # 1. Data structure # ============================================================== @dataclass class SRLSample: words: List[str] predicate_word_idx: int labels: List[str] predicate_form: Optional[str] = None # ============================================================== # 2. Bracket → BIO conversion (unchanged) # ============================================================== def _bio_from_brackets(tags): bio, stack = [], [] for t in tags: if "(V*" in t: bio.append("B-V") continue opens = [] i = 0 while True: s = t.find("(", i) if s == -1: break e = t.find("*", s) if e == -1: break role = t[s+1:e] opens.append(role) i = e + 1 closes = t.count(")") if opens: bio.append(f"B-{opens[0]}") for r in opens: stack.append(r) elif stack: bio.append(f"I-{stack[-1]}") else: bio.append("O") for _ in range(closes): if stack: stack.pop() return bio def _read_sentences(path): sent = [] with open(path, "r", encoding="utf-8", errors="replace") as f: for line in f: line = line.rstrip("\n") if not line: if sent: yield sent; sent = [] continue cols = line.split() if cols: sent.append(cols) if sent: yield sent # ============================================================== # 3. CoNLL → SRLSample objects (in-memory) # ============================================================== def load_conll_samples(in_path, word_col_idx=3, srl_first_col_idx=11): """ Reads .gold_conll file and returns list[SRLSample], one per predicate column. """ samples = [] for sent in _read_sentences(in_path): words = [row[word_col_idx] for row in sent] max_cols = max(len(row) for row in sent) for srl_col in range(srl_first_col_idx, max_cols): tags = [row[srl_col] if srl_col < len(row) else "*" for row in sent] try: pred_idx = next(i for i, t in enumerate(tags) if "(V*" in t) except StopIteration: continue labels = _bio_from_brackets(tags) predicate_form = words[pred_idx] samples.append(SRLSample(words, pred_idx, labels, predicate_form)) print(f"[SRL_preprocessing] Loaded {len(samples)} predicate instances from {in_path}") return samples # ============================================================== # 4. Dataset + Collate (same as yours, lightly cleaned) # ============================================================== class SRLDataset(Dataset): def __init__(self, samples: List[SRLSample], tokenizer: AutoTokenizer, label2id: Dict[str, int], max_length: int = 256, debug_print=False): self.samples = samples self.tokenizer = tokenizer self.label2id = label2id self.id2label = {v: k for k, v in label2id.items()} self.max_length = max_length self.debug_print = debug_print def __len__(self): return len(self.samples) def _tokenize_sentence(self, words): return self.tokenizer(words, is_split_into_words=True, add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False) def _tokenize_predicate(self, form): return self.tokenizer(form, add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False) def __getitem__(self, idx): instance = self.samples[idx] words = instance.words n_words = len(words) pred_form = instance.predicate_form or words[instance.predicate_word_idx] enc_sent = self._tokenize_sentence(words) enc_pred = self._tokenize_predicate(pred_form) sent_wp_ids = enc_sent["input_ids"] pred_wp_ids = enc_pred["input_ids"] input_ids = [self.tokenizer.cls_token_id] + sent_wp_ids + [self.tokenizer.sep_token_id] \ + pred_wp_ids + [self.tokenizer.sep_token_id] ttids = [0] * (1 + len(sent_wp_ids) + 1) + [1] * (len(pred_wp_ids) + 1) tmp = self.tokenizer(words, is_split_into_words=True) word_ids = tmp.word_ids() first_pos_by_wid = {} for pos, wid in enumerate(word_ids): if wid is not None and wid not in first_pos_by_wid: first_pos_by_wid[wid] = pos word_first_wp_fullidx = [first_pos_by_wid[w] for w in range(n_words)] label_ids = [self.label2id[l] for l in instance.labels] indicator = [0]*n_words; indicator[instance.predicate_word_idx] = 1 attention_mask = [1]*len(input_ids) if len(input_ids) > self.max_length: max_pos = self.max_length-1 input_ids = input_ids[:self.max_length] ttids = ttids[:self.max_length] attention_mask = attention_mask[:self.max_length] word_first_wp_fullidx = [min(p, max_pos) for p in word_first_wp_fullidx] return { "input_ids": torch.tensor(input_ids, dtype=torch.long), "token_type_ids": torch.tensor(ttids, dtype=torch.long), "attention_mask": torch.tensor(attention_mask, dtype=torch.long), "word_first_wp_fullidx": torch.tensor(word_first_wp_fullidx, dtype=torch.long), "labels": torch.tensor(label_ids, dtype=torch.long), "indicator": torch.tensor(indicator, dtype=torch.long), "sent_len": torch.tensor(len(words), dtype=torch.long), "pred_word_idx": torch.tensor(instance.predicate_word_idx, dtype=torch.long) } def srl_collate(batch: List[Dict], pad_token_id: int, pad_label_id: int = -100): B = len(batch) max_L = max(item["input_ids"].size(0) for item in batch) input_ids = torch.full((B, max_L), pad_token_id, dtype=torch.long) token_type_ids = torch.zeros((B, max_L), dtype=torch.long) attention_mask = torch.zeros((B, max_L), dtype=torch.long) max_n = max(int(item["sent_len"]) for item in batch) word_first_wp_fullidx = torch.full((B, max_n), -1, dtype=torch.long) labels = torch.full((B, max_n), pad_label_id, dtype=torch.long) indicator = torch.zeros((B, max_n), dtype=torch.long) sent_lens = torch.zeros((B,), dtype=torch.long) pred_word_idx = torch.zeros((B,), dtype=torch.long) sentence_mask = torch.zeros((B, max_n), dtype=torch.bool) for i, item in enumerate(batch): L = item["input_ids"].size(0) input_ids[i, :L] = item["input_ids"] token_type_ids[i, :L] = item["token_type_ids"] attention_mask[i, :L] = item["attention_mask"] n = int(item["sent_len"]) word_first_wp_fullidx[i, :n] = item["word_first_wp_fullidx"] labels[i, :n] = item["labels"] indicator[i, :n] = item["indicator"] sent_lens[i] = n pred_word_idx[i] = item["pred_word_idx"] sentence_mask[i, :n] = True return { "input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask, "word_first_wp_fullidx": word_first_wp_fullidx, "sentence_mask": sentence_mask, "labels": labels, "indicator": indicator, "sent_lens": sent_lens, "pred_word_idx": pred_word_idx, } # ============================================================== # 5. Helper for trainer # ============================================================== def data_processing_for_loader_conll( train_conll: str, dev_conll: Optional[str], # test_conll: Optional[str], tokenizer, word_col_idx: int = 3, srl_first_col_idx: int = 11, max_length: int = 256 ) -> Tuple[SRLDataset, Optional[SRLDataset], Dict[str, int], Dict[int, str]]: """ Reads train/dev/test .gold_conll files and returns: train_dataset, dev_dataset, test_dataset, label2id, id2label * label set is computed from the UNION of train/dev/test labels * dev/test can be None """ # Load samples train_samples = load_conll_samples(train_conll, word_col_idx, srl_first_col_idx) dev_samples = load_conll_samples(dev_conll, word_col_idx, srl_first_col_idx) if dev_conll else [] # test_samples = load_conll_samples(test_conll, word_col_idx, srl_first_col_idx) if test_conll else [] # Build label maps from ALL splits all_samples = train_samples + dev_samples label2id = {} for s in all_samples: for lab in s.labels: if lab not in label2id: label2id[lab] = len(label2id) id2label = {v: k for k, v in label2id.items()} # Datasets train_ds = SRLDataset(train_samples, tokenizer, label2id, max_length=max_length) dev_ds = SRLDataset(dev_samples, tokenizer, label2id, max_length=max_length) if dev_samples else None # test_ds = SRLDataset(test_samples, tokenizer, label2id, max_length=max_length) if test_samples else None return train_ds, dev_ds, label2id, id2label