| \documentclass{article} |
| \usepackage{graphicx} |
| \usepackage{hyperref} |
| \usepackage{amsmath} |
| \usepackage{caption} |
| \usepackage{tgtermes} |
| \usepackage{float} |
| \usepackage[a4paper, margin=1in]{geometry} |
| \usepackage{booktabs} |
| \usepackage{algorithm} |
| \usepackage{algorithmicx} |
| \usepackage{algpseudocode} |
| \date{} |
|
|
| \begin{document} |
|
|
| {\LARGE \bfseries Parallelize Muon with FSDP2 \par} |
| \vspace{1em} |
|
|
| \section*{Motivation} |
|
|
| \begin{figure}[H] |
| \centering |
| \includegraphics[width=0.8\textwidth]{distributed_muon.png} |
| \caption*{Distributed Muon by Moonlight} |
| \end{figure} |
|
|
| While a distributed version of Muon is available, it has the drawback of redundant computations across GPUs. |
|
|
| \begin{figure}[H] |
| \centering |
| \includegraphics[width=1.0\textwidth]{distributed_muon_execution.png} |
| \caption*{Execution timeline of Distributed Muon} |
| \end{figure} |
|
|
| \begin{itemize} |
| \item \texttt{C[i]} : Compute Newton-Schulz(G) for i-th gradient |
| \item \texttt{AG[i]} : AllGather i-th gradient |
| \item \texttt{G[i]} : Gather i-th gradient |
| \item \texttt{SC[i]} : Scatter i-th gradient |
| \end{itemize} |
| \clearpage |
| \section*{Algorithm} |
|
|
| \subsection*{Parallel Muon} |
|
|
| \begin{algorithm} |
| \caption{Parallel Muon} |
| \textbf{Require:} DP partitioned gradient $\mathbf{g}$, DP partitioned Momentum $\mathbf{m}$, DP partitioned parameter $\mathbf{p}$, momentum $\mu$, local rank $\mathbf{r}$ |
| \begin{algorithmic}[1] |
| \State \texttt{// Apply momentum to $\mathbf{g}$ using local partitioned momentum $\mathbf{m}$} |
| \State $\mathbf{g'} \gets \text{update\_with\_momentum}(\mathbf{g}, \mathbf{m}, \mu)$ |
| \State \texttt{// Schedule $\mathbf{g'}$ to rank $\mathbf{R}$} |
| \State $\mathbf{R} \gets \text{schedule}(\mathbf{g'}, \text{dp\_group})$ |
| \State \texttt{// Gather $\mathbf{g'}$ across DP into a full matrix $\mathbf{G}$ to rank $\mathbf{R}$} |
| \State $\mathbf{G} \gets \text{gather}(\mathbf{g'}, \text{dp\_group}, \text{dst=}\mathbf{R})$ |
| \State \texttt{// Calculate Newton-Schulz only in $\mathbf{R}$} |
| \If{$\mathbf{r}$ == $\mathbf{R}$} |
| \State $\mathbf{u} \gets \text{Newton-Schulz}(\mathbf{G})$ |
| \Else |
| \State $\mathbf{u} \gets None$ |
| \EndIf |
|
|
| \State \texttt{// Scatter a full matrix $\mathbf{u}$ across DP} |
| \State $\mathbf{u'} \gets \text{scatter}(\mathbf{u},\text{dp\_group},\text{src=}\mathbf{R})$ |
| \State \texttt{// Apply DP partitioned $\mathbf{u'}$ to $\mathbf{p}$} |
| \State $\mathbf{p'} \gets \text{apply\_update}(\mathbf{p}, \mathbf{u'})$ |
| \State \textbf{return $\mathbf{p'}$} |
| \end{algorithmic} |
| \end{algorithm} |
|
|
| We eliminate redundant computation by assigning each parameter to a specific GPU. |
|
|
| However, without proper scheduling, this optimization can lead to poor GPU utilization. In particular, although redundant computation is avoided by assigning each parameter to a specific rank, it causes idle time—since all other ranks must wait for the scatter communication to complete before proceeding. |
|
|
| \begin{figure}[H] |
| \centering |
| \includegraphics[width=1.0\textwidth]{naive_execution.png} |
| \caption*{Execution timeline of Parallel Muon} |
| \end{figure} |
|
|
| \subsection*{Scheduling Sub-Operations} |
|
|
| We can schedule the whole sub-operations as follows, due to the following reasons: |
| \begin{itemize} |
| \item There are no dependencies between parameters. |
| \item GPUs can execute computation and communication concurrently. |
| \end{itemize} |
|
|
| \begin{figure}[H] |
| \centering |
| \includegraphics[width=1.0\textwidth]{pipelined.png} |
| \caption*{Execution timeline of re-scheduled Parallel Muon} |
| \end{figure} |
|
|
| We define the chunk size $C$ as the number of GPUs and schedule each sub-operation in batches of size $C$. This scheduling allows each GPU to continue computation even while waiting for collective communication to complete. |
|
|
| \textbf{[Algorithm]} (To be written) |
| \clearpage |
| \subsection*{Load Balancing} |
|
|
| If parameters in a chunk have imbalanced computation loads, idle bubbles may occur. \\ |
| To mitigate this, we apply load balancing based on per-parameter FLOPs. |
|
|
| \vspace{1em} |
| \textbf{Imbalanced (Round Robin)} |
|
|
| \begin{figure}[H] |
| \centering |
| \includegraphics[width=1.0\textwidth]{imbalance.png} |
| \end{figure} |
|
|
| \textbf{After Load Balancing} |
|
|
| \begin{figure}[H] |
| \centering |
| \includegraphics[width=1.0\textwidth]{balanced.png} |
| \end{figure} |
|
|
| \section*{Implementation} |
|
|
| The full implementation is available in \texttt{optimizer/torch-ext/optimizer/muon.py}. |
| To enable concurrent computation and communication, we use separate compute and communication streams (\texttt{torch.cuda.Stream}) and use \texttt{torch.cuda.Event} to synchronize between sub-operations. |
|
|
| Thanks to the simplicity of \texttt{torch.DTensor} and \texttt{torch.distributed}, the implementation remains straightforward and low in complexity. |
|
|
| \section*{Evaluation} |
| We evaluated the performance using 10B model currently in development, achieving 151 TFLOPS per GPU during the optimizer step. |
|
|
| \begin{table}[H] |
| \centering |
| \begin{tabular}{@{}lllll@{}} |
| \toprule |
| Model Size & TFLOPs for Muon & GPUs & Elapsed time & TFLOPS/GPU \\ |
| \midrule |
| 10B & 847.45 & 4xMI250 (8 devices) & 1.4 s & 151 \\ |
| \bottomrule |
| \end{tabular} |
| \end{table} |
| Based on the breakdown, 7\% of the time is attributed to updating sharded gradients and parameters, 78\% to GEMM operations, and the remaining 15\% to non-overlapped communication overhead. |
|
|
| \end{document} |