test_UI / app.py
faustoont's picture
Update app.py
768d514
from distutils.command.upload import upload
import streamlit as st
from io import StringIO
from PIL import Image
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
st.title('Predict Images')
uploaded_file = st.file_uploader('Select File')
id2class = {0: 'agricultural', 1: 'airplane', 2: 'baseballdiamond', 3: 'beach', 4: 'buildings', 5: 'chaparral', 6: 'denseresidential', 7: 'forest', 8: 'freeway', 9: 'golfcourse', 10: 'intersection', 11: 'mediumresidential', 12: 'mobilehomepark', 13: 'overpass', 14: 'parkinglot', 15: 'river', 16: 'runway', 17: 'sparseresidential', 18: 'storagetanks', 19: 'tenniscourt', 20: 'harbor'}
model = models.resnet50(weights=None)
model.fc = nn.Linear(2048, 21)
model.load_state_dict(torch.load('resnet_best.pth', map_location=torch.device('cpu')), strict=True)
model.eval()
if uploaded_file is not None:
if '.jpg' in uploaded_file.name.lower() or '.png' in uploaded_file.name.lower():
st.write(uploaded_file.name)
img = Image.open(uploaded_file)
st.image(img)
img = np.array(img)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
cust_transform = A.Compose([A.Resize(height=256, width=256, p=1.0),ToTensorV2(p=1.0)], p=1.0)
tensor = cust_transform(image=img)
tensor = tensor['image'].float().resize(1,3,256,256)
custom_pred = model.forward(tensor).detach().numpy()
#custom_pred
st.write(f'Predicted: {id2class[np.argmax(custom_pred)]}')
elif '.csv' in uploaded_file.name:
dataframe = pd.read_csv(uploaded_file)
st.write(dataframe)