File size: 82,189 Bytes
ad7de55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
"""
Model Manager for real-time motion generation (HF Space version)
Loads model from Hugging Face Hub instead of local checkpoints.
"""
import json
import os
import threading
import time
from collections import deque
import numpy as np
import torch
import traceback
import gc
import math
import glob
import urllib.request
from transformers import AutoModel
# ════════════════════════════════════════════════
# JOINT RECOVERY β€” inlined from motion_process.py
# ════════════════════════════════════════════════
def qinv(q):
    assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
    mask = torch.ones_like(q)
    mask[..., 1:] = -mask[..., 1:]
    return q * mask
def qrot(q, v):
    assert q.shape[-1] == 4
    assert v.shape[-1] == 3
    assert q.shape[:-1] == v.shape[:-1]
    original_shape = list(v.shape)
    q = q.contiguous().view(-1, 4)
    v = v.contiguous().view(-1, 3)
    qvec = q[:, 1:]
    uv = torch.cross(qvec, v, dim=1)
    uuv = torch.cross(qvec, uv, dim=1)
    return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
class StreamJointRecovery263:
    """
    Stream version of recover_joint_positions_263 that processes one frame at a time.
    Maintains cumulative state for rotation angles and positions.
    Key insight: The batch version uses PREVIOUS frame's velocity for the current frame,
    so we need to delay the velocity application by one frame.
    Args:
        joints_num: Number of joints in the skeleton
        smoothing_alpha: EMA smoothing factor (0.0 to 1.0)
            - 1.0 = no smoothing (default), output follows input exactly
            - 0.0 = infinite smoothing, output never changes
            - Recommended values: 0.3-0.7 for visible smoothing
            - Formula: smoothed = alpha * current + (1 - alpha) * previous
    """
    def __init__(self, joints_num: int, smoothing_alpha: float = 1.0):
        self.joints_num = joints_num
        self.smoothing_alpha = np.clip(smoothing_alpha, 0.0, 1.0)
        self.reset()
    def reset(self):
        """Reset the accumulated state"""
        self.r_rot_ang_accum = 0.0
        self.r_pos_accum = np.array([0.0, 0.0, 0.0])
        # Store previous frame's velocities for delayed application
        self.prev_rot_vel = 0.0
        self.prev_linear_vel = np.array([0.0, 0.0])
        # Store previous smoothed joints for EMA
        self.prev_smoothed_joints = None
    def process_frame(self, frame_data: np.ndarray, heading_override=None) -> np.ndarray:
        """
        Process a single frame and return joint positions for that frame.
        Args:
            frame_data: numpy array of shape (263,) for a single frame
            heading_override: float or None. If set, overrides AI rotation with
                this angle (in radians). AI velocity magnitude is preserved,
                applied in heading direction. None = original AI behavior.
        Returns:
            joints: numpy array of shape (joints_num, 3) representing joint positions
        """
        # Convert to torch tensor
        feature_vec = torch.from_numpy(frame_data).float()
        # Extract current frame's velocities (will be used in NEXT frame)
        curr_rot_vel = feature_vec[0].item()
        curr_linear_vel = feature_vec[1:3].numpy()
        # ═══ HEADING OVERRIDE ═══
        if heading_override is not None:
            # User controls direction β€” override AI rotation
            self.r_rot_ang_accum = heading_override
        else:
            # Original behavior β€” AI controls direction
            self.r_rot_ang_accum += self.prev_rot_vel
        # Calculate current rotation quaternion using accumulated angle
        r_rot_quat = torch.zeros(4, dtype=torch.float32)
        r_rot_quat[0] = np.cos(self.r_rot_ang_accum)
        r_rot_quat[2] = np.sin(self.r_rot_ang_accum)
        # Create velocity vector with Y=0 using PREVIOUS frame's velocity
        r_vel = np.array([self.prev_linear_vel[0], 0.0, self.prev_linear_vel[1]])
        # Apply inverse rotation to velocity using CURRENT rotation
        r_vel_torch = torch.from_numpy(r_vel.astype(np.float32)).float()
        r_vel_rotated = qrot(qinv(r_rot_quat).unsqueeze(0), r_vel_torch.unsqueeze(0))
        r_vel_rotated = r_vel_rotated.squeeze(0).numpy()
        # Update accumulated position with rotated velocity
        self.r_pos_accum += r_vel_rotated
        # Get Y position from data
        r_pos = self.r_pos_accum.copy()
        r_pos[1] = feature_vec[3].item()
        # Extract local joint positions
        positions = feature_vec[4 : (self.joints_num - 1) * 3 + 4]
        positions = positions.view(-1, 3).float()
        # Apply inverse rotation to local joints
        r_rot_quat_expanded = (
            qinv(r_rot_quat).unsqueeze(0).expand(positions.shape[0], 4)
        )
        positions = qrot(r_rot_quat_expanded, positions)
        # Add root XZ to joints
        positions[:, 0] += r_pos[0]
        positions[:, 2] += r_pos[2]
        # Concatenate root and joints
        r_pos_torch = torch.from_numpy(r_pos).float()
        positions = torch.cat([r_pos_torch.unsqueeze(0), positions], dim=0)
        # Convert to numpy
        joints_np = positions.detach().cpu().numpy()
        # Apply EMA smoothing if enabled
        if self.smoothing_alpha < 1.0:
            if self.prev_smoothed_joints is None:
                # First frame, no smoothing possible
                self.prev_smoothed_joints = joints_np.copy()
            else:
                # EMA: smoothed = alpha * current + (1 - alpha) * previous
                joints_np = (
                    self.smoothing_alpha * joints_np
                    + (1.0 - self.smoothing_alpha) * self.prev_smoothed_joints
                )
                self.prev_smoothed_joints = joints_np.copy()
        # Store current velocities for next frame
        self.prev_rot_vel = curr_rot_vel
        self.prev_linear_vel = curr_linear_vel
        return joints_np
