File size: 58,288 Bytes
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c099d4
 
c00671c
6c099d4
4eeefd1
6c099d4
 
 
 
 
52f5401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22b92ec
6c099d4
 
 
57a7ef7
ca75be0
6cced5a
f2ed949
 
 
 
 
 
 
 
6cced5a
ca75be0
e655b1e
6c099d4
 
 
 
4eeefd1
 
 
 
 
 
0435c94
4eeefd1
 
 
c00671c
 
 
 
 
 
 
4eeefd1
6c099d4
 
 
 
4eeefd1
daaac94
 
 
 
6c099d4
daaac94
 
 
 
 
6c099d4
 
 
4eeefd1
6c099d4
daaac94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c099d4
4eeefd1
6c099d4
 
 
 
 
 
 
 
c00671c
 
 
 
 
 
 
 
 
 
 
 
6c099d4
 
 
daaac94
 
 
 
 
 
 
6c099d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4eeefd1
e4c5797
4eeefd1
e4c5797
c00671c
 
 
e4c5797
 
6c099d4
 
daaac94
e4c5797
 
 
8fb2134
 
 
 
 
4eeefd1
 
6c099d4
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
6c099d4
4eeefd1
 
 
 
 
 
 
 
 
 
 
c00671c
4eeefd1
6c099d4
c00671c
 
 
 
 
 
 
4eeefd1
c00671c
4eeefd1
 
 
 
 
 
 
 
c00671c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4eeefd1
 
 
c00671c
4eeefd1
 
c00671c
4eeefd1
 
6c099d4
 
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daaac94
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c099d4
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
daaac94
 
 
4eeefd1
 
 
6c099d4
 
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c099d4
4eeefd1
 
 
 
 
c00671c
 
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
c00671c
 
4eeefd1
 
 
6c099d4
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c00671c
 
4eeefd1
 
 
 
 
 
 
 
 
 
8fb2134
 
 
 
 
6c099d4
 
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c099d4
 
 
 
 
 
 
 
ff2fe61
 
 
 
 
 
 
 
 
6c099d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca75be0
 
dcb1ea4
 
ca75be0
 
6c099d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcb1ea4
6c099d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fb2134
 
c00671c
 
 
 
 
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c099d4
4eeefd1
 
6c099d4
 
ca75be0
6b056a6
ca75be0
 
6c099d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fb2134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c099d4
e7e26ca
6c099d4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import cv2
import torch
import numpy as np
import gradio as gr
import sys
import shutil
from datetime import datetime
import glob
import gc
import time
import open3d as o3d
import open_clip
from open_clip import tokenizer
import trimesh
import matplotlib.pyplot as plt
import subprocess
import tempfile
import contextlib
from huggingface_hub import hf_hub_download

try:
    import gdown
except Exception:
    gdown = None

# Defensive patch: some gradio_client versions crash on JSON schema with boolean additionalProperties.
try:
    import gradio_client.utils as _gcu

    if hasattr(_gcu, "_json_schema_to_python_type"):
        _orig = _gcu._json_schema_to_python_type

        def _json_schema_to_python_type_patched(schema, defs=None):
            if isinstance(schema, bool):
                return "Any"
            return _orig(schema, defs)

        _gcu._json_schema_to_python_type = _json_schema_to_python_type_patched
except Exception:
    pass

os.environ.setdefault("MAX_JOBS", "1")
REPO_ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(REPO_ROOT, "vggt"))
MK_PATH = os.path.join(REPO_ROOT, "MaskClustering")


# Ensure local detectron2 is installed at runtime if postBuild didn't run
# try:
#     import detectron2  # noqa: F401
# except Exception:
#     print("[runtime] detectron2 not found. Installing local detectron2 (editable, no build isolation)...")
#     os.system("python -m pip install --no-build-isolation -e ./MaskClustering/third_party/detectron2")
#     import importlib
#     importlib.invalidate_caches()
#     import detectron2  # noqa: F401

# If detectron2 isn't installed as a package, allow importing from vendored source.


# Writable workdir (HF Spaces: prefer /tmp)
WORK_DIR = os.environ.get("ZOO3D_WORKDIR", os.path.join(tempfile.gettempdir(), "zoo3d"))
os.makedirs(WORK_DIR, exist_ok=True)
from visual_util import predictions_to_glb
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.utils.geometry import unproject_depth_map_to_point_map

device = "cuda" if torch.cuda.is_available() else 'cpu'

print(f"Using device: {device}")

# CPU debug / compatibility knobs:
# - On CPU, VGGT-1B inference is usually impractical. For debugging, we fall back to a lightweight
#   dummy pipeline that produces a minimal predictions dict compatible with `predictions_to_glb`.
ZOO3D_ALLOW_CPU = os.environ.get("ZOO3D_ALLOW_CPU", "1") == "1"
ZOO3D_CPU_DUMMY = os.environ.get("ZOO3D_CPU_DUMMY", "1") == "1"
ZOO3D_SKIP_DOWNLOADS = os.environ.get("ZOO3D_SKIP_DOWNLOADS", "0") == "1"


_VGGT_MODEL = None
_METRIC3D_MODEL = None
_CLIP_MODEL = None


_MASK2FORMER_GDRIVE_FILE_ID = "10G7s6bVMwN__bcrR2fBal3goo69Y5Do4"


def _ensure_mask2former_weights(dst_path: str) -> str:
    """
    Ensure Mask2Former/CropFormer weights exist at dst_path.
    Priority:
    1) Use existing file (if present)
    2) Download from Google Drive (user-provided link / file id)
    3) Fallback: download from HF dataset (qqlu1992/Adobe_EntitySeg)
    """
    if os.path.exists(dst_path) and os.path.getsize(dst_path) > 0:
        return dst_path

    os.makedirs(os.path.dirname(dst_path), exist_ok=True)

    # Allow user override via local path
    override_path = os.environ.get("MASK2FORMER_WEIGHTS_PATH")
    if override_path and os.path.exists(override_path) and os.path.getsize(override_path) > 0:
        shutil.copyfile(override_path, dst_path)
        return dst_path

    # 2) Google Drive
    if gdown is not None:
        url = f"https://drive.google.com/uc?id={_MASK2FORMER_GDRIVE_FILE_ID}"
        out = gdown.download(url, dst_path, quiet=False)
        if out is not None and os.path.exists(dst_path) and os.path.getsize(dst_path) > 0:
            return dst_path
        print("Warning: gdown download failed for Mask2Former weights; falling back to HF dataset...")
    else:
        print("Warning: gdown is not available; falling back to HF dataset for Mask2Former weights...")

    # 3) HF fallback
    cached = hf_hub_download(
        repo_id="qqlu1992/Adobe_EntitySeg",
        repo_type="dataset",
        filename="CropFormer_model/Entity_Segmentation/Mask2Former_hornet_3x/Mask2Former_hornet_3x_576d0b.pth",
    )
    shutil.copyfile(cached, dst_path)
    return dst_path


def _init_models():
    """
    Lazy-load heavy models so the UI can start quickly on HF Spaces.
    """
    global _VGGT_MODEL, _METRIC3D_MODEL, _CLIP_MODEL

    if not torch.cuda.is_available():
        # CPU-friendly mode for debugging: skip heavy models.
        if not ZOO3D_ALLOW_CPU:
            raise RuntimeError("CUDA недоступна. Для этого Space нужен GPU (CUDA).")
        # We still can load CLIP on CPU if needed, but skip VGGT/Metric3D.
        if _CLIP_MODEL is None:
            print("[INFO] loading CLIP model (CPU)...")
            cm, _, _ = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k")
            cm.to("cpu")
            cm.eval()
            print("[INFO] finish loading CLIP model (CPU)...")
            globals()["_CLIP_MODEL"] = cm
        return None, None, _CLIP_MODEL

    if _VGGT_MODEL is None:
        print("Initializing and loading VGGT model...")
        # Prefer Hugging Face weights for VGGT
        try:
            m = VGGT.from_pretrained("facebook/VGGT-1B")
        except Exception:
            m = VGGT()
            _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
            m.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
        m.eval()
        _VGGT_MODEL = m.to(device)

    if _METRIC3D_MODEL is None:
        print("Initializing and loading Metric3D model...")
        try:
            mm = torch.hub.load("yvanyin/metric3d", "metric3d_vit_small", pretrain=True, trust_repo=True)
        except TypeError:
            mm = torch.hub.load("yvanyin/metric3d", "metric3d_vit_small", pretrain=True)
        mm.to(device)
        mm.eval()
        _METRIC3D_MODEL = mm

    if _CLIP_MODEL is None:
        print("[INFO] loading CLIP model...")
        cm, _, _ = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k")
        cm.to(device)
        cm.eval()
        print("[INFO] finish loading CLIP model...")
        _CLIP_MODEL = cm

    return _VGGT_MODEL, _METRIC3D_MODEL, _CLIP_MODEL

cropformer_name = "Mask2Former_hornet_3x_576d0b.pth"

def check_weights():
    if ZOO3D_SKIP_DOWNLOADS:
        print("[INFO] ZOO3D_SKIP_DOWNLOADS=1: skipping Mask2Former weights download.")
        return
    if not os.path.exists(os.path.join(MK_PATH, cropformer_name)):
        print(f"Downloading {cropformer_name}...")
        os.makedirs(MK_PATH, exist_ok=True)
        dst = os.path.join(MK_PATH, cropformer_name)
        _ensure_mask2former_weights(dst)
        print(f"Downloaded {cropformer_name}...")
    else:
        print(f"{cropformer_name} already exists...")
#
# IMPORTANT (HF Spaces):
# Do NOT download large weights at import time (startup). We'll download lazily
# when running detection/reconstruction that actually needs them.
#

def extract_text_feature(descriptions, clip_model, target_path):
    text_tokens = tokenizer.tokenize(descriptions).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text_tokens).float()
        text_features /= text_features.norm(dim=-1, keepdim=True)
        text_features = text_features.cpu().numpy()

    text_features_dict = {}
    for i, description in enumerate(descriptions):
        text_features_dict[description] = text_features[i]
    
    np.save(os.path.join(target_path, "text_features.npy"), text_features_dict)
    return text_features_dict


clip_model = None


# -------------------------------------------------------------------------
# 1) Core model inference
# -------------------------------------------------------------------------
def run_model(target_dir, model, metric3d_model=None) -> dict:
    """
    Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
    """
    print(f"Processing images from {target_dir}")

    # Device selection
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device != "cuda":
        if not ZOO3D_ALLOW_CPU:
            raise RuntimeError("CUDA недоступна. Для этого Space нужен GPU (CUDA).")
        if not ZOO3D_CPU_DUMMY:
            raise RuntimeError(
                "CPU режим включен, но ZOO3D_CPU_DUMMY=0. "
                "Для отладки поставь ZOO3D_CPU_DUMMY=1 или включи GPU."
            )

    # Load and preprocess images (we need them for both GPU and CPU-dummy)

    # Load and preprocess images
    image_names = glob.glob(os.path.join(target_dir, "images", "*"))
    image_names = sorted(image_names)
    print(f"Found {len(image_names)} images")
    if len(image_names) == 0:
        raise ValueError("No images found. Check your upload.")

    # For CPU dummy mode we want the original HxW for `predictions_to_glb` coloring.
    cpu_images_u8 = None
    if device == "cpu":
        imgs = []
        for p in image_names:
            im = cv2.imread(p, cv2.IMREAD_COLOR)
            if im is None:
                continue
            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
            imgs.append(im)
        if len(imgs) == 0:
            raise ValueError("No readable images found. Check your upload.")
        # Make all images same size for stacking
        H, W = imgs[0].shape[:2]
        imgs2 = []
        for im in imgs:
            if im.shape[:2] != (H, W):
                im = cv2.resize(im, (W, H))
            imgs2.append(im)
        cpu_images_u8 = np.stack(imgs2, axis=0)  # (S,H,W,3) uint8
        print(f"CPU dummy: loaded images shape: {cpu_images_u8.shape}")

    images = load_and_preprocess_images(image_names)
    print(f"Preprocessed images shape: {tuple(images.shape)}")
    if device == "cuda":
        images = images.to(device)

    if device == "cpu":
        # Dummy predictions for CPU debugging: minimal keys needed by `predictions_to_glb`
        S, H, W = cpu_images_u8.shape[0], cpu_images_u8.shape[1], cpu_images_u8.shape[2]
        # Simple planar point cloud in camera space
        uu, vv = np.meshgrid(np.arange(W), np.arange(H))
        x = (uu - (W / 2.0)) / float(max(W, 1))
        y = -(vv - (H / 2.0)) / float(max(W, 1))
        z = np.ones_like(x, dtype=np.float32) * 1.0
        pts = np.stack([x, y, z], axis=-1).astype(np.float32)  # (H,W,3)
        world_points_from_depth = np.repeat(pts[None, ...], S, axis=0)  # (S,H,W,3)
        depth = np.ones((S, H, W, 1), dtype=np.float32)
        depth_conf = np.ones((S, H, W), dtype=np.float32)
        extrinsic = np.tile(np.array([[1, 0, 0, 0],
                                      [0, 1, 0, 0],
                                      [0, 0, 1, 0]], dtype=np.float32)[None, ...], (S, 1, 1))
        intrinsic = np.tile(np.eye(3, dtype=np.float32)[None, ...], (S, 1, 1))
        pose = np.tile(np.eye(4, dtype=np.float32)[None, ...], (S, 1, 1))
        return {
            "images": cpu_images_u8,
            "extrinsic": extrinsic,
            "intrinsic": intrinsic,
            "pose": pose,
            "depth": depth,
            "depth_conf": depth_conf,
            "world_points_from_depth": world_points_from_depth,
        }

    # GPU inference
    # Move model to device
    model = model.to(device)
    model.eval()

    print("Running inference...")
    dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
    amp_ctx = torch.cuda.amp.autocast(dtype=dtype) if device == "cuda" else contextlib.nullcontext()

    with torch.no_grad():
        with amp_ctx:
            predictions = model(images)

    scale_factor = torch.tensor(1.0, device=device)

    # Metric3D inference
    if metric3d_model is not None:
        print("Running Metric3D inference...")
        # images is (B, 3, H, W) in [0, 1]
        # Metric3D usually expects [0, 255] if input is tensor via inference dict
        metric3d_input = images * 255.0
        
        m_depths = []
        # Process one by one to avoid potential batch issues if inference doesn't support batch
        for i in range(metric3d_input.shape[0]):
             img = metric3d_input[i:i+1] # (1, 3, H, W)
             
             # Pad image to be divisible by 32 (standard for HourGlass/UNet architectures)
             _, _, h, w = img.shape
             ph = ((h - 1) // 32 + 1) * 32
             pw = ((w - 1) // 32 + 1) * 32
             
             padding = (0, pw - w, 0, ph - h) # left, right, top, bottom
             if ph != h or pw != w:
                 img = torch.nn.functional.pad(img, padding, mode='constant', value=0)
             
             with torch.no_grad():
                 pred_depth, confidence, _ = metric3d_model.inference({'input': img})
             
             # Crop back to original size
             if ph != h or pw != w:
                 pred_depth = pred_depth[:, :, :h, :w]
                 
             m_depths.append(pred_depth)
        
        predictions["metric3d_depth"] = torch.cat(m_depths, dim=0)

        # Scale alignment: scale = median(Depths_VGGT / Depths_Metric3D)
        # We need to make sure we use valid depths (e.g. > 0) to avoid numerical issues
        vggt_depth = predictions["depth"][0]  # (B, H, W, 1) or similar
        metric_depth = predictions["metric3d_depth"]  # (B, 1, H, W) presumably

        # Ensure shapes match for broadcasting or direct division
        # VGGT depth usually (B, H, W, 1)
        # Metric3D depth usually (B, 1, H, W) or (B, H, W) depending on model output.
        # Let's check shapes and align.
        
        # Adjust Metric3D depth shape to match VGGT if needed
        # Assuming VGGT is (B, H, W, 1) and Metric3D is (B, 1, H, W)
        if metric_depth.dim() == 4 and metric_depth.shape[1] == 1:
            metric_depth = metric_depth.permute(0, 2, 3, 1) # -> (B, H, W, 1)
        elif metric_depth.dim() == 3:
             metric_depth = metric_depth.unsqueeze(-1) # -> (B, H, W, 1)

        # Move to same device/dtype
        vggt_depth = vggt_depth.to(metric_depth.device).float()
        metric_depth = metric_depth.float()
        
        # Resize metric depth to match VGGT depth if they differ in spatial resolution
        # vggt_depth: (B, H, W, 1) or (B, H, W)
        # metric_depth: (B, H, W, 1) after permutation
        
        # Mask for valid values to compute median
        print(f"Metric3D depth shape: {metric_depth.shape}")
        print(f"VGGT depth shape: {vggt_depth.shape}")
        valid_mask = (metric_depth > 1e-6) & (vggt_depth > 1e-6)
        
        if valid_mask.sum() > 0:
            print(f"Valid mask shape: {valid_mask.shape}")
            print(f"Metric depth shape: {metric_depth.shape}")
            print(f"VGGT depth shape: {vggt_depth.shape}")
            ratio = metric_depth[valid_mask] / vggt_depth[valid_mask]
            scale_factor = torch.median(ratio)
            print(f"Computed scale factor (VGGT / Metric3D): {scale_factor.item():.4f}")
        else:
            print("Warning: could not compute scale factor; falling back to 1.0")
    print("Converting pose encoding to extrinsic and intrinsic matrices...")
    extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
    extrinsic = extrinsic[0]
    add = torch.zeros_like(extrinsic[:, 2:])
    add[..., -1] = 1
    extrinsic = torch.cat([extrinsic, add], dim=-2)
    zero_extrinsic = extrinsic[0]
    for i, e in enumerate(extrinsic):
        extrinsic[i] = zero_extrinsic @ torch.linalg.inv(e)
        extrinsic[i, :3, 3] *= scale_factor
    extrinsic_inv = torch.linalg.inv(extrinsic)
    print(f"Extrinsic: {extrinsic.shape}")
    extrinsic_inv = extrinsic_inv[None, ..., :3, :]
    predictions["extrinsic"] = extrinsic_inv
    predictions["pose"] = extrinsic[None]
    print(f"Extrinsic: {extrinsic.shape} {extrinsic}")
    predictions["intrinsic"] = intrinsic

    # Convert tensors to numpy
    for key in predictions.keys():
        if isinstance(predictions[key], torch.Tensor):
            try:
                predictions[key] = predictions[key].cpu().numpy().squeeze(0)  # remove batch dimension
            except ValueError:
                pass

    # Generate world points from depth map
    print("Computing world points from depth map...")
    predictions["depth"] = predictions["depth"] * float(scale_factor.item())
    depth_map = predictions["depth"]
    world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
    predictions["world_points_from_depth"] = world_points

    # Clean up
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    return predictions


# -------------------------------------------------------------------------
# 2) Handle uploaded video/images --> produce target_dir + images
# -------------------------------------------------------------------------
def handle_uploads(input_video, input_images):
    """
    Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
    images or extracted frames from video into it. Return (target_dir, image_paths).
    """
    start_time = time.time()
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Create a unique folder name
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    target_dir = os.path.join(WORK_DIR, "input", timestamp)
    target_dir_images = os.path.join(target_dir, "images")

    # Clean up if somehow that folder already exists
    if os.path.exists(target_dir):
        shutil.rmtree(target_dir)
    os.makedirs(target_dir)
    os.makedirs(target_dir_images)

    image_paths = []

    # --- Handle images ---
    if input_images is not None:
        for file_data in input_images:
            if isinstance(file_data, dict) and "name" in file_data:
                file_path = file_data["name"]
            else:
                file_path = file_data
            dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
            shutil.copy(file_path, dst_path)
            image_paths.append(dst_path)

    # --- Handle video ---
    if input_video is not None:
        if isinstance(input_video, dict) and "name" in input_video:
            video_path = input_video["name"]
        else:
            video_path = input_video

        vs = cv2.VideoCapture(video_path)
        fps = vs.get(cv2.CAP_PROP_FPS)
        frame_interval = int(fps * 1)  # 1 frame/sec

        count = 0
        video_frame_num = 0
        while True:
            gotit, frame = vs.read()
            if not gotit:
                break
            count += 1
            if count % frame_interval == 0:
                image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.jpg")
                cv2.imwrite(image_path, frame)
                image_paths.append(image_path)
                video_frame_num += 1

    # Sort final images for gallery
    image_paths = sorted(image_paths)

    end_time = time.time()
    print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
    return target_dir, image_paths


# -------------------------------------------------------------------------
# 3) Update gallery on upload
# -------------------------------------------------------------------------
def update_gallery_on_upload(input_video, input_images):
    """
    Whenever user uploads or changes files, immediately handle them
    and show in the gallery. Return (target_dir, image_paths).
    If nothing is uploaded, returns "None" and empty list.
    """
    if not input_video and not input_images:
        return None, None, None, None
    target_dir, image_paths = handle_uploads(input_video, input_images)
    return None, target_dir, image_paths, "Upload complete. Click 'Detect Objects' to begin 3D processing."


# -------------------------------------------------------------------------
# 4) Reconstruction: uses the target_dir plus any viz parameters
# -------------------------------------------------------------------------
def reconstruct(
    target_dir,
    conf_thres=50.0,
    frame_filter="All",
    mask_black_bg=False,
    mask_white_bg=False,
    show_cam=True,
    mask_sky=False,
    prediction_mode="Depthmap and Camera Branch",
    text_labels="",
):
    """
    Perform reconstruction using the already-created target_dir/images.
    """
    prediction_mode = "Depthmap and Camera Branch" # Force prediction mode
    if not os.path.isdir(target_dir) or target_dir == "None":
        return None, "No valid target directory found. Please upload first.", None, None

    start_time = time.time()
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Prepare frame_filter dropdown
    target_dir_images = os.path.join(target_dir, "images")
    all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
    image_names = [f.split(".")[0] for f in all_files]
    all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
    frame_filter_choices = ["All"] + all_files

    print("Running run_model...")
    with torch.no_grad():
        # Ensure CropFormer weights exist if downstream pipeline is enabled
        try:
            check_weights()
        except Exception as e:
            print(f"Warning: could not ensure Mask2Former weights at startup: {e}")
        vggt_model, metric3d_model, _ = _init_models()
        predictions = run_model(target_dir, vggt_model, metric3d_model=metric3d_model)


    # Save predictions
    prediction_save_path = os.path.join(target_dir, "predictions.npz")
    try:
        np.savez(prediction_save_path, **predictions)
    except Exception as e:
        print(f"Warning: could not save predictions to npz: {e}")

    depth_path = os.path.join(target_dir, "depth")
    pose_path = os.path.join(target_dir, "pose")
    intrinsic_path = os.path.join(target_dir, "intrinsic")
    os.makedirs(depth_path, exist_ok=True)
    os.makedirs(pose_path, exist_ok=True)
    os.makedirs(intrinsic_path, exist_ok=True)
    for i, d in enumerate(predictions["depth"]):
        print(d.shape)
        cv2.imwrite(os.path.join(depth_path, f"{image_names[i]}.png"), (d[..., 0] * 1000).astype(np.uint16))
    intr = np.eye(4)
    intr[:3, :3] = np.mean(predictions["intrinsic"], axis=0)
    np.savetxt(os.path.join(intrinsic_path, "intrinsic_depth.txt"), intr)

    for i, p in enumerate(predictions["pose"]):
        np.savetxt(os.path.join(pose_path, f"{image_names[i]}.txt"), p)

    # Handle None frame_filter
    if frame_filter is None:
        frame_filter = "All"

    # Build a GLB file name
    glbfile = os.path.join(
        target_dir,
        f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
    )

    # Convert predictions to GLB
    glbscene, point_cloud_data = predictions_to_glb(
        predictions,
        conf_thres=conf_thres,
        filter_by_frames=frame_filter,
        mask_black_bg=mask_black_bg,
        mask_white_bg=mask_white_bg,
        show_cam=show_cam,
        mask_sky=mask_sky,
        target_dir=target_dir,
        prediction_mode=prediction_mode,
    )
    
    # Ensure colors are RGB (remove alpha if present) for Open3D
    v = np.asarray(point_cloud_data.vertices)
    c = np.asarray(point_cloud_data.colors) / 255.0
    if c.shape[1] == 4:
        c = c[:, :3]
        
    glbscene.export(file_obj=glbfile)
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(v)
    pcd.colors = o3d.utility.Vector3dVector(c)

    pcd = pcd.voxel_down_sample(voxel_size=0.01)
    o3d.io.write_point_cloud(os.path.join(target_dir, "point_cloud.ply"), pcd)


    # Cleanup
    del predictions
    gc.collect()
    torch.cuda.empty_cache()

    end_time = time.time()
    print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
    log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
    # External pipelines are fragile in Spaces (often require compiled ops).
    # We try to run them, but do not fail the whole app if they error.
    root_input_dir = os.path.dirname(target_dir)
    seq_name = os.path.basename(target_dir)
    try:
        subprocess.run(
            [
                sys.executable,
                os.path.join(
                    MK_PATH,
                    "third_party",
                    "detectron2",
                    "projects",
                    "CropFormer",
                    "demo_cropformer",
                    "mask_predict.py",
                ),
                "--config-file",
                os.path.join(
                    MK_PATH,
                    "third_party",
                    "detectron2",
                    "projects",
                    "CropFormer",
                    "configs",
                    "entityv2",
                    "entity_segmentation",
                    "mask2former_hornet_3x.yaml",
                ),
                "--root",
                root_input_dir,
                "--image_path_pattern",
                "images/*.jpg",
                "--dataset",
                "arkit_gt",
                "--seq_name_list",
                seq_name,
                "--opts",
                "MODEL.WEIGHTS",
                os.path.join(MK_PATH, cropformer_name),
            ],
            check=True,
            env={
                **os.environ,
                # Use installed detectron2; avoid shadowing it with partial local tree
                "PYTHONPATH": MK_PATH
                + (os.pathsep + os.environ["PYTHONPATH"] if os.environ.get("PYTHONPATH") else ""),
            },
        )

        subprocess.run(
            [
                sys.executable,
                os.path.join(MK_PATH, "main.py"),
                "--config",
                "wild",
                "--root",
                root_input_dir,
                "--seq_name_list",
                seq_name,
            ],
            check=True,
        )

        env = dict(os.environ)
        env["PYTHONPATH"] = MK_PATH + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
        subprocess.run(
            [
                sys.executable,
                os.path.join(MK_PATH, "semantics", "get_open-voc_features.py"),
                "--config",
                "wild",
                "--root",
                root_input_dir,
                "--seq_name_list",
                seq_name,
            ],
            env=env,
            check=True,
        )
    except Exception as e:
        print(f"Warning: external MaskClustering pipeline failed: {e}")

    return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)

def visualize_detections(target_dir, conf_thres, frame_filter="All", mask_black_bg=False, mask_white_bg=False, show_cam=True, mask_sky=False, prediction_mode="Depthmap and Camera Branch"):
    """
    Generate a GLB scene with bounding boxes for detected objects.
    """
    if not target_dir or not os.path.exists(target_dir):
         return None, "Target directory not found."

    ply_path = os.path.join(target_dir, "point_cloud.ply")
    npz_path = os.path.join(target_dir, "output", "object", "prediction.npz")

    # 1. Загрузить point cloud как основу сцены
    if not os.path.exists(ply_path):
        return None, f"Point cloud not found at {ply_path}. Please run detection first."

    pcd = o3d.io.read_point_cloud(ply_path)
    points = np.asarray(pcd.points)
    colors = np.asarray(pcd.colors)

    if points.size == 0:
        return None, "Point cloud is empty."

    # Создаем базовую сцену из облака точек
    scene = trimesh.Scene()

    if colors.size == 0:
        t_colors = np.ones((len(points), 4), dtype=np.uint8) * 255
    else:
        if colors.max() <= 1.0:
            t_colors = (colors * 255).astype(np.uint8)
        else:
            t_colors = colors.astype(np.uint8)
        if t_colors.shape[1] == 3:
            t_colors = np.hstack([t_colors, np.ones((len(t_colors), 1), dtype=np.uint8) * 255])

    base_pc = trimesh.PointCloud(vertices=points, colors=t_colors)
    scene.add_geometry(base_pc)

    # 2. Добавить боксы по результатам детекции, если они есть
    legend_md = ""
    if os.path.exists(npz_path):
        try:
            loaded = np.load(npz_path, allow_pickle=True)
            # Check for detection keys
            if 'pred_masks' in loaded:
                masks = loaded['pred_masks'].T 
                labels = loaded['pred_classes']
                confs = loaded['pred_score']

                # Load text features to map labels to names
                text_features_path = os.path.join(target_dir, "text_features.npy")
                label_to_name = {}
                if os.path.exists(text_features_path):
                     try:
                         text_features_dict = np.load(text_features_path, allow_pickle=True).item()
                         feature_keys = list(text_features_dict.keys())
                         for i, name in enumerate(feature_keys):
                             label_to_name[i] = name
                     except Exception as e:
                         print(f"Warning: Could not load text features for label mapping: {e}")

                # Filter
                if isinstance(confs, (list, tuple)):
                    confs = np.array(confs)
                
                valid_indices = np.where(confs > conf_thres)[0]
                
                if len(valid_indices) > 0:
                    legend_items = {}
                    cmap = plt.get_cmap("tab10") 
                    
                    detected_labels = np.unique(labels[valid_indices])
                    label_to_color = {label: cmap(i % 10) for i, label in enumerate(detected_labels)}

                    for idx in valid_indices:
                        mask = masks[idx]
                        if hasattr(mask, "toarray"): 
                            mask = mask.toarray().flatten()
                        mask = mask.astype(bool)
                        
                        # Verify mask size
                        if len(mask) != len(points):
                             # This is critical. If GLB points are filtered, masks might not match.
                             # If masks were generated on the FULL point cloud, we need the FULL point cloud to compute BBox.
                             # If we can't guarantee alignment, we skip or print warning.
                             # Ideally, detection pipeline should handle this alignment.
                             pass
                             # For now, let's assume they align or we skip.
                             # If alignment fails, we just don't add the box.

                        if len(mask) == len(points):
                            obj_points = points[mask]
                            if len(obj_points) >= 4:
                                obj_pcd = trimesh.PointCloud(obj_points)
                                try:
                                    bbox = obj_pcd.bounding_box_oriented
                                except Exception:
                                    bbox = obj_pcd.bounding_box

                                # Пропускаем нерелевантно большие боксы: если максимальная длина > 2.5 м
                                try:
                                    ext = np.asarray(bbox.extents).astype(float)
                                    if float(np.max(ext)) > 2.5:
                                        continue
                                except Exception:
                                    pass

                                # Строим только «каркас» бокса по 8 вершинам и трансформу:
                                # соединяем пары вершин, чьи локальные знаки отличаются ровно по одной оси
                                verts = np.asarray(bbox.vertices)
                                if verts.shape[0] != 8:
                                    continue
                                T = np.asarray(bbox.transform)
                                center = T[:3, 3]
                                R = T[:3, :3]
                                # Локальные координаты (в осях бокса)
                                local = (verts - center) @ R
                                # Присваиваем каждой вершине тройку знаков (+/-1)
                                signs = np.where(local >= 0.0, 1, -1).astype(int)
                                sign_to_idx = {tuple(s): i for i, s in enumerate(signs)}
                                # Сгенерировать 12 рёбер: пары вершин, различающиеся знаком ровно по одной оси
                                edges_idx = set()
                                for sx in (-1, 1):
                                    for sy in (-1, 1):
                                        for sz in (-1, 1):
                                            s = (sx, sy, sz)
                                            if s not in sign_to_idx:
                                                continue
                                            for axis in range(3):
                                                s2 = list(s)
                                                s2[axis] *= -1
                                                s2 = tuple(s2)
                                                if s2 in sign_to_idx:
                                                    i0 = sign_to_idx[s]
                                                    i1 = sign_to_idx[s2]
                                                    if i0 != i1:
                                                        edges_idx.add(tuple(sorted((i0, i1))))
                                if not edges_idx:
                                    continue
                                segments = np.array([[verts[i], verts[j]] for (i, j) in edges_idx], dtype=float)

                                lbl_idx = labels[idx]
                                lbl_name = label_to_name.get(lbl_idx, f"Class {lbl_idx}")
                                color = label_to_color.get(lbl_idx, (1, 0, 0, 1))

                                color_u8 = (np.array(color) * 255).astype(np.uint8)
                                # Постоянная толщина рамки: 3 см (0.03)
                                radius = 0.015
                                for seg in segments:
                                    p1, p2 = seg[0], seg[1]
                                    v = p2 - p1
                                    length = float(np.linalg.norm(v))
                                    if length <= 1e-8:
                                        continue
                                    direction = v / length
                                    try:
                                        cyl = trimesh.creation.cylinder(radius=radius, height=length, sections=12)
                                    except Exception:
                                        continue
                                    # Повернуть ось Z к направлению ребра и перенести в середину
                                    try:
                                        align = trimesh.geometry.align_vectors([0, 0, 1], direction)
                                        cyl.apply_transform(align)
                                    except Exception:
                                        pass
                                    midpoint = (p1 + p2) / 2.0
                                    cyl.apply_translation(midpoint)
                                    # Материал без влияния освещения (эмуляция unlit через emissive)
                                    try:
                                        emissive = (color_u8[:3] / 255.0).tolist()
                                        mat = trimesh.visual.material.PBRMaterial(
                                            baseColorFactor=(0.0, 0.0, 0.0, 1.0),
                                            metallicFactor=0.0,
                                            roughnessFactor=1.0,
                                            emissiveFactor=emissive,
                                            doubleSided=True,
                                        )
                                        cyl.visual.material = mat
                                    except Exception:
                                        cyl.visual.face_colors = np.tile(color_u8[None, :], (len(cyl.faces), 1))
                                    scene.add_geometry(cyl)
                                legend_items[lbl_name] = color

                    legend_md = "### Legend\n"
                    for lbl_name, color in legend_items.items():
                        c_u8 = (np.array(color) * 255).astype(np.uint8)
                        hex_c = "#{:02x}{:02x}{:02x}".format(c_u8[0], c_u8[1], c_u8[2])
                        legend_md += f"- <span style='color:{hex_c}'>■</span> {lbl_name}\n"

        except Exception as e:
            print(f"Error loading detections: {e}")
            legend_md = f"Error loading detections: {e}"

    # Export combined scene (облако + боксы)
    out_path = os.path.join(target_dir, f"combined_viz_{conf_thres}.glb")
    scene.export(file_obj=out_path)
    
    return out_path, legend_md

def detect_objects(text_labels, target_dir, conf_thres, *viz_args):
    """
    Detect objects from text labels and return the detected objects.
    """
    # Require non-empty text labels
    if not text_labels or not isinstance(text_labels, str) or len([l.strip() for l in text_labels.split(";") if l.strip()]) == 0:
        return None, "Please enter at least one text label (separated by ';')."

    # Ensure CropFormer weights exist (if detection pipeline uses them)
    if torch.cuda.is_available() or not ZOO3D_SKIP_DOWNLOADS:
        try:
            check_weights()
        except Exception as e:
            print(f"Warning: could not ensure Mask2Former weights: {e}")
    
    # 1. Run reconstruction first if needed (checking if predictions exist)
    predictions_path = os.path.join(target_dir, "predictions.npz")
    if not os.path.exists(predictions_path):
        # We need to run reconstruction. But reconstruction needs inputs we might not have in this function scope easily 
        # unless we pass them or assume they are in target_dir. 
        # reconstruct function takes target_dir. Let's call it.
        # However, reconstruct is heavy and takes many args.
        # Let's assume for now user clicked Reconstruct or we call it with defaults/passed args if we merged them.
        
        # Actually, if we want one button to do both, we should probably call `reconstruct` logic here.
        # But `reconstruct` returns GLB path. 
        # Let's call run_model directly if predictions don't exist?
        # Better: Reuse reconstruct function logic or call it.
        
        # Simplify: If predictions don't exist, run standard reconstruction first
        print("Predictions not found, running reconstruction first...")
        # We need arguments for reconstruction. 
        # viz_args contains [frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode]
        # conf_thres is passed separately.
        
        # reconstruct signature: target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, text_labels
        # viz_args order from click: frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode
        
        reconstruct(target_dir, 50.0, *viz_args, text_labels=text_labels) # conf_thres 3.0 default for reconstruction points
    
    
    # Extract text features if provided
    if text_labels:
        labels = [l.strip() for l in text_labels.split(";") if l.strip()]
        if labels:
            print(f"Extracting features for labels: {labels}")
            _, _, clip_model = _init_models()
            text_features = extract_text_feature(labels, clip_model, target_dir)
            print(f"Text features: {text_features}")
            try:
                env = dict(os.environ)
                env["PYTHONPATH"] = (
                    MK_PATH
                    + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
                )
                root_input_dir = os.path.dirname(target_dir)
                seq_name = os.path.basename(target_dir)
                subprocess.run(
                    [
                        sys.executable,
                        os.path.join(MK_PATH, "semantics", "wopen-voc_query.py"),
                        "--config",
                        "wild",
                        "--root",
                        root_input_dir,
                        "--seq_name",
                        seq_name,
                    ],
                    env=env,
                    check=True,
                )
            except Exception as e:
                print(f"Warning: open-voc query failed: {e}")

    return visualize_detections(target_dir, conf_thres, *viz_args)


# -------------------------------------------------------------------------
# 5) Helper functions for UI resets + re-visualization
# -------------------------------------------------------------------------
def clear_fields():
    """
    Clears the 3D viewer, the stored target_dir, and empties the gallery.
    """
    return None


def update_log():
    """
    Display a quick log message while waiting.
    """
    return "Loading and Reconstructing..."


def update_visualization(
    target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example
):
    """
    Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
    and return it for the 3D viewer. If is_example == "True", skip.
    """

    # If it's an example click, skip as requested
    if is_example == "True":
        return None, "No reconstruction available. Please click the Reconstruct button first."

    if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
        return None, "No reconstruction available. Please click the Reconstruct button first."

    predictions_path = os.path.join(target_dir, "predictions.npz")
    if not os.path.exists(predictions_path):
        return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first."

    key_list = [
        "pose_enc",
        "depth",
        "depth_conf",
        "world_points",
        "world_points_conf",
        "images",
        "extrinsic",
        "intrinsic",
        "world_points_from_depth",
    ]

    loaded = np.load(predictions_path, allow_pickle=True)
    predictions = {key: np.array(loaded[key]) for key in key_list}

    glbfile = os.path.join(
        target_dir,
        f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
    )

    glbscene = predictions_to_glb(
        predictions,
        conf_thres=conf_thres,
        filter_by_frames=frame_filter,
        mask_black_bg=mask_black_bg,
        mask_white_bg=mask_white_bg,
        show_cam=show_cam,
        mask_sky=mask_sky,
        target_dir=target_dir,
        prediction_mode=prediction_mode,
    )
    glbscene.export(file_obj=glbfile)

    return glbfile, "Updating Visualization"


# -------------------------------------------------------------------------
# Example images
# -------------------------------------------------------------------------

great_wall_video = "examples/videos/great_wall.mp4"
colosseum_video = "examples/videos/Colosseum.mp4"
room_video = "examples/videos/room.mp4"
kitchen_video = "examples/videos/kitchen.mp4"
fern_video = "examples/videos/fern.mp4"
single_cartoon_video = "examples/videos/single_cartoon.mp4"
single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
pyramid_video = "examples/videos/pyramid.mp4"


# -------------------------------------------------------------------------
# 6) Build Gradio UI
# -------------------------------------------------------------------------
theme = gr.themes.Ocean()
theme.set(
    checkbox_label_background_fill_selected="*button_primary_background_fill",
    checkbox_label_text_color_selected="*button_primary_text_color",
)

with gr.Blocks(
    theme=theme,
    css="""
    .custom-log * {
        font-style: italic;
        font-size: 22px !important;
        background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
        -webkit-background-clip: text;
        background-clip: text;
        font-weight: bold !important;
        color: transparent !important;
        text-align: center !important;
    }
    
    .example-log * {
        font-style: italic;
        font-size: 16px !important;
        background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
        -webkit-background-clip: text;
        background-clip: text;
        color: transparent !important;
    }
    
    #my_radio .wrap {
        display: flex;
        flex-wrap: nowrap;
        justify-content: center;
        align-items: center;
    }

    #my_radio .wrap label {
        display: flex;
        width: 50%;
        justify-content: center;
        align-items: center;
        margin: 0;
        padding: 10px 0;
        box-sizing: border-box;
    }
    """,
) as demo:
    # Instead of gr.State, we use a hidden Textbox:
    is_example = gr.Textbox(label="is_example", visible=False, value="None")
    num_images = gr.Textbox(label="num_images", visible=False, value="None")

    gr.HTML(
        """
    <h1>🦁 Zoo3D: Zero-Shot 3D Object Detection at Scene Level 🐼</h1>
    <p>
    <a href="https://github.com/col14m/zoo3d">GitHub Repository</a>
    </p>

    <div style="font-size: 16px; line-height: 1.5;">
    <p>Upload a video or a set of images to create a 3D reconstruction and run open‑vocabulary 3D object detection from your text labels. The app builds a point cloud and draws colored wireframe bounding boxes for the detected objects.</p>

    <h3>Getting Started:</h3>
    <ol>
        <li><strong>Upload Your Data:</strong> Use "Upload Video" or "Upload Images". Videos are sampled at 1 frame/sec.</li>
        <li><strong>Enter Text Labels (Required):</strong> Provide one or more labels separated by semicolons, e.g. <code>chair; table; plant</code>.</li>
        <li><strong>Detect:</strong> Click <strong>"Detect Objects"</strong>. The app will reconstruct the scene (if needed) and then run detection.</li>
        <li><strong>Threshold (Optional):</strong> Tune the <em>Detection Threshold</em> (0–1). Higher = fewer, more confident detections.</li>
        <li><strong>Visualize & Download:</strong> A single 3D view shows the point cloud and colored wireframe boxes. A legend maps colors to labels. You can download the GLB.</li>
    </ol>
    <p><strong style="color: #0ea5e9;">Notes:</strong> <span style="color: #0ea5e9; font-weight: bold;">Reconstruction is triggered automatically on first run. If no labels are provided, you'll see an error: </span><code>Please enter at least one text label (separated by ';').</code></p>
    </div>
    """
    )

    target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")

    with gr.Row():
        with gr.Column(scale=2):
            input_video = gr.Video(label="Upload Video", interactive=True)
            input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)

            def _safe_gallery(**kwargs):
                # Gradio API differs between versions; HF Spaces may run Gradio 6.x.
                # Retry by removing unsupported kwargs.
                while True:
                    try:
                        return gr.Gallery(**kwargs)
                    except TypeError as e:
                        msg = str(e)
                        # Typical: "got an unexpected keyword argument 'show_download_button'"
                        bad = None
                        import re
                        m = re.search(r"unexpected keyword argument '([^']+)'", msg)
                        if m:
                            bad = m.group(1)
                        if bad and bad in kwargs:
                            kwargs.pop(bad)
                            continue
                        # Fallback: drop known version-sensitive args
                        for k in ["show_download_button", "preview", "object_fit", "columns", "height"]:
                            if k in kwargs:
                                kwargs.pop(k)
                                break
                        else:
                            raise

            image_gallery = _safe_gallery(
                label="Preview",
                columns=4,
                height="300px",
                show_download_button=True,
                object_fit="contain",
                preview=True,
            )

        with gr.Column(scale=4):
            text_labels = gr.Textbox(label="Text Labels (separated by ;)", placeholder="cat; dog; car")
            with gr.Column():

            
                gr.Markdown("**3D Reconstruction & detection (Point Cloud and Bounding Boxes)**")
                log_output = gr.Markdown(
                    "Please upload a video or images, then click Detect Objects.", elem_classes=["custom-log"]
                )
                reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)

            with gr.Row():
                detect_btn = gr.Button("Detect Objects", scale=1, variant="primary")
                clear_btn = gr.ClearButton(
                    [input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery, text_labels],
                    scale=1,
                )
            # with gr.Row():
            #     prediction_mode = gr.Textbox(
            #         value="Depthmap and Camera Branch",
            #         visible=False,
            #         label="Prediction Mode"
            #     )
            
            # We'll create a hidden component so the event handlers don't break
            prediction_mode = gr.Textbox(value="Depthmap and Camera Branch", visible=False)

            # Основные параметры визуализации реконструкции
            with gr.Row():
                conf_thres = gr.Slider(
                    minimum=0,
                    maximum=100,
                    value=50,
                    step=0.1,
                    label="Confidence Threshold (%)",
                    visible=False,
                )
                frame_filter = gr.Dropdown(
                    choices=["All"],
                    value="All",
                    label="Show Points from Frame",
                    visible=False,
                )
                with gr.Column():
                    show_cam = gr.Checkbox(label="Show Camera", value=True, visible=False)
                    mask_sky = gr.Checkbox(label="Filter Sky", value=False, visible=False)
                    mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False, visible=False)
                    mask_white_bg = gr.Checkbox(label="Filter White Background", value=False, visible=False)

            # Порог для детекции и легенда цветов боксов
            detection_conf_thres = gr.Slider(
                minimum=0,
                maximum=1,
                value=0.6,
                step=0.01,
                label="Detection Threshold",
            )
            detection_legend = gr.Markdown("Legend will appear here")

    # ---------------------- Examples section ----------------------
    examples = [
    ]

    def example_pipeline(
        input_video,
        num_images_str,
        input_images,
        conf_thres,
        mask_black_bg,
        mask_white_bg,
        show_cam,
        mask_sky,
        prediction_mode,
        is_example_str,
        text_labels,
    ):
        """
        1) Copy example images to new target_dir
        2) Reconstruct (and Detect if labels present)
        3) Return model3D + logs + new_dir + updated dropdown + gallery
        We do NOT return is_example. It's just an input.
        """
        target_dir, image_paths = handle_uploads(input_video, input_images)
        # Always use "All" for frame_filter in examples
        frame_filter = "All"
        
        detection_conf = 0.85
        
        glbfile, legend_md = detect_objects(
            text_labels, 
            target_dir, 
            detection_conf, 
            frame_filter, 
            mask_black_bg, 
            mask_white_bg, 
            show_cam, 
            mask_sky, 
            prediction_mode
        )
        
        log_msg = "Example loaded and processed."
        
        return glbfile, log_msg + "\n\n" + legend_md, target_dir, gr.Dropdown(choices=["All"], value="All", interactive=True), image_paths

    detect_btn.click(fn=clear_fields, inputs=[], outputs=[]).then(
        fn=detect_objects, 
        inputs=[
            text_labels, 
            target_dir_output, 
            detection_conf_thres, 
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode
        ], 
        outputs=[reconstruction_output, detection_legend]
    ).then(
        fn=lambda: "False", inputs=[], outputs=[is_example]  # set is_example to "False"
    )

    detection_conf_thres.change(
        fn=visualize_detections,
        inputs=[
            target_dir_output, 
            detection_conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode
        ],
        outputs=[reconstruction_output, detection_legend]
    )

    # -------------------------------------------------------------------------
    # Real-time Visualization Updates
    # -------------------------------------------------------------------------
    conf_thres.change(
        update_visualization,
        [
            target_dir_output,
            conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
            is_example,
        ],
        [reconstruction_output, log_output],
    )
    frame_filter.change(
        update_visualization,
        [
            target_dir_output,
            conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
            is_example,
        ],
        [reconstruction_output, log_output],
    )
    mask_black_bg.change(
        update_visualization,
        [
            target_dir_output,
            conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
            is_example,
        ],
        [reconstruction_output, log_output],
    )
    mask_white_bg.change(
        update_visualization,
        [
            target_dir_output,
            conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
            is_example,
        ],
        [reconstruction_output, log_output],
    )
    show_cam.change(
        update_visualization,
        [
            target_dir_output,
            conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
            is_example,
        ],
        [reconstruction_output, log_output],
    )
    prediction_mode.change(
        update_visualization,
        [
            target_dir_output,
            conf_thres,
            frame_filter,
            mask_black_bg,
            mask_white_bg,
            show_cam,
            mask_sky,
            prediction_mode,
            is_example,
        ],
        [reconstruction_output, log_output],
    )

    # # -------------------------------------------------------------------------
    # # Auto-update gallery whenever user uploads or changes their files
    # # -------------------------------------------------------------------------
    input_video.change(
        fn=update_gallery_on_upload,
        inputs=[input_video, input_images],
        outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
    )
    input_images.change(
        fn=update_gallery_on_upload,
        inputs=[input_video, input_images],
        outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
    )

def main():
    demo.queue(max_size=20).launch(show_error=True, share=False, show_api=False)


if __name__ == "__main__":
    main()