File size: 900 Bytes
0d964c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import numpy as np
from sports.common.team import TeamClassifier

# Load the trained TeamClassifier from file
model = torch.load("team_classifier.pth", map_location="cpu")

# Example: Predict team for a list of player crops (numpy arrays)
def predict_teams(crops):
    """
    Predicts team assignments for a list of player crops (numpy arrays).
    Args:
        crops (List[np.ndarray]): List of player crops as numpy arrays.
    Returns:
        np.ndarray: Predicted team labels (0 or 1)
    """
    return model.predict(crops)

if __name__ == "__main__":
    # Example usage: load a crop and predict
    # Replace this with your own image loading logic
    # For demonstration, we use a dummy crop
    dummy_crop = np.zeros((224, 224, 3), dtype=np.uint8)  # Replace with real crop
    crops = [dummy_crop]
    preds = predict_teams(crops)
    print("Predicted team labels:", preds)