SiennaClarke commited on
Commit
5196574
·
verified ·
1 Parent(s): cd09b92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -43
app.py CHANGED
@@ -2,89 +2,120 @@ import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
3
  from threading import Thread
4
  import torch
 
5
 
6
- # UI Setup (No Sidebar as requested)
7
- st.set_page_config(page_title="Qwen 2.5 32B Chat", page_icon="🐘", layout="centered", initial_sidebar_state="collapsed")
8
- st.markdown("<style>[data-testid='collapsedControl'] { display: none; }</style>", unsafe_allow_html=True)
 
 
 
 
9
 
10
- # 1. Model Configuration (Quantized to fit on 24GB VRAM or 32GB RAM)
11
- MODEL_ID = "Qwen/Qwen2.5-32B-Instruct"
 
 
 
 
 
12
 
13
- @st.cache_resource
 
 
 
14
  def load_llm():
15
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16
-
17
- # 4-bit config allows this 64GB model to fit in ~18-20GB of memory
18
- quant_config = BitsAndBytesConfig(
19
- load_in_4bit=True,
20
- bnb_4bit_compute_dtype=torch.float16,
21
- bnb_4bit_quant_type="nf4"
22
- )
23
-
24
- model = AutoModelForCausalLM.from_pretrained(
25
- MODEL_ID,
26
- quantization_config=quant_config,
27
- device_map="auto" # Automatically splits between GPU and CPU
28
- )
29
- return tokenizer, model
 
 
 
 
 
 
 
 
 
30
 
31
  tokenizer, model = load_llm()
32
 
33
- # 2. Chat Interface
34
- st.title("🐘 Qwen 2.5 32B")
35
- st.caption("Running high-parameter model with 4-bit quantization")
36
-
37
  if "messages" not in st.session_state:
38
  st.session_state.messages = []
39
 
40
- # Action Button
41
- if st.button("Clear History"):
 
 
 
42
  st.session_state.messages = []
43
  st.rerun()
44
 
45
- # Display history
46
  for msg in st.session_state.messages:
47
- with st.chat_message(msg["role"]):
 
 
48
  st.markdown(msg["content"])
49
 
50
- # 3. Chat Logic with your exact Template Code
51
- if prompt := st.chat_input("Message Qwen 2.5 32B..."):
 
52
  st.session_state.messages.append({"role": "user", "content": prompt})
53
  with st.chat_message("user"):
54
  st.markdown(prompt)
55
 
 
56
  with st.chat_message("assistant"):
57
- # Setup Streamer
58
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
59
 
60
- # YOUR EXACT LOGIC: Applying the chat template
61
  inputs = tokenizer.apply_chat_template(
 
62
  st.session_state.messages,
63
  add_generation_prompt=True,
64
  tokenize=True,
65
  return_dict=True,
66
- return_tensors="pt",
67
  ).to(model.device)
68
 
69
- # Threading for live streaming
 
 
 
70
  generation_kwargs = dict(
71
  **inputs,
72
  streamer=streamer,
73
- max_new_tokens=512,
74
  do_sample=True,
75
  temperature=0.7,
 
76
  pad_token_id=tokenizer.eos_token_id
77
  )
78
-
79
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
80
  thread.start()
81
 
82
- # Word-by-word UI update
83
- placeholder = st.empty()
84
- full_response = ""
85
  for new_text in streamer:
86
  full_response += new_text
87
  placeholder.markdown(full_response + "▌")
88
-
89
  placeholder.markdown(full_response)
90
- st.session_state.messages.append({"role": "assistant", "content": full_response})
 
 
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
3
  from threading import Thread
4
  import torch
5
+ import sys
6
 
7
+ # --- UI Configuration ---
8
+ st.set_page_config(
9
+ page_title="Klove AI ChatBox",
10
+ page_icon="🐘",
11
+ layout="centered",
12
+ initial_sidebar_state="collapsed"
13
+ )
14
 
