Spaces:
Running
Running
| import os | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| import pandas as pd | |
| import numpy as np | |
| import joblib | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| from io import BytesIO | |
| from PIL import Image | |
| from torchvision import transforms, models | |
| from sklearn.preprocessing import LabelEncoder, MinMaxScaler | |
| from gradio import Interface, Image, Label, HTML | |
| from huggingface_hub import snapshot_download | |
| import torch_xla.utils.serialization as xser | |
| import s2sphere | |
| import folium | |
| local_dir = snapshot_download( | |
| repo_id="robocan/GeoG_23k", | |
| repo_type="model", | |
| local_dir="SVD", | |
| ) | |
| device = 'cpu' | |
| le = LabelEncoder() | |
| le = joblib.load("SVD/le.gz") | |
| len_classes = len(le.classes_) + 1 | |
| class ModelPre(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.embedding = torch.nn.Sequential( | |
| *list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1], | |
| torch.nn.Flatten(), | |
| torch.nn.Linear(in_features=768, out_features=1024), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(in_features=1024, out_features=1024), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(in_features=1024, out_features=len_classes), | |
| ) | |
| def forward(self, data): | |
| return self.embedding(data) | |
| # Load the pretrained model | |
| model = ModelPre() | |
| model_w = xser.load("SVD/GeoG.pth") | |
| model.load_state_dict(model_w['model']) | |
| cmp = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize(size=(224, 224), antialias=True), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def predict(input_img): | |
| with torch.inference_mode(): | |
| img = cmp(input_img).unsqueeze(0) | |
| res = model(img.to(device)) | |
| probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten() | |
| top_10_indices = np.argsort(probabilities)[-10:][::-1] | |
| top_10_probabilities = probabilities[top_10_indices] | |
| top_10_predictions = le.inverse_transform(top_10_indices) | |
| results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)} | |
| return results, top_10_predictions | |
| # Function to get S2 cell polygon | |
| def get_s2_cell_polygon(cell_id): | |
| cell = s2sphere.Cell(s2sphere.CellId(cell_id)) | |
| vertices = [] | |
| for i in range(4): | |
| vertex = s2sphere.LatLng.from_point(cell.get_vertex(i)) | |
| vertices.append((vertex.lat().degrees, vertex.lng().degrees)) | |
| vertices.append(vertices[0]) # Close the polygon | |
| return vertices | |
| def create_map_figure(predictions, cell_ids, selected_index=None): | |
| fig = go.Figure() | |
| # Assign colors based on rank | |
| colors = ['rgba(0, 255, 0, 0.2)'] * 3 + ['rgba(255, 255, 0, 0.2)'] * 7 | |
| zoom_level = 1 # Default zoom level | |
| center_lat = None | |
| center_lon = None | |
| for rank, cell_id in enumerate(cell_ids): | |
| cell_id = int(float(cell_id)) | |
| polygon = get_s2_cell_polygon(cell_id) | |
| lats, lons = zip(*polygon) | |
| color = colors[rank] | |
| # Draw S2 cell polygon | |
| fig.add_trace(go.Scattermapbox( | |
| lat=lats, | |
| lon=lons, | |
| mode='lines', | |
| fill='toself', | |
| fillcolor=color, | |
| line=dict(color='blue'), | |
| name=f'Prediction {rank + 1}', | |
| )) | |
| # Adjust zoom level if selected prediction is found | |
| if selected_index is not None and rank == selected_index: | |
| zoom_level = 10 # Adjust the zoom level to your liking | |
| center_lat = np.mean(lats) | |
| center_lon = np.mean(lons) | |
| # Update map layout | |
| fig.update_layout( | |
| mapbox_style="open-street-map", | |
| hovermode='closest', | |
| mapbox=dict( | |
| bearing=0, | |
| center=go.layout.mapbox.Center( | |
| lat=center_lat if center_lat else np.mean(lats), | |
| lon=center_lon if center_lon else np.mean(lons) | |
| ), | |
| pitch=0, | |
| zoom=zoom_level # Zoom in based on selection | |
| ), | |
| ) | |
| return fig | |
| # Create label output function | |
| def create_label_output(predictions): | |
| results, cell_ids = predictions | |
| fig = create_map_figure(results, cell_ids) | |
| return fig | |
| def predict_and_plot(input_img, selected_prediction): | |
| predictions = predict(input_img) | |
| # Convert dropdown selection into an index (Prediction 1 corresponds to index 0, etc.) | |
| if selected_prediction is not None: | |
| selected_index = int(selected_prediction.split()[-1]) - 1 # Extract index from "Prediction X" | |
| else: | |
| selected_index = None # No selection, default view | |
| return create_map_figure(predictions, predictions[1], selected_index=selected_index) | |
| # Gradio app definition | |
| with gr.Blocks() as gradio_app: | |
| with gr.Column(): | |
| input_image = gr.Image(label="Upload an Image", type="pil") | |
| selected_prediction = gr.Dropdown( | |
| choices=[f"Prediction {i+1}" for i in range(10)], | |
| label="Select Prediction to Zoom", | |
| value="Prediction 1" # Set default to "Prediction 1" | |
| ) | |
| output_map = gr.Plot(label="Predicted Location on Map") | |
| btn_predict = gr.Button("Predict") | |
| # Update click function to include selected prediction | |
| btn_predict.click(predict_and_plot, inputs=[input_image, selected_prediction], outputs=output_map) | |
| examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"] | |
| gr.Examples(examples=examples, inputs=input_image) | |
| gradio_app.launch() |