File size: 1,485 Bytes
15f353f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import * as ort from "onnxruntime-web";
import { ModelInferencer } from "./inc/modelInferencer";

async function main() {
    const modelPath = "/home/camus/work/trigo/trigo-web/public/onnx/20251204-trigo-value-gpt2-l6-h64-251125-lr500/GPT2CausalLM_ep0019_evaluation.onnx";

    const inferencer = new ModelInferencer(ort.Tensor, { seqLen: 256 });
    const session = await ort.InferenceSession.create(modelPath);
    inferencer.setSession(session);

    // Helper to tokenize TGN
    const tokenize = (tgn: string): number[] => {
        const START = 1;
        const END = 2;
        const tokens: number[] = [START];
        for (let i = 0; i < tgn.length; i++) {
            tokens.push(tgn.charCodeAt(i));
        }
        tokens.push(END);
        // Pad to 256
        while (tokens.length < 256) {
            tokens.push(0);
        }
        return tokens;
    };

    const testPositions = [
        { tgn: "[Board 5x5]\n\n", desc: "Empty board" },
        { tgn: "[Board 5x5]\n\n1. Pass ", desc: "After Black Pass" },
        { tgn: "[Board 5x5]\n\n1. aa ", desc: "After Black aa" },
        { tgn: "[Board 5x5]\n\n1. Pass Pass\n", desc: "Both pass" },
        { tgn: "[Board 5x5]\n\n1. aa zz\n2. ", desc: "After aa zz" },
    ];

    for (const pos of testPositions) {
        const tokens = tokenize(pos.tgn);
        const value = await inferencer.runValuePrediction(tokens);
        console.log(pos.desc + ": value = " + value);
    }
}

main().catch(console.error);