mattpowell commited on
Commit
9f44e68
·
1 Parent(s): 2168e29

Initial commit

Browse files
Files changed (2) hide show
  1. app.py +64 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from wtpsplit import SaT, WtP
3
+
4
+ WTP_MODELS = [
5
+ 'wtp-bert-mini',
6
+ 'wtp-canine-s-3l',
7
+ ]
8
+
9
+ SAT_MODELS = [
10
+ 'sat-3l',
11
+ 'sat-3l-sm',
12
+ 'sat-12l-sm',
13
+ 'sat-6l-sm',
14
+ 'sat-12l',
15
+ 'sat-1l-sm',
16
+ 'sat-6l',
17
+ 'sat-1l',
18
+ 'sat-9l',
19
+ 'sat-9l-sm',
20
+ 'sat-9l-no-limited-lookahead',
21
+ 'sat-3l-no-limited-lookahead',
22
+ ]
23
+
24
+ MODELS = SAT_MODELS + WTP_MODELS
25
+
26
+ def split(text, model, threshold, lang_code='en', do_paragraph_segmentation=False, paragraph_threshold = 0.5, use_onnx = False, strip_whitespace = False, remove_whitespace_before_inference = False):
27
+
28
+ # TODO: not currently used
29
+ ort_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if use_onnx else []
30
+
31
+ is_sat_model = model in SAT_MODELS
32
+
33
+ gr.Info(f"Loading model: {model} ({ort_providers})")
34
+ sat = SaT(model) if is_sat_model else WtP(model)
35
+
36
+ sat.half().to("cuda")
37
+
38
+ # TODO: figure out how to support passing lang_code to all models regardless of if they're LoRa
39
+ return sat.split(text, threshold=threshold if threshold > 0 else None, do_paragraph_segmentation=do_paragraph_segmentation, paragraph_threshold=paragraph_threshold, strip_whitespace=strip_whitespace, remove_whitespace_before_inference=remove_whitespace_before_inference)
40
+
41
+ with gr.Blocks() as demo:
42
+ gr.Markdown("Split sentences using segment-any-text wtpsplit")
43
+ text = gr.Textbox(label="Text")
44
+ model = gr.Dropdown(MODELS, value=MODELS[0], label='Model')
45
+ threshold = gr.Slider(label="Segmentation threshold", minimum=0, maximum=1, step=0.01, value=0)
46
+ lang_code=gr.Textbox(label="Language code", value="en")
47
+ strip_whitespace = gr.Checkbox(label="Strip whitespace?", value=False)
48
+ remove_whitespace_before_inference = gr.Checkbox(label="Remove whitespace before inference?", value=False)
49
+ do_paragraph_segmentation = gr.Checkbox(label="Segment paragraphs?", value=False)
50
+ paragraph_threshold = gr.Slider(label="Paragraph threshold", minimum=0, maximum=1, step=0.1, value=0.5)
51
+ use_onnx = gr.Checkbox(label="Use ONNX?", value=False, interactive=False)
52
+ run_button = gr.Button("Split")
53
+ output = gr.JSON(label="Sentences")
54
+
55
+ gr.on(
56
+ triggers=[text.submit, run_button.click],
57
+ fn=split,
58
+ inputs=[text, model, threshold, lang_code, do_paragraph_segmentation, paragraph_threshold, use_onnx, strip_whitespace, remove_whitespace_before_inference],
59
+ outputs=output,
60
+ api_name="split",
61
+ )
62
+
63
+ if __name__ == "__main__":
64
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ wtpsplit
2
+ wtpsplit[onnx-gpu]
3
+ wtpsplit[onnx-cpu]
4
+ onnxruntime-gpu
5
+ onnxruntime
6
+ gradio>=5.0
7
+ torch
8
+ spaces