| |
| """satellite_app.ipynb |
| |
| Automatically generated by Colab. |
| |
| Original file is located at |
| https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27 |
| """ |
|
|
| |
| |
|
|
| import subprocess |
| |
|
|
| import gradio as gr |
| from huggingface_hub import hf_hub_download |
| from datasets import load_dataset |
| import torch |
| import torchvision.transforms as T |
| import cv2 |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from PIL import Image |
| from timm import create_model |
|
|
|
|
|
|
| safe_tensors = "model.safetensors" |
|
|
| model_name = 'swin_s3_base_224' |
| |
| model = create_model( |
| model_name, |
| num_classes=17 |
| ) |
|
|
| load_model(model,safe_tensors) |
|
|
| def one_hot_decoding(labels): |
| class_names = ['conventional_mine', 'habitation', 'primary', 'water', 'agriculture', 'bare_ground', 'cultivation', 'blow_down', 'road', 'cloudy', 'blooming', 'partly_cloudy', 'selective_logging', 'artisinal_mine', 'slash_burn', 'clear', 'haze'] |
| id2label = {idx:c for idx,c in enumerate(class_names)} |
|
|
| id_list = [] |
| for idx,i in enumerate(labels): |
| if i == 1: |
| id_list.append(idx) |
|
|
| true_labels = [] |
| for i in id_list: |
| true_labels.append(id2label[i]) |
| return true_labels |
|
|
| def model_output(image): |
| image = cv2.imread(name) |
| PIL_image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
|
| img_size = (224,224) |
| test_tfms = T.Compose([ |
| T.Resize(img_size), |
| T.ToTensor(), |
| ]) |
|
|
| img = test_tfms(PIL_image) |
|
|
| with torch.no_grad(): |
| logits = model(img.unsqueeze(0)) |
|
|
| predictions = logits.sigmoid() > 0.5 |
| predictions = predictions.float().numpy().flatten() |
| pred_labels = one_hot_decoding(predictions) |
| output_text = " ".join(pred_labels) |
|
|
| return output_text |
|
|
| app = gr.Interface(fn=model_output, inputs="image", outputs="text") |
| app.launch(debug=True) |
|
|
|
|