File size: 41,722 Bytes
efaef67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ff4a07
efaef67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8a1503
 
efaef67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ff4a07
 
 
1f69fb6
 
 
 
 
 
 
 
 
7ff4a07
49a8c05
 
 
 
 
 
 
 
 
 
 
 
efaef67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8a1503
 
 
 
 
efaef67
 
 
 
 
 
 
d1b7eeb
efaef67
d1b7eeb
 
 
 
efaef67
 
 
 
 
 
 
d1b7eeb
 
efaef67
d1b7eeb
 
 
 
 
2a71103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de82703
 
 
 
 
 
 
 
efaef67
d1b7eeb
 
 
 
 
 
 
 
 
 
 
efaef67
d1b7eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efaef67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bd03ca
efaef67
 
 
 
e8a1503
efaef67
 
 
 
 
e8a1503
 
 
 
 
 
efaef67
 
 
 
e8a1503
efaef67
 
e8a1503
 
 
 
 
 
 
 
 
 
 
 
 
efaef67
 
 
 
 
 
 
 
 
e8a1503
 
 
 
efaef67
e8a1503
efaef67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aef3662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49a8c05
 
1f69fb6
 
 
 
 
49a8c05
 
 
 
 
 
aef3662
49a8c05
 
 
 
 
 
 
 
 
 
 
aef3662
1f69fb6
49a8c05
 
 
 
 
aef3662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efaef67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66aa6ed
efaef67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e08a816
 
 
 
 
 
 
 
 
efaef67
 
 
 
 
 
 
 
 
8273d80
efaef67
 
 
 
 
47ed39d
efaef67
 
 
 
 
 
 
dd6df30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efaef67
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
"""
TraceScene β€” Gradio ZeroGPU Application

Serves the custom TraceScene frontend + REST API with GPU-accelerated inference.
Architecture:
  - Gradio demo at / (primary β€” required for ZeroGPU)
  - Custom FastAPI routes added to Gradio's internal app for REST API
  - Custom HTML/CSS/JS frontend served alongside
  - @spaces.GPU wraps inference for dynamic GPU allocation
"""

import os
from pathlib import Path

import torch
import gradio as gr
import spaces

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse

# ── Backend Imports ────────────────────────────────────────────────────
from backend.app.config import settings
from backend.app.db.database import db
from backend.app.core.inference import inference_engine, chat_engine, SCENE_ANALYSIS_PROMPT
from backend.app.core.scene_analyzer import SceneAnalyzer
from backend.app.core.rule_matcher import RuleMatcher
from backend.app.core.fault_deducer import FaultDeducer
from backend.app.core.report_generator import ReportGenerator
from backend.app.rules.rule_loader import rule_loader
from backend.app.utils.logger import get_logger
from backend.app.api.routes import router

logger = get_logger("app")

scene_analyzer = SceneAnalyzer()
rule_matcher = RuleMatcher()
fault_deducer = FaultDeducer()
report_generator = ReportGenerator()

from backend.app.core.reference_data import REFERENCE_CASES

# ── ZeroGPU: Top-level decorated function ──────────────────────────────
# This MUST be a top-level function wired to a Gradio event handler.

_original_run_inference = inference_engine._run_inference  # bound method


@spaces.GPU(duration=120)
def gpu_run_inference(image, prompt):
    """GPU-accelerated inference β€” ZeroGPU allocates GPU for this call."""
    return _original_run_inference(image, prompt)


# Monkey-patch so the entire pipeline uses GPU
inference_engine._run_inference = gpu_run_inference

_original_chat = chat_engine.chat

@spaces.GPU(duration=60)
def gpu_run_chat(system_context, user_message):
    """GPU-accelerated chat inference."""
    try:
        # We call the engine's original method directly to avoid monkey-patch recursion
        # And let the engine handle its own loading inside this GPU worker
        return _original_chat(system_context, user_message)
    except Exception as e:
        logger.error(f"ZeroGPU Chat Worker Error: {e}")
        return f"Worker Error: {e}"

_original_chat_stream = chat_engine.chat_stream

@spaces.GPU(duration=60)
def gpu_run_chat_stream(system_context, user_message):
    """GPU-accelerated streaming chat inference."""
    try:
        for token_text in _original_chat_stream(system_context, user_message):
            yield token_text
    except Exception as e:
        logger.error(f"ZeroGPU Chat Stream Worker Error: {e}")
        yield f"Worker Error: {e}"


# ── Async helpers ──────────────────────────────────────────────────────

def run_async(coro):
    """Run async coroutine from sync Gradio context."""
    import asyncio
    try:
        loop = asyncio.get_event_loop()
        if loop.is_running():
            import concurrent.futures
            with concurrent.futures.ThreadPoolExecutor() as pool:
                return pool.submit(asyncio.run, coro).result()
        return loop.run_until_complete(coro)
    except RuntimeError:
        return asyncio.run(coro)


# ── Initialize backend ────────────────────────────────────────────────

_initialized = False


async def _ensure_init():
    global _initialized
    if _initialized:
        return
    await db.connect()
    rule_loader.load_rules()
    try:
        inference_engine.load_model()
    except Exception as e:
        logger.error(f"Vision model load failed: {e}")
    _initialized = True


def ensure_init():
    run_async(_ensure_init())


# ── Gradio Handlers ───────────────────────────────────────────────────