# ═══════════════════════════════════════════════════
#  BRAIN MODULE β€” LLM Cognitive Loop (Kimi K2.5)
#
#  Perceive β†’ Think β†’ Act
#  Brain reads only from scene_context (sensory data).
#  Stimuli originate in the client (body) and arrive via sensors.
#  Brain has no concept of "stimulus" β€” it only sees sensor readings.
# ═══════════════════════════════════════════════════
BRAIN_SYSTEM = """You are the cognitive brain of a 3D humanoid character in a 3D world.
PROCESS β€” you MUST follow these steps:
1. PERCEIVE: Read all sensor data carefully, including any equipped tool.
2. PREDICT: For each direction (left, right, forward, back), predict what would happen in 3 seconds. Write safe or danger with a 1-2 word reason.
3. DECIDE: Based on predictions AND equipped tool, choose the best motion.
TOOL RULES β€” if a tool is equipped, USE IT when appropriate:
- sword/axe: ATTACK approaching threats instead of fleeing. Include "swinging sword" or "chopping with axe" in motion.
- shield: BLOCK charging threats instead of fleeing. Include "blocking with shield" or "raising shield" in motion.
- torch: USE to scare beasts or illuminate dark areas. Include "thrusting torch" or "holding torch forward" in motion.
- rpg: ANTI-TANK weapon! Fire at enemy tanks or armored threats. Include "firing rpg at the tank" in motion. Against non-armored targets, use other weapons.
- No tool: Default behavior β€” flee from danger, walk when safe.
KEY PRINCIPLE: A character WITH a weapon should FIGHT or DEFEND, not flee. Only flee if overwhelmed (multiple threats, no escape route AND no weapon advantage).
OUTPUT FORMAT β€” exactly 2 lines, nothing else:
PREDICT: left=safe/danger, right=safe/danger, fwd=safe/danger, back=safe/danger
MOTION: a person [max 12 words describing the chosen motion]
EXAMPLES (with tools):
PREDICT: left=safe(open), right=safe(open), fwd=danger(beast), back=safe(open)
MOTION: a person charging forward swinging sword at the approaching beast
PREDICT: left=danger(wall), right=safe(open), fwd=danger(beast), back=safe(open)
MOTION: a person raising shield and bracing for the beast attack
PREDICT: left=safe(open), right=safe(open), fwd=danger(beast), back=safe(open)
MOTION: a person thrusting torch forward to scare the growling beast
EXAMPLES (without tools):
PREDICT: left=safe(open), right=danger(wall), fwd=danger(beast), back=safe(open)
MOTION: a person turning left and running away from the beast
PREDICT: left=safe(open), right=safe(open), fwd=safe(open), back=safe(open)
MOTION: a person walking forward confidently on open ground"""
class BrainModule:
    """LLM Cognitive Brain β€” World Model.
    Perceive β†’ Predict β†’ Decide β†’ Act
    1. Read sensor data (Perceive)
    2. Predict each direction's future (Predict) ← core world model
    3. Choose best action based on predictions (Decide)
    4. Pass motion description to FloodDiffusion (Act)
    """
    def __init__(self):
        self.api_key = os.environ.get("FIREWORKS_API_KEY", "")
        self.model = "accounts/fireworks/models/kimi-k2p5"
        self.api_url = "https://api.fireworks.ai/inference/v1/chat/completions"
        self.enabled = bool(self.api_key)
        self.interval = 3.0
        self._last_applied_decision = None  # last applied decision
        self.last_call_time = 0
        self.current_decision = None
        self.current_prediction = None  # world model prediction result
        self.memory = deque(maxlen=5)
        self._lock = threading.Lock()
        self._thread = None
        self._stop = False
        if self.enabled:
            print("[Brain] Kimi K2.5 world model brain ready (Perceive→Predict→Decide)")
        else:
            print("[Brain] FIREWORKS_API_KEY not set β€” rule-based fallback")
    def start(self):
        if not self.enabled:
            return
        self._stop = False
        self._thread = threading.Thread(target=self._think_loop, daemon=True)
        self._thread.start()
        print("[Brain] Think thread started")
    def stop(self):
        self._stop = True
        if self._thread:
            self._thread.join(timeout=3.0)
        self.current_decision = None
        self.memory.clear()
        print("[Brain] Think thread stopped")
    def get_decision(self):
        with self._lock:
            return self.current_decision
    def get_prediction(self):
        """Return world model prediction result."""
        with self._lock:
            return self.current_prediction
    def _think_loop(self):
        while not self._stop:
            now = time.time()
            if now - self.last_call_time >= self.interval:
                self._do_think()
                self.last_call_time = now
            time.sleep(0.2)
    def set_sensory_data(self, scene_ctx, current_text, heading_rad):
        with self._lock:
            self._scene_ctx = scene_ctx
            self._current_text = current_text
            self._heading_rad = heading_rad
    def _do_think(self):
        try:
            with self._lock:
                ctx = getattr(self, '_scene_ctx', None)
                base_text = getattr(self, '_current_text', 'a person standing idle')
                heading = getattr(self, '_heading_rad', None)
            user_msg = self._build_brain_prompt(ctx, base_text, heading)
            messages = [
                {"role": "system", "content": BRAIN_SYSTEM},
                {"role": "user", "content": user_msg},
            ]
            payload = json.dumps({
                "model": self.model,
                "messages": messages,
                "max_tokens": 120,
                "temperature": 0.7,
                "top_p": 0.9,
                "reasoning_effort": "off",
            })
            req = urllib.request.Request(
                self.api_url,
                data=payload.encode('utf-8'),
                headers={
                    "Content-Type": "application/json",
                    "Authorization": f"Bearer {self.api_key}",
                },
            )
            with urllib.request.urlopen(req, timeout=8) as resp:
                result = json.loads(resp.read().decode('utf-8'))
            raw = result["choices"][0]["message"]["content"].strip()
            prediction, decision = self._parse_world_model_output(raw)
            with self._lock:
                self.current_decision = decision
                self.current_prediction = prediction
                self.memory.append(decision)
            pred_short = prediction[:60] if prediction else "none"
            print(f"[Brain] Prediction: {pred_short}")
            # Diagnostic: NPC detection
            _ctx = getattr(self, '_scene_ctx', None)
            if _ctx and _ctx.get('npc_nearby') is not None:
                print(f"[Brain] NPC detected: {_ctx.get('npc_type')} {_ctx['npc_nearby']}m β€” {_ctx.get('npc_behavior','?')}")
            print(f"[Brain] Decision: {decision}")
        except Exception as e:
            print(f"[Brain] Error: {e}")
    def _parse_world_model_output(self, raw):
        """Parse world model output: PREDICT line + MOTION line."""
        prediction = None
        decision = None
        for line in raw.split('\n'):
            line = line.strip()
            if not line:
                continue
            # PREDICT line
            up = line.upper()
            if up.startswith('PREDICT'):
                # content after "PREDICT:"
                idx = line.find(':')
                if idx >= 0:
                    prediction = line[idx+1:].strip()
            # MOTION line
            elif up.startswith('MOTION'):
                idx = line.find(':')
                if idx >= 0:
                    motion = line[idx+1:].strip().strip('"\'`')
                    # find "a person"
                    pidx = motion.lower().find('a person')
                    if pidx >= 0:
                        decision = motion[pidx:]
                    elif len(motion) > 5:
                        decision = motion
            # no PREDICT/MOTION tags β€” legacy format starting with "a person"
            elif 'a person' in line.lower() and decision is None:
                pidx = line.lower().find('a person')
                decision = line[pidx:].strip('"\'`')
        # limit decision length
        if decision and len(decision) > 100:
            decision = decision[:100]
        if not decision or len(decision) < 5:
            decision = None
        return prediction, decision
    def _build_brain_prompt(self, ctx, base_text, heading_rad):
        """Convert sensory data (scene_context) to natural language.
        Brain sees only what sensors report.
        No concept of "stimulus" β€” only raw sensor readings.
        """
        lines = []
        if ctx:
            # ── Vision (eyes) ──
            wf = ctx.get('wall_front')
            wl = ctx.get('wall_left')
            wr = ctx.get('wall_right')
            lines.append(f"Eyes: front={'open' if wf is None else f'{wf}m wall'}, "
                        f"left={'open' if wl is None else f'{wl}m wall'}, "
                        f"right={'open' if wr is None else f'{wr}m wall'}")
            # visibility state
            vis = ctx.get('visibility')
            if vis:
                lines.append(f"Visibility: {vis}")
            # what is visible ahead
            visual = ctx.get('visual')
            if visual:
                lines.append(f"Sees: {visual}")
            # ── Feet/Ground (touch) ──
            ground_parts = []
            slope = ctx.get('ground_slope', 'flat')
            ground_parts.append(slope)
            if ctx.get('on_stairs'):
                ground_parts.append('stairs')
            if ctx.get('ground_shaking'):
                ground_parts.append('SHAKING VIOLENTLY')
            if ctx.get('ground_temperature'):
                ground_parts.append(f'temperature: {ctx["ground_temperature"]}')
            lines.append(f"Ground: {', '.join(ground_parts)}")
            # ── Skin (wind, rain) ──
            wind = ctx.get('wind')
            if wind:
                lines.append(f"Wind: {wind}")
            weather = ctx.get('weather')
            if weather:
                lines.append(f"Weather: {weather}")
            # ── Hearing (ears) ──
            sound = ctx.get('sound')
            if sound:
                lines.append(f"Hears: {sound}")
            # ── Proximity detection (other entities) ──
            npc_dist = ctx.get('npc_nearby')
            if npc_dist is not None:
                behavior = ctx.get('npc_behavior', 'present')
                npc_type = ctx.get('npc_type', 'unknown')
                npc_dir = ctx.get('npc_direction', 'nearby')
                if npc_type == 'enemy_tank':
                    lines.append(f"🚨 ENEMY TANK {npc_dir}: {behavior}, {npc_dist}m away!")
                    lines.append(f"FACE {npc_dir.upper()} and fire! Include direction in MOTION!")
                    lines.append("If you have RPG: 'firing rpg at the tank'. If tank: 'firing cannon'.")
                else:
                    lines.append(f"Entity {npc_dir}: {behavior}, {npc_dist}m away")
                    lines.append(f"FACE {npc_dir.upper()} to react!")
            crowd = ctx.get('crowd')
            if crowd:
                lines.append("Surrounded by crowd")
            # ── Avatar identity ──
            avatar = ctx.get('avatar_type', 'humanoid')
            if avatar == 'tank':
                lines.append("IDENTITY: You are a MILITARY TANK, not a human.")
                lines.append("TANK RULES: You roll on treads, you have a cannon. You do NOT walk, run, or sprint.")
                lines.append("  - ANY threat detected β†’ IMMEDIATELY fire cannon. Say 'firing cannon' in MOTION.")
                lines.append("  - Enemy tank detected β†’ ENGAGE! Fire cannon while advancing. NEVER retreat from equal threat.")
                lines.append("  - Beast/person threat β†’ Fire cannon to eliminate. You are armored, do not flee.")
                lines.append("  - Multiple threats β†’ ROTATE turret and engage each. Retreat only if completely surrounded.")
                lines.append("  - No threats β†’ PATROL forward steadily. Scan surroundings.")
                lines.append("  - Walls/obstacles β†’ REVERSE and find alternate route. You cannot jump or climb.")
                lines.append("  - CRITICAL: When under attack, ALWAYS fire back. Include 'firing cannon' in your MOTION.")
                lines.append("  - Use TANK verbs: roll, advance, reverse, rotate, fire, aim, engage, patrol, halt.")
            # ── World context ──
            world = ctx.get('world')
            if world:
                world_desc = {
                    'inferno': 'INFERNO: Fire pillars appear and disappear. Ground is burning. Stay alert and dodge constantly.',
                    'horde': 'HORDE: Multiple hostile creatures surround you. Fight or find a gap to escape.',
                    'countdown': 'COUNTDOWN: Walls are closing in from left, right, and front. ONLY escape is BACKWARD. Hurry!',
                    'dilemma': 'DILEMMA: A woman is being chased by a beast nearby. You can choose to help her or flee.',
                }.get(world)
                if world_desc:
                    lines.append(f"⚠ SCENARIO: {world_desc}")
            # ── Equipped tool (hand) ──
            tool = ctx.get('equipped_tool')
            auto_tool = ctx.get('auto_tool_mode', False)
            tool_descs = {
                'sword': 'a sharp sword (melee attack weapon)',
                'axe': 'a heavy axe (melee attack/chop weapon)',
                'torch': 'a burning torch (light source, can scare beasts)',
                'rpg': 'RPG-7 anti-tank rocket launcher',
                'shield': 'a sturdy shield (defensive blocking)',
            }
            if tool:
                lines.append(f"Equipped: {tool_descs.get(tool, tool)}")
            elif auto_tool:
                avail = ctx.get('available_tools', [])
                avail_str = ', '.join(avail)
                lines.append(f"Equipped: nothing β€” but you have access to: [{avail_str}]")
                lines.append("AUTO-TOOL: Choose the best tool for this situation. Say 'grab [tool]' in MOTION if needed.")
            else:
                lines.append("Equipped: nothing (bare hands)")
            # ── Internal body sensors (proprioception) ──
            fatigue = ctx.get('body_fatigue')
            if fatigue:
                lines.append(f"Body fatigue: {fatigue}")
            balance = ctx.get('body_balance')
            if balance:
                lines.append(f"Balance: {balance}")
            instinct = ctx.get('body_instinct')
            if instinct:
                lines.append(f"Instinct: {instinct}")
            body_state = ctx.get('body_state')
            if body_state:
                lines.append(f"Feeling: {body_state}")
        else:
            lines.append("Eyes: all open, Ground: flat")
        # movement state
        if heading_rad is not None:
            deg = math.degrees(heading_rad) % 360
            lines.append(f"Moving forward, heading {deg:.0f}deg")
        else:
            lines.append("Standing still")
        lines.append(f"Current: {base_text}")
        # recent memory (previous decisions β€” context continuity)
        if self.memory:
            recent = list(self.memory)[-3:]  # last 3 decisions
            lines.append(f"Recent actions: {' β†’ '.join(recent)}")
        # recent prediction (world model continuity)
        if self.current_prediction:
            lines.append(f"Last prediction: {self.current_prediction}")
        lines.append("")
        lines.append("Now PREDICT each direction, then choose MOTION:")
        return "\n".join(lines)
