SymTorch: Symbolic Distillation of Neural Networks
What mathematical functions do neural network components learn? Symbolic distillation addresses this question by expressing neural network components with interpretable, closed-form mathematical expressions that expose the functional structure learned during training. We develop symbolic distillation as a systematic, architecture-agnostic methodology, and release our approach as the open-source SymTorch package - a PySR-powered library built natively for the PyTorch ecosystem. Applying this methodology across diverse architectures, we find that SymTorch is successful in the automated discovery of physical laws. Specifically, our approach (1) recovers pairwise interaction forces from graph neural networks trained on empirical n-body observations, (2) distills the exact closed-form PDE/ODE solutions of multiple physical systems, including the value of constants, from physics-informed neural networks trained on sparse data, and (3) uncovers the chaotic dynamics of the Lorenz system from high-dimensional data, ultimately outperforming the base neural network on downstream prediction tasks. We further demonstrate the utility of our framework for model interpretability by providing an optimized implementation of SLIME - a symbolic extension to the LIME explainability method. SLIME consistently outperforms LIME across predictive metrics across eight popular classification and regression benchmarks, while still providing an interpretable local symbolic model. Lastly, we investigate replacing transformer MLP layers with symbolic surrogates: replacing 1-7 layers with symbolic approximations yields 2-19\% throughput improvements and up to 18.7\% VRAM reduction, with the resulting hybrid models lying on the Pareto front of throughput versus perplexity among open-source LLMs of comparable scale.
