DexerK commited on
Commit
7873ea4
·
verified ·
1 Parent(s): c043e22

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +1 -50
README.md CHANGED
@@ -3,53 +3,4 @@ library_name: transformers
3
  tags: []
4
  ---
5
 
6
- # Model Card for Model ID (Use this model for Leaderboard)
7
-
8
- This is a fine-tuned Vision Transformer (ViT) model from Google. The model was loaded and fine-tuned on the training data collected.
9
- Compared to Attempt 1, we are using the expanded dataset, trained for 7 epochs instead of 5, and only updated the classifier parameters at training time.
10
-
11
- Link: https://huggingface.co/google/vit-base-patch16-224-in21k
12
-
13
- lat_mean = 39.951640614844095
14
-
15
- lat_std = 0.0007502796001097172
16
-
17
- lon_mean = -75.19143196896502
18
-
19
- lon_std = 0.0007452186171662059
20
-
21
-
22
- ```python
23
- model_name = "AppliedMLReedShreya/ViT_Attempt_3"
24
- config = AutoConfig.from_pretrained(model_name)
25
- config.num_labels = 2 # We need two outputs: latitude and longitude
26
-
27
- # Load the pre-trained ViT model
28
- vit_model = AutoModelForImageClassification.from_pretrained(model_name, config=config)
29
-
30
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- print(f'Using device: {device}')
32
- vit_model = vit_model.to(device)
33
-
34
- # Initialize lists to store predictions and actual values
35
- all_preds = []
36
- all_actuals = []
37
-
38
- vit_model.eval()
39
- with torch.no_grad():
40
- for images, gps_coords in val_dataloader:
41
- images, gps_coords = images.to(device), gps_coords.to(device)
42
-
43
- outputs = vit_model(images).logits
44
-
45
- # Denormalize predictions and actual values
46
- preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])
47
- actuals = gps_coords.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])
48
-
49
- all_preds.append(preds)
50
- all_actuals.append(actuals)
51
-
52
- # Concatenate all batches
53
- all_preds = torch.cat(all_preds).numpy()
54
- all_actuals = torch.cat(all_actuals).numpy()
55
- ```
 
3
  tags: []
4
  ---
5
 
6
+ # DO NOT USE THIS MODEL