Spaces:
Sleeping
Sleeping
| // C++11 | |
| const int DEBUG_MODE = 0; | |
| namespace tree { | |
| class CNode { | |
| public: | |
| int visit_count, to_play, current_latent_state_index, batch_index, best_action, is_reset; | |
| float value_prefix, prior, value_sum; | |
| float parent_value_prefix; | |
| std::vector<int> children_index; | |
| std::map<int, CNode> children; | |
| std::vector<int> legal_actions; | |
| CNode(); | |
| CNode(float prior, std::vector<int> &legal_actions); | |
| ~CNode(); | |
| void expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector<float> &policy_logits); | |
| void add_exploration_noise(float exploration_fraction, const std::vector<float> &noises); | |
| float compute_mean_q(int isRoot, float parent_q, float discount_factor); | |
| void print_out(); | |
| int expanded(); | |
| float value(); | |
| std::vector<int> get_trajectory(); | |
| std::vector<int> get_children_distribution(); | |
| CNode* get_child(int action); | |
| }; | |
| class CRoots{ | |
| public: | |
| int root_num; | |
| std::vector<CNode> roots; | |
| std::vector<std::vector<int> > legal_actions_list; | |
| CRoots(); | |
| CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list); | |
| ~CRoots(); | |
| void prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch); | |
| void prepare_no_noise(const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch); | |
| void clear(); | |
| std::vector<std::vector<int> > get_trajectories(); | |
| std::vector<std::vector<int> > get_distributions(); | |
| std::vector<float> get_values(); | |
| CNode* get_root(int index); | |
| }; | |
| class CSearchResults{ | |
| public: | |
| int num; | |
| std::vector<int> latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens; | |
| std::vector<int> virtual_to_play_batchs; | |
| std::vector<CNode*> nodes; | |
| std::vector<std::vector<CNode*> > search_paths; | |
| CSearchResults(); | |
| CSearchResults(int num); | |
| ~CSearchResults(); | |
| }; | |
| //********************************************************* | |
| void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players); | |
| void cbackpropagate(std::vector<CNode*> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor); | |
| void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> is_reset_list, std::vector<int> &to_play_batch); | |
| int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players); | |
| float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players); | |
| void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch); | |
| } | |