MovieSentiment / src /streamlit_app.py
namngo's picture
Update src/streamlit_app.py
cfa5afc verified
import os
import streamlit as st
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Dense
from tensorflow.keras.models import Model
from transformers import DistilBertTokenizer, TFDistilBertModel
# =======================
# CẤU HÌNH
# =======================
MAX_LEN = 400
WEIGHTS_PATH = "src/model_Adam.h5"
TOKENIZER_PATH = "src"
CACHE_DIR = "./cache"
# =======================
# TRÁNH LỖI GHI CACHE
# =======================
os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
# =======================
# TẢI TOKENIZER
# =======================
@st.cache_resource
def load_tokenizer():
return DistilBertTokenizer.from_pretrained(TOKENIZER_PATH)
tokenizer = load_tokenizer()
# =======================
# TẠO MÔ HÌNH (PHẢI GIỐNG KHI TRAIN)
# =======================
@st.cache_resource
def create_model_and_load_weights():
transformer = TFDistilBertModel.from_pretrained("distilbert-base-uncased", cache_dir=CACHE_DIR)
input_ids = Input(shape=(MAX_LEN,), dtype=tf.int32, name="input_ids")
attention_mask = Input(shape=(MAX_LEN,), dtype=tf.int32, name="attention_mask")
def transformer_layer(inputs):
ids, mask = inputs
outputs = transformer(input_ids=ids, attention_mask=mask)
return outputs.last_hidden_state[:, 0, :] # Lấy CLS token
cls_output = Lambda(transformer_layer)([input_ids, attention_mask])
output = Dense(1, activation='sigmoid')(cls_output)
model = Model(inputs=[input_ids, attention_mask], outputs=output)
model.load_weights(WEIGHTS_PATH)
return model
model = create_model_and_load_weights()
# =======================
# TIỀN XỬ LÝ
# =======================
def preprocess(text):
tokens = tokenizer(
text,
max_length=MAX_LEN,
padding="max_length",
truncation=True,
return_tensors="tf"
)
return {
"input_ids": tokens["input_ids"],
"attention_mask": tokens["attention_mask"]
}
# =======================
# GIAO DIỆN STREAMLIT
# =======================
st.title("🎬 Phân tích cảm xúc đánh giá phim")
user_input = st.text_area("Nhập đánh giá phim của bạn:", height=150)
if st.button("Dự đoán cảm xúc"):
if not user_input.strip():
st.warning("Vui lòng nhập nội dung.")
else:
with st.spinner("Đang xử lý..."):
inputs = preprocess(user_input)
prob = model.predict(inputs)[0][0]
label = "TÍCH CỰC 😊" if prob >= 0.5 else "TIÊU CỰC 😞"
confidence = float(prob) if prob >= 0.5 else 1 - float(prob)
st.markdown(f"### ✅ Dự đoán: **{label}**")
st.markdown(f"**Độ tin cậy:** {confidence:.2%}")