Rajan Sharma commited on
Commit
744c807
·
verified ·
1 Parent(s): f209ae5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -65
app.py CHANGED
@@ -1,23 +1,28 @@
1
- \
2
  import os, re, json
3
  from functools import lru_cache
4
 
5
  import gradio as gr
6
  import torch
7
 
 
8
  os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
9
  os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")
10
  os.environ.setdefault("GRADIO_TEMP_DIR", "/data/gradio")
11
  os.environ.setdefault("GRADIO_CACHE_DIR", "/data/gradio")
 
12
  for p in ["/data/.cache/huggingface/hub", "/data/gradio"]:
13
- try: os.makedirs(p, exist_ok=True)
14
- except Exception: pass
 
 
15
 
 
16
  try:
17
- from zoneinfo import ZoneInfo
18
  except Exception:
19
- ZoneInfo = None
20
 
 
21
  try:
22
  import cohere
23
  _HAS_COHERE = True
@@ -27,6 +32,7 @@ except Exception:
27
  from transformers import AutoTokenizer, AutoModelForCausalLM
28
  from huggingface_hub import login
29
 
 
30
  from safety import safety_filter, refusal_reply
31
  from retriever import init_retriever, retrieve_context
32
  from decision_math import compute_operational_numbers
@@ -35,27 +41,40 @@ from upload_ingest import extract_text_from_files
35
  from session_rag import SessionRAG
36
  from mdsi_analysis import capacity_projection, cost_estimate, outcomes_summary
37
 
 
38
  MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024")
39
  HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
40
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
41
  USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
42
 
 
43
  def pick_dtype_and_map():
44
- if torch.cuda.is_available(): return torch.float16, "auto"
45
- if torch.backends.mps.is_available(): return torch.float16, {"": "mps"}
 
 
46
  return torch.float32, "cpu"
47
 
48
  def is_identity_query(message, history):
49
  patterns = [
50
- r"\bwho\s+are\s+you\b", r"\bwhat\s+are\s+you\b", r"\bwhat\s+is\s+your\s+name\b",
51
- r"\bwho\s+is\s+this\b", r"\bidentify\s+yourself\b", r"\btell\s+me\s+about\s+yourself\b",
52
- r"\bdescribe\s+yourself\b", r"\band\s+you\s*\?\b", r"\byour\s+name\b", r"\bwho\s+am\s+i\s+chatting\s+with\b"
 
 
 
 
 
 
 
53
  ]
54
  def match(t): return any(re.search(p, (t or "").strip().lower()) for p in patterns)
55
- if match(message): return True
 
56
  if history:
57
  last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) else None
58
- if match(last_user): return True
 
59
  return False
60
 
61
  def _iter_user_assistant(history):
@@ -74,6 +93,7 @@ def _history_to_prompt(message, history):
74
  parts.append("Assistant:")
75
  return "\n".join(parts)
76
 
 
77
  _co_client = None
78
  if USE_HOSTED_COHERE:
79
  _co_client = cohere.Client(api_key=COHERE_API_KEY)
@@ -85,7 +105,7 @@ def cohere_chat(message, history):
85
  model="command-r7b-12-2024",
86
  message=prompt,
87
  temperature=0.3,
88
- max_tokens=700,
89
  )
90
  if hasattr(resp, "text") and resp.text: return resp.text.strip()
91
  if hasattr(resp, "reply") and resp.reply: return resp.reply.strip()
@@ -94,9 +114,11 @@ def cohere_chat(message, history):
94
  except Exception as e:
95
  return f"Error calling Cohere API: {e}"
96
 
 
97
  @lru_cache(maxsize=1)
98
  def load_local_model():
99
- if not HF_TOKEN: raise RuntimeError("HUGGINGFACE_HUB_TOKEN is not set.")
 
100
  login(token=HF_TOKEN, add_to_git_credential=False)
101
  dtype, device_map = pick_dtype_and_map()
