Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,10 +4,12 @@ from torch.utils.data import Dataset, DataLoader
|
|
| 4 |
import pandas as pd
|
| 5 |
import numpy as np
|
| 6 |
import joblib
|
|
|
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
from torchvision import transforms,models
|
| 9 |
from sklearn.preprocessing import LabelEncoder,MinMaxScaler
|
| 10 |
-
from gradio import Interface, Image, Label
|
| 11 |
from huggingface_hub import snapshot_download
|
| 12 |
|
| 13 |
# Retrieve the token from the environment variables
|
|
@@ -85,9 +87,19 @@ def predict(input_img):
|
|
| 85 |
prediction = MMS.inverse_transform(res.cpu().numpy()).flatten()
|
| 86 |
return prediction
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
# Create label output function
|
| 89 |
def create_label_output(predictions):
|
| 90 |
-
|
|
|
|
|
|
|
| 91 |
|
| 92 |
# Predict and plot function
|
| 93 |
def predict_and_plot(input_img):
|
|
@@ -99,7 +111,7 @@ gradio_app = Interface(
|
|
| 99 |
fn=predict_and_plot,
|
| 100 |
inputs=Image(label="Upload an Image", type="pil"),
|
| 101 |
examples=["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"],
|
| 102 |
-
outputs=
|
| 103 |
title="Predict the Location of this Image"
|
| 104 |
)
|
| 105 |
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
import numpy as np
|
| 6 |
import joblib
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
import folium
|
| 9 |
from PIL import Image
|
| 10 |
from torchvision import transforms,models
|
| 11 |
from sklearn.preprocessing import LabelEncoder,MinMaxScaler
|
| 12 |
+
from gradio import Interface, Image, Label, HTML
|
| 13 |
from huggingface_hub import snapshot_download
|
| 14 |
|
| 15 |
# Retrieve the token from the environment variables
|
|
|
|
| 87 |
prediction = MMS.inverse_transform(res.cpu().numpy()).flatten()
|
| 88 |
return prediction
|
| 89 |
|
| 90 |
+
# Function to generate HTML for map
|
| 91 |
+
def create_map_html(lat, lon):
|
| 92 |
+
m = folium.Map(location=[lat, lon], zoom_start=12)
|
| 93 |
+
folium.Marker([lat, lon]).add_to(m)
|
| 94 |
+
data = BytesIO()
|
| 95 |
+
m.save(data, close_file=False)
|
| 96 |
+
return data.getvalue().decode()
|
| 97 |
+
|
| 98 |
# Create label output function
|
| 99 |
def create_label_output(predictions):
|
| 100 |
+
lat, lon = predictions
|
| 101 |
+
map_html = create_map_html(lat, lon)
|
| 102 |
+
return f"<div><h3>Predicted coordinates: ({lat:.6f}, {lon:.6f})</h3>{map_html}</div>"
|
| 103 |
|
| 104 |
# Predict and plot function
|
| 105 |
def predict_and_plot(input_img):
|
|
|
|
| 111 |
fn=predict_and_plot,
|
| 112 |
inputs=Image(label="Upload an Image", type="pil"),
|
| 113 |
examples=["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"],
|
| 114 |
+
outputs=HTML(),
|
| 115 |
title="Predict the Location of this Image"
|
| 116 |
)
|
| 117 |
|