File size: 1,807 Bytes
4f25e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import gradio as gr
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSequenceClassification
)

model_dir = "my-bert-model"

config = AutoConfig.from_pretrained(
    model_dir,
    finetuning_task="text-classification"
)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(
    model_dir,
    config=config
)

def inference(input_text):
    inputs = tokenizer.batch_encode_plus(
        [input_text],
        max_length=512,
        pad_to_max_length=True,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )

    with torch.no_grad():
        logits = model(**inputs).logits

    predicted_class_id = logits.argmax().item()
    output = model.config.id2label[predicted_class_id]
    return output


with gr.Blocks(css="""
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
#component-21 > div.wrap.svelte-w6rprc {height: 600px;}
""") as demo:
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(
                placeholder="Insert your prompt here:",
                scale=2,
                container=False
            )
            answer = gr.Textbox(lines=0, label="Answer")
            generate_bt = gr.Button("Generate", scale=1)

    inputs = [input_text]
    outputs = [answer]

    generate_bt.click(
        fn=inference,
        inputs=inputs,
        outputs=outputs,
        show_progress=True
    )

    examples = [
        ["My last two weather pics from the storm on August 2nd. People packed up real fast after the temp dropped and winds picked up.", 1],
        ["Lying Clinton sinking! Donald Trump singing: Let's Make America Great Again!", 0],
    ]

demo.queue()
demo.launch()