SiennaClarke commited on
Commit
9d1fd1f
·
verified ·
1 Parent(s): c80c5cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -24
app.py CHANGED
@@ -3,14 +3,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
3
  from threading import Thread
4
  import torch
5
 
6
- # Clean, centered layout without sidebar
7
- st.set_page_config(page_title="Qwen 3 4B Stream", page_icon="", layout="centered", initial_sidebar_state="collapsed")
8
 
9
- # 1. Model Configuration (Qwen 3 4B - 4-bit for speed)
10
- MODEL_ID = "unsloth/Qwen3-4B-Instruct-2507-bnb-4bit"
11
 
12
  @st.cache_resource
13
- def load_resource():
 
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_ID,
@@ -19,61 +20,57 @@ def load_resource():
19
  )
20
  return tokenizer, model
21
 
22
- tokenizer, model = load_resource()
23
 
24
- # Custom CSS to hide the sidebar toggle
25
  st.markdown("<style>[data-testid='collapsedControl'] { display: none; }</style>", unsafe_allow_html=True)
26
 
27
- st.title(" Qwen 3 4B Stream")
28
- st.caption("Real-time local generation | No Sidebar")
29
 
30
  if "messages" not in st.session_state:
31
  st.session_state.messages = []
32
 
33
- # Action Buttons
34
- col1, col2 = st.columns([5, 1])
35
- with col2:
36
- if st.button("Reset"):
37
- st.session_state.messages = []
38
- st.rerun()
39
 
40
- # Display chat history
41
  for msg in st.session_state.messages:
42
  with st.chat_message(msg["role"]):
43
  st.markdown(msg["content"])
44
 
45
  # 2. Streaming Chat Input
46
- if prompt := st.chat_input("Ask Qwen 3..."):
47
  st.session_state.messages.append({"role": "user", "content": prompt})
48
  with st.chat_message("user"):
49
  st.markdown(prompt)
50
 
51
  with st.chat_message("assistant"):
52
- # Setup the Streamer
53
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
54
 
55
- # Prepare input
56
  input_text = tokenizer.apply_chat_template(st.session_state.messages, tokenize=False, add_generation_prompt=True)
57
  inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
58
 
59
- # 3. Generation in a separate thread
60
  generation_kwargs = dict(
61
  **inputs,
62
  streamer=streamer,
63
  max_new_tokens=1024,
64
  do_sample=True,
65
  temperature=0.7,
66
- top_p=0.8,
67
- pad_token_id=tokenizer.eos_token_id
68
  )
69
 
70
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
71
  thread.start()
72
 
73
- # 4. Stream to UI
74
  placeholder = st.empty()
75
  full_response = ""
76
-
77
  for new_text in streamer:
78
  full_response += new_text
79
  placeholder.markdown(full_response + "▌")
 
3
  from threading import Thread
4
  import torch
5
 
6
+ # UI Setup - No Sidebar
7
+ st.set_page_config(page_title="Gemma 3 1B Fast Chat", page_icon="💎", layout="centered", initial_sidebar_state="collapsed")
8
 
9
+ # 1. Model ID for Gemma 3 1B Instruct
10
+ MODEL_ID = "google/gemma-3-1b-it"
11
 
12
  @st.cache_resource
13
+ def load_model():
14
+ # Gemma 3 1B is small enough to load in bfloat16 or float32 quickly
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16
  model = AutoModelForCausalLM.from_pretrained(
17
  MODEL_ID,
 
20
  )
21
  return tokenizer, model
22
 
23
+ tokenizer, model = load_model()
24
 
25
+ # Custom CSS to keep the clean look
26
  st.markdown("<style>[data-testid='collapsedControl'] { display: none; }</style>", unsafe_allow_html=True)
27
 
28
+ st.title("💎 Gemma 3 1B")
29
+ st.caption("Lightweight Google AI | High-Speed Local Chat")
30
 
31
  if "messages" not in st.session_state:
32
  st.session_state.messages = []
33
 
34
+ # Action Button
35
+ if st.button("Clear Chat History"):
36
+ st.session_state.messages = []
37
+ st.rerun()
 
 
38
 
39
+ # Display history
40
  for msg in st.session_state.messages:
41
  with st.chat_message(msg["role"]):
42
  st.markdown(msg["content"])
43
 
44
  # 2. Streaming Chat Input
45
+ if prompt := st.chat_input("Message Gemma 3..."):
46
  st.session_state.messages.append({"role": "user", "content": prompt})
47
  with st.chat_message("user"):
48
  st.markdown(prompt)
49
 
50
  with st.chat_message("assistant"):
51
+ # Setup Streamer
52
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
53
 
54
+ # Gemma 3 uses a specific chat template format
55
  input_text = tokenizer.apply_chat_template(st.session_state.messages, tokenize=False, add_generation_prompt=True)
56
  inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
57
 
58
+ # Threaded generation for real-time streaming
59
  generation_kwargs = dict(
60
  **inputs,
61
  streamer=streamer,
62
  max_new_tokens=1024,
63
  do_sample=True,
64
  temperature=0.7,
65
+ top_p=0.9,
 
66
  )
67
 
68
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
69
  thread.start()
70
 
71
+ # Update UI word-by-word
72
  placeholder = st.empty()
73
  full_response = ""
 
74
  for new_text in streamer:
75
  full_response += new_text
76
  placeholder.markdown(full_response + "▌")