File size: 3,085 Bytes
fc93158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import type { SkynetCausalValenceLabel } from "../causal-valence/episode-ledger.js";
import {
  encodeSkynetRuntimeTrajectoryFeatures,
  type SkynetRuntimeTrajectorySample,
} from "./trajectory-builder.js";

export type SkynetRuntimeObserverModel = {
  labels: SkynetCausalValenceLabel[];
  centroids: Record<SkynetCausalValenceLabel, number[]>;
  counts: Record<SkynetCausalValenceLabel, number>;
};

export type SkynetRuntimeObserverPrediction = {
  label: SkynetCausalValenceLabel;
  scores: Record<SkynetCausalValenceLabel, number>;
};

const LABELS: SkynetCausalValenceLabel[] = ["progress", "relief", "stall", "frustration", "damage"];

function zeroVector(length: number): number[] {
  return Array.from({ length }, () => 0);
}

function cosineSimilarity(a: number[], b: number[]): number {
  let dot = 0;
  let normA = 0;
  let normB = 0;
  for (let index = 0; index < a.length; index += 1) {
    const av = a[index] ?? 0;
    const bv = b[index] ?? 0;
    dot += av * bv;
    normA += av * av;
    normB += bv * bv;
  }
  if (normA === 0 || normB === 0) {
    return 0;
  }
  return dot / (Math.sqrt(normA) * Math.sqrt(normB));
}

export function trainSkynetRuntimeObserverModel(
  samples: SkynetRuntimeTrajectorySample[],
): SkynetRuntimeObserverModel | null {
  if (samples.length === 0) {
    return null;
  }
  const vectorLength = encodeSkynetRuntimeTrajectoryFeatures(samples[0]).length;
  const sums = LABELS.reduce(
    (acc, label) => {
      acc[label] = zeroVector(vectorLength);
      return acc;
    },
    {} as Record<SkynetCausalValenceLabel, number[]>,
  );
  const counts = LABELS.reduce(
    (acc, label) => {
      acc[label] = 0;
      return acc;
    },
    {} as Record<SkynetCausalValenceLabel, number>,
  );

  for (const sample of samples) {
    const vector = encodeSkynetRuntimeTrajectoryFeatures(sample);
    counts[sample.targetLabel] += 1;
    for (let index = 0; index < vector.length; index += 1) {
      sums[sample.targetLabel][index] += vector[index] ?? 0;
    }
  }

  const centroids = LABELS.reduce(
    (acc, label) => {
      const count = counts[label];
      acc[label] = count > 0 ? sums[label].map((value) => value / count) : zeroVector(vectorLength);
      return acc;
    },
    {} as Record<SkynetCausalValenceLabel, number[]>,
  );

  return {
    labels: LABELS.filter((label) => counts[label] > 0),
    centroids,
    counts,
  };
}

export function predictSkynetRuntimeObserverLabel(
  model: SkynetRuntimeObserverModel,
  sample: SkynetRuntimeTrajectorySample,
): SkynetRuntimeObserverPrediction {
  const vector = encodeSkynetRuntimeTrajectoryFeatures(sample);
  const scores = model.labels.reduce(
    (acc, label) => {
      acc[label] = cosineSimilarity(vector, model.centroids[label]);
      return acc;
    },
    {} as Record<SkynetCausalValenceLabel, number>,
  );
  const label =
    model.labels
      .slice()
      .sort(
        (left, right) =>
          (scores[right] ?? Number.NEGATIVE_INFINITY) - (scores[left] ?? Number.NEGATIVE_INFINITY),
      )
      .at(0) ?? "stall";
  return { label, scores };
}