File size: 3,175 Bytes
76de008
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import annotations

import random
from typing import Any, Dict, List, Sequence


def sample_rows(
    rows: Sequence[Dict[str, Any]],
    *,
    count: int,
    rng: random.Random,
    offset: int,
) -> List[Dict[str, Any]]:
    total = len(rows)
    if total == 0 or int(count) <= 0:
        return []
    if int(count) <= total:
        indices = list(range(total))
        rng.shuffle(indices)
        chosen = indices[: int(count)]
        return [dict(rows[idx]) for idx in chosen]
    out: List[Dict[str, Any]] = []
    for item_idx in range(int(count)):
        out.append(dict(rows[(int(offset) + item_idx) % total]))
    rng.shuffle(out)
    return out


def annotate_mixed_curriculum_row(
    row: Dict[str, Any],
    *,
    source_stage: int,
    target_stage: int,
    fraction: float,
) -> Dict[str, Any]:
    out = dict(row)
    metadata = dict(out.get("metadata", {}))
    metadata["mixed_curriculum_target_stage"] = int(target_stage)
    metadata["mixed_curriculum_source_stage"] = int(source_stage)
    metadata["mixed_curriculum_fraction"] = float(fraction)
    metadata["mixed_curriculum_runtime"] = True
    out["metadata"] = metadata
    return out


def build_two_stage_mixed_rows(
    stage1_rows: Sequence[Dict[str, Any]],
    stage2_rows: Sequence[Dict[str, Any]],
    *,
    stage1_ratio: float,
    stage2_ratio: float,
    seed: int,
    target_stage: int,
    total_rows: int = 0,
) -> List[Dict[str, Any]]:
    weight1 = max(0.0, float(stage1_ratio))
    weight2 = max(0.0, float(stage2_ratio))
    weight_sum = weight1 + weight2
    if weight_sum <= 0.0:
        raise ValueError("At least one mixed curriculum ratio must be positive.")

    if int(total_rows) > 0:
        target_count = int(total_rows)
    else:
        target_count = max(len(stage1_rows), len(stage2_rows))
    target_count = max(1, int(target_count))

    count1 = int(round(target_count * (weight1 / weight_sum)))
    count1 = min(max(count1, 0), target_count)
    count2 = int(target_count - count1)
    rng = random.Random(int(seed))

    mixed: List[Dict[str, Any]] = []
    for row in sample_rows(stage1_rows, count=count1, rng=rng, offset=1009):
        mixed.append(
            annotate_mixed_curriculum_row(
                row,
                source_stage=1,
                target_stage=int(target_stage),
                fraction=(weight1 / weight_sum),
            )
        )
    for row in sample_rows(stage2_rows, count=count2, rng=rng, offset=2003):
        mixed.append(
            annotate_mixed_curriculum_row(
                row,
                source_stage=2,
                target_stage=int(target_stage),
                fraction=(weight2 / weight_sum),
            )
        )
    rng.shuffle(mixed)
    return mixed


def training_stage_i_for_row(row: Dict[str, Any], default_stage_i: int) -> int:
    metadata = dict(row.get("metadata", {}))
    source_stage = metadata.get("mixed_curriculum_source_stage")
    if source_stage is not None:
        return max(1, int(source_stage))
    row_stage = row.get("stage_i")
    if row_stage is not None:
        return max(1, int(row_stage))
    return max(1, int(default_stage_i))