AIFinder / data_loader.py
CompactAI's picture
Upload 13 files
f52234e verified
"""
AIFinder Data Loader
Downloads and parses HuggingFace datasets, extracts assistant responses,
and labels them with is_ai, provider, and model.
"""
import re
import time
from datasets import load_dataset
from tqdm import tqdm
from config import (
DATASET_REGISTRY,
DEEPSEEK_AM_DATASETS,
)
def _parse_msg(msg):
"""Parse a message that may be a dict or a JSON string."""
if isinstance(msg, dict):
return msg
if isinstance(msg, str):
try:
import json
parsed = json.loads(msg)
if isinstance(parsed, dict):
return parsed
except (json.JSONDecodeError, ValueError):
pass
return {}
def _extract_assistant_texts_from_conversations(rows):
"""Extract assistant message content from conversation datasets.
These have a 'conversations' or 'messages' column with list of
{role, content} dicts (or JSON strings encoding such dicts).
"""
texts = []
for row in rows:
convos = row.get("conversations")
if convos is None or (hasattr(convos, "__len__") and len(convos) == 0):
convos = row.get("messages")
if convos is None or (hasattr(convos, "__len__") and len(convos) == 0):
convos = []
parts = []
for msg in convos:
msg = _parse_msg(msg)
role = msg.get("role", "")
content = msg.get("content", "")
if role in ("assistant", "gpt", "model") and content:
parts.append(content)
if parts:
texts.append("\n\n".join(parts))
return texts
def _extract_from_am_dataset(row):
"""Extract assistant text from a-m-team format (messages list with role/content)."""
messages = row.get("messages") or row.get("conversations") or []
parts = []
for msg in messages:
role = msg.get("role", "") if isinstance(msg, dict) else ""
content = msg.get("content", "") if isinstance(msg, dict) else ""
if role == "assistant" and content:
parts.append(content)
return "\n\n".join(parts) if parts else ""
def load_teichai_dataset(dataset_id, provider, model_name, kwargs):
"""Load a single conversation-format dataset and return (texts, providers, models)."""
max_samples = kwargs.get("max_samples")
load_kwargs = {}
if "name" in kwargs:
load_kwargs["name"] = kwargs["name"]
try:
ds = load_dataset(dataset_id, split="train", **load_kwargs)
rows = list(ds)
except Exception as e:
# Fallback: load from auto-converted parquet via HF API
try:
import pandas as pd
url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
df = pd.read_parquet(url)
rows = df.to_dict(orient="records")
except Exception as e2:
print(f" [SKIP] {dataset_id}: {e} / parquet fallback: {e2}")
return [], [], []
if max_samples and len(rows) > max_samples:
import random
random.seed(42)
rows = random.sample(rows, max_samples)
texts = _extract_assistant_texts_from_conversations(rows)
# Filter out empty/too-short texts
filtered = [(t, provider, model_name) for t in texts if len(t) > 50]
if not filtered:
print(f" [SKIP] {dataset_id}: no valid texts extracted")
return [], [], []
t, p, m = zip(*filtered)
return list(t), list(p), list(m)
def load_am_deepseek_dataset(dataset_id, provider, model_name, kwargs):
"""Load a-m-team DeepSeek dataset."""
max_samples = kwargs.get("max_samples")
load_kwargs = {}
if "name" in kwargs:
load_kwargs["name"] = kwargs["name"]
try:
ds = load_dataset(dataset_id, split="train", **load_kwargs)
except Exception as e1:
# Try without name kwarg as fallback
try:
ds = load_dataset(dataset_id, split="train", streaming=True)
rows = []
for row in ds:
rows.append(row)
if max_samples and len(rows) >= max_samples:
break
except Exception as e2:
print(f" [SKIP] {dataset_id}: {e2}")
return [], [], []
else:
rows = list(ds)
if max_samples and len(rows) > max_samples:
rows = rows[:max_samples]
texts = []
for row in rows:
text = _extract_from_am_dataset(row)
if len(text) > 50:
texts.append(text)
providers = [provider] * len(texts)
models = [model_name] * len(texts)
return texts, providers, models
def load_all_data():
"""Load all datasets and return combined lists.
Returns:
texts: list of str
providers: list of str
models: list of str
is_ai: list of int (1=AI, 0=Human)
"""
all_texts = []
all_providers = []
all_models = []
# TeichAI datasets
print("Loading TeichAI datasets...")
for dataset_id, provider, model_name, kwargs in tqdm(
DATASET_REGISTRY, desc="TeichAI"
):
t0 = time.time()
texts, providers, models = load_teichai_dataset(
dataset_id, provider, model_name, kwargs
)
elapsed = time.time() - t0
all_texts.extend(texts)
all_providers.extend(providers)
all_models.extend(models)
print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)")
# DeepSeek a-m-team datasets
print("\nLoading DeepSeek (a-m-team) datasets...")
for dataset_id, provider, model_name, kwargs in tqdm(
DEEPSEEK_AM_DATASETS, desc="DeepSeek-AM"
):
t0 = time.time()
texts, providers, models = load_am_deepseek_dataset(
dataset_id, provider, model_name, kwargs
)
elapsed = time.time() - t0
all_texts.extend(texts)
all_providers.extend(providers)
all_models.extend(models)
print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)")
# Build is_ai labels (all AI)
is_ai = [1] * len(all_texts)
print(f"\n=== Total: {len(all_texts)} samples ===")
# Print per-provider counts
from collections import Counter
prov_counts = Counter(all_providers)
for p, c in sorted(prov_counts.items(), key=lambda x: -x[1]):
print(f" {p}: {c}")
return all_texts, all_providers, all_models, is_ai
if __name__ == "__main__":
texts, providers, models, is_ai = load_all_data()