AkashKhamkar commited on
Commit
89b579b
·
1 Parent(s): 03ec754

Upload segmentation.py

Browse files
Files changed (1) hide show
  1. segmentation.py +96 -0
segmentation.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import attr
2
+ import pandas as pd
3
+ import numpy as np
4
+ from .utils import load_sentence_transformer, load_spacy
5
+ from nltk.tokenize.texttiling import TextTilingTokenizer
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+
8
+ model = load_sentence_transformer()
9
+ nlp = load_spacy()
10
+
11
+
12
+ @attr.s
13
+ class SemanticTextSegmentation:
14
+
15
+ """
16
+ Segment a call transcript based on topics discussed in the call using
17
+ TextTilling with Sentence Similarity via sentence transformer.
18
+
19
+ Paramters
20
+ ---------
21
+ data: pd.Dataframe
22
+ Pass the trascript in the dataframe format
23
+
24
+ utterance: str
25
+ pass the column name which represent utterance in transcript dataframe
26
+
27
+ """
28
+
29
+ data = attr.ib()
30
+ utterance = attr.ib(default='utterance')
31
+
32
+ def __attrs_post_init__(self):
33
+ columns = self.data.columns.tolist()
34
+
35
+ def get_segments(self, threshold=0.7):
36
+ """
37
+ returns the transcript segments computed with texttiling and sentence-transformer.
38
+
39
+ Paramters
40
+ ---------
41
+ threshold: float
42
+ sentence similarity threshold. (used to merge the sentences into coherant segments)
43
+
44
+ Return
45
+ ------
46
+ new_segments: list
47
+ list of segments
48
+ """
49
+ segments = self._text_tilling()
50
+ merge_index = self._merge_segments(segments, threshold)
51
+ new_segments = []
52
+ for i in merge_index:
53
+ seg = ' '.join([segments[_] for _ in i])
54
+ new_segments.append(seg)
55
+ return new_segments
56
+
57
+ def _merge_segments(self, segments, threshold):
58
+ segment_map = [0]
59
+ for index, (text1, text2) in enumerate(zip(segments[:-1], segments[1:])):
60
+ sim = self._get_similarity(text1, text2)
61
+ if sim >= threshold:
62
+ segment_map.append(0)
63
+ else:
64
+ segment_map.append(1)
65
+ return self._index_mapping(segment_map)
66
+
67
+ def _index_mapping(self, segment_map):
68
+ index_list = []
69
+ temp = []
70
+ for index, i in enumerate(segment_map):
71
+ if i == 1:
72
+ index_list.append(temp)
73
+ temp = [index]
74
+ else:
75
+ temp.append(index)
76
+ index_list.append(temp)
77
+ return index_list
78
+
79
+ def _get_similarity(self, text1, text2):
80
+ sentence_1 = [i.text.strip()
81
+ for i in nlp(text1).sents if len(i.text.split(' ')) > 1]
82
+ sentence_2 = [i.text.strip()
83
+ for i in nlp(text2).sents if len(i.text.split(' ')) > 2]
84
+ embeding_1 = model.encode(sentence_1)
85
+ embeding_2 = model.encode(sentence_2)
86
+ embeding_1 = np.mean(embeding_1, axis=0).reshape(1, -1)
87
+ embeding_2 = np.mean(embeding_2, axis=0).reshape(1, -1)
88
+ sim = cosine_similarity(embeding_1, embeding_2)
89
+ return sim
90
+
91
+ def _text_tilling(self):
92
+ tt = TextTilingTokenizer(w=15, k=10)
93
+ text = '\n\n\t'.join(self.data[self.utterance].tolist())
94
+ segment = tt.tokenize(text)
95
+ segment = [i.replace("\n\n\t", ' ') for i in segment]
96
+ return segment