Vrda commited on
Commit
6ca9407
·
verified ·
1 Parent(s): a238159

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +118 -30
app.py CHANGED
@@ -9,15 +9,89 @@ import json
9
  import os
10
  import tempfile
11
  import threading
 
12
  from datetime import datetime
13
  from pathlib import Path
14
- from backend import translate_to_english, call_model_a, call_model_b
15
 
16
  FEEDBACK_FILE = Path(__file__).parent / "feedback_data.json"
17
  HF_DATASET_REPO = "Vrda/im-error-check-data"
18
  HF_DATASET_FILE = "feedback_data.json"
19
 
20
- _DEEPSEEK_RESULTS: dict[str, dict] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # -------------------------------------------------------------------------
23
  # Feedback persistence (local + HF Hub sync)
@@ -186,12 +260,39 @@ for key, default in [
186
  ("model_b_result", None),
187
  ("translation_latency", 0),
188
  ("total_elapsed", 0),
 
 
189
  ("run_analysis", False),
190
  ("physician_id", ""),
191
  ]:
192
  if key not in st.session_state:
193
  st.session_state[key] = default
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  def load_sample():
197
  st.session_state.input_text = SAMPLE
@@ -252,24 +353,13 @@ st.button("Analyze", type="primary", on_click=trigger_analysis)
252
  # Run analysis (progressive: show GPT-OSS first, DeepSeek when ready)
253
  # -------------------------------------------------------------------------
254
 
255
- def _run_deepseek_background(session_key: str, english_text: str):
256
- """Background thread: calls DeepSeek and stores result in module-level dict."""
257
- result = call_model_a(english_text)
258
- _DEEPSEEK_RESULTS[session_key] = result
259
-
260
- if "session_key" not in st.session_state:
261
- import uuid
262
- st.session_state.session_key = str(uuid.uuid4())
263
-
264
  if st.session_state.run_analysis and st.session_state.input_text.strip():
265
  st.session_state.run_analysis = False
266
  st.session_state.model_a_result = None
267
  st.session_state.model_b_result = None
268
  st.session_state.total_elapsed = 0
269
- st.session_state._analysis_start = time.time()
270
-
271
- skey = st.session_state.session_key
272
- _DEEPSEEK_RESULTS.pop(skey, None)
273
 
274
  with st.spinner("Translating discharge letter..."):
275
  t0 = time.time()
@@ -278,16 +368,17 @@ if st.session_state.run_analysis and st.session_state.input_text.strip():
278
 
279
  english = st.session_state.translated_text
280
 
281
- thread = threading.Thread(
282
- target=_run_deepseek_background, args=(skey, english), daemon=True
283
- )
284
- thread.start()
285
 
286
  with st.spinner("GPT-OSS-120B responding (~5s)..."):
287
  st.session_state.model_b_result = call_model_b(english)
288
 
289
  st.rerun()
290
 
 
 
291
  # -------------------------------------------------------------------------
292
  # Helper: render a model's output
293
  # -------------------------------------------------------------------------
@@ -405,16 +496,13 @@ if has_any_result:
405
  if st.session_state.model_a_result is not None:
406
  render_model_output(st.session_state.model_a_result, "model-header-a")
407
  else:
408
- @st.fragment(run_every=5)
409
- def _poll_deepseek():
410
- skey = st.session_state.session_key
411
- if skey in _DEEPSEEK_RESULTS:
412
- st.session_state.model_a_result = _DEEPSEEK_RESULTS.pop(skey)
413
- st.session_state.total_elapsed = round(
414
- time.time() - st.session_state._analysis_start, 2
415
- )
416
- st.rerun()
417
- elapsed = round(time.time() - st.session_state._analysis_start)
418
  st.markdown(
419
  '<div style="background:#1e293b; border:2px dashed #475569; '
420
  'border-radius:8px; padding:2rem; text-align:center; color:#e2e8f0;">'
@@ -424,7 +512,7 @@ if has_any_result:
424
  "</div>",
425
  unsafe_allow_html=True,
426
  )
427
- _poll_deepseek()
428
 
429
  # -----------------------------------------------------------------
430
  # Feedback
 
9
  import os
10
  import tempfile
11
  import threading
12
+ from concurrent.futures import ThreadPoolExecutor
13
  from datetime import datetime
14
  from pathlib import Path
15
+ from backend import ModelResult, translate_to_english, call_model_a, call_model_b
16
 
17
  FEEDBACK_FILE = Path(__file__).parent / "feedback_data.json"
18
  HF_DATASET_REPO = "Vrda/im-error-check-data"
19
  HF_DATASET_FILE = "feedback_data.json"
20
 
