FortuneXia commited on
Commit
ca9a14d
·
1 Parent(s): 84a6867
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -18,7 +18,6 @@ def inference(input_text):
18
  inputs = tokenizer.batch_encode_plus(
19
  [input_text],
20
  max_length=512,
21
- pad_to_max_length=True,
22
  truncation=True,
23
  padding="max_length",
24
  return_tensors="pt",
@@ -27,7 +26,7 @@ def inference(input_text):
27
  with torch.no_grad():
28
  logits = model(**inputs).logits
29
 
30
- predicted_class_id = logits.argmax().item()
31
  output = model.config.id2label[predicted_class_id]
32
  return output
33
 
@@ -36,8 +35,8 @@ demo = gr.Interface(
36
  inputs=gr.Textbox(label="Input Text", scale=2, container=False),
37
  outputs=gr.Textbox(label="Output Label"),
38
  examples = [
39
- ["My last two weather pics from the storm on August 2nd. People packed up real fast after the temp dropped and winds picked up.", 1],
40
- ["Lying Clinton sinking! Donald Trump singing: Let's Make America Great Again!", 0],
41
  ],
42
  title="Tutorial: BERT-based Text Classificatioin",
43
  )
 
18
  inputs = tokenizer.batch_encode_plus(
19
  [input_text],
20
  max_length=512,
 
21
  truncation=True,
22
  padding="max_length",
23
  return_tensors="pt",
 
26
  with torch.no_grad():
27
  logits = model(**inputs).logits
28
 
29
+ predicted_class_id = logits.argmax(dim=-1).item()
30
  output = model.config.id2label[predicted_class_id]
31
  return output
32
 
 
35
  inputs=gr.Textbox(label="Input Text", scale=2, container=False),
36
  outputs=gr.Textbox(label="Output Label"),
37
  examples = [
38
+ ["My last two weather pics from the storm on August 2nd. "],
39
+ ["Lying Clinton sinking! Donald Trump singing."],
40
  ],
41
  title="Tutorial: BERT-based Text Classificatioin",
42
  )