language-extractor-demo / sentence_sampling.py
DerivedFunction1's picture
update
8a63f11
from __future__ import annotations
import random
from typing import Any, Callable
import pandas as pd
RowToSentence = Callable[[pd.Series], dict[str, Any]]
def _nonempty_frame(frame: pd.DataFrame, *, text_column: str = "text") -> pd.DataFrame:
if text_column not in frame.columns:
raise RuntimeError(f"Expected a {text_column!r} column in the cache frame.")
return frame[frame[text_column].astype(str).str.strip().ne("")]
def build_sentence_bundle(rows: pd.DataFrame, row_to_sentence: RowToSentence) -> dict[str, Any]:
sentences = [row_to_sentence(row) for _, row in rows.iterrows()]
if not sentences:
raise RuntimeError("Unable to sample a sentence bundle.")
combined_text = "\n\n".join(sentence["text"] for sentence in sentences if sentence.get("text"))
primary = sentences[0]
return {
**primary,
"text": combined_text,
"raw_text": combined_text,
"sentences": sentences,
"lang_count": len(sentences),
"langs": [sentence.get("lang_iso2", "") for sentence in sentences],
"lang_iso3s": [sentence.get("lang_iso3", "") for sentence in sentences],
}
def sample_single_group_bundle(
frame: pd.DataFrame,
*,
group_column: str,
row_to_sentence: RowToSentence,
attempts: int = 8,
min_sentences: int = 1,
max_sentences: int = 3,
multi_sentence_probability: float = 0.55,
text_column: str = "text",
allowed_groups: set[str] | frozenset[str] | None = None,
) -> dict[str, Any]:
"""Sample 1-3 sentences from a single random group, often more than one."""
candidate_frame = _nonempty_frame(frame, text_column=text_column)
if allowed_groups is not None:
candidate_frame = candidate_frame[candidate_frame[group_column].isin(allowed_groups)]
distinct_groups = [value for value in candidate_frame[group_column].dropna().unique().tolist() if value]
if not distinct_groups:
raise RuntimeError(f"No usable values were found in {group_column!r}.")
min_sentences = max(1, int(min_sentences))
max_sentences = max(min_sentences, int(max_sentences))
for _ in range(max(1, attempts)):
group = random.choice(distinct_groups)
group_rows = candidate_frame[candidate_frame[group_column] == group]
if group_rows.empty:
continue
sample_size = min_sentences
if len(group_rows) > 1 and random.random() < multi_sentence_probability:
sample_size = random.randint(min(2, max_sentences), min(max_sentences, len(group_rows)))
rows = group_rows.sample(n=min(sample_size, len(group_rows)))
bundle = build_sentence_bundle(rows, row_to_sentence)
if bundle["text"]:
return bundle
raise RuntimeError(f"Unable to sample a random bundle from {group_column!r}.")
def sample_multi_group_bundle(
frame: pd.DataFrame,
*,
group_column: str,
row_to_sentence: RowToSentence,
min_groups: int = 2,
max_groups: int = 3,
min_sentences_per_group: int = 1,
max_sentences_per_group: int = 2,
text_column: str = "text",
allowed_groups: set[str] | frozenset[str] | None = None,
) -> dict[str, Any]:
"""Sample 1-2 sentences from multiple random groups and concatenate them."""
candidate_frame = _nonempty_frame(frame, text_column=text_column)
if allowed_groups is not None:
candidate_frame = candidate_frame[candidate_frame[group_column].isin(allowed_groups)]
distinct_groups = [value for value in candidate_frame[group_column].dropna().unique().tolist() if value]
if not distinct_groups:
raise RuntimeError(f"No usable values were found in {group_column!r}.")
min_groups = max(1, int(min_groups))
max_groups = max(min_groups, int(max_groups))
min_sentences_per_group = max(1, int(min_sentences_per_group))
max_sentences_per_group = max(min_sentences_per_group, int(max_sentences_per_group))
group_count = random.randint(min_groups, max_groups)
random.shuffle(distinct_groups)
chosen_groups = distinct_groups[: min(group_count, len(distinct_groups))]
rows: list[pd.Series] = []
for group in chosen_groups:
group_rows = candidate_frame[candidate_frame[group_column] == group]
if group_rows.empty:
continue
row_count = random.randint(min_sentences_per_group, min(max_sentences_per_group, len(group_rows)))
sampled_rows = group_rows.sample(n=row_count)
rows.extend(row for _, row in sampled_rows.iterrows())
if not rows:
raise RuntimeError(f"Unable to sample a multi-group bundle from {group_column!r}.")
combined_rows = pd.DataFrame(rows)
return build_sentence_bundle(combined_rows, row_to_sentence)