congr-visualizer / src /poa_graph.py
Shahzaib98's picture
Upload 11 files
d2ff6a7 verified
"""
Adapted from Jonathan Dursi
https://github.com/ljdursi/poapy
"""
import collections
import textwrap
from typing import Dict, List, Optional, Union
import numpy
from .alignment import SeqGraphAlignment
class Node(object):
def __init__(self, nodeID: int = -1, text: str = ""):
self.ID = nodeID
self.text = text
self.inEdges = {}
self.outEdges = {}
self.alignedTo = []
def __str__(self):
return "(%d:%s)" % (self.ID, self.text)
def _add_edge(
self,
edgeset: Dict[int, "Node"],
neighbourID: int,
label: Union[int, List[int]],
from_neighbour: bool,
weight: int = 1,
):
if neighbourID is None:
return
# already present? just update labels
# otherwise create appropriately-ordered edge and proceed
if neighbourID in edgeset:
edgeset[neighbourID].weight += weight
if isinstance(label, list):
edgeset[neighbourID].labels.extend(label)
else:
edgeset[neighbourID].labels.append(label)
# remove duplicates
edgeset[neighbourID].labels = list(set(edgeset[neighbourID].labels))
else:
if from_neighbour:
edge = Edge(outNodeID=neighbourID, inNodeID=self.ID, label=label, weight=weight)
else:
edge = Edge(outNodeID=self.ID, inNodeID=neighbourID, label=label, weight=weight)
edgeset[neighbourID] = edge
def addInEdge(self, neighbourID: int, label: Optional[Union[int, List[int]]], weight: int = 1):
self._add_edge(self.inEdges, neighbourID, label, from_neighbour=True, weight=weight)
def addOutEdge(self, neighbourID: int, label: Optional[Union[int, List[int]]], weight: int = 1):
self._add_edge(self.outEdges, neighbourID, label, from_neighbour=False, weight=weight)
def nextNode(self, label: int):
"""Returns the first (presumably only) outward neighbour
having the given edge label"""
nextID = None
for e in self.outEdges:
if label in self.outEdges[e].labels:
nextID = e
return nextID
@property
def inDegree(self):
return len(self.inEdges)
@property
def outDegree(self):
return len(self.outEdges)
@property
def weightedInDegree(self):
return sum(edge.weight for edge in self.inEdges.values())
@property
def weightedOutDegree(self):
return sum(edge.weight for edge in self.outEdges.values())
@property
def labels(self):
"""Returns all the labels associated with an in-edge or an out edge."""
labelset = set([])
for e in list(self.inEdges.values()):
labelset = labelset.union(e.labels)
for e in list(self.outEdges.values()):
labelset = labelset.union(e.labels)
return list(labelset)
class Edge(object):
def __init__(
self,
inNodeID: int = -1,
outNodeID: int = -1,
label: Optional[Union[int, List[int]]] = None,
weight: int = 1,
):
self.inNodeID = inNodeID
self.outNodeID = outNodeID
self.weight = weight
if label is None:
self.labels = []
elif isinstance(label, list):
self.labels = label
else:
self.labels = [label]
def addLabel(self, newlabel):
self.labels.append(newlabel)
def __str__(self):
nodestr = "(%d) -> (%d) " % (self.inNodeID, self.outNodeID)
if self.labels is None:
return nodestr
else:
return nodestr + self.labels.__str__()
class POAGraph(object):
def addUnmatchedSeq(self, seq, label: int = -1, updateSequences=True):
"""Add a completely independant (sub)string to the graph,
and return node index to initial and final node"""
if seq is None:
return
firstID, lastID = None, None
neededSort = self.needsSort
for text in seq:
nodeID = self.addNode(text)
if firstID is None:
firstID = nodeID
if lastID is not None:
self.addEdge(lastID, nodeID, label)
lastID = nodeID
self._needsort = neededSort # no new order problems introduced
if updateSequences:
self._seqs.append(seq)
self._labels.append(label)
self._starts.append(firstID)
return firstID, lastID
def __init__(self, seq=None, label: Optional[Union[int, List[int]]] = None):
self._nextnodeID = 0
self._nnodes = 0
self._nedges = 0
self.nodedict = {}
self.nodeidlist = [] # allows a (partial) order to be imposed on the nodes
self._needsort = False
self._labels = []
self._seqs = []
self._starts = []
if seq is not None:
self.addUnmatchedSeq(seq, label)
def nodeIdxToBase(self, idx):
return self.nodedict[self.nodeidlist[idx]].text
def addNode(self, text):
nid = self._nextnodeID
newnode = Node(nid, text)
self.nodedict[nid] = newnode
self.nodeidlist.append(nid)
self._nnodes += 1
self._nextnodeID += 1
self._needsSort = True
return nid
def addEdge(self, start, end, label, weight: int = 1):
if start is None or end is None:
return
if start not in self.nodedict:
raise KeyError("addEdge: Start node not in graph: " + str(start))
if end not in self.nodedict:
raise KeyError("addEdge: End node not in graph: " + str(end))
oldNodeEdges = self.nodedict[start].outDegree + self.nodedict[end].inDegree
self.nodedict[start].addOutEdge(end, label, weight)
self.nodedict[end].addInEdge(start, label, weight)
newNodeEdges = self.nodedict[start].outDegree + self.nodedict[end].inDegree
if newNodeEdges != oldNodeEdges:
self._nedges += 1
self._needsSort = True
return
@property
def needsSort(self):
return self._needsort
@property
def nNodes(self):
return self._nnodes
@property
def nEdges(self):
return self._nedges
@property
def num_sequences(self):
return len(self._seqs)
def get_sequences(self):
return self._seqs
def _simplified_graph_rep(self):
node_to_pn = {}
pn_to_nodes = {}
# Find the mappings from nodes to pseudonodes
cur_pnid = 0
for _, node in self.nodedict.items():
if node.ID not in node_to_pn:
node_ids = [node.ID] + node.alignedTo
pn_to_nodes[cur_pnid] = node_ids
for nid in node_ids:
node_to_pn[nid] = cur_pnid
cur_pnid += 1
# create the pseudonodes
Pseudonode = collections.namedtuple(
"Pseudonode", ["pnode_id", "predecessors", "successors", "node_ids"]
)
pseudonodes = []
for pnid in range(cur_pnid):
nids, preds, succs = pn_to_nodes[pnid], [], []
for nid in nids:
node = self.nodedict[nid]
preds += [node_to_pn[inEdge.outNodeID] for _, inEdge in node.inEdges.items()]
succs += [node_to_pn[outEdge.inNodeID] for _, outEdge in node.outEdges.items()]
pn = Pseudonode(pnode_id=pnid, predecessors=preds, successors=succs, node_ids=nids)
pseudonodes.append(pn)
return pseudonodes
def toposort(self):
"""Sorts node list so that all incoming edges come from nodes earlier in the list."""
sortedlist = []
completed = set([])
#
# The topological sort of this graph is complicated by the alignedTo edges;
# we want to nodes connected by such edges to remain near each other in the
# topological sort.
#
# Here we'll create a simple version of the graph that merges nodes that
# are alignedTo each other, performs the sort, and then decomposes the
# 'pseudonodes'.
#
# The need for this suggests that the way the graph is currently represented
# isn't quite right and needs some rethinking.
#
pseudonodes = self._simplified_graph_rep()
def dfs(start, complete, sortedlist):
stack, started = [start], set()
while stack:
pnodeID = stack.pop()
if pnodeID in complete:
continue
if pnodeID in started:
complete.add(pnodeID)
for nid in pseudonodes[pnodeID].node_ids:
sortedlist.insert(0, nid)
started.remove(pnodeID)
continue
successors = pseudonodes[pnodeID].successors
started.add(pnodeID)
stack.append(pnodeID)
stack.extend(successors)
while len(sortedlist) < self.nNodes:
found = None
for pnid in range(len(pseudonodes)):
if pnid not in completed and len(pseudonodes[pnid].predecessors) == 0:
found = pnid
break
assert found is not None
dfs(found, completed, sortedlist)
assert len(sortedlist) == self.nNodes
self.nodeidlist = sortedlist
self._needsSort = False
return
def testsort(self):
"""Test the nodeidlist to make sure it is topologically sorted:
eg, all predecessors of a node preceed the node in the list"""
if self.nodeidlist is None:
return
seen_nodes = set()
for nodeidx in self.nodeidlist:
node = self.nodedict[nodeidx]
for in_neighbour in node.inEdges:
assert in_neighbour in seen_nodes
seen_nodes.add(nodeidx)
return
def nodeiterator(self):
if self.needsSort:
self.toposort()
def nodegenerator():
for nodeidx in self.nodeidlist:
yield self.nodedict[nodeidx]
return nodegenerator
def __str__(self):
selfstr = ""
ni = self.nodeiterator()
for node in ni():
selfstr += node.__str__() + "\n"
for outIdx in node.outEdges:
selfstr += " " + node.outEdges[outIdx].__str__() + "\n"
return selfstr
def incorporateSeqAlignment(self, alignment: SeqGraphAlignment, seq, label: int = -1):
"""Incorporate a SeqGraphAlignment into the graph."""
newseq = alignment.sequence
stringidxs = alignment.stringidxs
nodeidxs = alignment.nodeidxs
firstID = None
headID = None
tailID = None
path = []
# head, tail of sequence may be unaligned; just add those into the
# graph directly
validstringidxs = [si for si in stringidxs if si is not None]
startSeqIdx, endSeqIdx = validstringidxs[0], validstringidxs[-1]
if startSeqIdx > 0:
firstID, headID = self.addUnmatchedSeq(
newseq[0:startSeqIdx], label, updateSequences=False
)
if endSeqIdx < len(newseq):
tailID, __ = self.addUnmatchedSeq(newseq[endSeqIdx + 1 :], label, updateSequences=False)
# now we march along the aligned part. For each text, we find or create
# a node in the graph:
# - if unmatched, the corresponding node is a new node
# - if matched:
# - if matched to a node with the same text, the node is that node
# - if matched to a node with a different text whch is in turn
# aligned to a node with the same text, that aligned node is
# the node
# - otherwise, we create a new node.
# In all cases, we create edges (or add labels) threading through the
# nodes.
for sindex, matchID in zip(stringidxs, nodeidxs):
if sindex is None:
continue
text = newseq[sindex]
if matchID is None:
nodeID = self.addNode(text)
elif self.nodedict[matchID].text == text:
nodeID = matchID
else:
otherAligns = self.nodedict[matchID].alignedTo
foundNode = None
for otherNodeID in otherAligns:
if self.nodedict[otherNodeID].text == text:
foundNode = otherNodeID
if foundNode is None:
nodeID = self.addNode(text)
self.nodedict[nodeID].alignedTo = [matchID] + otherAligns
for otherNodeID in [matchID] + otherAligns:
self.nodedict[otherNodeID].alignedTo.append(nodeID)
else:
nodeID = foundNode
self.addEdge(headID, nodeID, label)
headID = nodeID
if firstID is None:
firstID = headID
path.append(nodeID)
# finished the unaligned portion: now add an edge from the current headID to the tailID.
self.addEdge(headID, tailID, label)
# resort
self.toposort()
self._seqs.append(seq)
self._labels.append(label)
self._starts.append(firstID)
self._seq_paths[label] = path
return
def consensus(self, excludeLabels=None):
if excludeLabels is None:
excludeLabels = []
if self.needsSort:
self.toposort()
nodesInReverse = self.nodeidlist[::-1]
maxnodeID = max(nodesInReverse) + 1
nextInPath = [-1] * maxnodeID
scores = numpy.zeros((maxnodeID))
for nodeID in nodesInReverse:
bestWeightScoreEdge = (-1, -1, None)
for neighbourID in self.nodedict[nodeID].outEdges:
# print(f"nodeID: {nodeID}, neighbourID: {neighbourID}")
e = self.nodedict[nodeID].outEdges[neighbourID]
weightScoreEdge = (e.weight, scores[neighbourID], neighbourID)
if weightScoreEdge > bestWeightScoreEdge:
bestWeightScoreEdge = weightScoreEdge
scores[nodeID] = sum(bestWeightScoreEdge[0:2])
nextInPath[nodeID] = bestWeightScoreEdge[2]
pos = numpy.argmax(scores)
path = []
bases = []
labels = []
while pos is not None and pos > -1:
path.append(pos)
bases.append(self.nodedict[pos].text)
labels.append(self.nodedict[pos].labels)
pos = nextInPath[pos]
# ignore END node
path = path[:-1]
bases = bases[:-1]
labels = labels[:-1]
return path, bases, labels
def allConsenses(self, maxfraction=0.5):
allpaths = []
allbases = []
alllabels = []
exclusions = []
passno = 0
lastlen = 1000
maxpasses = 10
while len(exclusions) < len(self._labels) and lastlen >= 10 and passno < maxpasses:
path, bases, labellists = self.consensus(exclusions)
if len(path) > 0:
allpaths.append(path)
allbases.append(bases)
alllabels.append(labellists)
labelcounts = collections.defaultdict(int)
for ll in labellists:
for label in ll:
labelcounts[label] += 1
for label, seq in zip(self._labels, self._seqs):
if label in labelcounts and labelcounts[label] >= maxfraction * len(seq):
exclusions.append(label)
lastlen = len(path)
passno += 1
return list(zip(allpaths, allbases, alllabels))
def generateAlignmentStrings(self):
"""Return a list of strings corresponding to the alignments in the graph"""
# Step 1: assign node IDs to columns in the output
# column_index[node.ID] is the position in the toposorted node list
# of the node itself, or the earliest node it is aligned to.
column_index = {}
current_column = 0
# go through nodes in toposort order
ni = self.nodeiterator()
for node in ni():
other_columns = [
column_index[other] for other in node.alignedTo if other in column_index
]
if other_columns:
found_idx = min(other_columns)
else:
found_idx = current_column
current_column += 1
column_index[node.ID] = found_idx
ncolumns = current_column
# Step 2: given the column indexes, populate the strings
# corresponding to the sequences inserted in the graph
seqnames = []
alignstrings = []
for label, start in zip(self._labels, self._starts):
seqnames.append(label)
curnode_id = start
charlist = ["-"] * ncolumns
while curnode_id is not None:
node = self.nodedict[curnode_id]
charlist[column_index[curnode_id]] = node.text
curnode_id = node.nextNode(label)
alignstrings.append("".join(charlist))
# Step 3: Same as step 2, but with consensus sequences
consenses = self.allConsenses()
for i, consensus in enumerate(consenses):
seqnames.append("Consensus" + str(i))
charlist = ["-"] * ncolumns
for path, text in zip(consensus[0], consensus[1]):
charlist[column_index[path]] = text
alignstrings.append("".join(charlist))
return list(zip(seqnames, alignstrings))
def jsOutput(self, verbose: bool = False, annotate_consensus: bool = True):
"""returns a list of strings containing a a description of the graph for viz.js, http://visjs.org"""
# get the consensus sequence, which we'll use as the "spine" of the
# graph
pathdict = {}
if annotate_consensus:
path, __, __ = self.consensus()
lines = ["var nodes = ["]
ni = self.nodeiterator()
count = 0
for node in ni():
line = " {id:" + str(node.ID) + ', label: "' + str(node.ID) + ": " + node.text + '"'
if node.ID in pathdict and count % 5 == 0 and annotate_consensus:
line += (
", x: "
+ str(pathdict[node.ID])
+ ", y: 0 , fixed: { x:true, y:false},"
+ "color: '#7BE141', is_consensus:true},"
)
else:
line += "},"
lines.append(line)
lines[-1] = lines[-1][:-1]
lines.append("];")
lines.append(" ")
lines.append("var edges = [")
ni = self.nodeiterator()
for node in ni():
nodeID = str(node.ID)
for edge in node.outEdges:
target = str(edge)
weight = str(len(node.outEdges[edge].labels) + 1.5)
lines.append(
" {from: "
+ nodeID
+ ", to: "
+ target
+ ", value: "
+ weight
+ ", color: '#4b72b0', arrows: 'to'},"
)
if verbose:
for alignededge in node.alignedTo:
# These edges indicate alignment to different bases, and are
# undirected; thus make sure we only plot them once:
if node.ID > alignededge:
continue
target = str(alignededge)
lines.append(
" {from: "
+ nodeID
+ ", to: "
+ target
+ ', value: 1, style: "dash-line", color: "red"},'
)
lines[-1] = lines[-1][:-1]
lines.append("];")
return lines
def htmlOutput(self, outfile, verbose: bool = False, annotate_consensus: bool = True):
header = """
<!doctype html>
<html>
<head>
<title>POA Graph Alignment</title>
<script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/standalone/umd/vis-network.min.js"></script>
</head>
<body>
<div id="loadingProgress">0%</div>
<div id="mynetwork"></div>
<script type="text/javascript">
// create a network
"""
outfile.write(textwrap.dedent(header[1:]))
lines = self.jsOutput(verbose=verbose, annotate_consensus=annotate_consensus)
for line in lines:
outfile.write(line + "\n")
footer = """
var container = document.getElementById('mynetwork');
var data= {
nodes: nodes,
edges: edges,
};
var options = {
width: '100%',
height: '800px',
physics: {
enabled: false,
stabilization: {
updateInterval: 10,
},
hierarchicalRepulsion: {
avoidOverlap: 0.9,
},
},
edges: {
color: {
inherit: false
}
},
layout: {
hierarchical: {
direction: "UD",
sortMethod: "directed",
shakeTowards: "roots",
levelSeparation: 150, // Adjust as needed
nodeSpacing: 100, // Adjust as needed
treeSpacing: 200, // Adjust as needed
parentCentralization: true,
}
}
};
var network = new vis.Network(container, data, options);
network.on('beforeDrawing', function(ctx) {
nodes.forEach(function(node) {
if (node.isConsensus) {
// Set the level of spine nodes to the bottom
network.body.data.nodes.update({
id: node.id,
level: 0 // Set level to 0 for spine nodes
});
}
});
});
network.on("stabilizationProgress", function (params) {
document.getElementById("loadingProgress").innerText = Math.round(params.iterations / params.total * 100) + "%";
});
network.once("stabilizationIterationsDone", function () {
document.getElementById("loadingProgress").innerText = "100%";
setTimeout(function () {
document.getElementById("loadingProgress").style.display = "none";
}, 500);
});
</script>
</body>
</html>
"""
outfile.write(textwrap.dedent(footer))