File size: 2,432 Bytes
94184e8
 
 
 
 
 
 
 
 
 
 
dba8f48
94184e8
dba8f48
94184e8
dba8f48
94184e8
 
 
 
dba8f48
94184e8
 
dba8f48
94184e8
 
 
 
dba8f48
94184e8
 
 
 
 
 
 
 
dba8f48
94184e8
 
dba8f48
94184e8
 
 
 
 
 
 
dba8f48
94184e8
 
 
dba8f48
94184e8
dba8f48
94184e8
 
 
dba8f48
94184e8
 
 
dba8f48
94184e8
 
 
 
 
dba8f48
94184e8
 
 
 
 
 
 
dba8f48
94184e8
 
 
dba8f48
94184e8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

import torch
import torchvision.transforms as transforms
import pandas as pd
from PIL import Image
import os
import numpy as np

def prepare_data(csv_path: str):
    """
    Load images and GPS coordinates from CSV.

    Args:
        csv_path: Path to CSV file with columns: image_path/filepath/image/path/file_name,
                  Latitude/latitude/lat, Longitude/longitude/lon

    Returns:
        X: torch.Tensor of shape (N, 3, 224, 224) - normalized images
        y: torch.Tensor of shape (N, 2) - raw lat/lon in degrees
    """

    # Read CSV
    df = pd.read_csv(csv_path)

    # Find the correct column names (case-insensitive)
    image_col = None
    lat_col = None
    lon_col = None

    for col in df.columns:
        col_lower = col.lower()
        if col_lower in ['image_path', 'filepath', 'image', 'path', 'file_name']:
            image_col = col
        elif col_lower in ['latitude', 'lat']:
            lat_col = col
        elif col_lower in ['longitude', 'lon']:
            lon_col = col

    if image_col is None or lat_col is None or lon_col is None:
        raise ValueError(f"Could not find image, latitude, or longitude columns in CSV")

    # Define transform (same as used during training)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    # Load images and GPS coordinates
    images = []
    gps_coords = []

    csv_dir = os.path.dirname(csv_path)

    for idx, row in df.iterrows():
        # Get image path
        img_path = row[image_col]

        # Handle relative paths
        if not os.path.isabs(img_path):
            img_path = os.path.join(csv_dir, img_path)

        try:
            # Load and transform image
            image = Image.open(img_path).convert('RGB')
            image = transform(image)
            images.append(image)

            # Get GPS coordinates (raw, in degrees)
            lat = float(row[lat_col])
            lon = float(row[lon_col])
            gps_coords.append([lat, lon])
        except Exception as e:
            print(f"Warning: Could not load image {img_path}: {e}")
            continue

    # Convert to tensors
    X = torch.stack(images)  # Shape: (N, 3, 224, 224)
    y = torch.tensor(gps_coords, dtype=torch.float32)  # Shape: (N, 2)

    return X, y