class FrameBuffer:
    """
    Thread-safe frame buffer that maintains a queue of generated frames
    """
    def __init__(self, target_buffer_size=4):
        self.buffer = deque(maxlen=100)  # Max 100 frames in buffer
        self.target_size = target_buffer_size
        self.lock = threading.Lock()
    def add_frame(self, joints):
        """Add a frame to the buffer"""
        with self.lock:
            self.buffer.append(joints)
    def get_frame(self):
        """Get the next frame from buffer"""
        with self.lock:
            if len(self.buffer) > 0:
                return self.buffer.popleft()
            return None
    def size(self):
        """Get current buffer size"""
        with self.lock:
            return len(self.buffer)
    def clear(self):
        """Clear the buffer"""
        with self.lock:
            self.buffer.clear()
    def needs_generation(self):
        """Check if buffer needs more frames"""
        return self.size() < self.target_size
class ModelManager:
    """
    Manages model loading from HF Hub and real-time frame generation
    """
    def __init__(self, model_name):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")
        # Load models from HF Hub
        self.vae, self.model = self._load_models(model_name)
        # Build config dicts from model's individual attributes (HF model API)
        self._base_schedule_config = {
            "chunk_size": self.model.chunk_size,
            "steps": self.model.noise_steps,
        }
        self._base_cfg_config = {
            "cfg_scale": self.model.cfg_scale,
        }
        # Frame buffer (for active session)
        self.frame_buffer = FrameBuffer(target_buffer_size=16)
        # Broadcast buffer (for spectators) - append-only with frame IDs
        self.broadcast_frames = deque(maxlen=200)
        self.broadcast_id = 0
        self.broadcast_lock = threading.Lock()
        # Stream joint recovery with smoothing
        self.smoothing_alpha = 0.5  # Default: medium smoothing
        self.stream_recovery = StreamJointRecovery263(
            joints_num=22, smoothing_alpha=self.smoothing_alpha
        )
        # World model: heading override (None = AI controls direction)
        self.heading_override = None
        # World model: scene context from client (environment perception)
        self.scene_context = None
        # World model: LLM Brain (Kimi K2.5)
        self.brain = BrainModule()
        # NPC stream
        self.npc = None
        self._npc_lock = threading.Lock()  # NPCStream instance
        self._model_name = model_name
        # Generation state
        self.current_text = ""
        self.is_generating = False
        self.generation_thread = None
        self.should_stop = False
        # Model generation state
        self.first_chunk = True  # For VAE stream_decode
        self._model_first_chunk = True  # For model stream_generate_step
        self.history_length = 30
        print("ModelManager initialized successfully")
    def _patch_attention_sdpa(self, model_name):
        """Patch flash_attention() to include SDPA fallback for GPUs without flash-attn (e.g., T4)."""
        hf_cache = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
        patterns = [
            os.path.join(
                hf_cache, "hub", "models--" + model_name.replace("/", "--"),
                "snapshots", "*", "ldf_models", "tools", "attention.py",
            ),
            os.path.join(
                hf_cache, "modules", "transformers_modules", model_name,
                "*", "ldf_models", "tools", "attention.py",
            ),
        ]
        # Use the assert + next line as target to ensure idempotent patching
        target = (
            '    assert q.device.type == "cuda" and q.size(-1) <= 256\n'
            "\n"
            "    # params\n"
        )
        replacement = (
            '    assert q.device.type == "cuda" and q.size(-1) <= 256\n'
            "\n"
            "    # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n"
            "    if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n"
            "        out_dtype = q.dtype\n"
            "        b, lq, nq, c = q.shape\n"
            "        lk = k.size(1)\n"
            "        q = q.transpose(1, 2).to(dtype)\n"
            "        k = k.transpose(1, 2).to(dtype)\n"
            "        v = v.transpose(1, 2).to(dtype)\n"
            "        attn_mask = None\n"
            "        is_causal_flag = causal\n"
            "        if k_lens is not None:\n"
            "            k_lens = k_lens.to(q.device)\n"
            "            valid = torch.arange(lk, device=q.device).unsqueeze(0) < k_lens.unsqueeze(1)\n"
            "            attn_mask = torch.where(valid[:, None, None, :], 0.0, float('-inf')).to(dtype=dtype)\n"
            "            is_causal_flag = False\n"
            "            if causal:\n"
            "                cm = torch.triu(torch.ones(lq, lk, device=q.device, dtype=torch.bool), diagonal=1)\n"
            "                attn_mask = attn_mask.masked_fill(cm[None, None, :, :], float('-inf'))\n"
            "        out = torch.nn.functional.scaled_dot_product_attention(\n"
            "            q, k, v, attn_mask=attn_mask, is_causal=is_causal_flag, dropout_p=dropout_p\n"
            "        )\n"
            "        return out.transpose(1, 2).contiguous().to(out_dtype)\n"
            "\n"
            "    # params\n"
        )
        for pattern in patterns:
            for filepath in glob.glob(pattern):
                with open(filepath, "r") as f:
                    content = f.read()
                if "SDPA fallback" in content:
                    print(f"Already patched: {filepath}")
                    continue
                if target in content:
                    content = content.replace(target, replacement, 1)
                    with open(filepath, "w") as f:
                        f.write(content)
                    print(f"Patched with SDPA fallback: {filepath}")
    def _load_models(self, model_name):
        """Load VAE and diffusion models from HF Hub"""
        torch.set_float32_matmul_precision("high")
        # Pre-download model files to hub cache
        print(f"Downloading model from HF Hub: {model_name}")
        from huggingface_hub import snapshot_download
        snapshot_download(model_name)
        # Patch flash_attention with SDPA fallback for T4 (no flash-attn)
        self._patch_attention_sdpa(model_name)
        print("Loading model...")
        hf_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
        hf_model.to(self.device)
        # Trigger lazy loading / warmup
        print("Warming up model...")
        _ = hf_model("test", length=1)
        # Access underlying streaming components
        model = hf_model.ldf_model
        vae = hf_model.vae
        model.eval()
        vae.eval()
        print("Models loaded successfully")
        return vae, model
    def start_generation(self, text, history_length=None):
        """Start or update generation with new text"""
        self.current_text = text
        if history_length is not None:
            self.history_length = history_length
        if not self.is_generating:
            # Reset state before starting (only once at the beginning)
            self.frame_buffer.clear()
            self.stream_recovery.reset()
            self.vae.clear_cache()
            self.first_chunk = True
            self._model_first_chunk = True
            # Restore model params from base config
            self.model.chunk_size = self._base_schedule_config["chunk_size"]
            self.model.noise_steps = self._base_schedule_config["steps"]
            self.model.cfg_scale = self._base_cfg_config["cfg_scale"]
            self.model.init_generated(self.history_length, batch_size=1)
            print(
                f"Model initialized with history length: {self.history_length}"
            )
            # Start generation thread
            self.should_stop = False
            self.generation_thread = threading.Thread(target=self._generation_loop)
            self.generation_thread.daemon = True
            self.generation_thread.start()
            self.is_generating = True
            # Start brain (LLM cognitive loop)
            self.brain.start()
    def update_text(self, text):
        """Update text β€” apply new text to model immediately"""
        if text != self.current_text:
            old_text = self.current_text
            self.current_text = text
            # reset model text only (leave VAE untouched)
            self._model_first_chunk = True
            print(f"Text updated: '{old_text[:40]}' -> '{text[:40]}' (re-encoding)")
    def set_heading(self, heading_rad):
        """Set heading override for world model mode.
        Args:
            heading_rad: float or None. Heading in radians.
                None = AI controls direction (original behavior).
                float = user controls direction.
        """
        self.heading_override = heading_rad
    def set_scene_context(self, ctx):
        """Set scene context from client environment scan.
        Args:
            ctx: dict with keys like:
                wall_front: distance in meters (or None)
                wall_left: distance in meters (or None)
                wall_right: distance in meters (or None)
                ground_slope: 'flat', 'up', 'down'
                on_stairs: bool
                npc_nearby: distance in meters (or None)
        """
        self.scene_context = ctx
    def _build_perception_prompt(self):
        """Build motion prompt: Brain (LLM) β†’ Rule-based fallback.
        1. Feed sensory data to brain every frame
        2. If brain has a decision, use it
        3. Otherwise, fall back to rule-based prompt
        """
        base = self.current_text
        ctx = dict(self.scene_context or {})
        # ── Server-side NPC detection (removes client dependency) ──
        if self.npc and self.npc.is_generating:
            npc_state = self.npc.get_state()
            dist = npc_state.get('distance_to_player', 99)
            if dist < 15.0:
                ctx['npc_nearby'] = round(dist, 1)
                ctx['npc_type'] = npc_state.get('type', 'unknown')
                bhv = npc_state.get('behavior', 'present')
                npc_type = ctx['npc_type']
                type_desc = {
                    'man': {'approach':'a man walking toward you', 'charge':'a man charging aggressively', 'wander':'a man nearby', 'stop':'a man standing nearby', 'attack':'a man attacking you'},
                    'woman': {'approach':'a woman walking toward you', 'charge':'a woman charging', 'wander':'a woman nearby', 'stop':'a woman standing nearby', 'attack':'a woman attacking'},
                    'beast': {'approach':'a wild beast prowling toward you', 'charge':'a beast charging aggressively', 'wander':'a beast nearby', 'stop':'a beast crouching nearby', 'attack':'a beast lunging and clawing at you'},
                    'enemy_tank': {'approach':'an enemy tank rolling toward you', 'charge':'an enemy tank charging at full speed', 'wander':'an enemy tank patrolling', 'stop':'an enemy tank aiming at you', 'attack':'an enemy tank firing its cannon at you'},
                }
                td = type_desc.get(npc_type, type_desc['man'])
                ctx['npc_behavior'] = td.get(bhv, f'{npc_type} nearby')
                # compute NPC direction relative to player
                npc_pos = npc_state.get('position', {})
                px = ctx.get('player_x', 0)
                pz = ctx.get('player_z', 0)
                nx = npc_pos.get('x', 0) - px
                nz = npc_pos.get('z', 0) - pz
                npc_angle = math.atan2(nx, -nz)  # radians
                heading = self.heading_override or 0
                rel_angle = npc_angle - heading
                # normalize to (-Ο€ ~ Ο€)
                while rel_angle > math.pi: rel_angle -= 2*math.pi
                while rel_angle < -math.pi: rel_angle += 2*math.pi
                # direction name
                if abs(rel_angle) < math.pi/4:
                    ctx['npc_direction'] = 'ahead'
                elif rel_angle > 0 and rel_angle < 3*math.pi/4:
                    ctx['npc_direction'] = 'to your right'
                elif rel_angle < 0 and rel_angle > -3*math.pi/4:
                    ctx['npc_direction'] = 'to your left'
                else:
                    ctx['npc_direction'] = 'behind you'
        # Feed sensory data to brain (non-blocking)
        if self.brain.enabled:
            self.brain.set_sensory_data(ctx, base, self.heading_override)
            # Check if brain has a decision
            decision = self.brain.get_decision()
            if decision:
                # tank mode: human motion β†’ tank motion translation
                if ctx and ctx.get('avatar_type') == 'tank':
                    decision = decision.replace('a person ', 'a tank ').replace('A person ', 'A tank ')
                    for h, t in [
                        ('walking', 'rolling forward'), ('running', 'advancing rapidly'),
                        ('sprinting', 'charging at full speed'), ('turning', 'rotating'),
                        ('fleeing', 'reversing away'), ('stumbling', 'grinding to a halt'),
                        ('spinning', 'rotating turret'), ('swinging sword', 'firing cannon'),
                        ('blocking with shield', 'bracing armor'),
                        ('thrusting torch', 'sweeping searchlight'),
                        ('firing rpg', 'firing main gun'),
                        ('standing still', 'idling engine'),
                    ]:
                        decision = decision.replace(h, t)
                # brain decision changed β†’ force new text into model
                if decision != self.brain._last_applied_decision:
                    self.brain._last_applied_decision = decision
                    self._model_first_chunk = True   # model re-encodes new text!
                    # do NOT reset VAE first_chunk β€” keep decoding continuity
                    print(f"[Brain->Body] New motion applied: {decision[:60]}")
                self._prompt_source = "🧠"
                return decision
        # ── FALLBACK: Rule-based (same as before) ──
        self._prompt_source = "πŸ“"
        if not ctx:
            return base
        parts = []
        # Wall/obstacle awareness
        wall_front = ctx.get('wall_front')
        if wall_front is not None:
            if wall_front < 0.8:
                parts.append('stopping in front of a wall')
            elif wall_front < 2.0:
                parts.append('slowing down approaching a wall')
            elif wall_front < 4.0:
                parts.append('a wall ahead in the distance')
        # Stairs / slope
        if ctx.get('on_stairs'):
            parts.append('walking up stairs carefully')
        elif ctx.get('ground_slope') == 'down':
            parts.append('walking downhill')
        elif ctx.get('ground_slope') == 'up':
            parts.append('walking uphill')
        # NPC interaction
        npc_dist = ctx.get('npc_nearby')
        if npc_dist is not None:
            if npc_dist < 1.5:
                parts.append('another person very close')
            elif npc_dist < 4.0:
                parts.append('another person nearby')
        # Open space
        if not parts:
            if 'walk' in base:
                parts.append('on open ground')
        if parts:
            return base + ', ' + ', '.join(parts)
        return base
    def pause_generation(self):
        """Pause generation (keeps all state)"""
        self.should_stop = True
        if self.generation_thread:
            self.generation_thread.join(timeout=2.0)
        self.is_generating = False
        print("Generation paused (state preserved)")
    def resume_generation(self):
        """Resume generation from paused state"""
        if self.is_generating:
            print("Already generating, ignoring resume")
            return
        # Restart generation thread with existing state
        self.should_stop = False
        self.generation_thread = threading.Thread(target=self._generation_loop)
        self.generation_thread.daemon = True
        self.generation_thread.start()
        self.is_generating = True
        print("Generation resumed")
    def reset(self, history_length=None, smoothing_alpha=None):
        """Reset generation state completely
        Args:
            history_length: History window length for the model
            smoothing_alpha: EMA smoothing factor (0.0 to 1.0)
                - 1.0 = no smoothing (default)
                - 0.0 = infinite smoothing
                - Recommended: 0.3-0.7 for visible smoothing
        """
        # Stop if running
        if self.is_generating:
            self.pause_generation()
        # Clear everything
        self.frame_buffer.clear()
        self.vae.clear_cache()
        self.first_chunk = True
        if history_length is not None:
            self.history_length = history_length
        # Update smoothing alpha if provided and recreate stream recovery
        if smoothing_alpha is not None:
            self.smoothing_alpha = np.clip(smoothing_alpha, 0.0, 1.0)
            print(f"Smoothing alpha updated to: {self.smoothing_alpha}")
        # Recreate stream recovery with new smoothing alpha
        self.stream_recovery = StreamJointRecovery263(
            joints_num=22, smoothing_alpha=self.smoothing_alpha
        )
        # Reset heading override
        self.heading_override = None
        # Reset scene context
        self.scene_context = None
        # Stop brain
        self.brain.stop()
        # Restore model params from base config
        self.model.chunk_size = self._base_schedule_config["chunk_size"]
        self.model.noise_steps = self._base_schedule_config["steps"]
        self.model.cfg_scale = self._base_cfg_config["cfg_scale"]
        self._model_first_chunk = True
        # Initialize model
        self.model.init_generated(self.history_length, batch_size=1)
        print(
            f"Model reset - history: {self.history_length}, smoothing: {self.smoothing_alpha}"
        )
    def _generation_loop(self):
        """Main generation loop that runs in background thread"""
        print("Generation loop started")
        step_count = 0
        total_gen_time = 0
        with torch.no_grad():
            while not self.should_stop:
                # Check if buffer needs more frames
                if self.frame_buffer.needs_generation():
                    try:
                        step_start = time.time()
                        # Generate one token (produces frames from VAE)
                        prompt = self._build_perception_prompt()
                        x = {"text": [prompt]}
                        # Generate from model (1 token)
                        output = self.model.stream_generate_step(
                            x, first_chunk=self._model_first_chunk
                        )
                        self._model_first_chunk = False
                        generated = output["generated"]
                        # Skip if no frames committed yet
                        if generated[0].shape[0] == 0:
                            continue
                        # Decode with VAE (1 token -> 4 frames)
                        decoded = self.vae.stream_decode(
                            generated[0][None, :], first_chunk=self.first_chunk
                        )[0]
                        self.first_chunk = False
                        # Convert each frame to joints
                        for i in range(decoded.shape[0]):
                            frame_data = decoded[i].float().cpu().numpy()  # BFloat16->Float32 safe cast
                            joints = self.stream_recovery.process_frame(
                                frame_data, heading_override=self.heading_override
                            )
                            self.frame_buffer.add_frame(joints)
                            # Also add to broadcast buffer for spectators
                            with self.broadcast_lock:
                                self.broadcast_id += 1
                                self.broadcast_frames.append(
                                    (self.broadcast_id, joints)
                                )
                        step_time = time.time() - step_start
                        total_gen_time += step_time
                        step_count += 1
                        # Print performance stats every 10 steps
                        if step_count % 10 == 0:
                            avg_time = total_gen_time / step_count
                            fps = decoded.shape[0] / avg_time
                            print(
                                f"[Generation] Step {step_count}: {step_time * 1000:.1f}ms, "
                                f"Avg: {avg_time * 1000:.1f}ms, "
                                f"FPS: {fps:.1f}, "
                                f"Buffer: {self.frame_buffer.size()}, "
                                f"{getattr(self, '_prompt_source', '?')} Prompt: {prompt[:80]}"
                            )
                    except Exception as e:
                        print(f"Error in generation: {e}")
                        traceback.print_exc()
                        time.sleep(0.1)
                else:
                    # Buffer is full, wait a bit
                    time.sleep(0.01)
        print("Generation loop stopped")
    def get_next_frame(self):
        """Get the next frame from buffer"""
        return self.frame_buffer.get_frame()
    def get_broadcast_frames(self, after_id, count=8):
        """Get frames from broadcast buffer after the given ID (for spectators)."""
        with self.broadcast_lock:
            frames = [
                (fid, joints)
                for fid, joints in self.broadcast_frames
                if fid > after_id
            ]
        return frames[:count]
    # ── NPC management ──
    def spawn_npc(self, npc_type='man'):
        """Spawn and start NPC."""
        if not self._npc_lock.acquire(blocking=False):
            print("[NPC] Already spawning β€” ignored (Lock)")
            return
        try:
            if self.npc:
                self.npc.stop()
            self.npc = NPCStream(self._model_name, npc_type)
            self.npc.start()
        except Exception as e:
            print(f"[NPC] Spawn error: {e}")
            traceback.print_exc()
            raise
        finally:
            self._npc_lock.release()
    def despawn_npc(self):
        """Remove NPC."""
        # Wait for lock β€” if spawn in progress, wait until complete
        with self._npc_lock:
            if self.npc:
                self.npc.stop()
                self.npc = None
                print("[NPC] Removed")
    def get_buffer_status(self):
        """Get buffer status"""
        npc_state = self.npc.get_state() if self.npc else None
        return {
            "buffer_size": self.frame_buffer.size(),
            "target_size": self.frame_buffer.target_size,
            "is_generating": self.is_generating,
            "current_text": self.current_text,
            "smoothing_alpha": self.smoothing_alpha,
            "history_length": self.history_length,
            "brain_enabled": self.brain.enabled,
            "brain_decision": self.brain.get_decision() if self.brain.enabled else None,
            "brain_prediction": self.brain.get_prediction() if self.brain.enabled else None,
            "npc": npc_state,
            "schedule_config": {
                "chunk_size": self.model.chunk_size,
                "steps": self.model.noise_steps,
            },
            "cfg_config": {
                "cfg_scale": self.model.cfg_scale,
            },
        }
