Time_RCD / models /time_rcd /dataset.py
Oliver Le
Initial commit
d03866e
import json
import numpy as np
import torch
from torch.utils.data import Dataset
import random
import os
import pickle
from typing import Dict, List, Union, Optional, Tuple
from pathlib import Path
class ChatTSTimeRCDPretrainDataset(Dataset):
def __init__(self,
dataset_dir: str,
filename: str,
split: str = 'train',
train_ratio: float = 0.95,
seed: int = 42):
file_path = os.path.join(dataset_dir, filename)
with open(file_path, 'rb') as f:
dataset = pickle.load(f)
random.seed(seed)
indices = list(range(len(dataset)))
random.shuffle(indices)
num_train = int(len(dataset) * train_ratio)
if split == 'train':
selected_indices = indices[:num_train]
elif split == 'test':
selected_indices = indices[num_train:]
else:
raise ValueError("split must be 'train' or 'test'")
self.data = [dataset[i] for i in selected_indices]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
time_series = torch.tensor(sample['time_series'], dtype=torch.float32)
normal_time_series = torch.tensor(sample['normal_time_series'], dtype=torch.float32)
labels = torch.tensor(sample['labels'], dtype=torch.long)
attribute = sample['attribute']
return time_series, normal_time_series, labels, attribute
class ChatTSTimeRCDQADataset(Dataset):
"""Dataset class for time series anomaly detection with QA pairs.
This dataset loads time series data and corresponding question-answer pairs
for anomaly detection tasks. It supports train/val split and efficient loading
of series data from the time_rcd_datasets format.
Attributes:
split (str): Dataset split, either 'train' or 'val'
series_dir (Path): Directory containing series JSON files
metadata (Dict): Dataset metadata loaded from metadata.json
series_files (List[str]): List of series file paths
window_size_range (Tuple[int, int]): Range of window sizes used in the dataset
"""
def __init__(
self,
dataset_dir: str,
split: str = 'train',
train_ratio: float = 0.95,
seed: int = 42,
cache_size: int = 1000
) -> None:
"""Initialize the dataset.
Args:
dataset_dir: Path to the dataset directory containing metadata.json and series/
split: Dataset split, either 'train' or 'val'
train_ratio: Ratio of training samples (default: 0.8)
seed: Random seed for reproducibility (default: 42)
cache_size: Number of series files to keep in memory (default: 1000)
"""
self.split = split
self.series_dir = Path(dataset_dir) / 'series'
# Get all series files and shuffle them
self.series_files = sorted(self.series_dir.glob('series_*.json'))
random.seed(seed)
random.shuffle(self.series_files)
# Split into train/val
split_idx = int(len(self.series_files) * train_ratio)
self.series_files = self.series_files[:split_idx] if split == 'train' else self.series_files[split_idx:]
# Initialize LRU cache for series data
self._cache = {}
self._cache_size = cache_size
self._cache_order = []
def _load_series(self, file_path: Path) -> Dict:
"""Load a series file with caching.
Args:
file_path: Path to the series JSON file
Returns:
Dictionary containing the series data
"""
if file_path in self._cache:
# Update cache order
self._cache_order.remove(file_path)
self._cache_order.append(file_path)
return self._cache[file_path]
# Load new file
with open(file_path, 'r') as f:
data = json.load(f)
# Update cache
if len(self._cache) >= self._cache_size:
# Remove oldest item
oldest = self._cache_order.pop(0)
del self._cache[oldest]
self._cache[file_path] = data
self._cache_order.append(file_path)
return data
def __len__(self) -> int:
"""Return the number of samples in the dataset."""
return len(self.series_files)
def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, List[Dict]]]:
"""Get a sample from the dataset.
Args:
idx: Index of the sample to retrieve
Returns:
Dictionary containing:
- time_series: Time series data as torch.Tensor
- windows: List of window data containing QA pairs
- sample_id: Unique identifier for the sample
"""
file_path = self.series_files[idx]
data = self._load_series(file_path)
# Convert time series to tensor
time_series = np.array(data['original_data']['time_series'])
time_series_tensor = torch.FloatTensor(time_series)
return {
'time_series': time_series_tensor,
'analysis_data': data['windows']
}