|
|
"""
|
|
|
Data processing module for KerdosAI.
|
|
|
"""
|
|
|
|
|
|
from typing import Dict, Any, List, Optional
|
|
|
from pathlib import Path
|
|
|
import pandas as pd
|
|
|
import json
|
|
|
from datasets import Dataset
|
|
|
from transformers import PreTrainedTokenizer
|
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DataProcessor:
|
|
|
"""
|
|
|
Handles data processing and preparation for training.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
data_path: Path,
|
|
|
max_length: int = 512,
|
|
|
text_column: str = "text",
|
|
|
**kwargs
|
|
|
):
|
|
|
"""
|
|
|
Initialize the data processor.
|
|
|
|
|
|
Args:
|
|
|
data_path: Path to the training data
|
|
|
max_length: Maximum sequence length
|
|
|
text_column: Name of the text column in the data
|
|
|
**kwargs: Additional configuration parameters
|
|
|
"""
|
|
|
self.data_path = Path(data_path)
|
|
|
self.max_length = max_length
|
|
|
self.text_column = text_column
|
|
|
self.config = kwargs
|
|
|
|
|
|
if not self.data_path.exists():
|
|
|
raise FileNotFoundError(f"Data path {data_path} does not exist")
|
|
|
|
|
|
def prepare_dataset(
|
|
|
self,
|
|
|
tokenizer: Optional[PreTrainedTokenizer] = None,
|
|
|
**kwargs
|
|
|
) -> Dataset:
|
|
|
"""
|
|
|
Prepare the dataset for training.
|
|
|
|
|
|
Args:
|
|
|
tokenizer: Tokenizer for text processing
|
|
|
**kwargs: Additional processing parameters
|
|
|
|
|
|
Returns:
|
|
|
HuggingFace Dataset object
|
|
|
"""
|
|
|
|
|
|
data = self._load_data()
|
|
|
|
|
|
|
|
|
processed_data = self._process_data(data, **kwargs)
|
|
|
|
|
|
|
|
|
dataset = Dataset.from_dict(processed_data)
|
|
|
|
|
|
|
|
|
if tokenizer is not None:
|
|
|
dataset = self._tokenize_dataset(dataset, tokenizer)
|
|
|
|
|
|
return dataset
|
|
|
|
|
|
def _load_data(self) -> Dict[str, List[str]]:
|
|
|
"""
|
|
|
Load data from file.
|
|
|
|
|
|
Returns:
|
|
|
Dictionary containing the loaded data
|
|
|
"""
|
|
|
if self.data_path.suffix == ".csv":
|
|
|
df = pd.read_csv(self.data_path)
|
|
|
return {"text": df[self.text_column].tolist()}
|
|
|
elif self.data_path.suffix == ".json":
|
|
|
with open(self.data_path, "r") as f:
|
|
|
data = json.load(f)
|
|
|
return {"text": [item[self.text_column] for item in data]}
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported file format: {self.data_path.suffix}")
|
|
|
|
|
|
def _process_data(
|
|
|
self,
|
|
|
data: Dict[str, List[str]],
|
|
|
**kwargs
|
|
|
) -> Dict[str, List[str]]:
|
|
|
"""
|
|
|
Process the loaded data.
|
|
|
|
|
|
Args:
|
|
|
data: Dictionary containing the loaded data
|
|
|
**kwargs: Additional processing parameters
|
|
|
|
|
|
Returns:
|
|
|
Processed data dictionary
|
|
|
"""
|
|
|
|
|
|
texts = data["text"]
|
|
|
processed_texts = []
|
|
|
|
|
|
for text in texts:
|
|
|
|
|
|
text = " ".join(text.split())
|
|
|
processed_texts.append(text)
|
|
|
|
|
|
return {"text": processed_texts}
|
|
|
|
|
|
def _tokenize_dataset(
|
|
|
self,
|
|
|
dataset: Dataset,
|
|
|
tokenizer: PreTrainedTokenizer
|
|
|
) -> Dataset:
|
|
|
"""
|
|
|
Tokenize the dataset.
|
|
|
|
|
|
Args:
|
|
|
dataset: HuggingFace Dataset object
|
|
|
tokenizer: Tokenizer for text processing
|
|
|
|
|
|
Returns:
|
|
|
Tokenized dataset
|
|
|
"""
|
|
|
def tokenize_function(examples):
|
|
|
return tokenizer(
|
|
|
examples["text"],
|
|
|
padding="max_length",
|
|
|
truncation=True,
|
|
|
max_length=self.max_length,
|
|
|
return_tensors="pt"
|
|
|
)
|
|
|
|
|
|
tokenized_dataset = dataset.map(
|
|
|
tokenize_function,
|
|
|
batched=True,
|
|
|
remove_columns=dataset.column_names
|
|
|
)
|
|
|
|
|
|
return tokenized_dataset
|
|
|
|
|
|
def validate_data(self) -> bool:
|
|
|
"""
|
|
|
Validate the training data.
|
|
|
|
|
|
Returns:
|
|
|
True if data is valid, False otherwise
|
|
|
"""
|
|
|
try:
|
|
|
data = self._load_data()
|
|
|
|
|
|
|
|
|
if not data["text"]:
|
|
|
logger.warning("No training examples found in the data")
|
|
|
return False
|
|
|
|
|
|
|
|
|
if len(data["text"]) < 1000:
|
|
|
logger.warning(
|
|
|
f"Only {len(data['text'])} training examples found. "
|
|
|
"At least 1000 examples are recommended."
|
|
|
)
|
|
|
|
|
|
|
|
|
lengths = [len(text.split()) for text in data["text"]]
|
|
|
avg_length = sum(lengths) / len(lengths)
|
|
|
|
|
|
if avg_length > self.max_length:
|
|
|
logger.warning(
|
|
|
f"Average text length ({avg_length:.1f} words) "
|
|
|
f"exceeds max_length ({self.max_length})"
|
|
|
)
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error validating data: {str(e)}")
|
|
|
return False |