tags:
- model_hub_mixin
- pytorch_model_hub_mixin
from geopy.distance import geodesic import numpy as np import torch import torch.nn as nn import torchvision.models as models import torchvision.transforms as transforms from torch.utils.data import DataLoader, Dataset from transformers import AutoImageProcessor, AutoModelForImageClassification from huggingface_hub import PyTorchModelHubMixin
model = CustomResNetModel.from_pretrained("5190final/model1")
lat_mean = 39.951611366653395 lat_std = 0.0006686190927448403 lon_mean = -75.19145880459313 lon_std = 0.0006484111794126842
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") model.to(device) model.eval()
test_dataset = GPSImageDataset( hf_dataset=dataset_test, transform=inference_transform, lat_mean=lat_mean, lat_std=lat_std, lon_mean=lon_mean, lon_std=lon_std ) test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
all_preds = [] all_actuals = []
with torch.no_grad(): # Disable gradient calculations during inference for images, gps_coords in test_dataloader: images = images.to(device) gps_coords = gps_coords.to(device)
outputs = model(images)
logits = outputs.logits # Extract the predictions
all_preds.extend(logits.cpu().numpy()) # Append predictions to the list
all_actuals.extend(gps_coords.cpu().numpy()) # Append actual values to the list
all_preds = np.array(all_preds) all_actuals = np.array(all_actuals)
all_preds_denorm = all_preds * np.array([lat_std, lon_std]) + np.array([lat_mean, lon_mean]) # Assuming lat_std, etc., are lists or NumPy arrays all_actuals_denorm = all_actuals * np.array([lat_std, lon_std]) + np.array([lat_mean, lon_mean])
squared_errors = []
for pred, actual in zip(all_preds_denorm, all_actuals_denorm):
# Calculate geodesic distance between predicted and actual coordinates
distance = geodesic((actual[0], actual[1]), (pred[0], pred[1])).meters
squared_errors.append(distance**2) # Square the distance for RMSE
rmse = np.sqrt(np.mean(squared_errors))
This model has been pushed to the Hub using the PytorchModelHubMixin integration:
- Library: [More Information Needed]
- Docs: [More Information Needed]