File size: 1,893 Bytes
2692ab4
b2d05ca
 
 
 
 
2692ab4
 
 
 
 
 
 
 
 
 
 
 
 
b2d05ca
2692ab4
 
 
 
 
 
b2d05ca
aed63ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from distutils.command.upload import upload
import streamlit as st
from io import StringIO
from PIL import Image
import pandas as pd

import torch
import torch.nn as nn
import torchvision.models as models

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import numpy as np

import cv2

st.title('Dummy')
uploaded_file = st.file_uploader('Select File')

id2class = {0: 'agricultural', 1: 'airplane', 2: 'baseballdiamond', 3: 'beach', 4: 'buildings', 5: 'chaparral', 6: 'denseresidential', 7: 'forest', 8: 'freeway', 9: 'golfcourse', 10: 'intersection', 11: 'mediumresidential', 12: 'mobilehomepark', 13: 'overpass', 14: 'parkinglot', 15: 'river', 16: 'runway', 17: 'sparseresidential', 18: 'storagetanks', 19: 'tenniscourt', 20: 'harbor'}

model = models.resnet50(weights=None)
model.fc = nn.Linear(2048, 21)
model.load_state_dict(torch.load('resnet_best.pth', map_location=torch.device('cpu')), strict=True)
model.eval()

if uploaded_file is not None:
    if '.jpg' in uploaded_file.name.lower() or '.png' in uploaded_file.name.lower():
        st.write(uploaded_file.name)
        img = Image.open(uploaded_file)
        st.image(img)
        img = np.array(img)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # OpenCV images have a different color profile. So remember to switch to RGB, that our resnet model understands.

        cust_transform = A.Compose([A.Resize(height=256, width=256, p=1.0),ToTensorV2(p=1.0)], p=1.0)
        tensor = cust_transform(image=img)
        tensor = tensor['image'].float().resize(1,3,256,256)

        custom_pred = model.forward(tensor).detach().numpy() # Forward is the python method defined inside the resnet.
        custom_pred

        st.write(f'Predicted: {id2class[np.argmax(custom_pred)]}')
    elif '.csv' in uploaded_file.name:
        dataframe = pd.read_csv(uploaded_file)
        st.write(dataframe)