def gradio_analyze_photo(image):
    """Analyze a single uploaded photo via GPU."""
    if image is None:
        return "Please upload an image."
    from PIL import Image as PILImage
    if not isinstance(image, PILImage.Image):
        image = PILImage.fromarray(image)

    ensure_init()
    if not inference_engine.is_loaded:
        inference_engine.load_model()

    result = gpu_run_inference(image, SCENE_ANALYSIS_PROMPT)
    return result


import json
import hashlib
import time
from PIL import Image


def create_case_fn(case_number, officer_name, location, incident_date, notes):
    """Create a new accident case."""
    if not case_number or not case_number.strip():
        return "❌ Case number is required.", list_cases_fn()
    ensure_init()
    try:
        cid = run_async(db.create_case(
            case_number=case_number.strip(),
            officer_name=officer_name.strip() if officer_name else None,
            location=location.strip() if location else None,
            incident_date=incident_date if incident_date else None,
            notes=notes.strip() if notes else None,
        ))
        return f"βœ… Case **{case_number}** created (ID: {cid})", list_cases_fn()
    except Exception as e:
        return f"❌ {e}", list_cases_fn()


def list_cases_fn():
    """List all cases."""
    ensure_init()
    try:
        cases = run_async(db.list_cases())
        if not cases:
            return []
        rows = []
        for c in cases:
            photos = run_async(db.get_photos_by_case(c["id"]))
            rows.append([
                c["id"], c["case_number"],
                c.get("officer_name", "β€”"), c.get("location", "β€”"),
                c.get("incident_date", "β€”"), c["status"], len(photos),
            ])
        return rows
    except Exception:
        return []


def delete_case_fn(case_id):
    """Delete a case."""
    if not case_id:
        return "❌ Enter a Case ID.", list_cases_fn()
    ensure_init()
    try:
        run_async(db.delete_case(int(case_id)))
        return f"βœ… Case {int(case_id)} deleted.", list_cases_fn()
    except Exception as e:
        return f"❌ {e}", list_cases_fn()


def upload_photos_fn(case_id, files):
    """Upload photos to a case."""
    if not case_id:
        return "❌ Enter a Case ID."
    if not files:
        return "❌ Select photos to upload."
    ensure_init()
    try:
        case = run_async(db.get_case(int(case_id)))
        if not case:
            return f"❌ Case {int(case_id)} not found."

        case_dir = settings.upload_path / f"case_{int(case_id)}"
        case_dir.mkdir(parents=True, exist_ok=True)

        count = 0
        for fp in files:
            with open(fp, "rb") as f:
                content = f.read()
            filename = Path(fp).name
            ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
            if ext not in settings.allowed_extensions_list:
                continue
            fhash = hashlib.md5(content).hexdigest()[:12]
            dest = case_dir / f"{fhash}_{filename}"
            with open(dest, "wb") as f:
                f.write(content)
            w, h = None, None
            try:
                img = Image.open(dest)
                w, h = img.size
            except Exception:
                pass
            run_async(db.add_photo(
                case_id=int(case_id), filename=filename,
                filepath=str(dest), file_size=len(content),
                width=w, height=h,
            ))
            count += 1
        return f"βœ… Uploaded {count} photo(s) to Case {int(case_id)}."
    except Exception as e:
        return f"❌ {e}"


def get_case_photos_fn(case_id):
    """Get photo gallery for a case."""
    if not case_id:
        return []
    ensure_init()
    try:
        photos = run_async(db.get_photos_by_case(int(case_id)))
        if not photos:
            # Check reference cases
            ref = REFERENCE_CASES.get(int(case_id))
            if ref:
                return [(p["filepath"], p["filename"]) for p in ref["photos"]]
        return [(p["filepath"], p["filename"]) for p in photos if Path(p["filepath"]).exists()]
    except Exception:
        return []


