Commit ·
9f44e68
1
Parent(s): 2168e29
Initial commit
Browse files- app.py +64 -0
- requirements.txt +8 -0
app.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from wtpsplit import SaT, WtP
|
| 3 |
+
|
| 4 |
+
WTP_MODELS = [
|
| 5 |
+
'wtp-bert-mini',
|
| 6 |
+
'wtp-canine-s-3l',
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
SAT_MODELS = [
|
| 10 |
+
'sat-3l',
|
| 11 |
+
'sat-3l-sm',
|
| 12 |
+
'sat-12l-sm',
|
| 13 |
+
'sat-6l-sm',
|
| 14 |
+
'sat-12l',
|
| 15 |
+
'sat-1l-sm',
|
| 16 |
+
'sat-6l',
|
| 17 |
+
'sat-1l',
|
| 18 |
+
'sat-9l',
|
| 19 |
+
'sat-9l-sm',
|
| 20 |
+
'sat-9l-no-limited-lookahead',
|
| 21 |
+
'sat-3l-no-limited-lookahead',
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
MODELS = SAT_MODELS + WTP_MODELS
|
| 25 |
+
|
| 26 |
+
def split(text, model, threshold, lang_code='en', do_paragraph_segmentation=False, paragraph_threshold = 0.5, use_onnx = False, strip_whitespace = False, remove_whitespace_before_inference = False):
|
| 27 |
+
|
| 28 |
+
# TODO: not currently used
|
| 29 |
+
ort_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if use_onnx else []
|
| 30 |
+
|
| 31 |
+
is_sat_model = model in SAT_MODELS
|
| 32 |
+
|
| 33 |
+
gr.Info(f"Loading model: {model} ({ort_providers})")
|
| 34 |
+
sat = SaT(model) if is_sat_model else WtP(model)
|
| 35 |
+
|
| 36 |
+
sat.half().to("cuda")
|
| 37 |
+
|
| 38 |
+
# TODO: figure out how to support passing lang_code to all models regardless of if they're LoRa
|
| 39 |
+
return sat.split(text, threshold=threshold if threshold > 0 else None, do_paragraph_segmentation=do_paragraph_segmentation, paragraph_threshold=paragraph_threshold, strip_whitespace=strip_whitespace, remove_whitespace_before_inference=remove_whitespace_before_inference)
|
| 40 |
+
|
| 41 |
+
with gr.Blocks() as demo:
|
| 42 |
+
gr.Markdown("Split sentences using segment-any-text wtpsplit")
|
| 43 |
+
text = gr.Textbox(label="Text")
|
| 44 |
+
model = gr.Dropdown(MODELS, value=MODELS[0], label='Model')
|
| 45 |
+
threshold = gr.Slider(label="Segmentation threshold", minimum=0, maximum=1, step=0.01, value=0)
|
| 46 |
+
lang_code=gr.Textbox(label="Language code", value="en")
|
| 47 |
+
strip_whitespace = gr.Checkbox(label="Strip whitespace?", value=False)
|
| 48 |
+
remove_whitespace_before_inference = gr.Checkbox(label="Remove whitespace before inference?", value=False)
|
| 49 |
+
do_paragraph_segmentation = gr.Checkbox(label="Segment paragraphs?", value=False)
|
| 50 |
+
paragraph_threshold = gr.Slider(label="Paragraph threshold", minimum=0, maximum=1, step=0.1, value=0.5)
|
| 51 |
+
use_onnx = gr.Checkbox(label="Use ONNX?", value=False, interactive=False)
|
| 52 |
+
run_button = gr.Button("Split")
|
| 53 |
+
output = gr.JSON(label="Sentences")
|
| 54 |
+
|
| 55 |
+
gr.on(
|
| 56 |
+
triggers=[text.submit, run_button.click],
|
| 57 |
+
fn=split,
|
| 58 |
+
inputs=[text, model, threshold, lang_code, do_paragraph_segmentation, paragraph_threshold, use_onnx, strip_whitespace, remove_whitespace_before_inference],
|
| 59 |
+
outputs=output,
|
| 60 |
+
api_name="split",
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wtpsplit
|
| 2 |
+
wtpsplit[onnx-gpu]
|
| 3 |
+
wtpsplit[onnx-cpu]
|
| 4 |
+
onnxruntime-gpu
|
| 5 |
+
onnxruntime
|
| 6 |
+
gradio>=5.0
|
| 7 |
+
torch
|
| 8 |
+
spaces
|