subhuatharva's picture
Update app.py
2c31052 verified
raw
history blame
2.67 kB
# -*- coding: utf-8 -*-
"""satellite_app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
"""
#!pip install gradio --quiet
#!pip install -Uq transformers datasets timm accelerate evaluate
import subprocess
# subprocess.run('pip3 install datasets timm cv2 huggingface_hub torch pillow matplotlib' ,shell=True)
import gradio as gr
from huggingface_hub import hf_hub_download
from safetensors.torch import load_model
from datasets import load_dataset
import torch
import torchvision.transforms as T
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from timm import create_model
safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")
model_name = 'swin_s3_base_224'
# intialize the model
model = create_model(
model_name,
num_classes=17
)
load_model(model,safe_tensors)
def one_hot_decoding(labels):
class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze']
id2label = {idx:c for idx,c in enumerate(class_names)}
id_list = []
for idx,i in enumerate(labels):
if i == 1:
id_list.append(idx)
true_labels = []
for i in id_list:
true_labels.append(id2label[i])
return true_labels
title = "Satellite Image Classification for Landscape Analysis"
description = """The bot was primarily trained to classify satellite images of the entire Amazon rainforest. You will need to upload satellite images and the bot will classify roads, forest, agriculure areas and much more!"""
def model_output(image):
#image = cv2.imread(image)
PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')
img_size = (224,224)
test_tfms = T.Compose([
T.Resize(img_size),
T.ToTensor(),
])
img = test_tfms(PIL_image)
with torch.no_grad():
logits = model(img.unsqueeze(0))
predictions = logits.sigmoid() > 0.5
predictions = predictions.float().numpy().flatten()
pred_labels = one_hot_decoding(predictions)
output_text = " ".join(pred_labels)
return output_text
app = gr.Interface(fn=model_output, inputs="image", outputs="text", title=title,
description=description, examples=[["sample_images/train_142.jpg"], ["sample_images/train_75.jpg"],["sample_images/train_32.jpg"], ["sample_images/train_59.jpg"], ["sample_images/train_67.jpg"], ["sample_images/train_92.jpg"]])
app.launch(share=True)