import gradio as gr from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification import tensorflow as tf # Load the pre-trained model and tokenizer using the correct Hugging Face model repo ID model_path = 'shukdevdatta123/Dreaddit_DistillBert_Stress_Model' loaded_model = TFDistilBertForSequenceClassification.from_pretrained(model_path) loaded_tokenizer = DistilBertTokenizer.from_pretrained(model_path) # Define the prediction function def predict_with_loaded_model(in_sentence): if not in_sentence: return "Please enter a sentence to predict.", None labels = ["non-stress", "stress"] inputs = loaded_tokenizer([in_sentence], return_tensors="tf", padding=True, truncation=True, max_length=512) predictions = loaded_model(inputs) predicted_labels = tf.argmax(predictions.logits, axis=-1).numpy() predicted_probs = tf.nn.softmax(predictions.logits, axis=-1).numpy() label = labels[predicted_labels[0]] confidence = predicted_probs[0].tolist() # Create a dictionary for the label output result = {labels[i]: confidence[i] for i in range(len(labels))} return f"Prediction: {label}", result # Create Gradio interface with gr.Blocks(title="Stress Prediction with DistilBERT") as demo: gr.Markdown("# Stress Prediction with DistilBERT") gr.Markdown("Enter a sentence or text to predict whether it indicates stress or non-stress.") with gr.Row(): with gr.Column(): input_text = gr.Textbox( label="Enter a sentence or text:", placeholder="Type your text here...", lines=5 ) predict_btn = gr.Button("Predict", variant="primary") with gr.Column(): output_text = gr.Textbox(label="Result", lines=2) output_confidence = gr.Label(label="Confidence Scores", num_top_classes=2) predict_btn.click( fn=predict_with_loaded_model, inputs=input_text, outputs=[output_text, output_confidence] ) # Add examples gr.Examples( examples=[ ["I feel overwhelmed with all the work I have to do."], ["I had a great day at the park with my friends."], ["The deadline is tomorrow and I haven't even started."] ], inputs=input_text ) # Launch the app if __name__ == "__main__": demo.launch()