new_website / app.py
kajonation's picture
initial commit
faeb6ce verified
import torch
from transformers import BertForSequenceClassification, BertTokenizer
from safetensors.torch import load_file
import gradio as gr
model_path = "/kaggle/input/model_12k/other/default/1/model (5).safetensors"
state_dict = load_file(model_path)
model = BertForSequenceClassification.from_pretrained('indobenchmark/indobert-base-p2', num_labels=3)
tokenizer = BertTokenizer.from_pretrained('indobenchmark/indobert-base-p2')
model.load_state_dict(state_dict, strict=False)
model.eval()
def detect_stress(input_text):
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=1).item()
labels = {
0: ("Not Stress", "#8BC34A", "Currently you are not experiencing stress. Stay on top of your health!"),
1: ("Mild Stress", "#FF7F00", "Saat ini anda sedang mengalami stres ringan. Luangkan waktu untuk relaksasi."),
2: ("High Stress", "#F44336", "Currently you are experiencing mild stress. Take time to relax.")
}
level, color, message = labels[predicted_class]
return f"<div style='background-color:{color}; color:white; text-align:center; padding:15px; border-radius:10px; font-size:16px; heigth:200px; width: 500px; margin:auto;'>" \
f"Level stres Anda: {level}<br>{message}" \
f"</div>"
# Apabila menggunakan model SVM atau ensemble learning
# pipeline = joblib.load("/kaggle/input/svm_model/other/default/1/svm_hybrid_pipeline.pkl")
# def detect_stress(input_text):
# predicted_class = pipeline.predict([input_text])[0]
# probs = pipeline.predict_proba([input_text])[0]
# confidence = max(probs)
# labels = {
# 0: ("Not Stress", "#8BC34A", "Currently you are not experiencing stress. Stay on top of your health!"),
# 1: ("Mild Stress", "#FF7F00", "Saat ini anda sedang mengalami stres ringan. Luangkan waktu untuk relaksasi."),
# 2: ("High Stress", "#F44336", "Currently you are experiencing mild stress. Take time to relax.")
# }
# level, color, message = labels[predicted_class]
# return f"<div style='background-color:{color}; color:white; text-align:center; padding:15px; border-radius:10px; font-size:16px; heigth:200px; width: 500px; margin:auto;'>" \
# f"Level stress anda : {level}<br>{message}" \
# f"</div>"
custom_css = """
body {
margin: 0;
padding: 0;
font-family: Arial, sans-serif;
background-color: var(--background);
color: var(--text);
transition: background-color 0.3s, color 0.3s;
}
#title {
position: fixed;
top: 0;
left: 0;
width: 100vw;
padding: 20px;
background-color: #ff7a33;
color: white;
font-size: 28px;
font-weight: bold;
text-align: center;
z-index: 1000;
}
body {
padding-top: 80px;
}
#container {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
min-height: calc(100vh - 80px);
padding: 20px;
}
textarea {
background-color: var(--textarea-bg);
color: var(--textarea-text);
border: none;
border-radius: 5px;
padding: 10px;
font-size: 16px;
box-sizing: border-box;
resize: none;
}
textarea:focus {
outline: 2px solid #ff7a33;
}
.button_detect {
background-color: #ff7a33;
color: white;
border: none;
border-radius: 5px;
padding: 15px 30px;
font-size: 16px;
cursor: pointer;
margin-top: 10px;
width: 200px;
heigth: 100px;
}
.button_detect:hover {
background-color: #e5662c;
}
@media (prefers-color-scheme: dark) {
:root {
--background: #121212;
--text: white;
--textarea-bg: #2c2c2c;
--textarea-text: white;
}
}
@media (prefers-color-scheme: light) {
:root {
--background: #ffffff;
--text: black;
--textarea-bg: #f0f0f0;
--textarea-text: black;
}
}
"""
# UI Layout
with gr.Blocks(css=custom_css) as demo:
gr.HTML("<div id='title'>Stress Detector</div>") # Banner on top
with gr.Column(elem_id="container"):
input_text = gr.Textbox(
label="Input text",
placeholder="Tell us your complaint here...",
lines=5
)
btn_submit = gr.Button("Detect", elem_classes=["button_detect"])
output_label = gr.HTML(label="Detection Results")
btn_submit.click(fn=detect_stress, inputs=input_text, outputs=output_label)
demo.launch()