namngo commited on
Commit
cfa5afc
·
verified ·
1 Parent(s): da6adda

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +42 -36
src/streamlit_app.py CHANGED
@@ -1,79 +1,85 @@
1
  import os
2
  import streamlit as st
3
- from transformers import DistilBertTokenizer, TFDistilBertModel
4
- from tensorflow.keras.models import load_model
5
  import numpy as np
6
  import tensorflow as tf
 
 
 
7
 
8
  # =======================
9
- # ĐẶT LẠI CACHE ĐỂ TRÁNH LỖI TRÊN SPACES
10
  # =======================
11
- os.environ['TRANSFORMERS_CACHE'] = './cache'
 
 
 
12
 
13
  # =======================
14
- # CẤU HÌNH
15
  # =======================
16
- MAX_LEN = 400
17
- MODEL_PATH = "src/model_Adam.h5"
18
- TOKENIZER_PATH = "src/"
19
  # =======================
20
- # LOAD TOKENIZER & TRANSFORMER
21
  # =======================
22
  @st.cache_resource
23
  def load_tokenizer():
24
  return DistilBertTokenizer.from_pretrained(TOKENIZER_PATH)
25
 
26
- @st.cache_resource
27
- def load_transformer():
28
- return TFDistilBertModel.from_pretrained(TOKENIZER_PATH)
29
-
30
- # =======================
31
- # ĐỊNH NGHĨA transformer_layer
32
- # =======================
33
- def transformer_layer(inputs):
34
- input_ids, attention_mask = inputs
35
- transformer = load_transformer()
36
- outputs = transformer(input_ids=input_ids, attention_mask=attention_mask)
37
- return outputs.last_hidden_state[:, 0, :]
38
 
39
  # =======================
40
- # LOAD MÔ HÌNH PHÂN LOẠI CẢM XÚC
41
  # =======================
42
  @st.cache_resource
43
- def load_sentiment_model():
44
- return load_model(MODEL_PATH, custom_objects={'transformer_layer': transformer_layer})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # =======================
47
- # TIỀN XỬ LÝ VĂN BẢN
48
  # =======================
49
- tokenizer = load_tokenizer()
50
-
51
  def preprocess(text):
52
  tokens = tokenizer(
53
  text,
54
  max_length=MAX_LEN,
55
- padding='max_length',
56
  truncation=True,
57
- return_tensors='tf'
58
  )
59
  return {
60
- 'input_ids': tokens['input_ids'],
61
- 'attention_mask': tokens['attention_mask']
62
  }
63
 
64
  # =======================
65
- # STREAMLIT APP
66
  # =======================
67
- st.title("🎬 Sentiment Analysis Đánh giá phim")
68
 
69
- user_input = st.text_area("Nhập nội dung đánh giá phim của bạn:", height=150)
70
 
71
  if st.button("Dự đoán cảm xúc"):
72
  if not user_input.strip():
73
- st.warning("Vui lòng nhập nội dung trước khi dự đoán.")
74
  else:
75
  with st.spinner("Đang xử lý..."):
76
- model = load_sentiment_model()
77
  inputs = preprocess(user_input)
78
  prob = model.predict(inputs)[0][0]
79
  label = "TÍCH CỰC 😊" if prob >= 0.5 else "TIÊU CỰC 😞"
 
1
  import os
2
  import streamlit as st
 
 
3
  import numpy as np
4
  import tensorflow as tf
5
+ from tensorflow.keras.layers import Input, Lambda, Dense
6
+ from tensorflow.keras.models import Model
7
+ from transformers import DistilBertTokenizer, TFDistilBertModel
8
 
9
  # =======================
10
+ # CẤU HÌNH
11
  # =======================
12
+ MAX_LEN = 400
13
+ WEIGHTS_PATH = "src/model_Adam.h5"
14
+ TOKENIZER_PATH = "src"
15
+ CACHE_DIR = "./cache"
16
 
17
  # =======================
18
+ # TRÁNH LỖI GHI CACHE
19
  # =======================
20
+ os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
21
+
 
22
  # =======================
23
+ # TẢI TOKENIZER
24
  # =======================
25
  @st.cache_resource
26
  def load_tokenizer():
27
  return DistilBertTokenizer.from_pretrained(TOKENIZER_PATH)
28
 
29
+ tokenizer = load_tokenizer()
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # =======================
32
+ # TẠO MÔ HÌNH (PHẢI GIỐNG KHI TRAIN)
33
  # =======================
34
  @st.cache_resource
35
+ def create_model_and_load_weights():
36
+ transformer = TFDistilBertModel.from_pretrained("distilbert-base-uncased", cache_dir=CACHE_DIR)
37
+
38
+ input_ids = Input(shape=(MAX_LEN,), dtype=tf.int32, name="input_ids")
39
+ attention_mask = Input(shape=(MAX_LEN,), dtype=tf.int32, name="attention_mask")
40
+
41
+ def transformer_layer(inputs):
42
+ ids, mask = inputs
43
+ outputs = transformer(input_ids=ids, attention_mask=mask)
44
+ return outputs.last_hidden_state[:, 0, :] # Lấy CLS token
45
+
46
+ cls_output = Lambda(transformer_layer)([input_ids, attention_mask])
47
+ output = Dense(1, activation='sigmoid')(cls_output)
48
+
49
+ model = Model(inputs=[input_ids, attention_mask], outputs=output)
50
+ model.load_weights(WEIGHTS_PATH)
51
+ return model
52
+
53
+ model = create_model_and_load_weights()
54
 
55
  # =======================
56
+ # TIỀN XỬ LÝ
57
  # =======================
 
 
58
  def preprocess(text):
59
  tokens = tokenizer(
60
  text,
61
  max_length=MAX_LEN,
62
+ padding="max_length",
63
  truncation=True,
64
+ return_tensors="tf"
65
  )
66
  return {
67
+ "input_ids": tokens["input_ids"],
68
+ "attention_mask": tokens["attention_mask"]
69
  }
70
 
71
  # =======================
72
+ # GIAO DIỆN STREAMLIT
73
  # =======================
74
+ st.title("🎬 Phân tích cảm xúc đánh giá phim")
75
 
76
+ user_input = st.text_area("Nhập đánh giá phim của bạn:", height=150)
77
 
78
  if st.button("Dự đoán cảm xúc"):
79
  if not user_input.strip():
80
+ st.warning("Vui lòng nhập nội dung.")
81
  else:
82
  with st.spinner("Đang xử lý..."):
 
83
  inputs = preprocess(user_input)
84
  prob = model.predict(inputs)[0][0]
85
  label = "TÍCH CỰC 😊" if prob >= 0.5 else "TIÊU CỰC 😞"