Rajan Sharma commited on
Commit
8f6e031
·
verified ·
1 Parent(s): aa3bd5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -4
app.py CHANGED
@@ -4,6 +4,9 @@ from functools import lru_cache
4
  import gradio as gr
5
  import torch
6
 
 
 
 
7
  from settings import SNAPSHOT_PATH, PERSIST_CONTENT
8
  from audit_log import log_event, hash_summary
9
  from privacy import redact_text
@@ -91,12 +94,20 @@ def _history_to_prompt(message, history):
91
  parts.append("Assistant:")
92
  return "\n".join(parts)
93
 
 
 
 
 
 
 
 
 
94
  # ---------- Cohere (default path) ----------
95
  def cohere_chat(message, history):
96
  if not USE_HOSTED_COHERE:
97
  return None
98
  try:
99
- # Create client on demand to avoid init errors on some builds
100
  client = cohere.Client(api_key=COHERE_API_KEY)
101
  prompt = _history_to_prompt(message, history)
102
  resp = client.chat(
@@ -198,19 +209,37 @@ def clarityops_reply(user_msg, history, tz, uploaded_files_paths):
198
  try:
199
  log_event("user_message", None, {"sizes": {"chars": len(user_msg or "")}})
200
 
 
201
  safe_in, blocked_in, reason_in = safety_filter(user_msg, mode="input")
202
  if blocked_in:
203
  ans = refusal_reply(reason_in)
204
  return history + [(user_msg, ans)]
205
 
 
206
  if is_identity_query(safe_in, history):
207
  ans = "I am ClarityOps, your strategic decision making AI partner."
208
  return history + [(user_msg, ans)]
209
 
210
- # ---------- Ingest uploads: now returns chunks + artifacts ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  if uploaded_files_paths:
212
  ing = extract_text_from_files(uploaded_files_paths)
213
- chunks = ing.get("chunks", []) if isinstance(ing, dict) else (ing or [])
214
  artifacts = ing.get("artifacts", []) if isinstance(ing, dict) else []
215
  if chunks:
216
  _session_rag.add_docs(chunks)
@@ -218,12 +247,22 @@ def clarityops_reply(user_msg, history, tz, uploaded_files_paths):
218
  _session_rag.register_artifacts(artifacts)
219
  log_event("uploads_added", None, {"chunks": len(chunks), "artifacts": len(artifacts)})
220
 
221
- # ---------- Deterministic CSV "columns/headers" handler ----------
222
  if re.search(r"\b(columns?|headers?)\b", (safe_in or "").lower()):
223
  cols = _session_rag.get_latest_csv_columns()
224
  if cols:
225
  return history + [(user_msg, "Here are the column names from your most recent CSV upload:\n\n- " + "\n- ".join(cols))]
226
 
 
 
 
 
 
 
 
 
 
 
227
  # Retrieve from session uploads (text chunks)
228
  session_snips = "\n---\n".join(_session_rag.retrieve(
229
  "diabetes screening Indigenous Métis mobile program cost throughput outcomes logistics bed flow staffing discharge forecast",
@@ -261,15 +300,19 @@ def clarityops_reply(user_msg, history, tz, uploaded_files_paths):
261
  inputs = build_inputs(tokenizer, augmented_user, history)
262
  out = local_generate(model, tokenizer, inputs, max_new_tokens=MAX_NEW_TOKENS)
263
 
 
264
  if isinstance(out, str):
265
  for tag in ("Assistant:", "System:", "User:"):
266
  if out.startswith(tag):
267
  out = out[len(tag):].strip()
 
268
 
 
269
  safe_out, blocked_out, reason_out = safety_filter(out, mode="output")
270
  if blocked_out:
271
  safe_out = refusal_reply(reason_out)
272
 
 
273
  log_event("assistant_reply", None, {
274
  **hash_summary("prompt", augmented_user if not PERSIST_CONTENT else ""),
275
  **hash_summary("reply", safe_out if not PERSIST_CONTENT else ""),
@@ -360,3 +403,4 @@ if __name__ == "__main__":
360
 
361
 
362
 
 
 
4
  import gradio as gr
5
  import torch
6
 
7
+ # NEW: robust control-char sanitizer (requires `regex` package)
8
+ import regex as re2 # pip install regex
9
+
10
  from settings import SNAPSHOT_PATH, PERSIST_CONTENT
11
  from audit_log import log_event, hash_summary
12
  from privacy import redact_text
 
94
  parts.append("Assistant:")
95
  return "\n".join(parts)
96
 
97
+ def _sanitize_text(s: str) -> str:
98
+ """
99
+ Strip control characters (except newline/tab) to avoid garbled UI output.
100
+ """
101
+ if not isinstance(s, str):
102
+ return s
103
+ return re2.sub(r'[\p{C}--[\n\t]]+', '', s)
104
+
105
  # ---------- Cohere (default path) ----------
106
  def cohere_chat(message, history):
107
  if not USE_HOSTED_COHERE:
108
  return None
109
  try:
110
+ # Create client on demand to avoid init errors in some environments
111
  client = cohere.Client(api_key=COHERE_API_KEY)
112
  prompt = _history_to_prompt(message, history)
113
  resp = client.chat(
 
209
  try:
210
  log_event("user_message", None, {"sizes": {"chars": len(user_msg or "")}})
211
 
212
+ # Safety (input)
213
  safe_in, blocked_in, reason_in = safety_filter(user_msg, mode="input")
214
  if blocked_in:
215
  ans = refusal_reply(reason_in)
216
  return history + [(user_msg, ans)]
217
 
218
+ # Identity short-circuit
219
  if is_identity_query(safe_in, history):
220
  ans = "I am ClarityOps, your strategic decision making AI partner."
221
  return history + [(user_msg, ans)]
222
 
223
+ # Debug slash command: /diag
224
+ if (safe_in or "").strip().lower().startswith("/diag"):
225
+ try:
226
+ chunk_count = len(getattr(_session_rag, "texts", []) or [])
227
+ cols = _session_rag.get_latest_csv_columns()
228
+ sample = _session_rag.retrieve("the", k=2)
229
+ msg = [
230
+ f"Chunks in session: {chunk_count}",
231
+ f"Latest CSV columns: {', '.join(cols) if cols else '<none>'}",
232
+ "Sample retrieved snippets:",
233
+ *(sample or ["<no snippets>"])
234
+ ]
235
+ return history + [(user_msg, "\n\n".join(msg))]
236
+ except Exception as e:
237
+ return history + [(user_msg, f"Diag error: {e}")]
238
+
239
+ # Ingest uploads: returns chunks + artifacts
240
  if uploaded_files_paths:
241
  ing = extract_text_from_files(uploaded_files_paths)
242
+ chunks = ing.get("chunks", []) if isinstance(ing, dict) else (inf or [])
243
  artifacts = ing.get("artifacts", []) if isinstance(ing, dict) else []
244
  if chunks:
245
  _session_rag.add_docs(chunks)
 
247
  _session_rag.register_artifacts(artifacts)
248
  log_event("uploads_added", None, {"chunks": len(chunks), "artifacts": len(artifacts)})
249
 
250
+ # Deterministic CSV "columns/headers" handler
251
  if re.search(r"\b(columns?|headers?)\b", (safe_in or "").lower()):
252
  cols = _session_rag.get_latest_csv_columns()
253
  if cols:
254
  return history + [(user_msg, "Here are the column names from your most recent CSV upload:\n\n- " + "\n- ".join(cols))]
255
 
256
+ # Heuristic: scenario mode nudge if a long case study was pasted
257
+ plain = (safe_in or "").strip().lower()
258
+ looks_like_case = ("background" in plain and "objective" in plain) or ("case study" in plain)
259
+ if looks_like_case and len(plain) > 600:
260
+ safe_in += (
261
+ "\n\nPlease analyze the scenario above using the Expected Output Format: "
262
+ "produce structured recommendations, estimates and assumptions, include tables and bullet points, "
263
+ "and explicitly state how uploaded files (CSV/docs) influenced your estimates."
264
+ )
265
+
266
  # Retrieve from session uploads (text chunks)
267
  session_snips = "\n---\n".join(_session_rag.retrieve(
268
  "diabetes screening Indigenous Métis mobile program cost throughput outcomes logistics bed flow staffing discharge forecast",
 
300
  inputs = build_inputs(tokenizer, augmented_user, history)
301
  out = local_generate(model, tokenizer, inputs, max_new_tokens=MAX_NEW_TOKENS)
302
 
303
+ # Tidy echoes and sanitize
304
  if isinstance(out, str):
305
  for tag in ("Assistant:", "System:", "User:"):
306
  if out.startswith(tag):
307
  out = out[len(tag):].strip()
308
+ out = _sanitize_text(out)
309
 
310
+ # Safety (output)
311
  safe_out, blocked_out, reason_out = safety_filter(out, mode="output")
312
  if blocked_out:
313
  safe_out = refusal_reply(reason_out)
314
 
315
+ # Audit (content-free fingerprints)
316
  log_event("assistant_reply", None, {
317
  **hash_summary("prompt", augmented_user if not PERSIST_CONTENT else ""),
318
  **hash_summary("reply", safe_out if not PERSIST_CONTENT else ""),
 
403
 
404
 
405
 
406
+