WinstonDeng commited on
Commit
5789f7a
·
verified ·
1 Parent(s): b1e8330

add step-3.7-flash fp8 model libs

Browse files
Files changed (5) hide show
  1. config.json +862 -4
  2. configuration_step3p7.py +219 -0
  3. modeling_step3p7.py +1395 -0
  4. processing_step3.py +464 -0
  5. vision_encoder.py +452 -0
config.json CHANGED
@@ -1,11 +1,12 @@
1
  {
2
  "architectures": [
3
- "MMGPTStepRoboticsForCausalLM"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "configuration_step_robotics.StepRoboticsConfig"
 
7
  },
8
- "model_type": "step3p5v",
9
  "im_end_token": "<im_end>",
10
  "im_patch_token": "<im_patch>",
11
  "im_start_token": "<im_start>",
@@ -347,6 +348,863 @@
347
  "weight_block_size": [
348
  128,
349
  128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  ]
351
  }
352
- }
 
1
  {
2
  "architectures": [
3
+ "Step3p7ForConditionalGeneration"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_step3p7.Step3p7Config",
7
+ "AutoModelForCausalLM": "modeling_step3p7.Step3p7ForConditionalGeneration"
8
  },
9
+ "model_type": "step3p7",
10
  "im_end_token": "<im_end>",
11
  "im_patch_token": "<im_patch>",
12
  "im_start_token": "<im_start>",
 
348
  "weight_block_size": [
349
  128,
350
  128
351
+ ],
352
+ "modules_to_not_convert": [
353
+ "lm_head",
354
+ "model.embed_tokens",
355
+ "model.norm",
356
+ "model.layers.0.self_attn.g_proj",
357
+ "model.layers.0.self_attn.qkv_proj",
358
+ "model.layers.0.self_attn.q_proj",
359
+ "model.layers.0.self_attn.k_proj",
360
+ "model.layers.0.self_attn.v_proj",
361
+ "model.layers.0.self_attn.o_proj",
362
+ "model.layers.0.mlp.gate_up_proj",
363
+ "model.layers.0.mlp.gate_proj",
364
+ "model.layers.0.mlp.up_proj",
365
+ "model.layers.0.mlp.down_proj",
366
+ "model.layers.1.self_attn.g_proj",
367
+ "model.layers.1.self_attn.qkv_proj",
368
+ "model.layers.1.self_attn.q_proj",
369
+ "model.layers.1.self_attn.k_proj",
370
+ "model.layers.1.self_attn.v_proj",
371
+ "model.layers.1.self_attn.o_proj",
372
+ "model.layers.1.mlp.gate_up_proj",
373
+ "model.layers.1.mlp.gate_proj",
374
+ "model.layers.1.mlp.up_proj",
375
+ "model.layers.1.mlp.down_proj",
376
+ "model.layers.2.self_attn.g_proj",
377
+ "model.layers.2.self_attn.qkv_proj",
378
+ "model.layers.2.self_attn.q_proj",
379
+ "model.layers.2.self_attn.k_proj",
380
+ "model.layers.2.self_attn.v_proj",
381
+ "model.layers.2.self_attn.o_proj",
382
+ "model.layers.2.mlp.gate_up_proj",
383
+ "model.layers.2.mlp.gate_proj",
384
+ "model.layers.2.mlp.up_proj",
385
+ "model.layers.2.mlp.down_proj",
386
+ "model.layers.3.self_attn.g_proj",
387
+ "model.layers.3.self_attn.qkv_proj",
388
+ "model.layers.3.self_attn.q_proj",
389
+ "model.layers.3.self_attn.k_proj",
390
+ "model.layers.3.self_attn.v_proj",
391
+ "model.layers.3.self_attn.o_proj",
392
+ "model.layers.3.moe.gate",
393
+ "model.layers.3.share_expert.gate_up_proj",
394
+ "model.layers.3.share_expert.gate_proj",
395
+ "model.layers.3.share_expert.up_proj",
396
+ "model.layers.3.share_expert.down_proj",
397
+ "model.layers.4.self_attn.g_proj",
398
+ "model.layers.4.self_attn.qkv_proj",
399
+ "model.layers.4.self_attn.q_proj",
400
+ "model.layers.4.self_attn.k_proj",
401
+ "model.layers.4.self_attn.v_proj",
402
+ "model.layers.4.self_attn.o_proj",
403
+ "model.layers.4.moe.gate",
404
+ "model.layers.4.share_expert.gate_up_proj",
405
+ "model.layers.4.share_expert.gate_proj",
406
+ "model.layers.4.share_expert.up_proj",
407
+ "model.layers.4.share_expert.down_proj",
408
+ "model.layers.5.self_attn.g_proj",
409
+ "model.layers.5.self_attn.qkv_proj",
410
+ "model.layers.5.self_attn.q_proj",
411
+ "model.layers.5.self_attn.k_proj",
412
+ "model.layers.5.self_attn.v_proj",
413
+ "model.layers.5.self_attn.o_proj",
414
+ "model.layers.5.moe.gate",
415
+ "model.layers.5.share_expert.gate_up_proj",
416
+ "model.layers.5.share_expert.gate_proj",
417
+ "model.layers.5.share_expert.up_proj",
418
+ "model.layers.5.share_expert.down_proj",
419
+ "model.layers.6.self_attn.g_proj",
420
+ "model.layers.6.self_attn.qkv_proj",
421
+ "model.layers.6.self_attn.q_proj",
422
+ "model.layers.6.self_attn.k_proj",
423
+ "model.layers.6.self_attn.v_proj",
424
+ "model.layers.6.self_attn.o_proj",
425
+ "model.layers.6.moe.gate",
426
+ "model.layers.6.share_expert.gate_up_proj",
427
+ "model.layers.6.share_expert.gate_proj",
428
+ "model.layers.6.share_expert.up_proj",
429
+ "model.layers.6.share_expert.down_proj",
430
+ "model.layers.7.self_attn.g_proj",
431
+ "model.layers.7.self_attn.qkv_proj",
432
+ "model.layers.7.self_attn.q_proj",
433
+ "model.layers.7.self_attn.k_proj",
434
+ "model.layers.7.self_attn.v_proj",
435
+ "model.layers.7.self_attn.o_proj",
436
+ "model.layers.7.moe.gate",
437
+ "model.layers.7.share_expert.gate_up_proj",
438
+ "model.layers.7.share_expert.gate_proj",
439
+ "model.layers.7.share_expert.up_proj",
440
+ "model.layers.7.share_expert.down_proj",
441
+ "model.layers.8.self_attn.g_proj",
442
+ "model.layers.8.self_attn.qkv_proj",
443
+ "model.layers.8.self_attn.q_proj",
444
+ "model.layers.8.self_attn.k_proj",
445
+ "model.layers.8.self_attn.v_proj",
446
+ "model.layers.8.self_attn.o_proj",
447
+ "model.layers.8.moe.gate",
448
+ "model.layers.8.share_expert.gate_up_proj",
449
+ "model.layers.8.share_expert.gate_proj",
450
+ "model.layers.8.share_expert.up_proj",
451
+ "model.layers.8.share_expert.down_proj",
452
+ "model.layers.9.self_attn.g_proj",
453
+ "model.layers.9.self_attn.qkv_proj",
454
+ "model.layers.9.self_attn.q_proj",
455
+ "model.layers.9.self_attn.k_proj",
456
+ "model.layers.9.self_attn.v_proj",
457
+ "model.layers.9.self_attn.o_proj",
458
+ "model.layers.9.moe.gate",
459
+ "model.layers.9.share_expert.gate_up_proj",
460
+ "model.layers.9.share_expert.gate_proj",
461
+ "model.layers.9.share_expert.up_proj",
462
+ "model.layers.9.share_expert.down_proj",
463
+ "model.layers.10.self_attn.g_proj",
464
+ "model.layers.10.self_attn.qkv_proj",
465
+ "model.layers.10.self_attn.q_proj",
466
+ "model.layers.10.self_attn.k_proj",
467
+ "model.layers.10.self_attn.v_proj",
468
+ "model.layers.10.self_attn.o_proj",
469
+ "model.layers.10.moe.gate",
470
+ "model.layers.10.share_expert.gate_up_proj",
471
+ "model.layers.10.share_expert.gate_proj",
472
+ "model.layers.10.share_expert.up_proj",
473
+ "model.layers.10.share_expert.down_proj",
474
+ "model.layers.11.self_attn.g_proj",
475
+ "model.layers.11.self_attn.qkv_proj",
476
+ "model.layers.11.self_attn.q_proj",
477
+ "model.layers.11.self_attn.k_proj",
478
+ "model.layers.11.self_attn.v_proj",
479
+ "model.layers.11.self_attn.o_proj",
480
+ "model.layers.11.moe.gate",
481
+ "model.layers.11.share_expert.gate_up_proj",
482
+ "model.layers.11.share_expert.gate_proj",
483
+ "model.layers.11.share_expert.up_proj",
484
+ "model.layers.11.share_expert.down_proj",
485
+ "model.layers.12.self_attn.g_proj",
486
+ "model.layers.12.self_attn.qkv_proj",
487
+ "model.layers.12.self_attn.q_proj",
488
+ "model.layers.12.self_attn.k_proj",
489
+ "model.layers.12.self_attn.v_proj",
490
+ "model.layers.12.self_attn.o_proj",
491
+ "model.layers.12.moe.gate",
492
+ "model.layers.12.share_expert.gate_up_proj",
493
+ "model.layers.12.share_expert.gate_proj",
494
+ "model.layers.12.share_expert.up_proj",
495
+ "model.layers.12.share_expert.down_proj",
496
+ "model.layers.13.self_attn.g_proj",
497
+ "model.layers.13.self_attn.qkv_proj",
498
+ "model.layers.13.self_attn.q_proj",
499
+ "model.layers.13.self_attn.k_proj",
500
+ "model.layers.13.self_attn.v_proj",
501
+ "model.layers.13.self_attn.o_proj",
502
+ "model.layers.13.moe.gate",
503
+ "model.layers.13.share_expert.gate_up_proj",
504
+ "model.layers.13.share_expert.gate_proj",
505
+ "model.layers.13.share_expert.up_proj",
506
+ "model.layers.13.share_expert.down_proj",
507
+ "model.layers.14.self_attn.g_proj",
508
+ "model.layers.14.self_attn.qkv_proj",
509
+ "model.layers.14.self_attn.q_proj",
510
+ "model.layers.14.self_attn.k_proj",
511
+ "model.layers.14.self_attn.v_proj",
512
+ "model.layers.14.self_attn.o_proj",
513
+ "model.layers.14.moe.gate",
514
+ "model.layers.14.share_expert.gate_up_proj",
515
+ "model.layers.14.share_expert.gate_proj",
516
+ "model.layers.14.share_expert.up_proj",
517
+ "model.layers.14.share_expert.down_proj",
518
+ "model.layers.15.self_attn.g_proj",
519
+ "model.layers.15.self_attn.qkv_proj",
520
+ "model.layers.15.self_attn.q_proj",
521
+ "model.layers.15.self_attn.k_proj",
522
+ "model.layers.15.self_attn.v_proj",
523
+ "model.layers.15.self_attn.o_proj",
524
+ "model.layers.15.moe.gate",
525
+ "model.layers.15.share_expert.gate_up_proj",
526
+ "model.layers.15.share_expert.gate_proj",
527
+ "model.layers.15.share_expert.up_proj",
528
+ "model.layers.15.share_expert.down_proj",
529
+ "model.layers.16.self_attn.g_proj",
530
+ "model.layers.16.self_attn.qkv_proj",
531
+ "model.layers.16.self_attn.q_proj",
532
+ "model.layers.16.self_attn.k_proj",
533
+ "model.layers.16.self_attn.v_proj",
534
+ "model.layers.16.self_attn.o_proj",
535
+ "model.layers.16.moe.gate",
536
+ "model.layers.16.share_expert.gate_up_proj",
537
+ "model.layers.16.share_expert.gate_proj",
538
+ "model.layers.16.share_expert.up_proj",
539
+ "model.layers.16.share_expert.down_proj",
540
+ "model.layers.17.self_attn.g_proj",
541
+ "model.layers.17.self_attn.qkv_proj",
542
+ "model.layers.17.self_attn.q_proj",
543
+ "model.layers.17.self_attn.k_proj",
544
+ "model.layers.17.self_attn.v_proj",
545
+ "model.layers.17.self_attn.o_proj",
546
+ "model.layers.17.moe.gate",
547
+ "model.layers.17.share_expert.gate_up_proj",
548
+ "model.layers.17.share_expert.gate_proj",
549
+ "model.layers.17.share_expert.up_proj",
550
+ "model.layers.17.share_expert.down_proj",
551
+ "model.layers.18.self_attn.g_proj",
552
+ "model.layers.18.self_attn.qkv_proj",
553
+ "model.layers.18.self_attn.q_proj",
554
+ "model.layers.18.self_attn.k_proj",
555
+ "model.layers.18.self_attn.v_proj",
556
+ "model.layers.18.self_attn.o_proj",
557
+ "model.layers.18.moe.gate",
558
+ "model.layers.18.share_expert.gate_up_proj",
559
+ "model.layers.18.share_expert.gate_proj",
560
+ "model.layers.18.share_expert.up_proj",
561
+ "model.layers.18.share_expert.down_proj",
562
+ "model.layers.19.self_attn.g_proj",
563
+ "model.layers.19.self_attn.qkv_proj",
564
+ "model.layers.19.self_attn.q_proj",
565
+ "model.layers.19.self_attn.k_proj",
566
+ "model.layers.19.self_attn.v_proj",
567
+ "model.layers.19.self_attn.o_proj",
568
+ "model.layers.19.moe.gate",
569
+ "model.layers.19.share_expert.gate_up_proj",
570
+ "model.layers.19.share_expert.gate_proj",
571
+ "model.layers.19.share_expert.up_proj",
572
+ "model.layers.19.share_expert.down_proj",
573
+ "model.layers.20.self_attn.g_proj",
574
+ "model.layers.20.self_attn.qkv_proj",
575
+ "model.layers.20.self_attn.q_proj",
576
+ "model.layers.20.self_attn.k_proj",
577
+ "model.layers.20.self_attn.v_proj",
578
+ "model.layers.20.self_attn.o_proj",
579
+ "model.layers.20.moe.gate",
580
+ "model.layers.20.share_expert.gate_up_proj",
581
+ "model.layers.20.share_expert.gate_proj",
582
+ "model.layers.20.share_expert.up_proj",
583
+ "model.layers.20.share_expert.down_proj",
584
+ "model.layers.21.self_attn.g_proj",
585
+ "model.layers.21.self_attn.qkv_proj",
586
+ "model.layers.21.self_attn.q_proj",
587
+ "model.layers.21.self_attn.k_proj",
588
+ "model.layers.21.self_attn.v_proj",
589
+ "model.layers.21.self_attn.o_proj",
590
+ "model.layers.21.moe.gate",
591
+ "model.layers.21.share_expert.gate_up_proj",
592
+ "model.layers.21.share_expert.gate_proj",
593
+ "model.layers.21.share_expert.up_proj",
594
+ "model.layers.21.share_expert.down_proj",
595
+ "model.layers.22.self_attn.g_proj",
596
+ "model.layers.22.self_attn.qkv_proj",
597
+ "model.layers.22.self_attn.q_proj",
598
+ "model.layers.22.self_attn.k_proj",
599
+ "model.layers.22.self_attn.v_proj",
600
+ "model.layers.22.self_attn.o_proj",
601
+ "model.layers.22.moe.gate",
602
+ "model.layers.22.share_expert.gate_up_proj",
603
+ "model.layers.22.share_expert.gate_proj",
604
+ "model.layers.22.share_expert.up_proj",
605
+ "model.layers.22.share_expert.down_proj",
606
+ "model.layers.23.self_attn.g_proj",
607
+ "model.layers.23.self_attn.qkv_proj",
608
+ "model.layers.23.self_attn.q_proj",
609
+ "model.layers.23.self_attn.k_proj",
610
+ "model.layers.23.self_attn.v_proj",
611
+ "model.layers.23.self_attn.o_proj",
612
+ "model.layers.23.moe.gate",
613
+ "model.layers.23.share_expert.gate_up_proj",
614
+ "model.layers.23.share_expert.gate_proj",
615
+ "model.layers.23.share_expert.up_proj",
616
+ "model.layers.23.share_expert.down_proj",
617
+ "model.layers.24.self_attn.g_proj",
618
+ "model.layers.24.self_attn.qkv_proj",
619
+ "model.layers.24.self_attn.q_proj",
620
+ "model.layers.24.self_attn.k_proj",
621
+ "model.layers.24.self_attn.v_proj",
622
+ "model.layers.24.self_attn.o_proj",
623
+ "model.layers.24.moe.gate",
624
+ "model.layers.24.share_expert.gate_up_proj",
625
+ "model.layers.24.share_expert.gate_proj",
626
+ "model.layers.24.share_expert.up_proj",
627
+ "model.layers.24.share_expert.down_proj",
628
+ "model.layers.25.self_attn.g_proj",
629
+ "model.layers.25.self_attn.qkv_proj",
630
+ "model.layers.25.self_attn.q_proj",
631
+ "model.layers.25.self_attn.k_proj",
632
+ "model.layers.25.self_attn.v_proj",
633
+ "model.layers.25.self_attn.o_proj",
634
+ "model.layers.25.moe.gate",
635
+ "model.layers.25.share_expert.gate_up_proj",
636
+ "model.layers.25.share_expert.gate_proj",
637
+ "model.layers.25.share_expert.up_proj",
638
+ "model.layers.25.share_expert.down_proj",
639
+ "model.layers.26.self_attn.g_proj",
640
+ "model.layers.26.self_attn.qkv_proj",
641
+ "model.layers.26.self_attn.q_proj",
642
+ "model.layers.26.self_attn.k_proj",
643
+ "model.layers.26.self_attn.v_proj",
644
+ "model.layers.26.self_attn.o_proj",
645
+ "model.layers.26.moe.gate",
646
+ "model.layers.26.share_expert.gate_up_proj",
647
+ "model.layers.26.share_expert.gate_proj",
648
+ "model.layers.26.share_expert.up_proj",
649
+ "model.layers.26.share_expert.down_proj",
650
+ "model.layers.27.self_attn.g_proj",
651
+ "model.layers.27.self_attn.qkv_proj",
652
+ "model.layers.27.self_attn.q_proj",
653
+ "model.layers.27.self_attn.k_proj",
654
+ "model.layers.27.self_attn.v_proj",
655
+ "model.layers.27.self_attn.o_proj",
656
+ "model.layers.27.moe.gate",
657
+ "model.layers.27.share_expert.gate_up_proj",
658
+ "model.layers.27.share_expert.gate_proj",
659
+ "model.layers.27.share_expert.up_proj",
660
+ "model.layers.27.share_expert.down_proj",
661
+ "model.layers.28.self_attn.g_proj",
662
+ "model.layers.28.self_attn.qkv_proj",
663
+ "model.layers.28.self_attn.q_proj",
664
+ "model.layers.28.self_attn.k_proj",
665
+ "model.layers.28.self_attn.v_proj",
666
+ "model.layers.28.self_attn.o_proj",
667
+ "model.layers.28.moe.gate",
668
+ "model.layers.28.share_expert.gate_up_proj",
669
+ "model.layers.28.share_expert.gate_proj",
670
+ "model.layers.28.share_expert.up_proj",
671
+ "model.layers.28.share_expert.down_proj",
672
+ "model.layers.29.self_attn.g_proj",
673
+ "model.layers.29.self_attn.qkv_proj",
674
+ "model.layers.29.self_attn.q_proj",
675
+ "model.layers.29.self_attn.k_proj",
676
+ "model.layers.29.self_attn.v_proj",
677
+ "model.layers.29.self_attn.o_proj",
678
+ "model.layers.29.moe.gate",
679
+ "model.layers.29.share_expert.gate_up_proj",
680
+ "model.layers.29.share_expert.gate_proj",
681
+ "model.layers.29.share_expert.up_proj",
682
+ "model.layers.29.share_expert.down_proj",
683
+ "model.layers.30.self_attn.g_proj",
684
+ "model.layers.30.self_attn.qkv_proj",
685
+ "model.layers.30.self_attn.q_proj",
686
+ "model.layers.30.self_attn.k_proj",
687
+ "model.layers.30.self_attn.v_proj",
688
+ "model.layers.30.self_attn.o_proj",
689
+ "model.layers.30.moe.gate",
690
+ "model.layers.30.share_expert.gate_up_proj",
691
+ "model.layers.30.share_expert.gate_proj",
692
+ "model.layers.30.share_expert.up_proj",
693
+ "model.layers.30.share_expert.down_proj",
694
+ "model.layers.31.self_attn.g_proj",
695
+ "model.layers.31.self_attn.qkv_proj",
696
+ "model.layers.31.self_attn.q_proj",
697
+ "model.layers.31.self_attn.k_proj",
698
+ "model.layers.31.self_attn.v_proj",
699
+ "model.layers.31.self_attn.o_proj",
700
+ "model.layers.31.moe.gate",
701
+ "model.layers.31.share_expert.gate_up_proj",
702
+ "model.layers.31.share_expert.gate_proj",
703
+ "model.layers.31.share_expert.up_proj",
704
+ "model.layers.31.share_expert.down_proj",
705
+ "model.layers.32.self_attn.g_proj",
706
+ "model.layers.32.self_attn.qkv_proj",
707
+ "model.layers.32.self_attn.q_proj",
708
+ "model.layers.32.self_attn.k_proj",
709
+ "model.layers.32.self_attn.v_proj",
710
+ "model.layers.32.self_attn.o_proj",
711
+ "model.layers.32.moe.gate",
712
+ "model.layers.32.share_expert.gate_up_proj",
713
+ "model.layers.32.share_expert.gate_proj",
714
+ "model.layers.32.share_expert.up_proj",
715
+ "model.layers.32.share_expert.down_proj",
716
+ "model.layers.33.self_attn.g_proj",
717
+ "model.layers.33.self_attn.qkv_proj",
718
+ "model.layers.33.self_attn.q_proj",
719
+ "model.layers.33.self_attn.k_proj",
720
+ "model.layers.33.self_attn.v_proj",
721
+ "model.layers.33.self_attn.o_proj",
722
+ "model.layers.33.moe.gate",
723
+ "model.layers.33.share_expert.gate_up_proj",
724
+ "model.layers.33.share_expert.gate_proj",
725
+ "model.layers.33.share_expert.up_proj",
726
+ "model.layers.33.share_expert.down_proj",
727
+ "model.layers.34.self_attn.g_proj",
728
+ "model.layers.34.self_attn.qkv_proj",
729
+ "model.layers.34.self_attn.q_proj",
730
+ "model.layers.34.self_attn.k_proj",
731
+ "model.layers.34.self_attn.v_proj",
732
+ "model.layers.34.self_attn.o_proj",
733
+ "model.layers.34.moe.gate",
734
+ "model.layers.34.share_expert.gate_up_proj",
735
+ "model.layers.34.share_expert.gate_proj",
736
+ "model.layers.34.share_expert.up_proj",
737
+ "model.layers.34.share_expert.down_proj",
738
+ "model.layers.35.self_attn.g_proj",
739
+ "model.layers.35.self_attn.qkv_proj",
740
+ "model.layers.35.self_attn.q_proj",
741
+ "model.layers.35.self_attn.k_proj",
742
+ "model.layers.35.self_attn.v_proj",
743
+ "model.layers.35.self_attn.o_proj",
744
+ "model.layers.35.moe.gate",
745
+ "model.layers.35.share_expert.gate_up_proj",
746
+ "model.layers.35.share_expert.gate_proj",
747
+ "model.layers.35.share_expert.up_proj",
748
+ "model.layers.35.share_expert.down_proj",
749
+ "model.layers.36.self_attn.g_proj",
750
+ "model.layers.36.self_attn.qkv_proj",
751
+ "model.layers.36.self_attn.q_proj",
752
+ "model.layers.36.self_attn.k_proj",
753
+ "model.layers.36.self_attn.v_proj",
754
+ "model.layers.36.self_attn.o_proj",
755
+ "model.layers.36.moe.gate",
756
+ "model.layers.36.share_expert.gate_up_proj",
757
+ "model.layers.36.share_expert.gate_proj",
758
+ "model.layers.36.share_expert.up_proj",
759
+ "model.layers.36.share_expert.down_proj",
760
+ "model.layers.37.self_attn.g_proj",
761
+ "model.layers.37.self_attn.qkv_proj",
762
+ "model.layers.37.self_attn.q_proj",
763
+ "model.layers.37.self_attn.k_proj",
764
+ "model.layers.37.self_attn.v_proj",
765
+ "model.layers.37.self_attn.o_proj",
766
+ "model.layers.37.moe.gate",
767
+ "model.layers.37.share_expert.gate_up_proj",
768
+ "model.layers.37.share_expert.gate_proj",
769
+ "model.layers.37.share_expert.up_proj",
770
+ "model.layers.37.share_expert.down_proj",
771
+ "model.layers.38.self_attn.g_proj",
772
+ "model.layers.38.self_attn.qkv_proj",
773
+ "model.layers.38.self_attn.q_proj",
774
+ "model.layers.38.self_attn.k_proj",
775
+ "model.layers.38.self_attn.v_proj",
776
+ "model.layers.38.self_attn.o_proj",
777
+ "model.layers.38.moe.gate",
778
+ "model.layers.38.share_expert.gate_up_proj",
779
+ "model.layers.38.share_expert.gate_proj",
780
+ "model.layers.38.share_expert.up_proj",
781
+ "model.layers.38.share_expert.down_proj",
782
+ "model.layers.39.self_attn.g_proj",
783
+ "model.layers.39.self_attn.qkv_proj",
784
+ "model.layers.39.self_attn.q_proj",
785
+ "model.layers.39.self_attn.k_proj",
786
+ "model.layers.39.self_attn.v_proj",
787
+ "model.layers.39.self_attn.o_proj",
788
+ "model.layers.39.moe.gate",
789
+ "model.layers.39.share_expert.gate_up_proj",
790
+ "model.layers.39.share_expert.gate_proj",
791
+ "model.layers.39.share_expert.up_proj",
792
+ "model.layers.39.share_expert.down_proj",
793
+ "model.layers.40.self_attn.g_proj",
794
+ "model.layers.40.self_attn.qkv_proj",
795
+ "model.layers.40.self_attn.q_proj",
796
+ "model.layers.40.self_attn.k_proj",
797
+ "model.layers.40.self_attn.v_proj",
798
+ "model.layers.40.self_attn.o_proj",
799
+ "model.layers.40.moe.gate",
800
+ "model.layers.40.share_expert.gate_up_proj",
801
+ "model.layers.40.share_expert.gate_proj",
802
+ "model.layers.40.share_expert.up_proj",
803
+ "model.layers.40.share_expert.down_proj",
804
+ "model.layers.41.self_attn.g_proj",
805
+ "model.layers.41.self_attn.qkv_proj",
806
+ "model.layers.41.self_attn.q_proj",
807
+ "model.layers.41.self_attn.k_proj",
808
+ "model.layers.41.self_attn.v_proj",
809
+ "model.layers.41.self_attn.o_proj",
810
+ "model.layers.41.moe.gate",
811
+ "model.layers.41.share_expert.gate_up_proj",
812
+ "model.layers.41.share_expert.gate_proj",
813
+ "model.layers.41.share_expert.up_proj",
814
+ "model.layers.41.share_expert.down_proj",
815
+ "model.layers.42.self_attn.g_proj",
816
+ "model.layers.42.self_attn.qkv_proj",
817
+ "model.layers.42.self_attn.q_proj",
818
+ "model.layers.42.self_attn.k_proj",
819
+ "model.layers.42.self_attn.v_proj",
820
+ "model.layers.42.self_attn.o_proj",
821
+ "model.layers.42.moe.gate",
822
+ "model.layers.42.share_expert.gate_up_proj",
823
+ "model.layers.42.share_expert.gate_proj",
824
+ "model.layers.42.share_expert.up_proj",
825
+ "model.layers.42.share_expert.down_proj",
826
+ "model.layers.43.self_attn.g_proj",
827
+ "model.layers.43.self_attn.qkv_proj",
828
+ "model.layers.43.self_attn.q_proj",
829
+ "model.layers.43.self_attn.k_proj",
830
+ "model.layers.43.self_attn.v_proj",
831
+ "model.layers.43.self_attn.o_proj",
832
+ "model.layers.43.moe.gate",
833
+ "model.layers.43.share_expert.gate_up_proj",
834
+ "model.layers.43.share_expert.gate_proj",
835
+ "model.layers.43.share_expert.up_proj",
836
+ "model.layers.43.share_expert.down_proj",
837
+ "model.layers.44.self_attn.g_proj",
838
+ "model.layers.44.self_attn.qkv_proj",
839
+ "model.layers.44.self_attn.q_proj",
840
+ "model.layers.44.self_attn.k_proj",
841
+ "model.layers.44.self_attn.v_proj",
842
+ "model.layers.44.self_attn.o_proj",
843
+ "model.layers.44.moe.gate",
844
+ "model.layers.44.share_expert.gate_up_proj",
845
+ "model.layers.44.share_expert.gate_proj",
846
+ "model.layers.44.share_expert.up_proj",
847
+ "model.layers.44.share_expert.down_proj",
848
+ "model.layers.45.mtp_block.self_attn.g_proj",
849
+ "model.layers.45.mtp_block.self_attn.qkv_proj",
850
+ "model.layers.45.mtp_block.self_attn.q_proj",
851
+ "model.layers.45.mtp_block.self_attn.k_proj",
852
+ "model.layers.45.mtp_block.self_attn.v_proj",
853
+ "model.layers.45.mtp_block.self_attn.o_proj",
854
+ "model.layers.45.mtp_block.mlp.gate_up_proj",
855
+ "model.layers.45.mtp_block.mlp.gate_proj",
856
+ "model.layers.45.mtp_block.mlp.up_proj",
857
+ "model.layers.45.mtp_block.mlp.down_proj",
858
+ "model.layers.46.mtp_block.self_attn.g_proj",
859
+ "model.layers.46.mtp_block.self_attn.qkv_proj",
860
+ "model.layers.46.mtp_block.self_attn.q_proj",
861
+ "model.layers.46.mtp_block.self_attn.k_proj",
862
+ "model.layers.46.mtp_block.self_attn.v_proj",
863
+ "model.layers.46.mtp_block.self_attn.o_proj",
864
+ "model.layers.46.mtp_block.mlp.gate_up_proj",
865
+ "model.layers.46.mtp_block.mlp.gate_proj",
866
+ "model.layers.46.mtp_block.mlp.up_proj",
867
+ "model.layers.46.mtp_block.mlp.down_proj",
868
+ "model.layers.47.mtp_block.self_attn.g_proj",
869
+ "model.layers.47.mtp_block.self_attn.qkv_proj",
870
+ "model.layers.47.mtp_block.self_attn.q_proj",
871
+ "model.layers.47.mtp_block.self_attn.k_proj",
872
+ "model.layers.47.mtp_block.self_attn.v_proj",
873
+ "model.layers.47.mtp_block.self_attn.o_proj",
874
+ "model.layers.47.mtp_block.mlp.gate_up_proj",
875
+ "model.layers.47.mtp_block.mlp.gate_proj",
876
+ "model.layers.47.mtp_block.mlp.up_proj",
877
+ "model.layers.47.mtp_block.mlp.down_proj",
878
+ "vision_model.transformer.resblocks.0.attn.qkv_proj",
879
+ "vision_model.transformer.resblocks.0.attn.q_proj",
880
+ "vision_model.transformer.resblocks.0.attn.k_proj",
881
+ "vision_model.transformer.resblocks.0.attn.v_proj",
882
+ "vision_model.transformer.resblocks.0.attn.out_proj",
883
+ "vision_model.transformer.resblocks.0.mlp.fc1",
884
+ "vision_model.transformer.resblocks.0.mlp.fc2",
885
+ "vision_model.transformer.resblocks.1.attn.qkv_proj",
886
+ "vision_model.transformer.resblocks.1.attn.q_proj",
887
+ "vision_model.transformer.resblocks.1.attn.k_proj",
888
+ "vision_model.transformer.resblocks.1.attn.v_proj",
889
+ "vision_model.transformer.resblocks.1.attn.out_proj",
890
+ "vision_model.transformer.resblocks.1.mlp.fc1",
891
+ "vision_model.transformer.resblocks.1.mlp.fc2",
892
+ "vision_model.transformer.resblocks.2.attn.qkv_proj",
893
+ "vision_model.transformer.resblocks.2.attn.q_proj",
894
+ "vision_model.transformer.resblocks.2.attn.k_proj",
895
+ "vision_model.transformer.resblocks.2.attn.v_proj",
896
+ "vision_model.transformer.resblocks.2.attn.out_proj",
897
+ "vision_model.transformer.resblocks.2.mlp.fc1",
898
+ "vision_model.transformer.resblocks.2.mlp.fc2",
899
+ "vision_model.transformer.resblocks.3.attn.qkv_proj",
900
+ "vision_model.transformer.resblocks.3.attn.q_proj",
901
+ "vision_model.transformer.resblocks.3.attn.k_proj",
902
+ "vision_model.transformer.resblocks.3.attn.v_proj",
903
+ "vision_model.transformer.resblocks.3.attn.out_proj",
904
+ "vision_model.transformer.resblocks.3.mlp.fc1",
905
+ "vision_model.transformer.resblocks.3.mlp.fc2",
906
+ "vision_model.transformer.resblocks.4.attn.qkv_proj",
907
+ "vision_model.transformer.resblocks.4.attn.q_proj",
908
+ "vision_model.transformer.resblocks.4.attn.k_proj",
909
+ "vision_model.transformer.resblocks.4.attn.v_proj",
910
+ "vision_model.transformer.resblocks.4.attn.out_proj",
911
+ "vision_model.transformer.resblocks.4.mlp.fc1",
912
+ "vision_model.transformer.resblocks.4.mlp.fc2",
913
+ "vision_model.transformer.resblocks.5.attn.qkv_proj",
914
+ "vision_model.transformer.resblocks.5.attn.q_proj",
915
+ "vision_model.transformer.resblocks.5.attn.k_proj",
916
+ "vision_model.transformer.resblocks.5.attn.v_proj",
917
+ "vision_model.transformer.resblocks.5.attn.out_proj",
918
+ "vision_model.transformer.resblocks.5.mlp.fc1",
919
+ "vision_model.transformer.resblocks.5.mlp.fc2",
920
+ "vision_model.transformer.resblocks.6.attn.qkv_proj",
921
+ "vision_model.transformer.resblocks.6.attn.q_proj",
922
+ "vision_model.transformer.resblocks.6.attn.k_proj",
923
+ "vision_model.transformer.resblocks.6.attn.v_proj",
924
+ "vision_model.transformer.resblocks.6.attn.out_proj",
925
+ "vision_model.transformer.resblocks.6.mlp.fc1",
926
+ "vision_model.transformer.resblocks.6.mlp.fc2",
927
+ "vision_model.transformer.resblocks.7.attn.qkv_proj",
928
+ "vision_model.transformer.resblocks.7.attn.q_proj",
929
+ "vision_model.transformer.resblocks.7.attn.k_proj",
930
+ "vision_model.transformer.resblocks.7.attn.v_proj",
931
+ "vision_model.transformer.resblocks.7.attn.out_proj",
932
+ "vision_model.transformer.resblocks.7.mlp.fc1",
933
+ "vision_model.transformer.resblocks.7.mlp.fc2",
934
+ "vision_model.transformer.resblocks.8.attn.qkv_proj",
935
+ "vision_model.transformer.resblocks.8.attn.q_proj",
936
+ "vision_model.transformer.resblocks.8.attn.k_proj",
937
+ "vision_model.transformer.resblocks.8.attn.v_proj",
938
+ "vision_model.transformer.resblocks.8.attn.out_proj",
939
+ "vision_model.transformer.resblocks.8.mlp.fc1",
940
+ "vision_model.transformer.resblocks.8.mlp.fc2",
941
+ "vision_model.transformer.resblocks.9.attn.qkv_proj",
942
+ "vision_model.transformer.resblocks.9.attn.q_proj",
943
+ "vision_model.transformer.resblocks.9.attn.k_proj",
944
+ "vision_model.transformer.resblocks.9.attn.v_proj",
945
+ "vision_model.transformer.resblocks.9.attn.out_proj",
946
+ "vision_model.transformer.resblocks.9.mlp.fc1",
947
+ "vision_model.transformer.resblocks.9.mlp.fc2",
948
+ "vision_model.transformer.resblocks.10.attn.qkv_proj",
949
+ "vision_model.transformer.resblocks.10.attn.q_proj",
950
+ "vision_model.transformer.resblocks.10.attn.k_proj",
951
+ "vision_model.transformer.resblocks.10.attn.v_proj",
952
+ "vision_model.transformer.resblocks.10.attn.out_proj",
953
+ "vision_model.transformer.resblocks.10.mlp.fc1",
954
+ "vision_model.transformer.resblocks.10.mlp.fc2",
955
+ "vision_model.transformer.resblocks.11.attn.qkv_proj",
956
+ "vision_model.transformer.resblocks.11.attn.q_proj",
957
+ "vision_model.transformer.resblocks.11.attn.k_proj",
958
+ "vision_model.transformer.resblocks.11.attn.v_proj",
959
+ "vision_model.transformer.resblocks.11.attn.out_proj",
960
+ "vision_model.transformer.resblocks.11.mlp.fc1",
961
+ "vision_model.transformer.resblocks.11.mlp.fc2",
962
+ "vision_model.transformer.resblocks.12.attn.qkv_proj",
963
+ "vision_model.transformer.resblocks.12.attn.q_proj",
964
+ "vision_model.transformer.resblocks.12.attn.k_proj",
965
+ "vision_model.transformer.resblocks.12.attn.v_proj",
966
+ "vision_model.transformer.resblocks.12.attn.out_proj",
967
+ "vision_model.transformer.resblocks.12.mlp.fc1",
968
+ "vision_model.transformer.resblocks.12.mlp.fc2",
969
+ "vision_model.transformer.resblocks.13.attn.qkv_proj",
970
+ "vision_model.transformer.resblocks.13.attn.q_proj",
971
+ "vision_model.transformer.resblocks.13.attn.k_proj",
972
+ "vision_model.transformer.resblocks.13.attn.v_proj",
973
+ "vision_model.transformer.resblocks.13.attn.out_proj",
974
+ "vision_model.transformer.resblocks.13.mlp.fc1",
975
+ "vision_model.transformer.resblocks.13.mlp.fc2",
976
+ "vision_model.transformer.resblocks.14.attn.qkv_proj",
977
+ "vision_model.transformer.resblocks.14.attn.q_proj",
978
+ "vision_model.transformer.resblocks.14.attn.k_proj",
979
+ "vision_model.transformer.resblocks.14.attn.v_proj",
980
+ "vision_model.transformer.resblocks.14.attn.out_proj",
981
+ "vision_model.transformer.resblocks.14.mlp.fc1",
982
+ "vision_model.transformer.resblocks.14.mlp.fc2",
983
+ "vision_model.transformer.resblocks.15.attn.qkv_proj",
984
+ "vision_model.transformer.resblocks.15.attn.q_proj",
985
+ "vision_model.transformer.resblocks.15.attn.k_proj",
986
+ "vision_model.transformer.resblocks.15.attn.v_proj",
987
+ "vision_model.transformer.resblocks.15.attn.out_proj",
988
+ "vision_model.transformer.resblocks.15.mlp.fc1",
989
+ "vision_model.transformer.resblocks.15.mlp.fc2",
990
+ "vision_model.transformer.resblocks.16.attn.qkv_proj",
991
+ "vision_model.transformer.resblocks.16.attn.q_proj",
992
+ "vision_model.transformer.resblocks.16.attn.k_proj",
993
+ "vision_model.transformer.resblocks.16.attn.v_proj",
994
+ "vision_model.transformer.resblocks.16.attn.out_proj",
995
+ "vision_model.transformer.resblocks.16.mlp.fc1",
996
+ "vision_model.transformer.resblocks.16.mlp.fc2",
997
+ "vision_model.transformer.resblocks.17.attn.qkv_proj",
998
+ "vision_model.transformer.resblocks.17.attn.q_proj",
999
+ "vision_model.transformer.resblocks.17.attn.k_proj",
1000
+ "vision_model.transformer.resblocks.17.attn.v_proj",
1001
+ "vision_model.transformer.resblocks.17.attn.out_proj",
1002
+ "vision_model.transformer.resblocks.17.mlp.fc1",
1003
+ "vision_model.transformer.resblocks.17.mlp.fc2",
1004
+ "vision_model.transformer.resblocks.18.attn.qkv_proj",
1005
+ "vision_model.transformer.resblocks.18.attn.q_proj",
1006
+ "vision_model.transformer.resblocks.18.attn.k_proj",
1007
+ "vision_model.transformer.resblocks.18.attn.v_proj",
1008
+ "vision_model.transformer.resblocks.18.attn.out_proj",
1009
+ "vision_model.transformer.resblocks.18.mlp.fc1",
1010
+ "vision_model.transformer.resblocks.18.mlp.fc2",
1011
+ "vision_model.transformer.resblocks.19.attn.qkv_proj",
1012
+ "vision_model.transformer.resblocks.19.attn.q_proj",
1013
+ "vision_model.transformer.resblocks.19.attn.k_proj",
1014
+ "vision_model.transformer.resblocks.19.attn.v_proj",
1015
+ "vision_model.transformer.resblocks.19.attn.out_proj",
1016
+ "vision_model.transformer.resblocks.19.mlp.fc1",
1017
+ "vision_model.transformer.resblocks.19.mlp.fc2",
1018
+ "vision_model.transformer.resblocks.20.attn.qkv_proj",
1019
+ "vision_model.transformer.resblocks.20.attn.q_proj",
1020
+ "vision_model.transformer.resblocks.20.attn.k_proj",
1021
+ "vision_model.transformer.resblocks.20.attn.v_proj",
1022
+ "vision_model.transformer.resblocks.20.attn.out_proj",
1023
+ "vision_model.transformer.resblocks.20.mlp.fc1",
1024
+ "vision_model.transformer.resblocks.20.mlp.fc2",
1025
+ "vision_model.transformer.resblocks.21.attn.qkv_proj",
1026
+ "vision_model.transformer.resblocks.21.attn.q_proj",
1027
+ "vision_model.transformer.resblocks.21.attn.k_proj",
1028
+ "vision_model.transformer.resblocks.21.attn.v_proj",
1029
+ "vision_model.transformer.resblocks.21.attn.out_proj",
1030
+ "vision_model.transformer.resblocks.21.mlp.fc1",
1031
+ "vision_model.transformer.resblocks.21.mlp.fc2",
1032
+ "vision_model.transformer.resblocks.22.attn.qkv_proj",
1033
+ "vision_model.transformer.resblocks.22.attn.q_proj",
1034
+ "vision_model.transformer.resblocks.22.attn.k_proj",
1035
+ "vision_model.transformer.resblocks.22.attn.v_proj",
1036
+ "vision_model.transformer.resblocks.22.attn.out_proj",
1037
+ "vision_model.transformer.resblocks.22.mlp.fc1",
1038
+ "vision_model.transformer.resblocks.22.mlp.fc2",
1039
+ "vision_model.transformer.resblocks.23.attn.qkv_proj",
1040
+ "vision_model.transformer.resblocks.23.attn.q_proj",
1041
+ "vision_model.transformer.resblocks.23.attn.k_proj",
1042
+ "vision_model.transformer.resblocks.23.attn.v_proj",
1043
+ "vision_model.transformer.resblocks.23.attn.out_proj",
1044
+ "vision_model.transformer.resblocks.23.mlp.fc1",
1045
+ "vision_model.transformer.resblocks.23.mlp.fc2",
1046
+ "vision_model.transformer.resblocks.24.attn.qkv_proj",
1047
+ "vision_model.transformer.resblocks.24.attn.q_proj",
1048
+ "vision_model.transformer.resblocks.24.attn.k_proj",
1049
+ "vision_model.transformer.resblocks.24.attn.v_proj",
1050
+ "vision_model.transformer.resblocks.24.attn.out_proj",
1051
+ "vision_model.transformer.resblocks.24.mlp.fc1",
1052
+ "vision_model.transformer.resblocks.24.mlp.fc2",
1053
+ "vision_model.transformer.resblocks.25.attn.qkv_proj",
1054
+ "vision_model.transformer.resblocks.25.attn.q_proj",
1055
+ "vision_model.transformer.resblocks.25.attn.k_proj",
1056
+ "vision_model.transformer.resblocks.25.attn.v_proj",
1057
+ "vision_model.transformer.resblocks.25.attn.out_proj",
1058
+ "vision_model.transformer.resblocks.25.mlp.fc1",
1059
+ "vision_model.transformer.resblocks.25.mlp.fc2",
1060
+ "vision_model.transformer.resblocks.26.attn.qkv_proj",
1061
+ "vision_model.transformer.resblocks.26.attn.q_proj",
1062
+ "vision_model.transformer.resblocks.26.attn.k_proj",
1063
+ "vision_model.transformer.resblocks.26.attn.v_proj",
1064
+ "vision_model.transformer.resblocks.26.attn.out_proj",
1065
+ "vision_model.transformer.resblocks.26.mlp.fc1",
1066
+ "vision_model.transformer.resblocks.26.mlp.fc2",
1067
+ "vision_model.transformer.resblocks.27.attn.qkv_proj",
1068
+ "vision_model.transformer.resblocks.27.attn.q_proj",
1069
+ "vision_model.transformer.resblocks.27.attn.k_proj",
1070
+ "vision_model.transformer.resblocks.27.attn.v_proj",
1071
+ "vision_model.transformer.resblocks.27.attn.out_proj",
1072
+ "vision_model.transformer.resblocks.27.mlp.fc1",
1073
+ "vision_model.transformer.resblocks.27.mlp.fc2",
1074
+ "vision_model.transformer.resblocks.28.attn.qkv_proj",
1075
+ "vision_model.transformer.resblocks.28.attn.q_proj",
1076
+ "vision_model.transformer.resblocks.28.attn.k_proj",
1077
+ "vision_model.transformer.resblocks.28.attn.v_proj",
1078
+ "vision_model.transformer.resblocks.28.attn.out_proj",
1079
+ "vision_model.transformer.resblocks.28.mlp.fc1",
1080
+ "vision_model.transformer.resblocks.28.mlp.fc2",
1081
+ "vision_model.transformer.resblocks.29.attn.qkv_proj",
1082
+ "vision_model.transformer.resblocks.29.attn.q_proj",
1083
+ "vision_model.transformer.resblocks.29.attn.k_proj",
1084
+ "vision_model.transformer.resblocks.29.attn.v_proj",
1085
+ "vision_model.transformer.resblocks.29.attn.out_proj",
1086
+ "vision_model.transformer.resblocks.29.mlp.fc1",
1087
+ "vision_model.transformer.resblocks.29.mlp.fc2",
1088
+ "vision_model.transformer.resblocks.30.attn.qkv_proj",
1089
+ "vision_model.transformer.resblocks.30.attn.q_proj",
1090
+ "vision_model.transformer.resblocks.30.attn.k_proj",
1091
+ "vision_model.transformer.resblocks.30.attn.v_proj",
1092
+ "vision_model.transformer.resblocks.30.attn.out_proj",
1093
+ "vision_model.transformer.resblocks.30.mlp.fc1",
1094
+ "vision_model.transformer.resblocks.30.mlp.fc2",
1095
+ "vision_model.transformer.resblocks.31.attn.qkv_proj",
1096
+ "vision_model.transformer.resblocks.31.attn.q_proj",
1097
+ "vision_model.transformer.resblocks.31.attn.k_proj",
1098
+ "vision_model.transformer.resblocks.31.attn.v_proj",
1099
+ "vision_model.transformer.resblocks.31.attn.out_proj",
1100
+ "vision_model.transformer.resblocks.31.mlp.fc1",
1101
+ "vision_model.transformer.resblocks.31.mlp.fc2",
1102
+ "vision_model.transformer.resblocks.32.attn.qkv_proj",
1103
+ "vision_model.transformer.resblocks.32.attn.q_proj",
1104
+ "vision_model.transformer.resblocks.32.attn.k_proj",
1105
+ "vision_model.transformer.resblocks.32.attn.v_proj",
1106
+ "vision_model.transformer.resblocks.32.attn.out_proj",
1107
+ "vision_model.transformer.resblocks.32.mlp.fc1",
1108
+ "vision_model.transformer.resblocks.32.mlp.fc2",
1109
+ "vision_model.transformer.resblocks.33.attn.qkv_proj",
1110
+ "vision_model.transformer.resblocks.33.attn.q_proj",
1111
+ "vision_model.transformer.resblocks.33.attn.k_proj",
1112
+ "vision_model.transformer.resblocks.33.attn.v_proj",
1113
+ "vision_model.transformer.resblocks.33.attn.out_proj",
1114
+ "vision_model.transformer.resblocks.33.mlp.fc1",
1115
+ "vision_model.transformer.resblocks.33.mlp.fc2",
1116
+ "vision_model.transformer.resblocks.34.attn.qkv_proj",
1117
+ "vision_model.transformer.resblocks.34.attn.q_proj",
1118
+ "vision_model.transformer.resblocks.34.attn.k_proj",
1119
+ "vision_model.transformer.resblocks.34.attn.v_proj",
1120
+ "vision_model.transformer.resblocks.34.attn.out_proj",
1121
+ "vision_model.transformer.resblocks.34.mlp.fc1",
1122
+ "vision_model.transformer.resblocks.34.mlp.fc2",
1123
+ "vision_model.transformer.resblocks.35.attn.qkv_proj",
1124
+ "vision_model.transformer.resblocks.35.attn.q_proj",
1125
+ "vision_model.transformer.resblocks.35.attn.k_proj",
1126
+ "vision_model.transformer.resblocks.35.attn.v_proj",
1127
+ "vision_model.transformer.resblocks.35.attn.out_proj",
1128
+ "vision_model.transformer.resblocks.35.mlp.fc1",
1129
+ "vision_model.transformer.resblocks.35.mlp.fc2",
1130
+ "vision_model.transformer.resblocks.36.attn.qkv_proj",
1131
+ "vision_model.transformer.resblocks.36.attn.q_proj",
1132
+ "vision_model.transformer.resblocks.36.attn.k_proj",
1133
+ "vision_model.transformer.resblocks.36.attn.v_proj",
1134
+ "vision_model.transformer.resblocks.36.attn.out_proj",
1135
+ "vision_model.transformer.resblocks.36.mlp.fc1",
1136
+ "vision_model.transformer.resblocks.36.mlp.fc2",
1137
+ "vision_model.transformer.resblocks.37.attn.qkv_proj",
1138
+ "vision_model.transformer.resblocks.37.attn.q_proj",
1139
+ "vision_model.transformer.resblocks.37.attn.k_proj",
1140
+ "vision_model.transformer.resblocks.37.attn.v_proj",
1141
+ "vision_model.transformer.resblocks.37.attn.out_proj",
1142
+ "vision_model.transformer.resblocks.37.mlp.fc1",
1143
+ "vision_model.transformer.resblocks.37.mlp.fc2",
1144
+ "vision_model.transformer.resblocks.38.attn.qkv_proj",
1145
+ "vision_model.transformer.resblocks.38.attn.q_proj",
1146
+ "vision_model.transformer.resblocks.38.attn.k_proj",
1147
+ "vision_model.transformer.resblocks.38.attn.v_proj",
1148
+ "vision_model.transformer.resblocks.38.attn.out_proj",
1149
+ "vision_model.transformer.resblocks.38.mlp.fc1",
1150
+ "vision_model.transformer.resblocks.38.mlp.fc2",
1151
+ "vision_model.transformer.resblocks.39.attn.qkv_proj",
1152
+ "vision_model.transformer.resblocks.39.attn.q_proj",
1153
+ "vision_model.transformer.resblocks.39.attn.k_proj",
1154
+ "vision_model.transformer.resblocks.39.attn.v_proj",
1155
+ "vision_model.transformer.resblocks.39.attn.out_proj",
1156
+ "vision_model.transformer.resblocks.39.mlp.fc1",
1157
+ "vision_model.transformer.resblocks.39.mlp.fc2",
1158
+ "vision_model.transformer.resblocks.40.attn.qkv_proj",
1159
+ "vision_model.transformer.resblocks.40.attn.q_proj",
1160
+ "vision_model.transformer.resblocks.40.attn.k_proj",
1161
+ "vision_model.transformer.resblocks.40.attn.v_proj",
1162
+ "vision_model.transformer.resblocks.40.attn.out_proj",
1163
+ "vision_model.transformer.resblocks.40.mlp.fc1",
1164
+ "vision_model.transformer.resblocks.40.mlp.fc2",
1165
+ "vision_model.transformer.resblocks.41.attn.qkv_proj",
1166
+ "vision_model.transformer.resblocks.41.attn.q_proj",
1167
+ "vision_model.transformer.resblocks.41.attn.k_proj",
1168
+ "vision_model.transformer.resblocks.41.attn.v_proj",
1169
+ "vision_model.transformer.resblocks.41.attn.out_proj",
1170
+ "vision_model.transformer.resblocks.41.mlp.fc1",
1171
+ "vision_model.transformer.resblocks.41.mlp.fc2",
1172
+ "vision_model.transformer.resblocks.42.attn.qkv_proj",
1173
+ "vision_model.transformer.resblocks.42.attn.q_proj",
1174
+ "vision_model.transformer.resblocks.42.attn.k_proj",
1175
+ "vision_model.transformer.resblocks.42.attn.v_proj",
1176
+ "vision_model.transformer.resblocks.42.attn.out_proj",
1177
+ "vision_model.transformer.resblocks.42.mlp.fc1",
1178
+ "vision_model.transformer.resblocks.42.mlp.fc2",
1179
+ "vision_model.transformer.resblocks.43.attn.qkv_proj",
1180
+ "vision_model.transformer.resblocks.43.attn.q_proj",
1181
+ "vision_model.transformer.resblocks.43.attn.k_proj",
1182
+ "vision_model.transformer.resblocks.43.attn.v_proj",
1183
+ "vision_model.transformer.resblocks.43.attn.out_proj",
1184
+ "vision_model.transformer.resblocks.43.mlp.fc1",
1185
+ "vision_model.transformer.resblocks.43.mlp.fc2",
1186
+ "vision_model.transformer.resblocks.44.attn.qkv_proj",
1187
+ "vision_model.transformer.resblocks.44.attn.q_proj",
1188
+ "vision_model.transformer.resblocks.44.attn.k_proj",
1189
+ "vision_model.transformer.resblocks.44.attn.v_proj",
1190
+ "vision_model.transformer.resblocks.44.attn.out_proj",
1191
+ "vision_model.transformer.resblocks.44.mlp.fc1",
1192
+ "vision_model.transformer.resblocks.44.mlp.fc2",
1193
+ "vision_model.transformer.resblocks.45.attn.qkv_proj",
1194
+ "vision_model.transformer.resblocks.45.attn.q_proj",
1195
+ "vision_model.transformer.resblocks.45.attn.k_proj",
1196
+ "vision_model.transformer.resblocks.45.attn.v_proj",
1197
+ "vision_model.transformer.resblocks.45.attn.out_proj",
1198
+ "vision_model.transformer.resblocks.45.mlp.fc1",
1199
+ "vision_model.transformer.resblocks.45.mlp.fc2",
1200
+ "vision_model.transformer.resblocks.46.attn.qkv_proj",
1201
+ "vision_model.transformer.resblocks.46.attn.q_proj",
1202
+ "vision_model.transformer.resblocks.46.attn.k_proj",
1203
+ "vision_model.transformer.resblocks.46.attn.v_proj",
1204
+ "vision_model.transformer.resblocks.46.attn.out_proj",
1205
+ "vision_model.transformer.resblocks.46.mlp.fc1",
1206
+ "vision_model.transformer.resblocks.46.mlp.fc2",
1207
+ "vit_large_projector"
1208
  ]
1209
  }
