chatbot fix sherika
Browse files- phase/Student_view/chatbot.py +31 -35
phase/Student_view/chatbot.py
CHANGED
|
@@ -17,7 +17,11 @@ TUTOR_PROMPT = (
|
|
| 17 |
"Teach step-by-step with tiny examples. Avoid giving personal financial advice."
|
| 18 |
)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
def _format_history_for_flan(messages: list[dict]) -> str:
|
|
|
|
| 21 |
lines = []
|
| 22 |
for m in messages:
|
| 23 |
txt = (m.get("text") or "").strip()
|
|
@@ -26,14 +30,8 @@ def _format_history_for_flan(messages: list[dict]) -> str:
|
|
| 26 |
lines.append(("Tutor" if m.get("sender") == "assistant" else "User") + f": {txt}")
|
| 27 |
return "\n".join(lines)
|
| 28 |
|
| 29 |
-
def _trim_turn(text: str) -> str:
|
| 30 |
-
for cp in ["\nUser:", "\nTutor:", "\nAssistant:", "\n###"]:
|
| 31 |
-
if cp in text:
|
| 32 |
-
return text.split(cp, 1)[0].strip()
|
| 33 |
-
return text.strip()
|
| 34 |
-
|
| 35 |
def _history_as_chat_messages(messages: list[dict]) -> list[dict]:
|
| 36 |
-
|
| 37 |
msgs = [{"role": "system", "content": TUTOR_PROMPT}]
|
| 38 |
for m in messages:
|
| 39 |
txt = (m.get("text") or "").strip()
|
|
@@ -44,55 +42,53 @@ def _history_as_chat_messages(messages: list[dict]) -> list[dict]:
|
|
| 44 |
return msgs
|
| 45 |
|
| 46 |
def _extract_chat_text(chat_resp) -> str:
|
| 47 |
-
|
| 48 |
try:
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
except Exception:
|
| 53 |
-
# fallback for dict payloads
|
| 54 |
try:
|
| 55 |
return chat_resp["choices"][0]["message"]["content"]
|
| 56 |
except Exception:
|
| 57 |
return str(chat_resp)
|
| 58 |
|
|
|
|
|
|
|
|
|
|
| 59 |
def _reply_with_hf():
|
| 60 |
if "client" not in globals():
|
| 61 |
raise RuntimeError("HF client not initialized")
|
| 62 |
|
| 63 |
-
# Text-generation prompt (for providers that support it)
|
| 64 |
-
convo = _format_history_for_flan(st.session_state.get("messages", []))
|
| 65 |
-
tg_prompt = f"{TUTOR_PROMPT}\n\n{convo}\n\nTutor:"
|
| 66 |
-
|
| 67 |
try:
|
| 68 |
-
# 1)
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
| 72 |
temperature=0.2,
|
| 73 |
top_p=0.9,
|
| 74 |
-
repetition_penalty=1.1,
|
| 75 |
-
return_full_text=True,
|
| 76 |
-
stream=False,
|
| 77 |
)
|
| 78 |
-
|
| 79 |
-
return _trim_turn(str(text or "").strip())
|
| 80 |
|
| 81 |
except ValueError as ve:
|
| 82 |
-
# 2)
|
| 83 |
-
if "Supported task:
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
temperature=0.2,
|
| 90 |
top_p=0.9,
|
|
|
|
|
|
|
|
|
|
| 91 |
)
|
| 92 |
-
return
|
| 93 |
|
| 94 |
-
#
|
| 95 |
-
raise
|
| 96 |
|
| 97 |
except Exception as e:
|
| 98 |
err_text = ''.join(traceback.format_exception_only(type(e), e)).strip()
|
|
|
|
| 17 |
"Teach step-by-step with tiny examples. Avoid giving personal financial advice."
|
| 18 |
)
|
| 19 |
|
| 20 |
+
# -------------------------------
|
| 21 |
+
# History helpers
|
| 22 |
+
# -------------------------------
|
| 23 |
def _format_history_for_flan(messages: list[dict]) -> str:
|
| 24 |
+
"""Format history for text-generation style models."""
|
| 25 |
lines = []
|
| 26 |
for m in messages:
|
| 27 |
txt = (m.get("text") or "").strip()
|
|
|
|
| 30 |
lines.append(("Tutor" if m.get("sender") == "assistant" else "User") + f": {txt}")
|
| 31 |
return "\n".join(lines)
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def _history_as_chat_messages(messages: list[dict]) -> list[dict]:
|
| 34 |
+
"""Convert history to chat-completion style messages."""
|
| 35 |
msgs = [{"role": "system", "content": TUTOR_PROMPT}]
|
| 36 |
for m in messages:
|
| 37 |
txt = (m.get("text") or "").strip()
|
|
|
|
| 42 |
return msgs
|
| 43 |
|
| 44 |
def _extract_chat_text(chat_resp) -> str:
|
| 45 |
+
"""Extract text from HF chat response."""
|
| 46 |
try:
|
| 47 |
+
return chat_resp.choices[0].message["content"] if isinstance(
|
| 48 |
+
chat_resp.choices[0].message, dict
|
| 49 |
+
) else chat_resp.choices[0].message.content
|
| 50 |
except Exception:
|
|
|
|
| 51 |
try:
|
| 52 |
return chat_resp["choices"][0]["message"]["content"]
|
| 53 |
except Exception:
|
| 54 |
return str(chat_resp)
|
| 55 |
|
| 56 |
+
# -------------------------------
|
| 57 |
+
# Reply logic
|
| 58 |
+
# -------------------------------
|
| 59 |
def _reply_with_hf():
|
| 60 |
if "client" not in globals():
|
| 61 |
raise RuntimeError("HF client not initialized")
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
try:
|
| 64 |
+
# 1) Prefer chat API
|
| 65 |
+
msgs = _history_as_chat_messages(st.session_state.get("messages", []))
|
| 66 |
+
chat = client.chat.completions.create(
|
| 67 |
+
model=GEN_MODEL,
|
| 68 |
+
messages=msgs,
|
| 69 |
+
max_tokens=300, # give enough room
|
| 70 |
temperature=0.2,
|
| 71 |
top_p=0.9,
|
|
|
|
|
|
|
|
|
|
| 72 |
)
|
| 73 |
+
return _extract_chat_text(chat).strip()
|
|
|
|
| 74 |
|
| 75 |
except ValueError as ve:
|
| 76 |
+
# 2) Fallback to text-generation if chat unsupported
|
| 77 |
+
if "Supported task: text-generation" in str(ve):
|
| 78 |
+
convo = _format_history_for_flan(st.session_state.get("messages", []))
|
| 79 |
+
tg_prompt = f"{TUTOR_PROMPT}\n\n{convo}\n\nTutor:"
|
| 80 |
+
resp = client.text_generation(
|
| 81 |
+
tg_prompt,
|
| 82 |
+
max_new_tokens=300,
|
| 83 |
temperature=0.2,
|
| 84 |
top_p=0.9,
|
| 85 |
+
repetition_penalty=1.1,
|
| 86 |
+
return_full_text=True,
|
| 87 |
+
stream=False,
|
| 88 |
)
|
| 89 |
+
return (resp.get("generated_text") if isinstance(resp, dict) else resp).strip()
|
| 90 |
|
| 91 |
+
raise # rethrow anything else
|
|
|
|
| 92 |
|
| 93 |
except Exception as e:
|
| 94 |
err_text = ''.join(traceback.format_exception_only(type(e), e)).strip()
|