File size: 1,946 Bytes
9cf1c17
 
 
 
 
 
82e3c2b
9cf1c17
 
 
82e3c2b
 
 
9cf1c17
 
 
82e3c2b
9cf1c17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82e3c2b
 
 
 
 
 
 
9cf1c17
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import annotations

import json

import torch

from load_aait86m import load_components


def main() -> None:
    components = load_components()
    cfg = components["config"]
    model = components["anchor_model"]

    batch = 2
    num_candidates = 3
    semantic_dim = int(cfg["semantic_vector_dim"])

    semantic_vector = torch.randn(batch, semantic_dim)
    recent_context_vector = torch.randn(batch, semantic_dim)
    active_anchor_state = torch.randn(batch, semantic_dim)
    modality_id = torch.tensor([0, 2], dtype=torch.long)
    time_delta = torch.tensor([0.0, 4.0], dtype=torch.float32)
    source_id = torch.tensor([0, 1], dtype=torch.long)
    candidate_semantic = torch.randn(batch, num_candidates, semantic_dim)
    candidate_features = torch.randn(batch, num_candidates, 7)
    candidate_mask = torch.tensor([[True, True, False], [True, True, True]], dtype=torch.bool)

    with torch.no_grad():
        outputs = model(
            semantic_vector=semantic_vector,
            recent_context_vector=recent_context_vector,
            active_anchor_state=active_anchor_state,
            modality_id=modality_id,
            time_delta=time_delta,
            source_id=source_id,
            candidate_semantic=candidate_semantic,
            candidate_features=candidate_features,
            candidate_mask=candidate_mask,
        )

    payload = {
        "base_checkpoint_keys": sorted(components["base_checkpoint"].keys()),
        "semantic_vector": list(semantic_vector.shape),
        "anchor_key": list(outputs["anchor_key"].shape),
        "anchor_action_logits": list(outputs["action_logits"].shape),
        "anchor_confidence": list(outputs["anchor_confidence"].shape),
        "salience_delta": list(outputs["salience_delta"].shape),
        "bind_logits": list(outputs["bind_logits"].shape),
    }
    print(json.dumps(payload, indent=2, sort_keys=True))


if __name__ == "__main__":
    main()