# app.py import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForCausalLM # --- Tối ưu hóa: Tải mô hình chỉ một lần duy nhất --- # @st.cache_resource là một "bảo bối" của Streamlit. # Nó đảm bảo hàm này chỉ chạy một lần khi ứng dụng khởi động. # Lần sau khi người dùng tương tác, mô hình đã có sẵn trong bộ nhớ. @st.cache_resource def load_model(): model_id = "SteveKGYang/MentaLLaMA-chat-7B" # Cấu hình Quantization 4-bit để giảm RAM bnb_config = { "load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16, "bnb_4bit_use_double_quant": False, } # Tải mô hình với cấu hình đã thiết lập model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=bnb_config, device_map="auto", # Tự động phân bổ lên các thiết bị có sẵn (CPU) trust_remote_code=True ) # Tải tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id) return model, tokenizer # --- Hàm để chạy suy luận --- def run_inference(post_text, model, tokenizer): # Đây là cấu trúc prompt mà MentaLLaMA được huấn luyện để tuân theo. # Việc tuân thủ đúng prompt format là RẤT QUAN TRỌNG. prompt = f"""### Human: Analyze the following post and provide a diagnosis along with a detailed reasoning. Post: {post_text} ### Assistant:""" # Mã hóa prompt thành các token inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") # Chạy mô hình để tạo ra kết quả outputs = model.generate( **inputs, max_new_tokens=256, # Giới hạn độ dài của câu trả lời eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id ) # Giải mã kết quả từ token về lại văn bản result_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Xử lý chuỗi để chỉ lấy phần trả lời của Assistant # Kết quả thô sẽ có dạng: "### Human: ... ### Assistant: Answer: ... Reasoning: ..." # Chúng ta cần cắt bỏ phần prompt đi. try: assistant_response = result_text.split("### Assistant:")[1].strip() return assistant_response except IndexError: return "Lỗi: Không thể phân tích cú pháp đầu ra của mô hình." # --- Giao diện ứng dụng Streamlit --- st.set_page_config(layout="wide", page_title="MentaLLaMA Live Demo") st.title("🔬 MentaLLaMA Live Demo - Chạy mô hình thật") st.markdown(""" Chào mừng đến với bản demo **chạy thật** của **MentaLLaMA**! - **Nhập** một đoạn văn bản (bằng tiếng Anh) vào ô bên dưới. - **Nhấn nút** để mô hình MentaLLaMA-7B phân tích trực tiếp. - **Lưu ý:** Vì chạy trên CPU miễn phí, quá trình suy luận có thể mất **1-2 phút**. Vui lòng kiên nhẫn! """) # Tải mô hình (sẽ hiển thị thanh tiến trình lần đầu) with st.spinner("Đang tải mô hình MentaLLaMA-7B (lần đầu có thể mất vài phút)..."): model, tokenizer = load_model() st.success("Mô hình đã sẵn sàng!") # Ô nhập liệu cho người dùng user_input = st.text_area("Nhập bài đăng của bạn vào đây (tiếng Anh):", "Lately, I just feel so empty and numb inside. Nothing brings me joy anymore...", height=150) # Nút chạy suy luận if st.button('Chạy Phân Tích MentaLLaMA', type="primary"): if user_input: with st.spinner("MentaLLaMA đang suy luận, quá trình này có thể mất 1-2 phút..."): result = run_inference(user_input, model, tokenizer) st.subheader("Kết quả phân tích từ MentaLLaMA:") st.success(result) else: st.warning("Vui lòng nhập nội dung bài đăng.") # Footer st.markdown("---") st.markdown("Demo được xây dựng dựa trên bài báo: *MentaLLaMA: Interpretable Mental Health Analysis on Social Media with Large Language Models*.")