Anirudh Esthuri commited on
Commit
4e58ae3
·
1 Parent(s): 14aec40

Fix: Add provider parameter to rewrite_message function to fix NameError

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -13,7 +13,7 @@ os.environ["STREAMLIT_SERVER_ADDRESS"] = "0.0.0.0"
13
 
14
 
15
  def rewrite_message(
16
- msg: str, persona_name: str, show_rationale: bool, skip_rewrite: bool
17
  ) -> str:
18
  if skip_rewrite:
19
  rewritten_msg = msg
@@ -140,13 +140,13 @@ if msg:
140
  # rewritten_msg = "Use the persona profile to personalize your naswer only when applicable.\n"
141
  if compare_personas:
142
  all_answers = {}
143
- rewritten_msg = rewrite_message(msg, persona_name, show_rationale, False)
144
  msgs = clean_history(st.session_state.history, persona_name)
145
  msgs = append_user_turn(msgs, rewritten_msg)
146
  txt, lat, tok, tps = chat(msgs, persona_name)
147
  all_answers[persona_name] = txt
148
 
149
- rewritten_msg_control = rewrite_message(msg, "Control", show_rationale, True)
150
  msgs_control = clean_history(st.session_state.history, "Control")
151
  msgs_control = append_user_turn(msgs_control, rewritten_msg_control)
152
  txt_control, lat, tok, tps = chat(msgs_control, "Arnold")
@@ -156,7 +156,7 @@ if msg:
156
  {"role": "assistant_all", "axis": "role", "content": all_answers}
157
  )
158
  else:
159
- rewritten_msg = rewrite_message(msg, persona_name, show_rationale, skip_rewrite)
160
  msgs = clean_history(st.session_state.history, persona_name)
161
  msgs = append_user_turn(msgs, rewritten_msg)
162
  txt, lat, tok, tps = chat(
 
13
 
14
 
15
  def rewrite_message(
16
+ msg: str, persona_name: str, show_rationale: bool, skip_rewrite: bool, provider: str = "openai"
17
  ) -> str:
18
  if skip_rewrite:
19
  rewritten_msg = msg
 
140
  # rewritten_msg = "Use the persona profile to personalize your naswer only when applicable.\n"
141
  if compare_personas:
142
  all_answers = {}
143
+ rewritten_msg = rewrite_message(msg, persona_name, show_rationale, False, provider)
144
  msgs = clean_history(st.session_state.history, persona_name)
145
  msgs = append_user_turn(msgs, rewritten_msg)
146
  txt, lat, tok, tps = chat(msgs, persona_name)
147
  all_answers[persona_name] = txt
148
 
149
+ rewritten_msg_control = rewrite_message(msg, "Control", show_rationale, True, provider)
150
  msgs_control = clean_history(st.session_state.history, "Control")
151
  msgs_control = append_user_turn(msgs_control, rewritten_msg_control)
152
  txt_control, lat, tok, tps = chat(msgs_control, "Arnold")
 
156
  {"role": "assistant_all", "axis": "role", "content": all_answers}
157
  )
158
  else:
159
+ rewritten_msg = rewrite_message(msg, persona_name, show_rationale, skip_rewrite, provider)
160
  msgs = clean_history(st.session_state.history, persona_name)
161
  msgs = append_user_turn(msgs, rewritten_msg)
162
  txt, lat, tok, tps = chat(