# dataset.py import os import csv import json import torch import logging from time import time from functools import wraps from preprocess import Preprocessor from torch.utils.data import Dataset from typing import List, Dict, Any, Optional, Union logger = logging.getLogger(__name__) def safe_file_operation(func): """Decorator to safely handle file operations with timeout""" @wraps(func) def wrapper(self, *args, **kwargs): start_time = time() timeout_seconds = 300 # 5-minute timeout try: # Try to perform the operation result = func(self, *args, **kwargs) # Check if operation took too long if time() - start_time > timeout_seconds: logger.warning(f"File operation {func.__name__} took more than {timeout_seconds} seconds") return result except (IOError, OSError) as e: logger.error(f"File operation error in {func.__name__}: {str(e)}") # Return empty result based on function type if func.__name__.startswith('_load_'): return [] raise except json.JSONDecodeError as e: logger.error(f"JSON decode error in {self.file_path}: {str(e)}") return [] except csv.Error as e: logger.error(f"CSV error in {self.file_path}: {str(e)}") return [] except Exception as e: logger.error(f"Unexpected error in {func.__name__}: {str(e)}") raise return wrapper class TensorDataset(Dataset): """Dataset class for handling tensor data with features and labels.""" def __init__(self, features, labels): """ Initialize TensorDataset. Args: features (Tensor): Feature tensors. labels (Tensor): Label tensors. """ self.features = features self.labels = labels def __len__(self): return len(self.features) def __getitem__(self, idx): return self.features[idx], self.labels[idx] class CustomDataset(Dataset): """A dataset that supports loading JSON, CSV, and TXT formats. It auto-detects the file type (if not specified) and filters out any records that are not dictionaries. If a preprocessor is provided, it applies it to each record. Additionally, it can standardize sample keys dynamically using a provided header mapping. For example, you can define a mapping like: mapping = { "title": ["Title", "Headline", "Article Title"], "content": ["Content", "Body", "Text"], } so that regardless of the CSV's header names your trainer always sees a standardized set of keys.""" def __init__( self, file_path: Optional[str] = None, tokenizer = None, max_length: Optional[int] = None, file_format: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, header_mapping: Optional[Dict[str, List[str]]] = None, data: Optional[List[Dict[str, Any]]] = None, # Add data parameter specialization: Optional[str] = None # Add specialization parameter ): """Args: file_path (Optional[str]): Path to the dataset file. tokenizer: Tokenizer instance to process the text. max_length (Optional[int]): Maximum sequence length. file_format (Optional[str]): Format of the file; inferred from the extension if not provided. preprocessor (Optional[Preprocessor]): Preprocessor to apply to each sample. header_mapping (Optional[Dict[str, List[str]]]): Dictionary that maps standardized keys. data (Optional[List[Dict[str, Any]]]): Direct data input instead of loading from file. specialization (Optional[str]): Specialization field for the dataset.""" self.file_path = file_path self.tokenizer = tokenizer self.max_length = max_length self.preprocessor = preprocessor self.header_mapping = header_mapping self.specialization = specialization # Store the specialization # Initialize samples either from data or file if data is not None: self.samples = data else: # Determine the file format if not specified and file_path is provided if file_path is not None: if file_format is None: _, ext = os.path.splitext(file_path) ext = ext.lower() if ext in ['.json']: file_format = 'json' elif ext in ['.csv']: file_format = 'csv' elif ext in ['.txt']: file_format = 'txt' else: logger.error(f"Unsupported file extension: {ext}") raise ValueError(f"Unsupported file extension: {ext}") self.file_format = file_format self.samples = self._load_file() else: self.samples = [] # Auto-detection: Ensure all loaded samples are dicts. initial_sample_count = len(self.samples) self.samples = [sample for sample in self.samples if isinstance(sample, dict)] if len(self.samples) < initial_sample_count: logger.warning(f"Filtered out {initial_sample_count - len(self.samples)} samples that were not dicts.") # If a preprocessor is provided, apply preprocessing to each record. if self.preprocessor: preprocessed_samples = [] for sample in self.samples: try: processed = self.preprocessor.preprocess_record(sample) preprocessed_samples.append(processed) except Exception as e: logger.error(f"Error preprocessing record {sample}: {e}") self.samples = preprocessed_samples def _load_file(self) -> List[Dict[str, Any]]: try: if self.file_format == 'json': return self._load_json() elif self.file_format == 'csv': return self._load_csv() elif self.file_format == 'txt': return self._load_txt() else: logger.error(f"Unrecognized file format: {self.file_format}") raise ValueError(f"Unrecognized file format: {self.file_format}") except Exception as e: logger.error(f"Error loading file {self.file_path}: {e}") raise @safe_file_operation def _load_json(self) -> List[Dict[str, Any]]: """Load JSON file with better error handling and validation""" try: with open(self.file_path, 'r', encoding='utf-8') as f: data = json.load(f) # Validate data structure if isinstance(data, list): valid_records = [record for record in data if isinstance(record, dict)] if len(valid_records) < len(data): logger.warning(f"{len(data) - len(valid_records)} records were not dictionaries in {self.file_path}") return valid_records elif isinstance(data, dict): # Handle single record case logger.warning(f"JSON file contains a single dictionary, not a list: {self.file_path}") return [data] else: logger.error(f"JSON file does not contain a list or dictionary: {self.file_path}") return [] except json.JSONDecodeError as e: line_col = f"line {e.lineno}, column {e.colno}" logger.error(f"JSON decode error at {line_col} in {self.file_path}: {e.msg}") # Try to recover partial content try: with open(self.file_path, 'r', encoding='utf-8') as f: content = f.read() # Try parsing up to the error valid_part = content[:e.pos] import re # Find complete objects (rough approach) matches = re.findall(r'\{[^{}]*\}', valid_part) if matches: logger.info(f"Recovered {len(matches)} complete records from {self.file_path}") parsed_records = [] for match in matches: try: parsed_records.append(json.loads(match)) except: pass return parsed_records except: pass return [] @safe_file_operation def _load_csv(self) -> List[Dict[str, Any]]: """Load CSV with better error handling""" samples = [] try: with open(self.file_path, 'r', encoding='utf-8') as csvfile: # Try detecting dialect first try: dialect = csv.Sniffer().sniff(csvfile.read(1024)) csvfile.seek(0) reader = csv.DictReader(csvfile, dialect=dialect) except: # Fall back to excel dialect csvfile.seek(0) reader = csv.DictReader(csvfile, dialect='excel') for i, row in enumerate(reader): if not isinstance(row, dict): logger.warning(f"Row {i} is not a dict: {row} -- skipping.") continue samples.append(row) if not samples: logger.warning(f"No valid rows found in CSV file: {self.file_path}") except csv.Error as e: logger.error(f"Error reading CSV file {self.file_path}: {e}") return samples def _load_txt(self) -> List[Dict[str, Any]]: samples = [] with open(self.file_path, 'r', encoding='utf-8') as txtfile: for i, line in enumerate(txtfile): line = line.strip() if line: # Wrap each line in a dictionary. samples.append({"text": line}) return samples def _standardize_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Remaps the sample's keys to a set of standardized keys using self.header_mapping. For each standardized key, the first matching header from the sample is used. If none is found, a default empty string is assigned.""" standardized = {} for std_field, possible_keys in self.header_mapping.items(): for key in possible_keys: if key in sample: standardized[std_field] = sample[key] break if std_field not in standardized: standardized[std_field] = "" return standardized def __len__(self) -> int: return len(self.samples) def __getitem__(self, index: int) -> Dict[str, Any]: sample = self.samples[index] # If a header mapping is provided, standardize the sample keys. if self.header_mapping is not None: sample = self._standardize_sample(sample) # Determine the text to tokenize: # If standardized keys "title" or "content" exist, combine them. if 'title' in sample or 'content' in sample: title = sample.get('title', '') content = sample.get('content', '') # Convert non-string fields to strings if not isinstance(title, str): title = str(title) if not isinstance(content, str): content = str(content) text = (title + " " + content).strip() elif "text" in sample: text = sample["text"] if isinstance(sample["text"], str) else str(sample["text"]) else: # Fallback: join all values (cast to str) text = " ".join(str(v) for v in sample.values()) # Tokenize the combined text. tokenized = self.tokenizer.encode_plus( text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) # Get specialization from sample or use class default specialization = None if isinstance(sample, dict) and "specialization" in sample: specialization = sample["specialization"] elif self.specialization: specialization = self.specialization # Return a standardized dictionary for training. result = { "input_ids": tokenized["input_ids"].squeeze(0), "attention_mask": tokenized["attention_mask"].squeeze(0), "token_type_ids": tokenized.get("token_type_ids", torch.zeros_like(tokenized["input_ids"])).squeeze(0), } # Add specialization if available if specialization: result["specialization"] = specialization # Optionally include standardized text fields if needed if 'title' in locals(): result["title"] = title if 'content' in locals(): result["content"] = content return result # dataset.py - Simple dataset module to fix initialization dependency issues import logging import os import json from typing import Dict, List, Any, Optional, Union logger = logging.getLogger(__name__) class DatasetManager: """ Simple dataset manager to provide basic functionality for model_manager without requiring external dataset dependencies """ def __init__(self, data_dir: Optional[str] = None): self.data_dir = data_dir or os.path.join(os.path.dirname(__file__), "data") self.datasets = {} self._ensure_data_dir() def _ensure_data_dir(self): """Ensure data directory exists""" try: if not os.path.exists(self.data_dir): os.makedirs(self.data_dir, exist_ok=True) logger.info(f"Created dataset directory at {self.data_dir}") except (PermissionError, OSError) as e: logger.warning(f"Could not create data directory: {e}") # Fall back to temp directory self.data_dir = os.path.join("/tmp", "wildnerve_data") os.makedirs(self.data_dir, exist_ok=True) logger.info(f"Using fallback data directory at {self.data_dir}") def load_dataset(self, name: str) -> List[Dict[str, Any]]: """Load dataset by name""" if name in self.datasets: return self.datasets[name] # Check for dataset file filepath = os.path.join(self.data_dir, f"{name}.json") if os.path.exists(filepath): try: with open(filepath, 'r', encoding='utf-8') as f: data = json.load(f) self.datasets[name] = data return data except Exception as e: logger.error(f"Error loading dataset {name}: {e}") # Return empty dataset if not found logger.warning(f"Dataset {name} not found, returning empty dataset") return [] def get_dataset_names(self) -> List[str]: """Get list of available datasets""" try: return [f.split('.')[0] for f in os.listdir(self.data_dir) if f.endswith('.json')] except Exception as e: logger.error(f"Error listing datasets: {e}") return [] def create_sample_dataset(self, name: str, samples: int = 10) -> List[Dict[str, Any]]: """Create a sample dataset for testing""" data = [ { "id": i, "text": f"Sample text {i} for model training", "label": i % 2 # Binary label } for i in range(samples) ] # Save to file filepath = os.path.join(self.data_dir, f"{name}.json") try: with open(filepath, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2) self.datasets[name] = data logger.info(f"Created sample dataset {name} with {samples} samples") except Exception as e: logger.error(f"Error creating sample dataset: {e}") return data def _load_and_process_dataset(self, path_or_paths: Union[str, List[str]], specialization: str) -> TensorDataset: # …existing code up to reading the file… import pandas as pd # Handle multiple JSON files by concatenation if isinstance(path_or_paths, list): frames = [pd.read_json(p) for p in path_or_paths] data = pd.concat(frames, ignore_index=True) else: data = pd.read_json(path_or_paths) # …existing code that splits into features/labels and returns TensorDataset… # Create a default dataset manager instance dataset_manager = DatasetManager() def get_dataset(name: str) -> List[Dict[str, Any]]: """Helper function to get a dataset by name""" return dataset_manager.load_dataset(name) # Create some minimal sample data if running as main if __name__ == "__main__": logging.basicConfig(level=logging.INFO) dm = DatasetManager() dm.create_sample_dataset("test_dataset", samples=20) print(f"Available datasets: {dm.get_dataset_names()}") test_data = dm.load_dataset("test_dataset") print(f"Loaded {len(test_data)} samples from test_dataset")