102
  tok = AutoTokenizer.from_pretrained(
@@ -116,23 +138,30 @@ def build_inputs(tokenizer, message, history):
116
  if u: msgs.append({"role": "user", "content": u})
117
  if a: msgs.append({"role": "assistant", "content": a})
118
  msgs.append({"role": "user", "content": message})
119
- return tokenizer.apply_chat_template(msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt")
 
 
120
 
121
  def local_generate(model, tokenizer, input_ids, max_new_tokens=900):
122
  input_ids = input_ids.to(model.device)
123
  with torch.no_grad():
124
  out = model.generate(
125
- input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.3, top_p=0.9,
126
- repetition_penalty=1.15, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
 
 
 
127
  )
128
  gen_only = out[0, input_ids.shape[-1]:]
129
  return tokenizer.decode(gen_only, skip_special_tokens=True).strip()
130
 
 
131
  def _load_snapshot(path="snapshots/current.json"):
132
  try:
133
  with open(path, "r", encoding="utf-8") as f:
134
  return json.load(f)
135
  except Exception:
 
136
  return {
137
  "timestamp": None, "beds_total": 400, "staffed_ratio": 1.0, "occupied_pct": 0.97,
138
  "ed_census": 62, "ed_admits_waiting": 19, "avg_ed_wait_hours": 8,
@@ -142,11 +171,12 @@ def _load_snapshot(path="snapshots/current.json"):
142
  "isolation_needs_waiting": {"contact": 3, "airborne": 1}, "telemetry_needed_waiting": 5
143
  }
144
 
145
- # Init retriever & session RAG
146
  init_retriever()
147
- _session_rag = SessionRAG()
148
 
149
- def _mdsi_block() -> str:
 
150
  base_capacity = capacity_projection(18, 48, 6)
151
  cons_capacity = capacity_projection(12, 48, 6)
152
  opt_capacity = capacity_projection(24, 48, 6)
@@ -160,42 +190,58 @@ def _mdsi_block() -> str:
160
  "outcomes_summary": outcomes
161
  }, indent=2)
162
 
163
- def chat_fn(message, history, user_tz, uploaded_files, scenario_text):
 
 
 
 
 
 
 
164
  try:
165
- safe_in, blocked_in, reason_in = safety_filter(message, mode="input")
166
- if blocked_in: return refusal_reply(reason_in)
167
- if is_identity_query(safe_in, history):
168
- return "I am ClarityOps, your strategic decision making AI partner."
169
-
170
- # Ingest uploads
171
- filepaths = [f.name if hasattr(f, "name") else f for f in (uploaded_files or [])]
172
- if filepaths:
173
- items = extract_text_from_files(filepaths)
174
- if items: _session_rag.add_docs(items)
175
 
176
- # Retrieve snippets from session uploads
177
- session_snips = "\\n---\\n".join(_session_rag.retrieve(
178
- "diabetes screening Indigenous Métis mobile program cost throughput outcomes logistics", k=6
 
 
 
 
 
 
 
 
 
 
 
179
  ))
180
 
 
181
  snapshot = _load_snapshot()
182
  policy_context = retrieve_context(
183
- "mobile diabetes screening Indigenous community outreach logistics referral pathways privacy cultural safety data governance cost effectiveness outcomes"
184
  )
185
  computed = compute_operational_numbers(snapshot)
186
 
187
- mdsi_extra = _mdsi_block() if ("diabetes" in (scenario_text or "").lower() or "mdsi" in (scenario_text or "").lower()) else ""
 
 
188
 
189
  system_preamble = build_system_preamble(
190
  snapshot=snapshot,
191
  policy_context=policy_context,
192
  computed_numbers=computed,
193
- scenario_text=(scenario_text or "" ) + (f"\\n\\nExecutive Pre-Computed Blocks:\\n{mdsi_extra}" if mdsi_extra else ""),
194
  session_snips=session_snips
195
  )
196
 
197
- augmented_user = system_preamble + "\\n\\nUser question or request:\\n" + safe_in
198
 
 
199
  if USE_HOSTED_COHERE:
200
  out = cohere_chat(augmented_user, history)
201
  else:
@@ -203,56 +249,143 @@ def chat_fn(message, history, user_tz, uploaded_files, scenario_text):
203
  inputs = build_inputs(tokenizer, augmented_user, history)
204
  out = local_generate(model, tokenizer, inputs, max_new_tokens=900)
205
 
 
206
  if isinstance(out, str):
207
  for tag in ("Assistant:", "System:", "User:"):
208
- if out.startswith(tag): out = out[len(tag):].strip()
 
209
 
 
210
  safe_out, blocked_out, reason_out = safety_filter(out, mode="output")
211
- if blocked_out: return refusal_reply(reason_out)
212
- return safe_out
 
 
213
  except Exception as e:
214
- return f"Error: {e}"
215
 
 
216
  theme = gr.themes.Soft(primary_hue="teal", neutral_hue="slate", radius_size=gr.themes.sizes.radius_lg)
