Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -107,9 +107,6 @@ def predict_single_text(text):
|
|
| 107 |
# Calculate the probabilities
|
| 108 |
probabilities = torch.sigmoid(logits).squeeze()
|
| 109 |
|
| 110 |
-
# Define the threshold for prediction
|
| 111 |
-
threshold = 0.3
|
| 112 |
-
|
| 113 |
# Get the predicted labels
|
| 114 |
predicted_labels_ = (probabilities.cpu().numpy() > threshold).tolist()
|
| 115 |
|
|
@@ -162,12 +159,13 @@ def predict_single_text(text):
|
|
| 162 |
|
| 163 |
# Create Gradio interface for single text
|
| 164 |
iface2 = gr.Interface(fn=predict_single_text,
|
| 165 |
-
inputs=gr.Textbox(lines=7, label="Paste or type text here"),
|
|
|
|
| 166 |
outputs=[gr.Label(label="Top Predictions", show_label=True),
|
| 167 |
gr.Plot(label="Likelihood of all labels", show_label=True)],
|
| 168 |
title="Single Text Prediction",
|
| 169 |
-
|
| 170 |
-
)
|
| 171 |
|
| 172 |
|
| 173 |
# UPLOAD CSV
|
|
@@ -245,8 +243,8 @@ def predict_from_csv(file, column_name, progress=gr.Progress()):
|
|
| 245 |
# Calculate the probabilities
|
| 246 |
predictions = torch.sigmoid(logits).squeeze()
|
| 247 |
|
| 248 |
-
# Define the threshold for prediction
|
| 249 |
-
threshold = 0.3
|
| 250 |
|
| 251 |
# Get the predicted labels
|
| 252 |
predicted_labels_ = (predictions.cpu().numpy() > threshold).tolist()
|
|
@@ -352,13 +350,17 @@ def predict_from_csv(file, column_name, progress=gr.Progress()):
|
|
| 352 |
# Define the input component
|
| 353 |
file_input = gr.File(label="Upload CSV or Excel file here", show_label=True, file_types=[".csv", ".xls", ".xlsx"])
|
| 354 |
column_name_input = gr.Textbox(label="Enter the column name containing the text to be analyzed", show_label=True)
|
|
|
|
| 355 |
|
| 356 |
# Create the Gradio interface
|
| 357 |
iface3 = gr.Interface(fn=predict_from_csv,
|
| 358 |
-
inputs=[file_input, column_name_input
|
|
|
|
| 359 |
outputs=gr.File(label='Download output CSV', show_label=True),
|
| 360 |
title="Multi-text Prediction",
|
| 361 |
-
description='**
|
|
|
|
|
|
|
| 362 |
|
| 363 |
# Create a tabbed interface
|
| 364 |
demo = gr.TabbedInterface(interface_list=[iface1, iface2, iface3],
|
|
|
|
| 107 |
# Calculate the probabilities
|
| 108 |
probabilities = torch.sigmoid(logits).squeeze()
|
| 109 |
|
|
|
|
|
|
|
|
|
|
| 110 |
# Get the predicted labels
|
| 111 |
predicted_labels_ = (probabilities.cpu().numpy() > threshold).tolist()
|
| 112 |
|
|
|
|
| 159 |
|
| 160 |
# Create Gradio interface for single text
|
| 161 |
iface2 = gr.Interface(fn=predict_single_text,
|
| 162 |
+
inputs=[gr.Textbox(lines=7, label="Paste or type text here"),
|
| 163 |
+
gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Threshold value (default=0.3)")],
|
| 164 |
outputs=[gr.Label(label="Top Predictions", show_label=True),
|
| 165 |
gr.Plot(label="Likelihood of all labels", show_label=True)],
|
| 166 |
title="Single Text Prediction",
|
| 167 |
+
description="**Threshold value:** The threshold value determines the minimum probability required for a label to be predicted. A higher threshold value will result in fewer labels being predicted, while a lower threshold value will result in more labels being predicted. The default threshold value is 0.3.",
|
| 168 |
+
article="**Note:** The quality of model predictions may depend on the quality of the information provided.")
|
| 169 |
|
| 170 |
|
| 171 |
# UPLOAD CSV
|
|
|
|
| 243 |
# Calculate the probabilities
|
| 244 |
predictions = torch.sigmoid(logits).squeeze()
|
| 245 |
|
| 246 |
+
# # Define the threshold for prediction
|
| 247 |
+
# threshold = 0.3
|
| 248 |
|
| 249 |
# Get the predicted labels
|
| 250 |
predicted_labels_ = (predictions.cpu().numpy() > threshold).tolist()
|
|
|
|
| 350 |
# Define the input component
|
| 351 |
file_input = gr.File(label="Upload CSV or Excel file here", show_label=True, file_types=[".csv", ".xls", ".xlsx"])
|
| 352 |
column_name_input = gr.Textbox(label="Enter the column name containing the text to be analyzed", show_label=True)
|
| 353 |
+
threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Threshold value (default=0.3)")
|
| 354 |
|
| 355 |
# Create the Gradio interface
|
| 356 |
iface3 = gr.Interface(fn=predict_from_csv,
|
| 357 |
+
inputs=[file_input, column_name_input,
|
| 358 |
+
gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Threshold value (default=0.3)")],
|
| 359 |
outputs=gr.File(label='Download output CSV', show_label=True),
|
| 360 |
title="Multi-text Prediction",
|
| 361 |
+
description='''**Threshold value:** The threshold value determines the minimum probability required
|
| 362 |
+
for a label to be predicted. A higher threshold value will result in fewer labels being predicted,
|
| 363 |
+
while a lower threshold value will result in more labels being predicted. The default threshold value is 0.3''')
|
| 364 |
|
| 365 |
# Create a tabbed interface
|
| 366 |
demo = gr.TabbedInterface(interface_list=[iface1, iface2, iface3],
|