WildnerveAI's picture
Upload 2 files
7602079 verified
# dataset.py
import os
import csv
import json
import torch
import logging
from time import time
from functools import wraps
from preprocess import Preprocessor
from torch.utils.data import Dataset
from typing import List, Dict, Any, Optional, Union
logger = logging.getLogger(__name__)
def safe_file_operation(func):
"""Decorator to safely handle file operations with timeout"""
@wraps(func)
def wrapper(self, *args, **kwargs):
start_time = time()
timeout_seconds = 300 # 5-minute timeout
try:
# Try to perform the operation
result = func(self, *args, **kwargs)
# Check if operation took too long
if time() - start_time > timeout_seconds:
logger.warning(f"File operation {func.__name__} took more than {timeout_seconds} seconds")
return result
except (IOError, OSError) as e:
logger.error(f"File operation error in {func.__name__}: {str(e)}")
# Return empty result based on function type
if func.__name__.startswith('_load_'):
return []
raise
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in {self.file_path}: {str(e)}")
return []
except csv.Error as e:
logger.error(f"CSV error in {self.file_path}: {str(e)}")
return []
except Exception as e:
logger.error(f"Unexpected error in {func.__name__}: {str(e)}")
raise
return wrapper
class TensorDataset(Dataset):
"""Dataset class for handling tensor data with features and labels."""
def __init__(self, features, labels):
"""
Initialize TensorDataset.
Args:
features (Tensor): Feature tensors.
labels (Tensor): Label tensors.
"""
self.features = features
self.labels = labels
def __len__(self):
return len(self.features)
def __getitem__(self, idx):
return self.features[idx], self.labels[idx]
class CustomDataset(Dataset):
"""A dataset that supports loading JSON, CSV, and TXT formats.
It auto-detects the file type (if not specified) and filters out any
records that are not dictionaries. If a preprocessor is provided, it
applies it to each record. Additionally, it can standardize sample keys
dynamically using a provided header mapping. For example, you can define a
mapping like:
mapping = {
"title": ["Title", "Headline", "Article Title"],
"content": ["Content", "Body", "Text"],
}
so that regardless of the CSV's header names your trainer always sees a
standardized set of keys."""
def __init__(
self,
file_path: Optional[str] = None,
tokenizer = None,
max_length: Optional[int] = None,
file_format: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
header_mapping: Optional[Dict[str, List[str]]] = None,
data: Optional[List[Dict[str, Any]]] = None, # Add data parameter
specialization: Optional[str] = None # Add specialization parameter
):
"""Args:
file_path (Optional[str]): Path to the dataset file.
tokenizer: Tokenizer instance to process the text.
max_length (Optional[int]): Maximum sequence length.
file_format (Optional[str]): Format of the file; inferred from the extension if not provided.
preprocessor (Optional[Preprocessor]): Preprocessor to apply to each sample.
header_mapping (Optional[Dict[str, List[str]]]): Dictionary that maps standardized keys.
data (Optional[List[Dict[str, Any]]]): Direct data input instead of loading from file.
specialization (Optional[str]): Specialization field for the dataset."""
self.file_path = file_path
self.tokenizer = tokenizer
self.max_length = max_length
self.preprocessor = preprocessor
self.header_mapping = header_mapping
self.specialization = specialization # Store the specialization
# Initialize samples either from data or file
if data is not None:
self.samples = data
else:
# Determine the file format if not specified and file_path is provided
if file_path is not None:
if file_format is None:
_, ext = os.path.splitext(file_path)
ext = ext.lower()
if ext in ['.json']:
file_format = 'json'
elif ext in ['.csv']:
file_format = 'csv'
elif ext in ['.txt']:
file_format = 'txt'
else:
logger.error(f"Unsupported file extension: {ext}")
raise ValueError(f"Unsupported file extension: {ext}")
self.file_format = file_format
self.samples = self._load_file()
else:
self.samples = []
# Auto-detection: Ensure all loaded samples are dicts.
initial_sample_count = len(self.samples)
self.samples = [sample for sample in self.samples if isinstance(sample, dict)]
if len(self.samples) < initial_sample_count:
logger.warning(f"Filtered out {initial_sample_count - len(self.samples)} samples that were not dicts.")
# If a preprocessor is provided, apply preprocessing to each record.
if self.preprocessor:
preprocessed_samples = []
for sample in self.samples:
try:
processed = self.preprocessor.preprocess_record(sample)
preprocessed_samples.append(processed)
except Exception as e:
logger.error(f"Error preprocessing record {sample}: {e}")
self.samples = preprocessed_samples
def _load_file(self) -> List[Dict[str, Any]]:
try:
if self.file_format == 'json':
return self._load_json()
elif self.file_format == 'csv':
return self._load_csv()
elif self.file_format == 'txt':
return self._load_txt()
else:
logger.error(f"Unrecognized file format: {self.file_format}")
raise ValueError(f"Unrecognized file format: {self.file_format}")
except Exception as e:
logger.error(f"Error loading file {self.file_path}: {e}")
raise
@safe_file_operation
def _load_json(self) -> List[Dict[str, Any]]:
"""Load JSON file with better error handling and validation"""
try:
with open(self.file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Validate data structure
if isinstance(data, list):
valid_records = [record for record in data if isinstance(record, dict)]
if len(valid_records) < len(data):
logger.warning(f"{len(data) - len(valid_records)} records were not dictionaries in {self.file_path}")
return valid_records
elif isinstance(data, dict):
# Handle single record case
logger.warning(f"JSON file contains a single dictionary, not a list: {self.file_path}")
return [data]
else:
logger.error(f"JSON file does not contain a list or dictionary: {self.file_path}")
return []
except json.JSONDecodeError as e:
line_col = f"line {e.lineno}, column {e.colno}"
logger.error(f"JSON decode error at {line_col} in {self.file_path}: {e.msg}")
# Try to recover partial content
try:
with open(self.file_path, 'r', encoding='utf-8') as f:
content = f.read()
# Try parsing up to the error
valid_part = content[:e.pos]
import re
# Find complete objects (rough approach)
matches = re.findall(r'\{[^{}]*\}', valid_part)
if matches:
logger.info(f"Recovered {len(matches)} complete records from {self.file_path}")
parsed_records = []
for match in matches:
try:
parsed_records.append(json.loads(match))
except:
pass
return parsed_records
except:
pass
return []
@safe_file_operation
def _load_csv(self) -> List[Dict[str, Any]]:
"""Load CSV with better error handling"""
samples = []
try:
with open(self.file_path, 'r', encoding='utf-8') as csvfile:
# Try detecting dialect first
try:
dialect = csv.Sniffer().sniff(csvfile.read(1024))
csvfile.seek(0)
reader = csv.DictReader(csvfile, dialect=dialect)
except:
# Fall back to excel dialect
csvfile.seek(0)
reader = csv.DictReader(csvfile, dialect='excel')
for i, row in enumerate(reader):
if not isinstance(row, dict):
logger.warning(f"Row {i} is not a dict: {row} -- skipping.")
continue
samples.append(row)
if not samples:
logger.warning(f"No valid rows found in CSV file: {self.file_path}")
except csv.Error as e:
logger.error(f"Error reading CSV file {self.file_path}: {e}")
return samples
def _load_txt(self) -> List[Dict[str, Any]]:
samples = []
with open(self.file_path, 'r', encoding='utf-8') as txtfile:
for i, line in enumerate(txtfile):
line = line.strip()
if line:
# Wrap each line in a dictionary.
samples.append({"text": line})
return samples
def _standardize_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Remaps the sample's keys to a set of standardized keys using self.header_mapping.
For each standardized key, the first matching header from the sample is used.
If none is found, a default empty string is assigned."""
standardized = {}
for std_field, possible_keys in self.header_mapping.items():
for key in possible_keys:
if key in sample:
standardized[std_field] = sample[key]
break
if std_field not in standardized:
standardized[std_field] = ""
return standardized
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, index: int) -> Dict[str, Any]:
sample = self.samples[index]
# If a header mapping is provided, standardize the sample keys.
if self.header_mapping is not None:
sample = self._standardize_sample(sample)
# Determine the text to tokenize:
# If standardized keys "title" or "content" exist, combine them.
if 'title' in sample or 'content' in sample:
title = sample.get('title', '')
content = sample.get('content', '')
# Convert non-string fields to strings
if not isinstance(title, str):
title = str(title)
if not isinstance(content, str):
content = str(content)
text = (title + " " + content).strip()
elif "text" in sample:
text = sample["text"] if isinstance(sample["text"], str) else str(sample["text"])
else:
# Fallback: join all values (cast to str)
text = " ".join(str(v) for v in sample.values())
# Tokenize the combined text.
tokenized = self.tokenizer.encode_plus(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# Get specialization from sample or use class default
specialization = None
if isinstance(sample, dict) and "specialization" in sample:
specialization = sample["specialization"]
elif self.specialization:
specialization = self.specialization
# Return a standardized dictionary for training.
result = {
"input_ids": tokenized["input_ids"].squeeze(0),
"attention_mask": tokenized["attention_mask"].squeeze(0),
"token_type_ids": tokenized.get("token_type_ids", torch.zeros_like(tokenized["input_ids"])).squeeze(0),
}
# Add specialization if available
if specialization:
result["specialization"] = specialization
# Optionally include standardized text fields if needed
if 'title' in locals():
result["title"] = title
if 'content' in locals():
result["content"] = content
return result
# dataset.py - Simple dataset module to fix initialization dependency issues
import logging
import os
import json
from typing import Dict, List, Any, Optional, Union
logger = logging.getLogger(__name__)
class DatasetManager:
"""
Simple dataset manager to provide basic functionality for model_manager
without requiring external dataset dependencies
"""
def __init__(self, data_dir: Optional[str] = None):
self.data_dir = data_dir or os.path.join(os.path.dirname(__file__), "data")
self.datasets = {}
self._ensure_data_dir()
def _ensure_data_dir(self):
"""Ensure data directory exists"""
try:
if not os.path.exists(self.data_dir):
os.makedirs(self.data_dir, exist_ok=True)
logger.info(f"Created dataset directory at {self.data_dir}")
except (PermissionError, OSError) as e:
logger.warning(f"Could not create data directory: {e}")
# Fall back to temp directory
self.data_dir = os.path.join("/tmp", "wildnerve_data")
os.makedirs(self.data_dir, exist_ok=True)
logger.info(f"Using fallback data directory at {self.data_dir}")
def load_dataset(self, name: str) -> List[Dict[str, Any]]:
"""Load dataset by name"""
if name in self.datasets:
return self.datasets[name]
# Check for dataset file
filepath = os.path.join(self.data_dir, f"{name}.json")
if os.path.exists(filepath):
try:
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
self.datasets[name] = data
return data
except Exception as e:
logger.error(f"Error loading dataset {name}: {e}")
# Return empty dataset if not found
logger.warning(f"Dataset {name} not found, returning empty dataset")
return []
def get_dataset_names(self) -> List[str]:
"""Get list of available datasets"""
try:
return [f.split('.')[0] for f in os.listdir(self.data_dir)
if f.endswith('.json')]
except Exception as e:
logger.error(f"Error listing datasets: {e}")
return []
def create_sample_dataset(self, name: str, samples: int = 10) -> List[Dict[str, Any]]:
"""Create a sample dataset for testing"""
data = [
{
"id": i,
"text": f"Sample text {i} for model training",
"label": i % 2 # Binary label
}
for i in range(samples)
]
# Save to file
filepath = os.path.join(self.data_dir, f"{name}.json")
try:
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2)
self.datasets[name] = data
logger.info(f"Created sample dataset {name} with {samples} samples")
except Exception as e:
logger.error(f"Error creating sample dataset: {e}")
return data
def _load_and_process_dataset(self, path_or_paths: Union[str, List[str]], specialization: str) -> TensorDataset:
# …existing code up to reading the file…
import pandas as pd
# Handle multiple JSON files by concatenation
if isinstance(path_or_paths, list):
frames = [pd.read_json(p) for p in path_or_paths]
data = pd.concat(frames, ignore_index=True)
else:
data = pd.read_json(path_or_paths)
# …existing code that splits into features/labels and returns TensorDataset…
# Create a default dataset manager instance
dataset_manager = DatasetManager()
def get_dataset(name: str) -> List[Dict[str, Any]]:
"""Helper function to get a dataset by name"""
return dataset_manager.load_dataset(name)
# Create some minimal sample data if running as main
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
dm = DatasetManager()
dm.create_sample_dataset("test_dataset", samples=20)
print(f"Available datasets: {dm.get_dataset_names()}")
test_data = dm.load_dataset("test_dataset")
print(f"Loaded {len(test_data)} samples from test_dataset")