| lat mean = 39.951614360789364 | |
| lat std = 0.0007384844437841076 | |
| lon mean = -75.19140262762761 | |
| lon std = 0.0007284591160342192 | |
| **To load model:** | |
| ``` | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| repo_id = "thestalkers/ImageToGPSproject_base_resnet18_v2" | |
| filename = "resnet_gps_regressor_complete.pth" | |
| model_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| # Load the model using torch | |
| model_test = torch.load(model_path) | |
| model_test.eval() # Set the model to evaluation mode | |
| ``` | |
| **Load a hf dataset:** | |
| ``` | |
| from datasets import load_dataset, Image | |
| dataset_test = load_dataset("gydou/released_img", split="train") | |
| inference_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| test_dataset = GPSImageDataset( | |
| hf_dataset=dataset_test, | |
| transform=inference_transform, | |
| lat_mean=lat_mean, | |
| lat_std=lat_std, | |
| lon_mean=lon_mean, | |
| lon_std=lon_std | |
| ) | |
| test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) | |
| ``` | |
| **Perform inference:** | |
| ``` | |
| from sklearn.metrics import mean_absolute_error, mean_squared_error | |
| # Initialize lists to store predictions and actual values | |
| all_preds = [] | |
| all_actuals = [] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f'Using device: {device}') | |
| with torch.no_grad(): | |
| for images, gps_coords in test_dataloader: | |
| images, gps_coords = images.to(device), gps_coords.to(device) | |
| outputs = model_test(images) | |
| # Denormalize predictions and actual values | |
| preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean]) | |
| actuals = gps_coords.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean]) | |
| all_preds.append(preds) | |
| all_actuals.append(actuals) | |
| # Concatenate all batches | |
| all_preds = torch.cat(all_preds).numpy() | |
| all_actuals = torch.cat(all_actuals).numpy() | |
| # Compute error metrics | |
| mae = mean_absolute_error(all_actuals, all_preds) | |
| rmse = mean_squared_error(all_actuals, all_preds, squared=False) | |
| print(f'Mean Absolute Error: {mae}') | |
| print(f'Root Mean Squared Error: {rmse}') | |
| ``` |