|
|
import gradio as gr |
|
|
from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification |
|
|
import tensorflow as tf |
|
|
|
|
|
|
|
|
model_path = 'shukdevdatta123/Dreaddit_DistillBert_Stress_Model' |
|
|
loaded_model = TFDistilBertForSequenceClassification.from_pretrained(model_path) |
|
|
loaded_tokenizer = DistilBertTokenizer.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
def predict_with_loaded_model(in_sentence): |
|
|
if not in_sentence: |
|
|
return "Please enter a sentence to predict.", None |
|
|
|
|
|
labels = ["non-stress", "stress"] |
|
|
inputs = loaded_tokenizer([in_sentence], return_tensors="tf", padding=True, truncation=True, max_length=512) |
|
|
predictions = loaded_model(inputs) |
|
|
predicted_labels = tf.argmax(predictions.logits, axis=-1).numpy() |
|
|
predicted_probs = tf.nn.softmax(predictions.logits, axis=-1).numpy() |
|
|
|
|
|
label = labels[predicted_labels[0]] |
|
|
confidence = predicted_probs[0].tolist() |
|
|
|
|
|
|
|
|
result = {labels[i]: confidence[i] for i in range(len(labels))} |
|
|
|
|
|
return f"Prediction: {label}", result |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Stress Prediction with DistilBERT") as demo: |
|
|
gr.Markdown("# Stress Prediction with DistilBERT") |
|
|
gr.Markdown("Enter a sentence or text to predict whether it indicates stress or non-stress.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_text = gr.Textbox( |
|
|
label="Enter a sentence or text:", |
|
|
placeholder="Type your text here...", |
|
|
lines=5 |
|
|
) |
|
|
predict_btn = gr.Button("Predict", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_text = gr.Textbox(label="Result", lines=2) |
|
|
output_confidence = gr.Label(label="Confidence Scores", num_top_classes=2) |
|
|
|
|
|
predict_btn.click( |
|
|
fn=predict_with_loaded_model, |
|
|
inputs=input_text, |
|
|
outputs=[output_text, output_confidence] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["I feel overwhelmed with all the work I have to do."], |
|
|
["I had a great day at the park with my friends."], |
|
|
["The deadline is tomorrow and I haven't even started."] |
|
|
], |
|
|
inputs=input_text |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |