/* Stockfish, a UCI chess playing engine derived from Glaurung 2.1 Copyright (C) 2004-2026 The Stockfish developers (see AUTHORS file) Stockfish is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. Stockfish is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ #ifndef NETWORK_H_INCLUDED #define NETWORK_H_INCLUDED #include #include #include #include #include #include #include #include #include #include "../misc.h" #include "../types.h" #include "nnue_accumulator.h" #include "nnue_architecture.h" #include "nnue_common.h" #include "nnue_feature_transformer.h" #include "nnue_misc.h" namespace Stockfish { class Position; } namespace Stockfish::Eval::NNUE { enum class EmbeddedNNUEType { BIG, SMALL, }; using NetworkOutput = std::tuple; // The network must be a trivial type, i.e. the memory must be in-line. // This is required to allow sharing the network via shared memory, as // there is no way to run destructors. template class Network { static constexpr IndexType FTDimensions = Arch::TransformedFeatureDimensions; public: Network(EvalFile file, EmbeddedNNUEType type) : evalFile(file), embeddedType(type) {} Network(const Network& other) = default; Network(Network&& other) = default; Network& operator=(const Network& other) = default; Network& operator=(Network&& other) = default; void load(const std::string& rootDirectory, std::string evalfilePath); bool save(const std::optional& filename) const; std::size_t get_content_hash() const; NetworkOutput evaluate(const Position& pos, AccumulatorStack& accumulatorStack, AccumulatorCaches::Cache& cache) const; void verify(std::string evalfilePath, const std::function&) const; NnueEvalTrace trace_evaluate(const Position& pos, AccumulatorStack& accumulatorStack, AccumulatorCaches::Cache& cache) const; private: void load_user_net(const std::string&, const std::string&); void load_internal(); void initialize(); bool save(std::ostream&, const std::string&, const std::string&) const; std::optional load(std::istream&); bool read_header(std::istream&, std::uint32_t*, std::string*) const; bool write_header(std::ostream&, std::uint32_t, const std::string&) const; bool read_parameters(std::istream&, std::string&); bool write_parameters(std::ostream&, const std::string&) const; // Input feature converter Transformer featureTransformer; // Evaluation function Arch network[LayerStacks]; EvalFile evalFile; EmbeddedNNUEType embeddedType; bool initialized = false; // Hash value of evaluation function structure static constexpr std::uint32_t hash = Transformer::get_hash_value() ^ Arch::get_hash_value(); template friend struct AccumulatorCaches::Cache; }; // Definitions of the network types using SmallFeatureTransformer = FeatureTransformer; using SmallNetworkArchitecture = NetworkArchitecture; using BigFeatureTransformer = FeatureTransformer; using BigNetworkArchitecture = NetworkArchitecture; using NetworkBig = Network; using NetworkSmall = Network; struct Networks { Networks(EvalFile bigFile, EvalFile smallFile) : big(bigFile, EmbeddedNNUEType::BIG), small(smallFile, EmbeddedNNUEType::SMALL) {} NetworkBig big; NetworkSmall small; }; } // namespace Stockfish template struct std::hash> { std::size_t operator()( const Stockfish::Eval::NNUE::Network& network) const noexcept { return network.get_content_hash(); } }; template<> struct std::hash { std::size_t operator()(const Stockfish::Eval::NNUE::Networks& networks) const noexcept { std::size_t h = 0; Stockfish::hash_combine(h, networks.big); Stockfish::hash_combine(h, networks.small); return h; } }; #endif