Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| # λͺ¨λΈ λ‘λ (DeepSeek-R1-Distill-Qwen-1.5B μμ) | |
| def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"): | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model_name, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| truncation=True, | |
| max_new_tokens=2048 | |
| ) | |
| return pipe | |
| # μ± μ€ν ν¨μ | |
| def main(): | |
| st.set_page_config(page_title="DeepSeek-R1 Chatbot", page_icon="π€") | |
| st.title("DeepSeek-R1 κΈ°λ° λνν μ±λ΄") | |
| st.write("DeepSeek-R1-Distill-Qwen-1.5B λͺ¨λΈμ νμ©ν λν ν μ€νΈμ© λ°λͺ¨μ λλ€.") | |
| # μΈμ μ€ν μ΄νΈ μ΄κΈ°ν | |
| if "chat_history_ids" not in st.session_state: | |
| st.session_state["chat_history_ids"] = None | |
| if "past_user_inputs" not in st.session_state: | |
| st.session_state["past_user_inputs"] = [] | |
| if "generated_responses" not in st.session_state: | |
| st.session_state["generated_responses"] = [] | |
| # λͺ¨λΈκ³Ό ν ν¬λμ΄μ λΆλ¬μ€κΈ° | |
| pipe = load_model() | |
| # κΈ°μ‘΄ λν λ΄μ νμ | |
| if st.session_state["past_user_inputs"]: | |
| for user_text, bot_text in zip(st.session_state["past_user_inputs"], st.session_state["generated_responses"]): | |
| # μ¬μ©μ λ©μμ§ | |
| with st.chat_message("user"): | |
| st.write(user_text) | |
| # λ΄ λ©μμ§ | |
| with st.chat_message("assistant"): | |
| st.write(bot_text) | |
| # μ±ν μ λ ₯μ°½ | |
| user_input = st.chat_input("λ©μμ§λ₯Ό μμ΄λ‘ μ λ ₯νμΈμ...") | |
| if user_input: | |
| # μ¬μ©μ λ©μμ§ νμ | |
| with st.chat_message("user"): | |
| st.write(user_input) | |
| # ν둬ννΈ μμ± | |
| prompt = f"Human: {user_input}\n\nAssistant:" | |
| # λͺ¨λΈ μμ± | |
| response = pipe( | |
| prompt, | |
| max_new_tokens=2048, | |
| temperature=0.7, | |
| do_sample=True, | |
| truncation=True, | |
| pad_token_id=50256 | |
| ) | |
| bot_text = response[0]["generated_text"] | |
| # Assistant μλ΅λ§ μΆμΆ (κ°μ λ λ°©μ) | |
| try: | |
| bot_text = bot_text.split("Assistant:")[-1].strip() | |
| if "</think>" in bot_text: # λ΄λΆ μ¬κ³ κ³Όμ μ κ±° | |
| bot_text = bot_text.split("</think>")[-1].strip() | |
| except: | |
| bot_text = "μ£μ‘ν©λλ€. μλ΅μ μμ±νλ λ° λ¬Έμ κ° λ°μνμ΅λλ€." | |
| # μΈμ μ€ν μ΄νΈμ λν λ΄μ© μ λ°μ΄νΈ | |
| st.session_state["past_user_inputs"].append(user_input) | |
| st.session_state["generated_responses"].append(bot_text) | |
| # λ΄ λ©μμ§ νμ | |
| with st.chat_message("assistant"): | |
| st.write(bot_text) | |
| if __name__ == "__main__": | |
| main() | |