15
+ # Professional CSS injection for cleaner UI
16
+ st.markdown("""
17
+ <style>
18
+ [data-testid='collapsedControl'] { display: none; }
19
+ .stChatMessage { border-radius: 10px; margin-bottom: 10px; }
20
+ </style>
21
+ """, unsafe_allow_html=True)
22
 
23
+ # --- Model Constants ---
24
+ MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
25
+
26
+ @st.cache_resource(show_spinner="Initializing Model Engine...")
27
  def load_llm():
28
+ try:
29
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
30
+
31
+ # Expert Config: nf4 quantization with bfloat16 for better stability if hardware supports it
32
+ compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
33
+
34
+ quant_config = BitsAndBytesConfig(
35
+ load_in_4bit=True,
36
+ bnb_4bit_compute_dtype=compute_dtype,
37
+ bnb_4bit_quant_type="nf4",
38
+ bnb_4bit_use_double_quant=True # Expert addition: Saves extra VRAM
39
+ )
40
+
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ MODEL_ID,
43
+ quantization_config=quant_config,
44
+ device_map="auto",
45
+ trust_remote_code=True,
46
+ low_cpu_mem_usage=True
47
+ )
48
+ return tokenizer, model
49
+ except Exception as e:
50
+ st.error(f"Failed to load model: {e}")
51
+ st.stop()
52
 
53
  tokenizer, model = load_llm()
54
 
55
+ # --- Chat Session State ---
 
 
 
56
  if "messages" not in st.session_state:
57
  st.session_state.messages = []
58
 
59
+ # --- Header ---
60
+ st.title("🐘 Qwen 2.5 Chat")
61
+ st.caption(f"Backend: {MODEL_ID} (4-bit NF4 Quantized)")
62
+
63
+ if st.button("Clear Conversation", type="primary"):
64
  st.session_state.messages = []
65
  st.rerun()
66
 
67
+ # --- Message Rendering ---
68
  for msg in st.session_state.messages:
69
+ # Handle the 'coder' role mapping to 'assistant' for UI consistency
70
+ role = "assistant" if msg["role"] == "coder" else msg["role"]
71
+ with st.chat_message(role):
72
  st.markdown(msg["content"])
73
 
74
+ # --- Generation Logic ---
75
+ if prompt := st.chat_input("Message to Qwen..."):
76
+ # Append User Message
77
  st.session_state.messages.append({"role": "user", "content": prompt})
78
  with st.chat_message("user"):
79
  st.markdown(prompt)
80
 
81
+ # Generate Assistant Response
82
  with st.chat_message("assistant"):
83
+ placeholder = st.empty()
84
+ full_response = ""
85
 
86
+ # 1. Prepare Inputs
87
  inputs = tokenizer.apply_chat_template(
88
+ # Filter history to only include user/coder roles for the template
89
  st.session_state.messages,
90
  add_generation_prompt=True,
91
  tokenize=True,
92
  return_dict=True,
93
+ return_tensors="pt"
94
  ).to(model.device)
95
 
96
+ # 2. Setup Streamer
97
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
98
+
99
+ # 3. Execution (Expert Note: use inference_mode for speed/memory)
100
  generation_kwargs = dict(
101
  **inputs,
102
  streamer=streamer,
103
+ max_new_tokens=1024, # Increased for more robust answers
104
  do_sample=True,
105
  temperature=0.7,
106
+ top_p=0.9, # Added for higher quality sampling
107
  pad_token_id=tokenizer.eos_token_id
108
  )
109
+
110
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
111
  thread.start()
112
 
113
+ # 4. Stream Handling
 
 
114
  for new_text in streamer:
115
  full_response += new_text
116
  placeholder.markdown(full_response + "▌")
117
+
118
  placeholder.markdown(full_response)
119
+
120
+ # Store as 'coder' per original logic requirement
121
+ st.session_state.messages.append({"role": "coder", "content": full_response})