|
|
|
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
import pandas as pd |
|
|
from PIL import Image |
|
|
import os |
|
|
import numpy as np |
|
|
|
|
|
def prepare_data(csv_path: str): |
|
|
""" |
|
|
Load images and GPS coordinates from CSV. |
|
|
|
|
|
Args: |
|
|
csv_path: Path to CSV file with columns: image_path/filepath/image/path/file_name, |
|
|
Latitude/latitude/lat, Longitude/longitude/lon |
|
|
|
|
|
Returns: |
|
|
X: torch.Tensor of shape (N, 3, 224, 224) - normalized images |
|
|
y: torch.Tensor of shape (N, 2) - raw lat/lon in degrees |
|
|
""" |
|
|
|
|
|
|
|
|
df = pd.read_csv(csv_path) |
|
|
|
|
|
|
|
|
image_col = None |
|
|
lat_col = None |
|
|
lon_col = None |
|
|
|
|
|
for col in df.columns: |
|
|
col_lower = col.lower() |
|
|
if col_lower in ['image_path', 'filepath', 'image', 'path', 'file_name']: |
|
|
image_col = col |
|
|
elif col_lower in ['latitude', 'lat']: |
|
|
lat_col = col |
|
|
elif col_lower in ['longitude', 'lon']: |
|
|
lon_col = col |
|
|
|
|
|
if image_col is None or lat_col is None or lon_col is None: |
|
|
raise ValueError(f"Could not find image, latitude, or longitude columns in CSV") |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
images = [] |
|
|
gps_coords = [] |
|
|
|
|
|
csv_dir = os.path.dirname(csv_path) |
|
|
|
|
|
for idx, row in df.iterrows(): |
|
|
|
|
|
img_path = row[image_col] |
|
|
|
|
|
|
|
|
if not os.path.isabs(img_path): |
|
|
img_path = os.path.join(csv_dir, img_path) |
|
|
|
|
|
try: |
|
|
|
|
|
image = Image.open(img_path).convert('RGB') |
|
|
image = transform(image) |
|
|
images.append(image) |
|
|
|
|
|
|
|
|
lat = float(row[lat_col]) |
|
|
lon = float(row[lon_col]) |
|
|
gps_coords.append([lat, lon]) |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not load image {img_path}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
X = torch.stack(images) |
|
|
y = torch.tensor(gps_coords, dtype=torch.float32) |
|
|
|
|
|
return X, y |
|
|
|