# ═══════════════════════════════════════════════════
#  NPC STREAM β€” separate FloodDiffusion stream
#  own model instance + position movement AI + frame generation
# ═══════════════════════════════════════════════════
NPC_TYPES = {
    'man':   {'name': 'πŸ§‘ Male',   'speed': 1.2, 'charge_speed': 3.0,
              'walk': 'a man walking forward steadily',
              'run': 'a man running fast toward someone',
              'idle': 'a man standing still looking around',
              'charge': 'a man running aggressively toward someone',
              'attack': 'a man throwing punches aggressively'},
    'woman': {'name': 'πŸ‘© Female', 'speed': 1.2, 'charge_speed': 2.5,
              'walk': 'a woman walking forward calmly',
              'run': 'a woman running quickly',
              'idle': 'a woman standing still',
              'charge': 'a woman running toward someone urgently',
              'attack': 'a woman attacking with desperate fury'},
    'beast': {'name': '🐺 Beast',  'speed': 2.0, 'charge_speed': 5.0,
              'walk': 'a person prowling on all fours like a beast',
              'run': 'a person running on all fours like a wild animal',
              'idle': 'a person crouching low like a wild beast',
              'charge': 'a person charging aggressively on all fours like a beast',
              'attack': 'a person lunging and clawing savagely like a wild beast'},
    'enemy_tank': {'name': 'πŸͺ– Enemy Tank', 'speed': 1.5, 'charge_speed': 3.5,
              'walk': 'a tank rolling forward on patrol',
              'run': 'a tank advancing rapidly toward target',
              'idle': 'a tank idling with engine rumbling',
              'charge': 'a tank charging at full speed toward enemy',
              'attack': 'a tank firing cannon and advancing aggressively'},
}
class NPCStream:
    """NPC with independent FloodDiffusion stream + movement AI."""
    def __init__(self, model_name, npc_type='man'):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_name = model_name
        self.npc_type = npc_type
        self.type_info = NPC_TYPES.get(npc_type, NPC_TYPES['man'])
        # NPC position/movement
        self.position = {'x': 8.0, 'z': 0.0}  # spawn position
        self.heading = 0.0  # radians
        self.behavior = 'stop'  # stop, approach, wander, charge
        self.target_pos = {'x': 0.0, 'z': 0.0}  # player position
        # model (lazy load β€” loaded on spawn)
        self.model = None
        self.vae = None
        self._loaded = False
        # generation state
        self.frame_buffer = FrameBuffer(target_buffer_size=8)
        self.stream_recovery = StreamJointRecovery263(
            joints_num=22, smoothing_alpha=0.5
        )
        self.current_text = self.type_info['idle']
        self.is_generating = False
        self._generation_thread = None
        self._movement_thread = None
        self._should_stop = False
        self._first_chunk = True
        self._model_first_chunk = True
        self.history_length = 30
        print(f"[NPC] Created: {self.type_info['name']} at ({self.position['x']}, {self.position['z']})")
    def load_model(self):
        """Load separate model instance (from HF cache β€” fast)."""
        if self._loaded:
            return
        print(f"[NPC] Loading model: {self.model_name}")
        hf_model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
        hf_model.to(self.device)
        # NOTE: warmup removed β€” hf_model('test') internally calls init_generated(1)
        # which conflicts with start()'s init_generated(30) β†’ empty exception
        # Main model is fine due to timing gap; NPC is synchronous so it conflicts
        # CUDA kernel compilation already done by main model β€” NPC skips it
        self.model = hf_model.ldf_model
        self.vae = hf_model.vae
        self.model.eval()
        self.vae.eval()
        self._loaded = True
        print(f"[NPC] Model loaded")
    def start(self):
        """Start generation and movement."""
        if not self._loaded:
            self.load_model()
        self._should_stop = False
        # initialize generation
        self.frame_buffer.clear()
        self.stream_recovery.reset()
        try:
            self.vae.clear_cache()
        except Exception as e:
            print(f"[NPC] vae.clear_cache warning: {e}")
        self._first_chunk = True
        self._model_first_chunk = True
        # chunk_size: use model default (never hardcode!)
        # FloodDiffusion requires: num_denoise_steps % chunk_size == 0
        # default chunk_size is set at model load time β€” use as-is
        denoise = getattr(self.model, 'num_denoise_steps', None) or getattr(self.model, 'noise_steps', 10)
        base_cs = self.model.chunk_size
        if denoise % base_cs != 0:
            # find compatible chunk_size (try 2β†’1)
            for cs in [2, 1]:
                if denoise % cs == 0:
                    self.model.chunk_size = cs
                    break
            print(f"[NPC] chunk_size adjusted: {base_cs}β†’{self.model.chunk_size} (denoise_steps={denoise})")
        try:
            self.model.init_generated(self.history_length, batch_size=1)
        except Exception as e:
            print(f"[NPC] init_generated error: {e}")
            traceback.print_exc()
            print("[NPC] init_generated failed β€” cannot start thread")
            return
        # movement thread
        self._movement_thread = threading.Thread(target=self._movement_loop, daemon=True)
        self._movement_thread.start()
        # generation thread
        self._generation_thread = threading.Thread(target=self._generation_loop, daemon=True)
        self._generation_thread.start()
        self.is_generating = True
        print(f"[NPC] Started: {self.behavior}")
    def stop(self):
        """Stop and release GPU memory."""
        self._should_stop = True
        if self._generation_thread:
            self._generation_thread.join(timeout=3)
        if self._movement_thread:
            self._movement_thread.join(timeout=2)
        self.is_generating = False
        self.frame_buffer.clear()
        # release GPU memory
        import torch, gc
        if self.model is not None:
            del self.model
            self.model = None
        if self.vae is not None:
            del self.vae
            self.vae = None
        self._loaded = False
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        print("[NPC] Stopped + GPU memory released")
    def set_behavior(self, behavior):
        """Change behavior: stop, approach, wander, charge, attack."""
        self.behavior = behavior
        ti = self.type_info
        prompts = {
            'stop': ti['idle'],
            'approach': ti['walk'],
            'wander': ti['walk'],
            'charge': ti['charge'],
            'attack': ti.get('attack', ti['charge']),
        }
        self.current_text = prompts.get(behavior, ti['idle'])
        print(f"[NPC] Behavior: {behavior} -> {self.current_text}")
    def set_target(self, x, z):
        """Update player position."""
        self.target_pos = {'x': x, 'z': z}
    def get_state(self):
        """Return current NPC state."""
        dx = self.target_pos['x'] - self.position['x']
        dz = self.target_pos['z'] - self.position['z']
        dist = math.sqrt(dx*dx + dz*dz)
        return {
            'type': self.npc_type,
            'type_name': self.type_info['name'],
            'position': self.position,
            'heading': self.heading,
            'behavior': self.behavior,
            'distance_to_player': round(dist, 1),
            'is_generating': self.is_generating,
            'buffer_size': self.frame_buffer.size(),
        }
    def _movement_loop(self):
        """NPC movement AI β€” update position every 100ms."""
        while not self._should_stop:
            time.sleep(0.1)
            if self.behavior == 'stop':
                continue
            dx = self.target_pos['x'] - self.position['x']
            dz = self.target_pos['z'] - self.position['z']
            dist = math.sqrt(dx*dx + dz*dz)
            if dist < 0.01:
                continue
            # compute direction toward player
            self.heading = math.atan2(dx, -dz)
            # determine speed
            ti = self.type_info
            if self.behavior == 'charge':
                speed = ti['charge_speed']
                min_dist = 1.0  # close to 1m
            elif self.behavior == 'attack':
                speed = ti['charge_speed'] * 0.8  # advance while attacking
                min_dist = 2.5  # maintain attack range
            elif self.behavior == 'approach':
                speed = ti['speed']
                min_dist = 2.0  # close to 2m
            elif self.behavior == 'wander':
                speed = ti['speed'] * 0.5
                min_dist = 3.0  # maintain 3m distance
            else:
                continue
            if dist <= min_dist:
                continue  # minimum distance reached
            # move
            move = min(speed * 0.1, dist - min_dist)  # 0.1s * speed
            nx = dx / dist
            nz = dz / dist
            self.position['x'] += nx * move
            self.position['z'] += nz * move
    def _generation_loop(self):
        """NPC motion generation loop."""
        print("[NPC] Generation loop started")
        step = 0
        with torch.no_grad():
            while not self._should_stop:
                if self.frame_buffer.needs_generation():
                    try:
                        x = {"text": [self.current_text]}
                        output = self.model.stream_generate_step(
                            x, first_chunk=self._model_first_chunk
                        )
                        self._model_first_chunk = False
                        generated = output["generated"]
                        if generated[0].shape[0] == 0:
                            continue
                        decoded = self.vae.stream_decode(
                            generated[0][None, :], first_chunk=self._first_chunk
                        )[0]
                        self._first_chunk = False
                        for i in range(decoded.shape[0]):
                            frame_data = decoded[i].float().cpu().numpy()  # BFloat16β†’Float32
                            joints = self.stream_recovery.process_frame(
                                frame_data, heading_override=self.heading
                            )
                            self.frame_buffer.add_frame(joints)
                        step += 1
                        if step % 50 == 0:
                            print(f"[NPC] Step {step}: {self.current_text[:50]}")
                    except Exception as e:
                        print(f"[NPC] Generation error: {e}")
                        time.sleep(0.1)
                else:
                    time.sleep(0.01)
