subhuatharva's picture
Update app.py
23ac470 verified
raw
history blame
2.09 kB
# -*- coding: utf-8 -*-
"""satellite_app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
"""
#!pip install gradio --quiet
#!pip install -Uq transformers datasets timm accelerate evaluate
import subprocess
# subprocess.run('pip3 install datasets timm cv2 huggingface_hub torch pillow matplotlib' ,shell=True)
import gradio as gr
from timm import create_model
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import torch
import torchvision.transforms as T
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")
model_name = 'swin_s3_base_224'
# intialize the model
model = create_model(
model_name,
num_classes=17
)
load_model(model,safe_tensors)
def one_hot_decoding(labels):
class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze']
id2label = {idx:c for idx,c in enumerate(class_names)}
id_list = []
for idx,i in enumerate(labels):
if i == 1:
id_list.append(idx)
true_labels = []
for i in id_list:
true_labels.append(id2label[i])
return true_labels
def model_output(image):
image = cv2.imread(name)
PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')
img_size = (224,224)
test_tfms = T.Compose([
T.Resize(img_size),
T.ToTensor(),
])
img = test_tfms(PIL_image)
with torch.no_grad():
logits = model(img.unsqueeze(0))
predictions = logits.sigmoid() > 0.5
predictions = predictions.float().numpy().flatten()
pred_labels = one_hot_decoding(predictions)
output_text = " ".join(pred_labels)
return output_text
app = gr.Interface(fn=model_output, inputs="image", outputs="text")
app.launch(debug=True)