| # Authors: The scikit-learn developers | |
| # SPDX-License-Identifier: BSD-3-Clause | |
| from cython cimport floating | |
| from libc.math cimport fabs | |
| def _update_cdnmf_fast(floating[:, ::1] W, floating[:, :] HHt, | |
| floating[:, :] XHt, Py_ssize_t[::1] permutation): | |
| cdef: | |
| floating violation = 0 | |
| Py_ssize_t n_components = W.shape[1] | |
| Py_ssize_t n_samples = W.shape[0] # n_features for H update | |
| floating grad, pg, hess | |
| Py_ssize_t i, r, s, t | |
| with nogil: | |
| for s in range(n_components): | |
| t = permutation[s] | |
| for i in range(n_samples): | |
| # gradient = GW[t, i] where GW = np.dot(W, HHt) - XHt | |
| grad = -XHt[i, t] | |
| for r in range(n_components): | |
| grad += HHt[t, r] * W[i, r] | |
| # projected gradient | |
| pg = min(0., grad) if W[i, t] == 0 else grad | |
| violation += fabs(pg) | |
| # Hessian | |
| hess = HHt[t, t] | |
| if hess != 0: | |
| W[i, t] = max(W[i, t] - grad / hess, 0.) | |
| return violation | |