File size: 1,728 Bytes
0b57da2
38780f1
 
 
 
d71c714
5ae6ac4
 
0b57da2
 
5ae6ac4
0b57da2
 
d71c714
0b57da2
 
aa9ba14
 
 
5ae6ac4
aa9ba14
cdbd6d3
38780f1
0b57da2
 
 
 
 
 
768d514
0b57da2
 
 
aa9ba14
768d514
 
38780f1
0b57da2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from distutils.command.upload import upload
import streamlit as st
from io import StringIO
from PIL import Image
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2

st.title('Predict Images')
uploaded_file = st.file_uploader('Select File')

id2class = {0: 'agricultural', 1: 'airplane', 2: 'baseballdiamond', 3: 'beach', 4: 'buildings', 5: 'chaparral', 6: 'denseresidential', 7: 'forest', 8: 'freeway', 9: 'golfcourse', 10: 'intersection', 11: 'mediumresidential', 12: 'mobilehomepark', 13: 'overpass', 14: 'parkinglot', 15: 'river', 16: 'runway', 17: 'sparseresidential', 18: 'storagetanks', 19: 'tenniscourt', 20: 'harbor'}

model = models.resnet50(weights=None)
model.fc = nn.Linear(2048, 21)
model.load_state_dict(torch.load('resnet_best.pth', map_location=torch.device('cpu')), strict=True)
model.eval()

if uploaded_file is not None:
    if '.jpg' in uploaded_file.name.lower() or '.png' in uploaded_file.name.lower():
        st.write(uploaded_file.name)
        img = Image.open(uploaded_file)
        st.image(img)
        img = np.array(img)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
        cust_transform = A.Compose([A.Resize(height=256, width=256, p=1.0),ToTensorV2(p=1.0)], p=1.0)
        tensor = cust_transform(image=img)
        tensor = tensor['image'].float().resize(1,3,256,256)

        custom_pred = model.forward(tensor).detach().numpy()
        #custom_pred

        st.write(f'Predicted: {id2class[np.argmax(custom_pred)]}')
    elif '.csv' in uploaded_file.name:
        dataframe = pd.read_csv(uploaded_file)
        st.write(dataframe)