"""Load Amazon polarity reviews from Hugging Face datasets.""" from __future__ import annotations import random from collections import Counter from statistics import mean, median from typing import Any, Iterable, Literal from feedback_intelligence.types import ReviewRecord AmazonSplit = Literal["train", "test"] AMAZON_POLARITY_DATASET = "SetFit/amazon_polarity" AMAZON_LABEL_MAP = {0: "negative", 1: "positive"} def load_amazon_polarity_reviews( split: AmazonSplit, sample_size: int | None = None, seed: int = 42, include_title: bool = True, dataset_name: str = AMAZON_POLARITY_DATASET, shuffle_buffer_size: int = 10_000, load_dataset_fn=None, ) -> list[ReviewRecord]: """Load Amazon polarity reviews with deterministic balanced sampling.""" if load_dataset_fn is None: from datasets import load_dataset as load_dataset_fn if sample_size is None: dataset = load_dataset_fn(dataset_name, split=split) rows = list(dataset) return _rows_to_review_records( rows=rows, split=split, include_title=include_title, source=dataset_name, ) per_label = sample_size // 2 if per_label <= 0: raise ValueError("sample_size must be at least 2 for balanced sampling.") dataset = load_dataset_fn(dataset_name, split=split, streaming=True) if hasattr(dataset, "shuffle"): dataset = dataset.shuffle(seed=seed, buffer_size=shuffle_buffer_size) sampled_rows = _sample_balanced_rows(dataset=dataset, per_label=per_label) if len(sampled_rows) < per_label * 2: raise ValueError( f"Unable to collect {per_label} rows per label from {dataset_name} split={split}." ) records = _rows_to_review_records( rows=sampled_rows, split=split, include_title=include_title, source=dataset_name, ) rng = random.Random(seed) rng.shuffle(records) return records def summarize_reviews(records: list[ReviewRecord]) -> dict[str, object]: """Compute a compact dataset summary for CLI inspection.""" if not records: return { "rows": 0, "label_distribution": {}, "word_count": {"min": 0, "median": 0, "mean": 0, "max": 0}, } word_counts = [record.word_count for record in records] labels = Counter(record.label for record in records) return { "rows": len(records), "label_distribution": dict(labels), "word_count": { "min": min(word_counts), "median": int(median(word_counts)), "mean": round(mean(word_counts), 2), "max": max(word_counts), }, } def _sample_balanced_rows( dataset: Iterable[dict[str, Any]], per_label: int, ) -> list[dict[str, Any]]: selected_rows: list[dict[str, Any]] = [] counts = {0: 0, 1: 0} for row in dataset: raw_label = int(row["label"]) if raw_label not in counts: continue if counts[raw_label] >= per_label: continue selected_rows.append(dict(row)) counts[raw_label] += 1 if counts[0] >= per_label and counts[1] >= per_label: break return selected_rows def _rows_to_review_records( rows: list[dict[str, Any]], split: str, include_title: bool, source: str, ) -> list[ReviewRecord]: return [ _review_record_from_row( row=row, row_index=index, split=split, include_title=include_title, source=source, ) for index, row in enumerate(rows) ] def _review_record_from_row( row: dict[str, Any], row_index: int, split: str, include_title: bool, source: str, ) -> ReviewRecord: raw_label = int(row["label"]) label = AMAZON_LABEL_MAP[raw_label] review_id = str(row.get("review_id") or row.get("id") or f"{split}-{row_index}") title = str(row.get("title", "")).strip() body = str(row.get("content") or row.get("text") or "").strip() text = _combine_title_and_body(title=title, body=body, include_title=include_title) if not text: raise ValueError(f"Amazon review row {review_id} does not contain text.") return ReviewRecord( review_id=review_id, text=text, label=label, split=split, source=source, ) def _combine_title_and_body(title: str, body: str, include_title: bool) -> str: if include_title and title and body: return f"{title}\n\n{body}" if include_title and title: return title return body