Richard CHEAM
Deploy customer feedback intelligence demo
73b0303
"""Load Amazon polarity reviews from Hugging Face datasets."""
from __future__ import annotations
import random
from collections import Counter
from statistics import mean, median
from typing import Any, Iterable, Literal
from feedback_intelligence.types import ReviewRecord
AmazonSplit = Literal["train", "test"]
AMAZON_POLARITY_DATASET = "SetFit/amazon_polarity"
AMAZON_LABEL_MAP = {0: "negative", 1: "positive"}
def load_amazon_polarity_reviews(
split: AmazonSplit,
sample_size: int | None = None,
seed: int = 42,
include_title: bool = True,
dataset_name: str = AMAZON_POLARITY_DATASET,
shuffle_buffer_size: int = 10_000,
load_dataset_fn=None,
) -> list[ReviewRecord]:
"""Load Amazon polarity reviews with deterministic balanced sampling."""
if load_dataset_fn is None:
from datasets import load_dataset as load_dataset_fn
if sample_size is None:
dataset = load_dataset_fn(dataset_name, split=split)
rows = list(dataset)
return _rows_to_review_records(
rows=rows,
split=split,
include_title=include_title,
source=dataset_name,
)
per_label = sample_size // 2
if per_label <= 0:
raise ValueError("sample_size must be at least 2 for balanced sampling.")
dataset = load_dataset_fn(dataset_name, split=split, streaming=True)
if hasattr(dataset, "shuffle"):
dataset = dataset.shuffle(seed=seed, buffer_size=shuffle_buffer_size)
sampled_rows = _sample_balanced_rows(dataset=dataset, per_label=per_label)
if len(sampled_rows) < per_label * 2:
raise ValueError(
f"Unable to collect {per_label} rows per label from {dataset_name} split={split}."
)
records = _rows_to_review_records(
rows=sampled_rows,
split=split,
include_title=include_title,
source=dataset_name,
)
rng = random.Random(seed)
rng.shuffle(records)
return records
def summarize_reviews(records: list[ReviewRecord]) -> dict[str, object]:
"""Compute a compact dataset summary for CLI inspection."""
if not records:
return {
"rows": 0,
"label_distribution": {},
"word_count": {"min": 0, "median": 0, "mean": 0, "max": 0},
}
word_counts = [record.word_count for record in records]
labels = Counter(record.label for record in records)
return {
"rows": len(records),
"label_distribution": dict(labels),
"word_count": {
"min": min(word_counts),
"median": int(median(word_counts)),
"mean": round(mean(word_counts), 2),
"max": max(word_counts),
},
}
def _sample_balanced_rows(
dataset: Iterable[dict[str, Any]],
per_label: int,
) -> list[dict[str, Any]]:
selected_rows: list[dict[str, Any]] = []
counts = {0: 0, 1: 0}
for row in dataset:
raw_label = int(row["label"])
if raw_label not in counts:
continue
if counts[raw_label] >= per_label:
continue
selected_rows.append(dict(row))
counts[raw_label] += 1
if counts[0] >= per_label and counts[1] >= per_label:
break
return selected_rows
def _rows_to_review_records(
rows: list[dict[str, Any]],
split: str,
include_title: bool,
source: str,
) -> list[ReviewRecord]:
return [
_review_record_from_row(
row=row,
row_index=index,
split=split,
include_title=include_title,
source=source,
)
for index, row in enumerate(rows)
]
def _review_record_from_row(
row: dict[str, Any],
row_index: int,
split: str,
include_title: bool,
source: str,
) -> ReviewRecord:
raw_label = int(row["label"])
label = AMAZON_LABEL_MAP[raw_label]
review_id = str(row.get("review_id") or row.get("id") or f"{split}-{row_index}")
title = str(row.get("title", "")).strip()
body = str(row.get("content") or row.get("text") or "").strip()
text = _combine_title_and_body(title=title, body=body, include_title=include_title)
if not text:
raise ValueError(f"Amazon review row {review_id} does not contain text.")
return ReviewRecord(
review_id=review_id,
text=text,
label=label,
split=split,
source=source,
)
def _combine_title_and_body(title: str, body: str, include_title: bool) -> str:
if include_title and title and body:
return f"{title}\n\n{body}"
if include_title and title:
return title
return body