Image-to-Text
Transformers
PyTorch
English
Geo-Localization
kevinloeffler commited on
Commit
3c67fc0
·
verified ·
1 Parent(s): 91ab0ec

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +1 -2
README.md CHANGED
@@ -48,14 +48,12 @@ Example inference:
48
  # imports
49
  import torch
50
  from PIL import Image
51
-
52
  from model import LocationDecoder # ReGeo model class: https://github.com/TobiasRothlin/GeoLocalization/blob/main/src/DGX1/src/RegressionPretraining/Model.py
53
  from transformers import CLIPProcessor
54
 
55
  # load custom config (do not use AutoConfig), an example can be found in this repo
56
  config = { ... }
57
 
58
- #
59
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60
  preprocessor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14-336')
61
  model = LocationDecoder.from_pretrained('OSTswiss/ReGeo', config=config)
@@ -70,6 +68,7 @@ image = Image.open(image_path)
70
  model_input = preprocessor(images=image, return_tensors="pt")
71
  pixel_values = model_input['pixel_values'].to(device)
72
 
 
73
  with torch.no_grad():
74
  output = model(pixel_values)
75
  normal_coordinates = output.squeeze().tolist()
 
48
  # imports
49
  import torch
50
  from PIL import Image
 
51
  from model import LocationDecoder # ReGeo model class: https://github.com/TobiasRothlin/GeoLocalization/blob/main/src/DGX1/src/RegressionPretraining/Model.py
52
  from transformers import CLIPProcessor
53
 
54
  # load custom config (do not use AutoConfig), an example can be found in this repo
55
  config = { ... }
56
 
 
57
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
58
  preprocessor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14-336')
59
  model = LocationDecoder.from_pretrained('OSTswiss/ReGeo', config=config)
 
68
  model_input = preprocessor(images=image, return_tensors="pt")
69
  pixel_values = model_input['pixel_values'].to(device)
70
 
71
+ # run inference
72
  with torch.no_grad():
73
  output = model(pixel_values)
74
  normal_coordinates = output.squeeze().tolist()