Spaces:
Sleeping
Sleeping
| """ | |
| Inference script for classifying a single card image. | |
| """ | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| from pathlib import Path | |
| # Import from training module | |
| import sys | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent)) | |
| from src.train.classifier import ( | |
| SetCardClassifier, | |
| NUMBER_NAMES, COLOR_NAMES, SHAPE_NAMES, FILL_NAMES | |
| ) | |
| WEIGHTS_DIR = Path(__file__).parent.parent.parent / "weights" | |
| def load_model(weights_path: Path = None, device: str = None): | |
| """Load trained classifier.""" | |
| if weights_path is None: | |
| weights_path = WEIGHTS_DIR / "classifier_best.pt" | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
| model = SetCardClassifier(pretrained=False) | |
| checkpoint = torch.load(weights_path, map_location=device) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.to(device) | |
| model.eval() | |
| return model, device | |
| def classify_card(image: Image.Image, model, device) -> dict: | |
| """ | |
| Classify a card image. | |
| Returns dict with predicted attributes and confidences. | |
| """ | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| # Get predictions and confidences | |
| result = {} | |
| for key, names in [ | |
| ("number", NUMBER_NAMES), | |
| ("color", COLOR_NAMES), | |
| ("shape", SHAPE_NAMES), | |
| ("fill", FILL_NAMES), | |
| ]: | |
| probs = torch.softmax(outputs[key], dim=1)[0] | |
| pred_idx = probs.argmax().item() | |
| result[key] = { | |
| "value": names[pred_idx], | |
| "confidence": probs[pred_idx].item(), | |
| "all_probs": {name: probs[i].item() for i, name in enumerate(names)}, | |
| } | |
| return result | |
| def main(): | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Classify a Set card image") | |
| parser.add_argument("image", type=str, help="Path to card image") | |
| args = parser.parse_args() | |
| print("Loading model...") | |
| model, device = load_model() | |
| print(f"Classifying {args.image}...") | |
| image = Image.open(args.image).convert("RGB") | |
| result = classify_card(image, model, device) | |
| print("\nPrediction:") | |
| print(f" Number: {result['number']['value']} ({result['number']['confidence']:.1%})") | |
| print(f" Color: {result['color']['value']} ({result['color']['confidence']:.1%})") | |
| print(f" Shape: {result['shape']['value']} ({result['shape']['confidence']:.1%})") | |
| print(f" Fill: {result['fill']['value']} ({result['fill']['confidence']:.1%})") | |
| # Human-readable card name | |
| n = result['number']['value'] | |
| c = result['color']['value'] | |
| s = result['shape']['value'] | |
| f = result['fill']['value'] | |
| print(f"\nCard: {n} {f} {c} {s}(s)") | |
| if __name__ == "__main__": | |
| main() | |