stefhooy
Setting up the training model first draft and also updating the requirements.txt for the second part of this project.
28e0f73
Raw
History Blame Contribute Delete
2.38 kB
from __future__ import annotations
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizerBase
from typing import List
from data.schema import Debate
LABEL2ID = {"claim": 0, "counter_claim": 1, "premise": 2, "unknown": 3}
ID2LABEL = {v: k for k, v in LABEL2ID.items()}
NUM_LABELS = len(LABEL2ID)
class ArgumentDataset(Dataset):
"""Converts a list of Debate objects into tokenized (parent, argument) pairs.
Each example feeds the parent comment as text_a and the reply as text_b so
the model sees relational context — critical for distinguishing counter_claims
from plain claims. Root arguments (no parent) are encoded as single sequences.
"""
def __init__(
self,
debates: List[Debate],
tokenizer: PreTrainedTokenizerBase,
max_length: int = 256,
):
self.tokenizer = tokenizer
self.max_length = max_length
self.examples = self._extract(debates)
def _extract(self, debates: List[Debate]) -> list:
examples = []
for debate in debates:
arg_map = {a.id: a for a in debate.arguments}
for arg in debate.arguments:
parent = arg_map.get(arg.parent_id) if arg.parent_id else None
examples.append({
"text_a": parent.text if parent else "",
"text_b": arg.text,
"label": LABEL2ID[arg.arg_type],
})
return examples
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> dict:
ex = self.examples[idx]
if ex["text_a"]:
enc = self.tokenizer(
ex["text_a"],
ex["text_b"],
max_length=self.max_length,
truncation=True,
padding="max_length",
return_tensors="pt",
)
else:
enc = self.tokenizer(
ex["text_b"],
max_length=self.max_length,
truncation=True,
padding="max_length",
return_tensors="pt",
)
return {
"input_ids": enc["input_ids"].squeeze(0),
"attention_mask": enc["attention_mask"].squeeze(0),
"label": torch.tensor(ex["label"], dtype=torch.long),
}