Spaces:
Sleeping
Sleeping
stefhooy
Setting up the training model first draft and also updating the requirements.txt for the second part of this project.
28e0f73 | from __future__ import annotations | |
| import torch | |
| from torch.utils.data import Dataset | |
| from transformers import PreTrainedTokenizerBase | |
| from typing import List | |
| from data.schema import Debate | |
| LABEL2ID = {"claim": 0, "counter_claim": 1, "premise": 2, "unknown": 3} | |
| ID2LABEL = {v: k for k, v in LABEL2ID.items()} | |
| NUM_LABELS = len(LABEL2ID) | |
| class ArgumentDataset(Dataset): | |
| """Converts a list of Debate objects into tokenized (parent, argument) pairs. | |
| Each example feeds the parent comment as text_a and the reply as text_b so | |
| the model sees relational context — critical for distinguishing counter_claims | |
| from plain claims. Root arguments (no parent) are encoded as single sequences. | |
| """ | |
| def __init__( | |
| self, | |
| debates: List[Debate], | |
| tokenizer: PreTrainedTokenizerBase, | |
| max_length: int = 256, | |
| ): | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| self.examples = self._extract(debates) | |
| def _extract(self, debates: List[Debate]) -> list: | |
| examples = [] | |
| for debate in debates: | |
| arg_map = {a.id: a for a in debate.arguments} | |
| for arg in debate.arguments: | |
| parent = arg_map.get(arg.parent_id) if arg.parent_id else None | |
| examples.append({ | |
| "text_a": parent.text if parent else "", | |
| "text_b": arg.text, | |
| "label": LABEL2ID[arg.arg_type], | |
| }) | |
| return examples | |
| def __len__(self) -> int: | |
| return len(self.examples) | |
| def __getitem__(self, idx: int) -> dict: | |
| ex = self.examples[idx] | |
| if ex["text_a"]: | |
| enc = self.tokenizer( | |
| ex["text_a"], | |
| ex["text_b"], | |
| max_length=self.max_length, | |
| truncation=True, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| else: | |
| enc = self.tokenizer( | |
| ex["text_b"], | |
| max_length=self.max_length, | |
| truncation=True, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| return { | |
| "input_ids": enc["input_ids"].squeeze(0), | |
| "attention_mask": enc["attention_mask"].squeeze(0), | |
| "label": torch.tensor(ex["label"], dtype=torch.long), | |
| } | |