217
  custom_css = """
218
  :root { --brand-bg: #e6f7f8; --brand-accent: #0d9488; --brand-text: #0f172a; --brand-text-light: #ffffff; }
219
  .gradio-container { background: var(--brand-bg); }
 
 
220
  h1 { color: var(--brand-text); font-weight: 700; font-size: 28px !important; }
221
- .chatbot header, .chatbot .label, .chatbot .label-wrap, .chatbot .top, .chatbot .header, .chatbot > .wrap > header { display: none !important; }
222
- .message.user, .message.bot { background: var(--brand-accent) !important; color: var(--brand-text-light) !important; border-radius: 12px !important; padding: 8px 12px !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  textarea, input, .gr-input { border-radius: 12px !important; }
224
- .examples, .examples .grid { display: flex !important; justify-content: center !important; text-align: center !important; }
225
  """
226
 
 
227
  with gr.Blocks(theme=theme, css=custom_css) as demo:
 
228
  tz_box = gr.Textbox(visible=False)
229
- demo.load(lambda tz: tz, inputs=[tz_box], outputs=[tz_box],
230
- js="() => Intl.DateTimeFormat().resolvedOptions().timeZone")
 
 
 
 
231
 
 
232
  hide_label_sink = gr.HTML(visible=False)
233
- demo.load(fn=lambda: "", inputs=None, outputs=hide_label_sink, js="""
234
- () => { const sel = ['.chatbot header','.chatbot .label','.chatbot .label-wrap','.chatbot .top','.chatbot .header','.chatbot > .wrap > header'];
235
- sel.forEach(s => document.querySelectorAll(s).forEach(el => el.style.display = 'none')); return ""; } """)
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
  gr.Markdown("# ClarityOps Augmented Decision AI")
238
 
239
- uploads = gr.Files(label="Upload docs/images (PDF, DOCX, CSV, PNG, JPG)", file_types=["file"], file_count="multiple")
240
- scenario = gr.Textbox(label="Scenario Context (paste case studies or executive briefs here)",
241
- lines=10, placeholder="Paste scenario text...")
242
-
243
- gr.ChatInterface(
244
- fn=chat_fn,
245
- type="messages",
246
- additional_inputs=[tz_box, uploads, scenario],
247
- chatbot=gr.Chatbot(label="", show_label=False, type="messages", height=700),
248
- examples=[
249
- ["What are the symptoms of hypertension?"],
250
- ["What are common drug interactions with aspirin?"],
251
- ["What are the warning signs of diabetes?"],
252
- ],
253
- cache_examples=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  )
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  if __name__ == "__main__":
257
  port = int(os.environ.get("PORT", "7860"))
258
  demo.launch(server_name="0.0.0.0", server_port=port, show_api=False, max_threads=8)
 
 
 
1
  import os, re, json
2
  from functools import lru_cache
3
 
4
  import gradio as gr
5
  import torch
6
 
7
+ # ---------- Env/cache (quiet deprecation) ----------
8
  os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
9
  os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")
10
  os.environ.setdefault("GRADIO_TEMP_DIR", "/data/gradio")
11
  os.environ.setdefault("GRADIO_CACHE_DIR", "/data/gradio")
12
+ os.environ.pop("TRANSFORMERS_CACHE", None) # silence v5 deprecation note
13
  for p in ["/data/.cache/huggingface/hub", "/data/gradio"]:
14
+ try:
15
+ os.makedirs(p, exist_ok=True)
16
+ except Exception:
17
+ pass
18
 
19
+ # ---------- Optional timezone ----------
20
  try:
21
+ from zoneinfo import ZoneInfo # noqa: F401
22
  except Exception:
23
+ ZoneInfo = None # noqa: N816
24
 
25
+ # ---------- Optional Cohere ----------
26
  try:
27
  import cohere
28
  _HAS_COHERE = True
 
32
  from transformers import AutoTokenizer, AutoModelForCausalLM
33
  from huggingface_hub import login
34
 
35
+ # ---------- ClarityOps modules ----------
36
  from safety import safety_filter, refusal_reply
37
  from retriever import init_retriever, retrieve_context
38
  from decision_math import compute_operational_numbers
 
41
  from session_rag import SessionRAG
42
  from mdsi_analysis import capacity_projection, cost_estimate, outcomes_summary
43
 
44
+ # ---------- Config ----------
45
  MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024")
46
  HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
47
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
48
  USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
49
 
50
+ # ---------- Helpers ----------
51
  def pick_dtype_and_map():
