Spaces:
Sleeping
Sleeping
Commit ·
ca9a14d
1
Parent(s): 84a6867
adjust
Browse files
app.py
CHANGED
|
@@ -18,7 +18,6 @@ def inference(input_text):
|
|
| 18 |
inputs = tokenizer.batch_encode_plus(
|
| 19 |
[input_text],
|
| 20 |
max_length=512,
|
| 21 |
-
pad_to_max_length=True,
|
| 22 |
truncation=True,
|
| 23 |
padding="max_length",
|
| 24 |
return_tensors="pt",
|
|
@@ -27,7 +26,7 @@ def inference(input_text):
|
|
| 27 |
with torch.no_grad():
|
| 28 |
logits = model(**inputs).logits
|
| 29 |
|
| 30 |
-
predicted_class_id = logits.argmax().item()
|
| 31 |
output = model.config.id2label[predicted_class_id]
|
| 32 |
return output
|
| 33 |
|
|
@@ -36,8 +35,8 @@ demo = gr.Interface(
|
|
| 36 |
inputs=gr.Textbox(label="Input Text", scale=2, container=False),
|
| 37 |
outputs=gr.Textbox(label="Output Label"),
|
| 38 |
examples = [
|
| 39 |
-
["My last two weather pics from the storm on August 2nd.
|
| 40 |
-
["Lying Clinton sinking! Donald Trump singing
|
| 41 |
],
|
| 42 |
title="Tutorial: BERT-based Text Classificatioin",
|
| 43 |
)
|
|
|
|
| 18 |
inputs = tokenizer.batch_encode_plus(
|
| 19 |
[input_text],
|
| 20 |
max_length=512,
|
|
|
|
| 21 |
truncation=True,
|
| 22 |
padding="max_length",
|
| 23 |
return_tensors="pt",
|
|
|
|
| 26 |
with torch.no_grad():
|
| 27 |
logits = model(**inputs).logits
|
| 28 |
|
| 29 |
+
predicted_class_id = logits.argmax(dim=-1).item()
|
| 30 |
output = model.config.id2label[predicted_class_id]
|
| 31 |
return output
|
| 32 |
|
|
|
|
| 35 |
inputs=gr.Textbox(label="Input Text", scale=2, container=False),
|
| 36 |
outputs=gr.Textbox(label="Output Label"),
|
| 37 |
examples = [
|
| 38 |
+
["My last two weather pics from the storm on August 2nd. "],
|
| 39 |
+
["Lying Clinton sinking! Donald Trump singing."],
|
| 40 |
],
|
| 41 |
title="Tutorial: BERT-based Text Classificatioin",
|
| 42 |
)
|