| | """ |
| | AIFinder Data Loader |
| | Downloads and parses HuggingFace datasets, extracts assistant responses, |
| | and labels them with is_ai, provider, and model. |
| | """ |
| |
|
| | import re |
| | import time |
| | from datasets import load_dataset |
| | from tqdm import tqdm |
| |
|
| | from config import ( |
| | DATASET_REGISTRY, |
| | DEEPSEEK_AM_DATASETS, |
| | ) |
| |
|
| |
|
| | def _parse_msg(msg): |
| | """Parse a message that may be a dict or a JSON string.""" |
| | if isinstance(msg, dict): |
| | return msg |
| | if isinstance(msg, str): |
| | try: |
| | import json |
| |
|
| | parsed = json.loads(msg) |
| | if isinstance(parsed, dict): |
| | return parsed |
| | except (json.JSONDecodeError, ValueError): |
| | pass |
| | return {} |
| |
|
| |
|
| | def _extract_assistant_texts_from_conversations(rows): |
| | """Extract assistant message content from conversation datasets. |
| | These have a 'conversations' or 'messages' column with list of |
| | {role, content} dicts (or JSON strings encoding such dicts). |
| | """ |
| | texts = [] |
| | for row in rows: |
| | convos = row.get("conversations") |
| | if convos is None or (hasattr(convos, "__len__") and len(convos) == 0): |
| | convos = row.get("messages") |
| | if convos is None or (hasattr(convos, "__len__") and len(convos) == 0): |
| | convos = [] |
| | parts = [] |
| | for msg in convos: |
| | msg = _parse_msg(msg) |
| | role = msg.get("role", "") |
| | content = msg.get("content", "") |
| | if role in ("assistant", "gpt", "model") and content: |
| | parts.append(content) |
| | if parts: |
| | texts.append("\n\n".join(parts)) |
| | return texts |
| |
|
| |
|
| | def _extract_from_am_dataset(row): |
| | """Extract assistant text from a-m-team format (messages list with role/content).""" |
| | messages = row.get("messages") or row.get("conversations") or [] |
| | parts = [] |
| | for msg in messages: |
| | role = msg.get("role", "") if isinstance(msg, dict) else "" |
| | content = msg.get("content", "") if isinstance(msg, dict) else "" |
| | if role == "assistant" and content: |
| | parts.append(content) |
| | return "\n\n".join(parts) if parts else "" |
| |
|
| |
|
| | def load_teichai_dataset(dataset_id, provider, model_name, kwargs): |
| | """Load a single conversation-format dataset and return (texts, providers, models).""" |
| | max_samples = kwargs.get("max_samples") |
| | load_kwargs = {} |
| | if "name" in kwargs: |
| | load_kwargs["name"] = kwargs["name"] |
| |
|
| | try: |
| | ds = load_dataset(dataset_id, split="train", **load_kwargs) |
| | rows = list(ds) |
| | except Exception as e: |
| | |
| | try: |
| | import pandas as pd |
| |
|
| | url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" |
| | df = pd.read_parquet(url) |
| | rows = df.to_dict(orient="records") |
| | except Exception as e2: |
| | print(f" [SKIP] {dataset_id}: {e} / parquet fallback: {e2}") |
| | return [], [], [] |
| |
|
| | if max_samples and len(rows) > max_samples: |
| | import random |
| |
|
| | random.seed(42) |
| | rows = random.sample(rows, max_samples) |
| |
|
| | texts = _extract_assistant_texts_from_conversations(rows) |
| |
|
| | |
| | filtered = [(t, provider, model_name) for t in texts if len(t) > 50] |
| | if not filtered: |
| | print(f" [SKIP] {dataset_id}: no valid texts extracted") |
| | return [], [], [] |
| |
|
| | t, p, m = zip(*filtered) |
| | return list(t), list(p), list(m) |
| |
|
| |
|
| | def load_am_deepseek_dataset(dataset_id, provider, model_name, kwargs): |
| | """Load a-m-team DeepSeek dataset.""" |
| | max_samples = kwargs.get("max_samples") |
| | load_kwargs = {} |
| | if "name" in kwargs: |
| | load_kwargs["name"] = kwargs["name"] |
| |
|
| | try: |
| | ds = load_dataset(dataset_id, split="train", **load_kwargs) |
| | except Exception as e1: |
| | |
| | try: |
| | ds = load_dataset(dataset_id, split="train", streaming=True) |
| | rows = [] |
| | for row in ds: |
| | rows.append(row) |
| | if max_samples and len(rows) >= max_samples: |
| | break |
| | except Exception as e2: |
| | print(f" [SKIP] {dataset_id}: {e2}") |
| | return [], [], [] |
| | else: |
| | rows = list(ds) |
| | if max_samples and len(rows) > max_samples: |
| | rows = rows[:max_samples] |
| |
|
| | texts = [] |
| | for row in rows: |
| | text = _extract_from_am_dataset(row) |
| | if len(text) > 50: |
| | texts.append(text) |
| |
|
| | providers = [provider] * len(texts) |
| | models = [model_name] * len(texts) |
| | return texts, providers, models |
| |
|
| |
|
| | def load_all_data(): |
| | """Load all datasets and return combined lists. |
| | |
| | Returns: |
| | texts: list of str |
| | providers: list of str |
| | models: list of str |
| | is_ai: list of int (1=AI, 0=Human) |
| | """ |
| | all_texts = [] |
| | all_providers = [] |
| | all_models = [] |
| |
|
| | |
| | print("Loading TeichAI datasets...") |
| | for dataset_id, provider, model_name, kwargs in tqdm( |
| | DATASET_REGISTRY, desc="TeichAI" |
| | ): |
| | t0 = time.time() |
| | texts, providers, models = load_teichai_dataset( |
| | dataset_id, provider, model_name, kwargs |
| | ) |
| | elapsed = time.time() - t0 |
| | all_texts.extend(texts) |
| | all_providers.extend(providers) |
| | all_models.extend(models) |
| | print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)") |
| |
|
| | |
| | print("\nLoading DeepSeek (a-m-team) datasets...") |
| | for dataset_id, provider, model_name, kwargs in tqdm( |
| | DEEPSEEK_AM_DATASETS, desc="DeepSeek-AM" |
| | ): |
| | t0 = time.time() |
| | texts, providers, models = load_am_deepseek_dataset( |
| | dataset_id, provider, model_name, kwargs |
| | ) |
| | elapsed = time.time() - t0 |
| | all_texts.extend(texts) |
| | all_providers.extend(providers) |
| | all_models.extend(models) |
| | print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)") |
| |
|
| | |
| | is_ai = [1] * len(all_texts) |
| |
|
| | print(f"\n=== Total: {len(all_texts)} samples ===") |
| | |
| | from collections import Counter |
| |
|
| | prov_counts = Counter(all_providers) |
| | for p, c in sorted(prov_counts.items(), key=lambda x: -x[1]): |
| | print(f" {p}: {c}") |
| |
|
| | return all_texts, all_providers, all_models, is_ai |
| |
|
| |
|
| | if __name__ == "__main__": |
| | texts, providers, models, is_ai = load_all_data() |
| |
|