def run_analysis_fn(case_id, progress=gr.Progress()):
    """Run the full AI analysis pipeline (GPU-accelerated)."""
    import traceback
    try:
        if not case_id:
            return "❌ Enter a Case ID.", "", ""
        ensure_init()

        case = run_async(db.get_case(int(case_id)))
        if not case:
            return "❌ Case not found.", "", ""
        photos = run_async(db.get_photos_by_case(int(case_id)))
        if not photos:
            return "❌ No photos uploaded.", "", ""

        if not inference_engine.is_loaded:
            inference_engine.load_model()

        # Step 1: Analyze each photo
        analysis_results = []
        for i, photo in enumerate(photos):
            progress((i + 1) / len(photos) * 0.5, desc=f"Analyzing photo {i+1}/{len(photos)}...")
            try:
                img = Image.open(photo["filepath"])
                start = time.perf_counter()
                raw = gpu_run_inference(img, SCENE_ANALYSIS_PROMPT)
                elapsed_ms = (time.perf_counter() - start) * 1000
                parsed = scene_analyzer._parse_analysis(raw)
                run_async(db.add_scene_analysis(
                    photo_id=photo["id"], raw_analysis=raw,
                    vehicles_json=json.dumps(parsed.get("vehicles", [])) if parsed.get("vehicles") else None,
                    road_conditions_json=json.dumps(parsed.get("road_conditions", {})) if parsed.get("road_conditions") else None,
                    evidence_json=json.dumps(parsed.get("evidence", {})) if parsed.get("evidence") else None,
                    environmental_json=json.dumps(parsed.get("environmental", {})) if parsed.get("environmental") else None,
                    positions_json=json.dumps(parsed.get("positions", {})) if parsed.get("positions") else None,
                    model_id=settings.model_id, inference_time_ms=elapsed_ms,
                ))
                analysis_results.append({"filename": photo["filename"], "analysis": raw, "time_ms": round(elapsed_ms)})
            except Exception as e:
                err_msg = f"Error: {e}"
                run_async(db.add_scene_analysis(
                    photo_id=photo["id"], 
                    raw_analysis=err_msg,
                    model_id=settings.model_id, 
                    inference_time_ms=0,
                ))
                analysis_results.append({"filename": photo["filename"], "analysis": err_msg, "time_ms": 0})

        # Identify parties
        progress(0.55, desc="Identifying parties...")
        all_analyses = run_async(db.get_analyses_by_case(int(case_id)))
        parties_data = scene_analyzer._identify_parties(all_analyses)
        run_async(db.clear_parties(int(case_id)))
        for p in parties_data:
            run_async(db.add_party(
                case_id=int(case_id), label=p.get("label", "Unknown"),
                vehicle_type=p.get("vehicle_type"), vehicle_color=p.get("vehicle_color"),
                vehicle_description=p.get("description"),
            ))

        # Step 2: Rule matching
        progress(0.65, desc="Matching traffic rules...")
        violations = run_async(rule_matcher.match_violations(int(case_id)))

        # Step 3: Fault deduction
        progress(0.8, desc="Deducing fault...")
        fault_result = run_async(fault_deducer.deduce_fault(int(case_id)))
        run_async(db.update_case_status(int(case_id), "complete"))

        # Format output
        total_time = sum(r["time_ms"] for r in analysis_results)
        analysis_text = ""
        for r in analysis_results:
            analysis_text += f"### πŸ“· {r['filename']} ({r['time_ms']}ms)\n```\n{r['analysis']}\n```\n---\n\n"

        violations_text = f"Found {len(violations)} violation(s):\n"
        for v in violations:
            violations_text += f"\nβ€’ **{v.get('rule_title', '?')}** ({v.get('severity', '?')}) β€” {v.get('confidence', 0):.0%}"
        violations_text += f"\n\n### Fault: {fault_result.get('primary_fault_party', 'N/A')}"
        violations_text += f"\nConfidence: {fault_result.get('overall_confidence', 0):.0%}"
        violations_text += f"\n\n{fault_result.get('analysis_summary', '')}"

        progress(1.0, desc="Complete!")
        return f"βœ… Done! {len(photos)} photos in {total_time/1000:.1f}s", analysis_text, violations_text
    except Exception as e:
        import traceback
        return f"❌ Python Error: {e}", traceback.format_exc(), ""


def generate_report_fn(case_id):
    """Generate incident report."""
    if not case_id:
        return "❌ Enter a Case ID."
    ensure_init()
    try:
        report = run_async(report_generator.generate_report(int(case_id)))
    except Exception as e:
        return f"❌ {e}"
    if "error" in report:
        return f"❌ {report['error']}"

    c = report.get("case", {})
    stats = report.get("statistics", {})
    fa = report.get("fault_analysis", {})
    md = f"""# πŸš” TraceScene Report
> Case: {c.get('case_number', 'β€”')} | Officer: {c.get('officer_name', 'β€”')}
> Location: {c.get('location', 'β€”')} | Date: {c.get('incident_date', 'β€”')}

*{report.get('disclaimer', '')}*

| Metric | Value |
|---|---|
| Photos | {stats.get('analyzed_photos', 0)} |
| Violations | {stats.get('total_violations', 0)} |
| Critical | {stats.get('critical_violations', 0)} |
| Parties | {stats.get('parties_identified', 0)} |

## Scene Summary
{report.get('scene_summary', 'N/A')}

## Violations
"""
    for v in report.get("violations", {}).get("list", []):
        md += f"- **{v.get('title', '?')}** [{v.get('severity', '?')}] β€” {v.get('party', '?')} ({v.get('confidence', 0):.0%})\n"
    md += f"\n## Fault Analysis\n"
    if fa.get("determined"):
        md += f"**Primary Fault:** {fa.get('primary_fault_party', '?')}\n"
        md += f"**Confidence:** {fa.get('overall_confidence', 0):.0%}\n"
        md += f"\n{fa.get('probable_cause', '')}\n"
    return md


def get_rules_fn():
    """Get traffic rules."""
    ensure_init()
    data = rule_loader.get_all_rules()
    categories = data.get("categories", [])
    if not categories:
        return "No rules loaded."
    md = "# πŸ“œ Traffic Rules\n\n"
    for cat in categories:
        md += f"## {cat.get('name', '?')} ({cat.get('rule_count', 0)})\n"
        md += "| ID | Title | Severity | Weight |\n|---|---|---|---|\n"
        for r in cat.get("rules", []):
            md += f"| {r.get('id', '')} | {r.get('title', '')} | {r.get('severity', '')} | {r.get('fault_weight', '')} |\n"
        md += "\n"
    return md


# ── JSON API functions (for custom frontend via @gradio/client) ────────

def health_fn():
    """Return system health as JSON."""
    ensure_init()
    return json.dumps({
        "status": "ok",
        "model_loaded": inference_engine.is_loaded,
        "model_id": settings.model_id if inference_engine.is_loaded else None,
        "device": inference_engine._device if inference_engine.is_loaded else None,
        "rules_loaded": len(rule_loader.get_all_rules()),
    })


