| |
|
| | import os
|
| | import csv
|
| | import json
|
| | import torch
|
| | import logging
|
| | from time import time
|
| | from functools import wraps
|
| | from preprocess import Preprocessor
|
| | from torch.utils.data import Dataset
|
| | from typing import List, Dict, Any, Optional, Union
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| |
|
| | def safe_file_operation(func):
|
| | """Decorator to safely handle file operations with timeout"""
|
| | @wraps(func)
|
| | def wrapper(self, *args, **kwargs):
|
| | start_time = time()
|
| | timeout_seconds = 300
|
| |
|
| | try:
|
| |
|
| | result = func(self, *args, **kwargs)
|
| |
|
| |
|
| | if time() - start_time > timeout_seconds:
|
| | logger.warning(f"File operation {func.__name__} took more than {timeout_seconds} seconds")
|
| |
|
| | return result
|
| | except (IOError, OSError) as e:
|
| | logger.error(f"File operation error in {func.__name__}: {str(e)}")
|
| |
|
| | if func.__name__.startswith('_load_'):
|
| | return []
|
| | raise
|
| | except json.JSONDecodeError as e:
|
| | logger.error(f"JSON decode error in {self.file_path}: {str(e)}")
|
| | return []
|
| | except csv.Error as e:
|
| | logger.error(f"CSV error in {self.file_path}: {str(e)}")
|
| | return []
|
| | except Exception as e:
|
| | logger.error(f"Unexpected error in {func.__name__}: {str(e)}")
|
| | raise
|
| |
|
| | return wrapper
|
| |
|
| | class TensorDataset(Dataset):
|
| | """Dataset class for handling tensor data with features and labels."""
|
| | def __init__(self, features, labels):
|
| | """
|
| | Initialize TensorDataset.
|
| |
|
| | Args:
|
| | features (Tensor): Feature tensors.
|
| | labels (Tensor): Label tensors.
|
| | """
|
| | self.features = features
|
| | self.labels = labels
|
| |
|
| | def __len__(self):
|
| | return len(self.features)
|
| |
|
| | def __getitem__(self, idx):
|
| | return self.features[idx], self.labels[idx]
|
| |
|
| | class CustomDataset(Dataset):
|
| | """A dataset that supports loading JSON, CSV, and TXT formats.
|
| | It auto-detects the file type (if not specified) and filters out any
|
| | records that are not dictionaries. If a preprocessor is provided, it
|
| | applies it to each record. Additionally, it can standardize sample keys
|
| | dynamically using a provided header mapping. For example, you can define a
|
| | mapping like:
|
| | mapping = {
|
| | "title": ["Title", "Headline", "Article Title"],
|
| | "content": ["Content", "Body", "Text"],
|
| | }
|
| | so that regardless of the CSV's header names your trainer always sees a
|
| | standardized set of keys."""
|
| | def __init__(
|
| | self,
|
| | file_path: Optional[str] = None,
|
| | tokenizer = None,
|
| | max_length: Optional[int] = None,
|
| | file_format: Optional[str] = None,
|
| | preprocessor: Optional[Preprocessor] = None,
|
| | header_mapping: Optional[Dict[str, List[str]]] = None,
|
| | data: Optional[List[Dict[str, Any]]] = None,
|
| | specialization: Optional[str] = None
|
| | ):
|
| | """Args:
|
| | file_path (Optional[str]): Path to the dataset file.
|
| | tokenizer: Tokenizer instance to process the text.
|
| | max_length (Optional[int]): Maximum sequence length.
|
| | file_format (Optional[str]): Format of the file; inferred from the extension if not provided.
|
| | preprocessor (Optional[Preprocessor]): Preprocessor to apply to each sample.
|
| | header_mapping (Optional[Dict[str, List[str]]]): Dictionary that maps standardized keys.
|
| | data (Optional[List[Dict[str, Any]]]): Direct data input instead of loading from file.
|
| | specialization (Optional[str]): Specialization field for the dataset."""
|
| |
|
| | self.file_path = file_path
|
| | self.tokenizer = tokenizer
|
| | self.max_length = max_length
|
| | self.preprocessor = preprocessor
|
| | self.header_mapping = header_mapping
|
| | self.specialization = specialization
|
| |
|
| |
|
| | if data is not None:
|
| | self.samples = data
|
| | else:
|
| |
|
| | if file_path is not None:
|
| | if file_format is None:
|
| | _, ext = os.path.splitext(file_path)
|
| | ext = ext.lower()
|
| | if ext in ['.json']:
|
| | file_format = 'json'
|
| | elif ext in ['.csv']:
|
| | file_format = 'csv'
|
| | elif ext in ['.txt']:
|
| | file_format = 'txt'
|
| | else:
|
| | logger.error(f"Unsupported file extension: {ext}")
|
| | raise ValueError(f"Unsupported file extension: {ext}")
|
| |
|
| | self.file_format = file_format
|
| | self.samples = self._load_file()
|
| | else:
|
| | self.samples = []
|
| |
|
| |
|
| | initial_sample_count = len(self.samples)
|
| | self.samples = [sample for sample in self.samples if isinstance(sample, dict)]
|
| | if len(self.samples) < initial_sample_count:
|
| | logger.warning(f"Filtered out {initial_sample_count - len(self.samples)} samples that were not dicts.")
|
| |
|
| |
|
| | if self.preprocessor:
|
| | preprocessed_samples = []
|
| | for sample in self.samples:
|
| | try:
|
| | processed = self.preprocessor.preprocess_record(sample)
|
| | preprocessed_samples.append(processed)
|
| | except Exception as e:
|
| | logger.error(f"Error preprocessing record {sample}: {e}")
|
| | self.samples = preprocessed_samples
|
| |
|
| | def _load_file(self) -> List[Dict[str, Any]]:
|
| | try:
|
| | if self.file_format == 'json':
|
| | return self._load_json()
|
| | elif self.file_format == 'csv':
|
| | return self._load_csv()
|
| | elif self.file_format == 'txt':
|
| | return self._load_txt()
|
| | else:
|
| | logger.error(f"Unrecognized file format: {self.file_format}")
|
| | raise ValueError(f"Unrecognized file format: {self.file_format}")
|
| | except Exception as e:
|
| | logger.error(f"Error loading file {self.file_path}: {e}")
|
| | raise
|
| |
|
| | @safe_file_operation
|
| | def _load_json(self) -> List[Dict[str, Any]]:
|
| | """Load JSON file with better error handling and validation"""
|
| | try:
|
| | with open(self.file_path, 'r', encoding='utf-8') as f:
|
| | data = json.load(f)
|
| |
|
| |
|
| | if isinstance(data, list):
|
| | valid_records = [record for record in data if isinstance(record, dict)]
|
| | if len(valid_records) < len(data):
|
| | logger.warning(f"{len(data) - len(valid_records)} records were not dictionaries in {self.file_path}")
|
| | return valid_records
|
| | elif isinstance(data, dict):
|
| |
|
| | logger.warning(f"JSON file contains a single dictionary, not a list: {self.file_path}")
|
| | return [data]
|
| | else:
|
| | logger.error(f"JSON file does not contain a list or dictionary: {self.file_path}")
|
| | return []
|
| | except json.JSONDecodeError as e:
|
| | line_col = f"line {e.lineno}, column {e.colno}"
|
| | logger.error(f"JSON decode error at {line_col} in {self.file_path}: {e.msg}")
|
| |
|
| | try:
|
| | with open(self.file_path, 'r', encoding='utf-8') as f:
|
| | content = f.read()
|
| |
|
| | valid_part = content[:e.pos]
|
| | import re
|
| |
|
| | matches = re.findall(r'\{[^{}]*\}', valid_part)
|
| | if matches:
|
| | logger.info(f"Recovered {len(matches)} complete records from {self.file_path}")
|
| | parsed_records = []
|
| | for match in matches:
|
| | try:
|
| | parsed_records.append(json.loads(match))
|
| | except:
|
| | pass
|
| | return parsed_records
|
| | except:
|
| | pass
|
| | return []
|
| |
|
| | @safe_file_operation
|
| | def _load_csv(self) -> List[Dict[str, Any]]:
|
| | """Load CSV with better error handling"""
|
| | samples = []
|
| | try:
|
| | with open(self.file_path, 'r', encoding='utf-8') as csvfile:
|
| |
|
| | try:
|
| | dialect = csv.Sniffer().sniff(csvfile.read(1024))
|
| | csvfile.seek(0)
|
| | reader = csv.DictReader(csvfile, dialect=dialect)
|
| | except:
|
| |
|
| | csvfile.seek(0)
|
| | reader = csv.DictReader(csvfile, dialect='excel')
|
| |
|
| | for i, row in enumerate(reader):
|
| | if not isinstance(row, dict):
|
| | logger.warning(f"Row {i} is not a dict: {row} -- skipping.")
|
| | continue
|
| | samples.append(row)
|
| |
|
| | if not samples:
|
| | logger.warning(f"No valid rows found in CSV file: {self.file_path}")
|
| |
|
| | except csv.Error as e:
|
| | logger.error(f"Error reading CSV file {self.file_path}: {e}")
|
| | return samples
|
| |
|
| | def _load_txt(self) -> List[Dict[str, Any]]:
|
| | samples = []
|
| | with open(self.file_path, 'r', encoding='utf-8') as txtfile:
|
| | for i, line in enumerate(txtfile):
|
| | line = line.strip()
|
| | if line:
|
| |
|
| | samples.append({"text": line})
|
| | return samples
|
| |
|
| | def _standardize_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
| | """Remaps the sample's keys to a set of standardized keys using self.header_mapping.
|
| | For each standardized key, the first matching header from the sample is used.
|
| | If none is found, a default empty string is assigned."""
|
| | standardized = {}
|
| | for std_field, possible_keys in self.header_mapping.items():
|
| | for key in possible_keys:
|
| | if key in sample:
|
| | standardized[std_field] = sample[key]
|
| | break
|
| | if std_field not in standardized:
|
| | standardized[std_field] = ""
|
| | return standardized
|
| |
|
| | def __len__(self) -> int:
|
| | return len(self.samples)
|
| |
|
| | def __getitem__(self, index: int) -> Dict[str, Any]:
|
| | sample = self.samples[index]
|
| |
|
| |
|
| | if self.header_mapping is not None:
|
| | sample = self._standardize_sample(sample)
|
| |
|
| |
|
| |
|
| | if 'title' in sample or 'content' in sample:
|
| | title = sample.get('title', '')
|
| | content = sample.get('content', '')
|
| |
|
| | if not isinstance(title, str):
|
| | title = str(title)
|
| | if not isinstance(content, str):
|
| | content = str(content)
|
| | text = (title + " " + content).strip()
|
| | elif "text" in sample:
|
| | text = sample["text"] if isinstance(sample["text"], str) else str(sample["text"])
|
| | else:
|
| |
|
| | text = " ".join(str(v) for v in sample.values())
|
| |
|
| |
|
| | tokenized = self.tokenizer.encode_plus(
|
| | text,
|
| | max_length=self.max_length,
|
| | padding='max_length',
|
| | truncation=True,
|
| | return_tensors='pt'
|
| | )
|
| |
|
| |
|
| | specialization = None
|
| | if isinstance(sample, dict) and "specialization" in sample:
|
| | specialization = sample["specialization"]
|
| | elif self.specialization:
|
| | specialization = self.specialization
|
| |
|
| |
|
| | result = {
|
| | "input_ids": tokenized["input_ids"].squeeze(0),
|
| | "attention_mask": tokenized["attention_mask"].squeeze(0),
|
| | "token_type_ids": tokenized.get("token_type_ids", torch.zeros_like(tokenized["input_ids"])).squeeze(0),
|
| | }
|
| |
|
| |
|
| | if specialization:
|
| | result["specialization"] = specialization
|
| |
|
| |
|
| | if 'title' in locals():
|
| | result["title"] = title
|
| | if 'content' in locals():
|
| | result["content"] = content
|
| |
|
| | return result
|
| |
|
| |
|
| | import logging
|
| | import os
|
| | import json
|
| | from typing import Dict, List, Any, Optional, Union
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| | class DatasetManager:
|
| | """
|
| | Simple dataset manager to provide basic functionality for model_manager
|
| | without requiring external dataset dependencies
|
| | """
|
| | def __init__(self, data_dir: Optional[str] = None):
|
| | self.data_dir = data_dir or os.path.join(os.path.dirname(__file__), "data")
|
| | self.datasets = {}
|
| | self._ensure_data_dir()
|
| |
|
| | def _ensure_data_dir(self):
|
| | """Ensure data directory exists"""
|
| | try:
|
| | if not os.path.exists(self.data_dir):
|
| | os.makedirs(self.data_dir, exist_ok=True)
|
| | logger.info(f"Created dataset directory at {self.data_dir}")
|
| | except (PermissionError, OSError) as e:
|
| | logger.warning(f"Could not create data directory: {e}")
|
| |
|
| | self.data_dir = os.path.join("/tmp", "wildnerve_data")
|
| | os.makedirs(self.data_dir, exist_ok=True)
|
| | logger.info(f"Using fallback data directory at {self.data_dir}")
|
| |
|
| | def load_dataset(self, name: str) -> List[Dict[str, Any]]:
|
| | """Load dataset by name"""
|
| | if name in self.datasets:
|
| | return self.datasets[name]
|
| |
|
| |
|
| | filepath = os.path.join(self.data_dir, f"{name}.json")
|
| | if os.path.exists(filepath):
|
| | try:
|
| | with open(filepath, 'r', encoding='utf-8') as f:
|
| | data = json.load(f)
|
| | self.datasets[name] = data
|
| | return data
|
| | except Exception as e:
|
| | logger.error(f"Error loading dataset {name}: {e}")
|
| |
|
| |
|
| | logger.warning(f"Dataset {name} not found, returning empty dataset")
|
| | return []
|
| |
|
| | def get_dataset_names(self) -> List[str]:
|
| | """Get list of available datasets"""
|
| | try:
|
| | return [f.split('.')[0] for f in os.listdir(self.data_dir)
|
| | if f.endswith('.json')]
|
| | except Exception as e:
|
| | logger.error(f"Error listing datasets: {e}")
|
| | return []
|
| |
|
| | def create_sample_dataset(self, name: str, samples: int = 10) -> List[Dict[str, Any]]:
|
| | """Create a sample dataset for testing"""
|
| | data = [
|
| | {
|
| | "id": i,
|
| | "text": f"Sample text {i} for model training",
|
| | "label": i % 2
|
| | }
|
| | for i in range(samples)
|
| | ]
|
| |
|
| |
|
| | filepath = os.path.join(self.data_dir, f"{name}.json")
|
| | try:
|
| | with open(filepath, 'w', encoding='utf-8') as f:
|
| | json.dump(data, f, indent=2)
|
| | self.datasets[name] = data
|
| | logger.info(f"Created sample dataset {name} with {samples} samples")
|
| | except Exception as e:
|
| | logger.error(f"Error creating sample dataset: {e}")
|
| |
|
| | return data
|
| |
|
| | def _load_and_process_dataset(self, path_or_paths: Union[str, List[str]], specialization: str) -> TensorDataset:
|
| |
|
| | import pandas as pd
|
| |
|
| |
|
| | if isinstance(path_or_paths, list):
|
| | frames = [pd.read_json(p) for p in path_or_paths]
|
| | data = pd.concat(frames, ignore_index=True)
|
| | else:
|
| | data = pd.read_json(path_or_paths)
|
| |
|
| |
|
| |
|
| |
|
| | dataset_manager = DatasetManager()
|
| |
|
| | def get_dataset(name: str) -> List[Dict[str, Any]]:
|
| | """Helper function to get a dataset by name"""
|
| | return dataset_manager.load_dataset(name)
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | logging.basicConfig(level=logging.INFO)
|
| | dm = DatasetManager()
|
| | dm.create_sample_dataset("test_dataset", samples=20)
|
| | print(f"Available datasets: {dm.get_dataset_names()}")
|
| | test_data = dm.load_dataset("test_dataset")
|
| | print(f"Loaded {len(test_data)} samples from test_dataset") |