import os import streamlit as st import numpy as np import tensorflow as tf from tensorflow.keras.layers import Input, Lambda, Dense from tensorflow.keras.models import Model from transformers import DistilBertTokenizer, TFDistilBertModel # ======================= # CẤU HÌNH # ======================= MAX_LEN = 400 WEIGHTS_PATH = "src/model_Adam.h5" TOKENIZER_PATH = "src" CACHE_DIR = "./cache" # ======================= # TRÁNH LỖI GHI CACHE # ======================= os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR # ======================= # TẢI TOKENIZER # ======================= @st.cache_resource def load_tokenizer(): return DistilBertTokenizer.from_pretrained(TOKENIZER_PATH) tokenizer = load_tokenizer() # ======================= # TẠO MÔ HÌNH (PHẢI GIỐNG KHI TRAIN) # ======================= @st.cache_resource def create_model_and_load_weights(): transformer = TFDistilBertModel.from_pretrained("distilbert-base-uncased", cache_dir=CACHE_DIR) input_ids = Input(shape=(MAX_LEN,), dtype=tf.int32, name="input_ids") attention_mask = Input(shape=(MAX_LEN,), dtype=tf.int32, name="attention_mask") def transformer_layer(inputs): ids, mask = inputs outputs = transformer(input_ids=ids, attention_mask=mask) return outputs.last_hidden_state[:, 0, :] # Lấy CLS token cls_output = Lambda(transformer_layer)([input_ids, attention_mask]) output = Dense(1, activation='sigmoid')(cls_output) model = Model(inputs=[input_ids, attention_mask], outputs=output) model.load_weights(WEIGHTS_PATH) return model model = create_model_and_load_weights() # ======================= # TIỀN XỬ LÝ # ======================= def preprocess(text): tokens = tokenizer( text, max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="tf" ) return { "input_ids": tokens["input_ids"], "attention_mask": tokens["attention_mask"] } # ======================= # GIAO DIỆN STREAMLIT # ======================= st.title("🎬 Phân tích cảm xúc đánh giá phim") user_input = st.text_area("Nhập đánh giá phim của bạn:", height=150) if st.button("Dự đoán cảm xúc"): if not user_input.strip(): st.warning("Vui lòng nhập nội dung.") else: with st.spinner("Đang xử lý..."): inputs = preprocess(user_input) prob = model.predict(inputs)[0][0] label = "TÍCH CỰC 😊" if prob >= 0.5 else "TIÊU CỰC 😞" confidence = float(prob) if prob >= 0.5 else 1 - float(prob) st.markdown(f"### ✅ Dự đoán: **{label}**") st.markdown(f"**Độ tin cậy:** {confidence:.2%}")