File size: 1,842 Bytes
a82c32b aaa7c43 9916019 02e1881 aaa7c43 da5cac2 e3b157b 1ce81ff a82c32b da5cac2 aaa7c43 9916019 aaa7c43 a82c32b da5cac2 02e1881 da5cac2 02e1881 da5cac2 02e1881 da5cac2 02e1881 a82c32b aaa7c43 9916019 a82c32b 9916019 aaa7c43 9916019 aaa7c43 9916019 aaa7c43 9916019 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# --- Dùng thư mục cache riêng, tránh PermissionError ---
os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
os.environ["HF_HOME"] = "./hf_cache"
st.title("🤖 Fine-tuned Qwen3 Chatbot")
# --- Model paths ---
BASE_MODEL = "unsloth/Qwen3-4B-Instruct-2507"
FINE_TUNED = "phuphan1310/Fine-tuned-model-test"
device = "cuda" if torch.cuda.is_available() else "cpu"
@st.cache_resource(show_spinner=True)
def load_model():
# ⚠️ Dùng tokenizer từ model gốc (Unsloth) vì tokenizer fine-tuned lỗi format
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
FINE_TUNED,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
return tokenizer, model
tokenizer, model = load_model()
def generate_response(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9,
do_sample=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
if "messages" not in st.session_state:
st.session_state.messages = []
user_input = st.text_input("Enter your message:")
if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
response = generate_response(user_input)
st.session_state.messages.append({"role": "assistant", "content": response})
for msg in st.session_state.messages:
if msg["role"] == "user":
st.markdown(f"**You:** {msg['content']}")
else:
st.markdown(f"**Bot:** {msg['content']}")
|