File size: 49,561 Bytes
793d027
de81512
 
793d027
 
91b591f
6bee373
3f3989a
 
a2ccb82
6bee373
de81512
793d027
 
 
 
 
6bee373
 
 
 
 
 
 
 
 
 
5a0599c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de81512
3f3989a
 
 
 
de81512
 
3f3989a
b1f80be
 
ac6f3d5
b1f80be
ac6f3d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de81512
ba2715f
674163f
 
 
 
 
 
 
 
 
 
 
 
 
74922ae
674163f
 
 
 
 
 
 
 
 
 
 
 
793d027
 
a2ccb82
 
 
793d027
 
 
 
de81512
a2ccb82
de81512
a2ccb82
de81512
a2ccb82
3d5c92f
a2ccb82
3d5c92f
 
 
 
a2ccb82
3d5c92f
 
 
 
 
a2ccb82
3d5c92f
a2ccb82
3d5c92f
 
 
a2ccb82
 
3d5c92f
de81512
3d5c92f
de81512
a2ccb82
3d5c92f
 
 
 
 
 
 
 
a2ccb82
3d5c92f
 
a2ccb82
3d5c92f
a2ccb82
3d5c92f
 
de81512
a2ccb82
3d5c92f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2ccb82
3d5c92f
a2ccb82
3d5c92f
 
 
 
a2ccb82
3d5c92f
 
 
 
 
 
 
 
 
 
 
 
 
a2ccb82
3d5c92f
a2ccb82
3d5c92f
 
a2ccb82
de81512
793d027
de81512
a2ccb82
de81512
 
 
 
a2ccb82
 
91b591f
de81512
 
a2ccb82
de81512
a2ccb82
de81512
5d82131
 
 
 
 
 
 
 
 
 
de81512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d82131
 
de81512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d82131
de81512
5d82131
de81512
5d82131
de81512
 
 
 
5d82131
de81512
 
5d82131
 
de81512
5d82131
de81512
 
 
5d82131
 
de81512
5d82131
 
 
de81512
5d82131
de81512
3d5c92f
de81512
 
5d82131
 
de81512
 
 
 
 
 
 
5d82131
 
de81512
 
 
 
 
 
 
 
 
 
5d82131
 
de81512
5d82131
 
 
de81512
5d82131
 
de81512
5d82131
de81512
 
5d82131
de81512
 
 
5d82131
 
de81512
 
 
 
 
 
 
 
ba2715f
91b591f
de81512
91b591f
de81512
91b591f
 
de81512
 
 
 
91b591f
 
 
 
de81512
91b591f
 
 
a2ccb82
 
91b591f
 
a2ccb82
 
 
de81512
 
a2ccb82
91b591f
 
 
 
 
de81512
91b591f
 
 
 
 
 
a2ccb82
91b591f
 
 
 
de81512
 
91b591f
 
793d027
 
de81512
793d027
de81512
 
 
 
 
 
 
 
 
 
a2ccb82
793d027
 
de81512
 
793d027
 
de81512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2ccb82
de81512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d71e88
de81512
 
 
 
674163f
6bee373
 
 
de81512
6bee373
 
de81512
 
91b591f
de81512
 
 
 
793d027
de81512
793d027
 
de81512
793d027
de81512
 
 
91b591f
a2ccb82
de81512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2ccb82
de81512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2ccb82
793d027
de81512
 
 
 
 
 
d902c1a
 
 
de81512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d71e88
de81512
6d71e88
 
 
 
 
 
 
 
 
de81512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d5c92f
 
 
 
de81512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
793d027
de81512
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
"""
AMR-Guard — Gradio Interface (ZeroGPU compatible)
Infection Lifecycle Orchestrator · Multi-Agent Clinical Decision Support
"""

import json
import logging
import os
import subprocess
import sys
import traceback
from io import BytesIO
from pathlib import Path

PROJECT_ROOT = Path(__file__).parent
sys.path.insert(0, str(PROJECT_ROOT))

# Configure logging early so all module-level loggers emit to stdout.
# force=True reconfigures the root logger even if already set by an import.
logging.basicConfig(
    level=logging.INFO,
    stream=sys.stdout,
    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
    force=True,
)
logger = logging.getLogger(__name__)

# ── huggingface_hub compatibility shim ───────────────────────────────────────
# Older gradio versions (pre-5.7) import HfFolder from huggingface_hub in
# oauth.py. HfFolder was removed in huggingface_hub >= 0.25. Patch it back
# in-memory before importing gradio so the old oauth.py can find it.
try:
    from huggingface_hub import HfFolder as _check  # noqa: F401
except ImportError:
    import huggingface_hub as _hfh

    class _HfFolder:
        @staticmethod
        def get_token():
            return os.environ.get("HF_TOKEN") or _hfh.get_token()

        @staticmethod
        def save_token(token: str) -> None:  # noqa: ARG004
            pass

        @staticmethod
        def delete_token() -> None:
            pass

    _hfh.HfFolder = _HfFolder

# ── HuggingFace Spaces: auto-build knowledge base on first boot ───────────────
_DB_PATH = PROJECT_ROOT / os.getenv("MEDIC_DATA_DIR", "data") / "amr_guard.db"
if os.environ.get("SPACE_ID") and not _DB_PATH.exists():
    subprocess.run([sys.executable, str(PROJECT_ROOT / "setup_demo.py")], check=False)

import gradio as gr
import pandas as pd

# ── Gradio boolean-schema safety patch ───────────────────────────────────────
# Gradio <5.7 walks JSON Schemas and does `if "const" in schema:` without
# guarding against boolean schemas (valid in JSON Schema spec but not a dict).
# sdk_version is now >=5.25.0 (bug fixed upstream) but keep this as a guard.
try:
    import gradio.utils as _gr_utils
    _orig_get_type = getattr(_gr_utils, "get_type", None)
    if _orig_get_type:
        def _safe_get_type(schema, *a, **kw):
            if not isinstance(schema, dict):
                return "other"
            return _orig_get_type(schema, *a, **kw)
        _gr_utils.get_type = _safe_get_type
except Exception:
    pass
try:
    import gradio.route_utils as _gr_ru
    for _fn_name in ("get_type", "_json_schema_to_python_type", "json_schema_to_python_type"):
        _fn = getattr(_gr_ru, _fn_name, None)
        if _fn:
            def _safe_fn(schema, *a, _f=_fn, **kw):
                if not isinstance(schema, dict):
                    return "other"
                return _f(schema, *a, **kw)
            setattr(_gr_ru, _fn_name, _safe_fn)
