Update README.md
Browse files
README.md
CHANGED
|
@@ -3,53 +3,4 @@ library_name: transformers
|
|
| 3 |
tags: []
|
| 4 |
---
|
| 5 |
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
This is a fine-tuned Vision Transformer (ViT) model from Google. The model was loaded and fine-tuned on the training data collected.
|
| 9 |
-
Compared to Attempt 1, we are using the expanded dataset, trained for 7 epochs instead of 5, and only updated the classifier parameters at training time.
|
| 10 |
-
|
| 11 |
-
Link: https://huggingface.co/google/vit-base-patch16-224-in21k
|
| 12 |
-
|
| 13 |
-
lat_mean = 39.951640614844095
|
| 14 |
-
|
| 15 |
-
lat_std = 0.0007502796001097172
|
| 16 |
-
|
| 17 |
-
lon_mean = -75.19143196896502
|
| 18 |
-
|
| 19 |
-
lon_std = 0.0007452186171662059
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
```python
|
| 23 |
-
model_name = "AppliedMLReedShreya/ViT_Attempt_3"
|
| 24 |
-
config = AutoConfig.from_pretrained(model_name)
|
| 25 |
-
config.num_labels = 2 # We need two outputs: latitude and longitude
|
| 26 |
-
|
| 27 |
-
# Load the pre-trained ViT model
|
| 28 |
-
vit_model = AutoModelForImageClassification.from_pretrained(model_name, config=config)
|
| 29 |
-
|
| 30 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
-
print(f'Using device: {device}')
|
| 32 |
-
vit_model = vit_model.to(device)
|
| 33 |
-
|
| 34 |
-
# Initialize lists to store predictions and actual values
|
| 35 |
-
all_preds = []
|
| 36 |
-
all_actuals = []
|
| 37 |
-
|
| 38 |
-
vit_model.eval()
|
| 39 |
-
with torch.no_grad():
|
| 40 |
-
for images, gps_coords in val_dataloader:
|
| 41 |
-
images, gps_coords = images.to(device), gps_coords.to(device)
|
| 42 |
-
|
| 43 |
-
outputs = vit_model(images).logits
|
| 44 |
-
|
| 45 |
-
# Denormalize predictions and actual values
|
| 46 |
-
preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])
|
| 47 |
-
actuals = gps_coords.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])
|
| 48 |
-
|
| 49 |
-
all_preds.append(preds)
|
| 50 |
-
all_actuals.append(actuals)
|
| 51 |
-
|
| 52 |
-
# Concatenate all batches
|
| 53 |
-
all_preds = torch.cat(all_preds).numpy()
|
| 54 |
-
all_actuals = torch.cat(all_actuals).numpy()
|
| 55 |
-
```
|
|
|
|
| 3 |
tags: []
|
| 4 |
---
|
| 5 |
|
| 6 |
+
# DO NOT USE THIS MODEL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|