Spaces:
Running
Running
| import json | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| import random | |
| import os | |
| import pickle | |
| from typing import Dict, List, Union, Optional, Tuple | |
| from pathlib import Path | |
| class ChatTSTimeRCDPretrainDataset(Dataset): | |
| def __init__(self, | |
| dataset_dir: str, | |
| filename: str, | |
| split: str = 'train', | |
| train_ratio: float = 0.95, | |
| seed: int = 42): | |
| file_path = os.path.join(dataset_dir, filename) | |
| with open(file_path, 'rb') as f: | |
| dataset = pickle.load(f) | |
| random.seed(seed) | |
| indices = list(range(len(dataset))) | |
| random.shuffle(indices) | |
| num_train = int(len(dataset) * train_ratio) | |
| if split == 'train': | |
| selected_indices = indices[:num_train] | |
| elif split == 'test': | |
| selected_indices = indices[num_train:] | |
| else: | |
| raise ValueError("split must be 'train' or 'test'") | |
| self.data = [dataset[i] for i in selected_indices] | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| sample = self.data[idx] | |
| time_series = torch.tensor(sample['time_series'], dtype=torch.float32) | |
| normal_time_series = torch.tensor(sample['normal_time_series'], dtype=torch.float32) | |
| labels = torch.tensor(sample['labels'], dtype=torch.long) | |
| attribute = sample['attribute'] | |
| return time_series, normal_time_series, labels, attribute | |
| class ChatTSTimeRCDQADataset(Dataset): | |
| """Dataset class for time series anomaly detection with QA pairs. | |
| This dataset loads time series data and corresponding question-answer pairs | |
| for anomaly detection tasks. It supports train/val split and efficient loading | |
| of series data from the time_rcd_datasets format. | |
| Attributes: | |
| split (str): Dataset split, either 'train' or 'val' | |
| series_dir (Path): Directory containing series JSON files | |
| metadata (Dict): Dataset metadata loaded from metadata.json | |
| series_files (List[str]): List of series file paths | |
| window_size_range (Tuple[int, int]): Range of window sizes used in the dataset | |
| """ | |
| def __init__( | |
| self, | |
| dataset_dir: str, | |
| split: str = 'train', | |
| train_ratio: float = 0.95, | |
| seed: int = 42, | |
| cache_size: int = 1000 | |
| ) -> None: | |
| """Initialize the dataset. | |
| Args: | |
| dataset_dir: Path to the dataset directory containing metadata.json and series/ | |
| split: Dataset split, either 'train' or 'val' | |
| train_ratio: Ratio of training samples (default: 0.8) | |
| seed: Random seed for reproducibility (default: 42) | |
| cache_size: Number of series files to keep in memory (default: 1000) | |
| """ | |
| self.split = split | |
| self.series_dir = Path(dataset_dir) / 'series' | |
| # Get all series files and shuffle them | |
| self.series_files = sorted(self.series_dir.glob('series_*.json')) | |
| random.seed(seed) | |
| random.shuffle(self.series_files) | |
| # Split into train/val | |
| split_idx = int(len(self.series_files) * train_ratio) | |
| self.series_files = self.series_files[:split_idx] if split == 'train' else self.series_files[split_idx:] | |
| # Initialize LRU cache for series data | |
| self._cache = {} | |
| self._cache_size = cache_size | |
| self._cache_order = [] | |
| def _load_series(self, file_path: Path) -> Dict: | |
| """Load a series file with caching. | |
| Args: | |
| file_path: Path to the series JSON file | |
| Returns: | |
| Dictionary containing the series data | |
| """ | |
| if file_path in self._cache: | |
| # Update cache order | |
| self._cache_order.remove(file_path) | |
| self._cache_order.append(file_path) | |
| return self._cache[file_path] | |
| # Load new file | |
| with open(file_path, 'r') as f: | |
| data = json.load(f) | |
| # Update cache | |
| if len(self._cache) >= self._cache_size: | |
| # Remove oldest item | |
| oldest = self._cache_order.pop(0) | |
| del self._cache[oldest] | |
| self._cache[file_path] = data | |
| self._cache_order.append(file_path) | |
| return data | |
| def __len__(self) -> int: | |
| """Return the number of samples in the dataset.""" | |
| return len(self.series_files) | |
| def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, List[Dict]]]: | |
| """Get a sample from the dataset. | |
| Args: | |
| idx: Index of the sample to retrieve | |
| Returns: | |
| Dictionary containing: | |
| - time_series: Time series data as torch.Tensor | |
| - windows: List of window data containing QA pairs | |
| - sample_id: Unique identifier for the sample | |
| """ | |
| file_path = self.series_files[idx] | |
| data = self._load_series(file_path) | |
| # Convert time series to tensor | |
| time_series = np.array(data['original_data']['time_series']) | |
| time_series_tensor = torch.FloatTensor(time_series) | |
| return { | |
| 'time_series': time_series_tensor, | |
| 'analysis_data': data['windows'] | |
| } | |