except Exception:
    pass

from src.config import get_settings
from src.form_config import CREATININE_PROMINENT_SITES, SITE_SPECIFIC_FIELDS, SUSPECTED_SOURCE_OPTIONS
from src.loader import run_inference  # noqa: F401 – triggers spaces import / ZeroGPU registration at startup

# ── Single GPU session for the full multi-agent pipeline ──────────────────────
# Each run_inference call uses lru_cache'd model weights.  ZeroGPU frees GPU
# memory when a @spaces.GPU function returns, so wrapping every individual
# inference call in its own GPU session would invalidate the cached model
# between agents, causing model.generate() to hang on freed CUDA memory.
# Wrapping the *entire* pipeline in one session keeps the CUDA context alive
# for all four agents, so the model is loaded once and stays valid throughout.
if os.environ.get("SPACE_ID"):
    try:
        import spaces as _spaces_ui

        @_spaces_ui.GPU(duration=200)
        def _run_pipeline_gpu(patient_data: dict, labs_raw_text):
            from src.graph import run_pipeline
            return run_pipeline(patient_data, labs_raw_text)
    except ImportError:
        def _run_pipeline_gpu(patient_data: dict, labs_raw_text):
            from src.graph import run_pipeline
            return run_pipeline(patient_data, labs_raw_text)
else:
    def _run_pipeline_gpu(patient_data: dict, labs_raw_text):
        from src.graph import run_pipeline
        return run_pipeline(patient_data, labs_raw_text)

from src.tools import (
    calculate_mic_trend,
    get_empirical_therapy_guidance,
    get_most_effective_antibiotics,
    interpret_mic_value,
    screen_antibiotic_safety,
    search_clinical_guidelines,
)

# ── CSS ────────────────────────────────────────────────────────────────────────

CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
body, .gradio-container { font-family: 'Inter', sans-serif !important; }

