STRESS-ANALYSIS / app.py
shukdevdattaEX's picture
Update app.py
ba2a27d verified
import gradio as gr
from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
import tensorflow as tf
# Load the pre-trained model and tokenizer using the correct Hugging Face model repo ID
model_path = 'shukdevdatta123/Dreaddit_DistillBert_Stress_Model'
loaded_model = TFDistilBertForSequenceClassification.from_pretrained(model_path)
loaded_tokenizer = DistilBertTokenizer.from_pretrained(model_path)
# Define the prediction function
def predict_with_loaded_model(in_sentence):
if not in_sentence:
return "Please enter a sentence to predict.", None
labels = ["non-stress", "stress"]
inputs = loaded_tokenizer([in_sentence], return_tensors="tf", padding=True, truncation=True, max_length=512)
predictions = loaded_model(inputs)
predicted_labels = tf.argmax(predictions.logits, axis=-1).numpy()
predicted_probs = tf.nn.softmax(predictions.logits, axis=-1).numpy()
label = labels[predicted_labels[0]]
confidence = predicted_probs[0].tolist()
# Create a dictionary for the label output
result = {labels[i]: confidence[i] for i in range(len(labels))}
return f"Prediction: {label}", result
# Create Gradio interface
with gr.Blocks(title="Stress Prediction with DistilBERT") as demo:
gr.Markdown("# Stress Prediction with DistilBERT")
gr.Markdown("Enter a sentence or text to predict whether it indicates stress or non-stress.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Enter a sentence or text:",
placeholder="Type your text here...",
lines=5
)
predict_btn = gr.Button("Predict", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Result", lines=2)
output_confidence = gr.Label(label="Confidence Scores", num_top_classes=2)
predict_btn.click(
fn=predict_with_loaded_model,
inputs=input_text,
outputs=[output_text, output_confidence]
)
# Add examples
gr.Examples(
examples=[
["I feel overwhelmed with all the work I have to do."],
["I had a great day at the park with my friends."],
["The deadline is tomorrow and I haven't even started."]
],
inputs=input_text
)
# Launch the app
if __name__ == "__main__":
demo.launch()