Spaces:
Running
Running
| /* | |
| Stockfish, a UCI chess playing engine derived from Glaurung 2.1 | |
| Copyright (C) 2004-2026 The Stockfish developers (see AUTHORS file) | |
| Stockfish is free software: you can redistribute it and/or modify | |
| it under the terms of the GNU General Public License as published by | |
| the Free Software Foundation, either version 3 of the License, or | |
| (at your option) any later version. | |
| Stockfish is distributed in the hope that it will be useful, | |
| but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| GNU General Public License for more details. | |
| You should have received a copy of the GNU General Public License | |
| along with this program. If not, see <http://www.gnu.org/licenses/>. | |
| */ | |
| namespace Stockfish { | |
| class Position; | |
| } | |
| namespace Stockfish::Eval::NNUE { | |
| enum class EmbeddedNNUEType { | |
| BIG, | |
| SMALL, | |
| }; | |
| using NetworkOutput = std::tuple<Value, Value>; | |
| // The network must be a trivial type, i.e. the memory must be in-line. | |
| // This is required to allow sharing the network via shared memory, as | |
| // there is no way to run destructors. | |
| template<typename Arch, typename Transformer> | |
| class Network { | |
| static constexpr IndexType FTDimensions = Arch::TransformedFeatureDimensions; | |
| public: | |
| Network(EvalFile file, EmbeddedNNUEType type) : | |
| evalFile(file), | |
| embeddedType(type) {} | |
| Network(const Network& other) = default; | |
| Network(Network&& other) = default; | |
| Network& operator=(const Network& other) = default; | |
| Network& operator=(Network&& other) = default; | |
| void load(const std::string& rootDirectory, std::string evalfilePath); | |
| bool save(const std::optional<std::string>& filename) const; | |
| std::size_t get_content_hash() const; | |
| NetworkOutput evaluate(const Position& pos, | |
| AccumulatorStack& accumulatorStack, | |
| AccumulatorCaches::Cache<FTDimensions>& cache) const; | |
| void verify(std::string evalfilePath, const std::function<void(std::string_view)>&) const; | |
| NnueEvalTrace trace_evaluate(const Position& pos, | |
| AccumulatorStack& accumulatorStack, | |
| AccumulatorCaches::Cache<FTDimensions>& cache) const; | |
| private: | |
| void load_user_net(const std::string&, const std::string&); | |
| void load_internal(); | |
| void initialize(); | |
| bool save(std::ostream&, const std::string&, const std::string&) const; | |
| std::optional<std::string> load(std::istream&); | |
| bool read_header(std::istream&, std::uint32_t*, std::string*) const; | |
| bool write_header(std::ostream&, std::uint32_t, const std::string&) const; | |
| bool read_parameters(std::istream&, std::string&); | |
| bool write_parameters(std::ostream&, const std::string&) const; | |
| // Input feature converter | |
| Transformer featureTransformer; | |
| // Evaluation function | |
| Arch network[LayerStacks]; | |
| EvalFile evalFile; | |
| EmbeddedNNUEType embeddedType; | |
| bool initialized = false; | |
| // Hash value of evaluation function structure | |
| static constexpr std::uint32_t hash = Transformer::get_hash_value() ^ Arch::get_hash_value(); | |
| template<IndexType Size> | |
| friend struct AccumulatorCaches::Cache; | |
| }; | |
| // Definitions of the network types | |
| using SmallFeatureTransformer = FeatureTransformer<TransformedFeatureDimensionsSmall>; | |
| using SmallNetworkArchitecture = | |
| NetworkArchitecture<TransformedFeatureDimensionsSmall, L2Small, L3Small>; | |
| using BigFeatureTransformer = FeatureTransformer<TransformedFeatureDimensionsBig>; | |
| using BigNetworkArchitecture = NetworkArchitecture<TransformedFeatureDimensionsBig, L2Big, L3Big>; | |
| using NetworkBig = Network<BigNetworkArchitecture, BigFeatureTransformer>; | |
| using NetworkSmall = Network<SmallNetworkArchitecture, SmallFeatureTransformer>; | |
| struct Networks { | |
| Networks(EvalFile bigFile, EvalFile smallFile) : | |
| big(bigFile, EmbeddedNNUEType::BIG), | |
| small(smallFile, EmbeddedNNUEType::SMALL) {} | |
| NetworkBig big; | |
| NetworkSmall small; | |
| }; | |
| } // namespace Stockfish | |
| template<typename ArchT, typename FeatureTransformerT> | |
| struct std::hash<Stockfish::Eval::NNUE::Network<ArchT, FeatureTransformerT>> { | |
| std::size_t operator()( | |
| const Stockfish::Eval::NNUE::Network<ArchT, FeatureTransformerT>& network) const noexcept { | |
| return network.get_content_hash(); | |
| } | |
| }; | |
| template<> | |
| struct std::hash<Stockfish::Eval::NNUE::Networks> { | |
| std::size_t operator()(const Stockfish::Eval::NNUE::Networks& networks) const noexcept { | |
| std::size_t h = 0; | |
| Stockfish::hash_combine(h, networks.big); | |
| Stockfish::hash_combine(h, networks.small); | |
| return h; | |
| } | |
| }; | |