mattpowell's picture
Add Zero GPU annotation
d74b6bd
import spaces
import gradio as gr
from wtpsplit import SaT, WtP
WTP_MODELS = [
'wtp-bert-mini',
'wtp-canine-s-3l',
]
SAT_MODELS = [
'sat-3l',
'sat-3l-sm',
'sat-12l-sm',
'sat-6l-sm',
'sat-12l',
'sat-1l-sm',
'sat-6l',
'sat-1l',
'sat-9l',
'sat-9l-sm',
'sat-9l-no-limited-lookahead',
'sat-3l-no-limited-lookahead',
]
MODELS = SAT_MODELS + WTP_MODELS
@spaces.GPU(duration=60)
def split(text, model, threshold, lang_code='en', do_paragraph_segmentation=False, paragraph_threshold = 0.5, use_onnx = False, strip_whitespace = False, remove_whitespace_before_inference = False):
# TODO: not currently used
ort_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if use_onnx else []
is_sat_model = model in SAT_MODELS
gr.Info(f"Loading model: {model} ({ort_providers})")
sat = SaT(model) if is_sat_model else WtP(model)
sat.half().to("cuda")
# TODO: figure out how to support passing lang_code to all models regardless of if they're LoRa
return sat.split(text, threshold=threshold if threshold > 0 else None, do_paragraph_segmentation=do_paragraph_segmentation, paragraph_threshold=paragraph_threshold, strip_whitespace=strip_whitespace, remove_whitespace_before_inference=remove_whitespace_before_inference)
with gr.Blocks() as demo:
gr.Markdown("Split sentences using segment-any-text wtpsplit")
text = gr.Textbox(label="Text")
model = gr.Dropdown(MODELS, value=MODELS[0], label='Model')
threshold = gr.Slider(label="Segmentation threshold", minimum=0, maximum=1, step=0.01, value=0)
lang_code=gr.Textbox(label="Language code", value="en")
strip_whitespace = gr.Checkbox(label="Strip whitespace?", value=False)
remove_whitespace_before_inference = gr.Checkbox(label="Remove whitespace before inference?", value=False)
do_paragraph_segmentation = gr.Checkbox(label="Segment paragraphs?", value=False)
paragraph_threshold = gr.Slider(label="Paragraph threshold", minimum=0, maximum=1, step=0.1, value=0.5)
use_onnx = gr.Checkbox(label="Use ONNX?", value=False, interactive=False)
run_button = gr.Button("Split")
output = gr.JSON(label="Sentences")
gr.on(
triggers=[text.submit, run_button.click],
fn=split,
inputs=[text, model, threshold, lang_code, do_paragraph_segmentation, paragraph_threshold, use_onnx, strip_whitespace, remove_whitespace_before_inference],
outputs=output,
api_name="split",
)
if __name__ == "__main__":
demo.launch()