File size: 1,842 Bytes
a82c32b
aaa7c43
9916019
02e1881
aaa7c43
da5cac2
e3b157b
1ce81ff
a82c32b
da5cac2
 
 
 
 
aaa7c43
9916019
aaa7c43
a82c32b
 
da5cac2
02e1881
da5cac2
02e1881
 
 
da5cac2
 
02e1881
da5cac2
02e1881
a82c32b
 
 
aaa7c43
9916019
 
a82c32b
 
 
 
 
 
 
9916019
aaa7c43
9916019
 
aaa7c43
9916019
 
 
 
 
aaa7c43
9916019
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# --- Dùng thư mục cache riêng, tránh PermissionError ---
os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
os.environ["HF_HOME"] = "./hf_cache"

st.title("🤖 Fine-tuned Qwen3 Chatbot")

# --- Model paths ---
BASE_MODEL = "unsloth/Qwen3-4B-Instruct-2507"
FINE_TUNED = "phuphan1310/Fine-tuned-model-test"

device = "cuda" if torch.cuda.is_available() else "cpu"

@st.cache_resource(show_spinner=True)
def load_model():
    # ⚠️ Dùng tokenizer từ model gốc (Unsloth) vì tokenizer fine-tuned lỗi format
    tokenizer = AutoTokenizer.from_pretrained(
        BASE_MODEL,
        trust_remote_code=True
    )
    model = AutoModelForCausalLM.from_pretrained(
        FINE_TUNED,
        trust_remote_code=True,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto"
    )
    return tokenizer, model

tokenizer, model = load_model()

def generate_response(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=200,
        temperature=0.7,
        top_p=0.9,
        do_sample=True
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

if "messages" not in st.session_state:
    st.session_state.messages = []

user_input = st.text_input("Enter your message:")
if user_input:
    st.session_state.messages.append({"role": "user", "content": user_input})
    response = generate_response(user_input)
    st.session_state.messages.append({"role": "assistant", "content": response})

for msg in st.session_state.messages:
    if msg["role"] == "user":
        st.markdown(f"**You:** {msg['content']}")
    else:
        st.markdown(f"**Bot:** {msg['content']}")