|
|
"""
|
|
|
Data loader factory and utilities for transformer models.
|
|
|
"""
|
|
|
import os
|
|
|
import json
|
|
|
import torch
|
|
|
import logging
|
|
|
import pandas as pd
|
|
|
import numpy as np
|
|
|
from typing import Dict, List, Optional, Union, Any, Tuple
|
|
|
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
|
|
from pathlib import Path
|
|
|
from config import app_config
|
|
|
from tokenizer import TokenizerWrapper
|
|
|
from datagrower.Crawl4MyAI import AdvancedWebCrawler
|
|
|
from datagrower.Webconverter import WebConverter
|
|
|
from dataset import DatasetManager
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TransformerDataset(Dataset):
|
|
|
"""Base dataset for transformer models that handles multiple input formats."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
data_path: str,
|
|
|
tokenizer: TokenizerWrapper,
|
|
|
max_length: int = 512,
|
|
|
format_type: str = None
|
|
|
):
|
|
|
"""
|
|
|
Initialize dataset.
|
|
|
|
|
|
Args:
|
|
|
data_path: Path to the data file
|
|
|
tokenizer: Tokenizer to use for encoding
|
|
|
max_length: Maximum sequence length
|
|
|
format_type: Format of data file ('csv', 'json', 'txt')
|
|
|
"""
|
|
|
self.data_path = data_path
|
|
|
self.tokenizer = tokenizer
|
|
|
self.max_length = max_length
|
|
|
self.format_type = format_type or self._detect_format(data_path)
|
|
|
|
|
|
|
|
|
self.data = self._load_data()
|
|
|
logger.info(f"Loaded {len(self.data)} samples from {data_path}")
|
|
|
|
|
|
def _detect_format(self, path: str) -> str:
|
|
|
"""Detect file format from extension."""
|
|
|
ext = os.path.splitext(path)[1].lower().lstrip('.')
|
|
|
if ext in ['csv']:
|
|
|
return 'csv'
|
|
|
elif ext in ['json']:
|
|
|
return 'json'
|
|
|
elif ext in ['txt', 'text']:
|
|
|
return 'txt'
|
|
|
else:
|
|
|
logger.warning(f"Unknown file extension: {ext}, defaulting to CSV")
|
|
|
return 'csv'
|
|
|
|
|
|
def _load_data(self) -> List[Dict[str, Any]]:
|
|
|
"""Load data based on format type."""
|
|
|
if not os.path.exists(self.data_path):
|
|
|
raise FileNotFoundError(f"Data file not found: {self.data_path}")
|
|
|
|
|
|
try:
|
|
|
if self.format_type == 'csv':
|
|
|
return self._load_csv()
|
|
|
elif self.format_type == 'json':
|
|
|
return self._load_json()
|
|
|
elif self.format_type == 'txt':
|
|
|
return self._load_txt()
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported format type: {self.format_type}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error loading data from {self.data_path}: {e}")
|
|
|
raise
|
|
|
|
|
|
def _load_csv(self) -> List[Dict[str, Any]]:
|
|
|
"""Load data from CSV file."""
|
|
|
df = pd.read_csv(self.data_path)
|
|
|
|
|
|
if 'text' not in df.columns:
|
|
|
|
|
|
text_cols = [col for col in df.columns if 'text' in col.lower() or 'content' in col.lower()]
|
|
|
if text_cols:
|
|
|
df = df.rename(columns={text_cols[0]: 'text'})
|
|
|
else:
|
|
|
|
|
|
df = df.rename(columns={df.columns[0]: 'text'})
|
|
|
|
|
|
|
|
|
if 'label' not in df.columns and len(df.columns) > 1:
|
|
|
|
|
|
df = df.rename(columns={df.columns[1]: 'label'})
|
|
|
|
|
|
return df.to_dict('records')
|
|
|
|
|
|
def _load_json(self) -> List[Dict[str, Any]]:
|
|
|
"""Load data from JSON file."""
|
|
|
with open(self.data_path, 'r', encoding='utf-8') as f:
|
|
|
data = json.load(f)
|
|
|
|
|
|
|
|
|
if isinstance(data, list):
|
|
|
|
|
|
return data
|
|
|
elif isinstance(data, dict):
|
|
|
|
|
|
if 'data' in data:
|
|
|
return data['data']
|
|
|
elif 'examples' in data:
|
|
|
return data['examples']
|
|
|
elif 'user_inputs' in data:
|
|
|
return data['user_inputs']
|
|
|
else:
|
|
|
|
|
|
return [{'text': str(value), 'id': key} for key, value in data.items()]
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported JSON data structure: {type(data)}")
|
|
|
|
|
|
def _load_txt(self) -> List[Dict[str, Any]]:
|
|
|
"""Load data from text file, one sample per line."""
|
|
|
with open(self.data_path, 'r', encoding='utf-8') as f:
|
|
|
lines = f.readlines()
|
|
|
|
|
|
|
|
|
return [{'text': line.strip(), 'id': i} for i, line in enumerate(lines) if line.strip()]
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
"""Get dataset length."""
|
|
|
return len(self.data)
|
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
|
|
"""Get an item from the dataset."""
|
|
|
item = self.data[idx]
|
|
|
text = item.get('text', '')
|
|
|
|
|
|
|
|
|
if not text:
|
|
|
text = " "
|
|
|
|
|
|
|
|
|
encoding = self.tokenizer(
|
|
|
text,
|
|
|
max_length=self.max_length,
|
|
|
padding="max_length",
|
|
|
truncation=True,
|
|
|
return_tensors="pt"
|
|
|
)
|
|
|
|
|
|
|
|
|
input_ids = encoding["input_ids"].squeeze(0)
|
|
|
attention_mask = encoding["attention_mask"].squeeze(0)
|
|
|
|
|
|
|
|
|
label = item.get('label', 0)
|
|
|
if isinstance(label, str):
|
|
|
try:
|
|
|
label = float(label)
|
|
|
except ValueError:
|
|
|
|
|
|
label = hash(label) % 100
|
|
|
|
|
|
return {
|
|
|
'input_ids': input_ids,
|
|
|
'attention_mask': attention_mask,
|
|
|
'labels': torch.tensor(label, dtype=torch.long)
|
|
|
}
|
|
|
|
|
|
def prepare_data_loaders_extended(
|
|
|
data_path: Union[str, Dict[str, str]],
|
|
|
tokenizer: Any,
|
|
|
batch_size: int = 16,
|
|
|
max_length: int = 512,
|
|
|
val_split: float = 0.1,
|
|
|
format_type: Optional[str] = None,
|
|
|
num_workers: int = 0
|
|
|
) -> Dict[str, DataLoader]:
|
|
|
"""
|
|
|
Create data loaders for training and validation.
|
|
|
|
|
|
Args:
|
|
|
data_path: Path to data file or dictionary mapping split to path
|
|
|
tokenizer: Tokenizer to use for encoding
|
|
|
batch_size: Batch size
|
|
|
max_length: Maximum sequence length
|
|
|
val_split: Validation split ratio when only one path is provided
|
|
|
format_type: Format of data file
|
|
|
num_workers: Number of workers for DataLoader
|
|
|
|
|
|
Returns:
|
|
|
Dictionary mapping split names to DataLoaders
|
|
|
"""
|
|
|
data_loaders = {}
|
|
|
|
|
|
|
|
|
if isinstance(data_path, dict):
|
|
|
|
|
|
for split_name, path in data_path.items():
|
|
|
dataset = TransformerDataset(
|
|
|
data_path=path,
|
|
|
tokenizer=tokenizer,
|
|
|
max_length=max_length,
|
|
|
format_type=format_type
|
|
|
)
|
|
|
|
|
|
data_loaders[split_name] = DataLoader(
|
|
|
dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=(split_name == 'train'),
|
|
|
num_workers=num_workers
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
dataset = TransformerDataset(
|
|
|
data_path=data_path,
|
|
|
tokenizer=tokenizer,
|
|
|
max_length=max_length,
|
|
|
format_type=format_type
|
|
|
)
|
|
|
|
|
|
|
|
|
val_size = int(len(dataset) * val_split)
|
|
|
train_size = len(dataset) - val_size
|
|
|
|
|
|
if val_size > 0:
|
|
|
train_dataset, val_dataset = torch.utils.data.random_split(
|
|
|
dataset, [train_size, val_size]
|
|
|
)
|
|
|
|
|
|
data_loaders['train'] = DataLoader(
|
|
|
train_dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=True,
|
|
|
num_workers=num_workers
|
|
|
)
|
|
|
|
|
|
data_loaders['validation'] = DataLoader(
|
|
|
val_dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=False,
|
|
|
num_workers=num_workers
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
data_loaders['train'] = DataLoader(
|
|
|
dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=True,
|
|
|
num_workers=num_workers
|
|
|
)
|
|
|
|
|
|
return data_loaders
|
|
|
|
|
|
def prepare_data_loaders(
|
|
|
data_path: str,
|
|
|
tokenizer: Any,
|
|
|
batch_size: int = 16,
|
|
|
val_split: float = 0.1
|
|
|
) -> Tuple[DataLoader, Optional[DataLoader]]:
|
|
|
"""
|
|
|
Simplified version that returns train and validation loaders directly.
|
|
|
|
|
|
Args:
|
|
|
data_path: Path to data file
|
|
|
tokenizer: Tokenizer to use for encoding
|
|
|
batch_size: Batch size
|
|
|
val_split: Validation split ratio
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (train_loader, val_loader)
|
|
|
"""
|
|
|
loaders = prepare_data_loaders_extended(
|
|
|
data_path=data_path,
|
|
|
tokenizer=tokenizer,
|
|
|
batch_size=batch_size,
|
|
|
val_split=val_split
|
|
|
)
|
|
|
|
|
|
train_loader = loaders.get('train')
|
|
|
val_loader = loaders.get('validation')
|
|
|
|
|
|
return train_loader, val_loader
|
|
|
|
|
|
def load_dataset(
|
|
|
specialization: str,
|
|
|
tokenizer: Any = None,
|
|
|
split: str = 'train'
|
|
|
) -> Dataset:
|
|
|
"""
|
|
|
Load a dataset for a specific specialization.
|
|
|
|
|
|
Args:
|
|
|
specialization: Name of the specialization
|
|
|
tokenizer: Tokenizer to use (optional)
|
|
|
split: Dataset split to load
|
|
|
|
|
|
Returns:
|
|
|
Dataset instance
|
|
|
"""
|
|
|
|
|
|
if hasattr(app_config, 'DATASET_PATHS') and specialization in app_config.DATASET_PATHS:
|
|
|
data_path = app_config.DATASET_PATHS[specialization]
|
|
|
else:
|
|
|
data_path = os.path.join(app_config.BASE_DATA_DIR, f"{specialization}.csv")
|
|
|
|
|
|
|
|
|
if tokenizer is None:
|
|
|
from tokenizer import TokenizerWrapper
|
|
|
tokenizer = TokenizerWrapper()
|
|
|
|
|
|
|
|
|
if data_path.startswith("http://") or data_path.startswith("https://"):
|
|
|
crawler = AdvancedWebCrawler()
|
|
|
converter = WebConverter(crawler=crawler)
|
|
|
raw_entries = converter.get_converted_web_data([data_path])
|
|
|
|
|
|
return TransformerDataset(data_path=data_path, tokenizer=tokenizer)._process_records(raw_entries)
|
|
|
|
|
|
|
|
|
dataset = TransformerDataset(
|
|
|
data_path=data_path,
|
|
|
tokenizer=tokenizer,
|
|
|
max_length=app_config.TRANSFORMER_CONFIG.MAX_SEQ_LENGTH
|
|
|
)
|
|
|
|
|
|
return dataset
|
|
|
|
|
|
def load_for_specialization(spec: str):
|
|
|
paths = app_config.get("DATASET_PATHS", {}).get(spec, [])
|
|
|
|
|
|
if isinstance(paths, str):
|
|
|
paths = [paths]
|
|
|
manager = DatasetManager()
|
|
|
return manager.load_dataset(paths, spec)
|
|
|
|
|
|
|
|
|
get_dataloader = prepare_data_loaders |