from difflib import SequenceMatcher import numpy as np from .new_alignment import ScoreParam, SeqGraphAlignment PUNCTUATION_MARKS = [".", "!", "?", ",", ":", ";", "...", "(", ")"] class TextSeqGraphAlignment(SeqGraphAlignment): def __init__( self, text, graph, fastMethod=True, globalAlign=True, matchscore=1, mismatchscore=-3, gap_open=-2, gap_extend=-1, position_weight=0.1, *args, **kwargs, ): score_params = ScoreParam( match=matchscore, mismatch=mismatchscore, gap_open=gap_open, gap_extend=gap_extend ) if isinstance(text, str): self.original_text = text self.sequence = text.split() else: self.sequence = text self.original_text = " ".join(text) self.position_weight = position_weight super().__init__( self.sequence, graph, fastMethod, globalAlign=globalAlign, score_params=score_params, *args, **kwargs, ) def string_similarity(self, s1, s2): """Get edit-distance based similarity between two strings""" return SequenceMatcher(None, s1, s2).ratio() def matchscore(self, word1: str, word2: str) -> float: """Enhanced scoring function that considers string similarity and relative position""" # Calculate basic string similarity similarity = self.string_similarity(word1, word2) # If words are very similar, treat as match if similarity > 0.8: # Can tune this threshold similarity = self.score.match # For less similar words, scale score based on similarity elif similarity > 0.5: # Can tune this threshold too similarity = self.score.match * similarity else: similarity = self.score.mismatch return similarity # add weight if any punctuation mark is present if any(char in word1 for char in PUNCTUATION_MARKS) or any( char in word2 for char in PUNCTUATION_MARKS ): similarity = similarity * 1.5 return similarity def alignmentStrings(self): """Override to handle word-based alignment""" aligned_seq = [self.sequence[i] if i is not None else "-" for i in self.stringidxs] aligned_graph = [ self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs ] return " ".join(aligned_seq), " ".join(aligned_graph) def alignStringToGraphFast(self): if not isinstance(self.sequence, list): raise TypeError("Sequence must be a list of words") nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx, backMtxIdx = ( self.initializeDynamicProgrammingData() ) # M: Match at last indices, X: Gap at last index of graph, Y: gap at last index of sequence M, X, Y = 0, 1, 2 ni = self.graph.nodeiterator() for i, node in enumerate(ni()): gbase = node.text for j, sbase in enumerate(self.sequence): candidates_X , candidates_Y , candidates_M = [], [], [] candidates_X += [ (self.score.gap_open + self.score.gap_extend + scores[0, i + 1, j], i + 1, j, M), (self.score.gap_extend + scores[1, i + 1, j], i + 1, j, X), (self.score.gap_open + self.score.gap_extend + scores[2, i + 1, j], i + 1, j, Y) ] for predIndex in self.prevIndices(node, nodeIDtoIndex): candidates_Y += [ (self.score.gap_open + self.score.gap_extend + scores[0, predIndex + 1, j + 1] , predIndex + 1, j + 1, M), (self.score.gap_open + self.score.gap_extend + scores[1, predIndex + 1, j + 1] , predIndex + 1, j + 1, X), (self.score.gap_extend + scores[2, predIndex + 1, j + 1] , predIndex + 1, j + 1, Y) ] candidates_M += [ (self.matchscore(sbase, gbase) + scores[0, predIndex + 1, j], predIndex + 1, j, M), (self.matchscore(sbase, gbase) + scores[1, predIndex + 1, j], predIndex + 1, j, X), (self.matchscore(sbase, gbase) + scores[2, predIndex + 1, j], predIndex + 1, j, Y) ] ( scores[0, i + 1, j + 1], backGrphIdx[0, i + 1, j + 1], backStrIdx[0, i + 1, j + 1], backMtxIdx[0, i + 1, j + 1], ) = max(candidates_M) ( scores[1, i + 1, j + 1], backGrphIdx[1, i + 1, j + 1], backStrIdx[1, i + 1, j + 1], backMtxIdx[1, i + 1, j + 1], ) = max(candidates_X) ( scores[2, i + 1, j + 1], backGrphIdx[2, i + 1, j + 1], backStrIdx[2, i + 1, j + 1], backMtxIdx[2, i + 1, j + 1], ) = max(candidates_Y) return self.backtrack(scores, backStrIdx, backGrphIdx, backMtxIdx ,nodeIndexToID)