|
|
|
|
|
""" |
|
|
Multi-Agent Dataset Loader |
|
|
|
|
|
This module provides comprehensive support for loading and processing multi-agent datasets |
|
|
with two supported patterns: |
|
|
A) Single folder with JSONLs that include an "agent" field |
|
|
B) Per-agent subfolders (agent name == folder name) |
|
|
|
|
|
Supports agent balancing, dataset validation, and integration with existing training pipelines. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import yaml |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Union, Tuple, Any |
|
|
from collections import Counter, defaultdict |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
|
class MultiAgentDatasetConfig: |
|
|
"""Configuration for multi-agent dataset loading""" |
|
|
dataset_path: str |
|
|
agents_file: Optional[str] = None |
|
|
agent_prefix: str = "<|agent:" |
|
|
agent_suffix: str = "|>" |
|
|
balance_agents: bool = False |
|
|
balance_cap: Optional[int] = None |
|
|
max_seq_length: int = 2048 |
|
|
validation_split: float = 0.1 |
|
|
seed: int = 42 |
|
|
|
|
|
class MultiAgentDatasetLoader: |
|
|
""" |
|
|
Multi-agent dataset loader supporting two patterns: |
|
|
1. Single folder with JSONLs containing 'agent' field |
|
|
2. Per-agent subfolders with agent name == folder name |
|
|
""" |
|
|
|
|
|
def __init__(self, config: MultiAgentDatasetConfig): |
|
|
self.config = config |
|
|
self.agents = [] |
|
|
self.dataset_stats = {} |
|
|
|
|
|
def read_agents_yaml(self, path: str) -> List[str]: |
|
|
"""Read agents list from YAML file""" |
|
|
yml_path = os.path.join(path, "agents.yaml") |
|
|
if os.path.isfile(yml_path): |
|
|
try: |
|
|
with open(yml_path, "r") as f: |
|
|
obj = yaml.safe_load(f) or {} |
|
|
agents = [str(a) for a in obj.get("agents", [])] |
|
|
logger.info(f"Loaded {len(agents)} agents from YAML: {agents}") |
|
|
return agents |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to read agents.yaml: {e}") |
|
|
return [] |
|
|
|
|
|
def list_agent_subdirs(self, path: str) -> List[Tuple[str, str]]: |
|
|
"""List agent subdirectories with train/test.jsonl files""" |
|
|
items = [] |
|
|
if not os.path.isdir(path): |
|
|
return items |
|
|
|
|
|
for name in sorted(os.listdir(path)): |
|
|
subdir_path = os.path.join(path, name) |
|
|
if os.path.isdir(subdir_path): |
|
|
train_file = os.path.join(subdir_path, "train.jsonl") |
|
|
test_file = os.path.join(subdir_path, "test.jsonl") |
|
|
if os.path.isfile(train_file) or os.path.isfile(test_file): |
|
|
items.append((name, subdir_path)) |
|
|
logger.debug(f"Found agent subdirectory: {name}") |
|
|
|
|
|
return items |
|
|
|
|
|
def load_single_folder_dataset(self, dataset_path: str) -> DatasetDict: |
|
|
"""Load dataset from single folder with agent field in rows""" |
|
|
data_files = {} |
|
|
|
|
|
train_file = os.path.join(dataset_path, "train.jsonl") |
|
|
test_file = os.path.join(dataset_path, "test.jsonl") |
|
|
|
|
|
if os.path.isfile(train_file): |
|
|
data_files["train"] = train_file |
|
|
if os.path.isfile(test_file): |
|
|
data_files["test"] = test_file |
|
|
|
|
|
if not data_files: |
|
|
raise FileNotFoundError(f"No dataset files found in {dataset_path}") |
|
|
|
|
|
logger.info(f"Loading single folder dataset from {data_files}") |
|
|
dataset = load_dataset("json", data_files=data_files) |
|
|
|
|
|
|
|
|
for split_name, split_data in dataset.items(): |
|
|
if "agent" not in split_data.column_names: |
|
|
raise ValueError(f"Agent field not found in {split_name} split") |
|
|
|
|
|
return dataset |
|
|
|
|
|
def load_subfolder_dataset(self, dataset_path: str) -> DatasetDict: |
|
|
"""Load dataset from per-agent subfolders""" |
|
|
subdirs = self.list_agent_subdirs(dataset_path) |
|
|
if not subdirs: |
|
|
raise FileNotFoundError(f"No agent subdirectories found in {dataset_path}") |
|
|
|
|
|
parts_train, parts_test = [], [] |
|
|
|
|
|
for agent_name, agent_dir in subdirs: |
|
|
train_file = os.path.join(agent_dir, "train.jsonl") |
|
|
test_file = os.path.join(agent_dir, "test.jsonl") |
|
|
|
|
|
def add_agent_field(example): |
|
|
example["agent"] = agent_name |
|
|
return example |
|
|
|
|
|
if os.path.isfile(train_file): |
|
|
logger.debug(f"Loading train data for agent: {agent_name}") |
|
|
train_data = load_dataset("json", data_files={"train": train_file})["train"] |
|
|
train_data = train_data.map(add_agent_field) |
|
|
parts_train.append(train_data) |
|
|
|
|
|
if os.path.isfile(test_file): |
|
|
logger.debug(f"Loading test data for agent: {agent_name}") |
|
|
test_data = load_dataset("json", data_files={"test": test_file})["test"] |
|
|
test_data = test_data.map(add_agent_field) |
|
|
parts_test.append(test_data) |
|
|
|
|
|
dataset_dict = {} |
|
|
if parts_train: |
|
|
dataset_dict["train"] = concatenate_datasets(parts_train) |
|
|
if parts_test: |
|
|
dataset_dict["test"] = concatenate_datasets(parts_test) |
|
|
|
|
|
if not dataset_dict: |
|
|
raise ValueError("No data splits found in agent subdirectories") |
|
|
|
|
|
return DatasetDict(dataset_dict) |
|
|
|
|
|
def load_multiagent_dataset(self) -> DatasetDict: |
|
|
""" |
|
|
Load multi-agent dataset supporting both patterns: |
|
|
- Single folder with 'agent' field in rows |
|
|
- Per-agent subfolders |
|
|
""" |
|
|
dataset_path = self.config.dataset_path |
|
|
|
|
|
|
|
|
if os.path.isfile(os.path.join(dataset_path, "train.jsonl")): |
|
|
logger.info("Loading dataset using single folder pattern") |
|
|
return self.load_single_folder_dataset(dataset_path) |
|
|
|
|
|
|
|
|
logger.info("Loading dataset using subfolder pattern") |
|
|
return self.load_subfolder_dataset(dataset_path) |
|
|
|
|
|
def infer_agents_from_dataset(self, dataset: DatasetDict) -> List[str]: |
|
|
"""Infer agent list from dataset""" |
|
|
agents = set() |
|
|
|
|
|
for split_name, split_data in dataset.items(): |
|
|
if "agent" in split_data.column_names: |
|
|
agent_values = [a for a in set(split_data["agent"]) if a is not None] |
|
|
agents.update(agent_values) |
|
|
logger.debug(f"Found agents in {split_name}: {agent_values}") |
|
|
|
|
|
agents_list = sorted(list(agents)) |
|
|
logger.info(f"Inferred {len(agents_list)} agents from dataset: {agents_list}") |
|
|
return agents_list |
|
|
|
|
|
def resolve_agents_list(self, dataset: DatasetDict) -> List[str]: |
|
|
"""Resolve agents list from YAML file or dataset inference""" |
|
|
agents = [] |
|
|
|
|
|
|
|
|
if self.config.agents_file and os.path.isfile(self.config.agents_file): |
|
|
try: |
|
|
with open(self.config.agents_file, "r") as f: |
|
|
obj = yaml.safe_load(f) or {} |
|
|
agents = [str(a) for a in obj.get("agents", [])] |
|
|
logger.info(f"Loaded agents from file: {agents}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load agents from file: {e}") |
|
|
|
|
|
|
|
|
if not agents: |
|
|
agents = self.infer_agents_from_dataset(dataset) |
|
|
|
|
|
self.agents = agents |
|
|
return agents |
|
|
|
|
|
def balance_by_agent(self, dataset: Dataset, agent_col: str = "agent") -> Dataset: |
|
|
""" |
|
|
Balance dataset by upsampling minority agents to the max count |
|
|
""" |
|
|
if agent_col not in dataset.column_names: |
|
|
logger.warning(f"Agent column '{agent_col}' not found, skipping balancing") |
|
|
return dataset |
|
|
|
|
|
counts = Counter(dataset[agent_col]) |
|
|
if not counts: |
|
|
logger.warning("No agent counts found, skipping balancing") |
|
|
return dataset |
|
|
|
|
|
max_count = max(counts.values()) |
|
|
if self.config.balance_cap: |
|
|
max_count = min(max_count, self.config.balance_cap) |
|
|
|
|
|
logger.info(f"Balancing agents. Current counts: {dict(counts)}") |
|
|
logger.info(f"Target count per agent: {max_count}") |
|
|
|
|
|
parts = [] |
|
|
for agent, count in counts.items(): |
|
|
agent_subset = dataset.filter(lambda x: x[agent_col] == agent) |
|
|
parts.append(agent_subset) |
|
|
|
|
|
|
|
|
needed = max_count - count |
|
|
if needed > 0: |
|
|
agent_subset_len = len(agent_subset) |
|
|
if agent_subset_len == 0: |
|
|
logger.warning(f"Agent '{agent}' has zero samples, cannot upsample.") |
|
|
continue |
|
|
|
|
|
reps = needed // agent_subset_len |
|
|
remainder = needed % agent_subset_len |
|
|
|
|
|
|
|
|
for _ in range(reps): |
|
|
parts.append(agent_subset) |
|
|
|
|
|
|
|
|
if remainder > 0: |
|
|
remainder_subset = agent_subset.shuffle(seed=self.config.seed).select(range(remainder)) |
|
|
parts.append(remainder_subset) |
|
|
|
|
|
balanced_dataset = concatenate_datasets(parts).shuffle(seed=self.config.seed) |
|
|
|
|
|
|
|
|
final_counts = Counter(balanced_dataset[agent_col]) |
|
|
logger.info(f"Balanced dataset counts: {dict(final_counts)}") |
|
|
|
|
|
return balanced_dataset |
|
|
|
|
|
def apply_agent_prefix(self, dataset: Dataset, tokenizer: AutoTokenizer) -> Dataset: |
|
|
""" |
|
|
Apply agent prefix to dataset text using chat template or direct text |
|
|
""" |
|
|
def add_agent_prefix(example): |
|
|
agent = example.get("agent", None) |
|
|
prefix = f"{self.config.agent_prefix}{agent}{self.config.agent_suffix}\n" if agent else "" |
|
|
|
|
|
|
|
|
if "messages" in example and example["messages"] is not None: |
|
|
|
|
|
try: |
|
|
text = tokenizer.apply_chat_template( |
|
|
example["messages"], |
|
|
tokenize=False, |
|
|
add_generation_prompt=False |
|
|
) |
|
|
example["text"] = prefix + text |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to apply chat template: {e}") |
|
|
|
|
|
text = str(example["messages"]) |
|
|
example["text"] = prefix + text |
|
|
|
|
|
elif "text" in example and example["text"] is not None: |
|
|
example["text"] = prefix + example["text"] |
|
|
|
|
|
else: |
|
|
|
|
|
prompt = example.get("prompt", "") |
|
|
response = example.get("response", "") |
|
|
example["text"] = prefix + prompt + ("\n" if response else "") + response |
|
|
|
|
|
return example |
|
|
|
|
|
|
|
|
original_features = list(dataset.features) |
|
|
features_to_remove = [f for f in original_features if f not in ["text", "agent"]] |
|
|
|
|
|
logger.info("Applying agent prefixes to dataset") |
|
|
processed_dataset = dataset.map( |
|
|
add_agent_prefix, |
|
|
remove_columns=features_to_remove, |
|
|
desc="Adding agent prefixes" |
|
|
) |
|
|
|
|
|
return processed_dataset |
|
|
|
|
|
def validate_dataset(self, dataset: DatasetDict) -> Dict[str, Any]: |
|
|
"""Validate dataset and return statistics""" |
|
|
stats = { |
|
|
"total_samples": 0, |
|
|
"agents": {}, |
|
|
"splits": {}, |
|
|
"validation_errors": [] |
|
|
} |
|
|
|
|
|
for split_name, split_data in dataset.items(): |
|
|
split_stats = { |
|
|
"samples": len(split_data), |
|
|
"agents": {}, |
|
|
"columns": split_data.column_names |
|
|
} |
|
|
|
|
|
stats["total_samples"] += len(split_data) |
|
|
|
|
|
|
|
|
if "agent" not in split_data.column_names: |
|
|
stats["validation_errors"].append(f"Missing 'agent' column in {split_name}") |
|
|
|
|
|
|
|
|
if "agent" in split_data.column_names: |
|
|
agent_counts = Counter(split_data["agent"]) |
|
|
split_stats["agents"] = dict(agent_counts) |
|
|
|
|
|
|
|
|
for agent, count in agent_counts.items(): |
|
|
if agent not in stats["agents"]: |
|
|
stats["agents"][agent] = 0 |
|
|
stats["agents"][agent] += count |
|
|
|
|
|
stats["splits"][split_name] = split_stats |
|
|
|
|
|
self.dataset_stats = stats |
|
|
logger.info(f"Dataset validation complete. Stats: {stats}") |
|
|
|
|
|
return stats |
|
|
|
|
|
def load_and_process(self, tokenizer: AutoTokenizer) -> Tuple[DatasetDict, List[str], Dict[str, Any]]: |
|
|
""" |
|
|
Complete dataset loading and processing pipeline |
|
|
""" |
|
|
logger.info(f"Loading multi-agent dataset from {self.config.dataset_path}") |
|
|
|
|
|
|
|
|
dataset = self.load_multiagent_dataset() |
|
|
|
|
|
|
|
|
agents = self.resolve_agents_list(dataset) |
|
|
|
|
|
|
|
|
stats = self.validate_dataset(dataset) |
|
|
|
|
|
|
|
|
if "train" in dataset: |
|
|
dataset["train"] = self.apply_agent_prefix(dataset["train"], tokenizer) |
|
|
if "test" in dataset: |
|
|
dataset["test"] = self.apply_agent_prefix(dataset["test"], tokenizer) |
|
|
|
|
|
|
|
|
if self.config.balance_agents and "train" in dataset: |
|
|
dataset["train"] = self.balance_by_agent(dataset["train"]) |
|
|
|
|
|
logger.info(f"Dataset processing complete. Loaded {len(agents)} agents with {stats['total_samples']} total samples") |
|
|
|
|
|
return dataset, agents, stats |
|
|
|
|
|
class MultiAgentDatasetValidator: |
|
|
"""Validator for multi-agent datasets""" |
|
|
|
|
|
@staticmethod |
|
|
def validate_jsonl_file(file_path: str) -> List[str]: |
|
|
"""Validate JSONL file format and content""" |
|
|
errors = [] |
|
|
|
|
|
if not os.path.isfile(file_path): |
|
|
errors.append(f"File not found: {file_path}") |
|
|
return errors |
|
|
|
|
|
try: |
|
|
with open(file_path, 'r') as f: |
|
|
for line_num, line in enumerate(f, 1): |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
|
|
|
try: |
|
|
data = json.loads(line) |
|
|
|
|
|
|
|
|
if not isinstance(data, dict): |
|
|
errors.append(f"Line {line_num}: Not a JSON object") |
|
|
continue |
|
|
|
|
|
|
|
|
if "agent" not in data: |
|
|
errors.append(f"Line {line_num}: Missing 'agent' field") |
|
|
|
|
|
|
|
|
has_text = any(field in data for field in ["text", "messages", "prompt"]) |
|
|
if not has_text: |
|
|
errors.append(f"Line {line_num}: No text content found") |
|
|
|
|
|
except json.JSONDecodeError as e: |
|
|
errors.append(f"Line {line_num}: JSON decode error - {e}") |
|
|
|
|
|
except Exception as e: |
|
|
errors.append(f"File read error: {e}") |
|
|
|
|
|
return errors |
|
|
|
|
|
@staticmethod |
|
|
def validate_dataset_structure(dataset_path: str) -> Dict[str, Any]: |
|
|
"""Validate complete dataset structure""" |
|
|
validation_result = { |
|
|
"valid": True, |
|
|
"errors": [], |
|
|
"warnings": [], |
|
|
"structure": {} |
|
|
} |
|
|
|
|
|
if not os.path.isdir(dataset_path): |
|
|
validation_result["valid"] = False |
|
|
validation_result["errors"].append(f"Dataset path is not a directory: {dataset_path}") |
|
|
return validation_result |
|
|
|
|
|
|
|
|
train_file = os.path.join(dataset_path, "train.jsonl") |
|
|
test_file = os.path.join(dataset_path, "test.jsonl") |
|
|
|
|
|
if os.path.isfile(train_file): |
|
|
validation_result["structure"]["pattern"] = "single_folder" |
|
|
validation_result["structure"]["files"] = [] |
|
|
|
|
|
if os.path.isfile(train_file): |
|
|
validation_result["structure"]["files"].append("train.jsonl") |
|
|
errors = MultiAgentDatasetValidator.validate_jsonl_file(train_file) |
|
|
validation_result["errors"].extend(errors) |
|
|
|
|
|
if os.path.isfile(test_file): |
|
|
validation_result["structure"]["files"].append("test.jsonl") |
|
|
errors = MultiAgentDatasetValidator.validate_jsonl_file(test_file) |
|
|
validation_result["errors"].extend(errors) |
|
|
|
|
|
else: |
|
|
|
|
|
validation_result["structure"]["pattern"] = "subfolders" |
|
|
validation_result["structure"]["agents"] = [] |
|
|
|
|
|
for item in os.listdir(dataset_path): |
|
|
item_path = os.path.join(dataset_path, item) |
|
|
if os.path.isdir(item_path): |
|
|
agent_train = os.path.join(item_path, "train.jsonl") |
|
|
agent_test = os.path.join(item_path, "test.jsonl") |
|
|
|
|
|
if os.path.isfile(agent_train) or os.path.isfile(agent_test): |
|
|
validation_result["structure"]["agents"].append(item) |
|
|
|
|
|
if os.path.isfile(agent_train): |
|
|
errors = MultiAgentDatasetValidator.validate_jsonl_file(agent_train) |
|
|
validation_result["errors"].extend([f"{item}/train.jsonl: {e}" for e in errors]) |
|
|
|
|
|
if os.path.isfile(agent_test): |
|
|
errors = MultiAgentDatasetValidator.validate_jsonl_file(agent_test) |
|
|
validation_result["errors"].extend([f"{item}/test.jsonl: {e}" for e in errors]) |
|
|
|
|
|
|
|
|
agents_yaml = os.path.join(dataset_path, "agents.yaml") |
|
|
if os.path.isfile(agents_yaml): |
|
|
validation_result["structure"]["has_agents_yaml"] = True |
|
|
try: |
|
|
with open(agents_yaml, 'r') as f: |
|
|
yaml.safe_load(f) |
|
|
except Exception as e: |
|
|
validation_result["warnings"].append(f"Invalid agents.yaml: {e}") |
|
|
else: |
|
|
validation_result["structure"]["has_agents_yaml"] = False |
|
|
|
|
|
validation_result["valid"] = len(validation_result["errors"]) == 0 |
|
|
|
|
|
return validation_result |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
config = MultiAgentDatasetConfig( |
|
|
dataset_path="/path/to/dataset", |
|
|
balance_agents=True, |
|
|
balance_cap=1000 |
|
|
) |
|
|
|
|
|
|
|
|
loader = MultiAgentDatasetLoader(config) |
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct") |
|
|
|
|
|
try: |
|
|
|
|
|
dataset, agents, stats = loader.load_and_process(tokenizer) |
|
|
|
|
|
print(f"Loaded dataset with {len(agents)} agents:") |
|
|
for agent in agents: |
|
|
print(f" - {agent}") |
|
|
|
|
|
print(f"Dataset stats: {stats}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading dataset: {e}") |
|
|
|