Spaces:
Running
Running
File size: 8,053 Bytes
dd4b76a c849723 dd4b76a d331bf4 dd4b76a 13ecc63 dd4b76a 13ecc63 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | import hashlib
import pickle
from pathlib import Path
from itertools import zip_longest
import gradio as gr
import torch
from sentence_transformers import SentenceTransformer, util
import numpy as np
import ruptures as rpt
from util import sent_tokenize
CACHE_DIR = '.cache'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
_ST_MODELS = ['all-mpnet-base-v2', 'multi-qa-mpnet-base-dot-v1', 'all-MiniLM-L12-v2']
def embed_sentences(sentences, embedder_fn, cache_path):
if Path(cache_path).exists():
print(f'Loading embeddings from cache: {cache_path}')
with open(cache_path, 'rb') as file:
embedded_sents = pickle.load(file)
else:
print(f'Embedding sentences and saving to cache: {cache_path}')
embedded_sents = embedder_fn(sentences)
assert len(embedded_sents) == len(sentences)
with open(cache_path, 'wb') as file:
pickle.dump(embedded_sents, file)
return embedded_sents
def calculate_cosine_similarities(embedded_sents, k=1, pool='mean'):
def cosine_similarity(a, b):
sim = util.cos_sim(a, b)
if pool == 'mean':
return sim.mean().item()
elif pool == 'max':
return sim.max().item()
elif pool == 'min':
return sim.min().item()
else:
raise ValueError(f'Invalid pooling method: {pool}')
cosine_sims = []
for i in range(len(embedded_sents) - 1):
lctx = embedded_sents[max(0, i-k+1) : i+1]
rctx = embedded_sents[i+1 : i+k+1]
sim = cosine_similarity(lctx, rctx)
cosine_sims.append(sim)
# cosine_sims.append(0.0)
assert len(cosine_sims) == len(embedded_sents) - 1, f'{len(cosine_sims)} != {len(embedded_sents)}'
return cosine_sims
def predict_boundaries(cosine_sims, threshold):
probs = [1.0 - sim for sim in cosine_sims]
preds = [1 if prob >= threshold else 0 for prob in probs]
return preds, probs
def output_segments(sents, preds, probs):
assert len(sents) - 1 == len(preds) == len(probs), f'{len(sents)} - 1 != {len(preds)} != {len(probs)}'
def iter_segments(sents, preds, probs):
segment = []
for i, (sent, pred, prob) in enumerate(zip_longest(sents, preds, probs)):
segment.append({
# 'id': i + 1,
'text': sent,
'prob': round(prob, 4) if prob is not None else None,
})
if pred == 1:
yield segment
segment = []
if len(segment) > 0:
yield segment
segment = []
out = {
'metadata': {},
'chunks': [],
}
n_segs = 0
n_sents = 0
for _, segment in enumerate(iter_segments(sents, preds, probs)):
# out['chunks'].append({
# 'id': n_segs + 1,
# 'chunk': segment,
# })
out['chunks'].append(segment)
n_segs += 1
n_sents += len(segment)
out['metadata'] = {
'n_chunks': n_segs,
'n_sents': n_sents,
'prob_mean': round(np.mean(probs), 4),
'prob_std': round(np.std(probs), 4),
'prob_min': round(min(probs), 4),
'prob_max': round(max(probs), 4),
}
out_text = "\n-------------------------\n".join(["\n".join([sent['text'] for sent in segment]) for segment in out['chunks']])
def plot_regimes(signal, preds):
def get_bkps_from_labels(labels):
return [i+1 for i, l in enumerate(labels) if l == 1]
# signal = signal[:-1]
preds = preds + [1]
bkps = get_bkps_from_labels(preds)
fig, [ax] = rpt.display(np.array(signal), bkps, figsize=(10, 5), dpi=250)
y_min = max(0.0, min(signal) - 0.1)
y_max = min(1.0, max(signal) + 0.1)
ax.set_ylim(y_min, y_max)
ax.set_title("Segment Regimes")
ax.set_xlabel("Sentence Index")
ax.set_ylabel("Semantic Shift Probability")
fig.tight_layout()
return fig
fig = plot_regimes(probs, preds)
return out_text, out, fig
def text_segmentation(input_text, model_name, k, pool, threshold):
if model_name in _ST_MODELS:
model = SentenceTransformer(model_name, device=DEVICE)
embedder_fn = model.encode
else:
raise ValueError(f'Invalid model name: {model_name}')
sents = sent_tokenize(input_text, method='nltk', initial_split_sep='\n')
cache_id = hashlib.md5(input_text.encode()).hexdigest()
cache_path = Path(CACHE_DIR) / f'{cache_id}.pkl'
embedded_sents = embed_sentences(sents, embedder_fn, cache_path=cache_path)
cosine_sims = calculate_cosine_similarities(embedded_sents, k=k, pool=pool)
preds, probs = predict_boundaries(cosine_sims, threshold=threshold)
return output_segments(sents, preds, probs)
with gr.Blocks() as app:
gr.Markdown("""
# LLM TextTiling Demo
An **extended** approach to text segmentation that combines **TextTiling** with **LLM embeddings**. Simply provide your text, choose an embedding model, and adjust segmentation parameters (window size, pooling, threshold). The demo will split your text into coherent segments based on **semantic shifts**. Refer to the [README](https://huggingface.co/spaces/saeedabc/llm-text-tiling-demo/blob/main/README.md) for more details.
""")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Text", placeholder="Enter your text here...", lines=15)
with gr.Row():
with gr.Column():
# model_name = gr.Radio(choices=_ST_MODELS, label="Embedding Model", value=_ST_MODELS[0])
model_name = gr.Dropdown(choices=_ST_MODELS, label="Embedding Model", value=_ST_MODELS[0])
with gr.Column():
pool = gr.Dropdown(choices=['max', 'mean', 'min'], label="Pooling Strategy", value='max')
with gr.Row():
with gr.Column():
threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Threshold", value=0.5)
with gr.Column():
k = gr.Slider(minimum=1, maximum=10, step=1, label="Window Size", value=3)
submit_button = gr.Button("Chunk Text")
with gr.Column():
with gr.Tabs():
with gr.Tab("Output Text"):
output_text = gr.Textbox(label="Output Text", placeholder="Chunks will appear here...", lines=22)
with gr.Tab("Output Json"):
output_json = gr.Json(label="Output Json", open=False, max_height=500)
with gr.Tab("Output Visualization"):
output_fig = gr.Plot(label="Output Visualization")
submit_button.click(text_segmentation, inputs=[input_text, model_name, k, pool, threshold], outputs=[output_text, output_json, output_fig])
examples = gr.Examples(
examples=[
["Rib Mountain is a census-designated place (CDP) in the town of Rib Mountain in Marathon County, Wisconsin, United States. "
"The population was 5,651 at the 2010 census. "
"The community is named for Rib Mountain. "
"According to the United States Census Bureau, the CDP has a total area of 33.8 km² (13.0 mi²). "
"31.4 km² (12.1 mi²) of it is land and 2.4 km² (0.9 mi²) of it (6.98%) is water. "
"As of the census of 2000, there were 6,059 people, 2,211 households, and 1,782 families residing in the CDP. "
"The population density was 193.0/km² (499.8/mi²). "
"There were 2,278 housing units at an average density of 72.6/km² (187.9/mi²).", "all-mpnet-base-v2", 3, 'max', 0.52],
],
inputs=[input_text, model_name, k, pool, threshold],
)
if __name__ == '__main__':
Path(CACHE_DIR).mkdir(exist_ok=True)
# Launch the app
app.launch()
|