shukdevdattaEX commited on
Commit
ba2a27d
·
verified ·
1 Parent(s): 70103cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -35
app.py CHANGED
@@ -1,35 +1,65 @@
1
- import streamlit as st
2
- from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
3
- import tensorflow as tf
4
-
5
- # Load the pre-trained model and tokenizer using the correct Hugging Face model repo ID
6
- model_path = 'shukdevdatta123/Dreaddit_DistillBert_Stress_Model'
7
- loaded_model = TFDistilBertForSequenceClassification.from_pretrained(model_path)
8
- loaded_tokenizer = DistilBertTokenizer.from_pretrained(model_path)
9
-
10
- # Define the prediction function
11
- def predict_with_loaded_model(in_sentences):
12
- labels = ["non-stress", "stress"]
13
- inputs = loaded_tokenizer(in_sentences, return_tensors="tf", padding=True, truncation=True, max_length=512)
14
- predictions = loaded_model(inputs)
15
- predicted_labels = tf.argmax(predictions.logits, axis=-1).numpy()
16
- predicted_probs = tf.nn.softmax(predictions.logits, axis=-1).numpy()
17
-
18
- return [{"text": sentence, "confidence": probs.tolist(), "label": labels[label]} for sentence, label, probs in zip(in_sentences, predicted_labels, predicted_probs)]
19
-
20
- # Streamlit interface
21
- st.title("Stress Prediction with DistilBERT")
22
-
23
- # Add a text input box for the user to enter a sentence
24
- user_input = st.text_area("Enter a sentence or text:", "")
25
-
26
- # When the user clicks "Predict", run the prediction function
27
- if st.button("Predict"):
28
- if user_input:
29
- # Make the prediction using the model
30
- prediction = predict_with_loaded_model([user_input])[0]
31
- st.write(f"Text: {prediction['text']}")
32
- st.write(f"Prediction: {prediction['label']}")
33
- # st.write(f"Confidence: {prediction['confidence']}")
34
- else:
35
- st.write("Please enter a sentence to predict.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
3
+ import tensorflow as tf
4
+
5
+ # Load the pre-trained model and tokenizer using the correct Hugging Face model repo ID
6
+ model_path = 'shukdevdatta123/Dreaddit_DistillBert_Stress_Model'
7
+ loaded_model = TFDistilBertForSequenceClassification.from_pretrained(model_path)
8
+ loaded_tokenizer = DistilBertTokenizer.from_pretrained(model_path)
9
+
10
+ # Define the prediction function
11
+ def predict_with_loaded_model(in_sentence):
12
+ if not in_sentence:
13
+ return "Please enter a sentence to predict.", None
14
+
15
+ labels = ["non-stress", "stress"]
16
+ inputs = loaded_tokenizer([in_sentence], return_tensors="tf", padding=True, truncation=True, max_length=512)
17
+ predictions = loaded_model(inputs)
18
+ predicted_labels = tf.argmax(predictions.logits, axis=-1).numpy()
19
+ predicted_probs = tf.nn.softmax(predictions.logits, axis=-1).numpy()
20
+
21
+ label = labels[predicted_labels[0]]
22
+ confidence = predicted_probs[0].tolist()
23
+
24
+ # Create a dictionary for the label output
25
+ result = {labels[i]: confidence[i] for i in range(len(labels))}
26
+
27
+ return f"Prediction: {label}", result
28
+
29
+ # Create Gradio interface
30
+ with gr.Blocks(title="Stress Prediction with DistilBERT") as demo:
31
+ gr.Markdown("# Stress Prediction with DistilBERT")
32
+ gr.Markdown("Enter a sentence or text to predict whether it indicates stress or non-stress.")
33
+
34
+ with gr.Row():
35
+ with gr.Column():
36
+ input_text = gr.Textbox(
37
+ label="Enter a sentence or text:",
38
+ placeholder="Type your text here...",
39
+ lines=5
40
+ )
41
+ predict_btn = gr.Button("Predict", variant="primary")
42
+
43
+ with gr.Column():
44
+ output_text = gr.Textbox(label="Result", lines=2)
45
+ output_confidence = gr.Label(label="Confidence Scores", num_top_classes=2)
46
+
47
+ predict_btn.click(
48
+ fn=predict_with_loaded_model,
49
+ inputs=input_text,
50
+ outputs=[output_text, output_confidence]
51
+ )
52
+
53
+ # Add examples
54
+ gr.Examples(
55
+ examples=[
56
+ ["I feel overwhelmed with all the work I have to do."],
57
+ ["I had a great day at the park with my friends."],
58
+ ["The deadline is tomorrow and I haven't even started."]
59
+ ],
60
+ inputs=input_text
61
+ )
62
+
63
+ # Launch the app
64
+ if __name__ == "__main__":
65
+ demo.launch()