Spaces:
Runtime error
Runtime error
| import torch | |
| from geoclip import GeoCLIP | |
| from PIL import Image | |
| import tempfile | |
| from pathlib import Path | |
| import gradio as gr | |
| import spaces | |
| from geopy.geocoders import Nominatim | |
| from transformers import CLIPProcessor, CLIPModel | |
| from torchvision import transforms | |
| import reverse_geocoder as rg | |
| from models.huggingface import Geolocalizer | |
| import folium | |
| import json | |
| from geopy.exc import GeocoderTimedOut | |
| if torch.cuda.is_available(): | |
| geoclip_model = GeoCLIP().to("cuda") | |
| else: | |
| geoclip_model = GeoCLIP() | |
| geolocator = Nominatim(user_agent="predictGeolocforImage") | |
| streetclip_model = CLIPModel.from_pretrained("geolocal/StreetCLIP") | |
| streetclip_processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") | |
| labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', | |
| 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', | |
| 'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', | |
| 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia','Iran', 'Ireland', | |
| 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', | |
| 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', | |
| 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', | |
| 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', | |
| 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', | |
| 'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', | |
| 'United States', 'Uruguay'] | |
| IMAGE_SIZE = (224, 224) | |
| GEOLOC_MODEL_NAME = "osv5m/baseline" | |
| geoloc_model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME) | |
| geoloc_model.eval() | |
| def transform_image(image): | |
| transform = transforms.Compose([ | |
| transforms.Resize(IMAGE_SIZE), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| return transform(image).unsqueeze(0) | |
| def create_map(lat, lon): | |
| m = folium.Map(location=[lat, lon], zoom_start=4) | |
| folium.Marker([lat, lon]).add_to(m) | |
| map_html = m._repr_html_() | |
| return map_html | |
| def get_country_coordinates(country_name): | |
| try: | |
| location = geolocator.geocode(country_name, timeout=10) | |
| if location: | |
| return location.latitude, location.longitude | |
| except GeocoderTimedOut: | |
| return None | |
| return None | |
| def predict_geoclip(image): | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| tmppath = Path(tmp_dir) / "tmp.jpg" | |
| image.save(str(tmppath)) | |
| top_pred_gps, top_pred_prob = geoclip_model.predict(str(tmppath), top_k=50) | |
| predictions = [] | |
| for i in range(1): | |
| lat, lon = top_pred_gps[i] | |
| probpercent = top_pred_prob[i] * 100 | |
| location = geolocator.reverse((lat, lon), exactly_one=True) | |
| address = location.raw['address'] | |
| city = address.get('city', '') | |
| country = address.get('country', '') | |
| prediction = f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {country}" | |
| predictions.append(prediction) | |
| map_html = create_map(lat, lon) | |
| return "\n".join(predictions), map_html | |
| def classify_streetclip(image): | |
| inputs = streetclip_processor(text=labels, images=image, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| outputs = streetclip_model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| prediction = logits_per_image.softmax(dim=1) | |
| confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))} | |
| sorted_confidences = sorted(confidences.items(), key=lambda item: item[1], reverse=True) | |
| top_label, top_confidence = sorted_confidences[0] | |
| coords = get_country_coordinates(top_label) | |
| map_html = create_map(*coords) if coords else "Map not available" | |
| return f"Country: {top_label}", map_html | |
| def infer(image): | |
| try: | |
| img_tensor = transform_image(image) | |
| gps_radians = geoloc_model(img_tensor) | |
| gps_degrees = torch.rad2deg(gps_radians).squeeze(0).cpu().tolist() | |
| lat, lon = gps_degrees[0], gps_degrees[1] | |
| location_query = rg.search((lat, lon))[0] | |
| location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}" | |
| map_html = create_map(lat, lon) | |
| return f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {location_query['admin1']} - {location_query['cc']}", map_html | |
| except Exception as e: | |
| return f"Failed to predict the location: {e}", None | |
| geoclip_interface = gr.Interface( | |
| fn=predict_geoclip, | |
| inputs=gr.Image(type="pil", label="Upload Image", elem_id="geoclip_image_input"), | |
| outputs=[gr.Textbox(label="Prediction", elem_id="geoclip_output"), | |
| gr.HTML(label="Map", elem_id="geoclip_map_output")], | |
| title="GeoCLIP" | |
| ) | |
| streetclip_interface = gr.Interface( | |
| fn=classify_streetclip, | |
| inputs=gr.Image(type="pil", label="Upload Image", elem_id="streetclip_image_input"), | |
| outputs=[gr.Textbox(label="Prediction", elem_id="streetclip_output"), | |
| gr.HTML(label="Map", elem_id="streetclip_map_output")], | |
| title="StreetCLIP" | |
| ) | |
| osv5m_interface = gr.Interface( | |
| fn=infer, | |
| inputs=gr.Image(label="Upload Image", type="pil", elem_id="osv5m_image_input"), | |
| outputs=[gr.Textbox(label="Prediction", elem_id="result_text"), gr.HTML(label="Map", elem_id="map_output")], | |
| title="OSV-5M Baseline" | |
| ) | |
| demo = gr.TabbedInterface([geoclip_interface, streetclip_interface, osv5m_interface], | |
| tab_names=["GeoCLIP", "StreetCLIP", "OSV-5M Baseline"]) | |
| demo.launch() | |