Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from backend import get_message_single, get_message_spam, send_single, send_spam | |
| from defaults import ( | |
| ADDRESS_BETTERTRANSFORMER, | |
| ADDRESS_VANILLA, | |
| defaults_bt_single, | |
| defaults_bt_spam, | |
| defaults_vanilla_single, | |
| defaults_vanilla_spam, | |
| ) | |
| TTILE_IMAGE = """ | |
| <div | |
| style=" | |
| display: block; | |
| margin-left: auto; | |
| margin-right: auto; | |
| width: 50%; | |
| " | |
| > | |
| <img src="https://huggingface.co/spaces/fxmarty/bettertransformer-demo/resolve/main/header.webp"/> | |
| </div> | |
| """ | |
| TITLE = """ | |
| <div | |
| style=" | |
| display: inline-flex; | |
| align-items: center; | |
| text-align: center; | |
| max-width: 1400px; | |
| gap: 0.8rem; | |
| font-size: 2.2rem; | |
| " | |
| > | |
| <h1 style="font-weight: 700; margin-bottom: 10px; margin-top: 10px;"> | |
| Speed up your inference and support more workload with PyTorch's BetterTransformer 🤗 | |
| </h1> | |
| </div> | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.HTML(TTILE_IMAGE) | |
| gr.HTML(TITLE) | |
| gr.Markdown( | |
| """ | |
| Let's try out TorchServe + BetterTransformer! | |
| BetterTransformer is a stable feature made available with [PyTorch 1.13](https://pytorch.org/blog/PyTorch-1.13-release/) allowing to use a fastpath execution for encoder attention blocks. | |
| As a one-liner, you can convert your 🤗 Transformers models to use BetterTransformer thanks to the [🤗 Optimum](https://huggingface.co/docs/optimum/main/en/index) library: | |
| ``` | |
| from optimum.bettertransformer import BetterTransformer | |
| better_model = BetterTransformer.transform(model) | |
| ``` | |
| This Space is a demo of an **end-to-end** deployement of PyTorch eager-mode models, both with and without BetterTransformer. The goal is to see what are the benefits server-side and client-side of using BetterTransformer. | |
| ## Inference using... | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=50): | |
| gr.Markdown("### Vanilla Transformers + TorchServe") | |
| address_input_vanilla = gr.Textbox( | |
| max_lines=1, label="ip vanilla", value=ADDRESS_VANILLA, visible=False | |
| ) | |
| input_model_vanilla = gr.Textbox( | |
| max_lines=1, | |
| label="Text", | |
| value="Expectations were low, enjoyment was high", | |
| ) | |
| btn_single_vanilla = gr.Button("Send single text request") | |
| output_single_vanilla = gr.Markdown( | |
| label="Output single vanilla", | |
| value=get_message_single(**defaults_vanilla_single), | |
| ) | |
| btn_spam_vanilla = gr.Button( | |
| "Spam text requests (from sst2 validation set)" | |
| ) | |
| output_spam_vanilla = gr.Markdown( | |
| label="Output spam vanilla", | |
| value=get_message_spam(**defaults_vanilla_spam), | |
| ) | |
| btn_single_vanilla.click( | |
| fn=send_single, | |
| inputs=[input_model_vanilla, address_input_vanilla], | |
| outputs=output_single_vanilla, | |
| ) | |
| btn_spam_vanilla.click( | |
| fn=send_spam, | |
| inputs=[address_input_vanilla], | |
| outputs=output_spam_vanilla, | |
| ) | |
| with gr.Column(scale=50): | |
| gr.Markdown("### BetterTransformer + TorchServe") | |
| address_input_bettertransformer = gr.Textbox( | |
| max_lines=1, | |
| label="ip bettertransformer", | |
| value=ADDRESS_BETTERTRANSFORMER, | |
| visible=False, | |
| ) | |
| input_model_bettertransformer = gr.Textbox( | |
| max_lines=1, | |
| label="Text", | |
| value="Expectations were low, enjoyment was high", | |
| ) | |
| btn_single_bt = gr.Button("Send single text request") | |
| output_single_bt = gr.Markdown( | |
| label="Output single bt", value=get_message_single(**defaults_bt_single) | |
| ) | |
| btn_spam_bt = gr.Button("Spam text requests (from sst2 validation set)") | |
| output_spam_bt = gr.Markdown( | |
| label="Output spam bt", value=get_message_spam(**defaults_bt_spam) | |
| ) | |
| btn_single_bt.click( | |
| fn=send_single, | |
| inputs=[input_model_bettertransformer, address_input_bettertransformer], | |
| outputs=output_single_bt, | |
| ) | |
| btn_spam_bt.click( | |
| fn=send_spam, | |
| inputs=[address_input_bettertransformer], | |
| outputs=output_spam_bt, | |
| ) | |
| demo.queue(concurrency_count=1) | |
| demo.launch() | |