""" Enhanced version of POAGraph for text alignment """ import pickle import textwrap from typing import Dict, Optional import numpy as np from tqdm import tqdm from src.text_poa_graph_utils import path_sim_llm from src.global_edit_utils import clean_up_text from .new_text_alignment import TextSeqGraphAlignment from .poa_graph import Node, POAGraph class TextNode(Node): def __init__(self, nodeID=-1, text=""): super().__init__(nodeID, text) self.variations = {} # Track alternate phrasings self.sequences = [] # Track sequences that contain this node self.influenceScore = 0 self.num_tokens_used = 0 def add_variation(self, text, sequence_id): self.variations[sequence_id] = text @property def is_stable(self): """A node is stable if it appears frequently enough relative to total sequences""" return self.frequency >= self.graph.stability_threshold class TextPOAGraph(POAGraph): def __init__(self, text=None, label=-1): self.consensus_node_ids = [] self._seq_paths = {} self.end_id = -1 self.start_id = -1 self.failed = False self.num_input_tokens_used = 0 self.num_output_tokens_used = 0 super().__init__(text, label) def addNode(self, text): """Override to use TextNode""" nid = self._nextnodeID newnode = TextNode(nid, text) self.nodedict[nid] = newnode self.nodeidlist.append(nid) self._nnodes += 1 self._nextnodeID += 1 self._needsSort = True return nid def addUnmatchedSeq(self, text, label=-1, updateSequences=True): """Modified to handle text sequences""" if text is None: return # Handle both string and list input if isinstance(text, str): words = text.split() else: words = text firstID, lastID = None, None neededSort = self.needsSort path = [] for word in words: nodeID = self.addNode(word) if firstID is None: firstID = nodeID if lastID is not None: self.addEdge(lastID, nodeID, label=label) lastID = nodeID path.append(nodeID) self._needsort = neededSort if updateSequences: self._seqs.append(words) self._labels.append(label) self._starts.append(firstID) self._seq_paths[label] = path return firstID, lastID def add_text(self, text, label=-1): """Main method to add new text to the alignment""" if len(self._seqs) == 0: # First sequence - just add it self.addUnmatchedSeq(text, label) else: # Align to existing graph alignment = TextSeqGraphAlignment( text, self, matchscore=2, mismatchscore=-1, gapscore=-2 ) self.incorporateSeqAlignment(alignment, text, label) # Update node frequencies self._update_frequencies() def removeNode(self, nodeID): """Override to handle text nodes""" node = self.nodedict[nodeID] if node is None: return # Remove all edges to this node out_edges = node.outEdges.copy() in_edges = node.inEdges.copy() for edge in out_edges: self.removeEdge(node.ID, edge) for edge in in_edges: self.removeEdge(edge, node.ID) # Remove from graph del self.nodedict[nodeID] self.nodeidlist.remove(nodeID) for path in self._seq_paths.values(): if nodeID in path: path.remove(nodeID) self._nnodes -= 1 self._needsSort = True def removeEdge(self, nodeID1, nodeID2): """Override to handle text nodes""" node1 = self.nodedict[nodeID1] node2 = self.nodedict[nodeID2] if node1 is None or node2 is None: return # Remove from graph del node1.outEdges[nodeID2] del node2.inEdges[nodeID1] def merge_consensus_nodes(self, verbose: bool = False): self.toposort() # reset consensus node ids self.consensus_node_ids = [] nodes = list(self.nodeiterator()()) consensus_segments = [] i = 0 while i < len(nodes): node = nodes[i] out_weight = sum(e.weight for e in node.outEdges.values()) in_weight = sum(e.weight for e in node.inEdges.values()) if out_weight in [0, self.num_sequences] and in_weight in [0, self.num_sequences]: consensus_segment = [(node.ID, node.text)] next_node = node while (i + 1) < len(nodes) and len(next_node.outEdges) == 1: next_node = nodes[i + 1] next_out_weight = sum(e.weight for e in next_node.outEdges.values()) next_in_weight = sum(e.weight for e in next_node.inEdges.values()) if ( next_out_weight != self.num_sequences or next_in_weight != self.num_sequences ): break consensus_segment.append((next_node.ID, next_node.text)) i += 1 consensus_segments.append(consensus_segment) i += 1 # merge consensus nodes into a single node for segment in consensus_segments: if len(segment) == 1: self.consensus_node_ids.append(segment[0][0]) continue merged_text = " ".join([text for _, text in segment]) first_node_id = segment[0][0] last_node_id = segment[-1][0] self.nodedict[last_node_id].text = merged_text self.consensus_node_ids.append(last_node_id) # attach all incoming edges to first node to last node for id, edge in self.nodedict[first_node_id].inEdges.items(): weight = edge.weight for _ in range(weight): self.addEdge(id, last_node_id, label=edge.labels) # delete all nodes except last node for node_id, _ in segment[:-1]: self.removeNode(node_id) if verbose: print(self.consensus_node_ids) """ find all paths between start_node_id and end_node_id from original sequences return a list of dictionaries with the following keys: - path: list of node ids in the path (excluding start and including end) - text: text of the path (excluding start and end) - weight: minimal edge weight across all edges in the path - labels: intersection of all edge labels in the path """ def find_paths_between(self, start_node_id: int, end_node_id: int): # find all paths between start_node_id and end_node_id from original sequences path_dicts = [] # keep track of visited paths to avoid duplicates visited_paths = set() for _, path in self._seq_paths.items(): start_index = path.index(start_node_id) if start_node_id in path else None end_index = path.index(end_node_id) if end_node_id in path else None # print(start_index, end_index) # print(path) if ( start_index is not None and end_index is not None and end_index - start_index > 1 and tuple(path[start_index + 1 : end_index + 1]) not in visited_paths ): # intersection of all edge labels in the path path_labels = set.intersection( *[ set(self.nodedict[next_node_id].inEdges[node_id].labels) for node_id, next_node_id in zip( path[start_index:end_index], path[start_index + 1 : end_index + 1] ) ] ) path_weight = len(path_labels) path_dicts.append( { "path": path[start_index + 1 : end_index + 1], "body_text": " ".join( [ self.nodedict[node_id].text for node_id in path[start_index + 1 : end_index] ] ), "begin_text": self.nodedict[path[start_index]].text, "end_text": self.nodedict[path[end_index]].text, "weight": path_weight, "labels": path_labels, } ) visited_paths.add(tuple(path[start_index + 1 : end_index + 1])) return path_dicts def _follow_path(self, start_id): """Follow all possible paths from a node""" paths = [] visited = set() def dfs(node_id, current_path): if node_id in visited: return visited.add(node_id) node = self.nodedict[node_id] if not node.outEdges: paths.append(current_path + [node_id]) return for next_id in node.outEdges: dfs(next_id, current_path + [node_id]) dfs(start_id, []) return paths def merge_paths_between( self, start_node_id: int, end_node_id: int, path_sim_type: str = "llm", verbose: bool = False, **kwargs, ): path_dicts = self.find_paths_between(start_node_id, end_node_id) if path_sim_type == "llm": api = kwargs.get("api", "openai") model = kwargs.get("model", "gpt-4o-mini") domain = kwargs.get("domain", None) similarity_judge_prompt = kwargs.get("similarity_judge_prompt", None) def path_sim_func(path1_text, path2_text): return path_sim_llm( path1_text, path2_text, api=api, model=model, domain=domain, custom_similarity_judge_prompt=similarity_judge_prompt, ) elif path_sim_type == "cosine": pass # embedding_model = SentenceTransformer("all-MiniLM-L6-v2") # threshold = kwargs.get("threshold", 0.9) # path_sim_func = path_sim_cosine(embedding_model, threshold) else: raise ValueError(f"Invalid path similarity type: {path_sim_type}") # merge paths based on semantic similarity path_equivalence_classes = {} class_count = 0 for path_dict in path_dicts: if verbose: print(path_dict) found_class = False for _, eq_class in path_equivalence_classes.items(): # check if path dict is already in an equivalence class path1_text = ( path_dict["begin_text"] + " " + path_dict["body_text"] + " " + path_dict["end_text"] ) path2_text = ( eq_class[0]["begin_text"] + " " + eq_class[0]["body_text"] + " " + eq_class[0]["end_text"] ) judgement, num_input_tokens, num_output_tokens = path_sim_func( path1_text, path2_text ) self.num_input_tokens_used += num_input_tokens self.num_output_tokens_used += num_output_tokens if judgement: eq_class.append(path_dict) found_class = True break if not found_class: class_count += 1 path_equivalence_classes[class_count] = [path_dict] nodes_to_remove = set() # Track nodes to remove for _, eq_class in path_equivalence_classes.items(): path_dict = eq_class[0] if verbose: print(eq_class) # add new node with merged text new_node_id = self.addNode(path_dict["body_text"]) for sequence_id in path_dict["labels"]: self.nodedict[new_node_id].variations[sequence_id] = path_dict["body_text"] # collect nodes to remove from first path nodes_to_remove.update(path_dict["path"][:-1]) # process data regarding weights and labels labels = list(path_dict["labels"]) weight = path_dict["weight"] self.addEdge(start_node_id, new_node_id, label=labels, weight=weight) # Updated seq_paths for all labels to include new_node betwwen start_node and end_node for label in labels: index = self._seq_paths[label].index(start_node_id) if ( index + 1 < len(self._seq_paths[label]) and self._seq_paths[label][index + 1] != new_node_id ): self._seq_paths[label].insert(index + 1, new_node_id) self.addEdge(new_node_id, end_node_id, label=labels, weight=weight) self.nodedict[new_node_id].sequences = labels # process additional paths for path_dict in eq_class[1:]: for sequence_id in path_dict["labels"]: self.nodedict[new_node_id].variations[sequence_id] = path_dict["body_text"] nodes_to_remove.update(path_dict["path"][:-1]) # copy incoming edges to new node labels = list(path_dict["labels"]) weight = path_dict["weight"] self.addEdge(start_node_id, new_node_id, label=labels, weight=weight) # Updated seq_paths for all labels to include new_node betwwen start_node and end_node for label in labels: index = self._seq_paths[label].index(start_node_id) if ( index + 1 < len(self._seq_paths[label]) and self._seq_paths[label][index + 1] != new_node_id ): self._seq_paths[label].insert(index + 1, new_node_id) self.addEdge(new_node_id, end_node_id, label=labels, weight=weight) self.nodedict[new_node_id].sequences.extend(labels) self.nodedict[new_node_id].sequences = list(set(self.nodedict[new_node_id].sequences)) # Remove all collected nodes after processing for node_id in nodes_to_remove: if node_id in self.nodedict: if verbose: print(f"Removing node {node_id}") self.removeNode(node_id) def merge_divergent_paths(self, path_sim_type: str = "llm", verbose: bool = False, **kwargs): # add dummy end node to the end of the graph if not self.consensus_node_ids: self.merge_consensus_nodes(verbose=verbose) self.toposort() if self.start_id == -1: if verbose: print("Adding start node") self.start_id = self.addNode(text="START") self._nextnodeID += 1 self.consensus_node_ids.insert(0, self.start_id) for label, path in self._seq_paths.items(): self.addEdge(self.start_id, path[0], label=label, weight=1) path.insert(0, self.start_id) if self.end_id == -1: if verbose: print("Adding end node") self.end_id = self.addNode(text="END") self._nextnodeID += 1 self.consensus_node_ids = self.consensus_node_ids + [self.end_id] for label, path in self._seq_paths.items(): self.addEdge(path[-1], self.end_id, label=label, weight=1) path.append(self.end_id) for i in tqdm(range(len(self.consensus_node_ids) - 1)): if verbose: print(self.consensus_node_ids[i], self.consensus_node_ids[i + 1]) self.merge_paths_between( self.consensus_node_ids[i], self.consensus_node_ids[i + 1], path_sim_type=path_sim_type, verbose=verbose, **kwargs, ) def get_variable_node_ids(self): return [ node.ID for node in self.nodedict.values() if node.ID not in self.consensus_node_ids ] def compress_paths_between(self, start_node_id: int, end_node_id: int): pass def compress_graph(self): pass def update_influence_scores(self, outcome: Dict[int, float], discount_factor: float = 0.2): self.toposort() direct_scores = [] for node in self.nodedict.values(): next_out_weight = sum(e.weight for e in node.outEdges.values()) next_in_weight = sum(e.weight for e in node.inEdges.values()) if next_out_weight == self.num_sequences and next_in_weight == self.num_sequences: out_list = [] for edge in node.outEdges.values(): for _ in range(len(set(edge.labels))): out_list.append(np.mean([outcome[label] for label in set(edge.labels)])) direct_scores.append((node.ID, np.var(out_list))) scores = direct_scores.copy() # Start from the end and propagate influence backward for i in range(len(scores) - 2, -1, -1): # Current node gets its direct influence plus discounted influence of next node current_direct = scores[i][1] next_total = scores[i + 1][1] scores[i] = (scores[i][0], current_direct + discount_factor * next_total) scores.sort(key=lambda x: x[1], reverse=True) return scores def jsOutput( self, verbose: bool = False, annotate_consensus: bool = True, color_annotations: Dict[int, str] = None, ): """returns a list of strings containing a a description of the graph for viz.js, http://visjs.org""" # get the consensus sequence, which we'll use as the "spine" of the # graph pathdict = {} if annotate_consensus: path, __, __ = self.consensus() lines = ["var nodes = ["] ni = self.nodeiterator() count = 0 for node in ni(): title_text = "" if node.sequences: title_text += f"Sequences: {node.sequences}" if node.variations: title_text += ";;;".join( [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()] ) title_text = title_text.replace('"', "'") line = ( " {id:" + str(node.ID) + ', label: "' + str(node.ID) + ": " + node.text.replace('"', "'") + '", title: ' + '"' + title_text + '",' ) if color_annotations and node.ID in color_annotations: line += f" color: '{color_annotations[node.ID]}', " if node.ID in pathdict and count % 5 == 0 and annotate_consensus: line += ( ", x: " + str(pathdict[node.ID]) + ", y: 0 , fixed: { x:true, y:false}," + "color: '#7BE141', is_consensus:true}," ) else: line += "}," lines.append(line) lines[-1] = lines[-1][:-1] lines.append("];") lines.append(" ") lines.append("var edges = [ ") ni = self.nodeiterator() for node in ni(): nodeID = str(node.ID) for edge in node.outEdges: target = str(edge) weight = str(node.outEdges[edge].weight + 1.5) lines.append( " {from: " + nodeID + ", to: " + target + ", value: " + weight + ", color: '#4b72b0', arrows: 'to'}," ) if verbose: for alignededge in node.alignedTo: # These edges indicate alignment to different bases, and are # undirected; thus make sure we only plot them once: if node.ID > alignededge: continue target = str(alignededge) lines.append( " {from: " + nodeID + ", to: " + target + ', value: 1, style: "dash-line", color: "red"},' ) lines[-1] = lines[-1][:-1] lines.append("];") return lines def htmlOutput( self, outfile, verbose: bool = False, annotate_consensus: bool = True, color_annotations: Dict[int, str] = None, ): header = """