File size: 2,089 Bytes
c846a27
 
 
 
 
 
 
 
 
64b7df8
 
c846a27
7de21da
23ac470
7de21da
c846a27
 
 
 
 
 
 
 
 
 
 
7de21da
b6ffd9a
f77bc3e
c846a27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6ff219
c846a27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# -*- coding: utf-8 -*-
"""satellite_app.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
"""

#!pip install gradio --quiet
#!pip install -Uq transformers datasets timm accelerate evaluate

import subprocess
# subprocess.run('pip3 install datasets timm cv2 huggingface_hub torch pillow matplotlib' ,shell=True)

import gradio as gr
from timm import create_model
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import torch
import torchvision.transforms as T
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image



safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")

model_name = 'swin_s3_base_224'
# intialize the model
model = create_model(
    model_name,
    num_classes=17
)

load_model(model,safe_tensors)

def one_hot_decoding(labels):
  class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze']
  id2label = {idx:c for idx,c in enumerate(class_names)}

  id_list = []
  for idx,i in enumerate(labels):
    if i == 1:
      id_list.append(idx)

  true_labels = []
  for i in id_list:
    true_labels.append(id2label[i])
  return true_labels

def model_output(image):
  image = cv2.imread(name)
  PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')

  img_size = (224,224)
  test_tfms = T.Compose([
    T.Resize(img_size),
    T.ToTensor(),
  ])

  img = test_tfms(PIL_image)

  with torch.no_grad():
    logits = model(img.unsqueeze(0))

  predictions = logits.sigmoid() > 0.5
  predictions = predictions.float().numpy().flatten()
  pred_labels = one_hot_decoding(predictions)
  output_text = " ".join(pred_labels)

  return output_text

app = gr.Interface(fn=model_output, inputs="image", outputs="text")
app.launch(debug=True)