File size: 4,017 Bytes
94d49c9
 
23b096f
 
94d49c9
23b096f
94d49c9
 
 
23b096f
94d49c9
 
 
 
 
 
23b096f
94d49c9
 
 
23b096f
94d49c9
 
23b096f
 
 
94d49c9
5e43dc8
94d49c9
 
23b096f
 
94d49c9
 
7a16b3c
94d49c9
 
23b096f
 
8aa1427
23b096f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94d49c9
23b096f
94d49c9
23b096f
94d49c9
 
 
 
23b096f
94d49c9
 
23b096f
94d49c9
 
23b096f
94d49c9
 
23b096f
649e09f
23b096f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94d49c9
23b096f
 
 
 
 
 
 
 
 
 
94d49c9
 
2e44d20
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import streamlit as st
import logging
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# λ‘œκΉ… μ„€μ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# νŽ˜μ΄μ§€ μ„€μ •
st.set_page_config(
    page_title="DeepSeek Chatbot - ruslanmv.com",
    page_icon="πŸ€–",
    layout="centered"
)

# μ„Έμ…˜ μƒνƒœμ— μ±„νŒ… 기둝 μ΄ˆκΈ°ν™”
if "messages" not in st.session_state:
    st.session_state.messages = []

# μ‚¬μ΄λ“œλ°” μ„€μ •
with st.sidebar:
    st.header("Model Configuration")
    st.markdown("λͺ¨λΈμ„ λ‘œμ»¬μ—μ„œ 직접 λ‘œλ“œν•©λ‹ˆλ‹€.")
    
    # λͺ¨λΈ 선택 λ“œλ‘­λ‹€μš΄ (ν•„μš” μ‹œ λ‹€λ₯Έ λͺ¨λΈ μΆ”κ°€)
    model_options = [
        "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
    ]
    selected_model = st.selectbox("Select Model", model_options, index=0)
    
    # μ‹œμŠ€ν…œ λ©”μ‹œμ§€ μ„€μ •
    system_message = st.text_area(
        "System Message",
        value="You are a friendly chatbot created by ruslanmv.com. Provide clear, accurate, and brief answers. Keep responses polite, engaging, and to the point. If unsure, politely suggest alternatives.",
        height=100
    )
    
    # 생성 νŒŒλΌλ―Έν„° μ„€μ •
    max_tokens = st.slider("Max Tokens", 10, 4000, 1000)
    temperature = st.slider("Temperature", 0.1, 4.0, 0.3)
    top_p = st.slider("Top-p", 0.1, 1.0, 0.6)

# λͺ¨λΈκ³Ό ν† ν¬λ‚˜μ΄μ €λ₯Ό λ‘œλ“œν•˜λŠ” ν•¨μˆ˜ (μΊμ‹±ν•˜μ—¬ ν•œ 번만 λ‘œλ“œ)
@st.cache_resource
def load_model_and_tokenizer(model_name: str):
    logger.info(f"Loading model and tokenizer for {model_name} ...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    # device_map="auto" μ˜΅μ…˜μ€ μ‚¬μš© κ°€λŠ₯ν•œ GPU/CPU에 맞게 λͺ¨λΈμ„ μžλ™μœΌλ‘œ ν• λ‹Ήν•©λ‹ˆλ‹€.
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
    return tokenizer, model

tokenizer, model = load_model_and_tokenizer(selected_model)

# μ±„νŒ… μΈν„°νŽ˜μ΄μŠ€
st.title("πŸ€– DeepSeek Chatbot")
st.caption("Powered by local model - Configure in sidebar")

# κΈ°μ‘΄ μ±„νŒ… 기둝 ν‘œμ‹œ
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# μ‚¬μš©μž μž…λ ₯ 처리
if prompt := st.chat_input("Type your message..."):
    st.session_state.messages.append({"role": "user", "content": prompt})
    
    with st.chat_message("user"):
        st.markdown(prompt)
    
    try:
        with st.spinner("Generating response..."):
            # μ‹œμŠ€ν…œ λ©”μ‹œμ§€μ™€ μ‚¬μš©μž μž…λ ₯을 κ²°ν•©ν•˜μ—¬ 전체 ν”„λ‘¬ν”„νŠΈ ꡬ성
            full_prompt = f"{system_message}\n\nUser: {prompt}\nAssistant:"
            inputs = tokenizer.encode(full_prompt, return_tensors="pt").to(model.device)
            
            # λͺ¨λΈμ„ μ‚¬μš©ν•˜μ—¬ 응닡 생성
            output_tokens = model.generate(
                inputs,
                max_new_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
            )
            
            # μƒμ„±λœ 토큰 λ””μ½”λ”©
            output_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
            # "Assistant:" μ΄ν›„μ˜ ν…μŠ€νŠΈλ§Œ μΆ”μΆœ (μ—†μœΌλ©΄ 전체 ν…μŠ€νŠΈ μ‚¬μš©)
            if "Assistant:" in output_text:
                assistant_response = output_text.split("Assistant:")[-1].strip()
            else:
                assistant_response = output_text.strip()
            
            logger.info(f"Generated response: {assistant_response}")
            
            # μƒμ„±λœ 응닡을 μ±„νŒ…μ— 좜λ ₯
            with st.chat_message("assistant"):
                st.markdown(assistant_response)
            
            st.session_state.messages.append({"role": "assistant", "content": assistant_response})
    
    except Exception as e:
        logger.error(f"Application Error: {str(e)}", exc_info=True)
        st.error(f"Application Error: {str(e)}")