# -*- coding: utf-8 -*- """satellite_app.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27 """ #!pip install gradio --quiet #!pip install -Uq transformers datasets timm accelerate evaluate import subprocess # subprocess.run('pip3 install datasets timm cv2 huggingface_hub torch pillow matplotlib' ,shell=True) import gradio as gr from huggingface_hub import hf_hub_download from safetensors.torch import load_model from datasets import load_dataset import torch import torchvision.transforms as T import cv2 import matplotlib.pyplot as plt import numpy as np from PIL import Image from timm import create_model safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors") model_name = 'swin_s3_base_224' # intialize the model model = create_model( model_name, num_classes=17 ) load_model(model,safe_tensors) def one_hot_decoding(labels): class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze'] id2label = {idx:c for idx,c in enumerate(class_names)} id_list = [] for idx,i in enumerate(labels): if i == 1: id_list.append(idx) true_labels = [] for i in id_list: true_labels.append(id2label[i]) return true_labels title = "Satellite Image Classification for Landscape Analysis" description = """The bot was primarily trained to classify satellite images of the entire Amazon rainforest. You will need to upload satellite images and the bot will classify roads, forest, agriculure areas and much more!""" def model_output(image): #image = cv2.imread(image) PIL_image = Image.fromarray(image.astype('uint8'), 'RGB') img_size = (224,224) test_tfms = T.Compose([ T.Resize(img_size), T.ToTensor(), ]) img = test_tfms(PIL_image) with torch.no_grad(): logits = model(img.unsqueeze(0)) predictions = logits.sigmoid() > 0.5 predictions = predictions.float().numpy().flatten() pred_labels = one_hot_decoding(predictions) output_text = " ".join(pred_labels) return output_text app = gr.Interface(fn=model_output, inputs="image", outputs="text", title=title, description=description, examples=[["sample_images/train_142.jpg"], ["sample_images/train_75.jpg"],["sample_images/train_32.jpg"], ["sample_images/train_59.jpg"], ["sample_images/train_67.jpg"], ["sample_images/train_92.jpg"]]) app.launch(share=True)