1210
+ }
configuration_step3p7.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Sequence, Union
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+ class StepRoboticsVisionEncoderConfig(PretrainedConfig):
6
+ model_type = "perception_encoder"
7
+
8
+ def __init__(
9
+ self,
10
+ width=1536,
11
+ layers=47,
12
+ heads=16,
13
+ num_channels=3,
14
+ image_size=728,
15
+ mlp_ratio = 8960/1536,
16
+ patch_size=14,
17
+ hidden_act="quick_gelu",
18
+ layer_norm_eps=1e-5,
19
+ ues_cls_token=False,
20
+ use_cls_token: Optional[bool] = None,
21
+ use_ln_pre=True,
22
+ use_ln_post=False,
23
+ use_abs_posemb=True,
24
+ use_rope2d=True,
25
+ ls_init_value=0.1,
26
+ **kwargs,
27
+ ):
28
+ self.width = width
29
+ self.layers = layers
30
+ self.heads = heads
31
+ self.num_channels = num_channels
32
+ self.patch_size = patch_size
33
+ self.image_size = image_size
34
+ self.mlp_ratio = mlp_ratio
35
+ self.layer_norm_eps = layer_norm_eps
36
+ self.hidden_act = hidden_act
37
+ if use_cls_token is None:
38
+ use_cls_token = ues_cls_token
39
+ self.ues_cls_token = use_cls_token
40
+ self.use_cls_token = use_cls_token
41
+ self.use_ln_pre = use_ln_pre
42
+ self.ls_init_value = ls_init_value
43
+ self.use_ln_post = use_ln_post
44
+ self.use_abs_posemb = use_abs_posemb
45
+ self.use_rope2d = use_rope2d
46
+ super().__init__(**kwargs)
47
+
48
+
49
+ class Step3p7TextConfig(PretrainedConfig):
50
+ model_type = "step3p5"
51
+ architectures = ["Step3p5ForCausalLM"]
52
+
53
+ def __init__(
54
+ self,
55
+ hidden_size: int = 4096,
56
+ intermediate_size: int = 11264,
57
+ num_attention_heads: int = 64,
58
+ num_attention_groups: int = 8,
59
+ num_hidden_layers: int = 45,
60
+ max_seq_len: int = 128000,
61
+ vocab_size: int = 128815,
62
+ rms_norm_eps: float = 1e-5,
63
+ moe_intermediate_size: int = 1280,
64
+ moe_num_experts: int = 288,
65
+ moe_top_k: int = 8,
66
+ rope_theta: float = 10000,
67
+ rope_scaling: Optional[dict[str, Any]] = None,
68
+ max_position_embeddings: int = 128000,
69
+ share_expert_dims: int = 1280,
70
+ share_expert_dim: Optional[int] = None,
71
+ head_dim: int = 128,
72
+ norm_expert_weight: bool = True,
73
+ layer_types: list[str] = None,
74
+ sliding_window: Optional[int] = None,
75
+ pad_token_id: int = 1,
76
+ attention_dropout: float = 0.0,
77
+ use_head_wise_attn_gate: bool = False,
78
+ use_moe_router_bias: bool = False,
79
+ moe_router_activation: str = "softmax",
80
+ moe_router_scaling_factor: float = 1.0,
81
+ need_fp32_gate: bool = False,
82
+ attention_other_setting: Optional[dict[str, Any]] = None,
83
+ swiglu_limits: Optional[list[Optional[float]]] = None,
84
+ swiglu_limits_shared: Optional[list[Optional[float]]] = None,
85
+ use_rope_layers: Optional[list[bool]] = None,
86
+ yarn_only_types: Optional[list[str]] = None,
87
+ moe_layers_enum: tuple[int] = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
88
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
89
+ 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
90
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44),
91
+ **kwargs,
92
+ ) -> None:
93
+ torch_dtype = kwargs.get("torch_dtype")
94
+ layer_types = _normalize_per_layer_values(layer_types,
95
+ num_hidden_layers)
96
+ swiglu_limits = _normalize_per_layer_values(swiglu_limits,
97
+ num_hidden_layers)
98
+ swiglu_limits_shared = _normalize_per_layer_values(
99
+ swiglu_limits_shared, num_hidden_layers)
100
+ partial_rotary_factors = kwargs.get("partial_rotary_factors")
101
+ kwargs["partial_rotary_factors"] = _normalize_per_layer_values(
102
+ partial_rotary_factors, num_hidden_layers)
103
+ if isinstance(rope_theta, list):
104
+ rope_theta = _normalize_per_layer_values(rope_theta,
105
+ num_hidden_layers)
106
+ if isinstance(rope_scaling, dict):
107
+ rope_scaling = dict(rope_scaling)
108
+ if use_rope_layers:
109
+ use_rope_layers = _normalize_per_layer_values(
110
+ use_rope_layers, num_hidden_layers)
111
+ if share_expert_dim is None:
112
+ share_expert_dim = share_expert_dims
113
+ self.hidden_size = hidden_size
114
+ self.intermediate_size = intermediate_size
115
+ self.num_attention_heads = num_attention_heads
116
+ self.num_attention_groups = num_attention_groups
117
+ self.num_hidden_layers = num_hidden_layers
118
+ self.max_seq_len = max_seq_len
119
+ self.vocab_size = vocab_size
120
+ self.rms_norm_eps = rms_norm_eps
121
+ self.moe_intermediate_size = moe_intermediate_size
122
+ self.moe_num_experts = moe_num_experts
123
+ self.moe_top_k = moe_top_k
124
+ self.rope_theta = rope_theta
125
+ self.rope_scaling = rope_scaling
126
+ self.max_position_embeddings = max_position_embeddings
127
+ self.share_expert_dim = share_expert_dim
128
+ self.head_dim = head_dim
129
+ self.norm_expert_weight = norm_expert_weight
130
+ self.moe_layers_enum = moe_layers_enum
131
+ self.layer_types = layer_types
132
+ self.sliding_window = sliding_window
133
+ self.pad_token_id = pad_token_id
134
+ self.attention_dropout = attention_dropout
135
+ self.use_head_wise_attn_gate = use_head_wise_attn_gate
136
+ self.use_moe_router_bias = use_moe_router_bias
137
+ self.moe_router_activation = moe_router_activation
138
+ self.moe_router_scaling_factor = moe_router_scaling_factor
139
+ self.need_fp32_gate = need_fp32_gate
140
+ self.attention_other_setting = attention_other_setting
141
+ self.swiglu_limits = swiglu_limits
142
+ self.swiglu_limits_shared = swiglu_limits_shared
143
+ self.use_rope_layers = use_rope_layers
144
+ self.yarn_only_types = yarn_only_types
145
+ super().__init__(**kwargs)
146
+ if torch_dtype is not None:
147
+ self.torch_dtype = torch_dtype
148
+
149
+ def to_dict(self):
150
+ output = super().to_dict()
151
+ torch_dtype = getattr(self, "torch_dtype", None)
152
+ if torch_dtype is not None:
153
+ output["torch_dtype"] = torch_dtype
154
+ return output
155
+
156
+
157
+ def _normalize_per_layer_values(
158
+ values: Optional[Sequence[Any]],
159
+ num_hidden_layers: int,
160
+ ) -> Optional[list[Any]]:
161
+ if values is None:
162
+ return None
163
+ normalized = list(values)
164
+ if not normalized:
165
+ return normalized
166
+ if len(normalized) < num_hidden_layers:
167
+ normalized.extend([normalized[-1]] *
168
+ (num_hidden_layers - len(normalized)))
169
+ # Some checkpoints keep MTP/spec layer entries after the decoder layers.
170
+ # This config only builds num_hidden_layers decoder layers, and HF strict
171
+ # validation requires per-layer fields to match that decoder count.
172
+ return normalized[:num_hidden_layers]
173
+
174
+ class Step3p7Config(PretrainedConfig):
175
+ # This loader is a compatibility shim for original Step VL checkpoints
176
+ # whose top-level config model_type is `step3p7`.
177
+ model_type = "step3p7"
178
+
179
+ def __init__(
180
+ self,
181
+ vision_config: Optional[Union[dict, StepRoboticsVisionEncoderConfig]] = None,
182
+ text_config: Optional[Union[dict, Step3p7TextConfig]] = None,
183
+ understand_projector_stride: int = 2,
184
+ projector_bias: bool = False,
185
+ image_token_id: int = 151679,
186
+ **kwargs,
187
+ ) -> None:
188
+ shared_rope_scaling = kwargs.get("rope_scaling")
189
+ if isinstance(shared_rope_scaling, dict):
190
+ shared_rope_scaling = dict(shared_rope_scaling)
191
+
192
+ if vision_config is None:
193
+ vision_config = StepRoboticsVisionEncoderConfig()
194
+ elif isinstance(vision_config, dict):
195
+ vision_config = StepRoboticsVisionEncoderConfig(**vision_config)
196
+ self.vision_config = vision_config
197
+
198
+ if text_config is None:
199
+ text_config = Step3p7TextConfig(rope_scaling=shared_rope_scaling)
200
+ elif isinstance(text_config, dict):
201
+ text_config = dict(text_config)
202
+ if shared_rope_scaling is not None and "rope_scaling" not in text_config:
203
+ text_config["rope_scaling"] = shared_rope_scaling
204
+ text_config = Step3p7TextConfig(**text_config)
205
+ elif shared_rope_scaling is not None and text_config.rope_scaling is None:
206
+ text_config.rope_scaling = dict(shared_rope_scaling)
207
+ self.text_config = text_config
208
+
209
+ rope_scaling = kwargs.get("rope_scaling")
210
+ if isinstance(rope_scaling, dict):
211
+ kwargs["rope_scaling"] = dict(rope_scaling)
212
+
213
+ self.understand_projector_stride = understand_projector_stride
214
+ self.projector_bias = projector_bias
215
+ self.hidden_size = text_config.hidden_size
216
+ self.max_position_embeddings = text_config.max_position_embeddings
217
+ self.image_token_id = image_token_id
218
+ # Help Auto classes find the correct implementation when saving/loading.
219
+ super().__init__(**kwargs)
modeling_step3p7.py ADDED
@@ -0,0 +1,1395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import copy
16
+ import inspect
17
+ from dataclasses import dataclass
18
+ from typing import Callable, Literal, Optional, Tuple, TypedDict, Union
19
+
20
+ from PIL import Image
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from transformers.activations import ACT2FN
26
+ from transformers.cache_utils import Cache, DynamicCache
27
+ from transformers.generation import GenerationMixin
28
+ from transformers.masking_utils import (
29
+ create_causal_mask,
30
+ create_sliding_window_causal_mask,
31
+ )
32
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
33
+ from transformers.modeling_layers import GradientCheckpointingLayer
34
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
35
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
+ from transformers.processing_utils import Unpack
38
+ from transformers.utils import TransformersKwargs, can_return_tuple, logging
39
+ from .configuration_step3p7 import Step3p7Config, Step3p7TextConfig
40
+ from .vision_encoder import StepRoboticsVisionEncoder
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+ _MASK_INPUT_EMBEDS_ARG = (
45
+ "inputs_embeds"
46
+ if "inputs_embeds" in inspect.signature(create_causal_mask).parameters
47
+ else "input_embeds"
48
+ )
49
+
50
+ __all__ = [
51
+ "Step3p7Model",
52
+ ]
53
+
54
+
55
+ class StepVLImagePixelInputs(TypedDict):
56
+ type: Literal["pixel_values"]
57
+ pixel_values: torch.Tensor
58
+ patch_pixel_values: Optional[torch.Tensor]
59
+ num_patches: list[int]
60
+
61
+
62
+ class StepVLImageEmbeddingInputs(TypedDict):
63
+ type: Literal["image_embeds"]
64
+ image_embeds: torch.Tensor
65
+
66
+
67
+ StepVLImageInputs = Union[StepVLImagePixelInputs, StepVLImageEmbeddingInputs]
68
+
69
+
70
+ def _flatten_embeddings(embeddings) -> torch.Tensor:
71
+ """
72
+ Recursively flattens and concatenates NestedTensors on all but the last
73
+ dimension.
74
+ """
75
+
76
+ if isinstance(embeddings, torch.Tensor):
77
+ # Flatten all but the last dimension.
78
+ return embeddings.flatten(0, -2)
79
+
80
+ return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
81
+
82
+ def _embedding_count_expression(embeddings) -> str:
83
+ """
84
+ Constructs a debugging representation of the number of embeddings in the
85
+ NestedTensors.
86
+ """
87
+
88
+ if isinstance(embeddings, torch.Tensor):
89
+ return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
90
+
91
+ return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
92
+
93
+
94
+ def _merge_multimodal_embeddings(
95
+ inputs_embeds: torch.Tensor,
96
+ is_multimodal: torch.Tensor,
97
+ multimodal_embeddings,
98
+ ) -> torch.Tensor:
99
+ """
100
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
101
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
102
+ ``input_ids``.
103
+ Note:
104
+ This updates ``inputs_embeds`` in place.
105
+ """
106
+ num_expected_tokens = is_multimodal.sum().item()
107
+ assert isinstance(num_expected_tokens, int)
108
+
109
+ flattened = _flatten_embeddings(multimodal_embeddings)
110
+ if flattened.shape[0] != num_expected_tokens:
111
+ expr = _embedding_count_expression(multimodal_embeddings)
112
+ raise ValueError(
113
+ f"Attempted to assign {expr} = {flattened.shape[0]} "
114
+ f"multimodal tokens to {num_expected_tokens} placeholders"
115
+ )
116
+
117
+ is_multimodal = is_multimodal.to(inputs_embeds.device)
118
+ flattened = flattened.to(inputs_embeds.device)
119
+ inputs_embeds[is_multimodal] = flattened
120
+ return inputs_embeds
121
+
122
+ def merge_multimodal_embeddings(
123
+ input_ids: torch.Tensor,
124
+ inputs_embeds: torch.Tensor,
125
+ multimodal_embeddings,
126
+ placeholder_token_id: Union[int, list[int]],
127
+ ) -> torch.Tensor:
128
+ """
129
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
130
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
131
+ ``input_ids``.
132
+
133
+ ``placeholder_token_id`` can be a list of token ids (e.g, token ids
134
+ of img_start, img_break, and img_end tokens) when needed: This means
135
+ the order of these tokens in the ``input_ids`` MUST MATCH the order of
136
+ their embeddings in ``multimodal_embeddings`` since we need to
137
+ slice-merge instead of individually scattering.
138
+ For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
139
+ - T is text token
140
+ - S is image start token
141
+ - I is image embedding token
142
+ - B is image break token
143
+ - E is image end token.
144
+
145
+ Then the image embeddings (that correspond to I's) from vision encoder
146
+ must be padded with embeddings of S, B, and E in the same order of
147
+ input_ids for a correct embedding merge.
148
+ Note:
149
+ This updates ``inputs_embeds`` in place.
150
+ """
151
+ if isinstance(placeholder_token_id, list):
152
+ placeholder_token_id = torch.tensor(
153
+ placeholder_token_id, device=input_ids.device
154
+ )
155
+ return _merge_multimodal_embeddings(
156
+ inputs_embeds,
157
+ torch.isin(input_ids, placeholder_token_id),
158
+ multimodal_embeddings,
159
+ )
160
+
161
+ return _merge_multimodal_embeddings(
162
+ inputs_embeds,
163
+ (input_ids == placeholder_token_id),
164
+ multimodal_embeddings,
165
+ )
166
+
167
+
168
+ class Step3p7PreTrainedModel(PreTrainedModel):
169
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
170
+ # can load the config instead of failing with a NoneType error.
171
+ config_class = Step3p7Config
172
+ supports_gradient_checkpointing = True
173
+ _skip_keys_device_placement = ["past_key_values"]
174
+ _keys_to_ignore_on_load_unexpected = [
175
+ r"model\.layers\.45\.*",
176
+ r"model\.layers\.46\.*",
177
+ r"model\.layers\.47\.*",
178
+ ]
179
+ _supports_flash_attn = False
180
+ _supports_sdpa = True
181
+ _supports_flex_attn = True
182
+ _supports_static_cache = True
183
+ _supports_attention_backend = True
184
+
185
+ @classmethod
186
+ def from_pretrained(
187
+ cls, pretrained_model_name_or_path, *model_args, **kwargs
188
+ ):
189
+ key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
190
+ if key_mapping is not None and kwargs.get("key_mapping") is None:
191
+ # Transformers only applies checkpoint renaming when key_mapping is
192
+ # passed explicitly; inheriting the class attribute alone is not enough.
193
+ kwargs["key_mapping"] = copy.deepcopy(key_mapping)
194
+ return super().from_pretrained(
195
+ pretrained_model_name_or_path, *model_args, **kwargs
196
+ )
197
+
198
+
199
+ class Step3p7RotaryEmbedding(nn.Module):
200
+ def __init__(self, config: Step3p7TextConfig, device=None, layer_idx=None):
201
+ super().__init__()
202
+ # BC: "rope_type" was originally "type"
203
+ self.layer_idx = layer_idx
204
+ self.original_rope_parameters = None
205
+ if config.rope_parameters is not None:
206
+ self.original_rope_parameters = config.rope_parameters
207
+ config.rope_parameters = dict(config.rope_parameters)
208
+ self.rope_type = config.rope_parameters.get(
209
+ "rope_type", config.rope_parameters.get("type")
210
+ )
211
+ else:
212
+ self.rope_type = "default"
213
+ self.max_seq_len_cached = config.max_position_embeddings
214
+ self.original_max_seq_len = config.max_position_embeddings
215
+
216
+ partial_rotary_factors = getattr(
217
+ config, "partial_rotary_factors", None
218
+ )
219
+ if partial_rotary_factors is not None:
220
+ config.partial_rotary_factor = partial_rotary_factors[self.layer_idx]
221
+ else:
222
+ config.partial_rotary_factor = 1.0
223
+
224
+ self.rope_theta = config.rope_theta
225
+ if isinstance(config.rope_theta, list):
226
+ self.rope_theta = config.rope_theta.copy()
227
+ config.rope_theta = self.rope_theta[self.layer_idx]
228
+
229
+ self.config = copy.copy(config)
230
+ if config.rope_parameters is not None:
231
+ self.config.rope_parameters = dict(config.rope_parameters)
232
+ self.rope_init_fn = self.compute_default_rope_parameters
233
+ if self.rope_type != "default":
234
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
235
+ inv_freq, self.attention_scaling = self.rope_init_fn(
236
+ self.config, device
237
+ )
238
+
239
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
240
+ self.original_inv_freq = self.inv_freq
241
+ config.rope_theta = self.rope_theta
242
+ config.rope_parameters = self.original_rope_parameters
243
+
244
+ @torch.no_grad()
245
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
246
+ def forward(self, x, position_ids):
247
+ inv_freq_expanded = (
248
+ self.inv_freq[None, :, None]
249
+ .float()
250
+ .expand(position_ids.shape[0], -1, 1)
251
+ .to(x.device)
252
+ )
253
+ position_ids_expanded = position_ids[:, None, :].float().to(x.device)
254
+
255
+ device_type = (
256
+ x.device.type
257
+ if isinstance(x.device.type, str) and x.device.type != "mps"
258
+ else "cpu"
259
+ )
260
+ with torch.autocast(
261
+ device_type=device_type, enabled=False
262
+ ): # Force float32
263
+ freqs = (
264
+ inv_freq_expanded.float() @ position_ids_expanded.float()
265
+ ).transpose(1, 2)
266
+ emb = torch.cat((freqs, freqs), dim=-1)
267
+ cos = emb.cos() * self.attention_scaling
268
+ sin = emb.sin() * self.attention_scaling
269
+
270
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
271
+
272
+ @staticmethod
273
+ def compute_default_rope_parameters(
274
+ config: Step3p7TextConfig | None = None,
275
+ device: Optional["torch.device"] = None,
276
+ ) -> tuple["torch.Tensor", float]:
277
+ """
278
+ Computes the inverse frequencies according to the original RoPE implementation
279
+ Args:
280
+ config ([`~transformers.PreTrainedConfig`]):
281
+ The model configuration.
282
+ device (`torch.device`):
283
+ The device to use for initialization of the inverse frequencies.
284
+ seq_len (`int`, *optional*):
285
+ The current sequence length. Unused for this type of RoPE.
286
+ Returns:
287
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
288
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
289
+ """
290
+ base = config.rope_theta
291
+ dim = (
292
+ getattr(config, "head_dim", None)
293
+ or config.hidden_size // config.num_attention_heads
294
+ )
295
+
296
+ attention_factor = 1.0 # Unused in this type of RoPE
297
+
298
+ # Compute the inverse frequencies
299
+ inv_freq = 1.0 / (
300
+ base
301
+ ** (
302
+ torch.arange(0, dim, 2, dtype=torch.int64).to(
303
+ device=device, dtype=torch.float
304
+ )
305
+ / dim
306
+ )
307
+ )
308
+ return inv_freq, attention_factor
309
+
310
+ def rotate_half(x):
311
+ """Rotates half the hidden dims of the input."""
312
+ x1 = x[..., :x.shape[-1] // 2]
313
+ x2 = x[..., x.shape[-1] // 2:]
314
+ return torch.cat((-x2, x1), dim=-1)
315
+
316
+
317
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
318
+ """Applies Rotary Position Embedding to the query and key tensors.
319
+
320
+ Args:
321
+ q (`torch.Tensor`): The query tensor.
322
+ k (`torch.Tensor`): The key tensor.
323
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
324
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
325
+ position_ids (`torch.Tensor`, *optional*):
326
+ Deprecated and unused.
327
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
328
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
329
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
330
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
331
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
332
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
333
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
334
+ Returns:
335
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
336
+ """
337
+ rotary_dim = cos.shape[-1]
338
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
339
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
340
+
341
+ # Apply rotary embeddings on the first half or full tensor
342
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
343
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
344
+
345
+ # Concatenate back to full shape
346
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
347
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
348
+ return q_embed, k_embed
349
+
350
+
351
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
352
+ """
353
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
354
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
355
+ """
356
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
357
+ if n_rep == 1:
358
+ return hidden_states
359
+ hidden_states = hidden_states[:, :, None, :, :].expand(
360
+ batch, num_key_value_heads, n_rep, slen, head_dim
361
+ )
362
+ return hidden_states.reshape(
363
+ batch, num_key_value_heads * n_rep, slen, head_dim
364
+ )
365
+
366
+
367
+ # Adapted from transformers.models.llama.modeling_llama.eager_attention_forward.
368
+ # Llama4 does not cast attention weights to fp32 here.
369
+ def eager_attention_forward(
370
+ module: nn.Module,
371
+ query: torch.Tensor,
372
+ key: torch.Tensor,
373
+ value: torch.Tensor,
374
+ attention_mask: Optional[torch.Tensor],
375
+ scaling: float,
376
+ dropout: float = 0.0,
377
+ **kwargs,
378
+ ):
379
+ key_states = repeat_kv(key, module.num_key_value_groups)
380
+ value_states = repeat_kv(value, module.num_key_value_groups)
381
+ # breakpoint()
382
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
383
+ if attention_mask is not None:
384
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
385
+ attn_weights = attn_weights + causal_mask
386
+
387
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
388
+ attn_weights = nn.functional.dropout(
389
+ attn_weights, p=dropout, training=module.training
390
+ )
391
+ attn_output = torch.matmul(attn_weights, value_states)
392
+ attn_output = attn_output.transpose(1, 2).contiguous()
393
+
394
+ return attn_output, attn_weights
395
+
396
+
397
+ @dataclass
398
+ class Step3p7CausalLMOutputWithPast(ModelOutput):
399
+ r"""
400
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
401
+ Language modeling loss (for next-token prediction).
402
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
403
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
404
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
405
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
406
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
407
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
408
+ `past_key_values` input) to speed up sequential decoding.
409
+ """
410
+
411
+ loss: Optional[torch.FloatTensor] = None
412
+ last_hidden_state: Optional[torch.FloatTensor] = None
413
+ logits: torch.FloatTensor = None
414
+ past_key_values: Optional[list[torch.FloatTensor]] = None
415
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
416
+ attentions: Optional[tuple[torch.FloatTensor]] = None
417
+
418
+
419
+ class Step3p7MLP(nn.Module):
420
+ def __init__(self, config, intermediate_size=None, swiglu_limit=None):
421
+ super().__init__()
422
+ self.config = config
423
+ self.hidden_size = config.hidden_size
424
+ self.intermediate_size = (
425
+ intermediate_size
426
+ if intermediate_size is not None
427
+ else config.intermediate_size
428
+ )
429
+ self.gate_proj = nn.Linear(self.hidden_size,
430
+ self.intermediate_size,
431
+ bias=False)
432
+ self.up_proj = nn.Linear(self.hidden_size,
433
+ self.intermediate_size,
434
+ bias=False)
435
+ self.down_proj = nn.Linear(self.intermediate_size,
436
+ self.hidden_size,
437
+ bias=False)
438
+ self.act_fn = ACT2FN["silu"]
439
+ self.limit = swiglu_limit
440
+
441
+ def forward(self, x):
442
+ up = self.up_proj(x)
443
+ gate = self.act_fn(self.gate_proj(x))
444
+ if self.limit is not None:
445
+ gate = gate.clamp(min=None, max=self.limit)
446
+ up = up.clamp(min=-self.limit, max=self.limit)
447
+
448
+ return self.down_proj(gate * up)
449
+
450
+
451
+ def sigmoid_routing_function(gating_output: torch.Tensor, topk: int,
452
+ renormalize: bool):
453
+ gating_output = gating_output.float()
454
+ gate_prob = torch.sigmoid(gating_output)
455
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
456
+ topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1)
457
+ expert_topk_weight = topk_prob
458
+ if renormalize:
459
+ expert_topk_weight = expert_topk_weight / torch.sum(
460
+ expert_topk_weight, dim=-1, keepdim=True)
461
+ return expert_topk_weight, indices
462
+
463
+
464
+ def softmax_routing_function(gating_output: torch.Tensor, top_k: int,
465
+ renormalize: bool):
466
+ gating_output = gating_output.float()
467
+ gate_prob = torch.softmax(gating_output, dim=-1)
468
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
469
+ topk_prob, indices = torch.topk(gate_prob, k=top_k, dim=1)
470
+ expert_topk_weight = topk_prob
471
+ if renormalize:
472
+ expert_topk_weight = expert_topk_weight / torch.sum(
473
+ expert_topk_weight, dim=-1, keepdim=True)
474
+ return expert_topk_weight, indices.to(torch.int32)
475
+
476
+
477
+ class MoELinear(nn.Module):
478
+
479
+ def __init__(self, num_experts, in_features, out_features):
480
+ super().__init__()
481
+ self.num_experts = num_experts
482
+ self.in_features = in_features
483
+ self.out_features = out_features
484
+ self.weight = nn.Parameter(
485
+ torch.empty(num_experts, out_features, in_features))
486
+
487
+ def forward(self, x, expert_id):
488
+ x = F.linear(x.float(), self.weight[expert_id].float())
489
+ return x
490
+
491
+
492
+ class Step3p7MoEMLP(nn.Module):
493
+
494
+ def __init__(self, config, swiglu_limit=None):
495
+ super().__init__()
496
+ self.num_experts = config.moe_num_experts
497
+ self.top_k = config.moe_top_k
498
+ self.hidden_size = config.hidden_size
499
+ self.moe_intermediate_size = config.moe_intermediate_size
500
+
501
+ self.use_moe_router_bias = config.use_moe_router_bias
502
+ if self.use_moe_router_bias:
503
+ self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts,
504
+ dtype=torch.float32),
505
+ requires_grad=False)
506
+ self.custom_routing_function = self.router_bias_func
507
+ elif config.moe_router_activation == "sigmoid":
508
+ self.custom_routing_function = sigmoid_routing_function
509
+ else:
510
+ self.custom_routing_function = None
511
+ self.need_fp32_gate = config.need_fp32_gate
512
+ self.routed_scaling_factor = getattr(config,
513
+ "moe_router_scaling_factor", 1.0)
514
+
515
+ # gating
516
+ self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
517
+
518
+ self.act_fn = ACT2FN["silu"]
519
+ self.limit = swiglu_limit
520
+
521
+ self.up_proj = MoELinear(self.num_experts, self.hidden_size,
522
+ self.moe_intermediate_size)
523
+ self.gate_proj = MoELinear(self.num_experts, self.hidden_size,
524
+ self.moe_intermediate_size)
525
+ self.down_proj = MoELinear(self.num_experts,
526
+ self.moe_intermediate_size,
527
+ self.hidden_size)
528
+
529
+ def router_bias_func(self, gating_output: torch.Tensor, topk: int,
530
+ renormalize: bool):
531
+ gate_prob = torch.sigmoid(gating_output.float())
532
+ gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0)
533
+ _, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1)
534
+ topk_prob = torch.gather(gate_prob, 1, indices)
535
+ expert_topk_weight = topk_prob
536
+ if renormalize:
537
+ expert_topk_weight = expert_topk_weight / (
538
+ torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20)
539
+ return expert_topk_weight, indices
540
+
541
+ def get_expert_output(self, inputs: torch.Tensor, expert_id):
542
+ #if self.limit is None:
543
+ up = self.up_proj(inputs, expert_id)
544
+ gate = self.act_fn(self.gate_proj(inputs, expert_id))
545
+ if self.limit is not None:
546
+ gate = gate.clamp(min=None, max=self.limit)
547
+ up = up.clamp(min=-self.limit, max=self.limit)
548
+
549
+ return self.down_proj(gate * up, expert_id)
550
+
551
+ def forward(self, hidden_states):
552
+ """ """
553
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
554
+ hidden_states = hidden_states.view(-1, hidden_dim)
555
+ if self.need_fp32_gate:
556
+ router_logits = torch.matmul(
557
+ hidden_states.to(torch.float32),
558
+ self.gate.weight.t().to(torch.float32),
559
+ )
560
+ else:
561
+ # router_logits: (batch * sequence_length, n_experts)
562
+ router_logits = self.gate(hidden_states)
563
+
564
+ if self.custom_routing_function:
565
+ routing_weights, selected_experts = self.custom_routing_function(
566
+ router_logits, self.top_k, renormalize=True)
567
+ else:
568
+ routing_weights = F.softmax(router_logits,
569
+ dim=1,
570
+ dtype=torch.float)
571
+ routing_weights, selected_experts = torch.topk(routing_weights,
572
+ self.top_k,
573
+ dim=-1)
574
+
575
+ routing_weights = routing_weights * self.routed_scaling_factor
576
+
577
+ final_hidden_states = torch.zeros(
578
+ (batch_size * sequence_length, hidden_dim),
579
+ dtype=hidden_states.dtype,
580
+ device=hidden_states.device)
581
+
582
+ # One hot encode the selected experts to create an expert mask
583
+ # this will be used to easily index which expert is going to be sollicitated
584
+ expert_mask = torch.nn.functional.one_hot(
585
+ selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
586
+
587
+ # Loop over all available experts in the model and perform the computation on each expert
588
+ for expert_idx in range(self.num_experts):
589
+ idx, top_x = torch.where(expert_mask[expert_idx])
590
+
591
+ # Index the correct hidden states and compute the expert hidden state for
592
+ # the current expert. We need to make sure to multiply the output hidden
593
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
594
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
595
+ current_hidden_states = (
596
+ self.get_expert_output(current_state, expert_idx) *
597
+ routing_weights[top_x, idx, None])
598
+
599
+ # However `index_add_` only support torch tensors for indexing so we'll use
600
+ # the `top_x` tensor here.
601
+ final_hidden_states.index_add_(
602
+ 0, top_x, current_hidden_states.to(hidden_states.dtype))
603
+ final_hidden_states = final_hidden_states.reshape(
604
+ batch_size, sequence_length, hidden_dim)
605
+ return final_hidden_states
606
+
607
+
608
+ class Step3p7RMSNorm(nn.Module):
609
+
610
+ def __init__(
611
+ self,
612
+ hidden_size: int,
613
+ eps: float = 1e-5,
614
+ ) -> None:
615
+ super().__init__()
616
+ self.weight = nn.Parameter(torch.ones(hidden_size))
617
+ self.variance_epsilon = eps
618
+
619
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
620
+ dtype = x.dtype
621
+ x = x.float()
622
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
623
+ normed = x * torch.rsqrt(variance + self.variance_epsilon)
624
+ normed = normed * (self.weight.float() + 1)
625
+ return normed.to(dtype)
626
+ class Step3p7Attention(nn.Module):
627
+
628
+ def __init__(self, config: Step3p7TextConfig, layer_idx):
629
+ super().__init__()
630
+ self.config = config
631
+ self.layer_idx = layer_idx
632
+ self.num_attention_heads = config.num_attention_heads
633
+ self.num_key_value_heads = config.num_attention_groups
634
+
635
+ layer_types = getattr(config, "layer_types", [])
636
+ if layer_types:
637
+ enable_sliding_window = layer_types[
638
+ self.layer_idx] == "sliding_attention"
639
+ else:
640
+ enable_sliding_window = self.layer_idx % 2 == 0
641
+
642
+ yarn_only_types = getattr(config, "yarn_only_types", None)
643
+ if yarn_only_types and layer_types[
644
+ self.layer_idx] not in yarn_only_types:
645
+ config.rope_parameters = None
646
+ else:
647
+ config.rope_parameters = getattr(config, "rope_scaling", None)
648
+
649
+ self.sliding_window = config.sliding_window
650
+ if enable_sliding_window:
651
+ self.num_attention_heads = config.attention_other_setting[
652
+ "num_attention_heads"]
653
+ self.num_key_value_heads = config.attention_other_setting[
654
+ "num_attention_groups"]
655
+
656
+ if self.sliding_window is not None and enable_sliding_window:
657
+ self.sliding_window = (self.sliding_window)
658
+ else:
659
+ self.sliding_window = None
660
+ self.head_dim = getattr(config, "head_dim",
661
+ config.hidden_size // self.num_attention_heads)
662
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
663
+
664
+ self.rotary_emb = Step3p7RotaryEmbedding(config, layer_idx=layer_idx)
665
+
666
+ self.q_size = self.num_attention_heads * self.head_dim
667
+ self.kv_size = self.num_key_value_heads * self.head_dim
668
+ self.scaling = self.head_dim**-0.5
669
+
670
+ self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=False)
671
+ self.k_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
672
+ self.v_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
673
+ self.o_proj = nn.Linear(self.q_size, config.hidden_size, bias=False)
674
+ self.attention_dropout = getattr(config, "attention_dropout", 0.0)
675
+ self.q_norm = Step3p7RMSNorm(self.head_dim,
676
+ eps=config.rms_norm_eps)
677
+ self.k_norm = Step3p7RMSNorm(self.head_dim,
678
+ eps=config.rms_norm_eps)
679
+
680
+ self.use_head_wise_attn_gate = config.use_head_wise_attn_gate
681
+ if self.use_head_wise_attn_gate:
682
+ self.g_proj = nn.Linear(config.hidden_size,
683
+ self.num_attention_heads,
684
+ bias=False)
685
+
686
+ self.use_rope = True
687
+ use_rope_layers = getattr(config, "use_rope_layers", None)
688
+ if use_rope_layers:
689
+ self.use_rope = use_rope_layers[self.layer_idx]
690
+
691
+ def forward(
692
+ self,
693
+ hidden_states: torch.Tensor,
694
+ attention_mask: Optional[torch.Tensor],
695
+ past_key_value: Optional[Cache] = None,
696
+ cache_position: Optional[torch.LongTensor] = None,
697
+ position_ids: Optional[torch.LongTensor] = None,
698
+ **kwargs: Unpack[FlashAttentionKwargs],
699
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
700
+ Optional[Tuple[torch.Tensor]]]:
701
+ input_shape = hidden_states.shape[:-1]
702
+ hidden_shape = (*input_shape, -1, self.head_dim)
703
+
704
+ query_states = self.q_norm(
705
+ self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
706
+ key_states = self.k_norm(
707
+ self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
708
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(
709
+ 1, 2)
710
+ if self.use_head_wise_attn_gate:
711
+ gate_states = self.g_proj(hidden_states)
712
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
713
+
714
+ # cos, sin = position_embeddings
715
+ query_states, key_states = apply_rotary_pos_emb(
716
+ query_states, key_states, cos, sin)
717
+
718
+ # query_states, key_states = apply_rotary_pos_emb(query_norm_states, key_norm_states, cos, sin)
719
+ if past_key_value is not None:
720
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
721
+ cache_kwargs = {
722
+ "sin": sin,
723
+ "cos": cos,
724
+ "cache_position": cache_position
725
+ }
726
+ key_states, value_states = past_key_value.update(
727
+ key_states, value_states, self.layer_idx, cache_kwargs)
728
+
729
+ attention_interface: Callable = eager_attention_forward
730
+ # TODO: considering FP8;
731
+ # RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype,
732
+ # but got attn_mask.dtype: long int and query.dtype: c10::BFloat16 instead.
733
+ if self.config._attn_implementation != "eager":
734
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
735
+ self.config._attn_implementation]
736
+
737
+ attn_output, attn_weights = attention_interface(
738
+ self,
739
+ query_states,
740
+ key_states,
741
+ value_states,
742
+ attention_mask,
743
+ dropout=0.0 if not self.training else self.attention_dropout,
744
+ scaling=self.scaling,
745
+ sliding_window=self.sliding_window, # main diff with Llama
746
+ **kwargs,
747
+ )
748
+ attn_output = attn_output.reshape(*input_shape, -1)
749
+ if self.use_head_wise_attn_gate:
750
+ output = attn_output.view(
751
+ *attn_output.shape[:-1], self.num_attention_heads,
752
+ self.head_dim) * gate_states.unsqueeze(-1).sigmoid()
753
+ attn_output = output.view(*attn_output.shape)
754
+ attn_output = self.o_proj(attn_output)
755
+
756
+ return attn_output, attn_weights
757
+
758
+
759
+ class Step3p7DecoderLayer(GradientCheckpointingLayer):
760
+
761
+ def __init__(self, config, layer_idx):
762
+ super().__init__()
763
+ self.hidden_size = config.hidden_size
764
+ self.layer_idx = layer_idx
765
+ self.self_attn = Step3p7Attention(config, layer_idx)
766
+ layer_types = getattr(config, "layer_types", None) or []
767
+ if layer_types:
768
+ self.attention_type = layer_types[layer_idx]
769
+ else:
770
+ self.attention_type = (
771
+ "sliding_attention" if layer_idx % 2 == 0 else "full_attention"
772
+ )
773
+
774
+ moe_layers_enum = getattr(config, "moe_layers_enum", None)
775
+ if moe_layers_enum is not None:
776
+ if isinstance(moe_layers_enum, str):
777
+ moe_layers_idx = [
778
+ int(i) for i in moe_layers_enum.split(',') if i.strip()
779
+ ]
780
+ else:
781
+ moe_layers_idx = [int(i) for i in moe_layers_enum]
782
+ else:
783
+ moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
784
+ self.is_moe_layer = layer_idx in moe_layers_idx
785
+ self.use_moe = False
786
+
787
+ if config.swiglu_limits_shared and config.swiglu_limits_shared[
788
+ layer_idx] is not None and config.swiglu_limits_shared[
789
+ layer_idx] != 0:
790
+ swiglu_limit_shared = config.swiglu_limits_shared[layer_idx]
791
+ else:
792
+ swiglu_limit_shared = None
793
+ if config.swiglu_limits and config.swiglu_limits[
794
+ layer_idx] is not None and config.swiglu_limits[layer_idx] != 0:
795
+ swiglu_limit = config.swiglu_limits[layer_idx]
796
+ else:
797
+ swiglu_limit = None
798
+ if self.is_moe_layer:
799
+ self.moe = Step3p7MoEMLP(config, swiglu_limit=swiglu_limit) #
800
+ self.share_expert = Step3p7MLP(
801
+ config,
802
+ intermediate_size=config.share_expert_dim,
803
+ swiglu_limit=swiglu_limit_shared)
804
+ self.use_moe = True
805
+ else:
806
+ self.mlp = Step3p7MLP(config,
807
+ intermediate_size=config.intermediate_size,
808
+ swiglu_limit=swiglu_limit_shared)
809
+
810
+ self.input_layernorm = Step3p7RMSNorm(
811
+ config.hidden_size,
812
+ eps=config.rms_norm_eps)
813
+ self.post_attention_layernorm = Step3p7RMSNorm(
814
+ config.hidden_size,
815
+ eps=config.rms_norm_eps)
816
+
817
+ def forward(
818
+ self,
819
+ hidden_states: torch.Tensor,
820
+ attention_mask: Optional[torch.Tensor] = None,
821
+ position_ids: Optional[torch.LongTensor] = None,
822
+ past_key_value: Optional[tuple[torch.Tensor]] = None,
823
+ cache_position: Optional[torch.LongTensor] = None,
824
+ **kwargs: Unpack[FlashAttentionKwargs],
825
+ ) -> torch.FloatTensor:
826
+ residual = hidden_states
827
+ hidden_states = self.input_layernorm(hidden_states)
828
+ hidden_states, _ = self.self_attn(
829
+ hidden_states=hidden_states,
830
+ attention_mask=attention_mask,
831
+ position_ids=position_ids,
832
+ past_key_value=past_key_value,
833
+ cache_position=cache_position,
834
+ **kwargs,
835
+ )
836
+ hidden_states = residual + hidden_states
837
+
838
+ # Fully Connected
839
+ residual = hidden_states
840
+ hidden_states = self.post_attention_layernorm(hidden_states)
841
+ if self.use_moe:
842
+ share_output = self.share_expert(hidden_states)
843
+ moe_output = self.moe(hidden_states)
844
+ ffn_output = moe_output + share_output
845
+ else:
846
+ ffn_output = self.mlp(hidden_states)
847
+ if isinstance(ffn_output, tuple):
848
+ hidden_states, _ = ffn_output
849
+ else:
850
+ hidden_states = ffn_output
851
+
852
+ hidden_states = residual + hidden_states
853
+ return hidden_states
854
+
855
+
856
+ class Step3p7TextPreTrainedModel(PreTrainedModel):
857
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
858
+ # can load the config instead of failing with a NoneType error.
859
+ config_class = Step3p7TextConfig
860
+ supports_gradient_checkpointing = True
861
+ _skip_keys_device_placement = ["past_key_values"]
862
+ _keys_to_ignore_on_load_unexpected = [
863
+ r"model\.layers\.45\.*",
864
+ r"model\.layers\.46\.*",
865
+ r"model\.layers\.47\.*",
866
+ ]
867
+ _supports_flash_attn = False
868
+ _supports_sdpa = True
869
+ _supports_flex_attn = True
870
+ _supports_static_cache = True
871
+ _supports_attention_backend = True
872
+
873
+
874
+ class Step3p7TextModel(Step3p7TextPreTrainedModel, GenerationMixin):
875
+ _no_split_modules = ["Step3p7DecoderLayer"]
876
+ base_model_prefix = "model"
877
+ _tied_weights_keys = ["lm_head.weight"]
878
+ config: Step3p7TextConfig
879
+
880
+ def __init__(self, config: Step3p7TextConfig):
881
+ super().__init__(config)
882
+ self.padding_idx = config.pad_token_id
883
+ self.vocab_size = config.vocab_size
884
+
885
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
886
+ self.padding_idx)
887
+ self.layers = nn.ModuleList([
888
+ Step3p7DecoderLayer(config, layer_idx)
889
+ for layer_idx in range(config.num_hidden_layers)
890
+ ])
891
+ self.norm = Step3p7RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
892
+ self.gradient_checkpointing = False
893
+ layer_types = self.config.layer_types or []
894
+ self.has_sliding_layers = (not layer_types or
895
+ "sliding_attention" in layer_types)
896
+
897
+ # Initialize weights and apply final processing
898
+ self.post_init()
899
+
900
+
901
+ def get_input_embeddings(self, input_ids):
902
+ return self.embed_tokens(input_ids)
903
+
904
+ @can_return_tuple
905
+ def forward(
906
+ self,
907
+ input_ids: torch.LongTensor = None,
908
+ attention_mask: Optional[torch.Tensor] = None,
909
+ position_ids: Optional[torch.LongTensor] = None,
910
+ past_key_values: Optional[Cache] = None,
911
+ inputs_embeds: Optional[torch.FloatTensor] = None,
912
+ use_cache: Optional[bool] = None,
913
+ output_attentions: Optional[bool] = None,
914
+ output_hidden_states: Optional[bool] = None,
915
+ return_dict: Optional[bool] = None,
916
+ cache_position: Optional[torch.LongTensor] = None,
917
+ **kwargs: Unpack[TransformersKwargs],
918
+ ) -> Union[tuple, BaseModelOutputWithPast]:
919
+ output_attentions = (
920
+ output_attentions
921
+ if output_attentions is not None
922
+ else self.config.output_attentions
923
+ )
924
+ output_hidden_states = (
925
+ output_hidden_states
926
+ if output_hidden_states is not None
927
+ else self.config.output_hidden_states
928
+ )
929
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
930
+ return_dict = (
931
+ return_dict
932
+ if return_dict is not None
933
+ else getattr(self.config, "return_dict", True)
934
+ )
935
+ if (input_ids is None) ^ (inputs_embeds is not None):
936
+ raise ValueError(
937
+ "You must specify exactly one of input_ids or inputs_embeds")
938
+
939
+ if self.gradient_checkpointing and self.training and use_cache:
940
+ logger.warning_once(
941
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
942
+ )
943
+ use_cache = False
944
+
945
+ if inputs_embeds is None:
946
+ inputs_embeds = self.embed_tokens(
947
+ input_ids.to(self.embed_tokens.weight.device))
948
+
949
+ if use_cache and past_key_values is None:
950
+ past_key_values = DynamicCache()
951
+
952
+ if cache_position is None:
953
+ past_seen_tokens = past_key_values.get_seq_length(
954
+ ) if past_key_values is not None else 0
955
+ cache_position = torch.arange(past_seen_tokens,
956
+ past_seen_tokens +
957
+ inputs_embeds.shape[1],
958
+ device=inputs_embeds.device)
959
+
960
+ if position_ids is None:
961
+ position_ids = cache_position.unsqueeze(0)
962
+
963
+ hidden_states = inputs_embeds
964
+
965
+ # It may already have been prepared by e.g. `generate`
966
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
967
+ # Prepare mask arguments
968
+ mask_kwargs = {
969
+ "config": self.config,
970
+ "attention_mask": attention_mask,
971
+ "cache_position": cache_position,
972
+ "past_key_values": past_key_values,
973
+ "position_ids": position_ids,
974
+ }
975
+ mask_kwargs[_MASK_INPUT_EMBEDS_ARG] = inputs_embeds
976
+ # Create the masks
977
+ causal_mask_mapping = {
978
+ "full_attention": create_causal_mask(**mask_kwargs),
979
+ }
980
+
981
+ # The sliding window alternating layers are not always activated depending on the config
982
+ if self.has_sliding_layers:
983
+ causal_mask_mapping[
984
+ "sliding_attention"] = create_sliding_window_causal_mask(
985
+ **mask_kwargs)
986
+
987
+ # # create position embeddings to be shared across the decoder layers
988
+ # decoder layers
989
+ all_hidden_states = () if output_hidden_states else None
990
+ all_self_attns = () if output_attentions else None
991
+ for decoder_layer in self.layers[:self.config.num_hidden_layers]:
992
+ if output_hidden_states:
993
+ all_hidden_states += (hidden_states, )
994
+
995
+ layer_outputs = decoder_layer(
996
+ hidden_states,
997
+ attention_mask=causal_mask_mapping[
998
+ decoder_layer.attention_type],
999
+ position_ids=position_ids,
1000
+ past_key_value=past_key_values,
1001
+ output_attentions=output_attentions,
1002
+ use_cache=use_cache,
1003
+ cache_position=cache_position,
1004
+ **kwargs,
1005
+ )
1006
+
1007
+ hidden_states = layer_outputs
1008
+
1009
+ hidden_states = self.norm(hidden_states)
1010
+
1011
+ return BaseModelOutputWithPast(
1012
+ last_hidden_state=hidden_states,
1013
+ past_key_values=past_key_values if use_cache else None,
1014
+ hidden_states=all_hidden_states,
1015
+ attentions=all_self_attns,
1016
+ )
1017
+
1018
+
1019
+ class Step3p7Model(Step3p7PreTrainedModel, GenerationMixin):
1020
+ config: Step3p7Config
1021
+ _tied_weights_keys = ["lm_head.weight"]
1022
+ base_model_prefix = ""
1023
+
1024
+ def __init__(self, config: Step3p7Config):
1025
+ super().__init__(config)
1026
+ self.vision_model = StepRoboticsVisionEncoder(config.vision_config)
1027
+ self.language_model = Step3p7TextModel(config.text_config)
1028
+ self.vocab_size = config.text_config.vocab_size
1029
+ self.vit_large_projector = nn.Linear(
1030
+ config.vision_config.width * 4,
1031
+ config.text_config.hidden_size,
1032
+ bias=config.projector_bias)
1033
+ self.image_placeholder_token_id = config.image_token_id
1034
+
1035
+ # Initialize weights and apply final processing
1036
+ self.post_init()
1037
+
1038
+ def get_input_embeddings(
1039
+ self,
1040
+ input_ids: torch.Tensor,
1041
+ multimodal_embeddings = None,
1042
+ ) -> torch.Tensor:
1043
+ # breakpoint()
1044
+ input_ids = input_ids.squeeze(0)
1045
+ if multimodal_embeddings is None:
1046
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
1047
+ else:
1048
+ is_text = input_ids != self.config.image_token_id
1049
+ text_ids = input_ids[is_text]
1050
+ text_embeds = self.language_model.get_input_embeddings(text_ids)
1051
+
1052
+ inputs_embeds = torch.empty(input_ids.shape[0],
1053
+ text_embeds.shape[-1],
1054
+ dtype=text_embeds.dtype,
1055
+ device=text_embeds.device)
1056
+ inputs_embeds[is_text] = text_embeds
1057
+ inputs_embeds = merge_multimodal_embeddings(
1058
+ input_ids, inputs_embeds, multimodal_embeddings,
1059
+ self.config.image_token_id)
1060
+ inputs_embeds = inputs_embeds.unsqueeze(0)
1061
+ return inputs_embeds
1062
+
1063
+
1064
+ def set_input_embeddings(self, value):
1065
+ return self.language_model.set_input_embeddings(value)
1066
+
1067
+ def set_decoder(self, decoder):
1068
+ self.language_model = decoder
1069
+
1070
+ def get_decoder(self):
1071
+ return self.language_model
1072
+
1073
+ def _parse_and_validate_image_input(
1074
+ self, **kwargs: object) -> Optional[StepVLImageInputs]:
1075
+ pixel_values = kwargs.pop("pixel_values", None)
1076
+ patch_pixel_values = kwargs.pop("patch_pixel_values", None)
1077
+ num_patches = kwargs.pop("num_patches", None)
1078
+ image_embeds = kwargs.pop("image_embeds", None)
1079
+
1080
+ if pixel_values is None and image_embeds is None:
1081
+ return None
1082
+
1083
+ if pixel_values is not None:
1084
+ # pixel_values = flatten_bn(pixel_values, concat=True)
1085
+ if pixel_values.dim() >= 3:
1086
+ pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
1087
+ if patch_pixel_values is not None:
1088
+ # patch_pixel_values = flatten_bn(patch_pixel_values,
1089
+ # concat=True)
1090
+ patch_pixel_values = patch_pixel_values.view(
1091
+ -1, *patch_pixel_values.shape[-3:])
1092
+ # Handle empty patch_pixel_values by setting to None
1093
+ if patch_pixel_values.shape[0] == 0:
1094
+ patch_pixel_values = None
1095
+
1096
+ return StepVLImagePixelInputs(
1097
+ type="pixel_values",
1098
+ pixel_values=pixel_values.to(self.dtype).to(self.device),
1099
+ patch_pixel_values=patch_pixel_values.to(self.dtype).to(
1100
+ self.device) if patch_pixel_values is not None else None,
1101
+ num_patches=num_patches,
1102
+ )
1103
+
1104
+ if image_embeds is not None:
1105
+ if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
1106
+ image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
1107
+ else:
1108
+ raise ValueError(
1109
+ f"Unexpected shape for image_embeds: {image_embeds.shape}")
1110
+
1111
+ return StepVLImageEmbeddingInputs(
1112
+ type="image_embeds",
1113
+ image_embeds=image_embeds.to(self.dtype).to(self.device),
1114
+ )
1115
+ return None
1116
+
1117
+ def _process_image_features(self,
1118
+ image_features: torch.Tensor) -> torch.Tensor:
1119
+ B, P = image_features.shape[:2]
1120
+ HW = int(P ** 0.5)
1121
+ image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
1122
+ image_features = self.vision_model.vit_downsampler1(image_features)
1123
+ image_features = self.vision_model.vit_downsampler2(image_features)
1124
+
1125
+ B, C, HW, HW = image_features.shape
1126
+ image_features = image_features.view(B, -1, HW * HW).permute(0, 2, 1)
1127
+ image_features = self.vit_large_projector(image_features)
1128
+ return image_features
1129
+
1130
+ def _get_vision_model_output(self,
1131
+ input_tensor: torch.Tensor) -> torch.Tensor:
1132
+ return self.vision_model(input_tensor)
1133
+
1134
+ def _process_image_input(
1135
+ self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]:
1136
+
1137
+ if image_input["type"] == "image_embeds":
1138
+ image_features = image_input["image_embeds"]
1139
+ else:
1140
+ image_features = self._get_vision_model_output(
1141
+ image_input["pixel_values"])
1142
+ patch_image_features = self._get_vision_model_output(
1143
+ image_input["patch_pixel_values"]
1144
+ ) if image_input["patch_pixel_values"] is not None else None
1145
+ num_patches = image_input["num_patches"]
1146
+
1147
+ image_features = self._process_image_features(image_features)
1148
+ patch_image_features = self._process_image_features(
1149
+ patch_image_features) if patch_image_features is not None else None
1150
+
1151
+ merged_image_features = []
1152
+ cur_patch_idx = 0
1153
+ for i, num_patch in enumerate(num_patches):
1154
+ cur_feature = []
1155
+ if num_patch > 0:
1156
+ patch_slice = patch_image_features[
1157
+ cur_patch_idx:cur_patch_idx + num_patch]
1158
+ cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
1159
+ cur_feature.append(image_features[i].view(
1160
+ -1, image_features.shape[-1]))
1161
+ cur_patch_idx += num_patch
1162
+ merged_image_features.append(
1163
+ torch.cat(cur_feature) if len(cur_feature) >
1164
+ 1 else cur_feature[0])
1165
+
1166
+ return merged_image_features
1167
+
1168
+ def get_multimodal_embeddings(self, **kwargs):
1169
+ # breakpoint()
1170
+ image_input = self._parse_and_validate_image_input(**kwargs)
1171
+ if image_input is None:
1172
+ return None
1173
+ vision_embeddings = self._process_image_input(image_input)
1174
+ return vision_embeddings
1175
+
1176
+ @can_return_tuple
1177
+ def forward(
1178
+ self,
1179
+ input_ids: torch.LongTensor = None,
1180
+ attention_mask: Optional[torch.Tensor] = None,
1181
+ position_ids: Optional[torch.LongTensor] = None,
1182
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
1183
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1184
+ labels: Optional[torch.LongTensor] = None,
1185
+ use_cache: Optional[bool] = None,
1186
+ output_attentions: Optional[bool] = None,
1187
+ output_hidden_states: Optional[bool] = None,
1188
+ return_dict: Optional[bool] = None,
1189
+ cache_position: Optional[torch.LongTensor] = None,
1190
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1191
+ images: Optional[list[Image.Image]] = None,
1192
+ **kwargs: Unpack[TransformersKwargs],
1193
+ ) -> Union[tuple, Step3p7CausalLMOutputWithPast]:
1194
+ r"""
1195
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1196
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1197
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1198
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1199
+ Example:
1200
+ ```python
1201
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
1202
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
1203
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
1204
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1205
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1206
+ >>> # Generate
1207
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1208
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1209
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1210
+ ```"""
1211
+ output_attentions = (
1212
+ output_attentions
1213
+ if output_attentions is not None
1214
+ else self.config.output_attentions
1215
+ )
1216
+ output_hidden_states = (
1217
+ output_hidden_states
1218
+ if output_hidden_states is not None
1219
+ else self.config.output_hidden_states
1220
+ )
1221
+ return_dict = (
1222
+ return_dict if return_dict is not None else self.config.use_return_dict
1223
+ )
1224
+
1225
+ if inputs_embeds is None:
1226
+ input_ids = input_ids
1227
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
1228
+ inputs_embeds = self.get_input_embeddings(input_ids,
1229
+ vision_embeddings)
1230
+ input_ids = None
1231
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1232
+ outputs = self.language_model(
1233
+ input_ids=None,
1234
+ position_ids=position_ids,
1235
+ attention_mask=attention_mask,
1236
+ past_key_values=past_key_values,
1237
+ inputs_embeds=inputs_embeds,
1238
+ use_cache=use_cache,
1239
+ output_attentions=output_attentions,
1240
+ output_hidden_states=output_hidden_states,
1241
+ return_dict=True,
1242
+ cache_position=cache_position,
1243
+ **kwargs,
1244
+ )
1245
+
1246
+ output = Step3p7CausalLMOutputWithPast(
1247
+ last_hidden_state=outputs.last_hidden_state,
1248
+ past_key_values=outputs.past_key_values,
1249
+ attentions=outputs.attentions,
1250
+ )
1251
+ return output if return_dict else output.to_tuple()
1252
+
1253
+
1254
+ class Step3p7ForConditionalGeneration(Step3p7PreTrainedModel, GenerationMixin):
1255
+ _checkpoint_conversion_mapping = {
1256
+ "^vision_model": "model.vision_model",
1257
+ r"^model(?!\.(language_model|vision_model))": "model.language_model",
1258
+ "^vit_large_projector": "model.vit_large_projector",
1259
+ }
1260
+ _tied_weights_keys = ["lm_head.weight"]
1261
+ config: Step3p7Config
1262
+
1263
+ def __init__(self, config: Step3p7Config):
1264
+ super().__init__(config)
1265
+ self.model = Step3p7Model(config)
1266
+ self.lm_head = nn.Linear(config.hidden_size,
1267
+ config.text_config.vocab_size,
1268
+ bias=False)
1269
+
1270
+ self.post_init()
1271
+
1272
+ def get_input_embeddings(self):
1273
+ return self.model.get_input_embeddings()
1274
+
1275
+ def set_input_embeddings(self, value):
1276
+ self.model.set_input_embeddings(value)
1277
+
1278
+ def get_output_embeddings(self):
1279
+ return self.model.get_output_embeddings()
1280
+
1281
+ def set_output_embeddings(self, new_embeddings):
1282
+ self.model.set_output_embeddings(new_embeddings)
1283
+
1284
+ def set_decoder(self, decoder):
1285
+ self.model.set_decoder(decoder)
1286
+
1287
+ def get_decoder(self):
1288
+ return self.model.get_decoder()
1289
+
1290
+ @property
1291
+ def language_model(self):
1292
+ return self.model.language_model
1293
+
1294
+ @property
1295
+ def visual(self):
1296
+ return self.model.vision_model
1297
+
1298
+ def forward(
1299
+ self,
1300
+ input_ids: torch.LongTensor = None,
1301
+ pixel_values: Optional[torch.Tensor] = None,
1302
+ num_patches=None,
1303
+ patch_pixel_values=None,
1304
+ patch_newline_mask=None,
1305
+ image_embeds: Optional[torch.FloatTensor] = None,
1306
+ attention_mask: Optional[torch.Tensor] = None,
1307
+ position_ids: Optional[torch.LongTensor] = None,
1308
+ past_key_values: Optional[Cache] = None,
1309
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1310
+ labels: Optional[torch.LongTensor] = None,
1311
+ use_cache: Optional[bool] = None,
1312
+ output_attentions: Optional[bool] = None,
1313
+ output_hidden_states: Optional[bool] = None,
1314
+ return_dict: Optional[bool] = None,
1315
+ cache_position: Optional[torch.LongTensor] = None,
1316
+ **kwargs: Unpack[TransformersKwargs],
1317
+ ) -> Union[tuple, Step3p7CausalLMOutputWithPast]:
1318
+ output_attentions = (
1319
+ output_attentions
1320
+ if output_attentions is not None
1321
+ else self.config.output_attentions
1322
+ )
1323
+ output_hidden_states = (
1324
+ output_hidden_states
1325
+ if output_hidden_states is not None
1326
+ else self.config.output_hidden_states
1327
+ )
1328
+
1329
+ outputs = self.model(
1330
+ input_ids=input_ids,
1331
+ num_patches=num_patches,
1332
+ patch_pixel_values=patch_pixel_values,
1333
+ patch_newline_mask=patch_newline_mask,
1334
+ position_ids=position_ids,
1335
+ attention_mask=attention_mask,
1336
+ past_key_values=past_key_values,
1337
+ inputs_embeds=inputs_embeds,
1338
+ use_cache=use_cache,
1339
+ output_attentions=output_attentions,
1340
+ output_hidden_states=output_hidden_states,
1341
+ return_dict=return_dict,
1342
+ cache_position=cache_position,
1343
+ **kwargs,
1344
+ )
1345
+
1346
+ hidden_states = outputs.last_hidden_state
1347
+ logits = self.lm_head(hidden_states)
1348
+
1349
+ los = None
1350
+ if labels is not None:
1351
+ loss = self.loss_function(
1352
+ logits=logits, labels=labels, vocab_size=self.config.vocab_size
1353
+ )
1354
+
1355
+ return Step3p7CausalLMOutputWithPast(
1356
+ logits=logits,
1357
+ )
1358
+
1359
+
1360
+ def prepare_inputs_for_generation(
1361
+ self,
1362
+ input_ids,
1363
+ past_key_values=None,
1364
+ inputs_embeds=None,
1365
+ pixel_values=None,
1366
+ patch_pixel_values=None,
1367
+ num_patches=None,
1368
+ image_embeds=None,
1369
+ attention_mask=None,
1370
+ cache_position=None,
1371
+ logits_to_keep=None,
1372
+ **kwargs,
1373
+ ):
1374
+ model_inputs = super().prepare_inputs_for_generation(
1375
+ input_ids,
1376
+ past_key_values=past_key_values,
1377
+ inputs_embeds=inputs_embeds,
1378
+ attention_mask=attention_mask,
1379
+ cache_position=cache_position,
1380
+ logits_to_keep=logits_to_keep,
1381
+ **kwargs,
1382
+ )
1383
+
1384
+ if cache_position[0] == 0:
1385
+ # During cached decoding, input ids no longer contain image tokens,
1386
+ # so pixel values should only be passed at the first step.
1387
+ model_inputs["pixel_values"] = pixel_values
1388
+
1389
+ return model_inputs
1390
+
1391
+ def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
1392
+ if key.startswith("language_model."):
1393
+ return key[len("language_model.") :], True
1394
+
1395
+ return key, False
processing_step3.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BaseImageProcessor, ImageProcessingMixin
2
+ from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
3
+ import math
4
+ from typing import Iterable, Optional, Tuple, List, TypedDict, Literal, Union, overload
5
+
6
+ from PIL import Image
7
+ import torch
8
+ import numpy as np
9
+ import torchvision
10
+ from torch import nn
11
+ from torch.nn import functional as F, LayerNorm
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ from transformers.activations import ACT2FN
14
+ from torchvision import transforms
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from transformers.feature_extraction_utils import BatchFeature, TensorType
17
+ from transformers.image_utils import ImageInput
18
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
19
+ from math import ceil
20
+ from itertools import product
21
+
22
+
23
+
24
+ MAX_IMAGE_SIZE: int = 3024
25
+
26
+ class Step3VLImagePixelInputs(TypedDict):
27
+ type: Literal["pixel_values"]
28
+ pixel_values: torch.Tensor
29
+ patch_pixel_values: Optional[torch.Tensor]
30
+ num_patches: list[int]
31
+
32
+
33
+ class Step3VLImageEmbeddingInputs(TypedDict):
34
+ type: Literal["image_embeds"]
35
+ image_embeds: torch.Tensor
36
+
37
+
38
+ ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
39
+
40
+
41
+ class GPUToTensor(torch.nn.Module):
42
+
43
+ def forward(self, raw_image: Union[np.ndarray,
44
+ Image.Image]) -> torch.Tensor:
45
+ if isinstance(raw_image, Image.Image):
46
+ return transforms.ToTensor()(raw_image)
47
+ if raw_image.ndim == 2:
48
+ raw_image = raw_image[:, :, None].repeat(3, -1)
49
+ if torch.cuda.is_available():
50
+ device = torch.device("cuda")
51
+ else:
52
+ device = torch.device("cpu")
53
+ image_tensor = torch.from_numpy(raw_image).to(device)
54
+ image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous()
55
+ if image_tensor.dtype == torch.uint8:
56
+ image_tensor = image_tensor.to(torch.float32).div(255)
57
+ return image_tensor
58
+
59
+ class Step3VisionProcessor(BaseImageProcessor):
60
+
61
+ def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
62
+ mean = [0.48145466, 0.4578275, 0.40821073]
63
+ std = [0.26862954, 0.26130258, 0.27577711]
64
+ patch_size = patch_size if patch_size is not None else size
65
+
66
+ self.transform = transforms.Compose([
67
+ GPUToTensor(),
68
+ transforms.Normalize(mean, std),
69
+ transforms.Resize(
70
+ (size, size),
71
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
72
+ == "bicubic" else InterpolationMode.BILINEAR,
73
+ antialias=True),
74
+ ])
75
+
76
+ self.patch_transform = transforms.Compose([
77
+ GPUToTensor(),
78
+ transforms.Normalize(mean, std),
79
+ transforms.Resize(
80
+ (patch_size, patch_size),
81
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
82
+ == "bicubic" else InterpolationMode.BILINEAR,
83
+ antialias=True),
84
+ ]) if patch_size is not None else None
85
+
86
+ def __call__(self, image, is_patch=False):
87
+ if is_patch:
88
+ return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
89
+ else:
90
+ return {"pixel_values": self.transform(image).unsqueeze(0)}
91
+
92
+ class ImagePatcher:
93
+ def determine_window_size(self, long: int, short: int) -> int:
94
+ if long <= 728:
95
+ return short if long / short > 1.5 else 0
96
+ return min(short, 504) if long / short > 4 else 504
97
+ def slide_window(
98
+ self,
99
+ width: int,
100
+ height: int,
101
+ sizes: list[tuple[int, int]],
102
+ steps: list[tuple[int, int]],
103
+ img_rate_thr: float = 0.6,
104
+ ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
105
+ assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
106
+ windows = []
107
+ # Sliding windows.
108
+ for size, step in zip(sizes, steps):
109
+ size_w, size_h = size
110
+ step_w, step_h = step
111
+
112
+ x_num = 1 if width <= size_w else ceil((width - size_w) / step_w +
113
+ 1)
114
+ x_start = [step_w * i for i in range(x_num)]
115
+ if len(x_start) > 1 and x_start[-1] + size_w > width:
116
+ x_start[-1] = width - size_w
117
+
118
+ y_num = 1 if height <= size_h else ceil((height - size_h) /
119
+ step_h + 1)
120
+ y_start = [step_h * i for i in range(y_num)]
121
+ if len(y_start) > 1 and y_start[-1] + size_h > height:
122
+ y_start[-1] = height - size_h
123
+
124
+ start = np.array(list(product(y_start, x_start)), dtype=int)
125
+ start[:, [0, 1]] = start[:, [1, 0]]
126
+ windows.append(np.concatenate([start, start + size], axis=1))
127
+ windows = np.concatenate(windows, axis=0)
128
+
129
+ return [(int(box[0]), int(box[1]), int(box[2] - box[0]),
130
+ int(box[3] - box[1])) for box in windows], (x_num, y_num)
131
+
132
+ def square_pad(self, img: Image.Image) -> Image.Image:
133
+ w, h = img.size
134
+ if w == h:
135
+ return img
136
+ size = max(w, h)
137
+ padded = Image.new(img.mode, (size, size), 0)
138
+ padded.paste(img, (0, 0))
139
+ return padded
140
+
141
+ def get_image_size_for_padding(self, img_width: int,
142
+ img_height: int) -> tuple[int, int]:
143
+ ratio = img_width / img_height
144
+ if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
145
+ new_size = max(img_height, img_width)
146
+ return new_size, new_size
147
+ return img_width, img_height
148
+
149
+ def get_image_size_for_preprocess(self, img_width: int,
150
+ img_height: int) -> tuple[int, int]:
151
+
152
+ if max(img_height, img_width) > MAX_IMAGE_SIZE:
153
+ scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width)
154
+ img_width = int(img_width * scale_factor)
155
+ img_height = int(img_height * scale_factor)
156
+ return img_width, img_height
157
+
158
+ def get_image_size_for_crop(self, img_width: int, img_height: int,
159
+ window_size: int):
160
+ w_ratio = img_width / window_size
161
+ h_ratio = img_height / window_size
162
+
163
+ if w_ratio < 1:
164
+ width_new = img_width
165
+ else:
166
+ decimal_w = w_ratio - img_width // window_size
167
+ w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
168
+ width_new = window_size * w_ratio
169
+ if h_ratio < 1:
170
+ height_new = img_height
171
+ else:
172
+ decimal_h = h_ratio - img_height // window_size
173
+ h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
174
+ height_new = window_size * h_ratio
175
+ return int(width_new), int(height_new)
176
+
177
+ def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
178
+ target = img.crop((j, i, j + tw, i + th))
179
+ return target
180
+
181
+ def get_num_patches(self, img_width: int,
182
+ img_height: int) -> tuple[int, int]:
183
+ img_width, img_height = self.get_image_size_for_padding(
184
+ img_width, img_height)
185
+ img_width, img_height = self.get_image_size_for_preprocess(
186
+ img_width, img_height)
187
+ window_size = self.determine_window_size(max(img_height, img_width),
188
+ min(img_height, img_width))
189
+ if window_size == 0:
190
+ return 0, 0
191
+ else:
192
+ img_width, img_height = self.get_image_size_for_crop(
193
+ img_width, img_height, window_size)
194
+ center_list, (x_num, y_num) = self.slide_window(
195
+ img_width, img_height, [(window_size, window_size)],
196
+ [(window_size, window_size)])
197
+ full_rows = (len(center_list) - 1) // x_num + 1
198
+ if len(center_list) > 0 and len(center_list) % x_num == 0:
199
+ full_rows -= 1
200
+ return len(center_list), full_rows
201
+
202
+ def __call__(
203
+ self, img: Image.Image
204
+ ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
205
+ img_width, img_height = img.size
206
+ new_img_width, new_img_height = self.get_image_size_for_padding(
207
+ img_width, img_height)
208
+ if new_img_width != img_width or new_img_height != img_height:
209
+ img = self.square_pad(img)
210
+ img_width, img_height = img.size
211
+
212
+ new_img_width, new_img_height = self.get_image_size_for_preprocess(
213
+ img_width, img_height)
214
+ img = img.resize((new_img_width, new_img_height),
215
+ Image.Resampling.BILINEAR)
216
+ window_size = self.determine_window_size(
217
+ max(new_img_height, new_img_width),
218
+ min(new_img_height, new_img_width))
219
+ # return img, [], None
220
+ if window_size == 0:
221
+ return img, [], None
222
+ else:
223
+ new_img_width, new_img_height = self.get_image_size_for_crop(
224
+ new_img_width, new_img_height, window_size)
225
+ if (new_img_width, new_img_height) != (img_width, img_height):
226
+ img_for_crop = img.resize((new_img_width, new_img_height),
227
+ Image.Resampling.BILINEAR)
228
+ else:
229
+ img_for_crop = img
230
+
231
+ patches = []
232
+ newlines = []
233
+ center_list, (x_num, y_num) = self.slide_window(
234
+ new_img_width, new_img_height, [(window_size, window_size)],
235
+ [(window_size, window_size)])
236
+ for patch_id, center_lf_point in enumerate(center_list):
237
+ x, y, patch_w, patch_h = center_lf_point
238
+ big_patch = self.patch_crop(img_for_crop, y, x, patch_h,
239
+ patch_w)
240
+ patches.append(big_patch)
241
+ if (patch_id + 1) % x_num == 0:
242
+ newlines.append(patch_id)
243
+
244
+ if newlines and newlines[-1] == len(patches) - 1:
245
+ newlines.pop()
246
+
247
+ return img, patches, [i in newlines for i in range(len(patches))] if len(patches) > 0 else None
248
+
249
+
250
+
251
+
252
+ class Step3VLProcessor(ProcessorMixin):
253
+ # Align ProcessorMixin with our custom components.
254
+ # We only have an image processor (not a feature extractor) plus a tokenizer.
255
+ attributes = ["tokenizer"]
256
+ tokenizer_class = "AutoTokenizer"
257
+
258
+ def __init__(
259
+ self,
260
+ tokenizer=None,
261
+ chat_template=None,
262
+ **kwargs
263
+ ) -> None:
264
+ self.image_size = 728
265
+ self.patch_size = 504
266
+
267
+ self.image_preprocessor = Step3VisionProcessor(self.image_size,
268
+ "bilinear",
269
+ self.patch_size)
270
+
271
+ self.num_image_feature_size = 169
272
+ self.num_patch_feature_size = 81
273
+ self.image_token = "<im_patch>"
274
+ self.image_feature_placeholder = (self.image_token *
275
+ self.num_image_feature_size)
276
+ self.patch_feature_placeholder = (self.image_token *
277
+ self.num_patch_feature_size)
278
+ super().__init__(tokenizer=tokenizer, chat_template=chat_template, **kwargs)
279
+ self.patcher = ImagePatcher()
280
+
281
+ @property
282
+ def image_token_id(self) -> int:
283
+ return self.tokenizer.get_vocab()[self.image_token]
284
+
285
+ def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
286
+ num_patches, num_newlines = self.patcher.get_num_patches(
287
+ img_width, img_height)
288
+
289
+ return num_patches * (
290
+ self.num_patch_feature_size +
291
+ 2) + self.num_image_feature_size + 2 + num_newlines
292
+
293
+ def _split_images(self,
294
+ images: list[Image.Image]) -> list[ImageWithPatches]:
295
+ result = []
296
+ for img in images:
297
+ result.append(self.patcher(img))
298
+ return result
299
+
300
+ def _convert_images_to_pixel_values(
301
+ self,
302
+ images: list[Image.Image],
303
+ is_patch: bool = False,
304
+ ) -> list[torch.Tensor]:
305
+ return [
306
+ self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
307
+ for img in images
308
+ ]
309
+
310
+ def _get_patch_repl(
311
+ self,
312
+ num_patches: int,
313
+ patch_newline_mask: list[bool] | None,
314
+ ) -> tuple[str, list[int]]:
315
+ text = ""
316
+ token_ids = []
317
+ for i in range(num_patches):
318
+ assert len(patch_newline_mask) == num_patches
319
+ text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
320
+ token_ids.extend(
321
+ [self.tokenizer.convert_tokens_to_ids("<patch_start>")] +
322
+ [self.image_token_id] * self.num_patch_feature_size +
323
+ [self.tokenizer.convert_tokens_to_ids("<patch_end>")])
324
+ if patch_newline_mask and patch_newline_mask[i]:
325
+ text += "<patch_newline>"
326
+ token_ids.append(
327
+ self.tokenizer.convert_tokens_to_ids("<patch_newline>"))
328
+ return text, token_ids
329
+
330
+ def _get_image_repl(
331
+ self,
332
+ num_images: int,
333
+ ) -> tuple[str, list[int]]:
334
+ text = f"<im_start>{self.image_feature_placeholder}<im_end>"
335
+ token_ids = [
336
+ self.tokenizer.convert_tokens_to_ids("<im_start>")
337
+ ] + [self.image_token_id] * self.num_image_feature_size + [
338
+ self.tokenizer.convert_tokens_to_ids("<im_end>")
339
+ ]
340
+ return text * num_images, token_ids * num_images
341
+
342
+ def _get_image_repl_features(
343
+ self,
344
+ num_images: int,
345
+ num_patches: int,
346
+ patch_new_line_idx: Optional[list[bool]],
347
+ ) -> tuple[str, list[int]]:
348
+ if num_patches > 0:
349
+ patch_repl, patch_repl_ids = self._get_patch_repl(
350
+ num_patches, patch_new_line_idx)
351
+ else:
352
+ patch_repl = ""
353
+ patch_repl_ids = []
354
+ image_repl, image_repl_ids = self._get_image_repl(num_images)
355
+ return patch_repl + image_repl, patch_repl_ids + image_repl_ids
356
+
357
+ def replace_placeholder(self, text: str, placeholder: str,
358
+ repls: list[str]) -> str:
359
+ parts = text.split(placeholder)
360
+
361
+ if len(parts) - 1 != len(repls):
362
+ raise ValueError(
363
+ "The number of placeholders does not match the number of replacements." # noqa: E501
364
+ )
365
+
366
+ result = [parts[0]]
367
+ for i, repl in enumerate(repls):
368
+ result.append(repl)
369
+ result.append(parts[i + 1])
370
+
371
+ return "".join(result)
372
+
373
+ def __call__(
374
+ self,
375
+ text: Optional[Union[str, list[str]]] = None,
376
+ images: ImageInput | None = None,
377
+ return_tensors: Optional[Union[str, TensorType]] = None,
378
+ **kwargs,
379
+ ) -> BatchFeature:
380
+
381
+ if images is not None:
382
+ images = self.image_preprocessor.fetch_images(images)
383
+ if text is None:
384
+ text = []
385
+ if not isinstance(text, list):
386
+ text = [text]
387
+ if images is None:
388
+ images = []
389
+ elif not isinstance(images, list):
390
+ images = [images]
391
+ elif isinstance(images[0], list):
392
+ images = images[0]
393
+
394
+ if len(images) == 0:
395
+ image_inputs = {}
396
+ text_inputs = self.tokenizer(text)
397
+ else:
398
+ splitted_images_data = self._split_images(images)
399
+ pixel_values_lst = []
400
+ patch_pixel_values_lst = []
401
+ patch_newline_mask_lst = []
402
+ image_repl_str_lst = []
403
+ image_repl_ids_lst = []
404
+ num_patches = []
405
+ for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501
406
+ pixel_values_lst.extend(
407
+ self._convert_images_to_pixel_values([raw_img]))
408
+
409
+ if len(img_patches) > 0:
410
+ patch_pixel_values_lst.extend(
411
+ self._convert_images_to_pixel_values(img_patches,
412
+ is_patch=True))
413
+ num_patches.append(len(img_patches))
414
+
415
+ image_repl_str, image_repl_ids = self._get_image_repl_features(
416
+ 1, len(img_patches), patch_newline_mask)
417
+ image_repl_str_lst.append(image_repl_str)
418
+ image_repl_ids_lst.extend(image_repl_ids)
419
+
420
+ if patch_newline_mask is not None:
421
+ patch_newline_mask_lst.extend(patch_newline_mask)
422
+
423
+ image_inputs = {
424
+ "pixel_values": torch.cat(pixel_values_lst),
425
+ "num_patches": num_patches,
426
+ }
427
+ if patch_pixel_values_lst:
428
+ image_inputs["patch_pixel_values"] = torch.cat(
429
+ patch_pixel_values_lst)
430
+ if patch_newline_mask_lst:
431
+ image_inputs["patch_newline_mask"] = torch.tensor(
432
+ patch_newline_mask_lst, dtype=torch.bool)
433
+
434
+ text = [
435
+ self.replace_placeholder(t, self.image_token,
436
+ image_repl_str_lst) for t in text
437
+ ]
438
+ text_inputs = self.tokenizer(text)
439
+
440
+ return BatchFeature(
441
+ {
442
+ **text_inputs,
443
+ **image_inputs,
444
+ },
445
+ tensor_type=return_tensors,
446
+ )
447
+
448
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
449
+ def batch_decode(self, *args, **kwargs):
450
+ """
451
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
452
+ refer to the docstring of this method for more information.
453
+ """
454
+ return self.tokenizer.batch_decode(*args, **kwargs)
455
+
456
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
457
+ def decode(self, *args, **kwargs):
458
+ """
459
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
460
+ the docstring of this method for more information.
461
+ """
462
+ return self.tokenizer.decode(*args, **kwargs)
463
+
464
+ __all__ = ["Step3VLProcessor"]
vision_encoder.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.activations import ACT2FN
7
+
8
+
9
+ from .configuration_step3p7 import StepRoboticsVisionEncoderConfig
10
+
11
+
12
+
13
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
14
+ """Rotate last dimension halves (used by RoPE)."""
15
+ x = x.reshape(*x.shape[:-1], -1, 2)
16
+ x1, x2 = x.unbind(dim=-1)
17
+ x = torch.stack((-x2, x1), dim=-1)
18
+ return x.reshape(*x.shape[:-2], -1)
19
+
20
+
21
+ def apply_rotary_emb(freqs: torch.Tensor,
22
+ t: torch.Tensor,
23
+ start_index: int = 0,
24
+ scale: float = 1.0,
25
+ seq_dim: int = -2) -> torch.Tensor:
26
+ """Apply 2D rotary embeddings to queries / keys."""
27
+ dtype = t.dtype
28
+
29
+ if t.ndim == 3:
30
+ seq_len = t.shape[seq_dim]
31
+ freqs = freqs[-seq_len:]
32
+
33
+ rot_dim = freqs.shape[-1]
34
+ end_index = start_index + rot_dim
35
+ assert rot_dim <= t.shape[-1], (
36
+ f"feature dimension {t.shape[-1]} is too small for rot_dim {rot_dim}")
37
+
38
+ t_left, t, t_right = (
39
+ t[..., :start_index],
40
+ t[..., start_index:end_index],
41
+ t[..., end_index:],
42
+ )
43
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
44
+ out = torch.cat((t_left, t, t_right), dim=-1)
45
+ return out.type(dtype)
46
+
47
+
48
+ class EncoderRope2D(nn.Module):
49
+ """Cacheable 2D rotary positional embedding."""
50
+
51
+ def __init__(
52
+ self,
53
+ dim: int,
54
+ max_grid_height: int,
55
+ max_grid_width: int,
56
+ use_cls_token: bool = False,
57
+ theta: Union[int, float] = 10000,
58
+ max_freq: int = 10,
59
+ num_freqs: int = 1,
60
+ theta_rescale_factor: float = 1.0,
61
+ ):
62
+ super().__init__()
63
+ self.dim = dim
64
+ self.max_grid_height = max_grid_height
65
+ self.max_grid_width = max_grid_width
66
+ self.use_cls_token = use_cls_token
67
+ self.theta = theta * theta_rescale_factor**(dim / (dim - 2))
68
+ self.max_freq = max_freq
69
+ self.num_freqs = num_freqs
70
+ cache = self._compute_2d_freqs()
71
+ self.register_buffer("freqs_cache", cache, persistent=False)
72
+
73
+ def _compute_inv_freq(self, base: Union[int, float],
74
+ dim: int) -> torch.Tensor:
75
+
76
+ freqs = 1.0 / (base**(
77
+ torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
78
+ return freqs
79
+
80
+ def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
81
+ freqs = torch.einsum("..., f -> ... f", t.type(inv_freq.dtype),
82
+ inv_freq)
83
+ freqs = freqs.repeat_interleave(2, dim=-1)
84
+ return freqs
85
+
86
+ def _compute_2d_freqs(self) -> torch.Tensor:
87
+ grid_h_range = torch.arange(self.max_grid_height, dtype=torch.float)
88
+ grid_w_range = torch.arange(self.max_grid_width, dtype=torch.float)
89
+ if self.use_cls_token:
90
+ grid_h_range += 1
91
+ grid_w_range += 1
92
+ inv_freq = self._compute_inv_freq(self.theta, self.dim // 2)
93
+ freqs_h = self._compute_freqs(grid_h_range, inv_freq)[:, None].expand(
94
+ self.max_grid_height, self.max_grid_width, -1)
95
+ freqs_w = self._compute_freqs(grid_w_range, inv_freq)[None, :].expand(
96
+ self.max_grid_height, self.max_grid_width, -1)
97
+ freqs = torch.cat([freqs_w, freqs_h], dim=-1).reshape(
98
+ self.max_grid_height * self.max_grid_width, -1)
99
+ if self.use_cls_token:
100
+ freqs = torch.cat([torch.zeros(1, freqs.shape[-1]), freqs], dim=0)
101
+ freqs = freqs[None, None, ...]
102
+ return freqs
103
+
104
+ def forward(self, q: torch.Tensor, k: torch.Tensor,
105
+ grid_hw: tuple[int, int]):
106
+ # If grid matches cached shape we reuse directly to avoid recomputation.
107
+ if grid_hw[0] != self.max_grid_height or grid_hw[1] != self.max_grid_width:
108
+ rows = torch.arange(grid_hw[0], device=q.device).view(-1, 1)
109
+ cols = torch.arange(grid_hw[1], device=q.device).view(1, -1)
110
+ positions = (rows * self.max_grid_width + cols).reshape(-1).to(
111
+ torch.long)
112
+ if self.use_cls_token:
113
+ positions = torch.cat(
114
+ [torch.zeros(1, device=q.device), positions + 1], dim=0)
115
+ freqs = self.freqs_cache.index_select(2, positions)
116
+ else:
117
+ freqs = self.freqs_cache
118
+ q = apply_rotary_emb(freqs, q)
119
+ k = apply_rotary_emb(freqs, k)
120
+ return q, k
121
+
122
+
123
+ class EncoderLayerScale(nn.Module):
124
+ """Per-channel residual scaling used when ls_init_value is set."""
125
+
126
+ def __init__(self, dim: int, init_values: float):
127
+ super().__init__()
128
+ self.gamma = nn.Parameter(torch.full((dim,), init_values))
129
+
130
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # (B, L, D)
131
+ return hidden_states * self.gamma
132
+
133
+
134
+ class EncoderMLP(nn.Module):
135
+ """Feed-forward network used inside each transformer block."""
136
+
137
+ def __init__(self, hidden_size: int, intermediate_size: int,
138
+ hidden_act: str):
139
+ super().__init__()
140
+ self.c_fc = nn.Linear(hidden_size, intermediate_size, bias=True)
141
+ self.act_fn = ACT2FN[hidden_act]
142
+ self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=True)
143
+
144
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
145
+
146
+ hidden_states = self.c_proj(self.act_fn(self.c_fc(hidden_states)))
147
+ return hidden_states
148
+
149
+
150
+ class EncoderVisionAttention(nn.Module):
151
+ """Multi-head self attention with optional 2D RoPE."""
152
+
153
+ def __init__(
154
+ self,
155
+ hidden_size: int,
156
+ num_heads: int,
157
+ max_grid_height: int,
158
+ max_grid_width: int,
159
+ use_cls_token: bool = False,
160
+ use_rope2d: bool = True,
161
+ rope_theta: Union[int, float] = 10000,
162
+ rope_max_freq: int = 10,
163
+ rope_num_freqs: int = 1,
164
+ rope_theta_rescale_factor: float = 1.0,
165
+ rope_freqs_for: Literal["lang", "pixel", "constant"] = "lang",
166
+ ):
167
+ super().__init__()
168
+ if hidden_size % num_heads != 0:
169
+ raise ValueError(
170
+ f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})."
171
+ )
172
+ self.num_heads = num_heads
173
+ self.head_dim = hidden_size // num_heads
174
+ self.scale = self.head_dim**-0.5
175
+ self.in_proj_weight = nn.Parameter(torch.zeros(hidden_size * 3, hidden_size))
176
+ self.in_proj_bias = nn.Parameter(torch.zeros(hidden_size * 3))
177
+ self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
178
+
179
+ self.rope = None
180
+ if use_rope2d:
181
+ self.rope = EncoderRope2D(
182
+ dim=self.head_dim,
183
+ max_grid_height=max_grid_height,
184
+ max_grid_width=max_grid_width,
185
+ use_cls_token=use_cls_token,
186
+ theta=rope_theta,
187
+ max_freq=rope_max_freq,
188
+ num_freqs=rope_num_freqs,
189
+ theta_rescale_factor=rope_theta_rescale_factor,
190
+ )
191
+
192
+ def forward(self, hidden_states: torch.Tensor, grid_hw: tuple[int, int]) -> torch.Tensor:
193
+ bsz, seq_len, _ = hidden_states.shape
194
+ qkv = F.linear(
195
+ hidden_states,
196
+ self.in_proj_weight,
197
+ self.in_proj_bias,
198
+ )
199
+ q, k, v = qkv.chunk(3, dim=-1)
200
+
201
+ q = q.view(bsz, seq_len, self.num_heads,
202
+ self.head_dim).transpose(1, 2)
203
+ k = k.view(bsz, seq_len, self.num_heads,
204
+ self.head_dim).transpose(1, 2)
205
+ if self.rope is not None:
206
+ q, k = self.rope(q, k, grid_hw=grid_hw)
207
+ v = v.view(bsz, seq_len, self.num_heads,
208
+ self.head_dim).transpose(1, 2)
209
+
210
+ attn_output = F.scaled_dot_product_attention(
211
+ q, k, v, is_causal=False, scale=self.scale)
212
+ attn_output = attn_output.transpose(1, 2).reshape(
213
+ bsz, seq_len, self.num_heads * self.head_dim)
214
+ return self.out_proj(attn_output)
215
+
216
+
217
+ class EncoderVisionBlock(nn.Module):
218
+ """A single Vision Transformer block (self-attention + MLP)."""
219
+
220
+ def __init__(
221
+ self,
222
+ hidden_size: int,
223
+ num_heads: int,
224
+ mlp_ratio: float,
225
+ hidden_act: str,
226
+ layer_norm_eps: float,
227
+ ls_init_value: Optional[float] = None,
228
+ max_grid_height: Optional[int] = None,
229
+ max_grid_width: Optional[int] = None,
230
+ use_cls_token: bool = False,
231
+ use_rope2d: bool = True,
232
+ rope_kwargs: Optional[dict] = None,
233
+ ):
234
+ super().__init__()
235
+ rope_kwargs = rope_kwargs or {}
236
+ self.attn = EncoderVisionAttention(
237
+ hidden_size,
238
+ num_heads,
239
+ max_grid_height=max_grid_height,
240
+ max_grid_width=max_grid_width,
241
+ use_cls_token=use_cls_token,
242
+ use_rope2d=use_rope2d,
243
+ **rope_kwargs,
244
+ )
245
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
246
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
247
+
248
+ intermediate = int(hidden_size * mlp_ratio)
249
+ self.mlp = EncoderMLP(hidden_size, intermediate, hidden_act)
250
+
251
+ self.ls_1 = EncoderLayerScale(hidden_size, ls_init_value)
252
+ self.ls_2 = EncoderLayerScale(hidden_size, ls_init_value)
253
+
254
+ def forward(self, hidden_states: torch.Tensor,
255
+ grid_hw: tuple[int, int]) -> torch.Tensor:
256
+ # breakpoint()
257
+ residual = hidden_states
258
+ hidden_states = self.ln_1(hidden_states)
259
+ hidden_states = self.attn(hidden_states, grid_hw=grid_hw)
260
+ hidden_states = residual + self.ls_1(hidden_states)
261
+
262
+ residual = hidden_states
263
+ hidden_states = self.ln_2(hidden_states)
264
+ hidden_states = self.mlp(hidden_states)
265
+ hidden_states = residual + self.ls_2(hidden_states)
266
+ return hidden_states
267
+
268
+
269
+ class EncoderVisionTransformer(nn.Module):
270
+ """Stack of encoder blocks parameterised by Step35VisionEncoderConfig."""
271
+
272
+ def __init__(
273
+ self,
274
+ embed_dim: int,
275
+ depth: int,
276
+ num_heads: int,
277
+ mlp_ratio: float,
278
+ hidden_act: str,
279
+ layer_norm_eps: float,
280
+ ls_init_value: Optional[float] = None,
281
+ max_grid_height: Optional[int] = None,
282
+ max_grid_width: Optional[int] = None,
283
+ use_cls_token: bool = False,
284
+ use_rope2d: bool = True,
285
+ rope_kwargs: Optional[dict] = None,
286
+ ):
287
+ super().__init__()
288
+ self.layers = depth
289
+ rope_kwargs = rope_kwargs or {}
290
+ self.resblocks = nn.ModuleList([
291
+ EncoderVisionBlock(embed_dim, num_heads, mlp_ratio, hidden_act,
292
+ layer_norm_eps,
293
+ max_grid_height=max_grid_height,
294
+ max_grid_width=max_grid_width,
295
+ use_cls_token=use_cls_token,
296
+ use_rope2d=use_rope2d,
297
+ ls_init_value=ls_init_value,
298
+ rope_kwargs=rope_kwargs)
299
+ for _ in range(depth)
300
+ ])
301
+
302
+ def forward(self,
303
+ hidden_states: torch.Tensor,
304
+ grid_hw: tuple[int, int]) -> torch.Tensor:
305
+ for block in self.resblocks:
306
+ hidden_states = block(hidden_states, grid_hw=grid_hw)
307
+ return hidden_states
308
+
309
+
310
+ class StepRoboticsVisionEncoder(nn.Module):
311
+ """
312
+ Vision encoder built from StepRoboticsVisionEncoderConfig.
313
+
314
+ The encoder performs patch embedding followed by a stack of transformer
315
+ blocks. Only the config fields defined in StepRoboticsVisionEncoderConfig (and
316
+ StepRoboticVLConfig.vision_config) are expected.
317
+ """
318
+
319
+ def __init__(self, config: StepRoboticsVisionEncoderConfig):
320
+ super().__init__()
321
+ self.config = config
322
+
323
+ # Align commonly used attributes so downstream code (e.g. StepRoboticVL)
324
+ # can access them without extra renaming.
325
+ self.hidden_size = config.width
326
+ self.num_heads = config.heads
327
+ self.num_hidden_layers = config.layers
328
+ self.patch_size = config.patch_size
329
+ self.image_size = config.image_size
330
+ self.use_cls_token = getattr(config, "use_cls_token", False)
331
+ self.use_rope2d = getattr(config, "use_rope2d", True)
332
+ self.use_abs_posemb = getattr(config, "use_abs_posemb", True)
333
+ self.layer_norm_eps = config.layer_norm_eps
334
+ self.mlp_ratio = getattr(config, "mlp_ratio", 8960 / 1536)
335
+ self.ls_init_value = getattr(config, "ls_init_value", None)
336
+ self.hidden_act = config.hidden_act
337
+ self.use_ln_pre = getattr(config, "use_ln_pre", False)
338
+ self.use_ln_post = getattr(config, "use_ln_post", True)
339
+
340
+ # Patch embedding.
341
+ self.conv1 = nn.Conv2d(in_channels=config.num_channels,
342
+ out_channels=self.hidden_size,
343
+ kernel_size=self.patch_size,
344
+ stride=self.patch_size,
345
+ bias=False)
346
+
347
+ self.ln_pre = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_pre else nn.Identity()
348
+ self.ln_post = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_post else nn.Identity()
349
+
350
+ grid_size = self.image_size // self.patch_size
351
+ self.base_grid = (grid_size, grid_size)
352
+
353
+ if self.use_cls_token:
354
+ self.class_embedding = nn.Parameter(
355
+ torch.randn(self.hidden_size) * (self.hidden_size**-0.5))
356
+ else:
357
+ self.class_embedding = None
358
+
359
+ if self.use_abs_posemb:
360
+ self.posemb_grid_size = self.image_size // self.patch_size
361
+ self.positional_embedding = nn.Parameter(
362
+ (self.hidden_size**-0.5) * torch.randn(
363
+ int(self.use_cls_token) + self.posemb_grid_size**2,
364
+ self.hidden_size,
365
+ ))
366
+
367
+ self.transformer = EncoderVisionTransformer(
368
+ embed_dim=self.hidden_size,
369
+ depth=self.num_hidden_layers,
370
+ num_heads=self.num_heads,
371
+ mlp_ratio=self.mlp_ratio,
372
+ hidden_act=self.hidden_act,
373
+ layer_norm_eps=self.layer_norm_eps,
374
+ ls_init_value=self.ls_init_value,
375
+ max_grid_height=self.base_grid[0],
376
+ max_grid_width=self.base_grid[1],
377
+ use_cls_token=self.use_cls_token,
378
+ use_rope2d=self.use_rope2d,
379
+ rope_kwargs={
380
+ "rope_theta": getattr(config, "rope_theta", 10000),
381
+ "rope_max_freq": getattr(config, "rope_max_freq", 10),
382
+ "rope_num_freqs": getattr(config, "rope_num_freqs", 1),
383
+ "rope_theta_rescale_factor":
384
+ getattr(config, "rope_theta_rescale_factor", 1.0),
385
+ "rope_freqs_for": getattr(config, "rope_freqs_for", "lang"),
386
+ },
387
+ )
388
+ self.vit_downsampler1 = nn.Conv2d(self.hidden_size,
389
+ self.hidden_size * 2,
390
+ kernel_size=3,
391
+ stride=2,
392
+ padding=1)
393
+ self.vit_downsampler2 = nn.Conv2d(self.hidden_size * 2,
394
+ self.hidden_size * 4,
395
+ kernel_size=3,
396
+ stride=2,
397
+ padding=1)
398
+
399
+
400
+ def sample_abs_posemb(self, grid_h: int, grid_w: int):
401
+ if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
402
+ return self.positional_embedding[None, ...]
403
+
404
+ pos_embed = self.positional_embedding
405
+ if self.use_cls_token:
406
+ cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
407
+
408
+ pos_embed = (pos_embed.reshape(1, self.posemb_grid_size,
409
+ self.posemb_grid_size,
410
+ -1).permute(0, 3, 1, 2).contiguous())
411
+ pos_embed = F.interpolate(pos_embed,
412
+ size=(grid_h, grid_w),
413
+ mode="bilinear",
414
+ align_corners=False)
415
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.hidden_size)
416
+
417
+ if self.use_cls_token:
418
+ pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
419
+
420
+ return pos_embed[None, ...]
421
+
422
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
423
+ """
424
+ Args:
425
+ pixel_values: Image tensor of shape (B, C, H, W).
426
+ layer_idx: Negative indices stop after a given block (e.g., -1 uses all blocks).
427
+ strip_cls_token: If True and cls token is used, remove it from output.
428
+ """
429
+ bsz, _, height, width = pixel_values.shape
430
+ grid_h, grid_w = height // self.patch_size, width // self.patch_size
431
+
432
+ hidden_state = self.conv1(pixel_values) # (B, D, Gh, Gw)
433
+ hidden_state = hidden_state.flatten(2).transpose(1, 2) # (B, Gh*Gw, D)
434
+
435
+ if self.use_cls_token:
436
+ cls_token = self.class_embedding.view(1, 1,
437
+ -1).expand(bsz, -1, -1)
438
+ hidden_state = torch.cat([cls_token, hidden_state], dim=1)
439
+
440
+ if self.use_abs_posemb:
441
+ pos_emb = self.sample_abs_posemb(grid_h, grid_w)
442
+ hidden_state = hidden_state + pos_emb
443
+ hidden_state = self.ln_pre(hidden_state)
444
+ hidden_state = self.transformer(hidden_state, grid_hw=(grid_h, grid_w))
445
+
446
+ if self.use_ln_post:
447
+ hidden_state = self.ln_post(hidden_state)
448
+
449
+ if self.use_cls_token:
450
+ hidden_state = hidden_state[:, 1:, :]
451
+
452
+ return hidden_state