|
|
|
|
|
"""satellite_app.ipynb |
|
|
|
|
|
Automatically generated by Colab. |
|
|
|
|
|
Original file is located at |
|
|
https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27 |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import subprocess |
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
from huggingface_hub import hf_hub_download |
|
|
from safetensors.torch import load_model |
|
|
from datasets import load_dataset |
|
|
import torch |
|
|
import torchvision.transforms as T |
|
|
import cv2 |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from timm import create_model |
|
|
|
|
|
|
|
|
|
|
|
safe_tensors = "model.safetensors" |
|
|
|
|
|
model_name = 'swin_s3_base_224' |
|
|
|
|
|
model = create_model( |
|
|
model_name, |
|
|
num_classes=17 |
|
|
) |
|
|
|
|
|
load_model(model,safe_tensors) |
|
|
|
|
|
def one_hot_decoding(labels): |
|
|
class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze'] |
|
|
id2label = {idx:c for idx,c in enumerate(class_names)} |
|
|
|
|
|
id_list = [] |
|
|
for idx,i in enumerate(labels): |
|
|
if i == 1: |
|
|
id_list.append(idx) |
|
|
|
|
|
true_labels = [] |
|
|
for i in id_list: |
|
|
true_labels.append(id2label[i]) |
|
|
return true_labels |
|
|
|
|
|
title = "Satellite Image Classification for Landscape Analysis" |
|
|
description = """The bot was primarily trained to classify satellite images of the entire Amazon rainforest. You will need to upload satellite images and the bot will classify roads, forest, agriculure areas and much more!""" |
|
|
|
|
|
def model_output(image): |
|
|
|
|
|
PIL_image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
|
|
|
|
img_size = (224,224) |
|
|
test_tfms = T.Compose([ |
|
|
T.Resize(img_size), |
|
|
T.ToTensor(), |
|
|
]) |
|
|
|
|
|
img = test_tfms(PIL_image) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(img.unsqueeze(0)) |
|
|
|
|
|
predictions = logits.sigmoid() > 0.5 |
|
|
predictions = predictions.float().numpy().flatten() |
|
|
pred_labels = one_hot_decoding(predictions) |
|
|
output_text = " ".join(pred_labels) |
|
|
|
|
|
return output_text |
|
|
|
|
|
app = gr.Interface(fn=model_output, inputs="image", outputs="text", title=title, |
|
|
description=description, examples=[["sample_images/train_142.jpg"], ["sample_images/train_75.jpg"],["sample_images/train_32.jpg"], ["sample_images/train_59.jpg"], ["sample_images/train_67.jpg"], ["sample_images/train_92.jpg"]]) |
|
|
app.launch(share=True) |
|
|
|
|
|
|