Spaces:
Build error
Build error
| """ | |
| Enhanced version of POAGraph for text alignment | |
| """ | |
| import pickle | |
| import textwrap | |
| from typing import Dict, Optional | |
| import numpy as np | |
| from tqdm import tqdm | |
| from src.text_poa_graph_utils import path_sim_llm | |
| from src.global_edit_utils import clean_up_text | |
| from .new_text_alignment import TextSeqGraphAlignment | |
| from .poa_graph import Node, POAGraph | |
| class TextNode(Node): | |
| def __init__(self, nodeID=-1, text=""): | |
| super().__init__(nodeID, text) | |
| self.variations = {} # Track alternate phrasings | |
| self.sequences = [] # Track sequences that contain this node | |
| self.influenceScore = 0 | |
| self.num_tokens_used = 0 | |
| def add_variation(self, text, sequence_id): | |
| self.variations[sequence_id] = text | |
| def is_stable(self): | |
| """A node is stable if it appears frequently enough relative to total sequences""" | |
| return self.frequency >= self.graph.stability_threshold | |
| class TextPOAGraph(POAGraph): | |
| def __init__(self, text=None, label=-1): | |
| self.consensus_node_ids = [] | |
| self._seq_paths = {} | |
| self.end_id = -1 | |
| self.start_id = -1 | |
| self.failed = False | |
| self.num_input_tokens_used = 0 | |
| self.num_output_tokens_used = 0 | |
| super().__init__(text, label) | |
| def addNode(self, text): | |
| """Override to use TextNode""" | |
| nid = self._nextnodeID | |
| newnode = TextNode(nid, text) | |
| self.nodedict[nid] = newnode | |
| self.nodeidlist.append(nid) | |
| self._nnodes += 1 | |
| self._nextnodeID += 1 | |
| self._needsSort = True | |
| return nid | |
| def addUnmatchedSeq(self, text, label=-1, updateSequences=True): | |
| """Modified to handle text sequences""" | |
| if text is None: | |
| return | |
| # Handle both string and list input | |
| if isinstance(text, str): | |
| words = text.split() | |
| else: | |
| words = text | |
| firstID, lastID = None, None | |
| neededSort = self.needsSort | |
| path = [] | |
| for word in words: | |
| nodeID = self.addNode(word) | |
| if firstID is None: | |
| firstID = nodeID | |
| if lastID is not None: | |
| self.addEdge(lastID, nodeID, label=label) | |
| lastID = nodeID | |
| path.append(nodeID) | |
| self._needsort = neededSort | |
| if updateSequences: | |
| self._seqs.append(words) | |
| self._labels.append(label) | |
| self._starts.append(firstID) | |
| self._seq_paths[label] = path | |
| return firstID, lastID | |
| def add_text(self, text, label=-1): | |
| """Main method to add new text to the alignment""" | |
| if len(self._seqs) == 0: | |
| # First sequence - just add it | |
| self.addUnmatchedSeq(text, label) | |
| else: | |
| # Align to existing graph | |
| alignment = TextSeqGraphAlignment( | |
| text, self, matchscore=2, mismatchscore=-1, gapscore=-2 | |
| ) | |
| self.incorporateSeqAlignment(alignment, text, label) | |
| # Update node frequencies | |
| self._update_frequencies() | |
| def removeNode(self, nodeID): | |
| """Override to handle text nodes""" | |
| node = self.nodedict[nodeID] | |
| if node is None: | |
| return | |
| # Remove all edges to this node | |
| out_edges = node.outEdges.copy() | |
| in_edges = node.inEdges.copy() | |
| for edge in out_edges: | |
| self.removeEdge(node.ID, edge) | |
| for edge in in_edges: | |
| self.removeEdge(edge, node.ID) | |
| # Remove from graph | |
| del self.nodedict[nodeID] | |
| self.nodeidlist.remove(nodeID) | |
| for path in self._seq_paths.values(): | |
| if nodeID in path: | |
| path.remove(nodeID) | |
| self._nnodes -= 1 | |
| self._needsSort = True | |
| def removeEdge(self, nodeID1, nodeID2): | |
| """Override to handle text nodes""" | |
| node1 = self.nodedict[nodeID1] | |
| node2 = self.nodedict[nodeID2] | |
| if node1 is None or node2 is None: | |
| return | |
| # Remove from graph | |
| del node1.outEdges[nodeID2] | |
| del node2.inEdges[nodeID1] | |
| def merge_consensus_nodes(self, verbose: bool = False): | |
| self.toposort() | |
| # reset consensus node ids | |
| self.consensus_node_ids = [] | |
| nodes = list(self.nodeiterator()()) | |
| consensus_segments = [] | |
| i = 0 | |
| while i < len(nodes): | |
| node = nodes[i] | |
| out_weight = sum(e.weight for e in node.outEdges.values()) | |
| in_weight = sum(e.weight for e in node.inEdges.values()) | |
| if out_weight in [0, self.num_sequences] and in_weight in [0, self.num_sequences]: | |
| consensus_segment = [(node.ID, node.text)] | |
| next_node = node | |
| while (i + 1) < len(nodes) and len(next_node.outEdges) == 1: | |
| next_node = nodes[i + 1] | |
| next_out_weight = sum(e.weight for e in next_node.outEdges.values()) | |
| next_in_weight = sum(e.weight for e in next_node.inEdges.values()) | |
| if ( | |
| next_out_weight != self.num_sequences | |
| or next_in_weight != self.num_sequences | |
| ): | |
| break | |
| consensus_segment.append((next_node.ID, next_node.text)) | |
| i += 1 | |
| consensus_segments.append(consensus_segment) | |
| i += 1 | |
| # merge consensus nodes into a single node | |
| for segment in consensus_segments: | |
| if len(segment) == 1: | |
| self.consensus_node_ids.append(segment[0][0]) | |
| continue | |
| merged_text = " ".join([text for _, text in segment]) | |
| first_node_id = segment[0][0] | |
| last_node_id = segment[-1][0] | |
| self.nodedict[last_node_id].text = merged_text | |
| self.consensus_node_ids.append(last_node_id) | |
| # attach all incoming edges to first node to last node | |
| for id, edge in self.nodedict[first_node_id].inEdges.items(): | |
| weight = edge.weight | |
| for _ in range(weight): | |
| self.addEdge(id, last_node_id, label=edge.labels) | |
| # delete all nodes except last node | |
| for node_id, _ in segment[:-1]: | |
| self.removeNode(node_id) | |
| if verbose: | |
| print(self.consensus_node_ids) | |
| """ | |
| find all paths between start_node_id and end_node_id from original sequences | |
| return a list of dictionaries with the following keys: | |
| - path: list of node ids in the path (excluding start and including end) | |
| - text: text of the path (excluding start and end) | |
| - weight: minimal edge weight across all edges in the path | |
| - labels: intersection of all edge labels in the path | |
| """ | |
| def find_paths_between(self, start_node_id: int, end_node_id: int): | |
| # find all paths between start_node_id and end_node_id from original sequences | |
| path_dicts = [] | |
| # keep track of visited paths to avoid duplicates | |
| visited_paths = set() | |
| for _, path in self._seq_paths.items(): | |
| start_index = path.index(start_node_id) if start_node_id in path else None | |
| end_index = path.index(end_node_id) if end_node_id in path else None | |
| # print(start_index, end_index) | |
| # print(path) | |
| if ( | |
| start_index is not None | |
| and end_index is not None | |
| and end_index - start_index > 1 | |
| and tuple(path[start_index + 1 : end_index + 1]) not in visited_paths | |
| ): | |
| # intersection of all edge labels in the path | |
| path_labels = set.intersection( | |
| *[ | |
| set(self.nodedict[next_node_id].inEdges[node_id].labels) | |
| for node_id, next_node_id in zip( | |
| path[start_index:end_index], path[start_index + 1 : end_index + 1] | |
| ) | |
| ] | |
| ) | |
| path_weight = len(path_labels) | |
| path_dicts.append( | |
| { | |
| "path": path[start_index + 1 : end_index + 1], | |
| "body_text": " ".join( | |
| [ | |
| self.nodedict[node_id].text | |
| for node_id in path[start_index + 1 : end_index] | |
| ] | |
| ), | |
| "begin_text": self.nodedict[path[start_index]].text, | |
| "end_text": self.nodedict[path[end_index]].text, | |
| "weight": path_weight, | |
| "labels": path_labels, | |
| } | |
| ) | |
| visited_paths.add(tuple(path[start_index + 1 : end_index + 1])) | |
| return path_dicts | |
| def _follow_path(self, start_id): | |
| """Follow all possible paths from a node""" | |
| paths = [] | |
| visited = set() | |
| def dfs(node_id, current_path): | |
| if node_id in visited: | |
| return | |
| visited.add(node_id) | |
| node = self.nodedict[node_id] | |
| if not node.outEdges: | |
| paths.append(current_path + [node_id]) | |
| return | |
| for next_id in node.outEdges: | |
| dfs(next_id, current_path + [node_id]) | |
| dfs(start_id, []) | |
| return paths | |
| def merge_paths_between( | |
| self, | |
| start_node_id: int, | |
| end_node_id: int, | |
| path_sim_type: str = "llm", | |
| verbose: bool = False, | |
| **kwargs, | |
| ): | |
| path_dicts = self.find_paths_between(start_node_id, end_node_id) | |
| if path_sim_type == "llm": | |
| api = kwargs.get("api", "openai") | |
| model = kwargs.get("model", "gpt-4o-mini") | |
| domain = kwargs.get("domain", None) | |
| similarity_judge_prompt = kwargs.get("similarity_judge_prompt", None) | |
| def path_sim_func(path1_text, path2_text): | |
| return path_sim_llm( | |
| path1_text, | |
| path2_text, | |
| api=api, | |
| model=model, | |
| domain=domain, | |
| custom_similarity_judge_prompt=similarity_judge_prompt, | |
| ) | |
| elif path_sim_type == "cosine": | |
| pass | |
| # embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| # threshold = kwargs.get("threshold", 0.9) | |
| # path_sim_func = path_sim_cosine(embedding_model, threshold) | |
| else: | |
| raise ValueError(f"Invalid path similarity type: {path_sim_type}") | |
| # merge paths based on semantic similarity | |
| path_equivalence_classes = {} | |
| class_count = 0 | |
| for path_dict in path_dicts: | |
| if verbose: | |
| print(path_dict) | |
| found_class = False | |
| for _, eq_class in path_equivalence_classes.items(): | |
| # check if path dict is already in an equivalence class | |
| path1_text = ( | |
| path_dict["begin_text"] | |
| + " " | |
| + path_dict["body_text"] | |
| + " " | |
| + path_dict["end_text"] | |
| ) | |
| path2_text = ( | |
| eq_class[0]["begin_text"] | |
| + " " | |
| + eq_class[0]["body_text"] | |
| + " " | |
| + eq_class[0]["end_text"] | |
| ) | |
| judgement, num_input_tokens, num_output_tokens = path_sim_func( | |
| path1_text, path2_text | |
| ) | |
| self.num_input_tokens_used += num_input_tokens | |
| self.num_output_tokens_used += num_output_tokens | |
| if judgement: | |
| eq_class.append(path_dict) | |
| found_class = True | |
| break | |
| if not found_class: | |
| class_count += 1 | |
| path_equivalence_classes[class_count] = [path_dict] | |
| nodes_to_remove = set() # Track nodes to remove | |
| for _, eq_class in path_equivalence_classes.items(): | |
| path_dict = eq_class[0] | |
| if verbose: | |
| print(eq_class) | |
| # add new node with merged text | |
| new_node_id = self.addNode(path_dict["body_text"]) | |
| for sequence_id in path_dict["labels"]: | |
| self.nodedict[new_node_id].variations[sequence_id] = path_dict["body_text"] | |
| # collect nodes to remove from first path | |
| nodes_to_remove.update(path_dict["path"][:-1]) | |
| # process data regarding weights and labels | |
| labels = list(path_dict["labels"]) | |
| weight = path_dict["weight"] | |
| self.addEdge(start_node_id, new_node_id, label=labels, weight=weight) | |
| # Updated seq_paths for all labels to include new_node betwwen start_node and end_node | |
| for label in labels: | |
| index = self._seq_paths[label].index(start_node_id) | |
| if ( | |
| index + 1 < len(self._seq_paths[label]) | |
| and self._seq_paths[label][index + 1] != new_node_id | |
| ): | |
| self._seq_paths[label].insert(index + 1, new_node_id) | |
| self.addEdge(new_node_id, end_node_id, label=labels, weight=weight) | |
| self.nodedict[new_node_id].sequences = labels | |
| # process additional paths | |
| for path_dict in eq_class[1:]: | |
| for sequence_id in path_dict["labels"]: | |
| self.nodedict[new_node_id].variations[sequence_id] = path_dict["body_text"] | |
| nodes_to_remove.update(path_dict["path"][:-1]) | |
| # copy incoming edges to new node | |
| labels = list(path_dict["labels"]) | |
| weight = path_dict["weight"] | |
| self.addEdge(start_node_id, new_node_id, label=labels, weight=weight) | |
| # Updated seq_paths for all labels to include new_node betwwen start_node and end_node | |
| for label in labels: | |
| index = self._seq_paths[label].index(start_node_id) | |
| if ( | |
| index + 1 < len(self._seq_paths[label]) | |
| and self._seq_paths[label][index + 1] != new_node_id | |
| ): | |
| self._seq_paths[label].insert(index + 1, new_node_id) | |
| self.addEdge(new_node_id, end_node_id, label=labels, weight=weight) | |
| self.nodedict[new_node_id].sequences.extend(labels) | |
| self.nodedict[new_node_id].sequences = list(set(self.nodedict[new_node_id].sequences)) | |
| # Remove all collected nodes after processing | |
| for node_id in nodes_to_remove: | |
| if node_id in self.nodedict: | |
| if verbose: | |
| print(f"Removing node {node_id}") | |
| self.removeNode(node_id) | |
| def merge_divergent_paths(self, path_sim_type: str = "llm", verbose: bool = False, **kwargs): | |
| # add dummy end node to the end of the graph | |
| if not self.consensus_node_ids: | |
| self.merge_consensus_nodes(verbose=verbose) | |
| self.toposort() | |
| if self.start_id == -1: | |
| if verbose: | |
| print("Adding start node") | |
| self.start_id = self.addNode(text="START") | |
| self._nextnodeID += 1 | |
| self.consensus_node_ids.insert(0, self.start_id) | |
| for label, path in self._seq_paths.items(): | |
| self.addEdge(self.start_id, path[0], label=label, weight=1) | |
| path.insert(0, self.start_id) | |
| if self.end_id == -1: | |
| if verbose: | |
| print("Adding end node") | |
| self.end_id = self.addNode(text="END") | |
| self._nextnodeID += 1 | |
| self.consensus_node_ids = self.consensus_node_ids + [self.end_id] | |
| for label, path in self._seq_paths.items(): | |
| self.addEdge(path[-1], self.end_id, label=label, weight=1) | |
| path.append(self.end_id) | |
| for i in tqdm(range(len(self.consensus_node_ids) - 1)): | |
| if verbose: | |
| print(self.consensus_node_ids[i], self.consensus_node_ids[i + 1]) | |
| self.merge_paths_between( | |
| self.consensus_node_ids[i], | |
| self.consensus_node_ids[i + 1], | |
| path_sim_type=path_sim_type, | |
| verbose=verbose, | |
| **kwargs, | |
| ) | |
| def get_variable_node_ids(self): | |
| return [ | |
| node.ID for node in self.nodedict.values() if node.ID not in self.consensus_node_ids | |
| ] | |
| def compress_paths_between(self, start_node_id: int, end_node_id: int): | |
| pass | |
| def compress_graph(self): | |
| pass | |
| def update_influence_scores(self, outcome: Dict[int, float], discount_factor: float = 0.2): | |
| self.toposort() | |
| direct_scores = [] | |
| for node in self.nodedict.values(): | |
| next_out_weight = sum(e.weight for e in node.outEdges.values()) | |
| next_in_weight = sum(e.weight for e in node.inEdges.values()) | |
| if next_out_weight == self.num_sequences and next_in_weight == self.num_sequences: | |
| out_list = [] | |
| for edge in node.outEdges.values(): | |
| for _ in range(len(set(edge.labels))): | |
| out_list.append(np.mean([outcome[label] for label in set(edge.labels)])) | |
| direct_scores.append((node.ID, np.var(out_list))) | |
| scores = direct_scores.copy() | |
| # Start from the end and propagate influence backward | |
| for i in range(len(scores) - 2, -1, -1): | |
| # Current node gets its direct influence plus discounted influence of next node | |
| current_direct = scores[i][1] | |
| next_total = scores[i + 1][1] | |
| scores[i] = (scores[i][0], current_direct + discount_factor * next_total) | |
| scores.sort(key=lambda x: x[1], reverse=True) | |
| return scores | |
| def jsOutput( | |
| self, | |
| verbose: bool = False, | |
| annotate_consensus: bool = True, | |
| color_annotations: Dict[int, str] = None, | |
| ): | |
| """returns a list of strings containing a a description of the graph for viz.js, http://visjs.org""" | |
| # get the consensus sequence, which we'll use as the "spine" of the | |
| # graph | |
| pathdict = {} | |
| if annotate_consensus: | |
| path, __, __ = self.consensus() | |
| lines = ["var nodes = ["] | |
| ni = self.nodeiterator() | |
| count = 0 | |
| for node in ni(): | |
| title_text = "" | |
| if node.sequences: | |
| title_text += f"Sequences: {node.sequences}" | |
| if node.variations: | |
| title_text += ";;;".join( | |
| [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()] | |
| ) | |
| title_text = title_text.replace('"', "'") | |
| line = ( | |
| " {id:" | |
| + str(node.ID) | |
| + ', label: "' | |
| + str(node.ID) | |
| + ": " | |
| + node.text.replace('"', "'") | |
| + '", title: ' | |
| + '"' | |
| + title_text | |
| + '",' | |
| ) | |
| if color_annotations and node.ID in color_annotations: | |
| line += f" color: '{color_annotations[node.ID]}', " | |
| if node.ID in pathdict and count % 5 == 0 and annotate_consensus: | |
| line += ( | |
| ", x: " | |
| + str(pathdict[node.ID]) | |
| + ", y: 0 , fixed: { x:true, y:false}," | |
| + "color: '#7BE141', is_consensus:true}," | |
| ) | |
| else: | |
| line += "}," | |
| lines.append(line) | |
| lines[-1] = lines[-1][:-1] | |
| lines.append("];") | |
| lines.append(" ") | |
| lines.append("var edges = [ ") | |
| ni = self.nodeiterator() | |
| for node in ni(): | |
| nodeID = str(node.ID) | |
| for edge in node.outEdges: | |
| target = str(edge) | |
| weight = str(node.outEdges[edge].weight + 1.5) | |
| lines.append( | |
| " {from: " | |
| + nodeID | |
| + ", to: " | |
| + target | |
| + ", value: " | |
| + weight | |
| + ", color: '#4b72b0', arrows: 'to'}," | |
| ) | |
| if verbose: | |
| for alignededge in node.alignedTo: | |
| # These edges indicate alignment to different bases, and are | |
| # undirected; thus make sure we only plot them once: | |
| if node.ID > alignededge: | |
| continue | |
| target = str(alignededge) | |
| lines.append( | |
| " {from: " | |
| + nodeID | |
| + ", to: " | |
| + target | |
| + ', value: 1, style: "dash-line", color: "red"},' | |
| ) | |
| lines[-1] = lines[-1][:-1] | |
| lines.append("];") | |
| return lines | |
| def htmlOutput( | |
| self, | |
| outfile, | |
| verbose: bool = False, | |
| annotate_consensus: bool = True, | |
| color_annotations: Dict[int, str] = None, | |
| ): | |
| header = """ | |
| <!doctype html> | |
| <html> | |
| <head> | |
| <title>POA Graph Alignment</title> | |
| <script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/standalone/umd/vis-network.min.js"></script> | |
| </head> | |
| <body> | |
| <div id="loadingProgress">0%</div> | |
| <div id="mynetwork"></div> | |
| <script type="text/javascript"> | |
| // create a network | |
| """ | |
| outfile.write(textwrap.dedent(header[1:])) | |
| lines = self.jsOutput( | |
| verbose=verbose, | |
| annotate_consensus=annotate_consensus, | |
| color_annotations=color_annotations, | |
| ) | |
| for line in lines: | |
| outfile.write(line + "\n") | |
| footer = """ | |
| var container = document.getElementById('mynetwork'); | |
| var data= { | |
| nodes: nodes, | |
| edges: edges, | |
| }; | |
| var options = { | |
| width: '100%', | |
| height: '800px', | |
| physics: { | |
| enabled: false, | |
| stabilization: { | |
| updateInterval: 10, | |
| }, | |
| }, | |
| edges: { | |
| color: { | |
| inherit: false | |
| } | |
| }, | |
| layout: { | |
| hierarchical: { | |
| direction: "UD", | |
| sortMethod: "directed", | |
| shakeTowards: "roots", | |
| levelSeparation: 150, // Adjust as needed | |
| nodeSpacing: 800, // Adjust as needed | |
| treeSpacing: 200, // Adjust as needed | |
| parentCentralization: true, | |
| } | |
| } | |
| }; | |
| var network = new vis.Network(container, data, options); | |
| network.on('beforeDrawing', function(ctx) { | |
| nodes.forEach(function(node) { | |
| if (node.isConsensus) { | |
| // Set the level of spine nodes to the bottom | |
| network.body.data.nodes.update({ | |
| id: node.id, | |
| level: 0 // Set level to 0 for spine nodes | |
| }); | |
| } | |
| }); | |
| }); | |
| network.on("stabilizationProgress", function (params) { | |
| document.getElementById("loadingProgress").innerText = Math.round(params.iterations / params.total * 100) + "%"; | |
| }); | |
| network.once("stabilizationIterationsDone", function () { | |
| document.getElementById("loadingProgress").innerText = "100%"; | |
| setTimeout(function () { | |
| document.getElementById("loadingProgress").style.display = "none"; | |
| }, 500); | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| outfile.write(textwrap.dedent(footer)) | |
| def multi_consensus_response(self, abstention_threshold: Optional[float] = None, filter: bool = True): | |
| self.toposort() | |
| nodesInReverse = self.nodeidlist[::-1] | |
| maxnodeID = self.end_id | |
| nextInPath = [-1] * maxnodeID | |
| scores = np.zeros(len(self.nodeidlist)) | |
| id_to_index = {node_id: index for index, node_id in enumerate(self.nodeidlist)} | |
| index_to_id = {index: node_id for index, node_id in enumerate(self.nodeidlist)} | |
| for nodeID in nodesInReverse: | |
| bestWeightScoreEdges = [(-1, -1, None)] | |
| for neighbourID in self.nodedict[nodeID].outEdges: | |
| # print(f"nodeID: {nodeID}, neighbourID: {neighbourID}") | |
| e = self.nodedict[nodeID].outEdges[neighbourID] | |
| weightScoreEdge = (e.weight, scores[id_to_index[neighbourID]], neighbourID) | |
| if weightScoreEdge > bestWeightScoreEdges[0]: | |
| bestWeightScoreEdges = [weightScoreEdge] | |
| elif weightScoreEdge == bestWeightScoreEdges[0] and filter: | |
| bestWeightScoreEdges.append(weightScoreEdge) | |
| scores[id_to_index[nodeID]] = sum(bestWeightScoreEdges[0][0:2]) | |
| if bestWeightScoreEdges[0][2] is not None: | |
| nextInPath[id_to_index[nodeID]] = id_to_index[bestWeightScoreEdges[0][2]] | |
| else: | |
| nextInPath[id_to_index[nodeID]] = None | |
| pos = np.argmax(scores) | |
| path = [] | |
| text = [] | |
| labels = [] | |
| while pos is not None and pos > -1: | |
| if abstention_threshold is not None and self.nodedict[index_to_id[pos]].variations: | |
| if ( | |
| len(self.nodedict[index_to_id[pos]].labels) / self.num_sequences | |
| >= abstention_threshold | |
| ): | |
| path.append(index_to_id[pos]) | |
| labels.append(self.nodedict[index_to_id[pos]].labels) | |
| text.append(self.nodedict[index_to_id[pos]].text) | |
| else: | |
| path.append(index_to_id[pos]) | |
| labels.append(self.nodedict[index_to_id[pos]].labels) | |
| text.append(self.nodedict[index_to_id[pos]].text) | |
| pos = nextInPath[pos] | |
| # ignore END node | |
| path = path[:-1] | |
| # ignore END node | |
| text = text[:-1] | |
| # ignore START in text | |
| text[0] = text[0].replace("START", "") | |
| labels = labels[:-1] | |
| return " ".join(text) | |
| def consensus_response( | |
| self, selection_threshold: Optional[float] = 0.5, api: str = "openai" , model: str = "gpt-4o-mini", task: str = "bio", **kwargs | |
| ) -> str: | |
| self.toposort() | |
| consensus_node_ids = self.consensus_node_ids | |
| print(consensus_node_ids) | |
| selected_node_ids = [] | |
| for node_id in consensus_node_ids: | |
| if node_id == self.start_id or node_id == self.end_id: | |
| continue | |
| selected_node_ids.append(node_id) | |
| for neighbor_id in self.nodedict[node_id].outEdges: | |
| if neighbor_id in consensus_node_ids: | |
| continue | |
| if ( | |
| len(self.nodedict[neighbor_id].labels) / self.num_sequences | |
| >= selection_threshold | |
| ): | |
| selected_node_ids.append(neighbor_id) | |
| text = " ".join([self.nodedict[node_id].text for node_id in selected_node_ids]) | |
| print(text) | |
| cleaned_text = clean_up_text(text, task=task, api=api, model=model, **kwargs) | |
| return cleaned_text | |
| def save_to_pickle(self, filename): | |
| with open(filename, "wb+") as f: | |
| pickle.dump(self, f) | |
| def refine_graph( | |
| self, | |
| verbose: bool = False, | |
| save_intermediate_file: str = None, | |
| final_merge: bool = True, | |
| **kwargs, | |
| ): | |
| self.merge_consensus_nodes(verbose=verbose) | |
| if save_intermediate_file: | |
| with open(save_intermediate_file, "w+") as f: | |
| self.htmlOutput(f, annotate_consensus=False) | |
| if not self.consensus_node_ids: | |
| self.failed = True | |
| return | |
| else: | |
| self.merge_divergent_paths(verbose=verbose, **kwargs) | |
| if final_merge: | |
| try: | |
| self.merge_consensus_nodes(verbose=verbose) | |
| except Exception as e: | |
| print(e) | |
| self.failed = True | |