Rajan Sharma commited on
Commit
ae93cdb
·
verified ·
1 Parent(s): 3d6a03f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -40
app.py CHANGED
@@ -15,8 +15,10 @@ os.environ.setdefault("GRADIO_TEMP_DIR", "/data/gradio")
15
  os.environ.setdefault("GRADIO_CACHE_DIR", "/data/gradio")
16
  os.environ.pop("TRANSFORMERS_CACHE", None)
17
  for p in ["/data/.cache/huggingface/hub", "/data/gradio"]:
18
- try: os.makedirs(p, exist_ok=True)
19
- except Exception: pass
 
 
20
 
21
  # Optional timezone
22
  try:
@@ -42,31 +44,49 @@ from upload_ingest import extract_text_from_files
42
  from session_rag import SessionRAG
43
  from mdsi_analysis import capacity_projection, cost_estimate, outcomes_summary
44
 
45
- MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024")
 
 
46
  HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
 
47
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
48
  USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
 
 
 
49
 
50
  # ---------- Helpers ----------
51
  def pick_dtype_and_map():
52
- if torch.cuda.is_available(): return torch.float16, "auto"
53
- if torch.backends.mps.is_available(): return torch.float16, {"": "mps"}
 
 
54
  return torch.float32, "cpu"
55
 
56
  def is_identity_query(message, history):
57
  patterns = [
58
- r"\bwho\s+are\s+you\b", r"\bwhat\s+are\s+you\b", r"\bwhat\s+is\s+your\s+name\b",
59
- r"\bwho\s+is\s+this\b", r"\bidentify\s+yourself\b", r"\btell\s+me\s+about\s+yourself\b",
60
- r"\bdescribe\s+yourself\b", r"\band\s+you\s*\?\b", r"\byour\s+name\b", r"\bwho\s+am\s+i\s+chatting\s+with\b"
 
 
 
 
 
 
 
61
  ]
62
  def match(t): return any(re.search(p, (t or "").strip().lower()) for p in patterns)
63
- if match(message): return True
 
64
  if history:
65
  last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) else None
66
- if match(last_user): return True
 
67
  return False
68
 
69
  def _iter_user_assistant(history):
 
70
  for item in (history or []):
71
  if isinstance(item, (list, tuple)):
72
  u = item[0] if len(item) > 0 else ""
@@ -82,28 +102,36 @@ def _history_to_prompt(message, history):
82
  parts.append("Assistant:")
83
  return "\n".join(parts)
84
 
85
- # ---------- Cohere path ----------
86
  _co_client = None
87
  if USE_HOSTED_COHERE:
88
- _co_client = cohere.Client(api_key=COHERE_API_KEY)
89
 
90
  def cohere_chat(message, history):
 
 
 
 
 
91
  try:
92
  prompt = _history_to_prompt(message, history)
93
  resp = _co_client.chat(
94
  model="command-r7b-12-2024",
95
  message=prompt,
96
  temperature=0.3,
97
- max_tokens=900,
98
  )
99
- if hasattr(resp, "text") and resp.text: return resp.text.strip()
100
- if hasattr(resp, "reply") and resp.reply: return resp.reply.strip()
101
- if hasattr(resp, "generations") and resp.generations: return resp.generations[0].text.strip()
102
- return "Sorry, I couldn't parse the response from Cohere."
103
- except Exception as e:
104
- return f"Error calling Cohere API: {e}"
 
 
 
105
 
106
- # ---------- Local model (with accelerate fallback) ----------
107
  @lru_cache(maxsize=1)
108
  def load_local_model():
109
  if not HF_TOKEN:
@@ -111,9 +139,10 @@ def load_local_model():
111
  login(token=HF_TOKEN, add_to_git_credential=False)
112
  dtype, device_map = pick_dtype_and_map()
113
  tok = AutoTokenizer.from_pretrained(
114
- MODEL_ID, token=HF_TOKEN, use_fast=True, model_max_length=8192, padding_side="left", trust_remote_code=True,
 
115
  )
116
- # Try device_map path (needs accelerate). Fallback to manual .to(device) if it fails.
117
  try:
