File size: 2,509 Bytes
94184e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

import torch
import torchvision.transforms as transforms
import pandas as pd
from PIL import Image
import os
import numpy as np

def prepare_data(csv_path: str):
    """
    Load images and GPS coordinates from CSV.
    
    Args:
        csv_path: Path to CSV file with columns: image_path/filepath/image/path/file_name, 
                  Latitude/latitude/lat, Longitude/longitude/lon
    
    Returns:
        X: torch.Tensor of shape (N, 3, 224, 224) - normalized images
        y: torch.Tensor of shape (N, 2) - raw lat/lon in degrees
    """
    
    # Read CSV
    df = pd.read_csv(csv_path)
    
    # Find the correct column names (case-insensitive)
    image_col = None
    lat_col = None
    lon_col = None
    
    for col in df.columns:
        col_lower = col.lower()
        if col_lower in ['image_path', 'filepath', 'image', 'path', 'file_name']:
            image_col = col
        elif col_lower in ['latitude', 'lat']:
            lat_col = col
        elif col_lower in ['longitude', 'lon']:
            lon_col = col
    
    if image_col is None or lat_col is None or lon_col is None:
        raise ValueError(f"Could not find image, latitude, or longitude columns in CSV")
    
    # Define transform (same as used during training)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Load images and GPS coordinates
    images = []
    gps_coords = []
    
    csv_dir = os.path.dirname(csv_path)
    
    for idx, row in df.iterrows():
        # Get image path
        img_path = row[image_col]
        
        # Handle relative paths
        if not os.path.isabs(img_path):
            img_path = os.path.join(csv_dir, img_path)
        
        try:
            # Load and transform image
            image = Image.open(img_path).convert('RGB')
            image = transform(image)
            images.append(image)
            
            # Get GPS coordinates (raw, in degrees)
            lat = float(row[lat_col])
            lon = float(row[lon_col])
            gps_coords.append([lat, lon])
        except Exception as e:
            print(f"Warning: Could not load image {img_path}: {e}")
            continue
    
    # Convert to tensors
    X = torch.stack(images)  # Shape: (N, 3, 224, 224)
    y = torch.tensor(gps_coords, dtype=torch.float32)  # Shape: (N, 2)
    
    return X, y