Spaces:
Runtime error
Runtime error
| import torch | |
| from loguru import logger | |
| from src.model import LitEfficientNet | |
| from src.dataloader import MNISTDataModule | |
| from torchmetrics.classification import Accuracy | |
| from pathlib import Path | |
| from src.utils.aws_s3_services import S3Handler | |
| # Configure Loguru to save logs to the logs/ directory | |
| logger.add("logs/test.log", rotation="1 MB", level="INFO") | |
| def infer(checkpoint_path, image): | |
| """ | |
| Perform inference on a single image using the model checkpoint. | |
| Args: | |
| checkpoint_path (str): Path to the model checkpoint. | |
| image (torch.Tensor): Image tensor to predict (shape: [1, 28, 28] for MNIST). | |
| Returns: | |
| int: Predicted class (0-9). | |
| """ | |
| logger.info(f"Loading model from checkpoint: {checkpoint_path} for inference...") | |
| if not Path(checkpoint_path).exists(): | |
| logger.error(f"Checkpoint not found: {checkpoint_path}") | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| # Detect device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Inference will run on device: {device}") | |
| # Load the model | |
| model = LitEfficientNet.load_from_checkpoint(checkpoint_path).to(device) | |
| model.eval() | |
| # Perform inference | |
| with torch.no_grad(): | |
| if image.dim() == 3: | |
| image = image.unsqueeze(0) # Add batch dimension if needed | |
| image = image.to(device) # Ensure the image is on the same device as the model | |
| prediction = model(image) | |
| predicted_class = torch.argmax(prediction, dim=1).item() | |
| logger.info(f"Predicted class: {predicted_class}") | |
| return predicted_class | |
| def test_model(checkpoint_path): | |
| """ | |
| Test the model using the test dataset and log metrics. | |
| Args: | |
| checkpoint_path (str): Path to the model checkpoint. | |
| Returns: | |
| float: Final test accuracy. | |
| """ | |
| logger.info(f"Loading model from checkpoint: {checkpoint_path} for testing...") | |
| if not Path(checkpoint_path).exists(): | |
| logger.error(f"Checkpoint not found: {checkpoint_path}") | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| # Detect device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Testing will run on device: {device}") | |
| # Load the model | |
| model = LitEfficientNet.load_from_checkpoint(checkpoint_path).to(device) | |
| model.eval() | |
| # Set up data module and load test data | |
| data_module = MNISTDataModule() | |
| data_module.setup(stage="test") | |
| test_loader = data_module.test_dataloader() | |
| # Initialize accuracy metric | |
| test_acc = Accuracy(num_classes=10, task="multiclass").to(device) | |
| # Evaluate model on test data | |
| logger.info("Evaluating on test dataset...") | |
| with torch.no_grad(): | |
| for images, labels in test_loader: | |
| images, labels = images.to(device), labels.to( | |
| device | |
| ) # Move data to the same device | |
| outputs = model(images) | |
| test_acc.update(outputs, labels) | |
| accuracy = test_acc.compute().item() | |
| logger.info(f"Final Test Accuracy (TorchMetrics): {accuracy:.2%}") | |
| return accuracy | |
| if __name__ == "__main__": | |
| # downloading from s3 | |
| s3_handler = S3Handler(bucket_name="deep-bucket-s3") | |
| s3_handler.download_folder( | |
| "checkpoints_test", | |
| "checkpoints", | |
| ) | |
| checkpoint_path = "./checkpoints/best_model.ckpt" | |
| try: | |
| # Perform testing | |
| test_accuracy = test_model(checkpoint_path) | |
| logger.info(f"Test completed successfully with accuracy: {test_accuracy:.2%}") | |
| # Example inference | |
| logger.info("Running inference on a single test image...") | |
| dummy_image = torch.randn(1, 28, 28) # Replace with actual test image | |
| predicted_class = infer(checkpoint_path, dummy_image) | |
| logger.info(f"Inference result: Predicted class {predicted_class}") | |
| except Exception as e: | |
| logger.error(f"An error occurred: {e}") | |