def list_cases_json():
    """List cases as JSON, including reference cases."""
    ensure_init()
    cases = run_async(db.list_cases())
    for c in cases:
        photos = run_async(db.get_photos_by_case(c["id"]))
        c["photo_count"] = len(photos)
        c["is_reference"] = False
    
    # Add reference cases
    ref_list = [v["case"] for v in REFERENCE_CASES.values()]
    cases = ref_list + cases
    
    return json.dumps({"cases": cases})


def get_case_json(case_id):
    """Get full case details as JSON, handling reference cases."""
    if not case_id:
        return json.dumps({"error": "No case ID"})
    
    # Check reference cases first
    ref = REFERENCE_CASES.get(int(case_id))
    if ref:
        data = ref.copy()
        data["stats"] = {
            "total_photos": len(data["photos"]),
            "analyzed_photos": len(data["analyses"]),
            "violations_found": len(data["violations"]),
            "parties_identified": len(data["parties"]),
        }
        return json.dumps(data)

    ensure_init()
    case = run_async(db.get_case(int(case_id)))
    if not case:
        return json.dumps({"error": f"Case {int(case_id)} not found"})
    photos = run_async(db.get_photos_by_case(int(case_id)))
    analyses = run_async(db.get_analyses_by_case(int(case_id)))
    parties = run_async(db.get_parties_by_case(int(case_id)))
    violations = run_async(db.get_violations_by_case(int(case_id)))
    fault = run_async(db.get_fault_analysis(int(case_id)))
    
    case_dict = dict(case)
    case_dict["is_reference"] = False

    return json.dumps({
        "case": case_dict,
        "photos": photos,
        "analyses": analyses,
        "parties": parties,
        "violations": violations,
        "fault_analysis": fault,
        "stats": {
            "total_photos": len(photos),
            "analyzed_photos": len(analyses),
            "violations_found": len(violations),
            "parties_identified": len(parties),
        },
    })


def get_report_json(case_id):
    """Get report as JSON."""
    if not case_id:
        return json.dumps({"error": "No case ID"})
    ensure_init()
    report = run_async(report_generator.generate_report(int(case_id)))
    return json.dumps(report)


def get_rules_json():
    """Get rules as JSON."""
    ensure_init()
    return json.dumps(rule_loader.get_all_rules())


def load_chat_context(case_id):
    if not case_id:
        default_ctx = "You are TraceScene AI assistant. You help insurers and investigating officers analyze accident cases, traffic rules, and insurance clauses. Answer concisely and accurately.\n\n"
        # Load traffic rules as general context
        ensure_init()
        rules_data = rule_loader.get_all_rules()
        rules_text = ""
        for cat in rules_data.get("categories", []):
            rules_text += f"\nCategory: {cat.get('name', '')}\n"
            for r in cat.get("rules", []):
                rules_text += f"  - {r.get('id', '')}: {r.get('title', '')} (Severity: {r.get('severity', '')})\n"
        ctx = default_ctx + "TRAFFIC RULES:\n" + rules_text
        return ctx, "*General mode: traffic rules loaded. Ask any question!*"

    ensure_init()
    case = run_async(db.get_case(int(case_id)))
    if not case:
        return "", f"❌ Case {int(case_id)} not found."

    analyses = run_async(db.get_analyses_by_case(int(case_id)))
    parties = run_async(db.get_parties_by_case(int(case_id)))
    violations = run_async(db.get_violations_by_case(int(case_id)))
    fault = run_async(db.get_fault_analysis(int(case_id)))
    rules_data = rule_loader.get_all_rules()

    ctx = f"""You are TraceScene AI assistant analyzing Case #{case.get('case_number', '')}.
Location: {case.get('location', 'Unknown')}
Date: {case.get('incident_date', 'Unknown')}
Officer: {case.get('officer_name', 'Unknown')}
Status: {case.get('status', 'Unknown')}

SCENE ANALYSES:\n"""
    for a in analyses:
        ctx += f"\n--- Photo Analysis ---\n{a.get('raw_analysis', '')}\n"

    if parties:
        ctx += "\nPARTIES IDENTIFIED:\n"
        for p in parties:
            ctx += f"  - {p.get('label', '')}: {p.get('vehicle_type', '')} {p.get('vehicle_color', '')} β€” {p.get('vehicle_description', '')}\n"

    if violations:
        ctx += "\nVIOLATIONS FOUND:\n"
        for v in violations:
            ctx += f"  - {v.get('rule_title', '')} (Severity: {v.get('severity', '')}, Confidence: {v.get('confidence', 0):.0%})\n"

    if fault:
        ctx += f"\nFAULT ANALYSIS:\n  Primary Fault: {fault.get('primary_fault_party', 'N/A')}\n  Confidence: {fault.get('overall_confidence', 0):.0%}\n  Summary: {fault.get('analysis_summary', '')}\n"

    # Append traffic rules
    rules_text = ""
    for cat in rules_data.get("categories", []):
        rules_text += f"\nCategory: {cat.get('name', '')}\n"
        for r in cat.get("rules", []):
            rules_text += f"  - {r.get('id', '')}: {r.get('title', '')} (Severity: {r.get('severity', '')})\n"
    ctx += "\nTRAFFIC RULES:\n" + rules_text

    return ctx, f"βœ… Case **{case.get('case_number', '')}** loaded with {len(analyses)} analyses, {len(violations)} violations."


def chat_respond(user_message, history, system_ctx):
    if not user_message or not user_message.strip():
        yield history or [], "", system_ctx
        return
    
    # ensure_init connects DB and loads rules, but not the models
    run_async(_ensure_init())
    
    logger.info(f"Chat request: {user_message[:50]}...")
    history = history or []
    
    # Ensure history is in dict format
    # In Gradio 5.x/6.x, back and forth can happen.
    history.append({"role": "user", "content": user_message.strip()})
    
    try:
        # Create a placeholder index for the assistant response in this turn
        partial_text = ""
        # Stream from ZeroGPU worker generator
        for token_text in gpu_run_chat_stream(system_ctx, user_message.strip()):
            partial_text = token_text  # token_text is the full response so far (cumulative)
            temp_history = history + [{"role": "assistant", "content": partial_text}]
            yield temp_history, "", system_ctx
            
        # Commit the fully loaded assistant response to the master history
        history.append({"role": "assistant", "content": partial_text})
        logger.info(f"Stream complete. Final response length: {len(partial_text)}")
    except Exception as e:
        logger.error(f"Chat failed: {e}")
        history.append({"role": "assistant", "content": f"Error: {e}"})
        yield history, "", system_ctx

    # Make sure to yield final state on completion
    yield history, "", system_ctx


