File size: 2,672 Bytes
c846a27
 
 
 
 
 
 
 
 
64b7df8
 
c846a27
7de21da
23ac470
7de21da
c846a27
 
c28bf1c
c846a27
 
 
 
 
 
 
bba43fc
c846a27
7de21da
b6ffd9a
f77bc3e
c846a27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd6a71c
2c31052
 
c846a27
 
fa2f4c1
c846a27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a65dcc
 
923cfc8
c846a27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# -*- coding: utf-8 -*-
"""satellite_app.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
"""

#!pip install gradio --quiet
#!pip install -Uq transformers datasets timm accelerate evaluate

import subprocess
# subprocess.run('pip3 install datasets timm cv2 huggingface_hub torch pillow matplotlib' ,shell=True)

import gradio as gr
from huggingface_hub import hf_hub_download
from safetensors.torch import load_model
from datasets import load_dataset
import torch
import torchvision.transforms as T
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from timm import create_model



safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")

model_name = 'swin_s3_base_224'
# intialize the model
model = create_model(
    model_name,
    num_classes=17
)

load_model(model,safe_tensors)

def one_hot_decoding(labels):
  class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze']
  id2label = {idx:c for idx,c in enumerate(class_names)}

  id_list = []
  for idx,i in enumerate(labels):
    if i == 1:
      id_list.append(idx)

  true_labels = []
  for i in id_list:
    true_labels.append(id2label[i])
  return true_labels
    
title = "Satellite Image Classification for Landscape Analysis"
description = """The bot was primarily trained to classify satellite images of the entire Amazon rainforest. You will need to upload satellite images and the bot will classify roads, forest, agriculure areas and much more!"""

def model_output(image):
  #image = cv2.imread(image)
  PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')

  img_size = (224,224)
  test_tfms = T.Compose([
    T.Resize(img_size),
    T.ToTensor(),
  ])

  img = test_tfms(PIL_image)

  with torch.no_grad():
    logits = model(img.unsqueeze(0))

  predictions = logits.sigmoid() > 0.5
  predictions = predictions.float().numpy().flatten()
  pred_labels = one_hot_decoding(predictions)
  output_text = " ".join(pred_labels)

  return output_text

app = gr.Interface(fn=model_output, inputs="image", outputs="text", title=title,
    description=description, examples=[["sample_images/train_142.jpg"], ["sample_images/train_75.jpg"],["sample_images/train_32.jpg"], ["sample_images/train_59.jpg"], ["sample_images/train_67.jpg"], ["sample_images/train_92.jpg"]])
app.launch(share=True)