| import spaces |
| import gradio as gr |
| from wtpsplit import SaT, WtP |
|
|
| WTP_MODELS = [ |
| 'wtp-bert-mini', |
| 'wtp-canine-s-3l', |
| ] |
|
|
| SAT_MODELS = [ |
| 'sat-3l', |
| 'sat-3l-sm', |
| 'sat-12l-sm', |
| 'sat-6l-sm', |
| 'sat-12l', |
| 'sat-1l-sm', |
| 'sat-6l', |
| 'sat-1l', |
| 'sat-9l', |
| 'sat-9l-sm', |
| 'sat-9l-no-limited-lookahead', |
| 'sat-3l-no-limited-lookahead', |
| ] |
|
|
| MODELS = SAT_MODELS + WTP_MODELS |
|
|
| @spaces.GPU(duration=60) |
| def split(text, model, threshold, lang_code='en', do_paragraph_segmentation=False, paragraph_threshold = 0.5, use_onnx = False, strip_whitespace = False, remove_whitespace_before_inference = False): |
|
|
| |
| ort_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if use_onnx else [] |
|
|
| is_sat_model = model in SAT_MODELS |
|
|
| gr.Info(f"Loading model: {model} ({ort_providers})") |
| sat = SaT(model) if is_sat_model else WtP(model) |
|
|
| sat.half().to("cuda") |
|
|
| |
| return sat.split(text, threshold=threshold if threshold > 0 else None, do_paragraph_segmentation=do_paragraph_segmentation, paragraph_threshold=paragraph_threshold, strip_whitespace=strip_whitespace, remove_whitespace_before_inference=remove_whitespace_before_inference) |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("Split sentences using segment-any-text wtpsplit") |
| text = gr.Textbox(label="Text") |
| model = gr.Dropdown(MODELS, value=MODELS[0], label='Model') |
| threshold = gr.Slider(label="Segmentation threshold", minimum=0, maximum=1, step=0.01, value=0) |
| lang_code=gr.Textbox(label="Language code", value="en") |
| strip_whitespace = gr.Checkbox(label="Strip whitespace?", value=False) |
| remove_whitespace_before_inference = gr.Checkbox(label="Remove whitespace before inference?", value=False) |
| do_paragraph_segmentation = gr.Checkbox(label="Segment paragraphs?", value=False) |
| paragraph_threshold = gr.Slider(label="Paragraph threshold", minimum=0, maximum=1, step=0.1, value=0.5) |
| use_onnx = gr.Checkbox(label="Use ONNX?", value=False, interactive=False) |
| run_button = gr.Button("Split") |
| output = gr.JSON(label="Sentences") |
|
|
| gr.on( |
| triggers=[text.submit, run_button.click], |
| fn=split, |
| inputs=[text, model, threshold, lang_code, do_paragraph_segmentation, paragraph_threshold, use_onnx, strip_whitespace, remove_whitespace_before_inference], |
| outputs=output, |
| api_name="split", |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|