|
|
import streamlit as st |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
from threading import Thread |
|
|
import torch |
|
|
import re |
|
|
|
|
|
|
|
|
MODEL_ID = "Qwen/Qwen2.5-Coder-0.5B-Instruct" |
|
|
DEVICE = "cpu" |
|
|
|
|
|
st.set_page_config(page_title="Smol Claude", layout="wide", page_icon="π€") |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
.stChatMessage { border-radius: 10px; padding: 10px; margin-bottom: 10px; } |
|
|
.artifact-container { background-color: #f8f9fa; border-left: 2px solid #ddd; padding: 20px; height: 100vh; overflow-y: auto; } |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_llm(): |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto").to(DEVICE) |
|
|
return tokenizer, model |
|
|
|
|
|
tokenizer, model = load_llm() |
|
|
|
|
|
|
|
|
def extract_code(text): |
|
|
match = re.search(r"```(?:\w+)?\n([\s\S]*?)\n```", text) |
|
|
return match.group(1) if match else None |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
if "current_code" not in st.session_state: |
|
|
st.session_state.current_code = "" |
|
|
|
|
|
col_chat, col_artifact = st.columns([1, 1]) |
|
|
|
|
|
|
|
|
with col_chat: |
|
|
st.subheader("π¬ Chat") |
|
|
for msg in st.session_state.messages: |
|
|
with st.chat_message(msg["role"]): |
|
|
st.markdown(msg["content"]) |
|
|
|
|
|
if prompt := st.chat_input("What should we build?"): |
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
with st.chat_message("user"): |
|
|
st.markdown(prompt) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
|
|
|
input_msgs = [{"role": "system", "content": "You are a professional coder. Output code in markdown blocks."}] + st.session_state.messages |
|
|
text = tokenizer.apply_chat_template(input_msgs, tokenize=False, add_generation_prompt=True) |
|
|
inputs = tokenizer([text], return_tensors="pt").to(DEVICE) |
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, temperature=0.1) |
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
|
|
|
full_response = "" |
|
|
message_placeholder = st.empty() |
|
|
for new_text in streamer: |
|
|
full_response += new_text |
|
|
message_placeholder.markdown(full_response + "β") |
|
|
|
|
|
message_placeholder.markdown(full_response) |
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
|
|
|
|
|
|
code_snippet = extract_code(full_response) |
|
|
if code_snippet: |
|
|
st.session_state.current_code = code_snippet |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
with col_artifact: |
|
|
st.subheader("π Code Artifact") |
|
|
if st.session_state.current_code: |
|
|
st.code(st.session_state.current_code, language="python", line_numbers=True) |
|
|
st.download_button("Download File", st.session_state.current_code, file_name="artifact.py") |
|
|
else: |
|
|
st.info("Ask the AI to write some code to see it appear here.") |