Aurora_Wave_App / aurora_predictor.py
Pr1ce1551's picture
Initial commit
f518dc3
import torch
import numpy as np
from datetime import datetime, timedelta
from aurora import AuroraWave, Batch, Metadata
class WavePredictor:
def __init__(self):
# Load Aurora Wave model
self.model = AuroraWave()
self.model.load_checkpoint()
self.model.eval()
# Coastal locations similar to Maldives
self.locations = {
"Maldives": (4.1755, 73.5093),
"Phu Quoc, Vietnam": (10.227, 103.963),
"Con Dao, Vietnam": (8.6833, 106.5833),
"Andaman Islands": (12.000, 92.900),
"Nicobar Islands": (7.000, 93.700),
"Lakshadweep": (10.5667, 72.6167),
"Palawan, Philippines": (9.8349, 118.7384),
"Koh Phi Phi, Thailand": (7.7407, 98.7784),
"Seychelles": (-4.6796, 55.4919),
"Zanzibar, Tanzania": (-6.1659, 39.2026)
}
# Your selected variables[1]
self.variables = {
"mwp": "Mean Wave Period (s)",
"mwd": "Mean Wave Direction (°)",
"pp1d": "Peak Wave Period (s)",
"wind": "Wind Speed (m/s)"
}
def create_wave_batch(self, center_lat, center_lon, grid_size=15):
"""Create a batch for wave predictions around a location"""
# Create grid around selected location
lat_range = torch.linspace(center_lat + 5, center_lat - 5, grid_size)
lon_range = torch.linspace(center_lon - 5, center_lon + 5, grid_size)
# Generate realistic wave data for the region[1]
batch = Batch(
surf_vars={
"mwp": torch.normal(8.0, 2.0, (1, 2, grid_size, grid_size)), # 8s ± 2s
"mwd": torch.normal(180, 45, (1, 2, grid_size, grid_size)), # SW direction ± 45°
"pp1d": torch.normal(12.0, 3.0, (1, 2, grid_size, grid_size)), # 12s ± 3s
"wind": torch.normal(5.0, 2.0, (1, 2, grid_size, grid_size)), # 5 m/s ± 2 m/s
# Add other required wave variables with default values
"swh": torch.normal(2.0, 0.5, (1, 2, grid_size, grid_size)), # Significant wave height
"10u_wind": torch.normal(3.0, 2.0, (1, 2, grid_size, grid_size)),
"10v_wind": torch.normal(2.0, 2.0, (1, 2, grid_size, grid_size)),
},
static_vars={
"lsm": torch.zeros(grid_size, grid_size), # Ocean mask
"z": torch.zeros(grid_size, grid_size), # Sea level
"slt": torch.zeros(grid_size, grid_size) # Soil type (irrelevant for ocean)
},
atmos_vars={
"z": torch.normal(0, 1000, (1, 2, 4, grid_size, grid_size)),
"u": torch.normal(3, 3, (1, 2, 4, grid_size, grid_size)),
"v": torch.normal(2, 3, (1, 2, 4, grid_size, grid_size)),
"t": torch.normal(285, 10, (1, 2, 4, grid_size, grid_size)),
"q": torch.normal(0.012, 0.003, (1, 2, 4, grid_size, grid_size))
},
metadata=Metadata(
lat=lat_range,
lon=lon_range,
time=(datetime.now(),),
atmos_levels=(100, 250, 500, 850),
),
)
return batch
def get_predictions(self, location_name, steps=8):
"""Get wave predictions for a location"""
if location_name not in self.locations:
return None
lat, lon = self.locations[location_name]
batch = self.create_wave_batch(lat, lon)
predictions = []
current_batch = batch
# Generate multi-step predictions[2]
for step in range(steps):
with torch.inference_mode():
pred = self.model.forward(current_batch)
# Extract center point (location of interest)
center_idx = len(batch.metadata.lat) // 2
step_data = {
"timestamp": datetime.now() + timedelta(hours=6*(step+1)),
"step": step + 1,
"predictions": {}
}
# Extract predictions for your variables
for var_code, var_name in self.variables.items():
value = pred.surf_vars[var_code][0, 0, center_idx, center_idx].item()
# Ensure realistic ranges
if var_code == "mwd":
value = value % 360 # Keep direction in 0-360 range
elif var_code in ["mwp", "pp1d"]:
value = max(3.0, min(25.0, value)) # Wave periods 3-25s
elif var_code == "wind":
value = max(0.0, min(30.0, value)) # Wind speed 0-30 m/s
step_data["predictions"][var_code] = round(value, 2)
predictions.append(step_data)
current_batch = pred
return predictions