Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from pathlib import Path | |
| from loguru import logger | |
| from src.model import LitEfficientNet | |
| from src.utils.aws_s3_services import S3Handler | |
| # Configure Loguru for logging | |
| logger.add("logs/inference.log", rotation="1 MB", level="INFO") | |
| class MNISTClassifier: | |
| def __init__(self, checkpoint_path="./checkpoints/best_model.ckpt"): | |
| self.checkpoint_path = checkpoint_path | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Inference will run on device: {self.device}") | |
| # Load the model | |
| self.model = self.load_model() | |
| self.model.eval() | |
| # Define transforms | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.Resize((28, 28)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)), | |
| ] | |
| ) | |
| self.labels = [str(i) for i in range(10)] # MNIST labels are 0-9 | |
| def load_model(self): | |
| """ | |
| Loads the model checkpoint for inference. | |
| """ | |
| if not Path(self.checkpoint_path).exists(): | |
| logger.error(f"Checkpoint not found: {self.checkpoint_path}") | |
| raise FileNotFoundError(f"Checkpoint not found: {self.checkpoint_path}") | |
| logger.info(f"Loading model from checkpoint: {self.checkpoint_path}") | |
| return LitEfficientNet.load_from_checkpoint(self.checkpoint_path).to( | |
| self.device | |
| ) | |
| def predict(self, image): | |
| """ | |
| Perform inference on a single image. | |
| Args: | |
| image: Input image in PIL format. | |
| Returns: | |
| dict: Predicted class probabilities. | |
| """ | |
| if image is None: | |
| logger.error("No image provided for prediction.") | |
| return None | |
| # Convert to tensor and preprocess | |
| img_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| # Perform inference | |
| output = self.model(img_tensor) | |
| probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
| # Map probabilities to labels | |
| return {self.labels[idx]: float(prob) for idx, prob in enumerate(probabilities)} | |
| # Instantiate the classifier | |
| checkpoint_path = "./checkpoints/best_model.ckpt" | |
| # Download checkpoint from S3 (if needed) | |
| s3_handler = S3Handler(bucket_name="deep-bucket-s3") | |
| s3_handler.download_folder( | |
| "checkpoints_test", | |
| "checkpoints", | |
| ) | |
| classifier = MNISTClassifier(checkpoint_path=checkpoint_path) | |
| # Define Gradio interface | |
| demo = gr.Interface( | |
| fn=classifier.predict, | |
| inputs=gr.Image(height=160, width=160, image_mode="L", type="pil"), | |
| outputs=gr.Label(num_top_classes=1), | |
| title="MNIST Classifier", | |
| description="Upload a handwritten digit image to classify it (0-9).", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |