File size: 2,282 Bytes
6c081cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62977ad
 
 
 
 
 
 
6c081cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import requests
import json
import os
from screenshot import (
    before_prompt,
    prompt_to_generation,
    after_generation,
    js_save,
    js_load_script,
)

def inference(input_sentence, max_length, seed=42):
       parameters = {
            "max_new_tokens": max_length,
            "do_sample": False,
            "seed": seed,
            "early_stopping": False,
            "length_penalty": 0.0,
            "eos_token_id": None,
       }

    payload = {
        "inputs": input_sentence, 
        "parameters": parameters,
        "options" : {
            "use_cache": False
            } 
         }

    data = query(payload)

    if "error" in data:
        return (None, None, f"<span style='color:red'>ERROR: {data['error']} </span>")

    generation = data[0]["generated_text"].split(input_sentence, 1)[1]
    return (
        before_prompt
        + input_sentence
        + prompt_to_generation
        + generation
        + after_generation,
        data[0]["generated_text"],
        "",
    )


if __name__ == "__main__":
    demo = gr.Blocks()
    with demo:
        with gr.Row():
            gr.Markdown(value=description)
        with gr.Row():
            with gr.Column():
                text = gr.Textbox(
                    label="Input",
                    value=" ",  # should be set to " " when plugged into a real API
                )
                tokens = gr.Slider(1, 64, value=32, step=1, label="Tokens to generate")

                with gr.Row():
                    submit = gr.Button("Submit")
            with gr.Column():
                text_error = gr.Markdown(label="Log information")
                text_out = gr.Textbox(label="Output")
                display_out.set_event_trigger(
                    "load",
                    fn=None,
                    inputs=None,
                    outputs=None,
                    no_target=True,
                    js=js_load_script,
                )
        with gr.Row():
            gr.Examples(examples=examples, inputs=[text, tokens, sampling, sampling2])

        submit.click(
            inference,
            inputs=[text, tokens, sampling, sampling2],
            outputs=[display_out, text_out, text_error],
        )

    demo.launch()