File size: 4,231 Bytes
bce5f7b
aab52f7
4b7d152
 
aab52f7
4b7d152
 
 
 
 
 
 
 
 
 
 
 
 
 
bce5f7b
 
4b7d152
 
 
 
 
 
 
bce5f7b
4b7d152
 
bce5f7b
4b7d152
bce5f7b
4b7d152
 
 
 
 
bce5f7b
4b7d152
 
bce5f7b
4b7d152
 
 
 
 
 
 
bce5f7b
4b7d152
 
bce5f7b
4b7d152
 
 
 
 
 
 
 
bce5f7b
4b7d152
 
bce5f7b
4b7d152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bce5f7b
4b7d152
 
 
 
bce5f7b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# app.py
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# --- Tối ưu hóa: Tải mô hình chỉ một lần duy nhất ---
# @st.cache_resource là một "bảo bối" của Streamlit.
# Nó đảm bảo hàm này chỉ chạy một lần khi ứng dụng khởi động.
# Lần sau khi người dùng tương tác, mô hình đã có sẵn trong bộ nhớ.
@st.cache_resource
def load_model():
    model_id = "SteveKGYang/MentaLLaMA-chat-7B"
    
    # Cấu hình Quantization 4-bit để giảm RAM
    bnb_config = {
        "load_in_4bit": True,
        "bnb_4bit_quant_type": "nf4",
        "bnb_4bit_compute_dtype": torch.bfloat16,
        "bnb_4bit_use_double_quant": False,
    }

    # Tải mô hình với cấu hình đã thiết lập
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto", # Tự động phân bổ lên các thiết bị có sẵn (CPU)
        trust_remote_code=True
    )

    # Tải tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    return model, tokenizer

# --- Hàm để chạy suy luận ---
def run_inference(post_text, model, tokenizer):
    # Đây là cấu trúc prompt mà MentaLLaMA được huấn luyện để tuân theo.
    # Việc tuân thủ đúng prompt format là RẤT QUAN TRỌNG.
    prompt = f"""### Human: Analyze the following post and provide a diagnosis along with a detailed reasoning. Post: {post_text} ### Assistant:"""

    # Mã hóa prompt thành các token
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")

    # Chạy mô hình để tạo ra kết quả
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,  # Giới hạn độ dài của câu trả lời
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id
    )

    # Giải mã kết quả từ token về lại văn bản
    result_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Xử lý chuỗi để chỉ lấy phần trả lời của Assistant
    # Kết quả thô sẽ có dạng: "### Human: ... ### Assistant: Answer: ... Reasoning: ..."
    # Chúng ta cần cắt bỏ phần prompt đi.
    try:
        assistant_response = result_text.split("### Assistant:")[1].strip()
        return assistant_response
    except IndexError:
        return "Lỗi: Không thể phân tích cú pháp đầu ra của mô hình."

# --- Giao diện ứng dụng Streamlit ---
st.set_page_config(layout="wide", page_title="MentaLLaMA Live Demo")

st.title("🔬 MentaLLaMA Live Demo - Chạy mô hình thật")
st.markdown("""
Chào mừng đến với bản demo **chạy thật** của **MentaLLaMA**! 
- **Nhập** một đoạn văn bản (bằng tiếng Anh) vào ô bên dưới.
- **Nhấn nút** để mô hình MentaLLaMA-7B phân tích trực tiếp.
- **Lưu ý:** Vì chạy trên CPU miễn phí, quá trình suy luận có thể mất **1-2 phút**. Vui lòng kiên nhẫn!
""")

# Tải mô hình (sẽ hiển thị thanh tiến trình lần đầu)
with st.spinner("Đang tải mô hình MentaLLaMA-7B (lần đầu có thể mất vài phút)..."):
    model, tokenizer = load_model()

st.success("Mô hình đã sẵn sàng!")

# Ô nhập liệu cho người dùng
user_input = st.text_area("Nhập bài đăng của bạn vào đây (tiếng Anh):", "Lately, I just feel so empty and numb inside. Nothing brings me joy anymore...", height=150)

# Nút chạy suy luận
if st.button('Chạy Phân Tích MentaLLaMA', type="primary"):
    if user_input:
        with st.spinner("MentaLLaMA đang suy luận, quá trình này có thể mất 1-2 phút..."):
            result = run_inference(user_input, model, tokenizer)
        
        st.subheader("Kết quả phân tích từ MentaLLaMA:")
        st.success(result)
    else:
        st.warning("Vui lòng nhập nội dung bài đăng.")

# Footer
st.markdown("---")
st.markdown("Demo được xây dựng dựa trên bài báo: *MentaLLaMA: Interpretable Mental Health Analysis on Social Media with Large Language Models*.")