File size: 900 Bytes
0d964c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch
import numpy as np
from sports.common.team import TeamClassifier
# Load the trained TeamClassifier from file
model = torch.load("team_classifier.pth", map_location="cpu")
# Example: Predict team for a list of player crops (numpy arrays)
def predict_teams(crops):
"""
Predicts team assignments for a list of player crops (numpy arrays).
Args:
crops (List[np.ndarray]): List of player crops as numpy arrays.
Returns:
np.ndarray: Predicted team labels (0 or 1)
"""
return model.predict(crops)
if __name__ == "__main__":
# Example usage: load a crop and predict
# Replace this with your own image loading logic
# For demonstration, we use a dummy crop
dummy_crop = np.zeros((224, 224, 3), dtype=np.uint8) # Replace with real crop
crops = [dummy_crop]
preds = predict_teams(crops)
print("Predicted team labels:", preds)
|