|
|
|
|
|
"""satellite_app.ipynb |
|
|
|
|
|
Automatically generated by Colab. |
|
|
|
|
|
Original file is located at |
|
|
https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27 |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import subprocess |
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
from timm import create_model |
|
|
from huggingface_hub import hf_hub_download |
|
|
from datasets import load_dataset |
|
|
import torch |
|
|
import torchvision.transforms as T |
|
|
import cv2 |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
safe_tensors = "model.safetensors" |
|
|
|
|
|
model_name = 'swin_s3_base_224' |
|
|
|
|
|
model = create_model( |
|
|
model_name, |
|
|
num_classes=17 |
|
|
) |
|
|
|
|
|
load_model(model,safe_tensors) |
|
|
|
|
|
def one_hot_decoding(labels): |
|
|
class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze'] |
|
|
id2label = {idx:c for idx,c in enumerate(class_names)} |
|
|
|
|
|
id_list = [] |
|
|
for idx,i in enumerate(labels): |
|
|
if i == 1: |
|
|
id_list.append(idx) |
|
|
|
|
|
true_labels = [] |
|
|
for i in id_list: |
|
|
true_labels.append(id2label[i]) |
|
|
return true_labels |
|
|
|
|
|
def model_output(image): |
|
|
image = cv2.imread(name) |
|
|
PIL_image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
|
|
|
|
img_size = (224,224) |
|
|
test_tfms = T.Compose([ |
|
|
T.Resize(img_size), |
|
|
T.ToTensor(), |
|
|
]) |
|
|
|
|
|
img = test_tfms(PIL_image) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(img.unsqueeze(0)) |
|
|
|
|
|
predictions = logits.sigmoid() > 0.5 |
|
|
predictions = predictions.float().numpy().flatten() |
|
|
pred_labels = one_hot_decoding(predictions) |
|
|
output_text = " ".join(pred_labels) |
|
|
|
|
|
return output_text |
|
|
|
|
|
app = gr.Interface(fn=model_output, inputs="image", outputs="text") |
|
|
app.launch(debug=True) |
|
|
|
|
|
|