118
  mdl = AutoModelForCausalLM.from_pretrained(
119
  MODEL_ID, token=HF_TOKEN, device_map=device_map,
@@ -130,19 +159,25 @@ def load_local_model():
130
  return mdl, tok
131
 
132
  def build_inputs(tokenizer, message, history):
 
133
  msgs = []
134
  for u, a in _iter_user_assistant(history):
135
  if u: msgs.append({"role": "user", "content": u})
136
  if a: msgs.append({"role": "assistant", "content": a})
137
  msgs.append({"role": "user", "content": message})
138
- return tokenizer.apply_chat_template(msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt")
 
 
139
 
140
- def local_generate(model, tokenizer, input_ids, max_new_tokens=900):
141
  input_ids = input_ids.to(model.device)
142
  with torch.no_grad():
143
  out = model.generate(
144
- input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.3, top_p=0.9,
145
- repetition_penalty=1.15, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
 
 
 
146
  )
147
  gen_only = out[0, input_ids.shape[-1]:]
148
  return tokenizer.decode(gen_only, skip_special_tokens=True).strip()
@@ -153,6 +188,7 @@ def _load_snapshot(path=SNAPSHOT_PATH):
153
  with open(path, "r", encoding="utf-8") as f:
154
  return json.load(f)
155
  except Exception:
 
156
  return {
157
  "timestamp": None, "beds_total": 400, "staffed_ratio": 1.0, "occupied_pct": 0.97,
158
  "ed_census": 62, "ed_admits_waiting": 19, "avg_ed_wait_hours": 8,
@@ -164,7 +200,7 @@ def _load_snapshot(path=SNAPSHOT_PATH):
164
 
165
  # ---------- Init retrieval engines ----------
166
  init_retriever()
167
- _session_rag = SessionRAG() # in-memory only
168
 
169
  # ---------- Executive pre-compute (MDSi block) ----------
170
  def _mdsi_block():
@@ -179,17 +215,25 @@ def _mdsi_block():
179
  "outcomes_summary": outcomes
180
  }, indent=2)
181
 
182
- # ---------- Core chat logic ----------
183
  def clarityops_reply(user_msg, history, tz, uploaded_files_paths):
 
 
 
 
 
 
184
  try:
185
  # Audit (content-free)
186
  log_event("user_message", None, {"sizes": {"chars": len(user_msg or "")}})
187
 
 
188
  safe_in, blocked_in, reason_in = safety_filter(user_msg, mode="input")
189
  if blocked_in:
190
  ans = refusal_reply(reason_in)
191
  return history + [(user_msg, ans)]
192
 
 
193
  if is_identity_query(safe_in, history):
194
  ans = "I am ClarityOps, your strategic decision making AI partner."
195
  return history + [(user_msg, ans)]
@@ -199,24 +243,25 @@ def clarityops_reply(user_msg, history, tz, uploaded_files_paths):
199
  items = extract_text_from_files(uploaded_files_paths)
200
  if items:
201
  _session_rag.add_docs(items)
202
- # Audit upload names & sizes only
203
  log_event("uploads_added", None, {"count": len(items)})
204
 
205
  # Retrieve from session uploads
206
  session_snips = "\n---\n".join(_session_rag.retrieve(
207
- "diabetes screening Indigenous Métis mobile program cost throughput outcomes logistics bed flow staffing discharge forecast", k=6
 
208
  ))
209
 
 
210
  snapshot = _load_snapshot()
211
  policy_context = retrieve_context(
212
  "mobile diabetes screening Indigenous community outreach logistics referral pathways cultural safety data governance cost effectiveness outcomes bed management discharge acceleration ambulance offload"
213
  )
214
  computed = compute_operational_numbers(snapshot)
215
 
 
216
  user_lower = (safe_in or "").lower()
217
  mdsi_extra = _mdsi_block() if ("diabetes" in user_lower or "mdsi" in user_lower or "mobile screening" in user_lower) else ""
218
 
219
- # Optionally include long scenario text; redact if persisting later (we don't persist by default)
220
  scenario_block = safe_in if len(safe_in) > 400 else ""
221
  system_preamble = build_system_preamble(
222
  snapshot=snapshot,
@@ -228,13 +273,16 @@ def clarityops_reply(user_msg, history, tz, uploaded_files_paths):
228
 
229
  augmented_user = system_preamble + "\n\nUser question or request:\n" + safe_in
230
 
231
- # Generate
 
232
  if USE_HOSTED_COHERE:
233
  out = cohere_chat(augmented_user, history)
234
- else:
 
 
235
  model, tokenizer = load_local_model()
236
  inputs = build_inputs(tokenizer, augmented_user, history)
237
- out = local_generate(model, tokenizer, inputs, max_new_tokens=900)
238
 
239
  # Tidy echoes
240
  if isinstance(out, str):
@@ -242,6 +290,7 @@ def clarityops_reply(user_msg, history, tz, uploaded_files_paths):
242
  if out.startswith(tag):
243
  out = out[len(tag):].strip()
244
 
 
245
  safe_out, blocked_out, reason_out = safety_filter(out, mode="output")
246
  if blocked_out:
247
  safe_out = refusal_reply(reason_out)
@@ -261,20 +310,43 @@ theme = gr.themes.Soft(primary_hue="teal", neutral_hue="slate", radius_size=gr.t
261
  custom_css = """
262
  :root { --brand-bg: #e6f7f8; --brand-accent: #0d9488; --brand-text: #0f172a; --brand-text-light: #ffffff; }
263
  .gradio-container { background: var(--brand-bg); }
 
 
264
  h1 { color: var(--brand-text); font-weight: 700; font-size: 28px !important; }
265
- .chatbot header, .chatbot .label, .chatbot .label-wrap, .chatbot .top, .chatbot .header, .chatbot > .wrap > header { display: none !important; }
266
- .message.user, .message.bot { background: var(--brand-accent) !important; color: var(--brand-text-light) !important; border-radius: 12px !important; padding: 8px 12px !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  textarea, input, .gr-input { border-radius: 12px !important; }
268
  """
269
 
270
  # ---------- UI ----------
271
  with gr.Blocks(theme=theme, css=custom_css) as demo:
272
  tz_box = gr.Textbox(visible=False)
273
- demo.load(lambda tz: tz, inputs=[tz_box], outputs=[tz_box], js="() => Intl.DateTimeFormat().resolvedOptions().timeZone")
 
 
 
 
 
 
274
  gr.Markdown("# ClarityOps Augmented Decision AI")
275
 
 
276
  chat = gr.Chatbot(label="", show_label=False, height=700)
277
 
 
278
  with gr.Row():
279
  uploads = gr.Files(
280
  label="Upload docs/images (PDF, DOCX, CSV, PNG, JPG)",
@@ -282,13 +354,20 @@ with gr.Blocks(theme=theme, css=custom_css) as demo:
282
  )
283
 
284
  with gr.Row():
285
- msg = gr.Textbox(label="", show_label=False, placeholder="Type a message… (paste scenarios here too; ClarityOps will adapt)", scale=10)
 
 
 
 
 
286
  send = gr.Button("Send", scale=1)
287
  clear = gr.Button("Clear chat", scale=1)
288
 
 
289
  state_history = gr.State(value=[])
290
  state_uploaded = gr.State(value=[])
291
 
 
292
  def _store_uploads(files, current):
293
  paths = []
294
  for f in (files or []):
@@ -297,17 +376,33 @@ with gr.Blocks(theme=theme, css=custom_css) as demo:
297
 
298
  uploads.change(fn=_store_uploads, inputs=[uploads, state_uploaded], outputs=state_uploaded)
299
 
 
300
  def _on_send(user_msg, history, tz, up_paths):
301
  if not user_msg or not user_msg.strip():
302
  return history, "", history
303
  new_history = clarityops_reply(user_msg.strip(), history or [], tz, up_paths or [])
304
  return new_history, "", new_history
305
 
306
- send.click(fn=_on_send, inputs=[msg, state_history, tz_box, state_uploaded], outputs=[chat, msg, state_history], queue=True)
307
- msg.submit(fn=_on_send, inputs=[msg, state_history, tz_box, state_uploaded], outputs=[chat, msg, state_history], queue=True)
 
 
 
 
 
 
 
 
 
 
308
 
 
309
  clear.click(lambda: ([], "", []), None, [chat, msg, state_history])
310
 
 
 
 
311
  if __name__ == "__main__":
312
  port = int(os.environ.get("PORT", "7860"))
313
  demo.launch(server_name="0.0.0.0", server_port=port, show_api=False, max_threads=8)
 
 
15
  os.environ.setdefault("GRADIO_CACHE_DIR", "/data/gradio")
16
  os.environ.pop("TRANSFORMERS_CACHE", None)
17
  for p in ["/data/.cache/huggingface/hub", "/data/gradio"]:
18
+ try:
19
+ os.makedirs(p, exist_ok=True)
20
+ except Exception:
21
+ pass
22
 
23
  # Optional timezone
24
  try:
 
44
  from session_rag import SessionRAG
45
  from mdsi_analysis import capacity_projection, cost_estimate, outcomes_summary
46
 
47
+ # ---------- Config ----------
48
+ # Local fallback model (lightweight by default). You can override via env.
49
+ MODEL_ID = os.getenv("MODEL_ID", "microsoft/Phi-3-mini-4k-instruct")
50
  HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
51
+
52
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
53
  USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
54
+ COHERE_TIMEOUT_SEC = float(os.getenv("COHERE_TIMEOUT_SEC", "30"))
55
+
56
+ MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512")) # faster defaults; adjust as needed
57
 
58
  # ---------- Helpers ----------
59
  def pick_dtype_and_map():
60
+ if torch.cuda.is_available():
61
+ return torch.float16, "auto"
62
+ if torch.backends.mps.is_available():
63
+ return torch.float16, {"": "mps"}
64
  return torch.float32, "cpu"
65
 
66
  def is_identity_query(message, history):
67
  patterns = [
68
+ r"\bwho\s+are\s+you\b",
69
+ r"\bwhat\s+are\s+you\b",
70
+ r"\bwhat\s+is\s+your\s+name\b",
71
+ r"\bwho\s+is\s+this\b",
72
+ r"\bidentify\s+yourself\b",
73
+ r"\btell\s+me\s+about\s+yourself\b",
74
+ r"\bdescribe\s+yourself\b",
75
+ r"\band\s+you\s*\?\b",
76
+ r"\byour\s+name\b",
77
+ r"\bwho\s+am\s+i\s+chatting\s+with\b",
78
  ]
79
  def match(t): return any(re.search(p, (t or "").strip().lower()) for p in patterns)
80
+ if match(message):
81
+ return True
82
  if history:
83
  last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) else None
84
+ if match(last_user):
85
+ return True
86
  return False
87
 
88
  def _iter_user_assistant(history):
89
+ # history is a list of (user, assistant) tuples (Chatbot default format)
90
  for item in (history or []):
91
  if isinstance(item, (list, tuple)):
92
  u = item[0] if len(item) > 0 else ""
 
102
  parts.append("Assistant:")
103
  return "\n".join(parts)
104
 
105
+ # ---------- Cohere path (default first; fallback to local on failure) ----------
106
  _co_client = None
107
  if USE_HOSTED_COHERE:
108
+ _co_client = cohere.Client(api_key=COHERE_API_KEY, timeout=COHERE_TIMEOUT_SEC)
109
 
110
  def cohere_chat(message, history):
111
+ """
112
+ Returns text on success, or None to signal fallback to local model.
113
+ """
114
+ if not _co_client:
115
+ return None
116
  try:
117
  prompt = _history_to_prompt(message, history)
118
  resp = _co_client.chat(
119
  model="command-r7b-12-2024",
120
  message=prompt,
121
  temperature=0.3,
122
+ max_tokens=MAX_NEW_TOKENS,
123
  )
124
+ if hasattr(resp, "text") and resp.text:
125
+ return resp.text.strip()
126
+ if hasattr(resp, "reply") and resp.reply:
127
+ return resp.reply.strip()
128
+ if hasattr(resp, "generations") and resp.generations:
129
+ return resp.generations[0].text.strip()
130
+ return None
131
+ except Exception:
132
+ return None
133
 
134
+ # ---------- Local model (accelerate-safe fallback) ----------
135
  @lru_cache(maxsize=1)
136
  def load_local_model():
137
  if not HF_TOKEN:
 
139
  login(token=HF_TOKEN, add_to_git_credential=False)
140
  dtype, device_map = pick_dtype_and_map()
141
  tok = AutoTokenizer.from_pretrained(
142
+ MODEL_ID, token=HF_TOKEN, use_fast=True, model_max_length=8192,
143
+ padding_side="left", trust_remote_code=True,
144
  )
145
+ # Try device_map (needs accelerate); fallback to manual .to(device) if it fails.
146
  try:
147
  mdl = AutoModelForCausalLM.from_pretrained(
148
  MODEL_ID, token=HF_TOKEN, device_map=device_map,
 
159
  return mdl, tok
160
 
161
  def build_inputs(tokenizer, message, history):
162
+ # Convert tuple history to chat template input for HF models
163
  msgs = []
164
  for u, a in _iter_user_assistant(history):
165
  if u: msgs.append({"role": "user", "content": u})
166
  if a: msgs.append({"role": "assistant", "content": a})
167
  msgs.append({"role": "user", "content": message})
168
+ return tokenizer.apply_chat_template(
169
+ msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt"
170
+ )
171
 
172
+ def local_generate(model, tokenizer, input_ids, max_new_tokens=MAX_NEW_TOKENS):
173
  input_ids = input_ids.to(model.device)
174
  with torch.no_grad():
175
  out = model.generate(
176
+ input_ids=input_ids, max_new_tokens=max_new_tokens,
177
+ do_sample=True, temperature=0.3, top_p=0.9,
178
+ repetition_penalty=1.15,
179
+ pad_token_id=tokenizer.eos_token_id,
180
+ eos_token_id=tokenizer.eos_token_id,
181
  )
182
  gen_only = out[0, input_ids.shape[-1]:]
183
  return tokenizer.decode(gen_only, skip_special_tokens=True).strip()
 
188
  with open(path, "r", encoding="utf-8") as f:
189
  return json.load(f)
190
  except Exception:
191
+ # Safe fallback if no snapshot present
192
  return {
193
  "timestamp": None, "beds_total": 400, "staffed_ratio": 1.0, "occupied_pct": 0.97,
194
  "ed_census": 62, "ed_admits_waiting": 19, "avg_ed_wait_hours": 8,
 
200
 
201
  # ---------- Init retrieval engines ----------
202
  init_retriever()
203
+ _session_rag = SessionRAG() # in-memory only; lazy-loads embeddings
204
 
205
  # ---------- Executive pre-compute (MDSi block) ----------
206
  def _mdsi_block():
 
215
  "outcomes_summary": outcomes
216
  }, indent=2)
217
 
218
+ # ---------- Core chat logic (Cohere-first with fallback) ----------
219
  def clarityops_reply(user_msg, history, tz, uploaded_files_paths):
220
+ """
221
+ - user_msg: latest message text
222
+ - history: list[(user, assistant)]
223
+ - tz: timezone str (unused but kept for future features)
224
+ - uploaded_files_paths: list[str] absolute paths of uploaded files
225
+ """
226
  try:
227
  # Audit (content-free)
228
  log_event("user_message", None, {"sizes": {"chars": len(user_msg or "")}})
229
 
230
+ # Safety (input)
231
  safe_in, blocked_in, reason_in = safety_filter(user_msg, mode="input")
232
  if blocked_in:
233
  ans = refusal_reply(reason_in)
234
  return history + [(user_msg, ans)]
235
 
236
+ # Identity short-circuit
237
  if is_identity_query(safe_in, history):
238
  ans = "I am ClarityOps, your strategic decision making AI partner."
239
  return history + [(user_msg, ans)]
 
243
  items = extract_text_from_files(uploaded_files_paths)
244
  if items:
245
  _session_rag.add_docs(items)
 
246
  log_event("uploads_added", None, {"count": len(items)})
247
 
248
  # Retrieve from session uploads
249
  session_snips = "\n---\n".join(_session_rag.retrieve(
250
+ "diabetes screening Indigenous Métis mobile program cost throughput outcomes logistics bed flow staffing discharge forecast",
251
+ k=6
252
  ))
253
 
254
+ # Load daily snapshot + policies + computed ops numbers
255
  snapshot = _load_snapshot()
256
  policy_context = retrieve_context(
257
  "mobile diabetes screening Indigenous community outreach logistics referral pathways cultural safety data governance cost effectiveness outcomes bed management discharge acceleration ambulance offload"
258
  )
259
  computed = compute_operational_numbers(snapshot)
260
 
261
+ # Exec scenario detect (MDSi)
262
  user_lower = (safe_in or "").lower()
263
  mdsi_extra = _mdsi_block() if ("diabetes" in user_lower or "mdsi" in user_lower or "mobile screening" in user_lower) else ""
264
 
 
265
  scenario_block = safe_in if len(safe_in) > 400 else ""
266
  system_preamble = build_system_preamble(
267
  snapshot=snapshot,
 
273
 
274
  augmented_user = system_preamble + "\n\nUser question or request:\n" + safe_in
275
 
276
+ # --- Cohere first ---
277
+ out = None
278
  if USE_HOSTED_COHERE:
279
  out = cohere_chat(augmented_user, history)
280
+
281
+ # --- Fallback to local HF model if Cohere not set or fails ---
282
+ if not out:
283
  model, tokenizer = load_local_model()
284
  inputs = build_inputs(tokenizer, augmented_user, history)
285
+ out = local_generate(model, tokenizer, inputs, max_new_tokens=MAX_NEW_TOKENS)
286
 
287
  # Tidy echoes
288
  if isinstance(out, str):
 
290
  if out.startswith(tag):
291
  out = out[len(tag):].strip()
292
 
293
+ # Safety (output)
294
  safe_out, blocked_out, reason_out = safety_filter(out, mode="output")
295
  if blocked_out:
296
  safe_out = refusal_reply(reason_out)
 
310
  custom_css = """
311
  :root { --brand-bg: #e6f7f8; --brand-accent: #0d9488; --brand-text: #0f172a; --brand-text-light: #ffffff; }
312
  .gradio-container { background: var(--brand-bg); }
313
+
314
+ /* Title */
315
  h1 { color: var(--brand-text); font-weight: 700; font-size: 28px !important; }
316
+
317
+ /* Hide default Chatbot label */
318
+ .chatbot header, .chatbot .label, .chatbot .label-wrap, .chatbot .top, .chatbot .header, .chatbot > .wrap > header {
319
+ display: none !important;
320
+ }
321
+
322
+ /* Chat bubbles */
323
+ .message.user, .message.bot {
324
+ background: var(--brand-accent) !important;
325
+ color: var(--brand-text-light) !important;
326
+ border-radius: 12px !important;
327
+ padding: 8px 12px !important;
328
+ }
329
+
330
+ /* Inputs softer */
331
  textarea, input, .gr-input { border-radius: 12px !important; }
332
  """
333
 
334
  # ---------- UI ----------
335
  with gr.Blocks(theme=theme, css=custom_css) as demo:
336
  tz_box = gr.Textbox(visible=False)
337
+ demo.load(
338
+ lambda tz: tz,
339
+ inputs=[tz_box],
340
+ outputs=[tz_box],
341
+ js="() => Intl.DateTimeFormat().resolvedOptions().timeZone",
342
+ )
343
+
344
  gr.Markdown("# ClarityOps Augmented Decision AI")
345
 
346
+ # Main chat (tuple-format history)
347
  chat = gr.Chatbot(label="", show_label=False, height=700)
348
 
349
+ # Uploads above the input
350
  with gr.Row():
351
  uploads = gr.Files(
352
  label="Upload docs/images (PDF, DOCX, CSV, PNG, JPG)",
 
354
  )
355
 
356
  with gr.Row():
357
+ msg = gr.Textbox(
358
+ label="",
359
+ show_label=False,
360
+ placeholder="Type a message… (paste scenarios here too; ClarityOps will adapt)",
361
+ scale=10
362
+ )
363
  send = gr.Button("Send", scale=1)
364
  clear = gr.Button("Clear chat", scale=1)
365
 
366
+ # State
367
  state_history = gr.State(value=[])
368
  state_uploaded = gr.State(value=[])
369
 
370
+ # Store uploaded file paths in state (persist through session)
371
  def _store_uploads(files, current):
372
  paths = []
373
  for f in (files or []):
 
376
 
377
  uploads.change(fn=_store_uploads, inputs=[uploads, state_uploaded], outputs=state_uploaded)
378
 
379
+ # Send / Enter handlers
380
  def _on_send(user_msg, history, tz, up_paths):
381
  if not user_msg or not user_msg.strip():
382
  return history, "", history
383
  new_history = clarityops_reply(user_msg.strip(), history or [], tz, up_paths or [])
384
  return new_history, "", new_history
385
 
386
+ send.click(
387
+ fn=_on_send,
388
+ inputs=[msg, state_history, tz_box, state_uploaded],
389
+ outputs=[chat, msg, state_history],
390
+ queue=True,
391
+ )
392
+ msg.submit(
393
+ fn=_on_send,
394
+ inputs=[msg, state_history, tz_box, state_uploaded],
395
+ outputs=[chat, msg, state_history],
396
+ queue=True,
397
+ )
398
 
399
+ # Clear chat (keep uploads)
400
  clear.click(lambda: ([], "", []), None, [chat, msg, state_history])
401
 
402
+ # Enable queue to avoid websocket timeouts on first call / heavy loads
403
+ demo = demo.queue(concurrency_count=2, max_size=32)
404
+
405
  if __name__ == "__main__":
406
  port = int(os.environ.get("PORT", "7860"))
407
  demo.launch(server_name="0.0.0.0", server_port=port, show_api=False, max_threads=8)
408
+