openfree's picture
Update app.py
5e43dc8 verified
import streamlit as st
import logging
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# λ‘œκΉ… μ„€μ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# νŽ˜μ΄μ§€ μ„€μ •
st.set_page_config(
page_title="DeepSeek Chatbot - ruslanmv.com",
page_icon="πŸ€–",
layout="centered"
)
# μ„Έμ…˜ μƒνƒœμ— μ±„νŒ… 기둝 μ΄ˆκΈ°ν™”
if "messages" not in st.session_state:
st.session_state.messages = []
# μ‚¬μ΄λ“œλ°” μ„€μ •
with st.sidebar:
st.header("Model Configuration")
st.markdown("λͺ¨λΈμ„ λ‘œμ»¬μ—μ„œ 직접 λ‘œλ“œν•©λ‹ˆλ‹€.")
# λͺ¨λΈ 선택 λ“œλ‘­λ‹€μš΄ (ν•„μš” μ‹œ λ‹€λ₯Έ λͺ¨λΈ μΆ”κ°€)
model_options = [
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
]
selected_model = st.selectbox("Select Model", model_options, index=0)
# μ‹œμŠ€ν…œ λ©”μ‹œμ§€ μ„€μ •
system_message = st.text_area(
"System Message",
value="You are a friendly chatbot created by ruslanmv.com. Provide clear, accurate, and brief answers. Keep responses polite, engaging, and to the point. If unsure, politely suggest alternatives.",
height=100
)
# 생성 νŒŒλΌλ―Έν„° μ„€μ •
max_tokens = st.slider("Max Tokens", 10, 4000, 1000)
temperature = st.slider("Temperature", 0.1, 4.0, 0.3)
top_p = st.slider("Top-p", 0.1, 1.0, 0.6)
# λͺ¨λΈκ³Ό ν† ν¬λ‚˜μ΄μ €λ₯Ό λ‘œλ“œν•˜λŠ” ν•¨μˆ˜ (μΊμ‹±ν•˜μ—¬ ν•œ 번만 λ‘œλ“œ)
@st.cache_resource
def load_model_and_tokenizer(model_name: str):
logger.info(f"Loading model and tokenizer for {model_name} ...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
# device_map="auto" μ˜΅μ…˜μ€ μ‚¬μš© κ°€λŠ₯ν•œ GPU/CPU에 맞게 λͺ¨λΈμ„ μžλ™μœΌλ‘œ ν• λ‹Ήν•©λ‹ˆλ‹€.
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
return tokenizer, model
tokenizer, model = load_model_and_tokenizer(selected_model)
# μ±„νŒ… μΈν„°νŽ˜μ΄μŠ€
st.title("πŸ€– DeepSeek Chatbot")
st.caption("Powered by local model - Configure in sidebar")
# κΈ°μ‘΄ μ±„νŒ… 기둝 ν‘œμ‹œ
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# μ‚¬μš©μž μž…λ ₯ 처리
if prompt := st.chat_input("Type your message..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
try:
with st.spinner("Generating response..."):
# μ‹œμŠ€ν…œ λ©”μ‹œμ§€μ™€ μ‚¬μš©μž μž…λ ₯을 κ²°ν•©ν•˜μ—¬ 전체 ν”„λ‘¬ν”„νŠΈ ꡬ성
full_prompt = f"{system_message}\n\nUser: {prompt}\nAssistant:"
inputs = tokenizer.encode(full_prompt, return_tensors="pt").to(model.device)
# λͺ¨λΈμ„ μ‚¬μš©ν•˜μ—¬ 응닡 생성
output_tokens = model.generate(
inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
# μƒμ„±λœ 토큰 λ””μ½”λ”©
output_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
# "Assistant:" μ΄ν›„μ˜ ν…μŠ€νŠΈλ§Œ μΆ”μΆœ (μ—†μœΌλ©΄ 전체 ν…μŠ€νŠΈ μ‚¬μš©)
if "Assistant:" in output_text:
assistant_response = output_text.split("Assistant:")[-1].strip()
else:
assistant_response = output_text.strip()
logger.info(f"Generated response: {assistant_response}")
# μƒμ„±λœ 응닡을 μ±„νŒ…μ— 좜λ ₯
with st.chat_message("assistant"):
st.markdown(assistant_response)
st.session_state.messages.append({"role": "assistant", "content": assistant_response})
except Exception as e:
logger.error(f"Application Error: {str(e)}", exc_info=True)
st.error(f"Application Error: {str(e)}")