mentallama-demo / src /streamlit_app.py
machinelearnAn's picture
Update src/streamlit_app.py
4b7d152 verified
# app.py
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# --- Tối ưu hóa: Tải mô hình chỉ một lần duy nhất ---
# @st.cache_resource là một "bảo bối" của Streamlit.
# Nó đảm bảo hàm này chỉ chạy một lần khi ứng dụng khởi động.
# Lần sau khi người dùng tương tác, mô hình đã có sẵn trong bộ nhớ.
@st.cache_resource
def load_model():
model_id = "SteveKGYang/MentaLLaMA-chat-7B"
# Cấu hình Quantization 4-bit để giảm RAM
bnb_config = {
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
"bnb_4bit_use_double_quant": False,
}
# Tải mô hình với cấu hình đã thiết lập
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto", # Tự động phân bổ lên các thiết bị có sẵn (CPU)
trust_remote_code=True
)
# Tải tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
return model, tokenizer
# --- Hàm để chạy suy luận ---
def run_inference(post_text, model, tokenizer):
# Đây là cấu trúc prompt mà MentaLLaMA được huấn luyện để tuân theo.
# Việc tuân thủ đúng prompt format là RẤT QUAN TRỌNG.
prompt = f"""### Human: Analyze the following post and provide a diagnosis along with a detailed reasoning. Post: {post_text} ### Assistant:"""
# Mã hóa prompt thành các token
inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
# Chạy mô hình để tạo ra kết quả
outputs = model.generate(
**inputs,
max_new_tokens=256, # Giới hạn độ dài của câu trả lời
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id
)
# Giải mã kết quả từ token về lại văn bản
result_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Xử lý chuỗi để chỉ lấy phần trả lời của Assistant
# Kết quả thô sẽ có dạng: "### Human: ... ### Assistant: Answer: ... Reasoning: ..."
# Chúng ta cần cắt bỏ phần prompt đi.
try:
assistant_response = result_text.split("### Assistant:")[1].strip()
return assistant_response
except IndexError:
return "Lỗi: Không thể phân tích cú pháp đầu ra của mô hình."
# --- Giao diện ứng dụng Streamlit ---
st.set_page_config(layout="wide", page_title="MentaLLaMA Live Demo")
st.title("🔬 MentaLLaMA Live Demo - Chạy mô hình thật")
st.markdown("""
Chào mừng đến với bản demo **chạy thật** của **MentaLLaMA**!
- **Nhập** một đoạn văn bản (bằng tiếng Anh) vào ô bên dưới.
- **Nhấn nút** để mô hình MentaLLaMA-7B phân tích trực tiếp.
- **Lưu ý:** Vì chạy trên CPU miễn phí, quá trình suy luận có thể mất **1-2 phút**. Vui lòng kiên nhẫn!
""")
# Tải mô hình (sẽ hiển thị thanh tiến trình lần đầu)
with st.spinner("Đang tải mô hình MentaLLaMA-7B (lần đầu có thể mất vài phút)..."):
model, tokenizer = load_model()
st.success("Mô hình đã sẵn sàng!")
# Ô nhập liệu cho người dùng
user_input = st.text_area("Nhập bài đăng của bạn vào đây (tiếng Anh):", "Lately, I just feel so empty and numb inside. Nothing brings me joy anymore...", height=150)
# Nút chạy suy luận
if st.button('Chạy Phân Tích MentaLLaMA', type="primary"):
if user_input:
with st.spinner("MentaLLaMA đang suy luận, quá trình này có thể mất 1-2 phút..."):
result = run_inference(user_input, model, tokenizer)
st.subheader("Kết quả phân tích từ MentaLLaMA:")
st.success(result)
else:
st.warning("Vui lòng nhập nội dung bài đăng.")
# Footer
st.markdown("---")
st.markdown("Demo được xây dựng dựa trên bài báo: *MentaLLaMA: Interpretable Mental Health Analysis on Social Media with Large Language Models*.")