File size: 2,432 Bytes
94184e8 dba8f48 94184e8 dba8f48 94184e8 dba8f48 94184e8 dba8f48 94184e8 dba8f48 94184e8 dba8f48 94184e8 dba8f48 94184e8 dba8f48 94184e8 dba8f48 94184e8 dba8f48 94184e8 dba8f48 94184e8 dba8f48 94184e8 dba8f48 94184e8 dba8f48 94184e8 dba8f48 94184e8 dba8f48 94184e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import torch
import torchvision.transforms as transforms
import pandas as pd
from PIL import Image
import os
import numpy as np
def prepare_data(csv_path: str):
"""
Load images and GPS coordinates from CSV.
Args:
csv_path: Path to CSV file with columns: image_path/filepath/image/path/file_name,
Latitude/latitude/lat, Longitude/longitude/lon
Returns:
X: torch.Tensor of shape (N, 3, 224, 224) - normalized images
y: torch.Tensor of shape (N, 2) - raw lat/lon in degrees
"""
# Read CSV
df = pd.read_csv(csv_path)
# Find the correct column names (case-insensitive)
image_col = None
lat_col = None
lon_col = None
for col in df.columns:
col_lower = col.lower()
if col_lower in ['image_path', 'filepath', 'image', 'path', 'file_name']:
image_col = col
elif col_lower in ['latitude', 'lat']:
lat_col = col
elif col_lower in ['longitude', 'lon']:
lon_col = col
if image_col is None or lat_col is None or lon_col is None:
raise ValueError(f"Could not find image, latitude, or longitude columns in CSV")
# Define transform (same as used during training)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Load images and GPS coordinates
images = []
gps_coords = []
csv_dir = os.path.dirname(csv_path)
for idx, row in df.iterrows():
# Get image path
img_path = row[image_col]
# Handle relative paths
if not os.path.isabs(img_path):
img_path = os.path.join(csv_dir, img_path)
try:
# Load and transform image
image = Image.open(img_path).convert('RGB')
image = transform(image)
images.append(image)
# Get GPS coordinates (raw, in degrees)
lat = float(row[lat_col])
lon = float(row[lon_col])
gps_coords.append([lat, lon])
except Exception as e:
print(f"Warning: Could not load image {img_path}: {e}")
continue
# Convert to tensors
X = torch.stack(images) # Shape: (N, 3, 224, 224)
y = torch.tensor(gps_coords, dtype=torch.float32) # Shape: (N, 2)
return X, y
|