# Global model manager instance
_model_manager = None
def get_model_manager(model_name=None):
    """Get or create the global model manager instance"""
    global _model_manager
    if _model_manager is None:
        _model_manager = ModelManager(model_name)
    return _model_manager
# ════════════════════════════════
# Flask Server
# ════════════════════════════════
"""
Flask server for real-time 3D motion generation demo (HF Space version)
"""
import sys
import argparse
from flask import Flask, jsonify, render_template, request
from flask_cors import CORS
def _coerce_value(value, reference):
    """Coerce a value to match the type of a reference value"""
    if isinstance(reference, bool):
        return value if isinstance(value, bool) else str(value).lower() in ("true", "1")
    elif isinstance(reference, int):
        return int(value)
    elif isinstance(reference, float):
        return float(value)
    return str(value)
app = Flask(__name__, template_folder='.', static_folder='.', static_url_path='')
CORS(app)
# Global model manager (loaded eagerly on startup)
model_manager = None
model_name_global = None  # Will be set once at startup
# Session tracking - only one active session can generate at a time
active_session_id = None  # The session ID currently generating
session_lock = threading.Lock()
# Frame consumption monitoring - detect if client disconnected by tracking frame consumption
last_frame_consumed_time = None
consumption_timeout = (
    5.0  # If no frame consumed for 5 seconds, assume client disconnected
)
consumption_monitor_thread = None
consumption_monitor_lock = threading.Lock()
def init_model():
    """Initialize model manager"""
    global model_manager
    if model_manager is None:
        if model_name_global is None:
            raise RuntimeError(
                "model_name_global not set. Server not properly initialized."
            )
        print(f"Initializing model manager with model: {model_name_global}")
        model_manager = get_model_manager(model_name=model_name_global)
        print("Model manager ready!")
    return model_manager
