Spaces:
Build error
Build error
| # AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb. | |
| # %% auto 0 | |
| __all__ = ['MODEL_PATH', 'model', 'image', 'label', 'processed_image', 'intf', 'predict'] | |
| # %% app.ipynb 2 | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from pathlib import Path | |
| import sys | |
| np.set_printoptions(threshold=sys.maxsize) | |
| # %% app.ipynb 4 | |
| from lenet import LeNet5 | |
| # Allowlist the custom class | |
| MODEL_PATH = Path("models/lenet5-cpu.pt") | |
| model = torch.load(MODEL_PATH, weights_only=False) | |
| model.eval() | |
| def predict(img): | |
| # Create a new image with a white background | |
| background = Image.new("L", (28, 28), 255) | |
| # Resize the input image | |
| img_pil = img["composite"].resize((28, 28)) | |
| # Paste the resized image onto the white background | |
| background.paste(img_pil, (0, 0), img_pil) | |
| # Convert to numpy | |
| img_array = np.array(background) | |
| # Invert colors (MNIST has white digits on black) | |
| img_array = 255 - img_array | |
| # Create a displayable version of the inverted image (what the model actually sees) | |
| inverted_debug = img_array.astype(np.uint8) | |
| img_tensor = torch.tensor(img_array, dtype=torch.float32) | |
| img_tensor = img_tensor.unsqueeze(0).unsqueeze(0) # Add channel and batch dimensions | |
| # Debug: Print the shape and values of the input tensor | |
| print(f"Input tensor shape: {img_tensor.shape}") | |
| print(f"Input tensor values: {img_tensor}") | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| probabilities = torch.nn.functional.softmax(output, dim=1)[0] | |
| print(f"Output shape: {output.shape}") | |
| print(f"Probabilities shape: {probabilities.shape}") | |
| print(f"Probabilities: {probabilities}") | |
| # Create dictionary of label: probability for Gradio Label output | |
| return {str(i): float(prob) for i, prob in enumerate(probabilities)}, inverted_debug | |
| image = gr.Sketchpad(type="pil", sources=(), canvas_size=(280,280), brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=20), layers=False, transforms=[]) | |
| label = gr.Label() | |
| processed_image = gr.Image(label="What the Model Sees (28x28)") | |
| intf = gr.Interface(title="Draw a digit", description="And let me identify it for you...", fn=predict, inputs=image, outputs=[label, processed_image], clear_btn=None) | |
| intf.launch(inline=False, debug=True) | |