Parallelisation / Execution on TPUs
Hi guys!
I was trying to run a multiple ensembles, and thinking of how to best optimise for parallel execution. AWS is not great with "single-GPU" instances, so I was wondering what strategy you guys use to cost-optimise the execution of the 50 ensembles?
Also, have anyone tried running this model on Google TPUs?
Thank you so much!
Jane
Hi, Check this out, https://huggingface.co/ecmwf/aifs-ens-1.0/discussions/17#692c03ed1221c480154e065f, it is working 24GB GPU, to potentially use cloudrun micro services. TPU could be great, but the pytorch anemoi flash attention would be a blocker.
Also, have anyone tried running this model on Google TPUs?
I have tried, without any success. TPU v3 and lower dont support Torchs DistributedDataParallel module which anemoi uses. Newer TPUs weren't (as of Feb 2025) supported by Slurm which we were using to run (via Googles Cluster Toolkit service). The only option is to use some serverless execution model like GCP Vertex, and I didn't have time to port my setup.
Indeed the lack of flash-attention would be a problem, but perhaps there is a sliding-window attention function implemented in JAX which one could switch to. Who knows how many other issues there could be :D
cost-optimise the execution of 50 ensemble members
ensemble members are trivially parallel during inference, one could set up some batch jobs running on older-gen spot GPUs