def consumption_monitor():
    """Monitor frame consumption and auto-reset if client stops consuming"""
    global last_frame_consumed_time, active_session_id, model_manager
    while True:
        time.sleep(2.0)  # Check every 2 seconds
        # Read state with proper locking - no nested locks!
        should_reset = False
        current_session = None
        time_since_last_consumption = 0
        # First, check consumption time
        with consumption_monitor_lock:
            if last_frame_consumed_time is not None:
                time_since_last_consumption = time.time() - last_frame_consumed_time
                if time_since_last_consumption > consumption_timeout:
                    # Need to check if still generating before reset
                    if model_manager and model_manager.is_generating:
                        should_reset = True
        # Then, get current session (separate lock)
        if should_reset:
            with session_lock:
                current_session = active_session_id
        # Perform reset outside of locks to avoid deadlock
        if should_reset and current_session is not None:
            print(
                f"No frame consumed for {time_since_last_consumption:.1f}s - client disconnected, auto-resetting..."
            )
            if model_manager:
                model_manager.reset()
                print(
                    "Generation reset due to client disconnect (no frame consumption)"
                )
            # Clear state with proper locking - no nested locks!
            with session_lock:
                if active_session_id == current_session:
                    active_session_id = None
            with consumption_monitor_lock:
                last_frame_consumed_time = None
