geolocation / src /g3 /pe /projection_rff.py
3v324v23's picture
init prj
eff2be4
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import itertools
from transformers import CLIPTokenizer, CLIPImageProcessor, CLIPModel
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from ..rff.layers import GaussianEncoding
from pyproj import Proj, Transformer
SF = 66.50336
class ProjectionRFF(nn.Module):
def __init__(self, projection="ecef", sigma=[2**0, 2**4, 2**8], hparams=None, device='cuda'):
super(ProjectionRFF, self).__init__()
self.device = device
self.sigma = sigma
self.num_hierarchies = len(self.sigma)
self.projection = projection.lower()
self.embedding_dim = [512] * self.num_hierarchies
proj_wgs84 = Proj('epsg:4326')
if self.projection == "mercator":
proj_target = Proj('epsg:3857')
input_dim = 2
self.normalizer = 20037508.3427892
elif self.projection == "eep":
proj_target = Proj('epsg:8857')
input_dim = 2
self.normalizer = 180/SF
elif self.projection == "ecef":
proj_target = Proj('epsg:4978')
input_dim = 3
self.normalizer = 6378137.0 # radius of Earth, not exact for ECEF but usable
else:
raise ValueError(f"Unsupported projection: {self.projection}")
self.transformer = Transformer.from_proj(proj_wgs84, proj_target, always_xy=True)
for i, s in enumerate(self.sigma):
self.add_module('LocEnc' + str(i), GaussianEncoding(sigma=s, input_size=input_dim, encoded_size=256))
def forward(self, input):
lat = input[:, 0].float().detach().cpu().numpy()
lon = input[:, 1].float().detach().cpu().numpy()
# lon (batch), lat (batch)
# Shape: (batch, 2) or (batch, 3) depending on projection
if self.projection == "ecef":
alt = np.zeros_like(lat)
projected = self.transformer.transform(lon, lat, alt)
location = list(zip(*projected)) # X, Y, Z
location = torch.Tensor(location).to(self.device)
else:
projected = self.transformer.transform(lon, lat)
location = [[y, x] for x, y in zip(*projected)]
location = torch.Tensor(location).to(self.device)
location = location / self.normalizer
out = []
for i in range(self.num_hierarchies):
out.append(self._modules['LocEnc' + str(i)](location))
location_features = torch.stack(out, dim=0) # (hierarchies, batch, 512)
return location_features