52
+ if torch.cuda.is_available():
53
+ return torch.float16, "auto"
54
+ if torch.backends.mps.is_available():
55
+ return torch.float16, {"": "mps"}
56
  return torch.float32, "cpu"
57
 
58
  def is_identity_query(message, history):
59
  patterns = [
60
+ r"\bwho\s+are\s+you\b",
61
+ r"\bwhat\s+are\s+you\b",
62
+ r"\bwhat\s+is\s+your\s+name\b",
63
+ r"\bwho\s+is\s+this\b",
64
+ r"\bidentify\s+yourself\b",
65
+ r"\btell\s+me\s+about\s+yourself\b",
66
+ r"\bdescribe\s+yourself\b",
67
+ r"\band\s+you\s*\?\b",
68
+ r"\byour\s+name\b",
69
+ r"\bwho\s+am\s+i\s+chatting\s+with\b",
70
  ]
71
  def match(t): return any(re.search(p, (t or "").strip().lower()) for p in patterns)
72
+ if match(message):
73
+ return True
74
  if history:
75
  last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) else None
76
+ if match(last_user):
77
+ return True
78
  return False
79
 
80
  def _iter_user_assistant(history):
 
93
  parts.append("Assistant:")
94
  return "\n".join(parts)
95
 
96
+ # ---------- Cohere path ----------
97
  _co_client = None
98
  if USE_HOSTED_COHERE:
99
  _co_client = cohere.Client(api_key=COHERE_API_KEY)
 
105
  model="command-r7b-12-2024",
106
  message=prompt,
107
  temperature=0.3,
108
+ max_tokens=900,
109
  )
110
  if hasattr(resp, "text") and resp.text: return resp.text.strip()
111
  if hasattr(resp, "reply") and resp.reply: return resp.reply.strip()
 
114
  except Exception as e:
115
  return f"Error calling Cohere API: {e}"
116
 
117
+ # ---------- Local model ----------
118
  @lru_cache(maxsize=1)
119
  def load_local_model():
120
+ if not HF_TOKEN:
121
+ raise RuntimeError("HUGGINGFACE_HUB_TOKEN is not set.")
122
  login(token=HF_TOKEN, add_to_git_credential=False)
123
  dtype, device_map = pick_dtype_and_map()
124
  tok = AutoTokenizer.from_pretrained(
 
138
  if u: msgs.append({"role": "user", "content": u})
139
  if a: msgs.append({"role": "assistant", "content": a})
140
  msgs.append({"role": "user", "content": message})
141
+ return tokenizer.apply_chat_template(
142
+ msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt"
143
+ )
144
 
145
  def local_generate(model, tokenizer, input_ids, max_new_tokens=900):
146
  input_ids = input_ids.to(model.device)
147
  with torch.no_grad():
148
  out = model.generate(
149
+ input_ids=input_ids, max_new_tokens=max_new_tokens,
150
+ do_sample=True, temperature=0.3, top_p=0.9,
151
+ repetition_penalty=1.15,
152
+ pad_token_id=tokenizer.eos_token_id,
153
+ eos_token_id=tokenizer.eos_token_id,
154
  )
155
  gen_only = out[0, input_ids.shape[-1]:]
156
  return tokenizer.decode(gen_only, skip_special_tokens=True).strip()
157
 
158
+ # ---------- Snapshot loader ----------
159
  def _load_snapshot(path="snapshots/current.json"):
160
  try:
161
  with open(path, "r", encoding="utf-8") as f:
162
  return json.load(f)
163
  except Exception:
164
+ # Safe fallback if no snapshot present
165
  return {
166
  "timestamp": None, "beds_total": 400, "staffed_ratio": 1.0, "occupied_pct": 0.97,
167
  "ed_census": 62, "ed_admits_waiting": 19, "avg_ed_wait_hours": 8,
 
171
  "isolation_needs_waiting": {"contact": 3, "airborne": 1}, "telemetry_needed_waiting": 5
172
  }
173
 
174
+ # ---------- Init retrieval engines ----------
175
  init_retriever()
176
+ _session_rag = SessionRAG() # ephemeral per-session index for uploaded docs/images
177
 
178
+ # ---------- Executive pre-compute (MDSi block) ----------
179
+ def _mdsi_block():
180
  base_capacity = capacity_projection(18, 48, 6)
181
  cons_capacity = capacity_projection(12, 48, 6)
182
  opt_capacity = capacity_projection(24, 48, 6)
 
190
  "outcomes_summary": outcomes
191
  }, indent=2)
