File size: 373 Bytes
3d66487
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
import torch.nn as nn

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")

# Test a model
model = nn.Linear(10, 5)
x = torch.randn(3, 10)
y = model(x)
print(f"Input on: {x.device}")
print(f"Output on: {y.device}")
print("✓ CPU inference works!")