| from __future__ import annotations |
|
|
| import random |
| from typing import Any, Dict, List, Sequence |
|
|
|
|
| def sample_rows( |
| rows: Sequence[Dict[str, Any]], |
| *, |
| count: int, |
| rng: random.Random, |
| offset: int, |
| ) -> List[Dict[str, Any]]: |
| total = len(rows) |
| if total == 0 or int(count) <= 0: |
| return [] |
| if int(count) <= total: |
| indices = list(range(total)) |
| rng.shuffle(indices) |
| chosen = indices[: int(count)] |
| return [dict(rows[idx]) for idx in chosen] |
| out: List[Dict[str, Any]] = [] |
| for item_idx in range(int(count)): |
| out.append(dict(rows[(int(offset) + item_idx) % total])) |
| rng.shuffle(out) |
| return out |
|
|
|
|
| def annotate_mixed_curriculum_row( |
| row: Dict[str, Any], |
| *, |
| source_stage: int, |
| target_stage: int, |
| fraction: float, |
| ) -> Dict[str, Any]: |
| out = dict(row) |
| metadata = dict(out.get("metadata", {})) |
| metadata["mixed_curriculum_target_stage"] = int(target_stage) |
| metadata["mixed_curriculum_source_stage"] = int(source_stage) |
| metadata["mixed_curriculum_fraction"] = float(fraction) |
| metadata["mixed_curriculum_runtime"] = True |
| out["metadata"] = metadata |
| return out |
|
|
|
|
| def build_two_stage_mixed_rows( |
| stage1_rows: Sequence[Dict[str, Any]], |
| stage2_rows: Sequence[Dict[str, Any]], |
| *, |
| stage1_ratio: float, |
| stage2_ratio: float, |
| seed: int, |
| target_stage: int, |
| total_rows: int = 0, |
| ) -> List[Dict[str, Any]]: |
| weight1 = max(0.0, float(stage1_ratio)) |
| weight2 = max(0.0, float(stage2_ratio)) |
| weight_sum = weight1 + weight2 |
| if weight_sum <= 0.0: |
| raise ValueError("At least one mixed curriculum ratio must be positive.") |
|
|
| if int(total_rows) > 0: |
| target_count = int(total_rows) |
| else: |
| target_count = max(len(stage1_rows), len(stage2_rows)) |
| target_count = max(1, int(target_count)) |
|
|
| count1 = int(round(target_count * (weight1 / weight_sum))) |
| count1 = min(max(count1, 0), target_count) |
| count2 = int(target_count - count1) |
| rng = random.Random(int(seed)) |
|
|
| mixed: List[Dict[str, Any]] = [] |
| for row in sample_rows(stage1_rows, count=count1, rng=rng, offset=1009): |
| mixed.append( |
| annotate_mixed_curriculum_row( |
| row, |
| source_stage=1, |
| target_stage=int(target_stage), |
| fraction=(weight1 / weight_sum), |
| ) |
| ) |
| for row in sample_rows(stage2_rows, count=count2, rng=rng, offset=2003): |
| mixed.append( |
| annotate_mixed_curriculum_row( |
| row, |
| source_stage=2, |
| target_stage=int(target_stage), |
| fraction=(weight2 / weight_sum), |
| ) |
| ) |
| rng.shuffle(mixed) |
| return mixed |
|
|
|
|
| def training_stage_i_for_row(row: Dict[str, Any], default_stage_i: int) -> int: |
| metadata = dict(row.get("metadata", {})) |
| source_stage = metadata.get("mixed_curriculum_source_stage") |
| if source_stage is not None: |
| return max(1, int(source_stage)) |
| row_stage = row.get("stage_i") |
| if row_stage is not None: |
| return max(1, int(row_stage)) |
| return max(1, int(default_stage_i)) |
|
|