curriculum-cot-code / mixed_curriculum_cot /runtime_mixed_curriculum.py
Avra98's picture
Initial code dump (rebuttal-ready snapshot)
76de008 verified
from __future__ import annotations
import random
from typing import Any, Dict, List, Sequence
def sample_rows(
rows: Sequence[Dict[str, Any]],
*,
count: int,
rng: random.Random,
offset: int,
) -> List[Dict[str, Any]]:
total = len(rows)
if total == 0 or int(count) <= 0:
return []
if int(count) <= total:
indices = list(range(total))
rng.shuffle(indices)
chosen = indices[: int(count)]
return [dict(rows[idx]) for idx in chosen]
out: List[Dict[str, Any]] = []
for item_idx in range(int(count)):
out.append(dict(rows[(int(offset) + item_idx) % total]))
rng.shuffle(out)
return out
def annotate_mixed_curriculum_row(
row: Dict[str, Any],
*,
source_stage: int,
target_stage: int,
fraction: float,
) -> Dict[str, Any]:
out = dict(row)
metadata = dict(out.get("metadata", {}))
metadata["mixed_curriculum_target_stage"] = int(target_stage)
metadata["mixed_curriculum_source_stage"] = int(source_stage)
metadata["mixed_curriculum_fraction"] = float(fraction)
metadata["mixed_curriculum_runtime"] = True
out["metadata"] = metadata
return out
def build_two_stage_mixed_rows(
stage1_rows: Sequence[Dict[str, Any]],
stage2_rows: Sequence[Dict[str, Any]],
*,
stage1_ratio: float,
stage2_ratio: float,
seed: int,
target_stage: int,
total_rows: int = 0,
) -> List[Dict[str, Any]]:
weight1 = max(0.0, float(stage1_ratio))
weight2 = max(0.0, float(stage2_ratio))
weight_sum = weight1 + weight2
if weight_sum <= 0.0:
raise ValueError("At least one mixed curriculum ratio must be positive.")
if int(total_rows) > 0:
target_count = int(total_rows)
else:
target_count = max(len(stage1_rows), len(stage2_rows))
target_count = max(1, int(target_count))
count1 = int(round(target_count * (weight1 / weight_sum)))
count1 = min(max(count1, 0), target_count)
count2 = int(target_count - count1)
rng = random.Random(int(seed))
mixed: List[Dict[str, Any]] = []
for row in sample_rows(stage1_rows, count=count1, rng=rng, offset=1009):
mixed.append(
annotate_mixed_curriculum_row(
row,
source_stage=1,
target_stage=int(target_stage),
fraction=(weight1 / weight_sum),
)
)
for row in sample_rows(stage2_rows, count=count2, rng=rng, offset=2003):
mixed.append(
annotate_mixed_curriculum_row(
row,
source_stage=2,
target_stage=int(target_stage),
fraction=(weight2 / weight_sum),
)
)
rng.shuffle(mixed)
return mixed
def training_stage_i_for_row(row: Dict[str, Any], default_stage_i: int) -> int:
metadata = dict(row.get("metadata", {}))
source_stage = metadata.get("mixed_curriculum_source_stage")
if source_stage is not None:
return max(1, int(source_stage))
row_stage = row.get("stage_i")
if row_stage is not None:
return max(1, int(row_stage))
return max(1, int(default_stage_i))