File size: 8,053 Bytes
dd4b76a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c849723
dd4b76a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d331bf4
dd4b76a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13ecc63
dd4b76a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13ecc63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import hashlib
import pickle
from pathlib import Path
from itertools import zip_longest

import gradio as gr
import torch
from sentence_transformers import SentenceTransformer, util
import numpy as np
import ruptures as rpt

from util import sent_tokenize


CACHE_DIR = '.cache'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
_ST_MODELS = ['all-mpnet-base-v2', 'multi-qa-mpnet-base-dot-v1', 'all-MiniLM-L12-v2']


def embed_sentences(sentences, embedder_fn, cache_path):
    if Path(cache_path).exists():
        print(f'Loading embeddings from cache: {cache_path}')
        with open(cache_path, 'rb') as file:
            embedded_sents = pickle.load(file)
    else:
        print(f'Embedding sentences and saving to cache: {cache_path}')
        embedded_sents = embedder_fn(sentences)
        assert len(embedded_sents) == len(sentences)
        
        with open(cache_path, 'wb') as file:
            pickle.dump(embedded_sents, file)
            
    return embedded_sents


def calculate_cosine_similarities(embedded_sents, k=1, pool='mean'):
    def cosine_similarity(a, b):
        sim = util.cos_sim(a, b)
        if pool == 'mean':
            return sim.mean().item()
        elif pool == 'max':
            return sim.max().item()
        elif pool == 'min':
            return sim.min().item()
        else:
            raise ValueError(f'Invalid pooling method: {pool}')

    cosine_sims = []
    for i in range(len(embedded_sents) - 1):
        lctx = embedded_sents[max(0, i-k+1) : i+1]
        rctx = embedded_sents[i+1 : i+k+1]
        sim = cosine_similarity(lctx, rctx)
        cosine_sims.append(sim)
    # cosine_sims.append(0.0)
    
    assert len(cosine_sims) == len(embedded_sents) - 1, f'{len(cosine_sims)} != {len(embedded_sents)}'
    return cosine_sims


def predict_boundaries(cosine_sims, threshold):
    probs = [1.0 - sim for sim in cosine_sims]
    preds = [1 if prob >= threshold else 0 for prob in probs]
    return preds, probs
    

def output_segments(sents, preds, probs):
    assert len(sents) - 1 == len(preds) == len(probs), f'{len(sents)} - 1 != {len(preds)} != {len(probs)}'
    
    def iter_segments(sents, preds, probs):
        segment = []
        for i, (sent, pred, prob) in enumerate(zip_longest(sents, preds, probs)):
            segment.append({
                # 'id': i + 1, 
                'text': sent, 
                'prob': round(prob, 4) if prob is not None else None,
            })
            if pred == 1:
                yield segment
                segment = []
        if len(segment) > 0:
            yield segment
            segment = []
    
    out = {
        'metadata': {},
        'chunks': [],
    }
    n_segs = 0
    n_sents = 0
    for _, segment in enumerate(iter_segments(sents, preds, probs)):
        # out['chunks'].append({
        #     'id': n_segs + 1, 
        #     'chunk': segment,
        # })
        out['chunks'].append(segment)
        n_segs += 1
        n_sents += len(segment)

    out['metadata'] = {
        'n_chunks': n_segs,
        'n_sents': n_sents,
        'prob_mean': round(np.mean(probs), 4),
        'prob_std': round(np.std(probs), 4),
        'prob_min': round(min(probs), 4),
        'prob_max': round(max(probs), 4),
    }
    
    out_text = "\n-------------------------\n".join(["\n".join([sent['text'] for sent in segment]) for segment in out['chunks']])
    
    def plot_regimes(signal, preds):
        def get_bkps_from_labels(labels):
            return [i+1 for i, l in enumerate(labels) if l == 1]

        # signal = signal[:-1]
        preds = preds + [1]
        bkps = get_bkps_from_labels(preds)
        
        fig, [ax] = rpt.display(np.array(signal), bkps, figsize=(10, 5), dpi=250)
        y_min = max(0.0, min(signal) - 0.1)
        y_max = min(1.0, max(signal) + 0.1)
        ax.set_ylim(y_min, y_max)
        ax.set_title("Segment Regimes")
        ax.set_xlabel("Sentence Index")
        ax.set_ylabel("Semantic Shift Probability")
        fig.tight_layout()
        
        return fig

    fig = plot_regimes(probs, preds)
    
    return out_text, out, fig


