| from __future__ import annotations | |
| import json | |
| import torch | |
| from load_aait86m import load_components | |
| def main() -> None: | |
| components = load_components() | |
| cfg = components["config"] | |
| model = components["anchor_model"] | |
| batch = 2 | |
| num_candidates = 3 | |
| semantic_dim = int(cfg["semantic_vector_dim"]) | |
| semantic_vector = torch.randn(batch, semantic_dim) | |
| recent_context_vector = torch.randn(batch, semantic_dim) | |
| active_anchor_state = torch.randn(batch, semantic_dim) | |
| modality_id = torch.tensor([0, 2], dtype=torch.long) | |
| time_delta = torch.tensor([0.0, 4.0], dtype=torch.float32) | |
| source_id = torch.tensor([0, 1], dtype=torch.long) | |
| candidate_semantic = torch.randn(batch, num_candidates, semantic_dim) | |
| candidate_features = torch.randn(batch, num_candidates, 7) | |
| candidate_mask = torch.tensor([[True, True, False], [True, True, True]], dtype=torch.bool) | |
| with torch.no_grad(): | |
| outputs = model( | |
| semantic_vector=semantic_vector, | |
| recent_context_vector=recent_context_vector, | |
| active_anchor_state=active_anchor_state, | |
| modality_id=modality_id, | |
| time_delta=time_delta, | |
| source_id=source_id, | |
| candidate_semantic=candidate_semantic, | |
| candidate_features=candidate_features, | |
| candidate_mask=candidate_mask, | |
| ) | |
| payload = { | |
| "base_checkpoint_keys": sorted(components["base_checkpoint"].keys()), | |
| "semantic_vector": list(semantic_vector.shape), | |
| "anchor_key": list(outputs["anchor_key"].shape), | |
| "anchor_action_logits": list(outputs["action_logits"].shape), | |
| "anchor_confidence": list(outputs["anchor_confidence"].shape), | |
| "salience_delta": list(outputs["salience_delta"].shape), | |
| "bind_logits": list(outputs["bind_logits"].shape), | |
| } | |
| print(json.dumps(payload, indent=2, sort_keys=True)) | |
| if __name__ == "__main__": | |
| main() | |