Spaces:
Running
Running
| import hashlib | |
| import pickle | |
| from pathlib import Path | |
| from itertools import zip_longest | |
| import gradio as gr | |
| import torch | |
| from sentence_transformers import SentenceTransformer, util | |
| import numpy as np | |
| import ruptures as rpt | |
| from util import sent_tokenize | |
| CACHE_DIR = '.cache' | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| _ST_MODELS = ['all-mpnet-base-v2', 'multi-qa-mpnet-base-dot-v1', 'all-MiniLM-L12-v2'] | |
| def embed_sentences(sentences, embedder_fn, cache_path): | |
| if Path(cache_path).exists(): | |
| print(f'Loading embeddings from cache: {cache_path}') | |
| with open(cache_path, 'rb') as file: | |
| embedded_sents = pickle.load(file) | |
| else: | |
| print(f'Embedding sentences and saving to cache: {cache_path}') | |
| embedded_sents = embedder_fn(sentences) | |
| assert len(embedded_sents) == len(sentences) | |
| with open(cache_path, 'wb') as file: | |
| pickle.dump(embedded_sents, file) | |
| return embedded_sents | |
| def calculate_cosine_similarities(embedded_sents, k=1, pool='mean'): | |
| def cosine_similarity(a, b): | |
| sim = util.cos_sim(a, b) | |
| if pool == 'mean': | |
| return sim.mean().item() | |
| elif pool == 'max': | |
| return sim.max().item() | |
| elif pool == 'min': | |
| return sim.min().item() | |
| else: | |
| raise ValueError(f'Invalid pooling method: {pool}') | |
| cosine_sims = [] | |
| for i in range(len(embedded_sents) - 1): | |
| lctx = embedded_sents[max(0, i-k+1) : i+1] | |
| rctx = embedded_sents[i+1 : i+k+1] | |
| sim = cosine_similarity(lctx, rctx) | |
| cosine_sims.append(sim) | |
| # cosine_sims.append(0.0) | |
| assert len(cosine_sims) == len(embedded_sents) - 1, f'{len(cosine_sims)} != {len(embedded_sents)}' | |
| return cosine_sims | |
| def predict_boundaries(cosine_sims, threshold): | |
| probs = [1.0 - sim for sim in cosine_sims] | |
| preds = [1 if prob >= threshold else 0 for prob in probs] | |
| return preds, probs | |
| def output_segments(sents, preds, probs): | |
| assert len(sents) - 1 == len(preds) == len(probs), f'{len(sents)} - 1 != {len(preds)} != {len(probs)}' | |
| def iter_segments(sents, preds, probs): | |
| segment = [] | |
| for i, (sent, pred, prob) in enumerate(zip_longest(sents, preds, probs)): | |
| segment.append({ | |
| # 'id': i + 1, | |
| 'text': sent, | |
| 'prob': round(prob, 4) if prob is not None else None, | |
| }) | |
| if pred == 1: | |
| yield segment | |
| segment = [] | |
| if len(segment) > 0: | |
| yield segment | |
| segment = [] | |
| out = { | |
| 'metadata': {}, | |
| 'chunks': [], | |
| } | |
| n_segs = 0 | |
| n_sents = 0 | |
| for _, segment in enumerate(iter_segments(sents, preds, probs)): | |
| # out['chunks'].append({ | |
| # 'id': n_segs + 1, | |
| # 'chunk': segment, | |
| # }) | |
| out['chunks'].append(segment) | |
| n_segs += 1 | |
| n_sents += len(segment) | |
| out['metadata'] = { | |
| 'n_chunks': n_segs, | |
| 'n_sents': n_sents, | |
| 'prob_mean': round(np.mean(probs), 4), | |
| 'prob_std': round(np.std(probs), 4), | |
| 'prob_min': round(min(probs), 4), | |
| 'prob_max': round(max(probs), 4), | |
| } | |
| out_text = "\n-------------------------\n".join(["\n".join([sent['text'] for sent in segment]) for segment in out['chunks']]) | |
| def plot_regimes(signal, preds): | |
| def get_bkps_from_labels(labels): | |
| return [i+1 for i, l in enumerate(labels) if l == 1] | |
| # signal = signal[:-1] | |
| preds = preds + [1] | |
| bkps = get_bkps_from_labels(preds) | |
| fig, [ax] = rpt.display(np.array(signal), bkps, figsize=(10, 5), dpi=250) | |
| y_min = max(0.0, min(signal) - 0.1) | |
| y_max = min(1.0, max(signal) + 0.1) | |
| ax.set_ylim(y_min, y_max) | |
| ax.set_title("Segment Regimes") | |
| ax.set_xlabel("Sentence Index") | |
| ax.set_ylabel("Semantic Shift Probability") | |
| fig.tight_layout() | |
| return fig | |
| fig = plot_regimes(probs, preds) | |
| return out_text, out, fig | |
| def text_segmentation(input_text, model_name, k, pool, threshold): | |
| if model_name in _ST_MODELS: | |
| model = SentenceTransformer(model_name, device=DEVICE) | |
| embedder_fn = model.encode | |
| else: | |
| raise ValueError(f'Invalid model name: {model_name}') | |
| sents = sent_tokenize(input_text, method='nltk', initial_split_sep='\n') | |
| cache_id = hashlib.md5(input_text.encode()).hexdigest() | |
| cache_path = Path(CACHE_DIR) / f'{cache_id}.pkl' | |
| embedded_sents = embed_sentences(sents, embedder_fn, cache_path=cache_path) | |
| cosine_sims = calculate_cosine_similarities(embedded_sents, k=k, pool=pool) | |
| preds, probs = predict_boundaries(cosine_sims, threshold=threshold) | |
| return output_segments(sents, preds, probs) | |
| with gr.Blocks() as app: | |
| gr.Markdown(""" | |
| # LLM TextTiling Demo | |
| An **extended** approach to text segmentation that combines **TextTiling** with **LLM embeddings**. Simply provide your text, choose an embedding model, and adjust segmentation parameters (window size, pooling, threshold). The demo will split your text into coherent segments based on **semantic shifts**. Refer to the [README](https://huggingface.co/spaces/saeedabc/llm-text-tiling-demo/blob/main/README.md) for more details. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox(label="Input Text", placeholder="Enter your text here...", lines=15) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # model_name = gr.Radio(choices=_ST_MODELS, label="Embedding Model", value=_ST_MODELS[0]) | |
| model_name = gr.Dropdown(choices=_ST_MODELS, label="Embedding Model", value=_ST_MODELS[0]) | |
| with gr.Column(): | |
| pool = gr.Dropdown(choices=['max', 'mean', 'min'], label="Pooling Strategy", value='max') | |
| with gr.Row(): | |
| with gr.Column(): | |
| threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Threshold", value=0.5) | |
| with gr.Column(): | |
| k = gr.Slider(minimum=1, maximum=10, step=1, label="Window Size", value=3) | |
| submit_button = gr.Button("Chunk Text") | |
| with gr.Column(): | |
| with gr.Tabs(): | |
| with gr.Tab("Output Text"): | |
| output_text = gr.Textbox(label="Output Text", placeholder="Chunks will appear here...", lines=22) | |
| with gr.Tab("Output Json"): | |
| output_json = gr.Json(label="Output Json", open=False, max_height=500) | |
| with gr.Tab("Output Visualization"): | |
| output_fig = gr.Plot(label="Output Visualization") | |
| submit_button.click(text_segmentation, inputs=[input_text, model_name, k, pool, threshold], outputs=[output_text, output_json, output_fig]) | |
| examples = gr.Examples( | |
| examples=[ | |
| ["Rib Mountain is a census-designated place (CDP) in the town of Rib Mountain in Marathon County, Wisconsin, United States. " | |
| "The population was 5,651 at the 2010 census. " | |
| "The community is named for Rib Mountain. " | |
| "According to the United States Census Bureau, the CDP has a total area of 33.8 km² (13.0 mi²). " | |
| "31.4 km² (12.1 mi²) of it is land and 2.4 km² (0.9 mi²) of it (6.98%) is water. " | |
| "As of the census of 2000, there were 6,059 people, 2,211 households, and 1,782 families residing in the CDP. " | |
| "The population density was 193.0/km² (499.8/mi²). " | |
| "There were 2,278 housing units at an average density of 72.6/km² (187.9/mi²).", "all-mpnet-base-v2", 3, 'max', 0.52], | |
| ], | |
| inputs=[input_text, model_name, k, pool, threshold], | |
| ) | |
| if __name__ == '__main__': | |
| Path(CACHE_DIR).mkdir(exist_ok=True) | |
| # Launch the app | |
| app.launch() | |