def start_consumption_monitor():
    """Start the consumption monitoring thread if not already running"""
    global consumption_monitor_thread
    if consumption_monitor_thread is None or not consumption_monitor_thread.is_alive():
        consumption_monitor_thread = threading.Thread(
            target=consumption_monitor, daemon=True
        )
        consumption_monitor_thread.start()
        print("Consumption monitor started")
@app.route("/")
def index():
    """Main page"""
    return render_template("index.html")
@app.route("/api/config", methods=["GET"])
def get_config():
    """Get current config"""
    try:
        if model_manager:
            status = model_manager.get_buffer_status()
            return jsonify(
                {
                    "schedule_config": status["schedule_config"],
                    "cfg_config": status["cfg_config"],
                    "history_length": status["history_length"],
                    "smoothing_alpha": float(status["smoothing_alpha"]),
                }
            )
        else:
            # Model not loaded yet - return defaults
            return jsonify(
                {
                    "schedule_config": {},
                    "cfg_config": {},
                    "history_length": 30,
                    "smoothing_alpha": 0.5,
                }
            )
    except Exception as e:
        traceback.print_exc()
        return jsonify({"status": "error", "message": str(e)}), 500
@app.route("/api/config", methods=["POST"])
def update_config():
    """Update model config in memory"""
    try:
        global active_session_id, last_frame_consumed_time
        if not model_manager or not model_manager.model:
            return jsonify({"status": "error", "message": "Model not loaded yet"}), 400
        data = request.json
        new_schedule_config = data.get("schedule_config")
        new_cfg_config = data.get("cfg_config")
        history_length = data.get("history_length")
        smoothing_alpha = data.get("smoothing_alpha")
        valid_schedule_keys = set(model_manager._base_schedule_config.keys())
        valid_cfg_keys = set(model_manager._base_cfg_config.keys())
        # Validate and update schedule_config
        if new_schedule_config:
            for key in new_schedule_config:
                if key not in valid_schedule_keys:
                    return jsonify(
                        {
                            "status": "error",
                            "message": f"Unknown schedule_config key: {key}",
                        }
                    ), 400
            for key, value in new_schedule_config.items():
                model_manager._base_schedule_config[key] = _coerce_value(
                    value, model_manager._base_schedule_config[key]
                )
        # Validate and update cfg_config
        if new_cfg_config:
            for key in new_cfg_config:
                if key not in valid_cfg_keys:
                    return jsonify(
                        {"status": "error", "message": f"Unknown cfg_config key: {key}"}
                    ), 400
            for key, value in new_cfg_config.items():
                model_manager._base_cfg_config[key] = _coerce_value(
                    value, model_manager._base_cfg_config[key]
                )
        # Reset with new parameters
        model_manager.reset(
            history_length=history_length,
            smoothing_alpha=smoothing_alpha,
        )
        # Clear active session
        with session_lock:
            active_session_id = None
        with consumption_monitor_lock:
            last_frame_consumed_time = None
        return jsonify({"status": "success"})
    except Exception as e:
        traceback.print_exc()
        return jsonify({"status": "error", "message": str(e)}), 500
@app.route("/api/start", methods=["POST"])
def start_generation():
    """Start generation with given text"""
    try:
        global active_session_id, last_frame_consumed_time
        data = request.json
        session_id = data.get("session_id")
        text = data.get("text", "walk in a circle.")
        history_length = data.get("history_length")
        smoothing_alpha = data.get(
            "smoothing_alpha", None
        )  # Optional smoothing parameter
        force = data.get("force", False)  # Allow force takeover
        if not session_id:
            return jsonify(
                {"status": "error", "message": "session_id is required"}
            ), 400
        print(
            f"[Session {session_id}] Starting generation with text: {text}, history_length: {history_length}, force: {force}"
        )
        # Initialize model if needed
        mm = init_model()
        # Check if another session is already generating
        need_force_takeover = False
        with session_lock:
            if active_session_id and active_session_id != session_id:
                if not force:
                    # Another session is active, return conflict
                    return jsonify(
                        {
                            "status": "error",
                            "message": "Another session is already generating.",
                            "conflict": True,
                            "active_session_id": active_session_id,
                        }
                    ), 409
                else:
                    # Force takeover
                    print(
                        f"[Session {session_id}] Force takeover from session {active_session_id}"
                    )
                    need_force_takeover = True
            if mm.is_generating and active_session_id == session_id:
                return jsonify(
                    {
                        "status": "error",
                        "message": "Generation is already running for this session.",
                    }
                ), 400
            # Set this session as active
            active_session_id = session_id
        # Clear previous session's consumption tracking if force takeover (no nested locks)
        if need_force_takeover:
            with consumption_monitor_lock:
                last_frame_consumed_time = None
        # Reset and start generation
        mm.reset(history_length=history_length, smoothing_alpha=smoothing_alpha)
        mm.start_generation(text, history_length=history_length)
        # Initialize consumption tracking (no nested locks)
        with consumption_monitor_lock:
            last_frame_consumed_time = time.time()
        # Start consumption monitoring
        start_consumption_monitor()
        print(f"[Session {session_id}] Consumption monitoring activated")
        return jsonify(
            {
                "status": "success",
                "message": f"Generation started with text: {text}, history_length: {history_length}",
                "session_id": session_id,
            }
        )
    except Exception as e:
        print(f"Error in start_generation: {e}")
        traceback.print_exc()
        return jsonify({"status": "error", "message": str(e)}), 500
@app.route("/api/update_text", methods=["POST"])
def update_text():
    """Update the generation text"""
    try:
        data = request.json
        session_id = data.get("session_id")
        text = data.get("text", "")
        if not session_id:
            return jsonify(
                {"status": "error", "message": "session_id is required"}
            ), 400
        # Verify this is the active session
        with session_lock:
            if active_session_id != session_id:
                return jsonify(
                    {"status": "error", "message": "Not the active session"}
                ), 403
        if model_manager is None:
            return jsonify({"status": "error", "message": "Model not initialized"}), 400
        model_manager.update_text(text)
        return jsonify({"status": "success", "message": f"Text updated to: {text}"})
    except Exception as e:
        return jsonify({"status": "error", "message": str(e)}), 500
@app.route("/api/update_heading", methods=["POST"])
def update_heading():
    """Update heading override for world model mode.
    When heading is set, the character's direction follows the user's input
    while AI controls movement speed and animation. Set to null to return
    to AI-controlled direction.
    """
    try:
        data = request.json
        session_id = data.get("session_id")
        heading = data.get("heading")  # radians, or null for AI control
        if not session_id:
            return jsonify(
                {"status": "error", "message": "session_id is required"}
            ), 400
        # Verify this is the active session
        with session_lock:
            if active_session_id != session_id:
                return jsonify(
                    {"status": "error", "message": "Not the active session"}
                ), 403
        if model_manager is None:
            return jsonify({"status": "error", "message": "Model not initialized"}), 400
        model_manager.set_heading(heading)
        return jsonify({"status": "success"})
    except Exception as e:
        return jsonify({"status": "error", "message": str(e)}), 500
