Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import numpy as np | |
| import lightning as pl | |
| import gradio as gr | |
| from PIL import Image | |
| from torchvision import transforms | |
| from data import LitMNISTDataModule | |
| from config import CONFIG | |
| from model import LitMNISTModel | |
| from timeit import default_timer as timer | |
| torch.set_float32_matmul_precision('medium') | |
| torch.cuda.amp.autocast(enabled=True) | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| torch.set_default_device( device= device ) | |
| pl.seed_everything(123, workers=True) | |
| TEST_TRANSFORMS = transforms.Compose([ | |
| # transforms.PILToTensor(), | |
| # transforms.Resize((28, 28)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| ## MNISTDataModule | |
| dm = LitMNISTDataModule( | |
| data_dir=CONFIG['data'].get('dir_path','.'), | |
| batch_size= CONFIG.get('batch_size'), | |
| num_workers=CONFIG.get('num_workers'), | |
| test_transform=TEST_TRANSFORMS, | |
| train_transform=None | |
| ) | |
| dm.prepare_data() | |
| dm.setup('test') | |
| ## MNISTModel | |
| # model = LitMNISTModel() | |
| chkpoint_path = os.path.join( os.path.dirname(__file__),'logs','chkpoints','epoch=13.ckpt' ) | |
| model = LitMNISTModel.load_from_checkpoint(chkpoint_path) | |
| trainer = pl.Trainer( | |
| fast_dev_run=True, | |
| precision=32, | |
| enable_model_summary=False, | |
| enable_progress_bar=False, | |
| ) | |
| # trainer.test(model,datamodule=dm) | |
| # for X,y in dm.test_dataloader(): | |
| # for i in range(X.shape[0]): | |
| # plt.imsave( | |
| # fname=os.path.join('numbers',f'img_{i}.png'), | |
| # arr=np.clip( | |
| # torch.stack( | |
| # [X[i,...],X[i,...],X[i,...]], | |
| # dim=1 | |
| # ).squeeze(0).permute(1,2,0).detach().cpu().contiguous().numpy(),0,1)) | |
| # break | |
| def predict_fn(img:Image): | |
| start_time = timer() | |
| try: | |
| img = np.array(img) | |
| img = TEST_TRANSFORMS(img) | |
| img = img.mean(dim=0).unsqueeze(0).unsqueeze(0).to(model.device) | |
| y_preds = model.predict_step( img) | |
| res = {f"Title: {y_preds['predict'][0]}": y_preds['prob'][0]} | |
| pred_time = round(timer() - start_time, 5) | |
| return(res, pred_time) | |
| except Exception as e: | |
| gr.Error("An error occured 💥!", duration=5) | |
| return ({ f"Title ☠️": 0.0},0.0) | |
| gr.Interface( | |
| fn=predict_fn, | |
| inputs=gr.Image(type='pil'), | |
| outputs=[ | |
| gr.Label(num_top_classes=1, label="Predictions"), # what are the outputs? | |
| gr.Number(label="Prediction time (s)") | |
| ], | |
| examples=[ ['numbers/'+i] for i in os.listdir(os.path.join( os.path.dirname(__file__) ,'numbers'))], | |
| title="The Unsolved MNIST 🔢", | |
| description="CNN-based Architecture for Fast and Accurate MNIST 🔢 Solution with Reproducible Logs", | |
| article="Created by muthukamalan.m ❤️" | |
| ).launch(share=False,debug=False) | |