21
+
22
+ @st.cache_resource
23
+ def get_deepseek_job_manager():
24
+ return {
25
+ "executor": ThreadPoolExecutor(max_workers=2),
26
+ "jobs": {},
27
+ "lock": threading.Lock(),
28
+ }
29
+
30
+
31
+ def cleanup_deepseek_jobs(max_age_seconds: int = 1800):
32
+ manager = get_deepseek_job_manager()
33
+ now = time.time()
34
+ stale_job_ids = []
35
+ with manager["lock"]:
36
+ for job_id, job in manager["jobs"].items():
37
+ if now - job["created_at"] > max_age_seconds:
38
+ stale_job_ids.append(job_id)
39
+ for job_id in stale_job_ids:
40
+ manager["jobs"].pop(job_id, None)
41
+
42
+
43
+ def submit_deepseek_job(job_id: str, english_text: str):
44
+ manager = get_deepseek_job_manager()
45
+ future = manager["executor"].submit(call_model_a, english_text)
46
+ with manager["lock"]:
47
+ manager["jobs"][job_id] = {
48
+ "future": future,
49
+ "created_at": time.time(),
50
+ }
51
+
52
+
53
+ def get_deepseek_job_info(job_id: str):
54
+ if not job_id:
55
+ return None
56
+ manager = get_deepseek_job_manager()
57
+ with manager["lock"]:
58
+ job = manager["jobs"].get(job_id)
59
+ if not job:
60
+ return None
61
+ return {
62
+ "created_at": job["created_at"],
63
+ "done": job["future"].done(),
64
+ }
65
+
66
+
67
+ def consume_deepseek_job_result(job_id: str) -> ModelResult | None:
68
+ if not job_id:
69
+ return None
70
+ manager = get_deepseek_job_manager()
71
+ with manager["lock"]:
72
+ job = manager["jobs"].get(job_id)
73
+ if not job:
74
+ return None
75
+
76
+ future = job["future"]
77
+ if not future.done():
78
+ return None
79
+
80
+ try:
81
+ result = future.result()
82
+ except Exception as exc:
83
+ result = ModelResult(
84
+ model_name="DeepSeek Reasoner",
85
+ raw_response="",
86
+ success=False,
87
+ error_message=f"Background job failed: {exc}",
88
+ latency_seconds=0.0,
89
+ )
90
+
91
+ with manager["lock"]:
92
+ manager["jobs"].pop(job_id, None)
93
+
94
+ return result
95
 
96
  # -------------------------------------------------------------------------
97
  # Feedback persistence (local + HF Hub sync)
 
260
  ("model_b_result", None),
261
  ("translation_latency", 0),
262
  ("total_elapsed", 0),
263
+ ("analysis_started_at", 0.0),
264
+ ("deepseek_job_id", None),
265
  ("run_analysis", False),
266
  ("physician_id", ""),
267
  ]:
268
  if key not in st.session_state:
269
  st.session_state[key] = default
270
 
271
+ if "session_key" not in st.session_state:
272
+ import uuid
273
+
274
+ st.session_state.session_key = str(uuid.uuid4())
275
+
276
+ cleanup_deepseek_jobs()
277
+
278
+
279
+ @st.fragment(run_every=5)
280
+ def poll_deepseek_job():
281
+ job_id = st.session_state.deepseek_job_id
282
+ if not job_id or st.session_state.model_a_result is not None:
283
+ return
284
+
285
+ result = consume_deepseek_job_result(job_id)
286
+ if result is None:
287
+ return
288
+
289
+ st.session_state.model_a_result = result
290
+ st.session_state.deepseek_job_id = None
291
+ st.session_state.total_elapsed = round(
292
+ time.time() - st.session_state.analysis_started_at, 2
293
+ )
294
+ st.rerun()
295
+
296
 
297
  def load_sample():
298
  st.session_state.input_text = SAMPLE
 
353
  # Run analysis (progressive: show GPT-OSS first, DeepSeek when ready)
354
  # -------------------------------------------------------------------------
355
 
 
 
 
 
 
 
 
 
 
356
  if st.session_state.run_analysis and st.session_state.input_text.strip():
357
  st.session_state.run_analysis = False
358
  st.session_state.model_a_result = None
359
  st.session_state.model_b_result = None
360
  st.session_state.total_elapsed = 0
361
+ st.session_state.analysis_started_at = time.time()
362
+ st.session_state.deepseek_job_id = None
 
 
363
 
364
  with st.spinner("Translating discharge letter..."):
365
  t0 = time.time()
 
368
 
369
  english = st.session_state.translated_text
370
 
371
+ job_id = f"{st.session_state.session_key}:{int(time.time() * 1000)}"
372
+ submit_deepseek_job(job_id, english)
373
+ st.session_state.deepseek_job_id = job_id
 
374
 
375
  with st.spinner("GPT-OSS-120B responding (~5s)..."):
376
  st.session_state.model_b_result = call_model_b(english)
377
 
378
  st.rerun()
379
 
380
+ poll_deepseek_job()
381
+
382
  # -------------------------------------------------------------------------
383
  # Helper: render a model's output
384
  # -------------------------------------------------------------------------
 
496
  if st.session_state.model_a_result is not None:
497
  render_model_output(st.session_state.model_a_result, "model-header-a")
498
  else:
499
+ job_info = get_deepseek_job_info(st.session_state.deepseek_job_id)
500
+ if job_info is None:
501
+ st.warning(
502
+ "DeepSeek job is no longer active. Click `Analyze` to run it again."
503
+ )
504
+ else:
505
+ elapsed = round(time.time() - job_info["created_at"])
 
 
 
506
  st.markdown(
507
  '<div style="background:#1e293b; border:2px dashed #475569; '
508
  'border-radius:8px; padding:2rem; text-align:center; color:#e2e8f0;">'
 
512
  "</div>",
513
  unsafe_allow_html=True,
514
  )
515
+ st.caption("Checking DeepSeek status every 5 seconds.")
516
 
517
  # -----------------------------------------------------------------
518
  # Feedback