def generate_animation_fn(case_id):
    if not case_id:
        return "<p style='color:red;'>Enter a Case ID.</p>"
    ensure_init()
    analyses = run_async(db.get_analyses_by_case(int(case_id)))
    if not analyses:
        return "<p style='color:red;'>No analyses found. Run analysis first.</p>"

    # Parse scene details from the first analysis
    raw = analyses[0].get("raw_analysis", "")

    def extract_field(text, field):
        import re
        pattern = rf"{re.escape(field)}:\s*(.+)"
        m = re.search(pattern, text, re.IGNORECASE)
        return m.group(1).strip() if m else "Unknown"

    road_type = extract_field(raw, "Road Type")
    num_vehicles = extract_field(raw, "Vehicles Involved")
    v1_pos = extract_field(raw, "Vehicle 1 Position")
    v1_tyre = extract_field(raw, "Vehicle 1 Tyre Direction")
    impact = extract_field(raw, "Area of Impact")
    category = extract_field(raw, "Accident Category")
    v1_make = extract_field(raw, "Vehicle 1 Make/Model")

    # Check for Vehicle 2
    v2_pos = extract_field(raw, "Vehicle 2 Position")
    v2_tyre = extract_field(raw, "Vehicle 2 Tyre Direction")
    v2_make = extract_field(raw, "Vehicle 2 Make/Model")
    has_v2 = v2_make != "Unknown"

    # Determine colors from extracted make
    import re as re_mod
    def extract_color(make_str):
        colors = ["Red", "Blue", "White", "Black", "Silver", "Grey", "Green", "Yellow", "Brown", "Orange"]
        for c in colors:
            if c.lower() in make_str.lower():
                return c.lower()
        return "#3b82f6"

    v1_color = extract_color(v1_make)
    v2_color = extract_color(v2_make) if has_v2 else "#ef4444"

    # Severity affects animation speed
    speed_map = {"mild": 1.5, "medium": 2.5, "critical": 4.0}
    anim_speed = speed_map.get(category.lower(), 2.5)

    # Road layout
    road_is_intersection = "intersection" in road_type.lower()
    road_is_highway = "highway" in road_type.lower()

    num_v = 1
    try:
        num_v = int(num_vehicles)
    except:
        pass
    # Unique ID to force Gradio to re-render on each click (enables replay)
    import random
    uid = random.randint(10000, 99999)
    # Determine animation duration based on severity
    dur = "3s" if category.lower() == "mild" else "2s" if category.lower() == "medium" else "1.5s"
    sev_color = "#22c55e" if category.lower() == "mild" else "#f59e0b" if category.lower() == "medium" else "#ef4444"

    # Build SVG road
    if road_is_intersection:
        road_svg = '''
        <rect x="0" y="160" width="700" height="100" fill="#555"/>
        <rect x="300" y="0" width="100" height="420" fill="#555"/>
        <line x1="0" y1="210" x2="300" y2="210" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/>
        <line x1="400" y1="210" x2="700" y2="210" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/>
        <line x1="350" y1="0" x2="350" y2="160" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/>
        <line x1="350" y1="260" x2="350" y2="420" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/>
        '''
    else:
        road_svg = '''
        <rect x="0" y="150" width="700" height="120" fill="#555" rx="2"/>
        <line x1="0" y1="210" x2="700" y2="210" stroke="#fbbf24" stroke-width="2" stroke-dasharray="20,15"/>
        <line x1="0" y1="150" x2="700" y2="150" stroke="white" stroke-width="2"/>
        <line x1="0" y1="270" x2="700" y2="270" stroke="white" stroke-width="2"/>
        '''

    # Vehicle 2 SVG (if present)
    v2_svg = ""
    if has_v2:
        if road_is_intersection:
            v2_svg = f'''<g>
              <animateTransform attributeName="transform" type="translate" from="0,0" to="0,135" dur="{dur}" fill="freeze"/>
              <rect x="325" y="60" width="50" height="26" rx="5" fill="{v2_color}" stroke="#fff" stroke-width="1"/>
              <text x="350" y="78" fill="white" font-size="10" font-weight="bold" text-anchor="middle">V2</text>
            </g>'''
        else:
            v2_svg = f'''<g>
              <animateTransform attributeName="transform" type="translate" from="0,0" to="-200,0" dur="{dur}" fill="freeze"/>
              <rect x="560" y="215" width="50" height="26" rx="5" fill="{v2_color}" stroke="#fff" stroke-width="1"/>
              <text x="585" y="233" fill="white" font-size="10" font-weight="bold" text-anchor="middle">V2</text>
            </g>'''

    html = f'''
<div style="text-align:center; font-family: Inter, Arial, sans-serif;">
<svg id="anim_{uid}" width="700" height="420" viewBox="0 0 700 420" xmlns="http://www.w3.org/2000/svg" style="border:1px solid #444; border-radius:10px; background:#1a1a2e;">
  <defs>
    <radialGradient id="glow_{uid}" cx="50%" cy="50%" r="50%">
      <stop offset="0%" stop-color="#fbbf24" stop-opacity="0.8"/>
      <stop offset="100%" stop-color="#fbbf24" stop-opacity="0"/>
    </radialGradient>
  </defs>

  {road_svg}

  <!-- Vehicle 1 -->
  <g>
    <animateTransform attributeName="transform" type="translate" from="0,0" to="200,0" dur="{dur}" fill="freeze"/>
    <rect x="80" y="190" width="50" height="26" rx="5" fill="{v1_color}" stroke="#fff" stroke-width="1"/>
    <text x="105" y="207" fill="white" font-size="10" font-weight="bold" text-anchor="middle">V1</text>
  </g>

  {v2_svg}

  <!-- Impact flash -->
  <circle cx="340" cy="210" r="0" fill="url(#glow_{uid})">
    <animate attributeName="r" values="0;0;0;0;0;0;0;45;55;0" dur="{dur}" fill="freeze"/>
    <animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0.5;0" dur="{dur}" fill="freeze"/>
  </circle>

  <!-- Debris -->
  <circle cx="340" cy="210" r="3" fill="#fbbf24" opacity="0">
    <animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0" dur="{dur}" fill="freeze"/>
    <animate attributeName="cx" values="340;340;340;340;340;340;340;310;290" dur="{dur}" fill="freeze"/>
    <animate attributeName="cy" values="210;210;210;210;210;210;210;185;170" dur="{dur}" fill="freeze"/>
  </circle>
  <circle cx="340" cy="210" r="2" fill="#ef4444" opacity="0">
    <animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0" dur="{dur}" fill="freeze"/>
    <animate attributeName="cx" values="340;340;340;340;340;340;340;370;395" dur="{dur}" fill="freeze"/>
    <animate attributeName="cy" values="210;210;210;210;210;210;210;190;175" dur="{dur}" fill="freeze"/>
  </circle>
  <circle cx="340" cy="210" r="3" fill="#e2e8f0" opacity="0">
    <animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0" dur="{dur}" fill="freeze"/>
    <animate attributeName="cx" values="340;340;340;340;340;340;340;320;305" dur="{dur}" fill="freeze"/>
    <animate attributeName="cy" values="210;210;210;210;210;210;210;235;255" dur="{dur}" fill="freeze"/>
  </circle>
  <circle cx="340" cy="210" r="2" fill="#f97316" opacity="0">
    <animate attributeName="opacity" values="0;0;0;0;0;0;0;1;0" dur="{dur}" fill="freeze"/>
    <animate attributeName="cx" values="340;340;340;340;340;340;340;365;385" dur="{dur}" fill="freeze"/>
    <animate attributeName="cy" values="210;210;210;210;210;210;210;230;250" dur="{dur}" fill="freeze"/>
  </circle>

  <!-- Collision label -->
  <text x="350" y="145" fill="#ef4444" font-size="18" font-weight="bold" text-anchor="middle" opacity="0" font-family="Inter, Arial, sans-serif">
    COLLISION
    <animate attributeName="opacity" values="0;0;0;0;0;0;0;1;1" dur="{dur}" fill="freeze"/>
  </text>

  <!-- HUD -->
  <rect x="10" y="350" width="680" height="60" rx="8" fill="rgba(0,0,0,0.6)"/>
  <text x="20" y="375" fill="#e2e8f0" font-size="12" font-family="Inter, Arial, sans-serif">{v1_make[:35]}</text>
  <text x="20" y="398" fill="#e2e8f0" font-size="12" font-family="Inter, Arial, sans-serif">{"" if not has_v2 else v2_make[:35]}{"Single vehicle accident" if not has_v2 else ""}</text>
  <text x="680" y="375" fill="{sev_color}" font-size="14" font-weight="bold" text-anchor="end" font-family="Inter, Arial, sans-serif">{category.upper()}</text>
  <text x="680" y="398" fill="#94a3b8" font-size="11" text-anchor="end" font-family="Inter, Arial, sans-serif">Impact: {impact} | Road: {road_type}</text>
</svg>
<div style="margin-top:8px; color:#94a3b8; font-size:12px;">
  Vehicles: {num_v} | Animation auto-plays on load
</div>
</div>
'''
    return html

# ── Build Gradio App ──────────────────────────────────────────────────

