"""
Step-by-Step Guide to Building This LLM:
Medium Article: https://medium.com/@fareedkhandev/building-a-perfect-million-parameter-llm-from-scratch-in-python-3b16e26b4139
"""
import random
import re
import time
import numpy as np
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
st.set_page_config(page_title="30M-SFT-LLM", initial_sidebar_state="collapsed")
# Custom CSS to style buttons and layout
st.markdown("""
""", unsafe_allow_html=True)
# Model Configuration
system_prompt = []
device = "cuda" if torch.cuda.is_available() else "cpu"
# Function to process assistant responses
def format_assistant_response(content):
content = re.sub(r'(?
{slogan}
You can create your own 30 Million Parameter LLM using my Medium article.
""", unsafe_allow_html=True)
def main():
model, tokenizer = load_model_tokenizer(model_path)
if "messages" not in st.session_state:
st.session_state.messages = []
st.session_state.chat_messages = []
for i, message in enumerate(st.session_state.messages):
if message["role"] == "assistant":
with st.chat_message("assistant", avatar=avatar_url):
st.markdown(format_assistant_response(message["content"]), unsafe_allow_html=True)
if st.button("🗑", key=f"delete_{i}"):
st.session_state.messages = st.session_state.messages[:i-1]
st.session_state.chat_messages = st.session_state.chat_messages[:i-1]
st.rerun()
else:
st.markdown(f'', unsafe_allow_html=True)
user_input = st.chat_input(placeholder="Send a message to 30M-SFT-LLM")
if user_input:
st.markdown(f'', unsafe_allow_html=True)
st.session_state.messages.append({"role": "user", "content": user_input})
st.session_state.chat_messages.append({"role": "user", "content": user_input})
with st.chat_message("assistant", avatar=avatar_url):
placeholder = st.empty()
setup_seed(random.randint(0, 2 ** 32 - 1))
conversation_history = system_prompt + st.session_state.chat_messages[-(st.session_state.history_chat_num + 1):]
formatted_prompt = tokenizer.apply_chat_template(conversation_history, tokenize=False, add_generation_prompt=True)[-(st.session_state.max_new_tokens - 1):]
input_tensor = torch.tensor(tokenizer(formatted_prompt)['input_ids'], device=device).unsqueeze(0)
with torch.no_grad():
generated_responses = model.generate(input_tensor, tokenizer.eos_token_id, max_new_tokens=st.session_state.max_new_tokens, temperature=st.session_state.temperature, top_p=st.session_state.top_p, stream=True)
full_response = ""
for response in generated_responses:
decoded_text = tokenizer.decode(response[0].tolist(), skip_special_tokens=True)
if not decoded_text or decoded_text[-1] == '�':
continue
full_response = decoded_text.replace(formatted_prompt, "")
placeholder.markdown(format_assistant_response(full_response), unsafe_allow_html=True)
st.session_state.messages.append({"role": "assistant", "content": full_response})
st.session_state.chat_messages.append({"role": "assistant", "content": full_response})
if __name__ == "__main__":
main()