JumaRubea commited on
Commit
6ea5e8d
·
verified ·
1 Parent(s): b9c9890

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +22 -59
src/streamlit_app.py CHANGED
@@ -58,19 +58,15 @@ if user_input:
58
 
59
  with st.spinner("Thinking..."):
60
  try:
61
- # Format messages for chat template
62
- messages = [
63
- {"role": "system", "content": system_prompt()},
64
- {"role": "user", "content": user_input}
65
- ]
66
-
67
- # Tokenize input using chat template
68
- inputs = tokenizer.apply_chat_template(
69
- messages,
70
- add_generation_prompt=True,
71
- tokenize=True,
72
- return_dict=True,
73
  return_tensors="pt",
 
74
  ).to(device)
75
 
76
  # Generate tokens
@@ -80,7 +76,8 @@ if user_input:
80
  # Stream tokens
81
  generated = inputs["input_ids"]
82
  outputs = model.generate(
83
- **inputs,
 
84
  max_new_tokens=200,
85
  do_sample=False,
86
  temperature=0.5,
@@ -92,56 +89,22 @@ if user_input:
92
  )
93
  sequence = outputs.sequences[0]
94
 
95
- # Decode only new tokens one by one
96
  for i in range(generated.shape[-1], sequence.shape[-1]):
97
  token_id = sequence[i].unsqueeze(0)
98
- text = tokenizer.decode(token_id, skip_special_tokens=True)
99
- if text.strip():
100
  full_response += text
101
  placeholder.markdown(full_response)
102
 
103
- st.chat_message("assistant").markdown(full_response)
104
- save_message(selected_chat_id, "assistant", full_response)
 
 
 
 
 
 
105
 
106
  except Exception as e:
107
- st.error(f"Error: {str(e)}")
108
- # Fallback to manual formatting if apply_chat_template fails
109
- try:
110
- system_message = system_prompt()
111
- prompt = f"<|SYSTEM|> {system_message} <|USER|> {user_input} <|ASSISTANT>"
112
- inputs = tokenizer(
113
- prompt,
114
- return_tensors="pt",
115
- add_special_tokens=True
116
- ).to(device)
117
-
118
- full_response = ""
119
- placeholder = st.empty()
120
-
121
- generated = inputs["input_ids"]
122
- outputs = model.generate(
123
- input_ids=inputs["input_ids"],
124
- attention_mask=inputs["attention_mask"],
125
- max_new_tokens=200,
126
- do_sample=False,
127
- temperature=0.5,
128
- top_p=0.9,
129
- eos_token_id=tokenizer.eos_token_id,
130
- pad_token_id=tokenizer.eos_token_id,
131
- return_dict_in_generate=True,
132
- output_scores=False
133
- )
134
- sequence = outputs.sequences[0]
135
-
136
- for i in range(generated.shape[-1], sequence.shape[-1]):
137
- token_id = sequence[i].unsqueeze(0)
138
- text = tokenizer.decode(token_id, skip_special_tokens=True)
139
- if text.strip():
140
- full_response += text
141
- placeholder.markdown(full_response)
142
-
143
- st.chat_message("assistant").markdown(full_response)
144
- save_message(selected_chat_id, "assistant", full_response)
145
-
146
- except Exception as fallback_e:
147
- st.error(f"Fallback Error: {str(fallback_e)}")
 
58
 
59
  with st.spinner("Thinking..."):
60
  try:
61
+ # Manually format the chat prompt
62
+ system_message = system_prompt()
63
+ prompt = f"<|SYSTEM|> {system_message} <|USER|> {user_input} <|ASSISTANT>"
64
+
65
+ # Tokenize the formatted prompt
66
+ inputs = tokenizer(
67
+ prompt,
 
 
 
 
 
68
  return_tensors="pt",
69
+ add_special_tokens=True
70
  ).to(device)
71
 
72
  # Generate tokens
 
76
  # Stream tokens
77
  generated = inputs["input_ids"]
78
  outputs = model.generate(
79
+ input_ids=inputs["input_ids"],
80
+ attention_mask=inputs["attention_mask"],
81
  max_new_tokens=200,
82
  do_sample=False,
83
  temperature=0.5,
 
89
  )
90
  sequence = outputs.sequences[0]
91
 
92
+ # Decode tokens one by one, preserving spaces
93
  for i in range(generated.shape[-1], sequence.shape[-1]):
94
  token_id = sequence[i].unsqueeze(0)
95
+ text = tokenizer.decode(token_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
96
+ if text:
97
  full_response += text
98
  placeholder.markdown(full_response)
99
 
100
+ # Final response, decoding only new tokens
101
+ final_response = tokenizer.decode(
102
+ sequence[generated.shape[-1]:],
103
+ skip_special_tokens=True,
104
+ clean_up_tokenization_spaces=True
105
+ ).strip()
106
+ st.chat_message("assistant").markdown(final_response)
107
+ save_message(selected_chat_id, "assistant", final_response)
108
 
109
  except Exception as e:
110
+ st.error(f"Error: {str(e)}")