Huangxin14 commited on
Commit
4f25e1c
·
verified ·
1 Parent(s): 2c0a561

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import (
4
+ AutoConfig,
5
+ AutoTokenizer,
6
+ AutoModelForSequenceClassification
7
+ )
8
+
9
+ model_dir = "my-bert-model"
10
+
11
+ config = AutoConfig.from_pretrained(
12
+ model_dir,
13
+ num_labels=3,
14
+ finetuning_task="text-classification"
15
+ )
16
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
17
+ model = AutoModelForSequenceClassification.from_pretrained(
18
+ model_dir,
19
+ config=config
20
+ )
21
+
22
+ def inference(input_text):
23
+ inputs = tokenizer.batch_encode_plus(
24
+ [input_text],
25
+ max_length=512,
26
+ pad_to_max_length=True,
27
+ truncation=True,
28
+ padding="max_length",
29
+ return_tensors="pt",
30
+ )
31
+
32
+ with torch.no_grad():
33
+ logits = model(**inputs).logits
34
+
35
+ predicted_class_id = logits.argmax().item()
36
+ output = model.config.id2label[predicted_class_id]
37
+ return output
38
+
39
+
40
+ with gr.Blocks(css="""
41
+ .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
42
+ #component-21 > div.wrap.svelte-w6rprc {height: 600px;}
43
+ """) as demo:
44
+ with gr.Row():
45
+ with gr.Column():
46
+ input_text = gr.Textbox(
47
+ placeholder="Insert your prompt here:",
48
+ scale=2,
49
+ container=False
50
+ )
51
+ answer = gr.Textbox(lines=0, label="Answer")
52
+ generate_bt = gr.Button("Generate", scale=1)
53
+
54
+ inputs = [input_text]
55
+ outputs = [answer]
56
+
57
+ generate_bt.click(
58
+ fn=inference,
59
+ inputs=inputs,
60
+ outputs=outputs,
61
+ show_progress=True
62
+ )
63
+
64
+ examples = [
65
+ ["My last two weather pics from the storm on August 2nd. People packed up real fast after the temp dropped and winds picked up.", 1],
66
+ ["Lying Clinton sinking! Donald Trump singing: Let's Make America Great Again!", 0],
67
+ ]
68
+
69
+ demo.queue()
70
+ demo.launch()
71
+