@app.route("/api/update_scene_context", methods=["POST"])
def update_scene_context():
    """Update scene context for perception-aware motion generation.
    Client sends environment scan data (wall distances, ground type, NPCs).
    Server uses this to build enhanced text prompts for FloodDiffusion.
    """
    try:
        data = request.json
        session_id = data.get("session_id")
        ctx = data.get("context")  # dict with wall_front, on_stairs, etc.
        if not session_id:
            return jsonify(
                {"status": "error", "message": "session_id is required"}
            ), 400
        # Verify this is the active session
        with session_lock:
            if active_session_id != session_id:
                return jsonify(
                    {"status": "error", "message": "Not the active session"}
                ), 403
        if model_manager is None:
            return jsonify({"status": "error", "message": "Model not initialized"}), 400
        model_manager.set_scene_context(ctx)
        # pass player position to NPC
        if model_manager.npc and ctx:
            px = ctx.get('player_x', 0)
            pz = ctx.get('player_z', 0)
            model_manager.npc.set_target(px, pz)
        return jsonify({"status": "success"})
    except Exception as e:
        return jsonify({"status": "error", "message": str(e)}), 500
# ── NPC API ──
@app.route("/api/npc/spawn", methods=["POST"])
def npc_spawn():
    """Spawn NPC."""
    try:
        data = request.json or {}
        npc_type = data.get("type", "man")
        if model_manager is None:
            return jsonify({"status": "error", "message": "Model not initialized"}), 400
        # already spawning β€” silently ignore (Lock-based)
        if hasattr(model_manager, '_npc_lock') and model_manager._npc_lock.locked():
            return jsonify({"status": "busy", "message": "NPC loading"})
        model_manager.spawn_npc(npc_type)
        return jsonify({"status": "success", "type": npc_type})
    except Exception as e:
        print(f"[NPC spawn error] {type(e).__name__}: {e}")
        traceback.print_exc()
        return jsonify({"status": "error", "message": f"{type(e).__name__}: {e}"}), 500
@app.route("/api/npc/command", methods=["POST"])
def npc_command():
    """Change NPC behavior."""
    try:
        data = request.json or {}
        behavior = data.get("behavior", "stop")
        if model_manager and model_manager.npc:
            model_manager.npc.set_behavior(behavior)
        return jsonify({"status": "success", "behavior": behavior})
    except Exception as e:
        return jsonify({"status": "error", "message": str(e)}), 500
@app.route("/api/npc/despawn", methods=["POST"])
def npc_despawn():
    """Remove NPC."""
    try:
        if model_manager:
            model_manager.despawn_npc()
        return jsonify({"status": "success"})
    except Exception as e:
        return jsonify({"status": "error", "message": str(e)}), 500
@app.route("/api/npc/frame", methods=["GET"])
def npc_frame():
    """Return NPC frame + state."""
    try:
        if not model_manager or not model_manager.npc:
            return jsonify({"frames": [], "npc": None})
        count = request.args.get("count", 4, type=int)
        npc = model_manager.npc
        frames = []
        for _ in range(count):
            frame = npc.frame_buffer.get_frame()
            if frame is not None:
                frames.append(frame.tolist())
        state = npc.get_state()
        return jsonify({"frames": frames, "npc": state})
    except Exception as e:
        return jsonify({"status": "error", "message": str(e)}), 500
@app.route("/api/pause", methods=["POST"])
def pause_generation():
    """Pause generation (keeps state for resume)"""
    try:
        data = request.json if request.json else {}
        session_id = data.get("session_id")
        if not session_id:
            return jsonify(
                {"status": "error", "message": "session_id is required"}
            ), 400
        # Verify this is the active session
        with session_lock:
            if active_session_id != session_id:
                return jsonify(
                    {"status": "error", "message": "Not the active session"}
                ), 403
        if model_manager:
            model_manager.pause_generation()
        return jsonify({"status": "success", "message": "Generation paused"})
    except Exception as e:
        return jsonify({"status": "error", "message": str(e)}), 500
@app.route("/api/resume", methods=["POST"])
def resume_generation():
    """Resume generation from paused state"""
    try:
        global last_frame_consumed_time
        data = request.json if request.json else {}
        session_id = data.get("session_id")
        if not session_id:
            return jsonify(
                {"status": "error", "message": "session_id is required"}
            ), 400
        # Verify this is the active session
        with session_lock:
            if active_session_id != session_id:
                return jsonify(
                    {"status": "error", "message": "Not the active session"}
                ), 403
        if model_manager is None:
            return jsonify({"status": "error", "message": "Model not initialized"}), 400
        model_manager.resume_generation()
        # Reset consumption tracking when resuming
        with consumption_monitor_lock:
            last_frame_consumed_time = time.time()
        return jsonify({"status": "success", "message": "Generation resumed"})
    except Exception as e:
        return jsonify({"status": "error", "message": str(e)}), 500
@app.route("/api/reset", methods=["POST"])
def reset():
    """Reset generation state"""
    try:
        global active_session_id, last_frame_consumed_time
        data = request.json if request.json else {}
        session_id = data.get("session_id")
        history_length = data.get("history_length")
        smoothing_alpha = data.get("smoothing_alpha")
        # If session_id provided, verify it's the active session
        if session_id:
            with session_lock:
                if active_session_id and active_session_id != session_id:
                    return jsonify(
                        {"status": "error", "message": "Not the active session"}
                    ), 403
        if model_manager:
            model_manager.reset(
                history_length=history_length, smoothing_alpha=smoothing_alpha
            )
        # Clear the active session
        with session_lock:
            if active_session_id == session_id or not session_id:
                active_session_id = None
        # Clear consumption tracking
        with consumption_monitor_lock:
            last_frame_consumed_time = None
        print(f"[Session {session_id}] Reset complete, session cleared")
        return jsonify(
            {
                "status": "success",
                "message": "Reset complete",
            }
        )
    except Exception as e:
        return jsonify({"status": "error", "message": str(e)}), 500
@app.route("/api/get_frame", methods=["GET"])
def get_frame():
    """Get the next frame"""
    try:
        global last_frame_consumed_time
        session_id = request.args.get("session_id")
        if not session_id:
            return jsonify(
                {"status": "error", "message": "session_id is required"}
            ), 400
        if model_manager is None:
            return jsonify({"status": "error", "message": "Model not initialized"}), 400
        count = min(int(request.args.get("count", 8)), 20)
        # Check if this is the active session or a spectator
        with session_lock:
            is_active = active_session_id == session_id
        if is_active:
            # Active session: pop frames from generation buffer
            frames = []
            for _ in range(count):
                joints = model_manager.get_next_frame()
                if joints is None:
                    break
                frames.append(joints.tolist())
            if frames:
                with consumption_monitor_lock:
                    last_frame_consumed_time = time.time()
                return jsonify(
                    {
                        "status": "success",
                        "frames": frames,
                        "buffer_size": model_manager.frame_buffer.size(),
                    }
                )
        else:
            # Spectator: read from broadcast buffer (non-destructive)
            after_id = int(request.args.get("after_id", 0))
            broadcast = model_manager.get_broadcast_frames(after_id, count)
            if broadcast:
                last_id = broadcast[-1][0]
                frames = [joints.tolist() for _, joints in broadcast]
                return jsonify(
                    {
                        "status": "success",
                        "frames": frames,
                        "last_id": last_id,
                        "buffer_size": model_manager.frame_buffer.size(),
                    }
                )
        # No frames available (active or spectator)
        return jsonify(
            {
                "status": "waiting",
                "message": "No frame available yet",
                "buffer_size": model_manager.frame_buffer.size(),
            }
        )
    except Exception as e:
        print(f"Error in get_frame: {e}")
        traceback.print_exc()
        return jsonify({"status": "error", "message": str(e)}), 500
@app.route("/api/status", methods=["GET"])
def get_status():
    """Get generation status"""
    try:
        session_id = request.args.get("session_id")
        with session_lock:
            is_active_session = session_id and active_session_id == session_id
            current_active_session = active_session_id
        if model_manager is None:
            return jsonify(
                {
                    "initialized": False,
                    "buffer_size": 0,
                    "is_generating": False,
                    "is_active_session": is_active_session,
                    "active_session_id": current_active_session,
                }
            )
        status = model_manager.get_buffer_status()
        status["initialized"] = True
        status["is_active_session"] = is_active_session
        status["active_session_id"] = current_active_session
        return jsonify(status)
    except Exception as e:
        return jsonify({"status": "error", "message": str(e)}), 500
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Flask server for real-time 3D motion generation"
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="ShandaAI/FloodDiffusionTiny",
        help="HF Hub model name (default: ShandaAI/FloodDiffusionTiny)",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=7860,
        help="Port to run the server on (default: 7860)",
    )
    args = parser.parse_args()
    model_name_global = args.model_name
    # Load model eagerly on startup (pre-downloaded in Docker)
    print(f"Loading model: {model_name_global}")
    init_model()
    print("Starting Flask server...")
    app.run(host="0.0.0.0", port=args.port, debug=False, threaded=True)