192
 
193
+ # ---------- Core chat logic ----------
194
+ def clarityops_reply(user_msg, history, tz, uploaded_files_paths):
195
+ """
196
+ - user_msg: latest message text
197
+ - history: list[(user, assistant)]
198
+ - tz: timezone str (unused but kept for future features)
199
+ - uploaded_files_paths: list[str] absolute paths of uploaded files
200
+ """
201
  try:
202
+ # Safety (input)
203
+ safe_in, blocked_in, reason_in = safety_filter(user_msg, mode="input")
204
+ if blocked_in:
205
+ return history + [(user_msg, refusal_reply(reason_in))]
 
 
 
 
 
 
206
 
207
+ # Identity short-circuit
208
+ if is_identity_query(safe_in, history):
209
+ return history + [(user_msg, "I am ClarityOps, your strategic decision making AI partner.")]
210
+
211
+ # Ingest new uploads into session RAG (ephemeral for this chat)
212
+ if uploaded_files_paths:
213
+ items = extract_text_from_files(uploaded_files_paths)
214
+ if items:
215
+ _session_rag.add_docs(items)
216
+
217
+ # Pull session snippets from uploaded docs/images
218
+ session_snips = "\n---\n".join(_session_rag.retrieve(
219
+ "diabetes screening Indigenous Métis mobile program cost throughput outcomes logistics bed flow staffing discharge forecast",
220
+ k=6
221
  ))
222
 
223
+ # Load daily snapshot + policies + computed ops numbers
224
  snapshot = _load_snapshot()
225
  policy_context = retrieve_context(
226
+ "mobile diabetes screening Indigenous community outreach logistics referral pathways cultural safety data governance cost effectiveness outcomes bed management discharge acceleration ambulance offload"
227
  )
228
  computed = compute_operational_numbers(snapshot)
229
 
230
+ # Smart scenario detection: if user message itself looks like exec MDSi context, include the pre-compute block
231
+ user_lower = (safe_in or "").lower()
232
+ mdsi_extra = _mdsi_block() if ("diabetes" in user_lower or "mdsi" in user_lower or "mobile screening" in user_lower) else ""
233
 
234
  system_preamble = build_system_preamble(
235
  snapshot=snapshot,
236
  policy_context=policy_context,
237
  computed_numbers=computed,
238
+ scenario_text=(safe_in if len(safe_in) > 400 else "") + (f"\n\nExecutive Pre-Computed Blocks:\n{mdsi_extra}" if mdsi_extra else ""),
239
  session_snips=session_snips
240
  )
241
 
242
+ augmented_user = system_preamble + "\n\nUser question or request:\n" + safe_in
243
 
244
+ # Generate
245
  if USE_HOSTED_COHERE:
246
  out = cohere_chat(augmented_user, history)
247
  else:
 
249
  inputs = build_inputs(tokenizer, augmented_user, history)
250
  out = local_generate(model, tokenizer, inputs, max_new_tokens=900)
251
 
252
+ # Tidy echoes
253
  if isinstance(out, str):
254
  for tag in ("Assistant:", "System:", "User:"):
255
+ if out.startswith(tag):
256
+ out = out[len(tag):].strip()
257
 
258
+ # Safety (output)
259
  safe_out, blocked_out, reason_out = safety_filter(out, mode="output")
260
+ if blocked_out:
261
+ out = refusal_reply(reason_out)
262
+
263
+ return history + [(user_msg, safe_out)]
264
  except Exception as e:
265
+ return history + [(user_msg, f"Error: {e}")]
266
 
267
+ # ---------- Theme & CSS ----------
268
  theme = gr.themes.Soft(primary_hue="teal", neutral_hue="slate", radius_size=gr.themes.sizes.radius_lg)
269
  custom_css = """
270
  :root { --brand-bg: #e6f7f8; --brand-accent: #0d9488; --brand-text: #0f172a; --brand-text-light: #ffffff; }
271
  .gradio-container { background: var(--brand-bg); }
272
+
273
+ /* Title */
274
  h1 { color: var(--brand-text); font-weight: 700; font-size: 28px !important; }
275
+
276
+ /* Hide default Chatbot label */
277
+ .chatbot header, .chatbot .label, .chatbot .label-wrap, .chatbot .top, .chatbot .header, .chatbot > .wrap > header {
278
+ display: none !important;
279
+ }
280
+
281
+ /* Chat bubbles */
282
+ .message.user, .message.bot {
283
+ background: var(--brand-accent) !important;
284
+ color: var(--brand-text-light) !important;
285
+ border-radius: 12px !important;
286
+ padding: 8px 12px !important;
287
+ }
288
+
289
+ /* Inputs softer */
290
  textarea, input, .gr-input { border-radius: 12px !important; }
 
291
  """
