Update app.py
Browse files
app.py
CHANGED
|
@@ -30,20 +30,25 @@ def predict_news_type(content):
|
|
| 30 |
content = pad_sequences(content, maxlen=MAX_LEN, padding='post')
|
| 31 |
content_predict = model.predict(content, verbose=0)
|
| 32 |
result = np.argmax(content_predict, axis=1)
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# Create Gradio Interface
|
| 36 |
demo = gr.Interface(
|
| 37 |
fn=predict_news_type,
|
| 38 |
inputs=gr.Textbox(label="Enter the news content"),
|
| 39 |
-
outputs=
|
|
|
|
|
|
|
|
|
|
| 40 |
title="News Type Prediction",
|
| 41 |
-
description="Enter the news content to predict its category."
|
| 42 |
)
|
| 43 |
-
|
| 44 |
if __name__ == "__main__":
|
| 45 |
# Launch the Gradio Interface
|
| 46 |
-
host = "127.0.0.1"
|
| 47 |
-
port = 7860
|
| 48 |
-
print(f"Gradio app is running on {host}:{port}")
|
| 49 |
demo.launch()#server_name=host, server_port=port, share=True)
|
|
|
|
| 30 |
content = pad_sequences(content, maxlen=MAX_LEN, padding='post')
|
| 31 |
content_predict = model.predict(content, verbose=0)
|
| 32 |
result = np.argmax(content_predict, axis=1)
|
| 33 |
+
category = label_dict[str(result[0])]
|
| 34 |
+
|
| 35 |
+
# Get all categories and their probabilities
|
| 36 |
+
probabilities = content_predict[0].tolist()
|
| 37 |
+
category_probabilities = {label_dict[str(i)]: prob for i, prob in enumerate(probabilities)}
|
| 38 |
+
|
| 39 |
+
return category, category_probabilities
|
| 40 |
|
| 41 |
# Create Gradio Interface
|
| 42 |
demo = gr.Interface(
|
| 43 |
fn=predict_news_type,
|
| 44 |
inputs=gr.Textbox(label="Enter the news content"),
|
| 45 |
+
outputs=[
|
| 46 |
+
gr.Textbox(label="Predicted News Category"),
|
| 47 |
+
gr.JSON(label="Category Probabilities")
|
| 48 |
+
],
|
| 49 |
title="News Type Prediction",
|
| 50 |
+
description="Enter the news content to predict its category and see the probabilities for all categories."
|
| 51 |
)
|
|
|
|
| 52 |
if __name__ == "__main__":
|
| 53 |
# Launch the Gradio Interface
|
|
|
|
|
|
|
|
|
|
| 54 |
demo.launch()#server_name=host, server_port=port, share=True)
|