Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,17 +8,19 @@ import gradio as gr
|
|
| 8 |
import plotly.graph_objects as go
|
| 9 |
from io import BytesIO
|
| 10 |
from PIL import Image
|
| 11 |
-
from torchvision import transforms,models
|
| 12 |
-
from sklearn.preprocessing import LabelEncoder,MinMaxScaler
|
| 13 |
from gradio import Interface, Image, Label, HTML
|
| 14 |
from huggingface_hub import snapshot_download
|
|
|
|
|
|
|
| 15 |
|
| 16 |
# Retrieve the token from the environment variables
|
| 17 |
token = os.environ.get("token")
|
| 18 |
|
| 19 |
# Download the repository snapshot
|
| 20 |
local_dir = snapshot_download(
|
| 21 |
-
repo_id="robocan/
|
| 22 |
repo_type="model",
|
| 23 |
local_dir="SVD",
|
| 24 |
token=token
|
|
@@ -27,7 +29,6 @@ local_dir = snapshot_download(
|
|
| 27 |
device = 'cpu'
|
| 28 |
le = LabelEncoder()
|
| 29 |
le = joblib.load("SVD/le.gz")
|
| 30 |
-
MMS = joblib.load("SVD/MMS.gz")
|
| 31 |
len_classes = len(le.classes_) + 1
|
| 32 |
|
| 33 |
class ModelPre(torch.nn.Module):
|
|
@@ -36,9 +37,9 @@ class ModelPre(torch.nn.Module):
|
|
| 36 |
self.embedding = torch.nn.Sequential(
|
| 37 |
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
|
| 38 |
torch.nn.Flatten(),
|
| 39 |
-
torch.nn.Linear(in_features=768,out_features=512),
|
| 40 |
torch.nn.ReLU(),
|
| 41 |
-
torch.nn.Linear(in_features=512,out_features=len_classes),
|
| 42 |
)
|
| 43 |
# Freeze all layers
|
| 44 |
|
|
@@ -47,30 +48,8 @@ class ModelPre(torch.nn.Module):
|
|
| 47 |
|
| 48 |
# Load the pretrained model
|
| 49 |
model = ModelPre()
|
| 50 |
-
#for param in model.parameters():
|
| 51 |
-
# param.requires_grad = False
|
| 52 |
-
class GeoGcord(torch.nn.Module):
|
| 53 |
-
def __init__(self):
|
| 54 |
-
super().__init__()
|
| 55 |
-
self.embedding = torch.nn.Sequential(
|
| 56 |
-
*list(model.children())[0][:-1],
|
| 57 |
-
torch.nn.Linear(in_features=512,out_features=256),
|
| 58 |
-
torch.nn.ReLU(),
|
| 59 |
-
torch.nn.Linear(in_features=256,out_features=128),
|
| 60 |
-
torch.nn.ReLU(),
|
| 61 |
-
torch.nn.Linear(in_features=128,out_features=2),
|
| 62 |
-
)
|
| 63 |
-
# Freeze all layers
|
| 64 |
|
| 65 |
-
def forward(self, data):
|
| 66 |
-
return self.embedding(data)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
# Load the pre-trained model
|
| 71 |
-
model = GeoGcord()
|
| 72 |
model_w = torch.load("SVD/GeoG.pth", map_location=torch.device(device))
|
| 73 |
-
|
| 74 |
model.load_state_dict(model_w['model'])
|
| 75 |
|
| 76 |
cmp = transforms.Compose([
|
|
@@ -79,27 +58,45 @@ cmp = transforms.Compose([
|
|
| 79 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 80 |
])
|
| 81 |
|
| 82 |
-
# Predict function for the new regression model
|
| 83 |
def predict(input_img):
|
| 84 |
with torch.inference_mode():
|
| 85 |
img = cmp(input_img).unsqueeze(0)
|
| 86 |
res = model(img.to(device))
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
# Function to generate Plotly map figure
|
| 92 |
-
def create_map_figure(
|
| 93 |
-
fig = go.Figure(
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
fig.update_layout(
|
| 105 |
mapbox_style="open-street-map",
|
|
@@ -107,8 +104,8 @@ def create_map_figure(lat, lon):
|
|
| 107 |
mapbox=dict(
|
| 108 |
bearing=0,
|
| 109 |
center=go.layout.mapbox.Center(
|
| 110 |
-
lat=
|
| 111 |
-
lon=
|
| 112 |
),
|
| 113 |
pitch=0,
|
| 114 |
zoom=3
|
|
@@ -119,8 +116,8 @@ def create_map_figure(lat, lon):
|
|
| 119 |
|
| 120 |
# Create label output function
|
| 121 |
def create_label_output(predictions):
|
| 122 |
-
|
| 123 |
-
fig = create_map_figure(
|
| 124 |
return fig
|
| 125 |
|
| 126 |
# Predict and plot function
|
|
@@ -138,4 +135,4 @@ with gr.Blocks() as gradio_app:
|
|
| 138 |
btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map)
|
| 139 |
examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
|
| 140 |
gr.Examples(examples=examples, inputs=input_image)
|
| 141 |
-
gradio_app.launch()
|
|
|
|
| 8 |
import plotly.graph_objects as go
|
| 9 |
from io import BytesIO
|
| 10 |
from PIL import Image
|
| 11 |
+
from torchvision import transforms, models
|
| 12 |
+
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
|
| 13 |
from gradio import Interface, Image, Label, HTML
|
| 14 |
from huggingface_hub import snapshot_download
|
| 15 |
+
import s2sphere
|
| 16 |
+
import folium
|
| 17 |
|
| 18 |
# Retrieve the token from the environment variables
|
| 19 |
token = os.environ.get("token")
|
| 20 |
|
| 21 |
# Download the repository snapshot
|
| 22 |
local_dir = snapshot_download(
|
| 23 |
+
repo_id="robocan/GeoG-GCP",
|
| 24 |
repo_type="model",
|
| 25 |
local_dir="SVD",
|
| 26 |
token=token
|
|
|
|
| 29 |
device = 'cpu'
|
| 30 |
le = LabelEncoder()
|
| 31 |
le = joblib.load("SVD/le.gz")
|
|
|
|
| 32 |
len_classes = len(le.classes_) + 1
|
| 33 |
|
| 34 |
class ModelPre(torch.nn.Module):
|
|
|
|
| 37 |
self.embedding = torch.nn.Sequential(
|
| 38 |
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
|
| 39 |
torch.nn.Flatten(),
|
| 40 |
+
torch.nn.Linear(in_features=768, out_features=512),
|
| 41 |
torch.nn.ReLU(),
|
| 42 |
+
torch.nn.Linear(in_features=512, out_features=len_classes),
|
| 43 |
)
|
| 44 |
# Freeze all layers
|
| 45 |
|
|
|
|
| 48 |
|
| 49 |
# Load the pretrained model
|
| 50 |
model = ModelPre()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
model_w = torch.load("SVD/GeoG.pth", map_location=torch.device(device))
|
|
|
|
| 53 |
model.load_state_dict(model_w['model'])
|
| 54 |
|
| 55 |
cmp = transforms.Compose([
|
|
|
|
| 58 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 59 |
])
|
| 60 |
|
|
|
|
| 61 |
def predict(input_img):
|
| 62 |
with torch.inference_mode():
|
| 63 |
img = cmp(input_img).unsqueeze(0)
|
| 64 |
res = model(img.to(device))
|
| 65 |
+
probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
|
| 66 |
+
top_10_indices = np.argsort(probabilities)[-10:][::-1]
|
| 67 |
+
top_10_probabilities = probabilities[top_10_indices]
|
| 68 |
+
top_10_predictions = le.inverse_transform(top_10_indices)
|
| 69 |
+
|
| 70 |
+
results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
|
| 71 |
+
return results, top_10_predictions
|
| 72 |
+
|
| 73 |
+
# Function to get S2 cell polygon
|
| 74 |
+
def get_s2_cell_polygon(cell_id):
|
| 75 |
+
cell = s2sphere.Cell(s2sphere.CellId(cell_id))
|
| 76 |
+
vertices = []
|
| 77 |
+
for i in range(4):
|
| 78 |
+
vertex = s2sphere.LatLng.from_point(cell.get_vertex(i))
|
| 79 |
+
vertices.append((vertex.lat().degrees, vertex.lng().degrees))
|
| 80 |
+
vertices.append(vertices[0]) # Close the polygon
|
| 81 |
+
return vertices
|
| 82 |
|
| 83 |
# Function to generate Plotly map figure
|
| 84 |
+
def create_map_figure(predictions, cell_ids):
|
| 85 |
+
fig = go.Figure()
|
| 86 |
+
|
| 87 |
+
for cell_id in cell_ids:
|
| 88 |
+
cell_id = int(cell_id)
|
| 89 |
+
polygon = get_s2_cell_polygon(cell_id)
|
| 90 |
+
lats, lons = zip(*polygon)
|
| 91 |
+
fig.add_trace(go.Scattermapbox(
|
| 92 |
+
lat=lats,
|
| 93 |
+
lon=lons,
|
| 94 |
+
mode='lines',
|
| 95 |
+
fill='toself',
|
| 96 |
+
fillcolor='rgba(0, 0, 255, 0.2)',
|
| 97 |
+
line=dict(color='blue'),
|
| 98 |
+
name=f'Cell ID: {cell_id}'
|
| 99 |
+
))
|
| 100 |
|
| 101 |
fig.update_layout(
|
| 102 |
mapbox_style="open-street-map",
|
|
|
|
| 104 |
mapbox=dict(
|
| 105 |
bearing=0,
|
| 106 |
center=go.layout.mapbox.Center(
|
| 107 |
+
lat=np.mean(lats),
|
| 108 |
+
lon=np.mean(lons)
|
| 109 |
),
|
| 110 |
pitch=0,
|
| 111 |
zoom=3
|
|
|
|
| 116 |
|
| 117 |
# Create label output function
|
| 118 |
def create_label_output(predictions):
|
| 119 |
+
results, cell_ids = predictions
|
| 120 |
+
fig = create_map_figure(results, cell_ids)
|
| 121 |
return fig
|
| 122 |
|
| 123 |
# Predict and plot function
|
|
|
|
| 135 |
btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map)
|
| 136 |
examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
|
| 137 |
gr.Examples(examples=examples, inputs=input_image)
|
| 138 |
+
gradio_app.launch()
|