congr-visualizer / src /new_text_alignment.py
Shahzaib98's picture
Upload 11 files
d2ff6a7 verified
from difflib import SequenceMatcher
import numpy as np
from .new_alignment import ScoreParam, SeqGraphAlignment
PUNCTUATION_MARKS = [".", "!", "?", ",", ":", ";", "...", "(", ")"]
class TextSeqGraphAlignment(SeqGraphAlignment):
def __init__(
self,
text,
graph,
fastMethod=True,
globalAlign=True,
matchscore=1,
mismatchscore=-3,
gap_open=-2,
gap_extend=-1,
position_weight=0.1,
*args,
**kwargs,
):
score_params = ScoreParam(
match=matchscore, mismatch=mismatchscore, gap_open=gap_open, gap_extend=gap_extend
)
if isinstance(text, str):
self.original_text = text
self.sequence = text.split()
else:
self.sequence = text
self.original_text = " ".join(text)
self.position_weight = position_weight
super().__init__(
self.sequence,
graph,
fastMethod,
globalAlign=globalAlign,
score_params=score_params,
*args,
**kwargs,
)
def string_similarity(self, s1, s2):
"""Get edit-distance based similarity between two strings"""
return SequenceMatcher(None, s1, s2).ratio()
def matchscore(self, word1: str, word2: str) -> float:
"""Enhanced scoring function that considers string similarity
and relative position"""
# Calculate basic string similarity
similarity = self.string_similarity(word1, word2)
# If words are very similar, treat as match
if similarity > 0.8: # Can tune this threshold
similarity = self.score.match
# For less similar words, scale score based on similarity
elif similarity > 0.5: # Can tune this threshold too
similarity = self.score.match * similarity
else:
similarity = self.score.mismatch
return similarity
# add weight if any punctuation mark is present
if any(char in word1 for char in PUNCTUATION_MARKS) or any(
char in word2 for char in PUNCTUATION_MARKS
):
similarity = similarity * 1.5
return similarity
def alignmentStrings(self):
"""Override to handle word-based alignment"""
aligned_seq = [self.sequence[i] if i is not None else "-" for i in self.stringidxs]
aligned_graph = [
self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs
]
return " ".join(aligned_seq), " ".join(aligned_graph)
def alignStringToGraphFast(self):
if not isinstance(self.sequence, list):
raise TypeError("Sequence must be a list of words")
nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx, backMtxIdx = (
self.initializeDynamicProgrammingData()
)
# M: Match at last indices, X: Gap at last index of graph, Y: gap at last index of sequence
M, X, Y = 0, 1, 2
ni = self.graph.nodeiterator()
for i, node in enumerate(ni()):
gbase = node.text
for j, sbase in enumerate(self.sequence):
candidates_X , candidates_Y , candidates_M = [], [], []
candidates_X += [
(self.score.gap_open + self.score.gap_extend + scores[0, i + 1, j], i + 1, j, M),
(self.score.gap_extend + scores[1, i + 1, j], i + 1, j, X),
(self.score.gap_open + self.score.gap_extend + scores[2, i + 1, j], i + 1, j, Y)
]
for predIndex in self.prevIndices(node, nodeIDtoIndex):
candidates_Y += [
(self.score.gap_open + self.score.gap_extend + scores[0, predIndex + 1, j + 1] , predIndex + 1, j + 1, M),
(self.score.gap_open + self.score.gap_extend + scores[1, predIndex + 1, j + 1] , predIndex + 1, j + 1, X),
(self.score.gap_extend + scores[2, predIndex + 1, j + 1] , predIndex + 1, j + 1, Y)
]
candidates_M += [
(self.matchscore(sbase, gbase) + scores[0, predIndex + 1, j], predIndex + 1, j, M),
(self.matchscore(sbase, gbase) + scores[1, predIndex + 1, j], predIndex + 1, j, X),
(self.matchscore(sbase, gbase) + scores[2, predIndex + 1, j], predIndex + 1, j, Y)
]
(
scores[0, i + 1, j + 1],
backGrphIdx[0, i + 1, j + 1],
backStrIdx[0, i + 1, j + 1],
backMtxIdx[0, i + 1, j + 1],
) = max(candidates_M)
(
scores[1, i + 1, j + 1],
backGrphIdx[1, i + 1, j + 1],
backStrIdx[1, i + 1, j + 1],
backMtxIdx[1, i + 1, j + 1],
) = max(candidates_X)
(
scores[2, i + 1, j + 1],
backGrphIdx[2, i + 1, j + 1],
backStrIdx[2, i + 1, j + 1],
backMtxIdx[2, i + 1, j + 1],
) = max(candidates_Y)
return self.backtrack(scores, backStrIdx, backGrphIdx, backMtxIdx ,nodeIndexToID)