| |
| import time |
| import random |
| import numpy as np |
| import streamlit as st |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| |
| st.set_page_config( |
| page_title="ChatGPT-124M", |
| layout="wide", |
| initial_sidebar_state="expanded", |
| page_icon="🤖", |
| ) |
|
|
| |
| st.title("🤖 ChatGPT-124M") |
|
|
| |
| if "messages" not in st.session_state: |
| st.session_state.messages = [] |
|
|
| if "max_length" not in st.session_state: |
| st.session_state.max_length = 256 |
|
|
| if "do_sample" not in st.session_state: |
| st.session_state.do_sample = True |
|
|
| if "top_k" not in st.session_state: |
| st.session_state.top_k = 20 |
|
|
| if "top_p" not in st.session_state: |
| st.session_state.top_p = 0.90 |
|
|
| if "temperature" not in st.session_state: |
| st.session_state.temperature = 0.9 |
|
|
| |
| for message in st.session_state.messages: |
| with st.chat_message(message["role"]): |
| st.markdown(message["content"]) |
|
|
| |
| MODEL_NAME = "GPT_124M" |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True) |
|
|
|
|
| |
| def string_to_generator(text): |
| """Yields text one character at a time for a streaming effect.""" |
| for char in text: |
| time.sleep(0.005) |
| yield char |
|
|
| |
| def format_prompt(prompt: str) -> str: |
| return f"""### Instruction: |
| {prompt.strip()} |
| |
| ### Response: |
| """ |
|
|
|
|
| |
| st.sidebar.header("⚙️ Generation Settings") |
|
|
|
|
| |
| max_length = st.sidebar.slider( |
| "Max Length", min_value=1, max_value=512, key="max_length" |
| ) |
|
|
| |
| do_sample = st.sidebar.toggle( |
| "Enable Sampling", key="do_sample" |
| ) |
|
|
| |
| top_k = st.sidebar.slider( |
| "Top-K", min_value=1, max_value=100, disabled=not do_sample, key="top_k" |
| ) |
|
|
| |
| top_p = st.sidebar.slider( |
| "Top-P", |
| min_value=0.0, |
| max_value=1.0, |
| step=0.01, |
| disabled=not do_sample, |
| key="top_p", |
| ) |
|
|
| |
| temperature = st.sidebar.slider( |
| "Temperature", |
| min_value=0.0, |
| max_value=1.0, |
| step=0.01, |
| disabled=not do_sample, |
| key="temperature", |
| ) |
|
|
| |
| if st.sidebar.button("Reset"): |
| for st_key in [ |
| "messages", |
| "do_sample", |
| "max_length", |
| "top_k", |
| "top_p", |
| "temperature", |
| ]: |
| del st.session_state[st_key] |
| st.rerun() |
|
|
| |
| loading_messages = [ |
| "Generating your response, please wait...", |
| "Working on your response...", |
| "Processing, this will just take a moment...", |
| "Creating your response, hold on...", |
| "Loading your answer, please be patient...", |
| ] |
|
|
| |
| if prompt := st.chat_input( |
| "Type your question here…", max_chars=400, key="chat_input" |
| ): |
|
|
| |
| st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
| |
| with st.chat_message("user"): |
| st.markdown(prompt) |
|
|
| |
| with st.chat_message("assistant"): |
| tokens = tokenizer.encode(format_prompt(prompt), return_tensors="pt") |
|
|
| with st.spinner(random.choice(loading_messages)): |
| generated_tokens = model.generate( |
| tokens, |
| max_length=max_length, |
| do_sample=do_sample, |
| top_k=top_k if do_sample else 1, |
| top_p=top_p if do_sample else 1.0, |
| temperature=temperature if do_sample else 1.0, |
| ) |
|
|
| response_text = tokenizer.decode( |
| generated_tokens[tokens.shape[1]:] |
| ).strip() |
| response = st.write_stream(string_to_generator(response_text)) |
|
|
| |
| st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|