# -*- coding: utf-8 -*- """satellite_app.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27 """ #!pip install gradio --quiet #!pip install -Uq transformers datasets timm accelerate evaluate import subprocess # subprocess.run('pip3 install datasets timm cv2 huggingface_hub torch pillow matplotlib' ,shell=True) import gradio as gr from timm import create_model from huggingface_hub import hf_hub_download from datasets import load_dataset import torch import torchvision.transforms as T import cv2 import matplotlib.pyplot as plt import numpy as np from PIL import Image safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors") model_name = 'swin_s3_base_224' # intialize the model model = create_model( model_name, num_classes=17 ) load_model(model,safe_tensors) def one_hot_decoding(labels): class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze'] id2label = {idx:c for idx,c in enumerate(class_names)} id_list = [] for idx,i in enumerate(labels): if i == 1: id_list.append(idx) true_labels = [] for i in id_list: true_labels.append(id2label[i]) return true_labels def model_output(image): image = cv2.imread(name) PIL_image = Image.fromarray(image.astype('uint8'), 'RGB') img_size = (224,224) test_tfms = T.Compose([ T.Resize(img_size), T.ToTensor(), ]) img = test_tfms(PIL_image) with torch.no_grad(): logits = model(img.unsqueeze(0)) predictions = logits.sigmoid() > 0.5 predictions = predictions.float().numpy().flatten() pred_labels = one_hot_decoding(predictions) output_text = " ".join(pred_labels) return output_text app = gr.Interface(fn=model_output, inputs="image", outputs="text") app.launch(debug=True)