StressDetect / data /dataset.py
Ace-119's picture
Build preprocessing pipeline for Dreaddit, Reddit, and Twitter datasets
65057ee
"""
data/dataset.py
===============
Phase 1: PyTorch DataLoader with Sliding-Window Chunking
Provides ``StressDataset`` — a ``torch.utils.data.Dataset`` that splits
long texts into overlapping chunks of ``chunk_size`` tokens with a
configurable ``stride``, preventing truncation loss for long Reddit posts.
Each chunk is treated as an independent sample during training / inference,
and results can be aggregated per-document at evaluation time via
``doc_index``.
Usage
-----
>>> from data.dataset import StressDataset, create_dataloaders
>>> dataset = StressDataset(texts, labels, domains)
>>> train_dl, val_dl, test_dl = create_dataloaders(texts, labels, domains)
"""
from __future__ import annotations
from typing import Optional
import torch
from torch.utils.data import DataLoader, Dataset
# ---------------------------------------------------------------------------
# Vocabulary builder (simple word-level tokenizer)
# ---------------------------------------------------------------------------
_PAD_TOKEN = "<PAD>"
_UNK_TOKEN = "<UNK>"
class SimpleVocab:
"""Minimal word-level vocabulary for the CNN model.
Assigns a unique integer to each token seen during ``build()``.
"""
def __init__(self) -> None:
self.token2idx: dict[str, int] = {_PAD_TOKEN: 0, _UNK_TOKEN: 1}
self.idx2token: dict[int, str] = {0: _PAD_TOKEN, 1: _UNK_TOKEN}
self.pad_idx: int = 0
self.unk_idx: int = 1
def build(self, texts: list[str], min_freq: int = 2) -> "SimpleVocab":
"""Build vocabulary from a list of texts.
Parameters
----------
texts : list[str]
Raw text strings.
min_freq : int
Minimum token frequency to be included.
Returns
-------
SimpleVocab
self, for method chaining.
"""
freq: dict[str, int] = {}
for text in texts:
for token in text.lower().split():
freq[token] = freq.get(token, 0) + 1
for token, count in freq.items():
if count >= min_freq and token not in self.token2idx:
idx = len(self.token2idx)
self.token2idx[token] = idx
self.idx2token[idx] = token
return self
def encode(self, text: str) -> list[int]:
"""Convert a text string to a list of token indices."""
return [
self.token2idx.get(t, self.unk_idx) for t in text.lower().split()
]
def __len__(self) -> int:
return len(self.token2idx)
# ---------------------------------------------------------------------------
# Sliding-Window Chunking Dataset
# ---------------------------------------------------------------------------
DEFAULT_CHUNK_SIZE: int = 200
DEFAULT_STRIDE: int = 50
class StressDataset(Dataset):
"""PyTorch Dataset with sliding-window chunking for long texts.
Parameters
----------
texts : list[str]
Raw text strings.
labels : list[int]
Binary labels (0 = no stress, 1 = stress).
domains : list[str]
Domain tags (e.g. ``'reddit_long'``, ``'twitter_short'``).
vocab : SimpleVocab, optional
Pre-built vocabulary. If ``None``, one is built from ``texts``.
chunk_size : int
Maximum number of tokens per chunk.
stride : int
Step size between consecutive chunks.
"""
def __init__(
self,
texts: list[str],
labels: list[int],
domains: list[str],
vocab: SimpleVocab | None = None,
chunk_size: int = DEFAULT_CHUNK_SIZE,
stride: int = DEFAULT_STRIDE,
) -> None:
if not (len(texts) == len(labels) == len(domains)):
raise ValueError(
"texts, labels, and domains must have the same length"
)
self.chunk_size = chunk_size
self.stride = stride
# Build or reuse vocabulary
if vocab is None:
self.vocab = SimpleVocab().build(texts)
else:
self.vocab = vocab
# Pre-compute all chunks
self._chunks: list[torch.Tensor] = []
self._labels: list[int] = []
self._domains: list[str] = []
self._doc_indices: list[int] = [] # maps chunk → original doc
for doc_idx, (text, label, domain) in enumerate(
zip(texts, labels, domains)
):
token_ids = self.vocab.encode(text)
if len(token_ids) == 0:
# Empty text → single padded chunk
chunk = torch.zeros(chunk_size, dtype=torch.long)
self._chunks.append(chunk)
self._labels.append(label)
self._domains.append(domain)
self._doc_indices.append(doc_idx)
continue
# Generate sliding-window chunks
chunks_created = 0
for start in range(0, len(token_ids), stride):
end = start + chunk_size
chunk_ids = token_ids[start:end]
# Pad if shorter than chunk_size
if len(chunk_ids) < chunk_size:
chunk_ids = chunk_ids + [self.vocab.pad_idx] * (
chunk_size - len(chunk_ids)
)
self._chunks.append(torch.tensor(chunk_ids, dtype=torch.long))
self._labels.append(label)
self._domains.append(domain)
self._doc_indices.append(doc_idx)
chunks_created += 1
# Stop if we've consumed the entire text
if end >= len(token_ids):
break
def __len__(self) -> int:
return len(self._chunks)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor | int | str]:
return {
"input_ids": self._chunks[idx],
"label": self._labels[idx],
"domain": self._domains[idx],
"doc_index": self._doc_indices[idx],
}
# ---------------------------------------------------------------------------
# DataLoader factory
# ---------------------------------------------------------------------------
def collate_fn(batch: list[dict]) -> dict[str, torch.Tensor | list]:
"""Custom collate function for ``StressDataset``.
Stacks ``input_ids`` and ``label`` into tensors; keeps ``domain``
and ``doc_index`` as lists.
"""
input_ids = torch.stack([item["input_ids"] for item in batch])
labels = torch.tensor([item["label"] for item in batch], dtype=torch.long)
domains = [item["domain"] for item in batch]
doc_indices = [item["doc_index"] for item in batch]
return {
"input_ids": input_ids,
"labels": labels,
"domains": domains,
"doc_indices": doc_indices,
}
def create_dataloaders(
texts: list[str],
labels: list[int],
domains: list[str],
vocab: SimpleVocab | None = None,
chunk_size: int = DEFAULT_CHUNK_SIZE,
stride: int = DEFAULT_STRIDE,
batch_size: int = 32,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
seed: int = 42,
) -> tuple[DataLoader, DataLoader, DataLoader, SimpleVocab]:
"""Create train / validation / test DataLoaders.
Parameters
----------
texts, labels, domains : list
Raw data arrays.
vocab : SimpleVocab, optional
Pre-built vocabulary; built from training split if ``None``.
chunk_size, stride : int
Sliding-window parameters.
batch_size : int
Batch size for all loaders.
train_ratio, val_ratio : float
Proportions for the train and validation splits.
Test ratio = ``1 - train_ratio - val_ratio``.
seed : int
Random seed for reproducibility.
Returns
-------
tuple[DataLoader, DataLoader, DataLoader, SimpleVocab]
``(train_loader, val_loader, test_loader, vocab)``
"""
import random
n = len(texts)
indices = list(range(n))
random.seed(seed)
random.shuffle(indices)
n_train = int(n * train_ratio)
n_val = int(n * val_ratio)
train_idx = indices[:n_train]
val_idx = indices[n_train : n_train + n_val]
test_idx = indices[n_train + n_val :]
def _select(idx_list: list[int]) -> tuple[list[str], list[int], list[str]]:
return (
[texts[i] for i in idx_list],
[labels[i] for i in idx_list],
[domains[i] for i in idx_list],
)
train_texts, train_labels, train_domains = _select(train_idx)
val_texts, val_labels, val_domains = _select(val_idx)
test_texts, test_labels, test_domains = _select(test_idx)
# Build vocab from training data only
if vocab is None:
vocab = SimpleVocab().build(train_texts)
train_ds = StressDataset(
train_texts, train_labels, train_domains,
vocab=vocab, chunk_size=chunk_size, stride=stride,
)
val_ds = StressDataset(
val_texts, val_labels, val_domains,
vocab=vocab, chunk_size=chunk_size, stride=stride,
)
test_ds = StressDataset(
test_texts, test_labels, test_domains,
vocab=vocab, chunk_size=chunk_size, stride=stride,
)
train_loader = DataLoader(
train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn,
)
val_loader = DataLoader(
val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn,
)
test_loader = DataLoader(
test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn,
)
return train_loader, val_loader, test_loader, vocab