AAIT-86M / example_inference.py
gcoderw's picture
Repackage AAIT-86M as combined TE-86M plus anchor head
82e3c2b verified
from __future__ import annotations
import json
import torch
from load_aait86m import load_components
def main() -> None:
components = load_components()
cfg = components["config"]
model = components["anchor_model"]
batch = 2
num_candidates = 3
semantic_dim = int(cfg["semantic_vector_dim"])
semantic_vector = torch.randn(batch, semantic_dim)
recent_context_vector = torch.randn(batch, semantic_dim)
active_anchor_state = torch.randn(batch, semantic_dim)
modality_id = torch.tensor([0, 2], dtype=torch.long)
time_delta = torch.tensor([0.0, 4.0], dtype=torch.float32)
source_id = torch.tensor([0, 1], dtype=torch.long)
candidate_semantic = torch.randn(batch, num_candidates, semantic_dim)
candidate_features = torch.randn(batch, num_candidates, 7)
candidate_mask = torch.tensor([[True, True, False], [True, True, True]], dtype=torch.bool)
with torch.no_grad():
outputs = model(
semantic_vector=semantic_vector,
recent_context_vector=recent_context_vector,
active_anchor_state=active_anchor_state,
modality_id=modality_id,
time_delta=time_delta,
source_id=source_id,
candidate_semantic=candidate_semantic,
candidate_features=candidate_features,
candidate_mask=candidate_mask,
)
payload = {
"base_checkpoint_keys": sorted(components["base_checkpoint"].keys()),
"semantic_vector": list(semantic_vector.shape),
"anchor_key": list(outputs["anchor_key"].shape),
"anchor_action_logits": list(outputs["action_logits"].shape),
"anchor_confidence": list(outputs["anchor_confidence"].shape),
"salience_delta": list(outputs["salience_delta"].shape),
"bind_logits": list(outputs["bind_logits"].shape),
}
print(json.dumps(payload, indent=2, sort_keys=True))
if __name__ == "__main__":
main()