| """Load Amazon polarity reviews from Hugging Face datasets.""" |
|
|
| from __future__ import annotations |
|
|
| import random |
| from collections import Counter |
| from statistics import mean, median |
| from typing import Any, Iterable, Literal |
|
|
| from feedback_intelligence.types import ReviewRecord |
|
|
| AmazonSplit = Literal["train", "test"] |
| AMAZON_POLARITY_DATASET = "SetFit/amazon_polarity" |
| AMAZON_LABEL_MAP = {0: "negative", 1: "positive"} |
|
|
|
|
| def load_amazon_polarity_reviews( |
| split: AmazonSplit, |
| sample_size: int | None = None, |
| seed: int = 42, |
| include_title: bool = True, |
| dataset_name: str = AMAZON_POLARITY_DATASET, |
| shuffle_buffer_size: int = 10_000, |
| load_dataset_fn=None, |
| ) -> list[ReviewRecord]: |
| """Load Amazon polarity reviews with deterministic balanced sampling.""" |
| if load_dataset_fn is None: |
| from datasets import load_dataset as load_dataset_fn |
|
|
| if sample_size is None: |
| dataset = load_dataset_fn(dataset_name, split=split) |
| rows = list(dataset) |
| return _rows_to_review_records( |
| rows=rows, |
| split=split, |
| include_title=include_title, |
| source=dataset_name, |
| ) |
|
|
| per_label = sample_size // 2 |
| if per_label <= 0: |
| raise ValueError("sample_size must be at least 2 for balanced sampling.") |
|
|
| dataset = load_dataset_fn(dataset_name, split=split, streaming=True) |
| if hasattr(dataset, "shuffle"): |
| dataset = dataset.shuffle(seed=seed, buffer_size=shuffle_buffer_size) |
|
|
| sampled_rows = _sample_balanced_rows(dataset=dataset, per_label=per_label) |
| if len(sampled_rows) < per_label * 2: |
| raise ValueError( |
| f"Unable to collect {per_label} rows per label from {dataset_name} split={split}." |
| ) |
|
|
| records = _rows_to_review_records( |
| rows=sampled_rows, |
| split=split, |
| include_title=include_title, |
| source=dataset_name, |
| ) |
| rng = random.Random(seed) |
| rng.shuffle(records) |
| return records |
|
|
|
|
| def summarize_reviews(records: list[ReviewRecord]) -> dict[str, object]: |
| """Compute a compact dataset summary for CLI inspection.""" |
| if not records: |
| return { |
| "rows": 0, |
| "label_distribution": {}, |
| "word_count": {"min": 0, "median": 0, "mean": 0, "max": 0}, |
| } |
|
|
| word_counts = [record.word_count for record in records] |
| labels = Counter(record.label for record in records) |
| return { |
| "rows": len(records), |
| "label_distribution": dict(labels), |
| "word_count": { |
| "min": min(word_counts), |
| "median": int(median(word_counts)), |
| "mean": round(mean(word_counts), 2), |
| "max": max(word_counts), |
| }, |
| } |
|
|
|
|
| def _sample_balanced_rows( |
| dataset: Iterable[dict[str, Any]], |
| per_label: int, |
| ) -> list[dict[str, Any]]: |
| selected_rows: list[dict[str, Any]] = [] |
| counts = {0: 0, 1: 0} |
|
|
| for row in dataset: |
| raw_label = int(row["label"]) |
| if raw_label not in counts: |
| continue |
| if counts[raw_label] >= per_label: |
| continue |
| selected_rows.append(dict(row)) |
| counts[raw_label] += 1 |
| if counts[0] >= per_label and counts[1] >= per_label: |
| break |
|
|
| return selected_rows |
|
|
|
|
| def _rows_to_review_records( |
| rows: list[dict[str, Any]], |
| split: str, |
| include_title: bool, |
| source: str, |
| ) -> list[ReviewRecord]: |
| return [ |
| _review_record_from_row( |
| row=row, |
| row_index=index, |
| split=split, |
| include_title=include_title, |
| source=source, |
| ) |
| for index, row in enumerate(rows) |
| ] |
|
|
|
|
| def _review_record_from_row( |
| row: dict[str, Any], |
| row_index: int, |
| split: str, |
| include_title: bool, |
| source: str, |
| ) -> ReviewRecord: |
| raw_label = int(row["label"]) |
| label = AMAZON_LABEL_MAP[raw_label] |
| review_id = str(row.get("review_id") or row.get("id") or f"{split}-{row_index}") |
| title = str(row.get("title", "")).strip() |
| body = str(row.get("content") or row.get("text") or "").strip() |
| text = _combine_title_and_body(title=title, body=body, include_title=include_title) |
| if not text: |
| raise ValueError(f"Amazon review row {review_id} does not contain text.") |
|
|
| return ReviewRecord( |
| review_id=review_id, |
| text=text, |
| label=label, |
| split=split, |
| source=source, |
| ) |
|
|
|
|
| def _combine_title_and_body(title: str, body: str, include_title: bool) -> str: |
| if include_title and title and body: |
| return f"{title}\n\n{body}" |
| if include_title and title: |
| return title |
| return body |
|
|