CUSTOM_CSS = """
.gradio-container { max-width: 1200px !important; }
footer { display: none !important; }
"""

with gr.Blocks(
    title="TraceScene β€” AI Accident Analysis",
) as demo:
    gr.Markdown("""
    # πŸš” TraceScene
    ### AI-Powered Accident Scene Analysis
    *GPU-accelerated inference via ZeroGPU (NVIDIA H200)*
    ---
    """)

    with gr.Tabs():
        # Tab 1: Quick Analyze (single photo)
        with gr.TabItem("⚑ Quick Analyze"):
            gr.Markdown("Upload a photo for instant GPU-accelerated analysis.")
            with gr.Row():
                with gr.Column():
                    input_image = gr.Image(label="Upload Accident Photo", type="pil")
                    quick_btn = gr.Button("πŸš€ Analyze with GPU", variant="primary")
                with gr.Column():
                    quick_output = gr.Textbox(label="AI Analysis", lines=20)
            quick_btn.click(fn=gradio_analyze_photo, inputs=[input_image], outputs=[quick_output], api_name="analyze_photo")

        # Tab 2: Cases
        with gr.TabItem("πŸ“‹ Cases"):
            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("### Create Case")
                    cn = gr.Textbox(label="Case Number *", placeholder="ACC-2026-001")
                    on = gr.Textbox(label="Officer Name")
                    loc = gr.Textbox(label="Location")
                    dt = gr.Textbox(label="Incident Date", placeholder="YYYY-MM-DD")
                    nt = gr.Textbox(label="Notes", lines=2)
                    create_btn = gr.Button("Create Case", variant="primary")
                    create_status = gr.Markdown()
                with gr.Column(scale=2):
                    gr.Markdown("### Existing Cases")
                    cases_tbl = gr.Dataframe(
                        headers=["ID", "Case #", "Officer", "Location", "Date", "Status", "Photos"],
                        interactive=False,
                    )
                    with gr.Row():
                        refresh_btn = gr.Button("πŸ”„ Refresh")
                        del_id = gr.Number(label="Case ID to Delete", precision=0)
                        del_btn = gr.Button("πŸ—‘οΈ Delete", variant="stop")
                    del_status = gr.Markdown()
            create_btn.click(create_case_fn, inputs=[cn, on, loc, dt, nt], outputs=[create_status, cases_tbl], api_name="create_case")
            refresh_btn.click(list_cases_fn, outputs=[cases_tbl], api_name="list_cases")
            del_btn.click(delete_case_fn, inputs=[del_id], outputs=[del_status, cases_tbl], api_name="delete_case")

        # Tab 3: Upload Photos
        with gr.TabItem("πŸ“Έ Photos"):
            with gr.Row():
                with gr.Column(scale=1):
                    up_case = gr.Number(label="Case ID", precision=0)
                    up_files = gr.File(label="Select Photos", file_count="multiple", file_types=["image"])
                    up_btn = gr.Button("Upload", variant="primary")
                    up_status = gr.Markdown()
                with gr.Column(scale=2):
                    pv_case = gr.Number(label="Preview Case ID", precision=0)
                    pv_btn = gr.Button("Load Photos")
                    gallery = gr.Gallery(label="Photos", columns=3)
            up_btn.click(upload_photos_fn, inputs=[up_case, up_files], outputs=[up_status], api_name="upload_photos")
            pv_btn.click(get_case_photos_fn, inputs=[pv_case], outputs=[gallery], api_name="get_case_photos")

        # Tab 4: Run Analysis
        with gr.TabItem("🧠 Analysis"):
            gr.Markdown("""
            ### Full Analysis Pipeline (GPU-accelerated)
            1. Scene Analysis β†’ 2. Rule Matching β†’ 3. Fault Deduction
            """)
            an_case = gr.Number(label="Case ID", precision=0)
            an_btn = gr.Button("πŸš€ Run Full Analysis", variant="primary", size="lg")
            an_status = gr.Markdown()
            with gr.Accordion("Scene Details", open=False):
                an_detail = gr.Markdown()
            an_violations = gr.Markdown(label="Violations & Fault")
            an_btn.click(run_analysis_fn, inputs=[an_case], outputs=[an_status, an_detail, an_violations], api_name="run_analysis")

        # Tab 5: Report
        with gr.TabItem("πŸ“„ Report"):
            rp_case = gr.Number(label="Case ID", precision=0)
            rp_btn = gr.Button("Generate Report", variant="primary")
            rp_out = gr.Markdown()
            rp_btn.click(generate_report_fn, inputs=[rp_case], outputs=[rp_out], api_name="generate_report")

        # Tab 6: Rules
        with gr.TabItem("πŸ“œ Rules"):
            ru_btn = gr.Button("Load Traffic Rules")
            ru_out = gr.Markdown()
            ru_btn.click(get_rules_fn, outputs=[ru_out], api_name="get_rules")

        # Tab 7: Chat Q&A
        with gr.TabItem("πŸ’¬ Chat"):
            gr.Markdown("### Case Q&A Chatbot\nAsk questions about logged cases, traffic rules, or insurance clauses.")
            with gr.Row():
                chat_case_id = gr.Number(label="Case ID (optional)", precision=0)
                chat_load_btn = gr.Button("πŸ“‚ Load Case Context", variant="secondary")
            chat_context_status = gr.Markdown(value="*No case loaded. You can still ask general traffic/insurance questions.*")
            chatbot = gr.Chatbot(label="Conversation", height=400)
            chat_input = gr.Textbox(label="Your Question", placeholder="e.g. What vehicles were involved? What rules were violated?", lines=2)
            with gr.Row():
                chat_send_btn = gr.Button("πŸ’¬ Send", variant="primary")
                chat_clear_btn = gr.Button("πŸ—‘οΈ Clear")

            # State for context
            chat_system_ctx = gr.State(value="You are TraceScene AI assistant. You help insurers and investigating officers analyze accident cases, traffic rules, and insurance clauses. Answer concisely and accurately based on the context provided.")



            chat_load_btn.click(load_chat_context, inputs=[chat_case_id], outputs=[chat_system_ctx, chat_context_status])
            chat_send_btn.click(chat_respond, inputs=[chat_input, chatbot, chat_system_ctx], outputs=[chatbot, chat_input, chat_system_ctx], api_name="chat")
            chat_input.submit(chat_respond, inputs=[chat_input, chatbot, chat_system_ctx], outputs=[chatbot, chat_input, chat_system_ctx])
            chat_clear_btn.click(lambda: ([], ""), outputs=[chatbot, chat_input])

        # Tab 8: 2D Animation
        with gr.Tab("Simulation"):
            gr.Markdown("### 2D Accident Simulation\nVisualize the top-down perspective of the incident.")
            anim_case_id = gr.Number(label="Case ID", precision=0)
            anim_btn = gr.Button("Generate Animation", variant="primary")
            anim_output = gr.HTML(label="Animation View")

            anim_btn.click(generate_animation_fn, inputs=[anim_case_id], outputs=[anim_output])

        # Hidden API-only endpoints (for @gradio/client from custom frontend)
        with gr.TabItem("πŸ”Œ API", visible=False):
            api_health_btn = gr.Button("health")
            api_health_out = gr.Textbox()
            api_health_btn.click(health_fn, outputs=[api_health_out], api_name="health")

            api_cases_btn = gr.Button("list_cases_json")
            api_cases_out = gr.Textbox()
            api_cases_btn.click(list_cases_json, outputs=[api_cases_out], api_name="list_cases_json")

            api_case_id = gr.Number(precision=0)
            api_case_btn = gr.Button("get_case")
            api_case_out = gr.Textbox()
            api_case_btn.click(get_case_json, inputs=[api_case_id], outputs=[api_case_out], api_name="get_case")

            api_report_id = gr.Number(precision=0)
            api_report_btn = gr.Button("get_report")
            api_report_out = gr.Textbox()
            api_report_btn.click(get_report_json, inputs=[api_report_id], outputs=[api_report_out], api_name="get_report_json")

            api_rules_btn = gr.Button("get_rules_json")
            api_rules_out = gr.Textbox()
            api_rules_btn.click(get_rules_json, outputs=[api_rules_out], api_name="get_rules_json")

    gr.Markdown("---\n*TraceScene β€” Built by Siddharth Ravikumar | tracescene@zohomail.ae*")


