from __future__ import annotations import json import torch from load_aait86m import load_components def main() -> None: components = load_components() cfg = components["config"] model = components["anchor_model"] batch = 2 num_candidates = 3 semantic_dim = int(cfg["semantic_vector_dim"]) semantic_vector = torch.randn(batch, semantic_dim) recent_context_vector = torch.randn(batch, semantic_dim) active_anchor_state = torch.randn(batch, semantic_dim) modality_id = torch.tensor([0, 2], dtype=torch.long) time_delta = torch.tensor([0.0, 4.0], dtype=torch.float32) source_id = torch.tensor([0, 1], dtype=torch.long) candidate_semantic = torch.randn(batch, num_candidates, semantic_dim) candidate_features = torch.randn(batch, num_candidates, 7) candidate_mask = torch.tensor([[True, True, False], [True, True, True]], dtype=torch.bool) with torch.no_grad(): outputs = model( semantic_vector=semantic_vector, recent_context_vector=recent_context_vector, active_anchor_state=active_anchor_state, modality_id=modality_id, time_delta=time_delta, source_id=source_id, candidate_semantic=candidate_semantic, candidate_features=candidate_features, candidate_mask=candidate_mask, ) payload = { "base_checkpoint_keys": sorted(components["base_checkpoint"].keys()), "semantic_vector": list(semantic_vector.shape), "anchor_key": list(outputs["anchor_key"].shape), "anchor_action_logits": list(outputs["action_logits"].shape), "anchor_confidence": list(outputs["anchor_confidence"].shape), "salience_delta": list(outputs["salience_delta"].shape), "bind_logits": list(outputs["bind_logits"].shape), } print(json.dumps(payload, indent=2, sort_keys=True)) if __name__ == "__main__": main()