Spaces:
Build error
Build error
| import torch | |
| import numpy as np | |
| from datetime import datetime, timedelta | |
| from aurora import AuroraWave, Batch, Metadata | |
| class WavePredictor: | |
| def __init__(self): | |
| # Load Aurora Wave model | |
| self.model = AuroraWave() | |
| self.model.load_checkpoint() | |
| self.model.eval() | |
| # Coastal locations similar to Maldives | |
| self.locations = { | |
| "Maldives": (4.1755, 73.5093), | |
| "Phu Quoc, Vietnam": (10.227, 103.963), | |
| "Con Dao, Vietnam": (8.6833, 106.5833), | |
| "Andaman Islands": (12.000, 92.900), | |
| "Nicobar Islands": (7.000, 93.700), | |
| "Lakshadweep": (10.5667, 72.6167), | |
| "Palawan, Philippines": (9.8349, 118.7384), | |
| "Koh Phi Phi, Thailand": (7.7407, 98.7784), | |
| "Seychelles": (-4.6796, 55.4919), | |
| "Zanzibar, Tanzania": (-6.1659, 39.2026) | |
| } | |
| # Your selected variables[1] | |
| self.variables = { | |
| "mwp": "Mean Wave Period (s)", | |
| "mwd": "Mean Wave Direction (°)", | |
| "pp1d": "Peak Wave Period (s)", | |
| "wind": "Wind Speed (m/s)" | |
| } | |
| def create_wave_batch(self, center_lat, center_lon, grid_size=15): | |
| """Create a batch for wave predictions around a location""" | |
| # Create grid around selected location | |
| lat_range = torch.linspace(center_lat + 5, center_lat - 5, grid_size) | |
| lon_range = torch.linspace(center_lon - 5, center_lon + 5, grid_size) | |
| # Generate realistic wave data for the region[1] | |
| batch = Batch( | |
| surf_vars={ | |
| "mwp": torch.normal(8.0, 2.0, (1, 2, grid_size, grid_size)), # 8s ± 2s | |
| "mwd": torch.normal(180, 45, (1, 2, grid_size, grid_size)), # SW direction ± 45° | |
| "pp1d": torch.normal(12.0, 3.0, (1, 2, grid_size, grid_size)), # 12s ± 3s | |
| "wind": torch.normal(5.0, 2.0, (1, 2, grid_size, grid_size)), # 5 m/s ± 2 m/s | |
| # Add other required wave variables with default values | |
| "swh": torch.normal(2.0, 0.5, (1, 2, grid_size, grid_size)), # Significant wave height | |
| "10u_wind": torch.normal(3.0, 2.0, (1, 2, grid_size, grid_size)), | |
| "10v_wind": torch.normal(2.0, 2.0, (1, 2, grid_size, grid_size)), | |
| }, | |
| static_vars={ | |
| "lsm": torch.zeros(grid_size, grid_size), # Ocean mask | |
| "z": torch.zeros(grid_size, grid_size), # Sea level | |
| "slt": torch.zeros(grid_size, grid_size) # Soil type (irrelevant for ocean) | |
| }, | |
| atmos_vars={ | |
| "z": torch.normal(0, 1000, (1, 2, 4, grid_size, grid_size)), | |
| "u": torch.normal(3, 3, (1, 2, 4, grid_size, grid_size)), | |
| "v": torch.normal(2, 3, (1, 2, 4, grid_size, grid_size)), | |
| "t": torch.normal(285, 10, (1, 2, 4, grid_size, grid_size)), | |
| "q": torch.normal(0.012, 0.003, (1, 2, 4, grid_size, grid_size)) | |
| }, | |
| metadata=Metadata( | |
| lat=lat_range, | |
| lon=lon_range, | |
| time=(datetime.now(),), | |
| atmos_levels=(100, 250, 500, 850), | |
| ), | |
| ) | |
| return batch | |
| def get_predictions(self, location_name, steps=8): | |
| """Get wave predictions for a location""" | |
| if location_name not in self.locations: | |
| return None | |
| lat, lon = self.locations[location_name] | |
| batch = self.create_wave_batch(lat, lon) | |
| predictions = [] | |
| current_batch = batch | |
| # Generate multi-step predictions[2] | |
| for step in range(steps): | |
| with torch.inference_mode(): | |
| pred = self.model.forward(current_batch) | |
| # Extract center point (location of interest) | |
| center_idx = len(batch.metadata.lat) // 2 | |
| step_data = { | |
| "timestamp": datetime.now() + timedelta(hours=6*(step+1)), | |
| "step": step + 1, | |
| "predictions": {} | |
| } | |
| # Extract predictions for your variables | |
| for var_code, var_name in self.variables.items(): | |
| value = pred.surf_vars[var_code][0, 0, center_idx, center_idx].item() | |
| # Ensure realistic ranges | |
| if var_code == "mwd": | |
| value = value % 360 # Keep direction in 0-360 range | |
| elif var_code in ["mwp", "pp1d"]: | |
| value = max(3.0, min(25.0, value)) # Wave periods 3-25s | |
| elif var_code == "wind": | |
| value = max(0.0, min(30.0, value)) # Wind speed 0-30 m/s | |
| step_data["predictions"][var_code] = round(value, 2) | |
| predictions.append(step_data) | |
| current_batch = pred | |
| return predictions | |