292
 
293
+ # ---------- UI (single integrated window; uploads at bottom) ----------
294
  with gr.Blocks(theme=theme, css=custom_css) as demo:
295
+ # timezone capture (hidden)
296
  tz_box = gr.Textbox(visible=False)
297
+ demo.load(
298
+ lambda tz: tz,
299
+ inputs=[tz_box],
300
+ outputs=[tz_box],
301
+ js="() => Intl.DateTimeFormat().resolvedOptions().timeZone",
302
+ )
303
 
304
+ # extra DOM cleanup for some gradio builds
305
  hide_label_sink = gr.HTML(visible=False)
306
+ demo.load(
307
+ fn=lambda: "",
308
+ inputs=None,
309
+ outputs=hide_label_sink,
310
+ js="""
311
+ () => {
312
+ const sel = [
313
+ '.chatbot header','.chatbot .label','.chatbot .label-wrap',
314
+ '.chatbot .top','.chatbot .header','.chatbot > .wrap > header'
315
+ ];
316
+ sel.forEach(s => document.querySelectorAll(s).forEach(el => el.style.display = 'none'));
317
+ return "";
318
+ }
319
+ """,
320
+ )
321
 
322
  gr.Markdown("# ClarityOps Augmented Decision AI")
323
 
324
+ # Main chat area
325
+ chat = gr.Chatbot(label="", show_label=False, type="messages", height=700)
326
+
327
+ # ---- Bottom bar: uploads + message box + send/clear ----
328
+ with gr.Row():
329
+ uploads = gr.Files(
330
+ label="Upload docs/images (PDF, DOCX, CSV, PNG, JPG)",
331
+ file_types=["file"],
332
+ file_count="multiple",
333
+ # keep compact footprint
334
+ height=68
335
+ )
336
+
337
+ with gr.Row():
338
+ msg = gr.Textbox(placeholder="Type a message… (paste scenarios here too; ClarityOps will adapt)", scale=10)
339
+ send = gr.Button("Send", scale=1)
340
+ clear = gr.Button("Clear chat", scale=1)
341
+
342
+ # States
343
+ state_history = gr.State(value=[])
344
+ state_uploaded = gr.State(value=[])
345
+
346
+ # When user selects files, store their paths in state (so they persist across turns)
347
+ def _store_uploads(files, current):
348
+ paths = []
349
+ for f in (files or []):
350
+ # gradio Files returns tempfile objects with .name
351
+ paths.append(getattr(f, "name", None) or f)
352
+ return (current or []) + paths
353
+
354
+ uploads.change(fn=_store_uploads, inputs=[uploads, state_uploaded], outputs=state_uploaded)
355
+
356
+ # Send message -> compute reply -> update chat
357
+ def _on_send(user_msg, history, tz, up_paths):
358
+ if not user_msg or not user_msg.strip():
359
+ return history, "" # no-op
360
+ new_history = clarityops_reply(user_msg.strip(), history or [], tz, up_paths or [])
361
+ return new_history, ""
362
+
363
+ send.click(
364
+ fn=_on_send,
365
+ inputs=[msg, state_history, tz_box, state_uploaded],
366
+ outputs=[chat, msg],
367
+ queue=True,
368
  )
369
 
370
+ # Also allow pressing Enter inside the textbox
371
+ msg.submit(
372
+ fn=_on_send,
373
+ inputs=[msg, state_history, tz_box, state_uploaded],
374
+ outputs=[chat, msg],
375
+ queue=True,
376
+ )
377
+
378
+ # Keep Chatbot history state in sync whenever it updates
379
+ chat.change(lambda h: h, inputs=chat, outputs=state_history)
380
+
381
+ # Clear chat (keeps uploads so you can keep referencing docs)
382
+ def _clear_chat():
383
+ return [], []
384
+ clear.click(lambda: [], None, chat)
385
+ # If you also want to clear uploads, uncomment below:
386
+ # clear.click(_clear_chat, None, [chat, state_uploaded])
387
+
388
  if __name__ == "__main__":
389
  port = int(os.environ.get("PORT", "7860"))
390
  demo.launch(server_name="0.0.0.0", server_port=port, show_api=False, max_threads=8)
391
+