File size: 2,427 Bytes
ba2a27d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
import tensorflow as tf

# Load the pre-trained model and tokenizer using the correct Hugging Face model repo ID
model_path = 'shukdevdatta123/Dreaddit_DistillBert_Stress_Model'
loaded_model = TFDistilBertForSequenceClassification.from_pretrained(model_path)
loaded_tokenizer = DistilBertTokenizer.from_pretrained(model_path)

# Define the prediction function
def predict_with_loaded_model(in_sentence):
    if not in_sentence:
        return "Please enter a sentence to predict.", None
    
    labels = ["non-stress", "stress"]
    inputs = loaded_tokenizer([in_sentence], return_tensors="tf", padding=True, truncation=True, max_length=512)
    predictions = loaded_model(inputs)
    predicted_labels = tf.argmax(predictions.logits, axis=-1).numpy()
    predicted_probs = tf.nn.softmax(predictions.logits, axis=-1).numpy()
    
    label = labels[predicted_labels[0]]
    confidence = predicted_probs[0].tolist()
    
    # Create a dictionary for the label output
    result = {labels[i]: confidence[i] for i in range(len(labels))}
    
    return f"Prediction: {label}", result

# Create Gradio interface
with gr.Blocks(title="Stress Prediction with DistilBERT") as demo:
    gr.Markdown("# Stress Prediction with DistilBERT")
    gr.Markdown("Enter a sentence or text to predict whether it indicates stress or non-stress.")
    
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(
                label="Enter a sentence or text:",
                placeholder="Type your text here...",
                lines=5
            )
            predict_btn = gr.Button("Predict", variant="primary")
        
        with gr.Column():
            output_text = gr.Textbox(label="Result", lines=2)
            output_confidence = gr.Label(label="Confidence Scores", num_top_classes=2)
    
    predict_btn.click(
        fn=predict_with_loaded_model,
        inputs=input_text,
        outputs=[output_text, output_confidence]
    )
    
    # Add examples
    gr.Examples(
        examples=[
            ["I feel overwhelmed with all the work I have to do."],
            ["I had a great day at the park with my friends."],
            ["The deadline is tomorrow and I haven't even started."]
        ],
        inputs=input_text
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()