| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import random as rd |
| from utils.app import PeptideAnalyzer |
| from utils.timer import StepTimer |
| from scoring.scoring_functions import ScoringFunctions |
|
|
| import noise_schedule |
|
|
| |
| def dominates(a, b): |
| a = np.asarray(a); b = np.asarray(b) |
| return np.all(a >= b) and np.any(a > b) |
|
|
| def dominated_by(a, b): |
| return dominates(b, a) |
|
|
|
|
| def updateParetoFront(paretoFront, node, scoreVector, totalSize=None, eps=1e-12): |
| """ |
| Maintain a non-dominated set (Pareto front) of (node -> scoreVector). |
| |
| - Accept 'node' iff it is NOT dominated by any node in the set. |
| - Remove any nodes that ARE dominated by 'node'. |
| - Skip insertion if an equal point already exists (within eps). |
| - If totalSize is given and the archive exceeds it, drop the item |
| with the smallest sum(scoreVector) as a simple tie-breaker. |
| |
| Args: |
| paretoFront (dict): {node: scoreVector} |
| node: candidate node (used as dict key) |
| scoreVector (array-like): candidate scores (to be maximized) |
| totalSize (int|None): optional max size for the archive |
| eps (float): tolerance for equality/inequality checks |
| |
| Returns: |
| dict: updated paretoFront |
| """ |
| s = np.asarray(scoreVector, dtype=float) |
|
|
| def dominates(a, b): |
| |
| return np.all(a >= b - eps) and np.any(a > b + eps) |
|
|
| def equal(a, b): |
| return np.all(np.abs(a - b) <= eps) |
|
|
| |
| for v in paretoFront.values(): |
| v = np.asarray(v, dtype=float) |
| if dominates(v, s): |
| return paretoFront |
|
|
| |
| survivors = {} |
| |
| for k, v in paretoFront.items(): |
| v_arr = np.asarray(v, dtype=float) |
| if dominates(s, v_arr): |
| continue |
| """if equal(s, v_arr): |
| has_equal = True # skip duplicate insertion later""" |
| survivors[k] = v_arr |
|
|
| |
| """if has_equal: |
| return survivors""" |
|
|
| |
| survivors[node] = s |
|
|
| |
| if totalSize is not None and totalSize > 0 and len(survivors) > totalSize: |
| |
| keys = list(survivors.keys()) |
| sums = np.array([np.sum(np.asarray(survivors[k], dtype=float)) for k in keys]) |
| drop_idx = int(np.argmin(sums)) |
| del survivors[keys[drop_idx]] |
|
|
| return survivors |
|
|
| |
|
|
| class Node: |
| """ |
| Node class: partially unmasked sequence |
| - parentNode: Node object at previous time step |
| - childNodes: set of M Node objects generated from sampling M distinct unmasking schemes |
| - totalReward: vector of cumulative rewards for all K objectives |
| - visits: number of times the node has been visited by an interation |
| - path: array of partially unmasked SMILES strings leading to the node from the completely masked root node |
| - timestep: the time step where the sequence was sampled |
| """ |
| def __init__(self, args, tokens=None, log_rnd=None, log_policy_step=None, log_pretrained_step=None, parentNode=None, childNodes=None, totalReward=None, timestep=None): |
| self.args = args |
| self.parentNode = parentNode |
| |
| self.childNodes = [] if childNodes is None else childNodes |
| |
| self.log_rnd = log_rnd |
| |
| |
| self.log_policy_step = log_policy_step |
| self.log_pretrained_step = log_pretrained_step |
| |
| |
| if totalReward is not None: |
| self.totalReward = totalReward |
| else: |
| self.totalReward = np.zeros(self.args.num_obj) |
| |
| |
| self.visits = 1 |
| |
| |
| self.timestep = timestep |
| |
| |
| self.tokens = tokens |
| |
| def selectNode(self): |
| """ |
| Selects a node to move to among the children nodes based on select score |
| """ |
| |
| nodeStatus = self.getExpandStatus() |
| |
| |
| if (nodeStatus == 3): |
| |
| |
| paretoFront = {} |
| |
| for childNode in self.childNodes: |
| childStatus = childNode.getExpandStatus() |
| |
| if childStatus == 2 or childStatus == 3: |
| selectScore = childNode.calcSelectScore() |
| paretoFront = updateParetoFront(paretoFront, childNode, selectScore) |
| |
| selected = rd.choice(list(paretoFront.keys())) |
| |
| |
| return selected, selected.getExpandStatus() |
| |
| |
| return self, nodeStatus |
|
|
| def addChildNode(self, tokens, log_rnd, log_policy_step, log_pretrained_step, totalReward): |
| """" |
| Adds a child node: |
| log_rnd: log_rnd of the path up to the added child node |
| log_policy_step: scalar value of the log-prob of sampling the step under the policy |
| log_pretrained_step: scalar value of the log-prob of sampling the step under the pretrained model |
| """ |
| child = Node(args=self.args, |
| tokens=tokens, |
| log_rnd = log_rnd, |
| log_policy_step=log_policy_step, |
| log_pretrained_step=log_pretrained_step, |
| parentNode=self, |
| childNodes=[], |
| totalReward=totalReward, |
| timestep=self.timestep+1) |
| |
| self.childNodes.append(child) |
| return child |
| |
| def update_logrnd(self, log_policy_step, log_rnd): |
| self.log_policy_step = log_policy_step |
| self.log_rnd = log_rnd |
| |
| def updateNode(self, rewards): |
| """ |
| Updates the cumulative rewards vector with the reward vector at a descendent leaf node. |
| Increments the number of visits to the node. |
| """ |
| self.visits += 1 |
| |
| self.totalReward += rewards |
| |
| def calcSelectScore(self): |
| """ |
| Calculates the select score for the node from the cumulative rewards vector and number of visits. |
| - c: determines the degree of exploration |
| - minSelectScore: determines the |
| """ |
| scaling = 0.1 |
| |
| |
| normRewards = self.totalReward / self.visits |
| |
| |
| |
| return normRewards + (scaling * self.log_policy_step.detach().cpu().item() * np.sqrt(self.parentNode.visits) / self.visits) |
| |
| def getExpandStatus(self): |
| """ |
| Returns an integer indicating whether the node is a: |
| 1. terminal node (sequence is fully unmasked) |
| 2. legal leaf node (partially unmasked sequence that can be expanded) |
| 3. legal non-leaf node (already expanded sequence with M child nodes) |
| """ |
| if self.timestep == self.args.total_num_steps: |
| return 1 |
| elif (self.timestep < self.args.total_num_steps) and (len(self.childNodes) == 0): |
| return 2 |
| return 3 |
| |
| |
| |
| |
|
|
| class MCTS: |
| def __init__(self, args, config, policy_model, pretrained, score_func_names=[], prot_seqs=None, rootNode=None): |
| self.timer = StepTimer(policy_model.device) |
| |
| self.device = policy_model.device |
| |
| self.args = args |
| self.config = config |
| self.noise = noise_schedule.get_noise(config) |
| self.time_conditioning = args.time_conditioning |
| |
| self.num_obj = len(score_func_names) |
| |
| self.mask_index = policy_model.mask_index |
| masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index |
| masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)} |
| if rootNode is None: |
| self.rootNode = Node(self.args, tokens = masked_tokens, |
| log_rnd=torch.zeros((), device=self.device), |
| log_policy_step=torch.zeros((), device=self.device), |
| log_pretrained_step=torch.zeros((), device=self.device), |
| totalReward=np.zeros(self.num_obj), timestep=0) |
| else: |
| self.rootNode = rootNode |
| |
| |
| |
| |
| |
| self.buffer = [] |
| |
| self.buffer_size = args.buffer_size |
| |
| self.num_steps = args.total_num_steps |
| |
| |
| |
| self.pretrained = pretrained |
| |
| |
| self.policy_model = policy_model |
| |
| self.device = policy_model.device |
| |
| self.sequence_length = args.seq_length |
| |
| self.num_iter = args.num_iter |
| |
| self.num_children = args.num_children |
| |
| |
| |
| self.rewardFunc = ScoringFunctions(score_func_names, prot_seqs, device=args.device) |
|
|
| self.iter_num = 0 |
| |
| self.reward_log = [] |
| self.logrnd_log = [] |
| |
| self.valid_fraction_log = [] |
| self.affinity1_log = [] |
| self.affinity2_log = [] |
| self.permeability_log = [] |
| self.sol_log = [] |
| self.hemo_log = [] |
| self.nf_log = [] |
| |
| self.policy_model.eval() |
| self.pretrained.eval() |
| |
| |
| self.analyzer = PeptideAnalyzer() |
| self.tokenizer = policy_model.tokenizer |
| |
| |
| def reset(self, resetTree): |
| self.iter_num = 0 |
| self.buffer = [] |
| self.reward_log = [] |
| self.logrnd_log = [] |
| |
| |
| self.valid_fraction_log = [] |
| self.affinity1_log = [] |
| self.affinity2_log = [] |
| self.permeability_log = [] |
| self.sol_log = [] |
| self.hemo_log = [] |
| self.nf_log = [] |
| |
| |
| if resetTree: |
| masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index |
| masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)} |
| self.rootNode = Node(self.args, tokens = masked_tokens, |
| log_rnd=torch.zeros((), device=self.device), |
| log_policy_step=torch.zeros((), device=self.device), |
| log_pretrained_step=torch.zeros((), device=self.device), |
| totalReward=np.zeros(self.num_obj), timestep=0) |
|
|
| def forward(self, resetTree=False): |
| |
| self.reset(resetTree) |
| |
| while (self.iter_num < self.num_iter): |
| self.iter_num += 1 |
| |
| |
| with self.timer.section("select"): |
| leafNode, _ = self.select(self.rootNode) |
| |
| |
| with self.timer.section("expand"): |
| self.expand(leafNode) |
| |
| final_x, log_rnd, final_rewards, score_vectors, sequences = self.consolidateBuffer() |
| |
| |
| rows = self.timer.summary() |
| print("\n=== Timing summary (by total time) ===") |
| for name, cnt, total, mean, p50, p95 in rows: |
| print(f"{name:30s} n={cnt:5d} total={total:8.3f}s mean={mean*1e3:7.2f}ms " |
| f"p50={p50*1e3:7.2f}ms p95={p95*1e3:7.2f}ms") |
| |
| return final_x, log_rnd, final_rewards, score_vectors, sequences |
|
|
| |
| def _debug_buffer_decision(self, sv, reason, extra=None): |
| if extra is None: extra = {} |
| print(f"[BUFFER] reason={reason} sv={np.round(sv,4)} " |
| f"buf_len={len(self.buffer)} extra={extra}") |
|
|
| def updateBuffer(self, x_final, log_rnd, score_vectors, validSequences): |
| B = x_final.shape[0] |
| traj_log_rnds, scalar_rewards = [], [] |
|
|
| for i in range(B): |
| sv = np.asarray(score_vectors[i], dtype=float) |
| |
| |
| if self.args.scalarization == "normalized": |
| pass |
| elif self.args.scalarization == "weighted": |
| pass |
| else: |
| scalar_reward = float(np.sum(sv)) |
| |
| traj_log_rnd = log_rnd[i] + (scalar_reward / self.args.alpha) |
|
|
| item = { |
| "x_final": x_final[i].clone(), |
| "log_rnd": traj_log_rnd.clone(), |
| "final_reward": scalar_reward, |
| "score_vector": sv.copy(), |
| "seq": validSequences[i], |
| } |
|
|
| |
| if any(dominated_by(sv, bi["score_vector"]) for bi in self.buffer): |
| |
| self._debug_buffer_decision(sv, "rejected_dominated") |
| continue |
|
|
| |
| keep = [] |
| for bi in self.buffer: |
| if not dominates(sv, bi["score_vector"]): |
| keep.append(bi) |
| self.buffer = keep |
|
|
| |
| if len(self.buffer) < self.buffer_size: |
| self.buffer.append(item) |
| else: |
| |
| worst_i = int(np.argmin([np.sum(bi["score_vector"]) for bi in self.buffer])) |
| self.buffer[worst_i] = item |
| |
| |
| self._debug_buffer_decision(sv, "inserted", {"new_len": len(self.buffer)}) |
|
|
| traj_log_rnds.append(traj_log_rnd) |
| scalar_rewards.append(scalar_reward) |
|
|
| traj_log_rnds = torch.stack(traj_log_rnds, dim=0) if traj_log_rnds else torch.empty(0) |
| scalar_rewards = np.asarray(scalar_rewards, dtype=float) |
| return traj_log_rnds, scalar_rewards |
| |
| def consolidateBuffer(self): |
| """ |
| returns x_final, log_rnd, and final_rewards in tensors |
| """ |
| x_final = [] |
| log_rnd = [] |
| final_rewards = [] |
| score_vectors = [] |
| sequences = [] |
| for item in self.buffer: |
| x_final.append(item["x_final"]) |
| log_rnd.append(item["log_rnd"]) |
| final_rewards.append(item["final_reward"]) |
| score_vectors.append(item["score_vector"]) |
| sequences.append(item["seq"]) |
| |
| x_final = torch.stack(x_final, dim=0) |
| log_rnd = torch.stack(log_rnd, dim=0).to(dtype=torch.float32) |
| final_rewards = np.stack(final_rewards, axis=0).astype(np.float32) |
| score_vectors = np.stack(score_vectors, axis=0).astype(np.float32) |
| |
| return x_final, log_rnd, final_rewards, score_vectors, sequences |
| |
|
|
| def isPathEnd(self, path, maxDepth): |
| """ |
| Checks if the node is completely unmasked (ie. end of path) |
| or if the path is at the max depth |
| """ |
| if (path[-1] != self.mask_index).all(): |
| return True |
| elif len(path) >= maxDepth: |
| return True |
| return False |
| |
| def select(self, currNode, eps=1e-5): |
| """ |
| Traverse the tree from the root node until reaching a legal leaf node |
| """ |
| updated_log_rnd = torch.zeros((), device=self.device) |
| while True: |
| currNode, nodeStatus = currNode.selectNode() |
| |
| if currNode.parentNode is not None: |
| |
| child_tokens = currNode.tokens['seqs'].to(self.device) |
| attn_mask = currNode.tokens['attention_mask'].to(self.device) |
| parent = currNode.parentNode |
| parent_tokens = parent.tokens['seqs'].to(self.device) |
| t = torch.ones(1, device = self.device) |
| dt = (1 - eps) / self.num_steps |
| with torch.no_grad(): |
| with self.timer.section("select.compute_log_policy"): |
| updated_log_policy_step = self.policy_model.compute_log_policy(parent_tokens, |
| child_tokens, |
| t=t, dt=dt) |
| updated_log_rnd += updated_log_policy_step |
| |
| currNode.update_logrnd(updated_log_policy_step, updated_log_rnd) |
| |
| if nodeStatus != 3: |
| return currNode, nodeStatus |
| |
| def expand(self, parentNode, eps=1e-5): |
| """ |
| Sample unmasking steps from the pre-trained MDLM |
| adds num_children partially unmasked sequences to the children of the parentNode |
| """ |
| |
| num_children = self.num_children |
| |
| |
| |
| |
| |
| num_rollout_steps = self.num_steps - parentNode.timestep |
| |
| rollout_t = torch.linspace(1, eps, self.num_steps + 1, device=self.device) |
| dt = (1 - eps) / self.num_steps |
| |
| |
| x = parentNode.tokens['seqs'].to(self.device) |
| attn_mask = parentNode.tokens['attention_mask'].to(self.device) |
| parent_log_rnd = parentNode.log_rnd |
| |
| t = rollout_t[parentNode.timestep] * torch.ones(1, 1, device = self.device) |
| |
| |
| with torch.no_grad(): |
| with self.timer.section("expand.batch_mcts_reverse_step"): |
| _, x_children, child_log_policy_step, child_log_pretrained_step = \ |
| self.policy_model.batch_mcts_reverse_step(token_array=x, |
| t=t, dt=dt, |
| batch_size=num_children, |
| pretrained=self.pretrained) |
| |
| |
| |
| child_log_rnd = (parent_log_rnd + (child_log_pretrained_step - child_log_policy_step)).to(self.device) |
| |
| x_rollout = x_children |
| |
| traj_log_rnd = child_log_rnd |
| |
| |
| with self.timer.section("expand.rollout_total"): |
| for i in range(1, num_rollout_steps): |
| t = rollout_t[parentNode.timestep + i] * torch.ones(num_children, 1, device = self.device) |
| |
| with torch.no_grad(): |
| _, x_next, log_policy_step, log_pretrained_step = \ |
| self.policy_model.mcts_reverse_step(x_rollout, |
| t=t, dt=dt, |
| pretrained=self.pretrained) |
| |
| |
| traj_log_rnd += log_pretrained_step - log_policy_step |
| |
| x_rollout = x_next |
| |
| |
| |
| mask_positions = (x_rollout == self.mask_index) |
|
|
| |
| any_mask_global = mask_positions.any().item() |
| if any_mask_global: |
| with torch.no_grad(): |
| with self.timer.section("expand.noise_removal"): |
| log_p, x_next, log_policy_step, log_pretrained_step = \ |
| self.policy_model.mcts_noise_removal(x_rollout, |
| t=t, dt=dt, |
| pretrained=self.pretrained) |
|
|
| traj_log_rnd += log_pretrained_step - log_policy_step |
| |
| x_rollout = x_next |
| |
| |
| with self.timer.section("expand.decode"): |
| childSequences = self.tokenizer.batch_decode(x_rollout) |
| |
| |
| valid_x_children = [] |
| valid_x_final = [] |
| validSequences = [] |
| valid_traj_log_rnd = [] |
| |
| with self.timer.section("expand.filter_is_peptide"): |
| for i in range(num_children): |
| |
| childSeq = childSequences[i] |
| |
| |
| if self.analyzer.is_peptide(childSeq): |
| valid_x_children.append(x_children[i]) |
| valid_x_final.append(x_rollout[i]) |
| validSequences.append(childSeq) |
| valid_traj_log_rnd.append(traj_log_rnd[i]) |
| else: |
| childTokens = {'seqs': x_children[i].to(dtype=torch.long), 'attention_mask': attn_mask} |
| parentNode.addChildNode(tokens=childTokens, |
| log_rnd=child_log_rnd[i], |
| log_policy_step=child_log_policy_step[i], |
| log_pretrained_step=child_log_pretrained_step[i], |
| totalReward=np.zeros(self.num_obj)) |
| |
| del traj_log_rnd |
| |
| if (len(validSequences) != 0): |
| |
| with self.timer.section("expand.scoring_functions"): |
| score_vectors = self.rewardFunc(input_seqs=validSequences) |
| |
| average_scores = score_vectors.T |
| |
| self.affinity1_log.append(average_scores[0]) |
| self.sol_log.append(average_scores[1]) |
| self.hemo_log.append(average_scores[2]) |
| self.nf_log.append(average_scores[3]) |
| self.permeability_log.append(average_scores[4]) |
|
|
| else: |
| |
| self.affinity1_log.append(np.zeros((self.num_obj, self.num_children))) |
| self.sol_log.append(np.zeros((self.num_obj, self.num_children))) |
| self.hemo_log.append(np.zeros((self.num_obj, self.num_children))) |
| self.nf_log.append(np.zeros((self.num_obj, self.num_children))) |
| self.permeability_log.append(np.zeros((self.num_obj, self.num_children))) |
| |
| |
| if len(valid_x_final) == 0: |
| |
| self.valid_fraction_log.append(0.0) |
| return |
| |
| valid_x_final = torch.stack(valid_x_final, dim=0) |
| valid_traj_log_rnd = torch.stack(valid_traj_log_rnd, dim=0) |
| |
| with self.timer.section("expand.update_buffer"): |
| traj_log_rnds, scalar_rewards = self.updateBuffer(valid_x_final, valid_traj_log_rnd, score_vectors, validSequences) |
| |
| allChildReward = np.zeros_like(score_vectors[0]) |
| |
| for i in range(len(score_vectors)): |
| reward = score_vectors[i] |
| |
| |
| allChildReward += reward |
| |
| |
| childTokens = {'seqs': valid_x_children[i].to(dtype=torch.long), 'attention_mask': attn_mask} |
| parentNode.addChildNode(tokens=childTokens, |
| log_rnd=child_log_rnd[i], |
| log_policy_step=child_log_policy_step[i], |
| log_pretrained_step=child_log_pretrained_step[i], |
| totalReward=reward) |
| |
| |
| |
| valid_fraction = len(validSequences) / num_children |
| self.valid_fraction_log.append(valid_fraction) |
| |
| |
| print(f"[EXPAND] iter={self.iter_num} parent_t={parentNode.timestep} " |
| f"num_children={num_children} valid={len(validSequences)} any_mask={any_mask_global}") |
| if score_vectors is not None: |
| print(f"[SCORES] min={np.min(score_vectors,0)} max={np.max(score_vectors,0)} " |
| f"nan_any={np.isnan(score_vectors).any()}") |
| |
| |
| self.reward_log.append(scalar_rewards) |
| self.logrnd_log.append(traj_log_rnds.detach().cpu().numpy()) |
| |
| allChildReward = allChildReward / len(validSequences) |
| |
| with self.timer.section("expand.backprop"): |
| self.backprop(parentNode, allChildReward) |
| |
|
|
| def backprop(self, node, allChildReward): |
| |
| while node: |
| node.updateNode(allChildReward) |
| node = node.parentNode |