The-Adnan-Syed commited on
Commit
3097f15
·
verified ·
1 Parent(s): a1a7b94

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ from transformers import BertTokenizer, TFBertForSequenceClassification
4
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
5
+
6
+ # Load the BERT tokenizer and model
7
+ tokenizer = BertTokenizer.from_pretrained('config.json') # Path to folder containing config.json
8
+ model = TFBertForSequenceClassification.from_pretrained('tf_model.h5', from_pt=True) # Path to folder containing tf_model.h5
9
+
10
+ def predict(text):
11
+ # Encode the text using the BERT tokenizer
12
+ input_ids = tokenizer.encode(text, add_special_tokens=True, max_length=128, truncation=True)
13
+ input_ids = pad_sequences([input_ids], maxlen=128, truncating='post', padding='post')
14
+
15
+ # Convert to tensors
16
+ input_ids = tf.convert_to_tensor(input_ids)
17
+
18
+ # Get predictions
19
+ logits = model(input_ids)[0]
20
+
21
+ # Apply softmax to calculate probabilities
22
+ probabilities = tf.nn.softmax(logits, axis=1).numpy()[0]
23
+
24
+ return probabilities
25
+
26
+ # Streamlit UI
27
+ st.title("Stress Categorization with BERT")
28
+ st.write("Enter the text to analyze for stress levels:")
29
+
30
+ # Text input
31
+ user_input = st.text_area("Text", height=150)
32
+
33
+ if st.button("Predict"):
34
+ # Make prediction
35
+ probabilities = predict(user_input)
36
+
37
+ # Display probabilities
38
+ st.write("Probabilities:")
39
+ st.write(f"Stressed: {probabilities[1]:.4f}")
40
+ st.write(f"Not Stressed: {probabilities[0]:.4f}")
41
+
42
+ # Display the most likely class
43
+ if probabilities[0] > probabilities[1]:
44
+ st.success("Prediction: Not Stressed")
45
+ else:
46
+ st.error("Prediction: Stressed")
47
+
48
+ # Assuming you have an accuracy metric available (replace with actual accuracy if available)
49
+ accuracy = 0.95 # Example accuracy
50
+ st.write(f"Model Accuracy: {accuracy * 100:.2f}%")