|
|
--- |
|
|
library_name: transformers |
|
|
tags: [] |
|
|
--- |
|
|
|
|
|
# Model Card for Model ID (Use this model for Leaderboard) |
|
|
|
|
|
This is a fine-tuned Vision Transformer (ViT) model from Google. The model was loaded and fine-tuned on the training data collected. |
|
|
Compared to Attempt 1, we are using the expanded dataset, trained for 20 epochs instead of 5, and only updated the classifier parameters at training time. |
|
|
Compared to Attempt 3, this was trained for 20 epochs and had a learning rate of |
|
|
|
|
|
Link: https://huggingface.co/google/vit-base-patch16-224-in21k |
|
|
|
|
|
lat_mean = 39.951640614844095 |
|
|
|
|
|
lat_std = 0.0007502796001097172 |
|
|
|
|
|
lon_mean = -75.19143196896502 |
|
|
|
|
|
lon_std = 0.0007452186171662059 |
|
|
|
|
|
|
|
|
```python |
|
|
model_name = "AppliedMLReedShreya/ViT_Attempt_4" |
|
|
config = AutoConfig.from_pretrained(model_name) |
|
|
config.num_labels = 2 # We need two outputs: latitude and longitude |
|
|
|
|
|
# Load the pre-trained ViT model |
|
|
vit_model = AutoModelForImageClassification.from_pretrained(model_name, config=config) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f'Using device: {device}') |
|
|
vit_model = vit_model.to(device) |
|
|
|
|
|
# Initialize lists to store predictions and actual values |
|
|
all_preds = [] |
|
|
all_actuals = [] |
|
|
|
|
|
vit_model.eval() |
|
|
with torch.no_grad(): |
|
|
for images, gps_coords in val_dataloader: |
|
|
images, gps_coords = images.to(device), gps_coords.to(device) |
|
|
|
|
|
outputs = vit_model(images).logits |
|
|
|
|
|
# Denormalize predictions and actual values |
|
|
preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean]) |
|
|
actuals = gps_coords.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean]) |
|
|
|
|
|
all_preds.append(preds) |
|
|
all_actuals.append(actuals) |
|
|
|
|
|
# Concatenate all batches |
|
|
all_preds = torch.cat(all_preds).numpy() |
|
|
all_actuals = torch.cat(all_actuals).numpy() |
|
|
``` |
|
|
|