/* ── Banner ── */
.med-banner {
    background: linear-gradient(135deg, #020d1f 0%, #0b2545 55%, #102f62 100%);
    padding: 24px 32px; border-radius: 14px; margin-bottom: 22px;
    border: 1px solid #1e4a80;
    box-shadow: 0 0 48px rgba(26,74,138,0.45), inset 0 1px 0 rgba(158,196,240,0.12);
}
.med-banner h1 {
    color: #ffffff; font-size: 1.95rem; font-weight: 700; margin: 0;
    text-shadow: 0 0 24px rgba(96,196,255,0.45);
}
.med-banner p  { color: #7eb8e8; font-size: 0.95rem; margin: 5px 0 0; }

/* ── Section titles ── */
.section-title {
    font-size: 0.8rem; font-weight: 700; color: #60b4ff;
    border-bottom: 1px solid #1e3f72; padding-bottom: 6px; margin: 18px 0 13px;
    text-transform: uppercase; letter-spacing: 0.1em;
}

/* ── Stat cards ── */
.stat-cards {
    display: grid; grid-template-columns: repeat(4, 1fr); gap: 16px; margin-bottom: 22px;
}
.stat-card {
    background: linear-gradient(160deg, #0b1e3d 0%, #0e2a56 100%);
    border: 1px solid #1e4a80; border-top: 3px solid #3b82f6;
    border-radius: 11px; padding: 18px 20px; text-align: center;
    box-shadow: 0 4px 18px rgba(0,0,0,0.35);
}
.stat-card .label {
    color: #7eaadb; font-size: 0.78rem; font-weight: 600;
    text-transform: uppercase; letter-spacing: 0.05em;
}
.stat-card .value { color: #60c8ff; font-size: 1.65rem; font-weight: 700; margin-top: 5px; }
.stat-card .sub   { color: #a8cce8; font-size: 0.75rem; margin-top: 3px; }

/* ── Agent steps ── */
.agent-step {
    background: linear-gradient(135deg, #091a36 0%, #0d2450 100%);
    border: 1px solid #1e4278; border-left: 4px solid #3b82f6;
    border-radius: 8px; padding: 14px 16px; margin-bottom: 10px;
}
.agent-step .num  { color: #60b4ff; font-weight: 700; font-size: 0.82rem; letter-spacing: 0.04em; }
.agent-step .name { color: #dceeff; font-weight: 600; }
.agent-step .desc { color: #8ab4d8; font-size: 0.85rem; margin-top: 4px; }

/* ── Status badges — dark backgrounds, high-contrast text ── */
.badge-high {
    background: #1e0707; border-left: 4px solid #dc2626; color: #fca5a5;
    padding: 10px 14px; border-radius: 7px; margin-bottom: 6px;
}
.badge-moderate {
    background: #1c1200; border-left: 4px solid #d97706; color: #fcd34d;
    padding: 10px 14px; border-radius: 7px; margin-bottom: 6px;
}
.badge-low {
    background: #021a0e; border-left: 4px solid #16a34a; color: #86efac;
    padding: 10px 14px; border-radius: 7px; margin-bottom: 6px;
}
.badge-info {
    background: #071428; border-left: 4px solid #2563eb; color: #93c5fd;
    padding: 10px 14px; border-radius: 7px; margin-bottom: 6px;
}

/* ── Prescription card ── */
.rx-card {
    background: linear-gradient(145deg, #081730 0%, #0c2248 100%);
    border: 1px solid #1e4a80; border-radius: 12px;
    padding: 24px 26px; font-size: 0.9rem; line-height: 1.75; color: #cce3ff;
    box-shadow: 0 6px 28px rgba(0,0,0,0.45), 0 0 0 1px rgba(59,130,246,0.18);
}
.rx-card .rx-symbol {
    font-size: 2.2rem; color: #60c8ff; font-weight: 700;
    text-shadow: 0 0 14px rgba(96,200,255,0.55);
}
.rx-card .rx-drug   { font-size: 1.25rem; font-weight: 700; color: #ffffff; }
.rx-card strong     { color: #a8d4ff; }
.rx-card ul         { color: #cce3ff; }

/* ── Badge child elements inherit color ── */
.badge-high strong,   .badge-high em,   .badge-high span   { color: inherit; }
.badge-moderate strong,.badge-moderate em,.badge-moderate span { color: inherit; }
.badge-low strong,    .badge-low em,    .badge-low span    { color: inherit; }
.badge-info strong,   .badge-info em,   .badge-info span   { color: inherit; }

/* ── Disclaimer ── */
.disclaimer {
    background: #150e00; border: 1px solid #78450e; border-radius: 8px;
    padding: 12px 16px; font-size: 0.78rem; color: #fbbf24; margin-top: 20px;
}
"""

BANNER_HTML = """
<div class="med-banner">
  <div>
    <h1>⚕ AMR-Guard</h1>
    <p>Infection Lifecycle Orchestrator &nbsp;·&nbsp; Multi-Agent Clinical Decision Support</p>
  </div>
</div>
"""

INFECTION_SITES = ["urinary", "respiratory", "bloodstream", "skin", "intra-abdominal", "CNS", "other"]


# ── HTML result builders ───────────────────────────────────────────────────────

def _parse_notes(raw):
    if not raw or raw in ("No lab data provided", "No MIC data available for trend analysis", ""):
        return None
    if isinstance(raw, (dict, list)):
        return raw
    try:
        return json.loads(raw)
    except Exception:
        return None


def _build_rec_html(result: dict) -> str:
    rec = result.get("recommendation") or {}
    if not rec:
        return '<div class="badge-info">No recommendation generated.</div>'
    primary  = rec.get("primary_antibiotic", "—")
    dose     = rec.get("dose", "—")
    route    = rec.get("route", "—")
    freq     = rec.get("frequency", "—")
    duration = rec.get("duration", "—")
    alt      = rec.get("backup_antibiotic", "")
    rationale = rec.get("rationale", "")
    refs     = rec.get("references", [])
    alt_html = f"<br><strong>Alternative:</strong> {alt}" if alt else ""
    rat_html = f"<br><br><strong>Clinical rationale</strong><br>{rationale}" if rationale else ""
    ref_html = ""
    if refs:
        items = "".join(f"<li>{r}</li>" for r in refs)
        ref_html = f"<br><strong>References</strong><ul style='margin:4px 0 0 16px'>{items}</ul>"
    return f"""
<div class="rx-card">
  <div class="rx-symbol">℞</div>
  <div class="rx-drug">{primary}</div><br>
  <strong>Dose:</strong> {dose} &nbsp;·&nbsp;
  <strong>Route:</strong> {route} &nbsp;·&nbsp;
  <strong>Frequency:</strong> {freq} &nbsp;·&nbsp;
  <strong>Duration:</strong> {duration}
  {alt_html}{rat_html}{ref_html}
</div>"""


def _build_intake_html(result: dict) -> str:
    intake = _parse_notes(result.get("intake_notes", ""))
    crcl   = result.get("creatinine_clearance_ml_min")
    html   = ""
    if isinstance(intake, dict):
        v       = crcl or intake.get("creatinine_clearance_ml_min", 0)
        sev     = intake.get("infection_severity", "")
        pathway = intake.get("recommended_stage", "")
        cells   = ""
        if v:
            cells += f"<td style='padding:8px 16px 8px 0'><strong>CrCl</strong><br>{float(v):.1f} mL/min</td>"
        if sev:
            cells += f"<td style='padding:8px 16px'><strong>Severity</strong><br>{sev.capitalize()}</td>"
        if pathway:
            cells += f"<td style='padding:8px 16px'><strong>Pathway</strong><br>{pathway.capitalize()}</td>"
        if cells:
            html += f"<table style='margin-bottom:12px'><tr>{cells}</tr></table>"
        if intake.get("patient_summary"):
            html += f'<div class="badge-info">{intake["patient_summary"]}</div>'
        if intake.get("renal_dose_adjustment_needed"):
            html += '<div class="badge-moderate" style="margin-top:8px">⚠ Renal dose adjustment required</div>'
        if intake.get("identified_risk_factors"):
            items = "".join(f"<li>{rf}</li>" for rf in intake["identified_risk_factors"])
            html += f"<br><strong>Identified risk factors</strong><ul style='margin:4px 0 0 16px'>{items}</ul>"
    elif crcl:
        html = f"<strong>CrCl:</strong> {float(crcl):.1f} mL/min"
    else:
        html = '<div class="badge-info">Intake summary not available.</div>'
    return html


def _build_lab_html_and_df(result: dict) -> tuple[str, pd.DataFrame]:
    vision = _parse_notes(result.get("vision_notes", ""))
    trend  = _parse_notes(result.get("trend_notes", ""))
    html   = ""
    df     = pd.DataFrame()

    if vision is None:
        html += '<div class="badge-info">No lab data processed. Provide lab results to activate the targeted pathway.</div>'
    else:
        v = vision if isinstance(vision, dict) else {}
        if v.get("specimen_type"):
            html += f"<strong>Specimen:</strong> {v['specimen_type'].capitalize()}<br>"
        if v.get("extraction_confidence") is not None:
            conf  = float(v["extraction_confidence"])
            color = "#86efac" if conf >= 0.85 else "#fcd34d" if conf >= 0.6 else "#fca5a5"
            html += (f'<div class="badge-info">Extraction confidence: '
                     f'<span style="color:{color};font-weight:700">{conf:.0%}</span></div>')
        orgs = v.get("identified_organisms", [])
        if orgs:
            items = "".join(
                f"<li><strong>{o.get('organism_name','?')}</strong>"
                + (f" — {o.get('significance','')}" if o.get("significance") else "")
                + "</li>"
                for o in orgs
            )
            html += f"<br><strong>Identified organisms</strong><ul style='margin:4px 0 0 16px'>{items}</ul>"
        sus = v.get("susceptibility_results", [])
        if sus:
            rows = [
                {
                    "Organism": e.get("organism", ""),
                    "Antibiotic": e.get("antibiotic", ""),
                    "MIC (mg/L)": str(e.get("mic_value", "")),
                    "Result": e.get("interpretation", ""),
                }
                for e in sus
            ]
            df = pd.DataFrame(rows)

    if trend:
        html += "<hr><strong>MIC Trend Analysis</strong><br>"
        items = trend if isinstance(trend, list) else [trend]
        for item in items:
            if not isinstance(item, dict):
                html += f"<p>{item}</p>"
                continue
            risk = item.get("risk_level", "UNKNOWN").upper()
            css  = {"HIGH": "badge-high", "MODERATE": "badge-moderate"}.get(risk, "badge-low")
            icon = {"HIGH": "🚨", "MODERATE": "⚠"}.get(risk, "✓")
            org  = item.get("organism", "")
            ab   = item.get("antibiotic", "")
            label = f"{org} / {ab} — " if (org or ab) else ""
            html += (f'<div class="{css}">{icon} <strong>{label}{risk}</strong><br>'
                     f'<span style="font-size:0.88rem">{item.get("recommendation","")}</span></div>')
    return html, df


def _build_safety_html(result: dict) -> str:
    warnings = result.get("safety_warnings", [])
    errors   = result.get("errors", [])
    html = "".join(f'<div class="badge-high">⚠ {w}</div>' for w in warnings)
    if not warnings:
        html = '<div class="badge-low">✓ No safety concerns identified.</div>'
    html += "".join(f'<div class="badge-high" style="margin-top:6px">Error: {e}</div>' for e in errors)
    return html


def _demo_result(patient_data: dict, has_labs: bool) -> dict:
    result = {
        "stage": "targeted" if has_labs else "empirical",
        "creatinine_clearance_ml_min": 58.3,
        "intake_notes": json.dumps({
            "patient_summary": (
                f"{patient_data.get('age_years')}-year-old {patient_data.get('sex')} "
                f"· {patient_data.get('suspected_source', 'infection')}"
            ),
            "creatinine_clearance_ml_min": 58.3,
            "renal_dose_adjustment_needed": True,
            "identified_risk_factors": patient_data.get("comorbidities", []),
            "infection_severity": "moderate",
            "recommended_stage": "targeted" if has_labs else "empirical",
        }),
        "recommendation": {
            "primary_antibiotic": "Ciprofloxacin",
            "dose": "500 mg",
            "route": "Oral",
            "frequency": "Every 12 hours",
            "duration": "7 days",
            "backup_antibiotic": "Nitrofurantoin 100 mg MR BD × 5 days",
            "rationale": (
                "Community-acquired UTI with moderate renal impairment (CrCl 58 mL/min). "
                "Ciprofloxacin provides broad Gram-negative coverage. "
                "No dose adjustment required above CrCl 30 mL/min."
            ),
            "references": ["IDSA UTI Guidelines 2024", "EUCAST Breakpoint Tables v16.0"],
        },
        "safety_warnings": [],
        "errors": [],
    }
    if has_labs:
        result["vision_notes"] = json.dumps({
            "specimen_type": "urine",
            "identified_organisms": [{"organism_name": "Escherichia coli", "significance": "pathogen"}],
            "susceptibility_results": [
                {"organism": "E. coli", "antibiotic": "Ciprofloxacin", "mic_value": 0.25, "interpretation": "S"},
                {"organism": "E. coli", "antibiotic": "Nitrofurantoin", "mic_value": 16, "interpretation": "S"},
                {"organism": "E. coli", "antibiotic": "Ampicillin", "mic_value": ">32", "interpretation": "R"},
            ],
            "extraction_confidence": 0.95,
        })
        result["trend_notes"] = json.dumps([{
            "organism": "E. coli", "antibiotic": "Ciprofloxacin",
            "risk_level": "LOW", "recommendation": "No MIC creep detected.",
        }])
    return result


# ── Site change / lab method handlers ─────────────────────────────────────────

def update_site_ui(site):
    grp_updates  = [gr.update(visible=(s == site)) for s in INFECTION_SITES]
    src_choices  = SUSPECTED_SOURCE_OPTIONS.get(site, []) or ["Other"]
    prominent    = site in CREATININE_PROMINENT_SITES
    return (
        *grp_updates,
        gr.update(choices=src_choices, value=src_choices[0]),
        gr.update(visible=prominent),       # creatinine_main
        gr.update(visible=not prominent),   # renal_flag
        gr.update(visible=False),           # creatinine_optional (reset hidden)
    )


def toggle_optional_creatinine(flag):
    return gr.update(visible=bool(flag))


def toggle_lab_inputs(method):
    return (
        gr.update(visible=(method == "Upload file (PDF / image)")),
        gr.update(visible=(method == "Paste lab text")),
    )


# ── Pipeline function ──────────────────────────────────────────────────────────
# Site-specific field order (matches component creation order in the Blocks):
#   urinary      : sf0  sf1  sf2                          (3 fields)
#   respiratory  : sf3  sf4  sf5  sf6                     (4 fields)
#   bloodstream  : sf7  sf8  sf9  sf10  sf11  sf12  sf13  (7 fields)
#   skin         : sf14 sf15 sf16 sf17                    (4 fields)
#   intra-abdom  : sf18 sf19 sf20 sf21                    (4 fields)
#   CNS          : sf22 sf23 sf24 sf25                    (4 fields)

def run_pipeline_ui(
    age, weight, height, sex,
    creatinine_main, renal_flag, creatinine_optional,
    infection_site, suspected_source,
    # urinary
    sf0, sf1, sf2,
    # respiratory
    sf3, sf4, sf5, sf6,
    # bloodstream
    sf7, sf8, sf9, sf10, sf11, sf12, sf13,
    # skin
    sf14, sf15, sf16, sf17,
    # intra-abdominal
    sf18, sf19, sf20, sf21,
    # CNS
    sf22, sf23, sf24, sf25,
    # medical history
    medications, allergies, comorbidities, risk_factors,
    # lab
    lab_method, lab_file, lab_paste,
    progress=gr.Progress(),
):
    # Creatinine
    if infection_site in CREATININE_PROMINENT_SITES:
        creatinine = creatinine_main
    else:
        creatinine = creatinine_optional if renal_flag else None

    # Site-specific vitals
    site_vitals: dict = {}
    if infection_site == "urinary":
        site_vitals = {
            "catheter_status": str(sf0 or ""),
            "urinary_symptoms": ", ".join(sf1) if sf1 else "",
            "urine_appearance": str(sf2 or ""),
        }
    elif infection_site == "respiratory":
        site_vitals = {
            "o2_saturation": str(sf3 or ""),
            "ventilation_status": str(sf4 or ""),
            "cough_type": str(sf5 or ""),
            "sputum_character": str(sf6 or ""),
        }
    elif infection_site == "bloodstream":
        site_vitals = {
            "central_line_present": "Yes" if sf7 else "No",
            "temperature_c": str(sf8 or ""),
            "heart_rate_bpm": str(sf9 or ""),
            "respiratory_rate": str(sf10 or ""),
            "wbc_count": str(sf11 or ""),
            "lactate_mmol": str(sf12 or ""),
            "shock_status": str(sf13 or ""),
        }
    elif infection_site == "skin":
        site_vitals = {
            "wound_type": str(sf14 or ""),
            "cellulitis_extent": str(sf15 or ""),
            "abscess_present": "Yes" if sf16 else "No",
            "foreign_body": "Yes" if sf17 else "No",
        }
    elif infection_site == "intra-abdominal":
        site_vitals = {
            "abdominal_pain_location": str(sf18 or ""),
            "peritonitis_signs": ", ".join(sf19) if sf19 else "",
            "perforation_suspected": "Yes" if sf20 else "No",
            "ascites": "Yes" if sf21 else "No",
        }
    elif infection_site == "CNS":
        site_vitals = {
            "csf_obtained": "Yes" if sf22 else "No",
            "neuro_symptoms": ", ".join(sf23) if sf23 else "",
            "recent_neurosurgery": "Yes" if sf24 else "No",
            "gcs_score": str(sf25 or ""),
        }

    # Lab file handling
    labs_raw_text   = None
    labs_image_bytes = None
    if lab_method == "Upload file (PDF / image)" and lab_file is not None:
        file_path = lab_file if isinstance(lab_file, str) else lab_file.name
        ext = file_path.rsplit(".", 1)[-1].lower()
        with open(file_path, "rb") as fh:
            file_bytes = fh.read()
        if ext == "pdf":
            try:
                import pypdf
                reader   = pypdf.PdfReader(BytesIO(file_bytes))
                extracted = "\n".join(p.extract_text() or "" for p in reader.pages).strip()
                if extracted:
                    labs_raw_text = extracted
                else:
                    labs_image_bytes = file_bytes
            except Exception:
                labs_image_bytes = file_bytes
        else:
            labs_image_bytes = file_bytes
    elif lab_method == "Paste lab text" and lab_paste:
        labs_raw_text = lab_paste.strip() or None

    patient_data = {
        "age_years":            float(age or 65),
        "weight_kg":            float(weight or 70),
        "height_cm":            float(height or 170),
        "sex":                  sex or "male",
        "serum_creatinine_mg_dl": float(creatinine) if creatinine else None,
        "infection_site":       infection_site,
        "suspected_source":     suspected_source or f"{infection_site} infection",
        "medications":          [m.strip() for m in (medications or "").split("\n") if m.strip()],
        "allergies":            [a.strip() for a in (allergies or "").split("\n") if a.strip()],
        "comorbidities":        list(comorbidities or []) + list(risk_factors or []),
        "vitals":               site_vitals,
        "labs_image_bytes":     labs_image_bytes,
    }

    has_labs = bool(labs_raw_text or labs_image_bytes)
    stages   = (
        ["Intake Historian", "Vision Specialist", "Trend Analyst", "Clinical Pharmacologist"]
        if has_labs else ["Intake Historian", "Clinical Pharmacologist"]
    )

    for i, name in enumerate(stages):
        progress((i + 0.5) / len(stages), desc=f"Running: {name}…")

    try:
        result = _run_pipeline_gpu(patient_data, labs_raw_text)
    except Exception as e:
        tb = traceback.format_exc()
        logger.error("Pipeline failed — falling back to demo result.\n%s", tb)
        result = _demo_result(patient_data, has_labs)
        result["errors"].append(f"Pipeline error: {e}")
        result["recommendation"] = {}  # suppress the hardcoded drug from showing

    progress(1.0, desc="Complete")

    rec_html          = _build_rec_html(result)
    intake_html       = _build_intake_html(result)
    lab_html, lab_df  = _build_lab_html_and_df(result)
    safety_html       = _build_safety_html(result)

    return rec_html, intake_html, lab_html, lab_df, safety_html, gr.update(visible=True)


# ── Clinical Tools handlers ────────────────────────────────────────────────────

def switch_tool(tool):
    tools = ["Empirical Advisor", "MIC Interpreter", "MIC Trend Analysis", "Drug Safety Check"]
    return [gr.update(visible=(t == tool)) for t in tools]


def run_empirical(infection_type, pathogen, risk):
    guidance = get_empirical_therapy_guidance(infection_type, list(risk or []))
    html = ""
    for i, rec in enumerate(guidance.get("recommendations", [])[:3], 1):
        score   = rec.get("relevance_score", 0)
        content = rec.get("content", "")
        source  = rec.get("source", "IDSA Guidelines 2024")
        html += (f'<div class="badge-info"><strong>Excerpt {i}</strong>'
                 f' (relevance {score:.2f})<br>{content}<br><em>Source: {source}</em></div>')
    if pathogen:
        effective = get_most_effective_antibiotics(pathogen, min_susceptibility=70)
        if effective:
            items = "".join(
                f"<li><strong>{ab.get('antibiotic')}</strong>"
                f" — {ab.get('avg_susceptibility', 0):.1f}% susceptible</li>"
                for ab in effective[:6]
            )
            html += f"<br><strong>Resistance data — {pathogen}</strong><ul style='margin:4px 0 0 16px'>{items}</ul>"
        else:
            html += '<div class="badge-info">No resistance data available for this pathogen.</div>'
    return html or '<div class="badge-info">No results found.</div>'


def run_mic_interpret(pathogen, antibiotic, mic):
    if not pathogen or not antibiotic:
        return '<div class="badge-info">Enter pathogen and antibiotic.</div>'
    result = interpret_mic_value(pathogen, antibiotic, float(mic or 1.0))
    interp = result.get("interpretation", "UNKNOWN")
    msg    = result.get("message", "")
    if interp == "SUSCEPTIBLE":
        return f'<div class="badge-low"><strong>Susceptible (S)</strong> — {msg}</div>'
    if interp == "RESISTANT":
        return f'<div class="badge-high"><strong>Resistant (R)</strong> — {msg}</div>'
    return f'<div class="badge-moderate"><strong>Intermediate (I)</strong> — {msg}</div>'


def update_mic_inputs(n):
    return [gr.update(visible=(i < int(n))) for i in range(6)]


def run_mic_trend(n, m0, m1, m2, m3, m4, m5):
    vals = [m0, m1, m2, m3, m4, m5][: int(n)]
    mic_values = [{"date": f"T{i}", "mic_value": float(v or 1.0)} for i, v in enumerate(vals)]
    result = calculate_mic_trend(mic_values)
    risk   = result.get("risk_level", "UNKNOWN")
    alert  = result.get("alert", "")
    css    = {"HIGH": "badge-high", "MODERATE": "badge-moderate"}.get(risk, "badge-low")
    icon   = {"HIGH": "🚨", "MODERATE": "⚠"}.get(risk, "✓")
    base   = result.get("baseline_mic", "—")
    curr   = result.get("current_mic", "—")
    ratio  = result.get("ratio", "—")
    return f"""
<div class="{css}">{icon} <strong>{risk} RISK</strong> — {alert}</div>
<br>
<table><tr>
<td style='padding:8px 24px 8px 0'><strong>Baseline MIC</strong><br>{base} mg/L</td>
<td style='padding:8px 24px'><strong>Current MIC</strong><br>{curr} mg/L</td>
<td style='padding:8px 24px'><strong>Fold change</strong><br>{ratio}×</td>
</tr></table>"""


def run_drug_safety(ab, meds, allergies_txt):
    if not ab:
        return '<div class="badge-info">Enter an antibiotic to check.</div>'
    med_list     = [m.strip() for m in (meds or "").split("\n") if m.strip()]
    allergy_list = [a.strip() for a in (allergies_txt or "").split("\n") if a.strip()]
    result = screen_antibiotic_safety(ab, med_list, allergy_list)
    if result.get("safe_to_use"):
        html = '<div class="badge-low">✓ No critical safety concerns identified.</div>'
    else:
        html = '<div class="badge-high">⚠ Safety concerns identified — review required.</div>'
    html += "".join(
        f'<div class="badge-moderate" style="margin-top:8px">⚠ {a.get("message","")}</div>'
        for a in result.get("alerts", [])
    )
    return html


def run_guidelines_search(query, pathogen_filter):
    if not query:
        return '<div class="badge-info">Enter a search query.</div>'
    filt    = None if pathogen_filter == "All" else pathogen_filter
    results = search_clinical_guidelines(query, pathogen_filter=filt, n_results=5)
    if not results:
        return ('<div class="badge-info">No results found. Try broader search terms or '
                'check that the knowledge base has been initialised.</div>')
    html = ""
    for i, r in enumerate(results, 1):
        score   = r.get("relevance_score", 0)
        content = r.get("content", "")
        source  = r.get("source", "")
        src_str = f"<br><em>Source: {source}</em>" if source else ""
        html += (f'<div class="badge-info"><strong>Result {i}</strong>'
                 f' · relevance {score:.2f}<br>{content}{src_str}</div>')
    return html


# ── Widget factory for site-specific fields ────────────────────────────────────

def _make_site_widget(field):
    ftype = field["type"]
    label = field["label"]
    if ftype == "selectbox":
        return gr.Dropdown(choices=field["options"], value=field["options"][0], label=label)
    if ftype == "multiselect":
        return gr.CheckboxGroup(choices=field["options"], label=label)
    if ftype == "number_input":
        return gr.Number(
            value=field.get("default", 0), label=label,
            minimum=field.get("min"), maximum=field.get("max"),
        )
    if ftype == "checkbox":
        return gr.Checkbox(value=field.get("default", False), label=label)
    return gr.Textbox(label=label)


# ── Models table (build-time) ─────────────────────────────────────────────────

_s = get_settings()
OVERVIEW_MODELS_MD = f"""
| Agent | Role | Model |
|---|---|---|
| 1, 2, 4 | Clinical reasoning | `{_s.medgemma_4b_model or "google/medgemma-4b-it"}` |
| 3 | Trend analysis | `{_s.medgemma_27b_model or "google/medgemma-27b-text-it"}` |
| 4 (safety) | Pharmacology check | `{_s.txgemma_9b_model or "google/txgemma-9b-predict"}` |
| — | Semantic retrieval | `{_s.embedding_model_name}` |
| — | Inference backend | HuggingFace Transformers · {_s.quantization} quant |
"""

# ── Gradio Blocks ─────────────────────────────────────────────────────────────

with gr.Blocks(theme=gr.themes.Soft(), css=CSS, title="AMR-Guard") as demo:
    gr.HTML(BANNER_HTML)

    with gr.Tabs():

        # ── Tab 1: Overview ────────────────────────────────────────────────────
        with gr.Tab("Overview"):
            gr.HTML("""
<div class="section-title">System Overview</div>
<div class="stat-cards">
  <div class="stat-card">
    <div class="label">WHO AWaRe</div><div class="value">264</div><div class="sub">antibiotics classified</div>
  </div>
  <div class="stat-card">
    <div class="label">EUCAST</div><div class="value">v16.0</div><div class="sub">breakpoint tables</div>
  </div>
  <div class="stat-card">
    <div class="label">IDSA</div><div class="value">2024</div><div class="sub">treatment guidelines</div>
  </div>
  <div class="stat-card">
    <div class="label">DDInter</div><div class="value">191K+</div><div class="sub">drug interactions</div>
  </div>
</div>
<div class="section-title">Agent Pipeline</div>
""")
            with gr.Row():
                with gr.Column():
                    gr.HTML("""
<p><strong>Stage 1 — Empirical</strong> <em>(no lab results yet)</em></p>
<div class="agent-step"><div class="num">Agent 01</div><div class="name">Intake Historian</div>
<div class="desc">Parses patient data, calculates CrCl, identifies MDR risk factors</div></div>
<div class="agent-step"><div class="num">Agent 04</div><div class="name">Clinical Pharmacologist</div>
<div class="desc">Empirical antibiotic selection · WHO AWaRe · safety screening</div></div>
""")
                with gr.Column():
                    gr.HTML("""
<p><strong>Stage 2 — Targeted</strong> <em>(culture / sensitivity available)</em></p>
<div class="agent-step"><div class="num">Agent 01</div><div class="name">Intake Historian</div>
<div class="desc">Same as Stage 1</div></div>
<div class="agent-step"><div class="num">Agent 02</div><div class="name">Vision Specialist</div>
<div class="desc">Extracts structured data from lab reports (any language / format)</div></div>
<div class="agent-step"><div class="num">Agent 03</div><div class="name">Trend Analyst</div>
<div class="desc">Detects MIC creep · calculates resistance velocity</div></div>
<div class="agent-step"><div class="num">Agent 04</div><div class="name">Clinical Pharmacologist</div>
<div class="desc">Targeted recommendation informed by susceptibility data</div></div>
""")
            gr.HTML('<div class="section-title">AI Models (Local)</div>')
            gr.Markdown(OVERVIEW_MODELS_MD)
            gr.HTML(
                '<div class="disclaimer">⚠ <strong>Research demo only.</strong> '
                "Not validated for clinical use. All recommendations must be reviewed "
                "by a licensed clinician before any patient-care decision.</div>"
            )

        # ── Tab 2: Patient Analysis ────────────────────────────────────────────
        with gr.Tab("Patient Analysis"):
            gr.HTML('<div class="section-title">Patient Analysis Pipeline</div>')

            # Demographics row
            with gr.Row():
                with gr.Column(scale=1):
                    age    = gr.Number(value=65,   label="Age (years)",   minimum=0,   maximum=120, precision=0)
                    weight = gr.Number(value=70.0, label="Weight (kg)",   minimum=1.0, maximum=300.0)
                    height = gr.Number(value=170.0,label="Height (cm)",   minimum=50.0,maximum=250.0)
                with gr.Column(scale=1):
                    sex               = gr.Dropdown(choices=["male", "female"], value="male", label="Biological sex")
                    creatinine_main   = gr.Number(value=1.2, label="Serum Creatinine (mg/dL)",
                                                  minimum=0.1, maximum=20.0, visible=True)
                    renal_flag        = gr.Checkbox(label="Known renal impairment / CKD?", visible=False)
                    creatinine_optional = gr.Number(value=1.2, label="Serum Creatinine (mg/dL)",
                                                    minimum=0.1, maximum=20.0, visible=False)
                with gr.Column(scale=1):
                    infection_site  = gr.Dropdown(choices=INFECTION_SITES, value="urinary",
                                                  label="Primary infection site")
                    _init_src = SUSPECTED_SOURCE_OPTIONS.get("urinary", [])
                    suspected_source = gr.Dropdown(choices=_init_src,
                                                   value=_init_src[0] if _init_src else None,
                                                   label="Suspected source")

            # Site-specific field groups (pre-rendered, one per site)
            site_groups: dict = {}
            # Component lists per site (in field declaration order)
            u_comps:  list = []  # 3 components
            r_comps:  list = []  # 4 components
            b_comps:  list = []  # 7 components
            sk_comps: list = []  # 4 components
            ia_comps: list = []  # 4 components
            cn_comps: list = []  # 4 components

            for site in INFECTION_SITES:
                fields = SITE_SPECIFIC_FIELDS.get(site, [])
                with gr.Group(visible=(site == "urinary")) as grp:
                    if fields:
                        gr.HTML(f'<div class="section-title">{site.title()} — Assessment</div>')
                        with gr.Row():
                            for field in fields:
                                comp = _make_site_widget(field)
                                if site == "urinary":
                                    u_comps.append(comp)
                                elif site == "respiratory":
                                    r_comps.append(comp)
                                elif site == "bloodstream":
                                    b_comps.append(comp)
                                elif site == "skin":
                                    sk_comps.append(comp)
                                elif site == "intra-abdominal":
                                    ia_comps.append(comp)
                                elif site == "CNS":
                                    cn_comps.append(comp)
                site_groups[site] = grp

            # Flatten all site components in fixed order for fn inputs
            all_site_inputs = u_comps + r_comps + b_comps + sk_comps + ia_comps + cn_comps

            # Medical history
            gr.HTML('<div class="section-title">Medical History</div>')
            with gr.Row():
                with gr.Column():
                    medications = gr.Textbox(
                        label="Current medications (one per line)",
                        placeholder="Metformin\nLisinopril", lines=4,
                    )
                    allergies = gr.Textbox(
                        label="Drug allergies (one per line)",
                        placeholder="Penicillin\nSulfa", lines=3,
                    )
                with gr.Column():
                    comorbidities = gr.CheckboxGroup(
                        choices=["Diabetes", "CKD", "Heart Failure", "COPD",
                                 "Immunocompromised", "Recent Surgery", "Malignancy", "Liver Disease"],
                        label="Comorbidities",
                    )
                    risk_factors = gr.CheckboxGroup(
                        choices=["Prior MRSA", "Recent antibiotics (<90 d)", "Healthcare-associated",
                                 "Recent hospitalisation", "Nursing home", "Prior MDR infection"],
                        label="MDR risk factors",
                    )

            # Lab input
            gr.HTML('<div class="section-title">Lab / Culture Results '
                    '<small>(optional — triggers targeted pathway)</small></div>')
            lab_method = gr.Radio(
                choices=["None — empirical pathway only", "Upload file (PDF / image)", "Paste lab text"],
                value="None — empirical pathway only",
                label="Input method",
            )
            lab_file  = gr.File(
                label="Lab report",
                file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".bmp"],
                visible=False,
            )
            lab_paste = gr.Textbox(
                label="Lab report text",
                placeholder=(
                    "Organism: Escherichia coli\n"
                    "Ciprofloxacin: S  MIC 0.25\n"
                    "Nitrofurantoin: S  MIC 16\n"
                    "Ampicillin: R  MIC >32"
                ),
                lines=5, visible=False,
            )

            run_btn = gr.Button("Run Agent Pipeline", variant="primary")

            # Results (hidden until pipeline completes)
            with gr.Group(visible=False) as results_group:
                gr.HTML('<div class="section-title">Results</div>')
                with gr.Tabs():
                    with gr.Tab("Recommendation"):
                        rec_out    = gr.HTML()
                    with gr.Tab("Patient Summary"):
                        intake_out = gr.HTML()
                    with gr.Tab("Lab Analysis"):
                        lab_html_out = gr.HTML()
                        lab_df_out   = gr.DataFrame(label="Susceptibility Table", wrap=True)
                    with gr.Tab("Safety"):
                        safety_out = gr.HTML()

            # ── Wiring ──
            infection_site.change(
                fn=update_site_ui,
                inputs=[infection_site],
                outputs=[
                    *[site_groups[s] for s in INFECTION_SITES],
                    suspected_source,
                    creatinine_main,
                    renal_flag,
                    creatinine_optional,
                ],
            )
            renal_flag.change(
                fn=toggle_optional_creatinine,
                inputs=[renal_flag],
                outputs=[creatinine_optional],
            )
            lab_method.change(
                fn=toggle_lab_inputs,
                inputs=[lab_method],
                outputs=[lab_file, lab_paste],
            )
            _loading_html = '<div class="badge-info" style="padding:16px;text-align:center;">⏳ Pipeline running — please wait…</div>'
            run_btn.click(
                fn=lambda: (
                    _loading_html, _loading_html, _loading_html,
                    pd.DataFrame(), _loading_html,
                    gr.update(visible=True),
                ),
                inputs=[],
                outputs=[rec_out, intake_out, lab_html_out, lab_df_out, safety_out, results_group],
                queue=False,
            ).then(
                fn=run_pipeline_ui,
                inputs=[
                    age, weight, height, sex,
                    creatinine_main, renal_flag, creatinine_optional,
                    infection_site, suspected_source,
                    *all_site_inputs,
                    medications, allergies, comorbidities, risk_factors,
                    lab_method, lab_file, lab_paste,
                ],
                outputs=[rec_out, intake_out, lab_html_out, lab_df_out, safety_out, results_group],
            )

        # ── Tab 3: Clinical Tools ──────────────────────────────────────────────
        with gr.Tab("Clinical Tools"):
            gr.HTML('<div class="section-title">Clinical Tools</div>')
            tool_sel = gr.Dropdown(
                choices=["Empirical Advisor", "MIC Interpreter", "MIC Trend Analysis", "Drug Safety Check"],
                value="Empirical Advisor",
                label="Select tool",
            )

            # Empirical Advisor
            with gr.Group(visible=True) as grp_ea:
                with gr.Row():
                    with gr.Column(scale=3):
                        ea_infection = gr.Dropdown(
                            choices=["Urinary Tract Infection", "Pneumonia", "Sepsis",
                                     "Skin / Soft Tissue", "Intra-abdominal", "Meningitis"],
                            value="Urinary Tract Infection", label="Infection type",
                        )
                        ea_pathogen = gr.Textbox(
                            label="Suspected pathogen (optional)",
                            placeholder="e.g., Klebsiella pneumoniae",
                        )
                        ea_risk = gr.CheckboxGroup(
                            choices=["Prior MRSA", "Recent antibiotics (<90 d)", "Healthcare-associated",
                                     "Immunocompromised", "Renal impairment", "Prior MDR"],
                            label="Risk factors",
                        )
                    with gr.Column(scale=1):
                        gr.HTML("""
<div class="badge-info"><strong style="color:#dceeff">WHO AWaRe</strong><br>
<span style="color:#86efac">●</span> Access — first-line<br>
<span style="color:#fcd34d">●</span> Watch — second-line<br>
<span style="color:#fca5a5">●</span> Reserve — last resort</div>""")
                ea_btn = gr.Button("Get recommendation", variant="primary")
                ea_out = gr.HTML()

            # MIC Interpreter
            with gr.Group(visible=False) as grp_mi:
                with gr.Row():
                    with gr.Column():
                        mi_pathogen  = gr.Textbox(label="Pathogen",   placeholder="e.g., Escherichia coli")
                        mi_antibiotic= gr.Textbox(label="Antibiotic", placeholder="e.g., Ciprofloxacin")
                        mi_mic       = gr.Number(value=1.0, label="MIC value (mg/L)", minimum=0.001, maximum=1024.0)
                    with gr.Column():
                        gr.HTML("""
<div class="badge-info" style="margin-top:28px"><strong>Interpretation guide</strong><br><br>
<strong>S</strong> Susceptible — antibiotic is effective<br>
<strong>I</strong> Intermediate — effective at higher doses<br>
<strong>R</strong> Resistant — do not use</div>""")
                mi_btn = gr.Button("Interpret", variant="primary")
                mi_out = gr.HTML()

            # MIC Trend Analysis
            with gr.Group(visible=False) as grp_mt:
                mt_n = gr.Slider(minimum=2, maximum=6, value=3, step=1,
                                 label="Number of historical readings")
                with gr.Row():
                    mt_m = [
                        gr.Number(value=float(2 ** i), label=f"MIC {i+1} (mg/L)",
                                  minimum=0.001, maximum=256.0, visible=(i < 3))
                        for i in range(6)
                    ]
                mt_btn = gr.Button("Analyse trend", variant="primary")
                mt_out = gr.HTML()
                mt_n.change(fn=update_mic_inputs, inputs=[mt_n], outputs=mt_m)

            # Drug Safety Check
            with gr.Group(visible=False) as grp_ds:
                with gr.Row():
                    with gr.Column():
                        ds_ab   = gr.Textbox(label="Antibiotic to check",
                                             placeholder="e.g., Ciprofloxacin")
                        ds_meds = gr.Textbox(label="Concurrent medications",
                                             placeholder="Warfarin\nMetformin\nAmlodipine", lines=4)
                    with gr.Column():
                        ds_allergies = gr.Textbox(label="Known allergies",
                                                  placeholder="Penicillin\nSulfa", lines=3)
                ds_btn = gr.Button("Check safety", variant="primary")
                ds_out = gr.HTML()

            tool_sel.change(
                fn=switch_tool, inputs=[tool_sel],
                outputs=[grp_ea, grp_mi, grp_mt, grp_ds],
            )
            ea_btn.click(fn=run_empirical,     inputs=[ea_infection, ea_pathogen, ea_risk], outputs=[ea_out])
            mi_btn.click(fn=run_mic_interpret, inputs=[mi_pathogen, mi_antibiotic, mi_mic], outputs=[mi_out])
            mt_btn.click(fn=run_mic_trend,     inputs=[mt_n, *mt_m],                       outputs=[mt_out])
            ds_btn.click(fn=run_drug_safety,   inputs=[ds_ab, ds_meds, ds_allergies],      outputs=[ds_out])

        # ── Tab 4: Guidelines ──────────────────────────────────────────────────
        with gr.Tab("Guidelines"):
            gr.HTML('<div class="section-title">Clinical Guidelines Search</div>')
            with gr.Row():
                gl_query  = gr.Textbox(
                    label="Search query",
                    placeholder="e.g., ESBL E. coli UTI treatment carbapenems",
                    scale=3,
                )
                gl_filter = gr.Dropdown(
                    choices=["All", "ESBL-E", "CRE", "CRAB", "DTR-PA"],
                    value="All", label="Filter by pathogen", scale=1,
                )
            gl_btn = gr.Button("Search", variant="primary")
            gl_out = gr.HTML()
            gr.HTML(
                '<div class="disclaimer">Sources: IDSA Treatment Guidelines 2024 · '
                "EUCAST Breakpoint Tables v16.0 · WHO EML · DDInter drug interaction database.</div>"
            )
            gl_btn.click(fn=run_guidelines_search, inputs=[gl_query, gl_filter], outputs=[gl_out])


if __name__ == "__main__":
    demo.launch()