def text_segmentation(input_text, model_name, k, pool, threshold):
    if model_name in _ST_MODELS:
        model = SentenceTransformer(model_name, device=DEVICE)
        embedder_fn = model.encode   
    else:
        raise ValueError(f'Invalid model name: {model_name}')

    sents = sent_tokenize(input_text, method='nltk', initial_split_sep='\n')
    
    cache_id = hashlib.md5(input_text.encode()).hexdigest()
    cache_path = Path(CACHE_DIR) / f'{cache_id}.pkl'
    embedded_sents = embed_sentences(sents, embedder_fn, cache_path=cache_path)
    
    cosine_sims = calculate_cosine_similarities(embedded_sents, k=k, pool=pool)
    
    preds, probs = predict_boundaries(cosine_sims, threshold=threshold)

    return output_segments(sents, preds, probs)
    

with gr.Blocks() as app:
    gr.Markdown("""
# LLM TextTiling Demo

An **extended** approach to text segmentation that combines **TextTiling** with **LLM embeddings**. Simply provide your text, choose an embedding model, and adjust segmentation parameters (window size, pooling, threshold). The demo will split your text into coherent segments based on **semantic shifts**. Refer to the [README](https://huggingface.co/spaces/saeedabc/llm-text-tiling-demo/blob/main/README.md) for more details.
""")

    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(label="Input Text", placeholder="Enter your text here...", lines=15)
            
            with gr.Row():
                with gr.Column():
                    # model_name = gr.Radio(choices=_ST_MODELS, label="Embedding Model", value=_ST_MODELS[0])
                    model_name = gr.Dropdown(choices=_ST_MODELS, label="Embedding Model", value=_ST_MODELS[0])
                
                with gr.Column():
                    pool = gr.Dropdown(choices=['max', 'mean', 'min'], label="Pooling Strategy", value='max')
                    
            with gr.Row():
                with gr.Column():
                    threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Threshold", value=0.5)
                
                with gr.Column():
                    k = gr.Slider(minimum=1, maximum=10, step=1, label="Window Size", value=3)


            submit_button = gr.Button("Chunk Text")

        with gr.Column():
            with gr.Tabs():
                with gr.Tab("Output Text"):
                    output_text = gr.Textbox(label="Output Text", placeholder="Chunks will appear here...", lines=22)
                with gr.Tab("Output Json"):
                    output_json = gr.Json(label="Output Json", open=False, max_height=500)
                with gr.Tab("Output Visualization"):          
                    output_fig = gr.Plot(label="Output Visualization")

    submit_button.click(text_segmentation, inputs=[input_text, model_name, k, pool, threshold], outputs=[output_text, output_json, output_fig])

    examples = gr.Examples(
        examples=[
            ["Rib Mountain is a census-designated place (CDP) in the town of Rib Mountain in Marathon County, Wisconsin, United States. "
             "The population was 5,651 at the 2010 census. "
             "The community is named for Rib Mountain. "
             "According to the United States Census Bureau, the CDP has a total area of 33.8 km² (13.0 mi²). "
             "31.4 km² (12.1 mi²) of it is land and 2.4 km² (0.9 mi²) of it (6.98%) is water. "
             "As of the census of 2000, there were 6,059 people, 2,211 households, and 1,782 families residing in the CDP. "
             "The population density was 193.0/km² (499.8/mi²). "
             "There were 2,278 housing units at an average density of 72.6/km² (187.9/mi²).", "all-mpnet-base-v2", 3, 'max', 0.52],
        ],
        inputs=[input_text, model_name, k, pool, threshold],
    )

if __name__ == '__main__':
    Path(CACHE_DIR).mkdir(exist_ok=True)
    
    # Launch the app
    app.launch()