rabukasim / tools /debug /debug_speed.py
trioskosmos's picture
chore: remove large files for HF Space
9bd4ce5
import sys
import time
from pathlib import Path
import torch
# Add project root to find alphazero module
root_dir = Path(__file__).resolve().parent.parent.parent
if str(root_dir) not in sys.path:
sys.path.insert(0, str(root_dir))
from alphazero.alphanet import AlphaNet
def debug_speed():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
print("Initializing model...")
model = AlphaNet(num_layers=10, embed_dim=512).to(device)
model.eval()
print("Creating tensors...")
x = torch.randn(1, 20500).to(device)
m = torch.ones(1, 22000, dtype=torch.bool).to(device)
print("Starting forward pass...")
torch.cuda.synchronize()
start = time.time()
with torch.no_grad():
p, v = model(x, mask=m)
torch.cuda.synchronize()
end = time.time()
print(f"Forward pass completed in {end - start:.4f}s")
print(f"Policy shape: {p.shape}, Value shape: {v.shape}")
if __name__ == "__main__":
debug_speed()