# ── Create FastAPI App & Mount Gradio ────────────────────────────────

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Static files (frontend)
frontend_dir = Path(__file__).resolve().parent / "frontend"
if frontend_dir.exists():
    # Mount specific subfolders to root for easier relative pathing
    app.mount("/css", StaticFiles(directory=str(frontend_dir / "css")), name="css")
    app.mount("/js", StaticFiles(directory=str(frontend_dir / "js")), name="js")
    app.mount("/images", StaticFiles(directory=str(frontend_dir / "images")), name="images")
    app.mount("/static", StaticFiles(directory=str(frontend_dir / "static")), name="static")

# Serve uploads folder
if settings.upload_path.exists():
    app.mount("/uploads", StaticFiles(directory=str(settings.upload_path)), name="uploads")

@app.get("/")
async def serve_frontend():
    index_file = frontend_dir / "index.html"
    if index_file.exists():
        return FileResponse(str(index_file))
    return {"message": "TraceScene API", "docs": "/docs"}

# API Routes
app.include_router(router)

# Mount Gradio app at /gradio
app = gr.mount_gradio_app(
    app,
    demo,
    path="/"
)

# Startup event wrapper
@app.on_event("startup")
async def startup_event():
    logger.info("Starting up FastAPI application...")
    await _ensure_init()
    
    # --- Hugging Face ZeroGPU Fix ---
    # When using gr.mount_gradio_app with a custom FastAPI app, gr.Blocks.launch()
    # is bypassed. The `spaces` library hooks `.launch()` to emit the `startup_report`
    # required by ZeroGPU orchestrator to verify `@spaces.GPU` functions exist. 
    # Without this report, the Hub errors out with "No @spaces.GPU function detected".
    # Therefore, we manually trigger it here.
    try:
        from spaces import config
        if getattr(config.Config, "zero_gpu", False):
            import spaces.zero as zero
            if hasattr(zero, "startup"):
                zero.startup()
                logger.info("Triggered ZeroGPU startup successfully.")
            elif hasattr(zero, "client"):
                zero.torch.pack()
                zero.client.startup_report()
                logger.info("Triggered ZeroGPU client startup manually.")
    except ImportError:
        pass
    except Exception as e:
        logger.warning(f"Failed to manually trigger ZeroGPU startup report: {e}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)