Osaurus-AI commited on
Commit
660e38c
·
verified ·
1 Parent(s): c714b66

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. README.md +100 -0
  2. chat_template.jinja +89 -0
  3. config.json +2650 -0
  4. configuration_step3p7.py +207 -0
  5. generation_config.json +10 -0
  6. hf_quant_config.json +145 -0
  7. jang_config.json +70 -0
  8. model-00001-of-00067.safetensors +3 -0
  9. model-00004-of-00067.safetensors +3 -0
  10. model-00005-of-00067.safetensors +3 -0
  11. model-00012-of-00067.safetensors +3 -0
  12. model-00013-of-00067.safetensors +3 -0
  13. model-00016-of-00067.safetensors +3 -0
  14. model-00017-of-00067.safetensors +3 -0
  15. model-00018-of-00067.safetensors +3 -0
  16. model-00019-of-00067.safetensors +3 -0
  17. model-00020-of-00067.safetensors +3 -0
  18. model-00021-of-00067.safetensors +3 -0
  19. model-00024-of-00067.safetensors +3 -0
  20. model-00025-of-00067.safetensors +3 -0
  21. model-00027-of-00067.safetensors +3 -0
  22. model-00032-of-00067.safetensors +3 -0
  23. model-00033-of-00067.safetensors +3 -0
  24. model-00036-of-00067.safetensors +3 -0
  25. model-00037-of-00067.safetensors +3 -0
  26. model-00038-of-00067.safetensors +3 -0
  27. model-00039-of-00067.safetensors +3 -0
  28. model-00040-of-00067.safetensors +3 -0
  29. model-00041-of-00067.safetensors +3 -0
  30. model-00044-of-00067.safetensors +3 -0
  31. model-00045-of-00067.safetensors +3 -0
  32. model-00052-of-00067.safetensors +3 -0
  33. model-00053-of-00067.safetensors +3 -0
  34. model-00056-of-00067.safetensors +3 -0
  35. model-00057-of-00067.safetensors +3 -0
  36. model-00058-of-00067.safetensors +3 -0
  37. model-00059-of-00067.safetensors +3 -0
  38. model-00060-of-00067.safetensors +3 -0
  39. model-00061-of-00067.safetensors +3 -0
  40. model-00064-of-00067.safetensors +3 -0
  41. model-00065-of-00067.safetensors +3 -0
  42. model-00066-of-00067.safetensors +3 -0
  43. model.safetensors.index.json +0 -0
  44. modeling_step3p7.py +1405 -0
  45. processing_step3.py +475 -0
  46. special_tokens_map.json +23 -0
  47. step3p7_mlx.py +39 -0
  48. tokenizer.json +0 -0
  49. tokenizer_config.json +0 -0
  50. vision_encoder.py +452 -0
README.md ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - zh
5
+ license: apache-2.0
6
+ base_model: stepfun-ai/Step-3.7-Flash-NVFP4
7
+ pipeline_tag: image-text-to-text
8
+ library_name: mlx
9
+ tags:
10
+ - mlx
11
+ - jang
12
+ - jang-2l
13
+ - stepfun
14
+ - vision-language
15
+ ---
16
+
17
+ # Step-3.7-Flash-JANG_2L
18
+
19
+ JANG_2L conversion of [stepfun-ai/Step-3.7-Flash-NVFP4](https://huggingface.co/stepfun-ai/Step-3.7-Flash-NVFP4).
20
+
21
+ This bundle was built from the public NVFP4 checkpoint. Routed MoE tensors were decoded from ModelOpt NVFP4 (`uint8` payload, `float8_e4m3fn` block scales, fp32 side scales) and then re-quantized into JANG affine `weight/scales/biases` tensors. BF16 attention, shared expert, dense, vision, and projector tensors were handled according to the JANG plan.
22
+
23
+ ## Status
24
+
25
+ This artifact has a text-only local coherence proof through the bundled `step3p7_mlx.py` bridge, which loads the nested Step3p5 text model using MLX and drops vision tensors for text generation.
26
+
27
+ Verified locally:
28
+
29
+ - 67 safetensors shards
30
+ - 2,570 tensors in `model.safetensors.index.json`
31
+ - No missing shard references
32
+ - No raw NVFP4 `weight_scale`, `weight_scale_2`, or `input_scale` sidecars are present in the output index
33
+ - `jang_config.json` capability verification passes
34
+ - Text generation proof passes on a math prompt
35
+
36
+ Text proof:
37
+
38
+ ```json
39
+ {
40
+ "prompt": "What is 2+2? Answer with only the number.",
41
+ "output": "The user asks \"What is 2+2? Answer with only the number.\" So the answer is 4. The user wants only the number. So we should output \"4\". There's no disallowed content. It's a simple arithmetic. So we comply.\\n</think>\\n4",
42
+ "prompt_tokens": 26,
43
+ "generated_tokens": 58,
44
+ "prefill_s": 9.161997079849243,
45
+ "total_s": 15.426342725753784,
46
+ "decode_tok_s": 9.25874836391233
47
+ }
48
+ ```
49
+
50
+ Still required before full VLM runtime claims:
51
+
52
+ - Step3p7 VLM wrapper in the target MLX/vMLX runtime
53
+ - image patch token expansion and vision projector path
54
+
55
+ ## Format
56
+
57
+ - Format: JANG affine
58
+ - Profile: `JANG_2L`
59
+ - Quantization backend: `mx.quantize`
60
+ - Default group size: `128`
61
+ - Bit widths used: `2`, `3`, `4`, `6`, `8`
62
+ - Vision/projector: BF16 source converted to F16 passthrough for this first artifact
63
+ - Output size: about `82G`
64
+ - Runtime bridge: `step3p7_mlx.py` wraps `mlx_lm.models.step3p5` for text-only proof
65
+
66
+ Important allocation choices:
67
+
68
+ - `self_attn.{q,k,v,o,g}_proj`: 8-bit
69
+ - `embed_tokens`: 6-bit
70
+ - routed experts: `gate_proj=4`, `down_proj=3`, `up_proj=2`
71
+ - true router/gate tensors: passthrough where present
72
+
73
+ ## Runtime Metadata
74
+
75
+ `jang_config.json` stamps:
76
+
77
+ ```json
78
+ {
79
+ "reasoning_parser": "qwen3",
80
+ "tool_parser": "step3p5",
81
+ "think_in_template": true,
82
+ "supports_tools": true,
83
+ "supports_thinking": true,
84
+ "family": "step3p7",
85
+ "modality": "vision",
86
+ "cache_type": "kv"
87
+ }
88
+ ```
89
+
90
+ The source chat template opens the assistant generation prompt inside `<think>`. Runtimes should not add a second synthetic reasoning prefix.
91
+
92
+ ## Vision And Audio
93
+
94
+ The source checkpoint contains the Step vision encoder and `vit_large_projector`. No audio tensors or audio tokenizer files were present in the downloaded checkpoint.
95
+
96
+ The source config mentions next-token prediction layers, but no MTP/nextn tensors were present in the NVFP4 source. This bundle does not synthesize MTP tensors from config fields.
97
+
98
+ ## Korean
99
+
100
+ 이 번들은 stepfun-ai/Step-3.7-Flash-NVFP4를 JANG_2L 형식으로 변환한 산출물입니다. 텍스트 경로는 `step3p7_mlx.py` 브리지를 통해 로컬 생성 검증을 통과했습니다. 비전 가중치는 포함되어 있지만, 이미지 입력 경로는 아직 별도 런타임 구현과 검증이 필요합니다. 오디오 텐서는 원본 체크포인트에 없었습니다.
chat_template.jinja ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_message_content(message) %}{% if message.content is none %}{{- '' }}{% elif message.content is string %}{{- message.content }}{% elif message.content is mapping %}{{- message.content['value'] if 'value' in message.content else message.content['text'] }}{% elif message.content is iterable %}{% set ns = namespace(needs_text_separator=false) %}{% for item in message.content %}{% if item.type == 'text' %}{% if ns.needs_text_separator %}{{- ' ' }}{% endif %}{{- item['value'] if 'value' in item else item['text'] }}{% set ns.needs_text_separator = true %}{% elif item.type == 'image' %}<im_patch>{% set ns.needs_text_separator = false %}{% endif %}{% endfor %}{% endif %}{% endmacro %}
2
+ {{bos_token}}{%- if tools %}
3
+ {{- '<|im_start|>system\n' }}
4
+ {%- if reasoning_effort is defined %}
5
+ {{- "Reasoning: " + reasoning_effort + '\n\n' }}
6
+ {%- endif %}
7
+ {%- if messages[0].role == 'system' %}
8
+ {{- render_message_content(messages[0]) + '\n\n' }}
9
+ {%- endif %}
10
+ {{- "# Tools\n\nYou have access to the following functions in JSONSchema format:\n\n<tools>" }}
11
+ {%- for tool in tools %}
12
+ {{- "\n" }}
13
+ {{- tool | tojson(ensure_ascii=False) }}
14
+ {%- endfor %}
15
+ {{- "\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...>\n...\n</function> block must be nested within <tool_call>\n...\n</tool_call> XML tags\n- Required parameters MUST be specified\n</IMPORTANT><|im_end|>\n" }}
16
+ {%- else %}
17
+ {%- if messages[0].role == 'system' %}
18
+ {{- '<|im_start|>system\n' }}
19
+ {%- if reasoning_effort is defined %}
20
+ {{- "Reasoning: " + reasoning_effort + '\n\n' }}
21
+ {%- endif %}
22
+ {{- render_message_content(messages[0]) + '<|im_end|>\n' }}
23
+ {%- elif reasoning_effort is defined %}
24
+ {{- '<|im_start|>system\n' + "Reasoning: " + reasoning_effort + '\n\n' + '<|im_end|>\n' }}
25
+ {%- endif %}
26
+ {%- endif %}
27
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
28
+ {%- for message in messages[::-1] %}
29
+ {%- set index = (messages|length - 1) - loop.index0 %}
30
+ {%- if ns.multi_step_tool and message.role == "user" and render_message_content(message) is string and not(render_message_content(message).startswith('<tool_response>') and render_message_content(message).endswith('</tool_response>')) %}
31
+ {%- set ns.multi_step_tool = false %}
32
+ {%- set ns.last_query_index = index %}
33
+ {%- endif %}
34
+ {%- endfor %}
35
+ {%- for message in messages %}
36
+ {%- set content = render_message_content(message) %}
37
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
38
+ {%- set role_name = 'observation' if (message.role == "system" and not loop.first and message.name == 'observation') else message.role %}
39
+ {{- '<|im_start|>' + role_name + '\n' + content + '<|im_end|>' + '\n' }}
40
+ {%- elif message.role == "assistant" %}
41
+ {%- if message.reasoning_content is string %}
42
+ {%- set reasoning_content = message.reasoning_content %}
43
+ {%- else %}
44
+ {%- if '</think>' in content %}
45
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
46
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
47
+ {%- else %}
48
+ {%- set reasoning_content = '' %}
49
+ {%- endif %}
50
+ {%- endif %}
51
+ {%- if loop.index0 > ns.last_query_index %}
52
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n' + content }}
53
+ {%- else %}
54
+ {{- '<|im_start|>' + message.role + '\n' + content }}
55
+ {%- endif %}
56
+ {%- if message.tool_calls %}
57
+ {%- for tool_call in message.tool_calls %}
58
+ {%- if tool_call.function is defined %}
59
+ {%- set tool_call = tool_call.function %}
60
+ {%- endif %}
61
+ {{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
62
+ {%- if tool_call.arguments is defined %}
63
+ {%- set arguments = tool_call.arguments | fromjson if tool_call.arguments is string else tool_call.arguments %}
64
+ {%- for args_name, args_value in arguments|items %}
65
+ {{- '<parameter=' + args_name + '>\n' }}
66
+ {%- set args_value = args_value | tojson(ensure_ascii=False) | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
67
+ {{- args_value }}
68
+ {{- '\n</parameter>\n' }}
69
+ {%- endfor %}
70
+ {%- endif %}
71
+ {{- '</function>\n</tool_call>' }}
72
+ {%- endfor %}
73
+ {%- endif %}
74
+ {{- '<|im_end|>\n' }}
75
+ {%- elif message.role == "tool" %}
76
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
77
+ {{- '<|im_start|>tool_response\n' }}
78
+ {%- endif %}
79
+ {{- '<tool_response>' }}
80
+ {{- content }}
81
+ {{- '</tool_response>' }}
82
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
83
+ {{- '<|im_end|>\n' }}
84
+ {%- endif %}
85
+ {%- endif %}
86
+ {%- endfor %}
87
+ {%- if add_generation_prompt %}
88
+ {{- '<|im_start|>assistant\n<think>\n' }}
89
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,2650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Step3p7ForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_step3p7.Step3p7Config",
7
+ "AutoModelForCausalLM": "modeling_step3p7.Step3p7ForConditionalGeneration",
8
+ "AutoProcessor": "processing_step3.Step3VLProcessor"
9
+ },
10
+ "dtype": "bfloat16",
11
+ "hidden_size": 4096,
12
+ "im_end_token": "<im_end>",
13
+ "im_patch_token": "<im_patch>",
14
+ "im_start_token": "<im_start>",
15
+ "image_token_id": 128001,
16
+ "image_token_len": 169,
17
+ "max_position_embeddings": 262144,
18
+ "model_file": "step3p7_mlx.py",
19
+ "model_type": "step3p7",
20
+ "patch_token_len": 81,
21
+ "projector_bias": false,
22
+ "quantization": {
23
+ "bits": 2,
24
+ "format": "jang",
25
+ "group_size": 128,
26
+ "lm_head": {
27
+ "bits": 8,
28
+ "group_size": 128
29
+ },
30
+ "model.embed_tokens": {
31
+ "bits": 6,
32
+ "group_size": 128
33
+ },
34
+ "model.layers.0.mlp.down_proj": {
35
+ "bits": 8,
36
+ "group_size": 128
37
+ },
38
+ "model.layers.0.mlp.gate_proj": {
39
+ "bits": 8,
40
+ "group_size": 128
41
+ },
42
+ "model.layers.0.mlp.up_proj": {
43
+ "bits": 8,
44
+ "group_size": 128
45
+ },
46
+ "model.layers.0.self_attn.g_proj": {
47
+ "bits": 8,
48
+ "group_size": 128
49
+ },
50
+ "model.layers.0.self_attn.k_proj": {
51
+ "bits": 8,
52
+ "group_size": 128
53
+ },
54
+ "model.layers.0.self_attn.o_proj": {
55
+ "bits": 8,
56
+ "group_size": 128
57
+ },
58
+ "model.layers.0.self_attn.q_proj": {
59
+ "bits": 8,
60
+ "group_size": 128
61
+ },
62
+ "model.layers.0.self_attn.v_proj": {
63
+ "bits": 8,
64
+ "group_size": 128
65
+ },
66
+ "model.layers.1.mlp.down_proj": {
67
+ "bits": 8,
68
+ "group_size": 128
69
+ },
70
+ "model.layers.1.mlp.gate_proj": {
71
+ "bits": 8,
72
+ "group_size": 128
73
+ },
74
+ "model.layers.1.mlp.up_proj": {
75
+ "bits": 8,
76
+ "group_size": 128
77
+ },
78
+ "model.layers.1.self_attn.g_proj": {
79
+ "bits": 8,
80
+ "group_size": 128
81
+ },
82
+ "model.layers.1.self_attn.k_proj": {
83
+ "bits": 8,
84
+ "group_size": 128
85
+ },
86
+ "model.layers.1.self_attn.o_proj": {
87
+ "bits": 8,
88
+ "group_size": 128
89
+ },
90
+ "model.layers.1.self_attn.q_proj": {
91
+ "bits": 8,
92
+ "group_size": 128
93
+ },
94
+ "model.layers.1.self_attn.v_proj": {
95
+ "bits": 8,
96
+ "group_size": 128
97
+ },
98
+ "model.layers.10.mlp.gate.gate": {
99
+ "bits": 8,
100
+ "group_size": 64
101
+ },
102
+ "model.layers.10.mlp.share_expert.down_proj": {
103
+ "bits": 3,
104
+ "group_size": 128
105
+ },
106
+ "model.layers.10.mlp.share_expert.gate_proj": {
107
+ "bits": 4,
108
+ "group_size": 128
109
+ },
110
+ "model.layers.10.mlp.share_expert.up_proj": {
111
+ "bits": 2,
112
+ "group_size": 128
113
+ },
114
+ "model.layers.10.mlp.switch_mlp.down_proj": {
115
+ "bits": 3,
116
+ "group_size": 128
117
+ },
118
+ "model.layers.10.mlp.switch_mlp.gate_proj": {
119
+ "bits": 4,
120
+ "group_size": 128
121
+ },
122
+ "model.layers.10.mlp.switch_mlp.up_proj": {
123
+ "bits": 2,
124
+ "group_size": 128
125
+ },
126
+ "model.layers.10.self_attn.g_proj": {
127
+ "bits": 8,
128
+ "group_size": 128
129
+ },
130
+ "model.layers.10.self_attn.k_proj": {
131
+ "bits": 8,
132
+ "group_size": 128
133
+ },
134
+ "model.layers.10.self_attn.o_proj": {
135
+ "bits": 8,
136
+ "group_size": 128
137
+ },
138
+ "model.layers.10.self_attn.q_proj": {
139
+ "bits": 8,
140
+ "group_size": 128
141
+ },
142
+ "model.layers.10.self_attn.v_proj": {
143
+ "bits": 8,
144
+ "group_size": 128
145
+ },
146
+ "model.layers.11.mlp.gate.gate": {
147
+ "bits": 8,
148
+ "group_size": 64
149
+ },
150
+ "model.layers.11.mlp.share_expert.down_proj": {
151
+ "bits": 3,
152
+ "group_size": 128
153
+ },
154
+ "model.layers.11.mlp.share_expert.gate_proj": {
155
+ "bits": 4,
156
+ "group_size": 128
157
+ },
158
+ "model.layers.11.mlp.share_expert.up_proj": {
159
+ "bits": 2,
160
+ "group_size": 128
161
+ },
162
+ "model.layers.11.mlp.switch_mlp.down_proj": {
163
+ "bits": 3,
164
+ "group_size": 128
165
+ },
166
+ "model.layers.11.mlp.switch_mlp.gate_proj": {
167
+ "bits": 4,
168
+ "group_size": 128
169
+ },
170
+ "model.layers.11.mlp.switch_mlp.up_proj": {
171
+ "bits": 2,
172
+ "group_size": 128
173
+ },
174
+ "model.layers.11.self_attn.g_proj": {
175
+ "bits": 8,
176
+ "group_size": 128
177
+ },
178
+ "model.layers.11.self_attn.k_proj": {
179
+ "bits": 8,
180
+ "group_size": 128
181
+ },
182
+ "model.layers.11.self_attn.o_proj": {
183
+ "bits": 8,
184
+ "group_size": 128
185
+ },
186
+ "model.layers.11.self_attn.q_proj": {
187
+ "bits": 8,
188
+ "group_size": 128
189
+ },
190
+ "model.layers.11.self_attn.v_proj": {
191
+ "bits": 8,
192
+ "group_size": 128
193
+ },
194
+ "model.layers.12.mlp.gate.gate": {
195
+ "bits": 8,
196
+ "group_size": 64
197
+ },
198
+ "model.layers.12.mlp.share_expert.down_proj": {
199
+ "bits": 3,
200
+ "group_size": 128
201
+ },
202
+ "model.layers.12.mlp.share_expert.gate_proj": {
203
+ "bits": 4,
204
+ "group_size": 128
205
+ },
206
+ "model.layers.12.mlp.share_expert.up_proj": {
207
+ "bits": 2,
208
+ "group_size": 128
209
+ },
210
+ "model.layers.12.mlp.switch_mlp.down_proj": {
211
+ "bits": 3,
212
+ "group_size": 128
213
+ },
214
+ "model.layers.12.mlp.switch_mlp.gate_proj": {
215
+ "bits": 4,
216
+ "group_size": 128
217
+ },
218
+ "model.layers.12.mlp.switch_mlp.up_proj": {
219
+ "bits": 2,
220
+ "group_size": 128
221
+ },
222
+ "model.layers.12.self_attn.g_proj": {
223
+ "bits": 8,
224
+ "group_size": 128
225
+ },
226
+ "model.layers.12.self_attn.k_proj": {
227
+ "bits": 8,
228
+ "group_size": 128
229
+ },
230
+ "model.layers.12.self_attn.o_proj": {
231
+ "bits": 8,
232
+ "group_size": 128
233
+ },
234
+ "model.layers.12.self_attn.q_proj": {
235
+ "bits": 8,
236
+ "group_size": 128
237
+ },
238
+ "model.layers.12.self_attn.v_proj": {
239
+ "bits": 8,
240
+ "group_size": 128
241
+ },
242
+ "model.layers.13.mlp.gate.gate": {
243
+ "bits": 8,
244
+ "group_size": 64
245
+ },
246
+ "model.layers.13.mlp.share_expert.down_proj": {
247
+ "bits": 3,
248
+ "group_size": 128
249
+ },
250
+ "model.layers.13.mlp.share_expert.gate_proj": {
251
+ "bits": 4,
252
+ "group_size": 128
253
+ },
254
+ "model.layers.13.mlp.share_expert.up_proj": {
255
+ "bits": 2,
256
+ "group_size": 128
257
+ },
258
+ "model.layers.13.mlp.switch_mlp.down_proj": {
259
+ "bits": 3,
260
+ "group_size": 128
261
+ },
262
+ "model.layers.13.mlp.switch_mlp.gate_proj": {
263
+ "bits": 4,
264
+ "group_size": 128
265
+ },
266
+ "model.layers.13.mlp.switch_mlp.up_proj": {
267
+ "bits": 2,
268
+ "group_size": 128
269
+ },
270
+ "model.layers.13.self_attn.g_proj": {
271
+ "bits": 8,
272
+ "group_size": 128
273
+ },
274
+ "model.layers.13.self_attn.k_proj": {
275
+ "bits": 8,
276
+ "group_size": 128
277
+ },
278
+ "model.layers.13.self_attn.o_proj": {
279
+ "bits": 8,
280
+ "group_size": 128
281
+ },
282
+ "model.layers.13.self_attn.q_proj": {
283
+ "bits": 8,
284
+ "group_size": 128
285
+ },
286
+ "model.layers.13.self_attn.v_proj": {
287
+ "bits": 8,
288
+ "group_size": 128
289
+ },
290
+ "model.layers.14.mlp.gate.gate": {
291
+ "bits": 8,
292
+ "group_size": 64
293
+ },
294
+ "model.layers.14.mlp.share_expert.down_proj": {
295
+ "bits": 3,
296
+ "group_size": 128
297
+ },
298
+ "model.layers.14.mlp.share_expert.gate_proj": {
299
+ "bits": 4,
300
+ "group_size": 128
301
+ },
302
+ "model.layers.14.mlp.share_expert.up_proj": {
303
+ "bits": 2,
304
+ "group_size": 128
305
+ },
306
+ "model.layers.14.mlp.switch_mlp.down_proj": {
307
+ "bits": 3,
308
+ "group_size": 128
309
+ },
310
+ "model.layers.14.mlp.switch_mlp.gate_proj": {
311
+ "bits": 4,
312
+ "group_size": 128
313
+ },
314
+ "model.layers.14.mlp.switch_mlp.up_proj": {
315
+ "bits": 2,
316
+ "group_size": 128
317
+ },
318
+ "model.layers.14.self_attn.g_proj": {
319
+ "bits": 8,
320
+ "group_size": 128
321
+ },
322
+ "model.layers.14.self_attn.k_proj": {
323
+ "bits": 8,
324
+ "group_size": 128
325
+ },
326
+ "model.layers.14.self_attn.o_proj": {
327
+ "bits": 8,
328
+ "group_size": 128
329
+ },
330
+ "model.layers.14.self_attn.q_proj": {
331
+ "bits": 8,
332
+ "group_size": 128
333
+ },
334
+ "model.layers.14.self_attn.v_proj": {
335
+ "bits": 8,
336
+ "group_size": 128
337
+ },
338
+ "model.layers.15.mlp.gate.gate": {
339
+ "bits": 8,
340
+ "group_size": 64
341
+ },
342
+ "model.layers.15.mlp.share_expert.down_proj": {
343
+ "bits": 3,
344
+ "group_size": 128
345
+ },
346
+ "model.layers.15.mlp.share_expert.gate_proj": {
347
+ "bits": 4,
348
+ "group_size": 128
349
+ },
350
+ "model.layers.15.mlp.share_expert.up_proj": {
351
+ "bits": 2,
352
+ "group_size": 128
353
+ },
354
+ "model.layers.15.mlp.switch_mlp.down_proj": {
355
+ "bits": 3,
356
+ "group_size": 128
357
+ },
358
+ "model.layers.15.mlp.switch_mlp.gate_proj": {
359
+ "bits": 4,
360
+ "group_size": 128
361
+ },
362
+ "model.layers.15.mlp.switch_mlp.up_proj": {
363
+ "bits": 2,
364
+ "group_size": 128
365
+ },
366
+ "model.layers.15.self_attn.g_proj": {
367
+ "bits": 8,
368
+ "group_size": 128
369
+ },
370
+ "model.layers.15.self_attn.k_proj": {
371
+ "bits": 8,
372
+ "group_size": 128
373
+ },
374
+ "model.layers.15.self_attn.o_proj": {
375
+ "bits": 8,
376
+ "group_size": 128
377
+ },
378
+ "model.layers.15.self_attn.q_proj": {
379
+ "bits": 8,
380
+ "group_size": 128
381
+ },
382
+ "model.layers.15.self_attn.v_proj": {
383
+ "bits": 8,
384
+ "group_size": 128
385
+ },
386
+ "model.layers.16.mlp.gate.gate": {
387
+ "bits": 8,
388
+ "group_size": 64
389
+ },
390
+ "model.layers.16.mlp.share_expert.down_proj": {
391
+ "bits": 3,
392
+ "group_size": 128
393
+ },
394
+ "model.layers.16.mlp.share_expert.gate_proj": {
395
+ "bits": 4,
396
+ "group_size": 128
397
+ },
398
+ "model.layers.16.mlp.share_expert.up_proj": {
399
+ "bits": 2,
400
+ "group_size": 128
401
+ },
402
+ "model.layers.16.mlp.switch_mlp.down_proj": {
403
+ "bits": 3,
404
+ "group_size": 128
405
+ },
406
+ "model.layers.16.mlp.switch_mlp.gate_proj": {
407
+ "bits": 4,
408
+ "group_size": 128
409
+ },
410
+ "model.layers.16.mlp.switch_mlp.up_proj": {
411
+ "bits": 2,
412
+ "group_size": 128
413
+ },
414
+ "model.layers.16.self_attn.g_proj": {
415
+ "bits": 8,
416
+ "group_size": 128
417
+ },
418
+ "model.layers.16.self_attn.k_proj": {
419
+ "bits": 8,
420
+ "group_size": 128
421
+ },
422
+ "model.layers.16.self_attn.o_proj": {
423
+ "bits": 8,
424
+ "group_size": 128
425
+ },
426
+ "model.layers.16.self_attn.q_proj": {
427
+ "bits": 8,
428
+ "group_size": 128
429
+ },
430
+ "model.layers.16.self_attn.v_proj": {
431
+ "bits": 8,
432
+ "group_size": 128
433
+ },
434
+ "model.layers.17.mlp.gate.gate": {
435
+ "bits": 8,
436
+ "group_size": 64
437
+ },
438
+ "model.layers.17.mlp.share_expert.down_proj": {
439
+ "bits": 3,
440
+ "group_size": 128
441
+ },
442
+ "model.layers.17.mlp.share_expert.gate_proj": {
443
+ "bits": 4,
444
+ "group_size": 128
445
+ },
446
+ "model.layers.17.mlp.share_expert.up_proj": {
447
+ "bits": 2,
448
+ "group_size": 128
449
+ },
450
+ "model.layers.17.mlp.switch_mlp.down_proj": {
451
+ "bits": 3,
452
+ "group_size": 128
453
+ },
454
+ "model.layers.17.mlp.switch_mlp.gate_proj": {
455
+ "bits": 4,
456
+ "group_size": 128
457
+ },
458
+ "model.layers.17.mlp.switch_mlp.up_proj": {
459
+ "bits": 2,
460
+ "group_size": 128
461
+ },
462
+ "model.layers.17.self_attn.g_proj": {
463
+ "bits": 8,
464
+ "group_size": 128
465
+ },
466
+ "model.layers.17.self_attn.k_proj": {
467
+ "bits": 8,
468
+ "group_size": 128
469
+ },
470
+ "model.layers.17.self_attn.o_proj": {
471
+ "bits": 8,
472
+ "group_size": 128
473
+ },
474
+ "model.layers.17.self_attn.q_proj": {
475
+ "bits": 8,
476
+ "group_size": 128
477
+ },
478
+ "model.layers.17.self_attn.v_proj": {
479
+ "bits": 8,
480
+ "group_size": 128
481
+ },
482
+ "model.layers.18.mlp.gate.gate": {
483
+ "bits": 8,
484
+ "group_size": 64
485
+ },
486
+ "model.layers.18.mlp.share_expert.down_proj": {
487
+ "bits": 3,
488
+ "group_size": 128
489
+ },
490
+ "model.layers.18.mlp.share_expert.gate_proj": {
491
+ "bits": 4,
492
+ "group_size": 128
493
+ },
494
+ "model.layers.18.mlp.share_expert.up_proj": {
495
+ "bits": 2,
496
+ "group_size": 128
497
+ },
498
+ "model.layers.18.mlp.switch_mlp.down_proj": {
499
+ "bits": 3,
500
+ "group_size": 128
501
+ },
502
+ "model.layers.18.mlp.switch_mlp.gate_proj": {
503
+ "bits": 4,
504
+ "group_size": 128
505
+ },
506
+ "model.layers.18.mlp.switch_mlp.up_proj": {
507
+ "bits": 2,
508
+ "group_size": 128
509
+ },
510
+ "model.layers.18.self_attn.g_proj": {
511
+ "bits": 8,
512
+ "group_size": 128
513
+ },
514
+ "model.layers.18.self_attn.k_proj": {
515
+ "bits": 8,
516
+ "group_size": 128
517
+ },
518
+ "model.layers.18.self_attn.o_proj": {
519
+ "bits": 8,
520
+ "group_size": 128
521
+ },
522
+ "model.layers.18.self_attn.q_proj": {
523
+ "bits": 8,
524
+ "group_size": 128
525
+ },
526
+ "model.layers.18.self_attn.v_proj": {
527
+ "bits": 8,
528
+ "group_size": 128
529
+ },
530
+ "model.layers.19.mlp.gate.gate": {
531
+ "bits": 8,
532
+ "group_size": 64
533
+ },
534
+ "model.layers.19.mlp.share_expert.down_proj": {
535
+ "bits": 3,
536
+ "group_size": 128
537
+ },
538
+ "model.layers.19.mlp.share_expert.gate_proj": {
539
+ "bits": 4,
540
+ "group_size": 128
541
+ },
542
+ "model.layers.19.mlp.share_expert.up_proj": {
543
+ "bits": 2,
544
+ "group_size": 128
545
+ },
546
+ "model.layers.19.mlp.switch_mlp.down_proj": {
547
+ "bits": 3,
548
+ "group_size": 128
549
+ },
550
+ "model.layers.19.mlp.switch_mlp.gate_proj": {
551
+ "bits": 4,
552
+ "group_size": 128
553
+ },
554
+ "model.layers.19.mlp.switch_mlp.up_proj": {
555
+ "bits": 2,
556
+ "group_size": 128
557
+ },
558
+ "model.layers.19.self_attn.g_proj": {
559
+ "bits": 8,
560
+ "group_size": 128
561
+ },
562
+ "model.layers.19.self_attn.k_proj": {
563
+ "bits": 8,
564
+ "group_size": 128
565
+ },
566
+ "model.layers.19.self_attn.o_proj": {
567
+ "bits": 8,
568
+ "group_size": 128
569
+ },
570
+ "model.layers.19.self_attn.q_proj": {
571
+ "bits": 8,
572
+ "group_size": 128
573
+ },
574
+ "model.layers.19.self_attn.v_proj": {
575
+ "bits": 8,
576
+ "group_size": 128
577
+ },
578
+ "model.layers.2.mlp.down_proj": {
579
+ "bits": 8,
580
+ "group_size": 128
581
+ },
582
+ "model.layers.2.mlp.gate_proj": {
583
+ "bits": 8,
584
+ "group_size": 128
585
+ },
586
+ "model.layers.2.mlp.up_proj": {
587
+ "bits": 8,
588
+ "group_size": 128
589
+ },
590
+ "model.layers.2.self_attn.g_proj": {
591
+ "bits": 8,
592
+ "group_size": 128
593
+ },
594
+ "model.layers.2.self_attn.k_proj": {
595
+ "bits": 8,
596
+ "group_size": 128
597
+ },
598
+ "model.layers.2.self_attn.o_proj": {
599
+ "bits": 8,
600
+ "group_size": 128
601
+ },
602
+ "model.layers.2.self_attn.q_proj": {
603
+ "bits": 8,
604
+ "group_size": 128
605
+ },
606
+ "model.layers.2.self_attn.v_proj": {
607
+ "bits": 8,
608
+ "group_size": 128
609
+ },
610
+ "model.layers.20.mlp.gate.gate": {
611
+ "bits": 8,
612
+ "group_size": 64
613
+ },
614
+ "model.layers.20.mlp.share_expert.down_proj": {
615
+ "bits": 3,
616
+ "group_size": 128
617
+ },
618
+ "model.layers.20.mlp.share_expert.gate_proj": {
619
+ "bits": 4,
620
+ "group_size": 128
621
+ },
622
+ "model.layers.20.mlp.share_expert.up_proj": {
623
+ "bits": 2,
624
+ "group_size": 128
625
+ },
626
+ "model.layers.20.mlp.switch_mlp.down_proj": {
627
+ "bits": 3,
628
+ "group_size": 128
629
+ },
630
+ "model.layers.20.mlp.switch_mlp.gate_proj": {
631
+ "bits": 4,
632
+ "group_size": 128
633
+ },
634
+ "model.layers.20.mlp.switch_mlp.up_proj": {
635
+ "bits": 2,
636
+ "group_size": 128
637
+ },
638
+ "model.layers.20.self_attn.g_proj": {
639
+ "bits": 8,
640
+ "group_size": 128
641
+ },
642
+ "model.layers.20.self_attn.k_proj": {
643
+ "bits": 8,
644
+ "group_size": 128
645
+ },
646
+ "model.layers.20.self_attn.o_proj": {
647
+ "bits": 8,
648
+ "group_size": 128
649
+ },
650
+ "model.layers.20.self_attn.q_proj": {
651
+ "bits": 8,
652
+ "group_size": 128
653
+ },
654
+ "model.layers.20.self_attn.v_proj": {
655
+ "bits": 8,
656
+ "group_size": 128
657
+ },
658
+ "model.layers.21.mlp.gate.gate": {
659
+ "bits": 8,
660
+ "group_size": 64
661
+ },
662
+ "model.layers.21.mlp.share_expert.down_proj": {
663
+ "bits": 3,
664
+ "group_size": 128
665
+ },
666
+ "model.layers.21.mlp.share_expert.gate_proj": {
667
+ "bits": 4,
668
+ "group_size": 128
669
+ },
670
+ "model.layers.21.mlp.share_expert.up_proj": {
671
+ "bits": 2,
672
+ "group_size": 128
673
+ },
674
+ "model.layers.21.mlp.switch_mlp.down_proj": {
675
+ "bits": 3,
676
+ "group_size": 128
677
+ },
678
+ "model.layers.21.mlp.switch_mlp.gate_proj": {
679
+ "bits": 4,
680
+ "group_size": 128
681
+ },
682
+ "model.layers.21.mlp.switch_mlp.up_proj": {
683
+ "bits": 2,
684
+ "group_size": 128
685
+ },
686
+ "model.layers.21.self_attn.g_proj": {
687
+ "bits": 8,
688
+ "group_size": 128
689
+ },
690
+ "model.layers.21.self_attn.k_proj": {
691
+ "bits": 8,
692
+ "group_size": 128
693
+ },
694
+ "model.layers.21.self_attn.o_proj": {
695
+ "bits": 8,
696
+ "group_size": 128
697
+ },
698
+ "model.layers.21.self_attn.q_proj": {
699
+ "bits": 8,
700
+ "group_size": 128
701
+ },
702
+ "model.layers.21.self_attn.v_proj": {
703
+ "bits": 8,
704
+ "group_size": 128
705
+ },
706
+ "model.layers.22.mlp.gate.gate": {
707
+ "bits": 8,
708
+ "group_size": 64
709
+ },
710
+ "model.layers.22.mlp.share_expert.down_proj": {
711
+ "bits": 3,
712
+ "group_size": 128
713
+ },
714
+ "model.layers.22.mlp.share_expert.gate_proj": {
715
+ "bits": 4,
716
+ "group_size": 128
717
+ },
718
+ "model.layers.22.mlp.share_expert.up_proj": {
719
+ "bits": 2,
720
+ "group_size": 128
721
+ },
722
+ "model.layers.22.mlp.switch_mlp.down_proj": {
723
+ "bits": 3,
724
+ "group_size": 128
725
+ },
726
+ "model.layers.22.mlp.switch_mlp.gate_proj": {
727
+ "bits": 4,
728
+ "group_size": 128
729
+ },
730
+ "model.layers.22.mlp.switch_mlp.up_proj": {
731
+ "bits": 2,
732
+ "group_size": 128
733
+ },
734
+ "model.layers.22.self_attn.g_proj": {
735
+ "bits": 8,
736
+ "group_size": 128
737
+ },
738
+ "model.layers.22.self_attn.k_proj": {
739
+ "bits": 8,
740
+ "group_size": 128
741
+ },
742
+ "model.layers.22.self_attn.o_proj": {
743
+ "bits": 8,
744
+ "group_size": 128
745
+ },
746
+ "model.layers.22.self_attn.q_proj": {
747
+ "bits": 8,
748
+ "group_size": 128
749
+ },
750
+ "model.layers.22.self_attn.v_proj": {
751
+ "bits": 8,
752
+ "group_size": 128
753
+ },
754
+ "model.layers.23.mlp.gate.gate": {
755
+ "bits": 8,
756
+ "group_size": 64
757
+ },
758
+ "model.layers.23.mlp.share_expert.down_proj": {
759
+ "bits": 3,
760
+ "group_size": 128
761
+ },
762
+ "model.layers.23.mlp.share_expert.gate_proj": {
763
+ "bits": 4,
764
+ "group_size": 128
765
+ },
766
+ "model.layers.23.mlp.share_expert.up_proj": {
767
+ "bits": 2,
768
+ "group_size": 128
769
+ },
770
+ "model.layers.23.mlp.switch_mlp.down_proj": {
771
+ "bits": 3,
772
+ "group_size": 128
773
+ },
774
+ "model.layers.23.mlp.switch_mlp.gate_proj": {
775
+ "bits": 4,
776
+ "group_size": 128
777
+ },
778
+ "model.layers.23.mlp.switch_mlp.up_proj": {
779
+ "bits": 2,
780
+ "group_size": 128
781
+ },
782
+ "model.layers.23.self_attn.g_proj": {
783
+ "bits": 8,
784
+ "group_size": 128
785
+ },
786
+ "model.layers.23.self_attn.k_proj": {
787
+ "bits": 8,
788
+ "group_size": 128
789
+ },
790
+ "model.layers.23.self_attn.o_proj": {
791
+ "bits": 8,
792
+ "group_size": 128
793
+ },
794
+ "model.layers.23.self_attn.q_proj": {
795
+ "bits": 8,
796
+ "group_size": 128
797
+ },
798
+ "model.layers.23.self_attn.v_proj": {
799
+ "bits": 8,
800
+ "group_size": 128
801
+ },
802
+ "model.layers.24.mlp.gate.gate": {
803
+ "bits": 8,
804
+ "group_size": 64
805
+ },
806
+ "model.layers.24.mlp.share_expert.down_proj": {
807
+ "bits": 3,
808
+ "group_size": 128
809
+ },
810
+ "model.layers.24.mlp.share_expert.gate_proj": {
811
+ "bits": 4,
812
+ "group_size": 128
813
+ },
814
+ "model.layers.24.mlp.share_expert.up_proj": {
815
+ "bits": 2,
816
+ "group_size": 128
817
+ },
818
+ "model.layers.24.mlp.switch_mlp.down_proj": {
819
+ "bits": 3,
820
+ "group_size": 128
821
+ },
822
+ "model.layers.24.mlp.switch_mlp.gate_proj": {
823
+ "bits": 4,
824
+ "group_size": 128
825
+ },
826
+ "model.layers.24.mlp.switch_mlp.up_proj": {
827
+ "bits": 2,
828
+ "group_size": 128
829
+ },
830
+ "model.layers.24.self_attn.g_proj": {
831
+ "bits": 8,
832
+ "group_size": 128
833
+ },
834
+ "model.layers.24.self_attn.k_proj": {
835
+ "bits": 8,
836
+ "group_size": 128
837
+ },
838
+ "model.layers.24.self_attn.o_proj": {
839
+ "bits": 8,
840
+ "group_size": 128
841
+ },
842
+ "model.layers.24.self_attn.q_proj": {
843
+ "bits": 8,
844
+ "group_size": 128
845
+ },
846
+ "model.layers.24.self_attn.v_proj": {
847
+ "bits": 8,
848
+ "group_size": 128
849
+ },
850
+ "model.layers.25.mlp.gate.gate": {
851
+ "bits": 8,
852
+ "group_size": 64
853
+ },
854
+ "model.layers.25.mlp.share_expert.down_proj": {
855
+ "bits": 3,
856
+ "group_size": 128
857
+ },
858
+ "model.layers.25.mlp.share_expert.gate_proj": {
859
+ "bits": 4,
860
+ "group_size": 128
861
+ },
862
+ "model.layers.25.mlp.share_expert.up_proj": {
863
+ "bits": 2,
864
+ "group_size": 128
865
+ },
866
+ "model.layers.25.mlp.switch_mlp.down_proj": {
867
+ "bits": 3,
868
+ "group_size": 128
869
+ },
870
+ "model.layers.25.mlp.switch_mlp.gate_proj": {
871
+ "bits": 4,
872
+ "group_size": 128
873
+ },
874
+ "model.layers.25.mlp.switch_mlp.up_proj": {
875
+ "bits": 2,
876
+ "group_size": 128
877
+ },
878
+ "model.layers.25.self_attn.g_proj": {
879
+ "bits": 8,
880
+ "group_size": 128
881
+ },
882
+ "model.layers.25.self_attn.k_proj": {
883
+ "bits": 8,
884
+ "group_size": 128
885
+ },
886
+ "model.layers.25.self_attn.o_proj": {
887
+ "bits": 8,
888
+ "group_size": 128
889
+ },
890
+ "model.layers.25.self_attn.q_proj": {
891
+ "bits": 8,
892
+ "group_size": 128
893
+ },
894
+ "model.layers.25.self_attn.v_proj": {
895
+ "bits": 8,
896
+ "group_size": 128
897
+ },
898
+ "model.layers.26.mlp.gate.gate": {
899
+ "bits": 8,
900
+ "group_size": 64
901
+ },
902
+ "model.layers.26.mlp.share_expert.down_proj": {
903
+ "bits": 3,
904
+ "group_size": 128
905
+ },
906
+ "model.layers.26.mlp.share_expert.gate_proj": {
907
+ "bits": 4,
908
+ "group_size": 128
909
+ },
910
+ "model.layers.26.mlp.share_expert.up_proj": {
911
+ "bits": 2,
912
+ "group_size": 128
913
+ },
914
+ "model.layers.26.mlp.switch_mlp.down_proj": {
915
+ "bits": 3,
916
+ "group_size": 128
917
+ },
918
+ "model.layers.26.mlp.switch_mlp.gate_proj": {
919
+ "bits": 4,
920
+ "group_size": 128
921
+ },
922
+ "model.layers.26.mlp.switch_mlp.up_proj": {
923
+ "bits": 2,
924
+ "group_size": 128
925
+ },
926
+ "model.layers.26.self_attn.g_proj": {
927
+ "bits": 8,
928
+ "group_size": 128
929
+ },
930
+ "model.layers.26.self_attn.k_proj": {
931
+ "bits": 8,
932
+ "group_size": 128
933
+ },
934
+ "model.layers.26.self_attn.o_proj": {
935
+ "bits": 8,
936
+ "group_size": 128
937
+ },
938
+ "model.layers.26.self_attn.q_proj": {
939
+ "bits": 8,
940
+ "group_size": 128
941
+ },
942
+ "model.layers.26.self_attn.v_proj": {
943
+ "bits": 8,
944
+ "group_size": 128
945
+ },
946
+ "model.layers.27.mlp.gate.gate": {
947
+ "bits": 8,
948
+ "group_size": 64
949
+ },
950
+ "model.layers.27.mlp.share_expert.down_proj": {
951
+ "bits": 3,
952
+ "group_size": 128
953
+ },
954
+ "model.layers.27.mlp.share_expert.gate_proj": {
955
+ "bits": 4,
956
+ "group_size": 128
957
+ },
958
+ "model.layers.27.mlp.share_expert.up_proj": {
959
+ "bits": 2,
960
+ "group_size": 128
961
+ },
962
+ "model.layers.27.mlp.switch_mlp.down_proj": {
963
+ "bits": 3,
964
+ "group_size": 128
965
+ },
966
+ "model.layers.27.mlp.switch_mlp.gate_proj": {
967
+ "bits": 4,
968
+ "group_size": 128
969
+ },
970
+ "model.layers.27.mlp.switch_mlp.up_proj": {
971
+ "bits": 2,
972
+ "group_size": 128
973
+ },
974
+ "model.layers.27.self_attn.g_proj": {
975
+ "bits": 8,
976
+ "group_size": 128
977
+ },
978
+ "model.layers.27.self_attn.k_proj": {
979
+ "bits": 8,
980
+ "group_size": 128
981
+ },
982
+ "model.layers.27.self_attn.o_proj": {
983
+ "bits": 8,
984
+ "group_size": 128
985
+ },
986
+ "model.layers.27.self_attn.q_proj": {
987
+ "bits": 8,
988
+ "group_size": 128
989
+ },
990
+ "model.layers.27.self_attn.v_proj": {
991
+ "bits": 8,
992
+ "group_size": 128
993
+ },
994
+ "model.layers.28.mlp.gate.gate": {
995
+ "bits": 8,
996
+ "group_size": 64
997
+ },
998
+ "model.layers.28.mlp.share_expert.down_proj": {
999
+ "bits": 3,
1000
+ "group_size": 128
1001
+ },
1002
+ "model.layers.28.mlp.share_expert.gate_proj": {
1003
+ "bits": 4,
1004
+ "group_size": 128
1005
+ },
1006
+ "model.layers.28.mlp.share_expert.up_proj": {
1007
+ "bits": 2,
1008
+ "group_size": 128
1009
+ },
1010
+ "model.layers.28.mlp.switch_mlp.down_proj": {
1011
+ "bits": 3,
1012
+ "group_size": 128
1013
+ },
1014
+ "model.layers.28.mlp.switch_mlp.gate_proj": {
1015
+ "bits": 4,
1016
+ "group_size": 128
1017
+ },
1018
+ "model.layers.28.mlp.switch_mlp.up_proj": {
1019
+ "bits": 2,
1020
+ "group_size": 128
1021
+ },
1022
+ "model.layers.28.self_attn.g_proj": {
1023
+ "bits": 8,
1024
+ "group_size": 128
1025
+ },
1026
+ "model.layers.28.self_attn.k_proj": {
1027
+ "bits": 8,
1028
+ "group_size": 128
1029
+ },
1030
+ "model.layers.28.self_attn.o_proj": {
1031
+ "bits": 8,
1032
+ "group_size": 128
1033
+ },
1034
+ "model.layers.28.self_attn.q_proj": {
1035
+ "bits": 8,
1036
+ "group_size": 128
1037
+ },
1038
+ "model.layers.28.self_attn.v_proj": {
1039
+ "bits": 8,
1040
+ "group_size": 128
1041
+ },
1042
+ "model.layers.29.mlp.gate.gate": {
1043
+ "bits": 8,
1044
+ "group_size": 64
1045
+ },
1046
+ "model.layers.29.mlp.share_expert.down_proj": {
1047
+ "bits": 3,
1048
+ "group_size": 128
1049
+ },
1050
+ "model.layers.29.mlp.share_expert.gate_proj": {
1051
+ "bits": 4,
1052
+ "group_size": 128
1053
+ },
1054
+ "model.layers.29.mlp.share_expert.up_proj": {
1055
+ "bits": 2,
1056
+ "group_size": 128
1057
+ },
1058
+ "model.layers.29.mlp.switch_mlp.down_proj": {
1059
+ "bits": 3,
1060
+ "group_size": 128
1061
+ },
1062
+ "model.layers.29.mlp.switch_mlp.gate_proj": {
1063
+ "bits": 4,
1064
+ "group_size": 128
1065
+ },
1066
+ "model.layers.29.mlp.switch_mlp.up_proj": {
1067
+ "bits": 2,
1068
+ "group_size": 128
1069
+ },
1070
+ "model.layers.29.self_attn.g_proj": {
1071
+ "bits": 8,
1072
+ "group_size": 128
1073
+ },
1074
+ "model.layers.29.self_attn.k_proj": {
1075
+ "bits": 8,
1076
+ "group_size": 128
1077
+ },
1078
+ "model.layers.29.self_attn.o_proj": {
1079
+ "bits": 8,
1080
+ "group_size": 128
1081
+ },
1082
+ "model.layers.29.self_attn.q_proj": {
1083
+ "bits": 8,
1084
+ "group_size": 128
1085
+ },
1086
+ "model.layers.29.self_attn.v_proj": {
1087
+ "bits": 8,
1088
+ "group_size": 128
1089
+ },
1090
+ "model.layers.3.mlp.gate.gate": {
1091
+ "bits": 8,
1092
+ "group_size": 64
1093
+ },
1094
+ "model.layers.3.mlp.share_expert.down_proj": {
1095
+ "bits": 3,
1096
+ "group_size": 128
1097
+ },
1098
+ "model.layers.3.mlp.share_expert.gate_proj": {
1099
+ "bits": 4,
1100
+ "group_size": 128
1101
+ },
1102
+ "model.layers.3.mlp.share_expert.up_proj": {
1103
+ "bits": 2,
1104
+ "group_size": 128
1105
+ },
1106
+ "model.layers.3.mlp.switch_mlp.down_proj": {
1107
+ "bits": 3,
1108
+ "group_size": 128
1109
+ },
1110
+ "model.layers.3.mlp.switch_mlp.gate_proj": {
1111
+ "bits": 4,
1112
+ "group_size": 128
1113
+ },
1114
+ "model.layers.3.mlp.switch_mlp.up_proj": {
1115
+ "bits": 2,
1116
+ "group_size": 128
1117
+ },
1118
+ "model.layers.3.self_attn.g_proj": {
1119
+ "bits": 8,
1120
+ "group_size": 128
1121
+ },
1122
+ "model.layers.3.self_attn.k_proj": {
1123
+ "bits": 8,
1124
+ "group_size": 128
1125
+ },
1126
+ "model.layers.3.self_attn.o_proj": {
1127
+ "bits": 8,
1128
+ "group_size": 128
1129
+ },
1130
+ "model.layers.3.self_attn.q_proj": {
1131
+ "bits": 8,
1132
+ "group_size": 128
1133
+ },
1134
+ "model.layers.3.self_attn.v_proj": {
1135
+ "bits": 8,
1136
+ "group_size": 128
1137
+ },
1138
+ "model.layers.30.mlp.gate.gate": {
1139
+ "bits": 8,
1140
+ "group_size": 64
1141
+ },
1142
+ "model.layers.30.mlp.share_expert.down_proj": {
1143
+ "bits": 3,
1144
+ "group_size": 128
1145
+ },
1146
+ "model.layers.30.mlp.share_expert.gate_proj": {
1147
+ "bits": 4,
1148
+ "group_size": 128
1149
+ },
1150
+ "model.layers.30.mlp.share_expert.up_proj": {
1151
+ "bits": 2,
1152
+ "group_size": 128
1153
+ },
1154
+ "model.layers.30.mlp.switch_mlp.down_proj": {
1155
+ "bits": 3,
1156
+ "group_size": 128
1157
+ },
1158
+ "model.layers.30.mlp.switch_mlp.gate_proj": {
1159
+ "bits": 4,
1160
+ "group_size": 128
1161
+ },
1162
+ "model.layers.30.mlp.switch_mlp.up_proj": {
1163
+ "bits": 2,
1164
+ "group_size": 128
1165
+ },
1166
+ "model.layers.30.self_attn.g_proj": {
1167
+ "bits": 8,
1168
+ "group_size": 128
1169
+ },
1170
+ "model.layers.30.self_attn.k_proj": {
1171
+ "bits": 8,
1172
+ "group_size": 128
1173
+ },
1174
+ "model.layers.30.self_attn.o_proj": {
1175
+ "bits": 8,
1176
+ "group_size": 128
1177
+ },
1178
+ "model.layers.30.self_attn.q_proj": {
1179
+ "bits": 8,
1180
+ "group_size": 128
1181
+ },
1182
+ "model.layers.30.self_attn.v_proj": {
1183
+ "bits": 8,
1184
+ "group_size": 128
1185
+ },
1186
+ "model.layers.31.mlp.gate.gate": {
1187
+ "bits": 8,
1188
+ "group_size": 64
1189
+ },
1190
+ "model.layers.31.mlp.share_expert.down_proj": {
1191
+ "bits": 3,
1192
+ "group_size": 128
1193
+ },
1194
+ "model.layers.31.mlp.share_expert.gate_proj": {
1195
+ "bits": 4,
1196
+ "group_size": 128
1197
+ },
1198
+ "model.layers.31.mlp.share_expert.up_proj": {
1199
+ "bits": 2,
1200
+ "group_size": 128
1201
+ },
1202
+ "model.layers.31.mlp.switch_mlp.down_proj": {
1203
+ "bits": 3,
1204
+ "group_size": 128
1205
+ },
1206
+ "model.layers.31.mlp.switch_mlp.gate_proj": {
1207
+ "bits": 4,
1208
+ "group_size": 128
1209
+ },
1210
+ "model.layers.31.mlp.switch_mlp.up_proj": {
1211
+ "bits": 2,
1212
+ "group_size": 128
1213
+ },
1214
+ "model.layers.31.self_attn.g_proj": {
1215
+ "bits": 8,
1216
+ "group_size": 128
1217
+ },
1218
+ "model.layers.31.self_attn.k_proj": {
1219
+ "bits": 8,
1220
+ "group_size": 128
1221
+ },
1222
+ "model.layers.31.self_attn.o_proj": {
1223
+ "bits": 8,
1224
+ "group_size": 128
1225
+ },
1226
+ "model.layers.31.self_attn.q_proj": {
1227
+ "bits": 8,
1228
+ "group_size": 128
1229
+ },
1230
+ "model.layers.31.self_attn.v_proj": {
1231
+ "bits": 8,
1232
+ "group_size": 128
1233
+ },
1234
+ "model.layers.32.mlp.gate.gate": {
1235
+ "bits": 8,
1236
+ "group_size": 64
1237
+ },
1238
+ "model.layers.32.mlp.share_expert.down_proj": {
1239
+ "bits": 3,
1240
+ "group_size": 128
1241
+ },
1242
+ "model.layers.32.mlp.share_expert.gate_proj": {
1243
+ "bits": 4,
1244
+ "group_size": 128
1245
+ },
1246
+ "model.layers.32.mlp.share_expert.up_proj": {
1247
+ "bits": 2,
1248
+ "group_size": 128
1249
+ },
1250
+ "model.layers.32.mlp.switch_mlp.down_proj": {
1251
+ "bits": 3,
1252
+ "group_size": 128
1253
+ },
1254
+ "model.layers.32.mlp.switch_mlp.gate_proj": {
1255
+ "bits": 4,
1256
+ "group_size": 128
1257
+ },
1258
+ "model.layers.32.mlp.switch_mlp.up_proj": {
1259
+ "bits": 2,
1260
+ "group_size": 128
1261
+ },
1262
+ "model.layers.32.self_attn.g_proj": {
1263
+ "bits": 8,
1264
+ "group_size": 128
1265
+ },
1266
+ "model.layers.32.self_attn.k_proj": {
1267
+ "bits": 8,
1268
+ "group_size": 128
1269
+ },
1270
+ "model.layers.32.self_attn.o_proj": {
1271
+ "bits": 8,
1272
+ "group_size": 128
1273
+ },
1274
+ "model.layers.32.self_attn.q_proj": {
1275
+ "bits": 8,
1276
+ "group_size": 128
1277
+ },
1278
+ "model.layers.32.self_attn.v_proj": {
1279
+ "bits": 8,
1280
+ "group_size": 128
1281
+ },
1282
+ "model.layers.33.mlp.gate.gate": {
1283
+ "bits": 8,
1284
+ "group_size": 64
1285
+ },
1286
+ "model.layers.33.mlp.share_expert.down_proj": {
1287
+ "bits": 3,
1288
+ "group_size": 128
1289
+ },
1290
+ "model.layers.33.mlp.share_expert.gate_proj": {
1291
+ "bits": 4,
1292
+ "group_size": 128
1293
+ },
1294
+ "model.layers.33.mlp.share_expert.up_proj": {
1295
+ "bits": 2,
1296
+ "group_size": 128
1297
+ },
1298
+ "model.layers.33.mlp.switch_mlp.down_proj": {
1299
+ "bits": 3,
1300
+ "group_size": 128
1301
+ },
1302
+ "model.layers.33.mlp.switch_mlp.gate_proj": {
1303
+ "bits": 4,
1304
+ "group_size": 128
1305
+ },
1306
+ "model.layers.33.mlp.switch_mlp.up_proj": {
1307
+ "bits": 2,
1308
+ "group_size": 128
1309
+ },
1310
+ "model.layers.33.self_attn.g_proj": {
1311
+ "bits": 8,
1312
+ "group_size": 128
1313
+ },
1314
+ "model.layers.33.self_attn.k_proj": {
1315
+ "bits": 8,
1316
+ "group_size": 128
1317
+ },
1318
+ "model.layers.33.self_attn.o_proj": {
1319
+ "bits": 8,
1320
+ "group_size": 128
1321
+ },
1322
+ "model.layers.33.self_attn.q_proj": {
1323
+ "bits": 8,
1324
+ "group_size": 128
1325
+ },
1326
+ "model.layers.33.self_attn.v_proj": {
1327
+ "bits": 8,
1328
+ "group_size": 128
1329
+ },
1330
+ "model.layers.34.mlp.gate.gate": {
1331
+ "bits": 8,
1332
+ "group_size": 64
1333
+ },
1334
+ "model.layers.34.mlp.share_expert.down_proj": {
1335
+ "bits": 3,
1336
+ "group_size": 128
1337
+ },
1338
+ "model.layers.34.mlp.share_expert.gate_proj": {
1339
+ "bits": 4,
1340
+ "group_size": 128
1341
+ },
1342
+ "model.layers.34.mlp.share_expert.up_proj": {
1343
+ "bits": 2,
1344
+ "group_size": 128
1345
+ },
1346
+ "model.layers.34.mlp.switch_mlp.down_proj": {
1347
+ "bits": 3,
1348
+ "group_size": 128
1349
+ },
1350
+ "model.layers.34.mlp.switch_mlp.gate_proj": {
1351
+ "bits": 4,
1352
+ "group_size": 128
1353
+ },
1354
+ "model.layers.34.mlp.switch_mlp.up_proj": {
1355
+ "bits": 2,
1356
+ "group_size": 128
1357
+ },
1358
+ "model.layers.34.self_attn.g_proj": {
1359
+ "bits": 8,
1360
+ "group_size": 128
1361
+ },
1362
+ "model.layers.34.self_attn.k_proj": {
1363
+ "bits": 8,
1364
+ "group_size": 128
1365
+ },
1366
+ "model.layers.34.self_attn.o_proj": {
1367
+ "bits": 8,
1368
+ "group_size": 128
1369
+ },
1370
+ "model.layers.34.self_attn.q_proj": {
1371
+ "bits": 8,
1372
+ "group_size": 128
1373
+ },
1374
+ "model.layers.34.self_attn.v_proj": {
1375
+ "bits": 8,
1376
+ "group_size": 128
1377
+ },
1378
+ "model.layers.35.mlp.gate.gate": {
1379
+ "bits": 8,
1380
+ "group_size": 64
1381
+ },
1382
+ "model.layers.35.mlp.share_expert.down_proj": {
1383
+ "bits": 3,
1384
+ "group_size": 128
1385
+ },
1386
+ "model.layers.35.mlp.share_expert.gate_proj": {
1387
+ "bits": 4,
1388
+ "group_size": 128
1389
+ },
1390
+ "model.layers.35.mlp.share_expert.up_proj": {
1391
+ "bits": 2,
1392
+ "group_size": 128
1393
+ },
1394
+ "model.layers.35.mlp.switch_mlp.down_proj": {
1395
+ "bits": 3,
1396
+ "group_size": 128
1397
+ },
1398
+ "model.layers.35.mlp.switch_mlp.gate_proj": {
1399
+ "bits": 4,
1400
+ "group_size": 128
1401
+ },
1402
+ "model.layers.35.mlp.switch_mlp.up_proj": {
1403
+ "bits": 2,
1404
+ "group_size": 128
1405
+ },
1406
+ "model.layers.35.self_attn.g_proj": {
1407
+ "bits": 8,
1408
+ "group_size": 128
1409
+ },
1410
+ "model.layers.35.self_attn.k_proj": {
1411
+ "bits": 8,
1412
+ "group_size": 128
1413
+ },
1414
+ "model.layers.35.self_attn.o_proj": {
1415
+ "bits": 8,
1416
+ "group_size": 128
1417
+ },
1418
+ "model.layers.35.self_attn.q_proj": {
1419
+ "bits": 8,
1420
+ "group_size": 128
1421
+ },
1422
+ "model.layers.35.self_attn.v_proj": {
1423
+ "bits": 8,
1424
+ "group_size": 128
1425
+ },
1426
+ "model.layers.36.mlp.gate.gate": {
1427
+ "bits": 8,
1428
+ "group_size": 64
1429
+ },
1430
+ "model.layers.36.mlp.share_expert.down_proj": {
1431
+ "bits": 3,
1432
+ "group_size": 128
1433
+ },
1434
+ "model.layers.36.mlp.share_expert.gate_proj": {
1435
+ "bits": 4,
1436
+ "group_size": 128
1437
+ },
1438
+ "model.layers.36.mlp.share_expert.up_proj": {
1439
+ "bits": 2,
1440
+ "group_size": 128
1441
+ },
1442
+ "model.layers.36.mlp.switch_mlp.down_proj": {
1443
+ "bits": 3,
1444
+ "group_size": 128
1445
+ },
1446
+ "model.layers.36.mlp.switch_mlp.gate_proj": {
1447
+ "bits": 4,
1448
+ "group_size": 128
1449
+ },
1450
+ "model.layers.36.mlp.switch_mlp.up_proj": {
1451
+ "bits": 2,
1452
+ "group_size": 128
1453
+ },
1454
+ "model.layers.36.self_attn.g_proj": {
1455
+ "bits": 8,
1456
+ "group_size": 128
1457
+ },
1458
+ "model.layers.36.self_attn.k_proj": {
1459
+ "bits": 8,
1460
+ "group_size": 128
1461
+ },
1462
+ "model.layers.36.self_attn.o_proj": {
1463
+ "bits": 8,
1464
+ "group_size": 128
1465
+ },
1466
+ "model.layers.36.self_attn.q_proj": {
1467
+ "bits": 8,
1468
+ "group_size": 128
1469
+ },
1470
+ "model.layers.36.self_attn.v_proj": {
1471
+ "bits": 8,
1472
+ "group_size": 128
1473
+ },
1474
+ "model.layers.37.mlp.gate.gate": {
1475
+ "bits": 8,
1476
+ "group_size": 64
1477
+ },
1478
+ "model.layers.37.mlp.share_expert.down_proj": {
1479
+ "bits": 3,
1480
+ "group_size": 128
1481
+ },
1482
+ "model.layers.37.mlp.share_expert.gate_proj": {
1483
+ "bits": 4,
1484
+ "group_size": 128
1485
+ },
1486
+ "model.layers.37.mlp.share_expert.up_proj": {
1487
+ "bits": 2,
1488
+ "group_size": 128
1489
+ },
1490
+ "model.layers.37.mlp.switch_mlp.down_proj": {
1491
+ "bits": 3,
1492
+ "group_size": 128
1493
+ },
1494
+ "model.layers.37.mlp.switch_mlp.gate_proj": {
1495
+ "bits": 4,
1496
+ "group_size": 128
1497
+ },
1498
+ "model.layers.37.mlp.switch_mlp.up_proj": {
1499
+ "bits": 2,
1500
+ "group_size": 128
1501
+ },
1502
+ "model.layers.37.self_attn.g_proj": {
1503
+ "bits": 8,
1504
+ "group_size": 128
1505
+ },
1506
+ "model.layers.37.self_attn.k_proj": {
1507
+ "bits": 8,
1508
+ "group_size": 128
1509
+ },
1510
+ "model.layers.37.self_attn.o_proj": {
1511
+ "bits": 8,
1512
+ "group_size": 128
1513
+ },
1514
+ "model.layers.37.self_attn.q_proj": {
1515
+ "bits": 8,
1516
+ "group_size": 128
1517
+ },
1518
+ "model.layers.37.self_attn.v_proj": {
1519
+ "bits": 8,
1520
+ "group_size": 128
1521
+ },
1522
+ "model.layers.38.mlp.gate.gate": {
1523
+ "bits": 8,
1524
+ "group_size": 64
1525
+ },
1526
+ "model.layers.38.mlp.share_expert.down_proj": {
1527
+ "bits": 3,
1528
+ "group_size": 128
1529
+ },
1530
+ "model.layers.38.mlp.share_expert.gate_proj": {
1531
+ "bits": 4,
1532
+ "group_size": 128
1533
+ },
1534
+ "model.layers.38.mlp.share_expert.up_proj": {
1535
+ "bits": 2,
1536
+ "group_size": 128
1537
+ },
1538
+ "model.layers.38.mlp.switch_mlp.down_proj": {
1539
+ "bits": 3,
1540
+ "group_size": 128
1541
+ },
1542
+ "model.layers.38.mlp.switch_mlp.gate_proj": {
1543
+ "bits": 4,
1544
+ "group_size": 128
1545
+ },
1546
+ "model.layers.38.mlp.switch_mlp.up_proj": {
1547
+ "bits": 2,
1548
+ "group_size": 128
1549
+ },
1550
+ "model.layers.38.self_attn.g_proj": {
1551
+ "bits": 8,
1552
+ "group_size": 128
1553
+ },
1554
+ "model.layers.38.self_attn.k_proj": {
1555
+ "bits": 8,
1556
+ "group_size": 128
1557
+ },
1558
+ "model.layers.38.self_attn.o_proj": {
1559
+ "bits": 8,
1560
+ "group_size": 128
1561
+ },
1562
+ "model.layers.38.self_attn.q_proj": {
1563
+ "bits": 8,
1564
+ "group_size": 128
1565
+ },
1566
+ "model.layers.38.self_attn.v_proj": {
1567
+ "bits": 8,
1568
+ "group_size": 128
1569
+ },
1570
+ "model.layers.39.mlp.gate.gate": {
1571
+ "bits": 8,
1572
+ "group_size": 64
1573
+ },
1574
+ "model.layers.39.mlp.share_expert.down_proj": {
1575
+ "bits": 3,
1576
+ "group_size": 128
1577
+ },
1578
+ "model.layers.39.mlp.share_expert.gate_proj": {
1579
+ "bits": 4,
1580
+ "group_size": 128
1581
+ },
1582
+ "model.layers.39.mlp.share_expert.up_proj": {
1583
+ "bits": 2,
1584
+ "group_size": 128
1585
+ },
1586
+ "model.layers.39.mlp.switch_mlp.down_proj": {
1587
+ "bits": 3,
1588
+ "group_size": 128
1589
+ },
1590
+ "model.layers.39.mlp.switch_mlp.gate_proj": {
1591
+ "bits": 4,
1592
+ "group_size": 128
1593
+ },
1594
+ "model.layers.39.mlp.switch_mlp.up_proj": {
1595
+ "bits": 2,
1596
+ "group_size": 128
1597
+ },
1598
+ "model.layers.39.self_attn.g_proj": {
1599
+ "bits": 8,
1600
+ "group_size": 128
1601
+ },
1602
+ "model.layers.39.self_attn.k_proj": {
1603
+ "bits": 8,
1604
+ "group_size": 128
1605
+ },
1606
+ "model.layers.39.self_attn.o_proj": {
1607
+ "bits": 8,
1608
+ "group_size": 128
1609
+ },
1610
+ "model.layers.39.self_attn.q_proj": {
1611
+ "bits": 8,
1612
+ "group_size": 128
1613
+ },
1614
+ "model.layers.39.self_attn.v_proj": {
1615
+ "bits": 8,
1616
+ "group_size": 128
1617
+ },
1618
+ "model.layers.4.mlp.gate.gate": {
1619
+ "bits": 8,
1620
+ "group_size": 64
1621
+ },
1622
+ "model.layers.4.mlp.share_expert.down_proj": {
1623
+ "bits": 3,
1624
+ "group_size": 128
1625
+ },
1626
+ "model.layers.4.mlp.share_expert.gate_proj": {
1627
+ "bits": 4,
1628
+ "group_size": 128
1629
+ },
1630
+ "model.layers.4.mlp.share_expert.up_proj": {
1631
+ "bits": 2,
1632
+ "group_size": 128
1633
+ },
1634
+ "model.layers.4.mlp.switch_mlp.down_proj": {
1635
+ "bits": 3,
1636
+ "group_size": 128
1637
+ },
1638
+ "model.layers.4.mlp.switch_mlp.gate_proj": {
1639
+ "bits": 4,
1640
+ "group_size": 128
1641
+ },
1642
+ "model.layers.4.mlp.switch_mlp.up_proj": {
1643
+ "bits": 2,
1644
+ "group_size": 128
1645
+ },
1646
+ "model.layers.4.self_attn.g_proj": {
1647
+ "bits": 8,
1648
+ "group_size": 128
1649
+ },
1650
+ "model.layers.4.self_attn.k_proj": {
1651
+ "bits": 8,
1652
+ "group_size": 128
1653
+ },
1654
+ "model.layers.4.self_attn.o_proj": {
1655
+ "bits": 8,
1656
+ "group_size": 128
1657
+ },
1658
+ "model.layers.4.self_attn.q_proj": {
1659
+ "bits": 8,
1660
+ "group_size": 128
1661
+ },
1662
+ "model.layers.4.self_attn.v_proj": {
1663
+ "bits": 8,
1664
+ "group_size": 128
1665
+ },
1666
+ "model.layers.40.mlp.gate.gate": {
1667
+ "bits": 8,
1668
+ "group_size": 64
1669
+ },
1670
+ "model.layers.40.mlp.share_expert.down_proj": {
1671
+ "bits": 3,
1672
+ "group_size": 128
1673
+ },
1674
+ "model.layers.40.mlp.share_expert.gate_proj": {
1675
+ "bits": 4,
1676
+ "group_size": 128
1677
+ },
1678
+ "model.layers.40.mlp.share_expert.up_proj": {
1679
+ "bits": 2,
1680
+ "group_size": 128
1681
+ },
1682
+ "model.layers.40.mlp.switch_mlp.down_proj": {
1683
+ "bits": 3,
1684
+ "group_size": 128
1685
+ },
1686
+ "model.layers.40.mlp.switch_mlp.gate_proj": {
1687
+ "bits": 4,
1688
+ "group_size": 128
1689
+ },
1690
+ "model.layers.40.mlp.switch_mlp.up_proj": {
1691
+ "bits": 2,
1692
+ "group_size": 128
1693
+ },
1694
+ "model.layers.40.self_attn.g_proj": {
1695
+ "bits": 8,
1696
+ "group_size": 128
1697
+ },
1698
+ "model.layers.40.self_attn.k_proj": {
1699
+ "bits": 8,
1700
+ "group_size": 128
1701
+ },
1702
+ "model.layers.40.self_attn.o_proj": {
1703
+ "bits": 8,
1704
+ "group_size": 128
1705
+ },
1706
+ "model.layers.40.self_attn.q_proj": {
1707
+ "bits": 8,
1708
+ "group_size": 128
1709
+ },
1710
+ "model.layers.40.self_attn.v_proj": {
1711
+ "bits": 8,
1712
+ "group_size": 128
1713
+ },
1714
+ "model.layers.41.mlp.gate.gate": {
1715
+ "bits": 8,
1716
+ "group_size": 64
1717
+ },
1718
+ "model.layers.41.mlp.share_expert.down_proj": {
1719
+ "bits": 3,
1720
+ "group_size": 128
1721
+ },
1722
+ "model.layers.41.mlp.share_expert.gate_proj": {
1723
+ "bits": 4,
1724
+ "group_size": 128
1725
+ },
1726
+ "model.layers.41.mlp.share_expert.up_proj": {
1727
+ "bits": 2,
1728
+ "group_size": 128
1729
+ },
1730
+ "model.layers.41.mlp.switch_mlp.down_proj": {
1731
+ "bits": 3,
1732
+ "group_size": 128
1733
+ },
1734
+ "model.layers.41.mlp.switch_mlp.gate_proj": {
1735
+ "bits": 4,
1736
+ "group_size": 128
1737
+ },
1738
+ "model.layers.41.mlp.switch_mlp.up_proj": {
1739
+ "bits": 2,
1740
+ "group_size": 128
1741
+ },
1742
+ "model.layers.41.self_attn.g_proj": {
1743
+ "bits": 8,
1744
+ "group_size": 128
1745
+ },
1746
+ "model.layers.41.self_attn.k_proj": {
1747
+ "bits": 8,
1748
+ "group_size": 128
1749
+ },
1750
+ "model.layers.41.self_attn.o_proj": {
1751
+ "bits": 8,
1752
+ "group_size": 128
1753
+ },
1754
+ "model.layers.41.self_attn.q_proj": {
1755
+ "bits": 8,
1756
+ "group_size": 128
1757
+ },
1758
+ "model.layers.41.self_attn.v_proj": {
1759
+ "bits": 8,
1760
+ "group_size": 128
1761
+ },
1762
+ "model.layers.42.mlp.gate.gate": {
1763
+ "bits": 8,
1764
+ "group_size": 64
1765
+ },
1766
+ "model.layers.42.mlp.share_expert.down_proj": {
1767
+ "bits": 3,
1768
+ "group_size": 128
1769
+ },
1770
+ "model.layers.42.mlp.share_expert.gate_proj": {
1771
+ "bits": 4,
1772
+ "group_size": 128
1773
+ },
1774
+ "model.layers.42.mlp.share_expert.up_proj": {
1775
+ "bits": 2,
1776
+ "group_size": 128
1777
+ },
1778
+ "model.layers.42.mlp.switch_mlp.down_proj": {
1779
+ "bits": 3,
1780
+ "group_size": 128
1781
+ },
1782
+ "model.layers.42.mlp.switch_mlp.gate_proj": {
1783
+ "bits": 4,
1784
+ "group_size": 128
1785
+ },
1786
+ "model.layers.42.mlp.switch_mlp.up_proj": {
1787
+ "bits": 2,
1788
+ "group_size": 128
1789
+ },
1790
+ "model.layers.42.self_attn.g_proj": {
1791
+ "bits": 8,
1792
+ "group_size": 128
1793
+ },
1794
+ "model.layers.42.self_attn.k_proj": {
1795
+ "bits": 8,
1796
+ "group_size": 128
1797
+ },
1798
+ "model.layers.42.self_attn.o_proj": {
1799
+ "bits": 8,
1800
+ "group_size": 128
1801
+ },
1802
+ "model.layers.42.self_attn.q_proj": {
1803
+ "bits": 8,
1804
+ "group_size": 128
1805
+ },
1806
+ "model.layers.42.self_attn.v_proj": {
1807
+ "bits": 8,
1808
+ "group_size": 128
1809
+ },
1810
+ "model.layers.43.mlp.gate.gate": {
1811
+ "bits": 8,
1812
+ "group_size": 64
1813
+ },
1814
+ "model.layers.43.mlp.share_expert.down_proj": {
1815
+ "bits": 3,
1816
+ "group_size": 128
1817
+ },
1818
+ "model.layers.43.mlp.share_expert.gate_proj": {
1819
+ "bits": 4,
1820
+ "group_size": 128
1821
+ },
1822
+ "model.layers.43.mlp.share_expert.up_proj": {
1823
+ "bits": 2,
1824
+ "group_size": 128
1825
+ },
1826
+ "model.layers.43.mlp.switch_mlp.down_proj": {
1827
+ "bits": 3,
1828
+ "group_size": 128
1829
+ },
1830
+ "model.layers.43.mlp.switch_mlp.gate_proj": {
1831
+ "bits": 4,
1832
+ "group_size": 128
1833
+ },
1834
+ "model.layers.43.mlp.switch_mlp.up_proj": {
1835
+ "bits": 2,
1836
+ "group_size": 128
1837
+ },
1838
+ "model.layers.43.self_attn.g_proj": {
1839
+ "bits": 8,
1840
+ "group_size": 128
1841
+ },
1842
+ "model.layers.43.self_attn.k_proj": {
1843
+ "bits": 8,
1844
+ "group_size": 128
1845
+ },
1846
+ "model.layers.43.self_attn.o_proj": {
1847
+ "bits": 8,
1848
+ "group_size": 128
1849
+ },
1850
+ "model.layers.43.self_attn.q_proj": {
1851
+ "bits": 8,
1852
+ "group_size": 128
1853
+ },
1854
+ "model.layers.43.self_attn.v_proj": {
1855
+ "bits": 8,
1856
+ "group_size": 128
1857
+ },
1858
+ "model.layers.44.mlp.gate.gate": {
1859
+ "bits": 8,
1860
+ "group_size": 64
1861
+ },
1862
+ "model.layers.44.mlp.share_expert.down_proj": {
1863
+ "bits": 3,
1864
+ "group_size": 128
1865
+ },
1866
+ "model.layers.44.mlp.share_expert.gate_proj": {
1867
+ "bits": 4,
1868
+ "group_size": 128
1869
+ },
1870
+ "model.layers.44.mlp.share_expert.up_proj": {
1871
+ "bits": 2,
1872
+ "group_size": 128
1873
+ },
1874
+ "model.layers.44.mlp.switch_mlp.down_proj": {
1875
+ "bits": 3,
1876
+ "group_size": 128
1877
+ },
1878
+ "model.layers.44.mlp.switch_mlp.gate_proj": {
1879
+ "bits": 4,
1880
+ "group_size": 128
1881
+ },
1882
+ "model.layers.44.mlp.switch_mlp.up_proj": {
1883
+ "bits": 2,
1884
+ "group_size": 128
1885
+ },
1886
+ "model.layers.44.self_attn.g_proj": {
1887
+ "bits": 8,
1888
+ "group_size": 128
1889
+ },
1890
+ "model.layers.44.self_attn.k_proj": {
1891
+ "bits": 8,
1892
+ "group_size": 128
1893
+ },
1894
+ "model.layers.44.self_attn.o_proj": {
1895
+ "bits": 8,
1896
+ "group_size": 128
1897
+ },
1898
+ "model.layers.44.self_attn.q_proj": {
1899
+ "bits": 8,
1900
+ "group_size": 128
1901
+ },
1902
+ "model.layers.44.self_attn.v_proj": {
1903
+ "bits": 8,
1904
+ "group_size": 128
1905
+ },
1906
+ "model.layers.5.mlp.gate.gate": {
1907
+ "bits": 8,
1908
+ "group_size": 64
1909
+ },
1910
+ "model.layers.5.mlp.share_expert.down_proj": {
1911
+ "bits": 3,
1912
+ "group_size": 128
1913
+ },
1914
+ "model.layers.5.mlp.share_expert.gate_proj": {
1915
+ "bits": 4,
1916
+ "group_size": 128
1917
+ },
1918
+ "model.layers.5.mlp.share_expert.up_proj": {
1919
+ "bits": 2,
1920
+ "group_size": 128
1921
+ },
1922
+ "model.layers.5.mlp.switch_mlp.down_proj": {
1923
+ "bits": 3,
1924
+ "group_size": 128
1925
+ },
1926
+ "model.layers.5.mlp.switch_mlp.gate_proj": {
1927
+ "bits": 4,
1928
+ "group_size": 128
1929
+ },
1930
+ "model.layers.5.mlp.switch_mlp.up_proj": {
1931
+ "bits": 2,
1932
+ "group_size": 128
1933
+ },
1934
+ "model.layers.5.self_attn.g_proj": {
1935
+ "bits": 8,
1936
+ "group_size": 128
1937
+ },
1938
+ "model.layers.5.self_attn.k_proj": {
1939
+ "bits": 8,
1940
+ "group_size": 128
1941
+ },
1942
+ "model.layers.5.self_attn.o_proj": {
1943
+ "bits": 8,
1944
+ "group_size": 128
1945
+ },
1946
+ "model.layers.5.self_attn.q_proj": {
1947
+ "bits": 8,
1948
+ "group_size": 128
1949
+ },
1950
+ "model.layers.5.self_attn.v_proj": {
1951
+ "bits": 8,
1952
+ "group_size": 128
1953
+ },
1954
+ "model.layers.6.mlp.gate.gate": {
1955
+ "bits": 8,
1956
+ "group_size": 64
1957
+ },
1958
+ "model.layers.6.mlp.share_expert.down_proj": {
1959
+ "bits": 3,
1960
+ "group_size": 128
1961
+ },
1962
+ "model.layers.6.mlp.share_expert.gate_proj": {
1963
+ "bits": 4,
1964
+ "group_size": 128
1965
+ },
1966
+ "model.layers.6.mlp.share_expert.up_proj": {
1967
+ "bits": 2,
1968
+ "group_size": 128
1969
+ },
1970
+ "model.layers.6.mlp.switch_mlp.down_proj": {
1971
+ "bits": 3,
1972
+ "group_size": 128
1973
+ },
1974
+ "model.layers.6.mlp.switch_mlp.gate_proj": {
1975
+ "bits": 4,
1976
+ "group_size": 128
1977
+ },
1978
+ "model.layers.6.mlp.switch_mlp.up_proj": {
1979
+ "bits": 2,
1980
+ "group_size": 128
1981
+ },
1982
+ "model.layers.6.self_attn.g_proj": {
1983
+ "bits": 8,
1984
+ "group_size": 128
1985
+ },
1986
+ "model.layers.6.self_attn.k_proj": {
1987
+ "bits": 8,
1988
+ "group_size": 128
1989
+ },
1990
+ "model.layers.6.self_attn.o_proj": {
1991
+ "bits": 8,
1992
+ "group_size": 128
1993
+ },
1994
+ "model.layers.6.self_attn.q_proj": {
1995
+ "bits": 8,
1996
+ "group_size": 128
1997
+ },
1998
+ "model.layers.6.self_attn.v_proj": {
1999
+ "bits": 8,
2000
+ "group_size": 128
2001
+ },
2002
+ "model.layers.7.mlp.gate.gate": {
2003
+ "bits": 8,
2004
+ "group_size": 64
2005
+ },
2006
+ "model.layers.7.mlp.share_expert.down_proj": {
2007
+ "bits": 3,
2008
+ "group_size": 128
2009
+ },
2010
+ "model.layers.7.mlp.share_expert.gate_proj": {
2011
+ "bits": 4,
2012
+ "group_size": 128
2013
+ },
2014
+ "model.layers.7.mlp.share_expert.up_proj": {
2015
+ "bits": 2,
2016
+ "group_size": 128
2017
+ },
2018
+ "model.layers.7.mlp.switch_mlp.down_proj": {
2019
+ "bits": 3,
2020
+ "group_size": 128
2021
+ },
2022
+ "model.layers.7.mlp.switch_mlp.gate_proj": {
2023
+ "bits": 4,
2024
+ "group_size": 128
2025
+ },
2026
+ "model.layers.7.mlp.switch_mlp.up_proj": {
2027
+ "bits": 2,
2028
+ "group_size": 128
2029
+ },
2030
+ "model.layers.7.self_attn.g_proj": {
2031
+ "bits": 8,
2032
+ "group_size": 128
2033
+ },
2034
+ "model.layers.7.self_attn.k_proj": {
2035
+ "bits": 8,
2036
+ "group_size": 128
2037
+ },
2038
+ "model.layers.7.self_attn.o_proj": {
2039
+ "bits": 8,
2040
+ "group_size": 128
2041
+ },
2042
+ "model.layers.7.self_attn.q_proj": {
2043
+ "bits": 8,
2044
+ "group_size": 128
2045
+ },
2046
+ "model.layers.7.self_attn.v_proj": {
2047
+ "bits": 8,
2048
+ "group_size": 128
2049
+ },
2050
+ "model.layers.8.mlp.gate.gate": {
2051
+ "bits": 8,
2052
+ "group_size": 64
2053
+ },
2054
+ "model.layers.8.mlp.share_expert.down_proj": {
2055
+ "bits": 3,
2056
+ "group_size": 128
2057
+ },
2058
+ "model.layers.8.mlp.share_expert.gate_proj": {
2059
+ "bits": 4,
2060
+ "group_size": 128
2061
+ },
2062
+ "model.layers.8.mlp.share_expert.up_proj": {
2063
+ "bits": 2,
2064
+ "group_size": 128
2065
+ },
2066
+ "model.layers.8.mlp.switch_mlp.down_proj": {
2067
+ "bits": 3,
2068
+ "group_size": 128
2069
+ },
2070
+ "model.layers.8.mlp.switch_mlp.gate_proj": {
2071
+ "bits": 4,
2072
+ "group_size": 128
2073
+ },
2074
+ "model.layers.8.mlp.switch_mlp.up_proj": {
2075
+ "bits": 2,
2076
+ "group_size": 128
2077
+ },
2078
+ "model.layers.8.self_attn.g_proj": {
2079
+ "bits": 8,
2080
+ "group_size": 128
2081
+ },
2082
+ "model.layers.8.self_attn.k_proj": {
2083
+ "bits": 8,
2084
+ "group_size": 128
2085
+ },
2086
+ "model.layers.8.self_attn.o_proj": {
2087
+ "bits": 8,
2088
+ "group_size": 128
2089
+ },
2090
+ "model.layers.8.self_attn.q_proj": {
2091
+ "bits": 8,
2092
+ "group_size": 128
2093
+ },
2094
+ "model.layers.8.self_attn.v_proj": {
2095
+ "bits": 8,
2096
+ "group_size": 128
2097
+ },
2098
+ "model.layers.9.mlp.gate.gate": {
2099
+ "bits": 8,
2100
+ "group_size": 64
2101
+ },
2102
+ "model.layers.9.mlp.share_expert.down_proj": {
2103
+ "bits": 3,
2104
+ "group_size": 128
2105
+ },
2106
+ "model.layers.9.mlp.share_expert.gate_proj": {
2107
+ "bits": 4,
2108
+ "group_size": 128
2109
+ },
2110
+ "model.layers.9.mlp.share_expert.up_proj": {
2111
+ "bits": 2,
2112
+ "group_size": 128
2113
+ },
2114
+ "model.layers.9.mlp.switch_mlp.down_proj": {
2115
+ "bits": 3,
2116
+ "group_size": 128
2117
+ },
2118
+ "model.layers.9.mlp.switch_mlp.gate_proj": {
2119
+ "bits": 4,
2120
+ "group_size": 128
2121
+ },
2122
+ "model.layers.9.mlp.switch_mlp.up_proj": {
2123
+ "bits": 2,
2124
+ "group_size": 128
2125
+ },
2126
+ "model.layers.9.self_attn.g_proj": {
2127
+ "bits": 8,
2128
+ "group_size": 128
2129
+ },
2130
+ "model.layers.9.self_attn.k_proj": {
2131
+ "bits": 8,
2132
+ "group_size": 128
2133
+ },
2134
+ "model.layers.9.self_attn.o_proj": {
2135
+ "bits": 8,
2136
+ "group_size": 128
2137
+ },
2138
+ "model.layers.9.self_attn.q_proj": {
2139
+ "bits": 8,
2140
+ "group_size": 128
2141
+ },
2142
+ "model.layers.9.self_attn.v_proj": {
2143
+ "bits": 8,
2144
+ "group_size": 128
2145
+ }
2146
+ },
2147
+ "quantization_config": {
2148
+ "config_groups": {
2149
+ "group_0": {
2150
+ "input_activations": {
2151
+ "dynamic": false,
2152
+ "group_size": 16,
2153
+ "num_bits": 4,
2154
+ "type": "float"
2155
+ },
2156
+ "targets": [
2157
+ "Linear"
2158
+ ],
2159
+ "weights": {
2160
+ "dynamic": false,
2161
+ "group_size": 16,
2162
+ "num_bits": 4,
2163
+ "type": "float"
2164
+ }
2165
+ }
2166
+ },
2167
+ "ignore": [
2168
+ "lm_head",
2169
+ "model.language_model.layers.0*",
2170
+ "model.language_model.layers.1.*",
2171
+ "model.language_model.layers.10.moe.gate",
2172
+ "model.language_model.layers.10.self_attn*",
2173
+ "model.language_model.layers.10.share_expert*",
2174
+ "model.language_model.layers.11.moe.gate",
2175
+ "model.language_model.layers.11.self_attn*",
2176
+ "model.language_model.layers.11.share_expert*",
2177
+ "model.language_model.layers.12.moe.gate",
2178
+ "model.language_model.layers.12.self_attn*",
2179
+ "model.language_model.layers.12.share_expert*",
2180
+ "model.language_model.layers.13.moe.gate",
2181
+ "model.language_model.layers.13.self_attn*",
2182
+ "model.language_model.layers.13.share_expert*",
2183
+ "model.language_model.layers.14.moe.gate",
2184
+ "model.language_model.layers.14.self_attn*",
2185
+ "model.language_model.layers.14.share_expert*",
2186
+ "model.language_model.layers.15.moe.gate",
2187
+ "model.language_model.layers.15.self_attn*",
2188
+ "model.language_model.layers.15.share_expert*",
2189
+ "model.language_model.layers.16.moe.gate",
2190
+ "model.language_model.layers.16.self_attn*",
2191
+ "model.language_model.layers.16.share_expert*",
2192
+ "model.language_model.layers.17.moe.gate",
2193
+ "model.language_model.layers.17.self_attn*",
2194
+ "model.language_model.layers.17.share_expert*",
2195
+ "model.language_model.layers.18.moe.gate",
2196
+ "model.language_model.layers.18.self_attn*",
2197
+ "model.language_model.layers.18.share_expert*",
2198
+ "model.language_model.layers.19.moe.gate",
2199
+ "model.language_model.layers.19.self_attn*",
2200
+ "model.language_model.layers.19.share_expert*",
2201
+ "model.language_model.layers.2.*",
2202
+ "model.language_model.layers.20.moe.gate",
2203
+ "model.language_model.layers.20.self_attn*",
2204
+ "model.language_model.layers.20.share_expert*",
2205
+ "model.language_model.layers.21.moe.gate",
2206
+ "model.language_model.layers.21.self_attn*",
2207
+ "model.language_model.layers.21.share_expert*",
2208
+ "model.language_model.layers.22.moe.gate",
2209
+ "model.language_model.layers.22.self_attn*",
2210
+ "model.language_model.layers.22.share_expert*",
2211
+ "model.language_model.layers.23.moe.gate",
2212
+ "model.language_model.layers.23.self_attn*",
2213
+ "model.language_model.layers.23.share_expert*",
2214
+ "model.language_model.layers.24.moe.gate",
2215
+ "model.language_model.layers.24.self_attn*",
2216
+ "model.language_model.layers.24.share_expert*",
2217
+ "model.language_model.layers.25.moe.gate",
2218
+ "model.language_model.layers.25.self_attn*",
2219
+ "model.language_model.layers.25.share_expert*",
2220
+ "model.language_model.layers.26.moe.gate",
2221
+ "model.language_model.layers.26.self_attn*",
2222
+ "model.language_model.layers.26.share_expert*",
2223
+ "model.language_model.layers.27.moe.gate",
2224
+ "model.language_model.layers.27.self_attn*",
2225
+ "model.language_model.layers.27.share_expert*",
2226
+ "model.language_model.layers.28.moe.gate",
2227
+ "model.language_model.layers.28.self_attn*",
2228
+ "model.language_model.layers.28.share_expert*",
2229
+ "model.language_model.layers.29.moe.gate",
2230
+ "model.language_model.layers.29.self_attn*",
2231
+ "model.language_model.layers.29.share_expert*",
2232
+ "model.language_model.layers.3.moe.gate",
2233
+ "model.language_model.layers.3.self_attn*",
2234
+ "model.language_model.layers.3.share_expert*",
2235
+ "model.language_model.layers.30.moe.gate",
2236
+ "model.language_model.layers.30.self_attn*",
2237
+ "model.language_model.layers.30.share_expert*",
2238
+ "model.language_model.layers.31.moe.gate",
2239
+ "model.language_model.layers.31.self_attn*",
2240
+ "model.language_model.layers.31.share_expert*",
2241
+ "model.language_model.layers.32.moe.gate",
2242
+ "model.language_model.layers.32.self_attn*",
2243
+ "model.language_model.layers.32.share_expert*",
2244
+ "model.language_model.layers.33.moe.gate",
2245
+ "model.language_model.layers.33.self_attn*",
2246
+ "model.language_model.layers.33.share_expert*",
2247
+ "model.language_model.layers.34.moe.gate",
2248
+ "model.language_model.layers.34.self_attn*",
2249
+ "model.language_model.layers.34.share_expert*",
2250
+ "model.language_model.layers.35.moe.gate",
2251
+ "model.language_model.layers.35.self_attn*",
2252
+ "model.language_model.layers.35.share_expert*",
2253
+ "model.language_model.layers.36.moe.gate",
2254
+ "model.language_model.layers.36.self_attn*",
2255
+ "model.language_model.layers.36.share_expert*",
2256
+ "model.language_model.layers.37.moe.gate",
2257
+ "model.language_model.layers.37.self_attn*",
2258
+ "model.language_model.layers.37.share_expert*",
2259
+ "model.language_model.layers.38.moe.gate",
2260
+ "model.language_model.layers.38.self_attn*",
2261
+ "model.language_model.layers.38.share_expert*",
2262
+ "model.language_model.layers.39.moe.gate",
2263
+ "model.language_model.layers.39.self_attn*",
2264
+ "model.language_model.layers.39.share_expert*",
2265
+ "model.language_model.layers.4.moe.gate",
2266
+ "model.language_model.layers.4.self_attn*",
2267
+ "model.language_model.layers.4.share_expert*",
2268
+ "model.language_model.layers.40.moe.gate",
2269
+ "model.language_model.layers.40.self_attn*",
2270
+ "model.language_model.layers.40.share_expert*",
2271
+ "model.language_model.layers.41.moe.gate",
2272
+ "model.language_model.layers.41.self_attn*",
2273
+ "model.language_model.layers.41.share_expert*",
2274
+ "model.language_model.layers.42.moe.gate",
2275
+ "model.language_model.layers.42.self_attn*",
2276
+ "model.language_model.layers.42.share_expert*",
2277
+ "model.language_model.layers.43.moe.gate",
2278
+ "model.language_model.layers.43.self_attn*",
2279
+ "model.language_model.layers.43.share_expert*",
2280
+ "model.language_model.layers.44.moe.gate",
2281
+ "model.language_model.layers.44.self_attn*",
2282
+ "model.language_model.layers.44.share_expert*",
2283
+ "model.language_model.layers.5.moe.gate",
2284
+ "model.language_model.layers.5.self_attn*",
2285
+ "model.language_model.layers.5.share_expert*",
2286
+ "model.language_model.layers.6.moe.gate",
2287
+ "model.language_model.layers.6.self_attn*",
2288
+ "model.language_model.layers.6.share_expert*",
2289
+ "model.language_model.layers.7.moe.gate",
2290
+ "model.language_model.layers.7.self_attn*",
2291
+ "model.language_model.layers.7.share_expert*",
2292
+ "model.language_model.layers.8.moe.gate",
2293
+ "model.language_model.layers.8.self_attn*",
2294
+ "model.language_model.layers.8.share_expert*",
2295
+ "model.language_model.layers.9.moe.gate",
2296
+ "model.language_model.layers.9.self_attn*",
2297
+ "model.language_model.layers.9.share_expert*",
2298
+ "model.vision_model*",
2299
+ "model.vit_large_projector"
2300
+ ],
2301
+ "kv_cache_scheme": {
2302
+ "dynamic": false,
2303
+ "num_bits": 8,
2304
+ "type": "float"
2305
+ },
2306
+ "producer": {
2307
+ "name": "modelopt",
2308
+ "version": "0.45.0.dev37+g3ad4f4f09.d20260524"
2309
+ },
2310
+ "quant_algo": "NVFP4",
2311
+ "quant_method": "modelopt"
2312
+ },
2313
+ "text_config": {
2314
+ "architectures": [
2315
+ "Step3p5ForCausalLM"
2316
+ ],
2317
+ "att_impl_type": "GQA",
2318
+ "attention_dropout": 0.0,
2319
+ "attention_other_setting": {
2320
+ "attention_type": "sliding_attention",
2321
+ "head_dim": 128,
2322
+ "num_attention_groups": 8,
2323
+ "num_attention_heads": 96,
2324
+ "true_head_dim": 128
2325
+ },
2326
+ "bos_token_id": 0,
2327
+ "dtype": "bfloat16",
2328
+ "eos_token_id": [
2329
+ 1,
2330
+ 2,
2331
+ 128007
2332
+ ],
2333
+ "head_dim": 128,
2334
+ "hidden_size": 4096,
2335
+ "intermediate_size": 11264,
2336
+ "layer_types": [
2337
+ "full_attention",
2338
+ "sliding_attention",
2339
+ "sliding_attention",
2340
+ "sliding_attention",
2341
+ "full_attention",
2342
+ "sliding_attention",
2343
+ "sliding_attention",
2344
+ "sliding_attention",
2345
+ "full_attention",
2346
+ "sliding_attention",
2347
+ "sliding_attention",
2348
+ "sliding_attention",
2349
+ "full_attention",
2350
+ "sliding_attention",
2351
+ "sliding_attention",
2352
+ "sliding_attention",
2353
+ "full_attention",
2354
+ "sliding_attention",
2355
+ "sliding_attention",
2356
+ "sliding_attention",
2357
+ "full_attention",
2358
+ "sliding_attention",
2359
+ "sliding_attention",
2360
+ "sliding_attention",
2361
+ "full_attention",
2362
+ "sliding_attention",
2363
+ "sliding_attention",
2364
+ "sliding_attention",
2365
+ "full_attention",
2366
+ "sliding_attention",
2367
+ "sliding_attention",
2368
+ "sliding_attention",
2369
+ "full_attention",
2370
+ "sliding_attention",
2371
+ "sliding_attention",
2372
+ "sliding_attention",
2373
+ "full_attention",
2374
+ "sliding_attention",
2375
+ "sliding_attention",
2376
+ "sliding_attention",
2377
+ "full_attention",
2378
+ "sliding_attention",
2379
+ "sliding_attention",
2380
+ "sliding_attention",
2381
+ "full_attention"
2382
+ ],
2383
+ "max_position_embeddings": 262144,
2384
+ "max_seq_len": 262144,
2385
+ "model_type": "step3p5",
2386
+ "moe_every_n_layer": 1,
2387
+ "moe_intermediate_size": 1280,
2388
+ "moe_layer_offset": 0,
2389
+ "moe_layers_enum": "3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44",
2390
+ "moe_num_experts": 288,
2391
+ "moe_router_activation": "sigmoid",
2392
+ "moe_router_scaling_factor": 3.0,
2393
+ "moe_top_k": 8,
2394
+ "need_fp32_gate": true,
2395
+ "norm_expert_weight": true,
2396
+ "num_attention_groups": 8,
2397
+ "num_attention_heads": 64,
2398
+ "num_hidden_layers": 45,
2399
+ "num_nextn_predict_layers": 3,
2400
+ "partial_rotary_factor": 0.5,
2401
+ "partial_rotary_factors": [
2402
+ 0.5,
2403
+ 1.0,
2404
+ 1.0,
2405
+ 1.0,
2406
+ 0.5,
2407
+ 1.0,
2408
+ 1.0,
2409
+ 1.0,
2410
+ 0.5,
2411
+ 1.0,
2412
+ 1.0,
2413
+ 1.0,
2414
+ 0.5,
2415
+ 1.0,
2416
+ 1.0,
2417
+ 1.0,
2418
+ 0.5,
2419
+ 1.0,
2420
+ 1.0,
2421
+ 1.0,
2422
+ 0.5,
2423
+ 1.0,
2424
+ 1.0,
2425
+ 1.0,
2426
+ 0.5,
2427
+ 1.0,
2428
+ 1.0,
2429
+ 1.0,
2430
+ 0.5,
2431
+ 1.0,
2432
+ 1.0,
2433
+ 1.0,
2434
+ 0.5,
2435
+ 1.0,
2436
+ 1.0,
2437
+ 1.0,
2438
+ 0.5,
2439
+ 1.0,
2440
+ 1.0,
2441
+ 1.0,
2442
+ 0.5,
2443
+ 1.0,
2444
+ 1.0,
2445
+ 1.0,
2446
+ 0.5
2447
+ ],
2448
+ "rms_norm_eps": 1e-05,
2449
+ "rope_parameters": {
2450
+ "factor": 2.0,
2451
+ "high_freq_factor": 32.0,
2452
+ "low_freq_factor": 1.0,
2453
+ "original_max_position_embeddings": 131072,
2454
+ "rope_theta": 500000.0,
2455
+ "rope_type": "llama3"
2456
+ },
2457
+ "rope_scaling": {
2458
+ "factor": 2.0,
2459
+ "high_freq_factor": 32.0,
2460
+ "low_freq_factor": 1.0,
2461
+ "original_max_position_embeddings": 131072,
2462
+ "rope_type": "llama3"
2463
+ },
2464
+ "rope_theta": [
2465
+ 5000000.0,
2466
+ 10000.0,
2467
+ 10000.0,
2468
+ 10000.0,
2469
+ 5000000.0,
2470
+ 10000.0,
2471
+ 10000.0,
2472
+ 10000.0,
2473
+ 5000000.0,
2474
+ 10000.0,
2475
+ 10000.0,
2476
+ 10000.0,
2477
+ 5000000.0,
2478
+ 10000.0,
2479
+ 10000.0,
2480
+ 10000.0,
2481
+ 5000000.0,
2482
+ 10000.0,
2483
+ 10000.0,
2484
+ 10000.0,
2485
+ 5000000.0,
2486
+ 10000.0,
2487
+ 10000.0,
2488
+ 10000.0,
2489
+ 5000000.0,
2490
+ 10000.0,
2491
+ 10000.0,
2492
+ 10000.0,
2493
+ 5000000.0,
2494
+ 10000.0,
2495
+ 10000.0,
2496
+ 10000.0,
2497
+ 5000000.0,
2498
+ 10000.0,
2499
+ 10000.0,
2500
+ 10000.0,
2501
+ 5000000.0,
2502
+ 10000.0,
2503
+ 10000.0,
2504
+ 10000.0,
2505
+ 5000000.0,
2506
+ 10000.0,
2507
+ 10000.0,
2508
+ 10000.0,
2509
+ 5000000.0,
2510
+ 10000.0,
2511
+ 10000.0,
2512
+ 10000.0
2513
+ ],
2514
+ "share_expert_dim": 1280,
2515
+ "sink": false,
2516
+ "sliding_window": 512,
2517
+ "swiglu_limits": [
2518
+ 0.0,
2519
+ 0.0,
2520
+ 0.0,
2521
+ 0.0,
2522
+ 0.0,
2523
+ 0.0,
2524
+ 0.0,
2525
+ 0.0,
2526
+ 0.0,
2527
+ 0.0,
2528
+ 0.0,
2529
+ 0.0,
2530
+ 0.0,
2531
+ 0.0,
2532
+ 0.0,
2533
+ 0.0,
2534
+ 0.0,
2535
+ 0.0,
2536
+ 0.0,
2537
+ 0.0,
2538
+ 0.0,
2539
+ 0.0,
2540
+ 0.0,
2541
+ 0.0,
2542
+ 0.0,
2543
+ 0.0,
2544
+ 0.0,
2545
+ 0.0,
2546
+ 0.0,
2547
+ 0.0,
2548
+ 0.0,
2549
+ 0.0,
2550
+ 0.0,
2551
+ 0.0,
2552
+ 0.0,
2553
+ 0.0,
2554
+ 0.0,
2555
+ 0.0,
2556
+ 0.0,
2557
+ 0.0,
2558
+ 0.0,
2559
+ 0.0,
2560
+ 0.0,
2561
+ 7,
2562
+ 7
2563
+ ],
2564
+ "swiglu_limits_shared": [
2565
+ 0.0,
2566
+ 0.0,
2567
+ 0.0,
2568
+ 0.0,
2569
+ 0.0,
2570
+ 0.0,
2571
+ 0.0,
2572
+ 0.0,
2573
+ 0.0,
2574
+ 0.0,
2575
+ 0.0,
2576
+ 0.0,
2577
+ 0.0,
2578
+ 0.0,
2579
+ 0.0,
2580
+ 0.0,
2581
+ 0.0,
2582
+ 0.0,
2583
+ 0.0,
2584
+ 0.0,
2585
+ 0.0,
2586
+ 0.0,
2587
+ 0.0,
2588
+ 0.0,
2589
+ 0.0,
2590
+ 0.0,
2591
+ 0.0,
2592
+ 0.0,
2593
+ 0.0,
2594
+ 0.0,
2595
+ 0.0,
2596
+ 0.0,
2597
+ 0.0,
2598
+ 0.0,
2599
+ 0.0,
2600
+ 0.0,
2601
+ 0.0,
2602
+ 0.0,
2603
+ 0.0,
2604
+ 0.0,
2605
+ 0.0,
2606
+ 0.0,
2607
+ 0.0,
2608
+ 16,
2609
+ 16
2610
+ ],
2611
+ "torch_dtype": "bfloat16",
2612
+ "use_cache": true,
2613
+ "use_head_wise_attn_gate": true,
2614
+ "use_mfa": false,
2615
+ "use_moe": true,
2616
+ "use_moe_router_bias": true,
2617
+ "use_qk_norm": false,
2618
+ "use_rope_layers": [],
2619
+ "vocab_size": 128896,
2620
+ "yarn_only_types": [
2621
+ "full_attention"
2622
+ ]
2623
+ },
2624
+ "transformers_version": "4.56.2",
2625
+ "understand_projector_stride": 2,
2626
+ "use_cache": true,
2627
+ "use_im_start_end": "true",
2628
+ "vision_config": {
2629
+ "heads": 16,
2630
+ "hidden_act": "quick_gelu",
2631
+ "image_size": 728,
2632
+ "layer_norm_eps": 1e-05,
2633
+ "layers": 47,
2634
+ "ls_init_value": 0.1,
2635
+ "mlp_ratio": 5.833333333333333,
2636
+ "model_type": "perception_encoder",
2637
+ "num_channels": 3,
2638
+ "output_dim": null,
2639
+ "patch_size": 14,
2640
+ "pool_type": "none",
2641
+ "ues_cls_token": false,
2642
+ "use_abs_posemb": true,
2643
+ "use_cls_token": false,
2644
+ "use_ln_post": false,
2645
+ "use_ln_pre": true,
2646
+ "use_rope2d": true,
2647
+ "width": 1536
2648
+ },
2649
+ "vision_select_layer": -1
2650
+ }
configuration_step3p7.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Sequence, Union
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+ class StepRoboticsVisionEncoderConfig(PretrainedConfig):
6
+ model_type = "perception_encoder"
7
+
8
+ def __init__(
9
+ self,
10
+ width=1536,
11
+ layers=47,
12
+ heads=16,
13
+ num_channels=3,
14
+ image_size=728,
15
+ mlp_ratio = 8960/1536,
16
+ patch_size=14,
17
+ hidden_act="quick_gelu",
18
+ layer_norm_eps=1e-5,
19
+ ues_cls_token=False,
20
+ use_cls_token: Optional[bool] = None,
21
+ use_ln_pre=True,
22
+ use_ln_post=False,
23
+ use_abs_posemb=True,
24
+ use_rope2d=True,
25
+ ls_init_value=0.1,
26
+ **kwargs,
27
+ ):
28
+ self.width = width
29
+ self.layers = layers
30
+ self.heads = heads
31
+ self.num_channels = num_channels
32
+ self.patch_size = patch_size
33
+ self.image_size = image_size
34
+ self.mlp_ratio = mlp_ratio
35
+ self.layer_norm_eps = layer_norm_eps
36
+ self.hidden_act = hidden_act
37
+ if use_cls_token is None:
38
+ use_cls_token = ues_cls_token
39
+ self.ues_cls_token = use_cls_token
40
+ self.use_cls_token = use_cls_token
41
+ self.use_ln_pre = use_ln_pre
42
+ self.ls_init_value = ls_init_value
43
+ self.use_ln_post = use_ln_post
44
+ self.use_abs_posemb = use_abs_posemb
45
+ self.use_rope2d = use_rope2d
46
+ super().__init__(**kwargs)
47
+
48
+
49
+ class Step3p7TextConfig(PretrainedConfig):
50
+ model_type = "step3p5"
51
+ architectures = ["Step3p5ForCausalLM"]
52
+
53
+ def __init__(
54
+ self,
55
+ hidden_size: int = 4096,
56
+ intermediate_size: int = 11264,
57
+ num_attention_heads: int = 64,
58
+ num_attention_groups: int = 8,
59
+ num_hidden_layers: int = 45,
60
+ max_seq_len: int = 128000,
61
+ vocab_size: int = 128815,
62
+ rms_norm_eps: float = 1e-5,
63
+ moe_intermediate_size: int = 1280,
64
+ moe_num_experts: int = 288,
65
+ moe_top_k: int = 8,
66
+ rope_theta: float = 10000,
67
+ rope_scaling: Optional[dict[str, Any]] = None,
68
+ max_position_embeddings: int = 128000,
69
+ share_expert_dims: int = 1280,
70
+ share_expert_dim: Optional[int] = None,
71
+ head_dim: int = 128,
72
+ norm_expert_weight: bool = True,
73
+ layer_types: list[str] = None,
74
+ sliding_window: Optional[int] = None,
75
+ pad_token_id: int = 1,
76
+ attention_dropout: float = 0.0,
77
+ use_head_wise_attn_gate: bool = False,
78
+ use_moe_router_bias: bool = False,
79
+ moe_router_activation: str = "softmax",
80
+ moe_router_scaling_factor: float = 1.0,
81
+ need_fp32_gate: bool = False,
82
+ attention_other_setting: Optional[dict[str, Any]] = None,
83
+ swiglu_limits: Optional[list[Optional[float]]] = None,
84
+ swiglu_limits_shared: Optional[list[Optional[float]]] = None,
85
+ use_rope_layers: Optional[list[bool]] = None,
86
+ yarn_only_types: Optional[list[str]] = None,
87
+ moe_layers_enum: tuple[int] = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
88
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
89
+ 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
90
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44),
91
+ **kwargs,
92
+ ) -> None:
93
+ torch_dtype = kwargs.get("torch_dtype")
94
+ trim_layer_types = _normalize_per_layer_values(layer_types,
95
+ num_hidden_layers)
96
+ if isinstance(rope_scaling, dict):
97
+ rope_scaling = dict(rope_scaling)
98
+ if share_expert_dim is None:
99
+ share_expert_dim = share_expert_dims
100
+ self.hidden_size = hidden_size
101
+ self.intermediate_size = intermediate_size
102
+ self.num_attention_heads = num_attention_heads
103
+ self.num_attention_groups = num_attention_groups
104
+ self.num_hidden_layers = num_hidden_layers
105
+ self.max_seq_len = max_seq_len
106
+ self.vocab_size = vocab_size
107
+ self.rms_norm_eps = rms_norm_eps
108
+ self.moe_intermediate_size = moe_intermediate_size
109
+ self.moe_num_experts = moe_num_experts
110
+ self.moe_top_k = moe_top_k
111
+ self.rope_theta = rope_theta
112
+ self.rope_scaling = rope_scaling
113
+ self.max_position_embeddings = max_position_embeddings
114
+ self.share_expert_dim = share_expert_dim
115
+ self.head_dim = head_dim
116
+ self.norm_expert_weight = norm_expert_weight
117
+ self.moe_layers_enum = moe_layers_enum
118
+ self.layer_types = trim_layer_types
119
+ self.sliding_window = sliding_window
120
+ self.pad_token_id = pad_token_id
121
+ self.attention_dropout = attention_dropout
122
+ self.use_head_wise_attn_gate = use_head_wise_attn_gate
123
+ self.use_moe_router_bias = use_moe_router_bias
124
+ self.moe_router_activation = moe_router_activation
125
+ self.moe_router_scaling_factor = moe_router_scaling_factor
126
+ self.need_fp32_gate = need_fp32_gate
127
+ self.attention_other_setting = attention_other_setting
128
+ self.swiglu_limits = swiglu_limits
129
+ self.swiglu_limits_shared = swiglu_limits_shared
130
+ self.use_rope_layers = use_rope_layers
131
+ self.yarn_only_types = yarn_only_types
132
+ super().__init__(**kwargs)
133
+ if torch_dtype is not None:
134
+ self.torch_dtype = torch_dtype
135
+ self.layer_types = layer_types
136
+
137
+ def to_dict(self):
138
+ output = super().to_dict()
139
+ torch_dtype = getattr(self, "torch_dtype", None)
140
+ if torch_dtype is not None:
141
+ output["torch_dtype"] = torch_dtype
142
+ return output
143
+
144
+
145
+ def _normalize_per_layer_values(
146
+ values: Optional[Sequence[Any]],
147
+ num_hidden_layers: int,
148
+ ) -> Optional[list[Any]]:
149
+ if values is None:
150
+ return None
151
+ normalized = list(values)
152
+ if not normalized:
153
+ return normalized
154
+ if len(normalized) < num_hidden_layers:
155
+ normalized.extend([normalized[-1]] *
156
+ (num_hidden_layers - len(normalized)))
157
+ # Some checkpoints keep MTP/spec layer entries after the decoder layers.
158
+ # This config only builds num_hidden_layers decoder layers, and HF strict
159
+ # validation requires per-layer fields to match that decoder count.
160
+ return normalized[:num_hidden_layers]
161
+
162
+ class Step3p7Config(PretrainedConfig):
163
+ # This loader is a compatibility shim for original Step VL checkpoints
164
+ # whose top-level config model_type is `step3p7`.
165
+ model_type = "step3p7"
166
+
167
+ def __init__(
168
+ self,
169
+ vision_config: Optional[Union[dict, StepRoboticsVisionEncoderConfig]] = None,
170
+ text_config: Optional[Union[dict, Step3p7TextConfig]] = None,
171
+ understand_projector_stride: int = 2,
172
+ projector_bias: bool = False,
173
+ image_token_id: int = 151679,
174
+ **kwargs,
175
+ ) -> None:
176
+ shared_rope_scaling = kwargs.get("rope_scaling")
177
+ if isinstance(shared_rope_scaling, dict):
178
+ shared_rope_scaling = dict(shared_rope_scaling)
179
+
180
+ if vision_config is None:
181
+ vision_config = StepRoboticsVisionEncoderConfig()
182
+ elif isinstance(vision_config, dict):
183
+ vision_config = StepRoboticsVisionEncoderConfig(**vision_config)
184
+ self.vision_config = vision_config
185
+
186
+ if text_config is None:
187
+ text_config = Step3p7TextConfig(rope_scaling=shared_rope_scaling)
188
+ elif isinstance(text_config, dict):
189
+ text_config = dict(text_config)
190
+ if shared_rope_scaling is not None and "rope_scaling" not in text_config:
191
+ text_config["rope_scaling"] = shared_rope_scaling
192
+ text_config = Step3p7TextConfig(**text_config)
193
+ elif shared_rope_scaling is not None and text_config.rope_scaling is None:
194
+ text_config.rope_scaling = dict(shared_rope_scaling)
195
+ self.text_config = text_config
196
+
197
+ rope_scaling = kwargs.get("rope_scaling")
198
+ if isinstance(rope_scaling, dict):
199
+ kwargs["rope_scaling"] = dict(rope_scaling)
200
+
201
+ self.understand_projector_stride = understand_projector_stride
202
+ self.projector_bias = projector_bias
203
+ self.hidden_size = text_config.hidden_size
204
+ self.max_position_embeddings = text_config.max_position_embeddings
205
+ self.image_token_id = image_token_id
206
+ # Help Auto classes find the correct implementation when saving/loading.
207
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": [
5
+ 1,
6
+ 2,
7
+ 128007
8
+ ],
9
+ "transformers_version": "4.56.2"
10
+ }
hf_quant_config.json ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "producer": {
3
+ "name": "modelopt",
4
+ "version": "0.45.0.dev37+g3ad4f4f09.d20260524"
5
+ },
6
+ "quantization": {
7
+ "quant_algo": "NVFP4",
8
+ "kv_cache_quant_algo": "FP8",
9
+ "group_size": 16,
10
+ "exclude_modules": [
11
+ "lm_head",
12
+ "model.language_model.layers.0*",
13
+ "model.language_model.layers.1.*",
14
+ "model.language_model.layers.10.moe.gate",
15
+ "model.language_model.layers.10.self_attn*",
16
+ "model.language_model.layers.10.share_expert*",
17
+ "model.language_model.layers.11.moe.gate",
18
+ "model.language_model.layers.11.self_attn*",
19
+ "model.language_model.layers.11.share_expert*",
20
+ "model.language_model.layers.12.moe.gate",
21
+ "model.language_model.layers.12.self_attn*",
22
+ "model.language_model.layers.12.share_expert*",
23
+ "model.language_model.layers.13.moe.gate",
24
+ "model.language_model.layers.13.self_attn*",
25
+ "model.language_model.layers.13.share_expert*",
26
+ "model.language_model.layers.14.moe.gate",
27
+ "model.language_model.layers.14.self_attn*",
28
+ "model.language_model.layers.14.share_expert*",
29
+ "model.language_model.layers.15.moe.gate",
30
+ "model.language_model.layers.15.self_attn*",
31
+ "model.language_model.layers.15.share_expert*",
32
+ "model.language_model.layers.16.moe.gate",
33
+ "model.language_model.layers.16.self_attn*",
34
+ "model.language_model.layers.16.share_expert*",
35
+ "model.language_model.layers.17.moe.gate",
36
+ "model.language_model.layers.17.self_attn*",
37
+ "model.language_model.layers.17.share_expert*",
38
+ "model.language_model.layers.18.moe.gate",
39
+ "model.language_model.layers.18.self_attn*",
40
+ "model.language_model.layers.18.share_expert*",
41
+ "model.language_model.layers.19.moe.gate",
42
+ "model.language_model.layers.19.self_attn*",
43
+ "model.language_model.layers.19.share_expert*",
44
+ "model.language_model.layers.2.*",
45
+ "model.language_model.layers.20.moe.gate",
46
+ "model.language_model.layers.20.self_attn*",
47
+ "model.language_model.layers.20.share_expert*",
48
+ "model.language_model.layers.21.moe.gate",
49
+ "model.language_model.layers.21.self_attn*",
50
+ "model.language_model.layers.21.share_expert*",
51
+ "model.language_model.layers.22.moe.gate",
52
+ "model.language_model.layers.22.self_attn*",
53
+ "model.language_model.layers.22.share_expert*",
54
+ "model.language_model.layers.23.moe.gate",
55
+ "model.language_model.layers.23.self_attn*",
56
+ "model.language_model.layers.23.share_expert*",
57
+ "model.language_model.layers.24.moe.gate",
58
+ "model.language_model.layers.24.self_attn*",
59
+ "model.language_model.layers.24.share_expert*",
60
+ "model.language_model.layers.25.moe.gate",
61
+ "model.language_model.layers.25.self_attn*",
62
+ "model.language_model.layers.25.share_expert*",
63
+ "model.language_model.layers.26.moe.gate",
64
+ "model.language_model.layers.26.self_attn*",
65
+ "model.language_model.layers.26.share_expert*",
66
+ "model.language_model.layers.27.moe.gate",
67
+ "model.language_model.layers.27.self_attn*",
68
+ "model.language_model.layers.27.share_expert*",
69
+ "model.language_model.layers.28.moe.gate",
70
+ "model.language_model.layers.28.self_attn*",
71
+ "model.language_model.layers.28.share_expert*",
72
+ "model.language_model.layers.29.moe.gate",
73
+ "model.language_model.layers.29.self_attn*",
74
+ "model.language_model.layers.29.share_expert*",
75
+ "model.language_model.layers.3.moe.gate",
76
+ "model.language_model.layers.3.self_attn*",
77
+ "model.language_model.layers.3.share_expert*",
78
+ "model.language_model.layers.30.moe.gate",
79
+ "model.language_model.layers.30.self_attn*",
80
+ "model.language_model.layers.30.share_expert*",
81
+ "model.language_model.layers.31.moe.gate",
82
+ "model.language_model.layers.31.self_attn*",
83
+ "model.language_model.layers.31.share_expert*",
84
+ "model.language_model.layers.32.moe.gate",
85
+ "model.language_model.layers.32.self_attn*",
86
+ "model.language_model.layers.32.share_expert*",
87
+ "model.language_model.layers.33.moe.gate",
88
+ "model.language_model.layers.33.self_attn*",
89
+ "model.language_model.layers.33.share_expert*",
90
+ "model.language_model.layers.34.moe.gate",
91
+ "model.language_model.layers.34.self_attn*",
92
+ "model.language_model.layers.34.share_expert*",
93
+ "model.language_model.layers.35.moe.gate",
94
+ "model.language_model.layers.35.self_attn*",
95
+ "model.language_model.layers.35.share_expert*",
96
+ "model.language_model.layers.36.moe.gate",
97
+ "model.language_model.layers.36.self_attn*",
98
+ "model.language_model.layers.36.share_expert*",
99
+ "model.language_model.layers.37.moe.gate",
100
+ "model.language_model.layers.37.self_attn*",
101
+ "model.language_model.layers.37.share_expert*",
102
+ "model.language_model.layers.38.moe.gate",
103
+ "model.language_model.layers.38.self_attn*",
104
+ "model.language_model.layers.38.share_expert*",
105
+ "model.language_model.layers.39.moe.gate",
106
+ "model.language_model.layers.39.self_attn*",
107
+ "model.language_model.layers.39.share_expert*",
108
+ "model.language_model.layers.4.moe.gate",
109
+ "model.language_model.layers.4.self_attn*",
110
+ "model.language_model.layers.4.share_expert*",
111
+ "model.language_model.layers.40.moe.gate",
112
+ "model.language_model.layers.40.self_attn*",
113
+ "model.language_model.layers.40.share_expert*",
114
+ "model.language_model.layers.41.moe.gate",
115
+ "model.language_model.layers.41.self_attn*",
116
+ "model.language_model.layers.41.share_expert*",
117
+ "model.language_model.layers.42.moe.gate",
118
+ "model.language_model.layers.42.self_attn*",
119
+ "model.language_model.layers.42.share_expert*",
120
+ "model.language_model.layers.43.moe.gate",
121
+ "model.language_model.layers.43.self_attn*",
122
+ "model.language_model.layers.43.share_expert*",
123
+ "model.language_model.layers.44.moe.gate",
124
+ "model.language_model.layers.44.self_attn*",
125
+ "model.language_model.layers.44.share_expert*",
126
+ "model.language_model.layers.5.moe.gate",
127
+ "model.language_model.layers.5.self_attn*",
128
+ "model.language_model.layers.5.share_expert*",
129
+ "model.language_model.layers.6.moe.gate",
130
+ "model.language_model.layers.6.self_attn*",
131
+ "model.language_model.layers.6.share_expert*",
132
+ "model.language_model.layers.7.moe.gate",
133
+ "model.language_model.layers.7.self_attn*",
134
+ "model.language_model.layers.7.share_expert*",
135
+ "model.language_model.layers.8.moe.gate",
136
+ "model.language_model.layers.8.self_attn*",
137
+ "model.language_model.layers.8.share_expert*",
138
+ "model.language_model.layers.9.moe.gate",
139
+ "model.language_model.layers.9.self_attn*",
140
+ "model.language_model.layers.9.share_expert*",
141
+ "model.vision_model*",
142
+ "model.vit_large_projector"
143
+ ]
144
+ }
145
+ }
jang_config.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": {
3
+ "has_audio": false,
4
+ "has_mtp_tensors": false,
5
+ "has_vision": true,
6
+ "text_model_type": "step3p5",
7
+ "type": "step3p7"
8
+ },
9
+ "capabilities": {
10
+ "cache_type": "kv",
11
+ "family": "step3p7",
12
+ "modality": "vision",
13
+ "reasoning_parser": "qwen3",
14
+ "supports_thinking": true,
15
+ "supports_tools": true,
16
+ "think_in_template": true,
17
+ "tool_parser": "step3p5"
18
+ },
19
+ "format": "jang",
20
+ "format_version": "2.0",
21
+ "mxtq_bits": {
22
+ "attention": 8,
23
+ "embedding": 6,
24
+ "routed_expert": {
25
+ "down_proj": 3,
26
+ "gate_proj": 4,
27
+ "up_proj": 2
28
+ }
29
+ },
30
+ "quantization": {
31
+ "actual_bits": 3.398,
32
+ "bit_widths_used": [
33
+ 2,
34
+ 3,
35
+ 4,
36
+ 6,
37
+ 8
38
+ ],
39
+ "block_size": 128,
40
+ "method": "jang-importance",
41
+ "passthrough_bit_widths_used": [
42
+ 16,
43
+ 32
44
+ ],
45
+ "profile": "JANG_2L",
46
+ "quantization_backend": "mx.quantize",
47
+ "source_weight_decode": "modelopt-nvfp4-for-routed-moe",
48
+ "target_bits": 2.0,
49
+ "total_quantized_bits": 669270114304,
50
+ "total_source_bits": 3151291744256
51
+ },
52
+ "runtime": {
53
+ "requires": [
54
+ "step3p7-vlm-wrapper",
55
+ "step3p5-text-runtime",
56
+ "full-and-sliding-kv-cache",
57
+ "head-wise-attention-gate",
58
+ "qk-rmsnorm",
59
+ "kv-scale-sidecars",
60
+ "image-patch-processing"
61
+ ],
62
+ "shard_count": 67,
63
+ "total_shard_bytes": 87621944984
64
+ },
65
+ "source_model": {
66
+ "dtype": "nvfp4+bf16",
67
+ "hub_id": "stepfun-ai/Step-3.7-Flash-NVFP4",
68
+ "name": "Step-3.7-Flash-NVFP4"
69
+ }
70
+ }
model-00001-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b88593183efb5896bdeef84490dfa878117932160cf3230c0c5789dbaa3d46e
3
+ size 1144607880
model-00004-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8573cbf5d160da516142285db152e3d1876943ca35399d228c663d5bf8d0c62e
3
+ size 1585825616
model-00005-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:733c4efe2bfc6cb51ec149e6d99dbedbf1d1a3c53a465fad65dea67ea2f9879a
3
+ size 1339963408
model-00012-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3885bb3a2200821fe5b15c6c95a00ed599fdf79966334fcf371061dca436d31e
3
+ size 1150967584
model-00013-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:842aee57a603677db6b838c577391d6d24be5929de07aa63d963bb506efa16a3
3
+ size 1423239760
model-00016-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcd7c45a1b9103dd4cb51c4fdfe9910d550513900f2af18ad8935e79a87a1480
3
+ size 1165028216
model-00017-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2559644091cf2dd2f94d46ea46dfecaa56b8fdf800681de3465e9f7e56bb7371
3
+ size 1346111856
model-00018-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48ffc0acf2acee48ee0036f60940666c396c1401a9f3fdb597d25623b70ff0fb
3
+ size 1416833416
model-00019-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:380ea0e6e62ef7e292a726f62ca91646036632505042c91ba4178faee5bc3389
3
+ size 1123882448
model-00020-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:968d152f31d258863bc0111602286aede7046a48d0c16eec49eab23399af6625
3
+ size 1347358536
model-00021-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3a87a2791d54afc254b438f3f536ac37d26554ea717f517e6851edaf1f0ae35
3
+ size 1150967592
model-00024-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:939d77dc9873753c4007400d2f5303239ab3d7ee4cfef4ce5ed3c6d45d14b374
3
+ size 1416833416
model-00025-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c22ad29194eadf79192145c80c0b8adca58adb17156f874f46b4c30793ba5b52
3
+ size 1123874136
model-00027-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9b79dc8819dc1448b7940e860bbfdc1d495fa3c4da8efc86b8a988e25975235
3
+ size 1416833416
model-00032-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf3f72cd66b241ce7e5eed200194b804ffddadb0bdbafba08fd2a2a2b16c7938
3
+ size 1421986048
model-00033-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd934ec3a9d3f24d0ac9f867987cb2a732e215df256c3545e5eb08c0b498d687
3
+ size 1416833416
model-00036-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fe431a441ff5218f4be3cb2fc4d6649cf540b573a40c6413b206ea1a0f92c61
3
+ size 1416833416
model-00037-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:daa0ef4517f7fa47973a81e556bf947431112675232d5be6a71df966cae08b91
3
+ size 1123890768
model-00038-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e48e303487e60a94e39ac6bd9507a46984a107689c6518c3443b33e129be767
3
+ size 1346111856
model-00039-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ddafeca14d6b660665f05c93d27e2422917e220d3e68f61b113a0c02ede3316
3
+ size 1416833416
model-00040-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:508457293fdc5d853581fa2595dd425ce225bcf6da13a1024142aede3b874c66
3
+ size 1158620632
model-00041-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4509e0ec408fe787cfcf26e957c59d485aeb800e5cabe6981dd39e034381cc8f
3
+ size 1347358536
model-00044-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9461d83672ed03051de5fe05e6b63ec52e91147b2de62cb08e5852340f59745d
3
+ size 1346111856
model-00045-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6e2bf216b3dcefe2c8283ce0f6cb86c2323c2af6f0ee11f2e0f60a6c457d521
3
+ size 1416833416
model-00052-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88905dd2b383bc57a55f80cb4169f7e93edcf8459b12ef6e50ef3b50880ac600
3
+ size 1271490120
model-00053-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5dc0f262e77977417682aeafdf5a36176da698819f65a3d5c5bdddca215b1287
3
+ size 1233242360
model-00056-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1eacde34bf9e5524d3577991d2c26527a2af17b7afffcead26735ec3c1c55498
3
+ size 1347357280
model-00057-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2f2fd04b8d407595f9ad4bb868f7e6ee438dbf56d466da92fad594058abfc26
3
+ size 1339712544
model-00058-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dcefc98844b20f9cee85fb362645799ef8de0e70d03cf1cef8666ecdd6981efe
3
+ size 1234496104
model-00059-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e48527de7312987a4195201d741c578073ee75ef046a751a4e121e15f089354a
3
+ size 1346111856
model-00060-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5f05a4f37f9726b52c3e50e599fecf5e59b4a8f1f2fd4a8f955ec49bc80e43e
3
+ size 1416833416
model-00061-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c17b661d43814171df00a41f1a619566a00d89e13e59e36d2baf2814f6953d7d
3
+ size 1123882448
model-00064-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38de06338a12ac7cba633527c89ad1d686e7a31628b5bc5ee317db61e8687277
3
+ size 1233242360
model-00065-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af2afdb70d9c9d9c4562405b0a16cdd8ce95fa5998925777dffa1276a7b7c0b2
3
+ size 1416833416
model-00066-of-00067.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09c29930d31f204a6e33a2de4f97727e4b396578bba40ba9023c13a0d25d5ef6
3
+ size 1158620632
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_step3p7.py ADDED
@@ -0,0 +1,1405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import copy
16
+ import inspect
17
+ from dataclasses import dataclass
18
+ from typing import Callable, Literal, Optional, Tuple, TypedDict, Union
19
+
20
+ from PIL import Image
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from transformers.activations import ACT2FN
26
+ from transformers.cache_utils import Cache, DynamicCache
27
+ from transformers.generation import GenerationMixin
28
+ from transformers.masking_utils import (
29
+ create_causal_mask,
30
+ create_sliding_window_causal_mask,
31
+ )
32
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
33
+ from transformers.modeling_layers import GradientCheckpointingLayer
34
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
35
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
+ from transformers.processing_utils import Unpack
38
+ from transformers.utils import TransformersKwargs, can_return_tuple, logging
39
+ from .configuration_step3p7 import Step3p7Config, Step3p7TextConfig
40
+ from .vision_encoder import StepRoboticsVisionEncoder
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+ _MASK_INPUT_EMBEDS_ARG = (
45
+ "inputs_embeds"
46
+ if "inputs_embeds" in inspect.signature(create_causal_mask).parameters
47
+ else "input_embeds"
48
+ )
49
+
50
+ __all__ = [
51
+ "Step3p7Model",
52
+ ]
53
+
54
+
55
+ class StepVLImagePixelInputs(TypedDict):
56
+ type: Literal["pixel_values"]
57
+ pixel_values: torch.Tensor
58
+ patch_pixel_values: Optional[torch.Tensor]
59
+ num_patches: list[int]
60
+
61
+
62
+ class StepVLImageEmbeddingInputs(TypedDict):
63
+ type: Literal["image_embeds"]
64
+ image_embeds: torch.Tensor
65
+
66
+
67
+ StepVLImageInputs = Union[StepVLImagePixelInputs, StepVLImageEmbeddingInputs]
68
+
69
+
70
+ def _flatten_embeddings(embeddings) -> torch.Tensor:
71
+ """
72
+ Recursively flattens and concatenates NestedTensors on all but the last
73
+ dimension.
74
+ """
75
+
76
+ if isinstance(embeddings, torch.Tensor):
77
+ # Flatten all but the last dimension.
78
+ return embeddings.flatten(0, -2)
79
+
80
+ return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
81
+
82
+ def _embedding_count_expression(embeddings) -> str:
83
+ """
84
+ Constructs a debugging representation of the number of embeddings in the
85
+ NestedTensors.
86
+ """
87
+
88
+ if isinstance(embeddings, torch.Tensor):
89
+ return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
90
+
91
+ return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
92
+
93
+
94
+ def _merge_multimodal_embeddings(
95
+ inputs_embeds: torch.Tensor,
96
+ is_multimodal: torch.Tensor,
97
+ multimodal_embeddings,
98
+ ) -> torch.Tensor:
99
+ """
100
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
101
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
102
+ ``input_ids``.
103
+ Note:
104
+ This updates ``inputs_embeds`` in place.
105
+ """
106
+ num_expected_tokens = is_multimodal.sum().item()
107
+ assert isinstance(num_expected_tokens, int)
108
+
109
+ flattened = _flatten_embeddings(multimodal_embeddings)
110
+ if flattened.shape[0] != num_expected_tokens:
111
+ expr = _embedding_count_expression(multimodal_embeddings)
112
+ raise ValueError(
113
+ f"Attempted to assign {expr} = {flattened.shape[0]} "
114
+ f"multimodal tokens to {num_expected_tokens} placeholders"
115
+ )
116
+
117
+ is_multimodal = is_multimodal.to(inputs_embeds.device)
118
+ flattened = flattened.to(inputs_embeds.device)
119
+ inputs_embeds[is_multimodal] = flattened
120
+ return inputs_embeds
121
+
122
+ def merge_multimodal_embeddings(
123
+ input_ids: torch.Tensor,
124
+ inputs_embeds: torch.Tensor,
125
+ multimodal_embeddings,
126
+ placeholder_token_id: Union[int, list[int]],
127
+ ) -> torch.Tensor:
128
+ """
129
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
130
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
131
+ ``input_ids``.
132
+
133
+ ``placeholder_token_id`` can be a list of token ids (e.g, token ids
134
+ of img_start, img_break, and img_end tokens) when needed: This means
135
+ the order of these tokens in the ``input_ids`` MUST MATCH the order of
136
+ their embeddings in ``multimodal_embeddings`` since we need to
137
+ slice-merge instead of individually scattering.
138
+ For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
139
+ - T is text token
140
+ - S is image start token
141
+ - I is image embedding token
142
+ - B is image break token
143
+ - E is image end token.
144
+
145
+ Then the image embeddings (that correspond to I's) from vision encoder
146
+ must be padded with embeddings of S, B, and E in the same order of
147
+ input_ids for a correct embedding merge.
148
+ Note:
149
+ This updates ``inputs_embeds`` in place.
150
+ """
151
+ if isinstance(placeholder_token_id, list):
152
+ placeholder_token_id = torch.tensor(
153
+ placeholder_token_id, device=input_ids.device
154
+ )
155
+ return _merge_multimodal_embeddings(
156
+ inputs_embeds,
157
+ torch.isin(input_ids, placeholder_token_id),
158
+ multimodal_embeddings,
159
+ )
160
+
161
+ return _merge_multimodal_embeddings(
162
+ inputs_embeds,
163
+ (input_ids == placeholder_token_id),
164
+ multimodal_embeddings,
165
+ )
166
+
167
+
168
+ class Step3p7PreTrainedModel(PreTrainedModel):
169
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
170
+ # can load the config instead of failing with a NoneType error.
171
+ config_class = Step3p7Config
172
+ supports_gradient_checkpointing = True
173
+ _skip_keys_device_placement = ["past_key_values"]
174
+ _keys_to_ignore_on_load_unexpected = [
175
+ r"model\.layers\.45\.*",
176
+ r"model\.layers\.46\.*",
177
+ r"model\.layers\.47\.*",
178
+ ]
179
+ _supports_flash_attn = False
180
+ _supports_sdpa = True
181
+ _supports_flex_attn = True
182
+ _supports_static_cache = True
183
+ _supports_attention_backend = True
184
+
185
+ @classmethod
186
+ def from_pretrained(
187
+ cls, pretrained_model_name_or_path, *model_args, **kwargs
188
+ ):
189
+ key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
190
+ if key_mapping is not None and kwargs.get("key_mapping") is None:
191
+ # Transformers only applies checkpoint renaming when key_mapping is
192
+ # passed explicitly; inheriting the class attribute alone is not enough.
193
+ kwargs["key_mapping"] = copy.deepcopy(key_mapping)
194
+ return super().from_pretrained(
195
+ pretrained_model_name_or_path, *model_args, **kwargs
196
+ )
197
+
198
+
199
+ class Step3p7RotaryEmbedding(nn.Module):
200
+ def __init__(self, config: Step3p7TextConfig, device=None, layer_idx=None):
201
+ super().__init__()
202
+ self.layer_idx = layer_idx
203
+ self.max_seq_len_cached = config.max_position_embeddings
204
+ self.original_max_seq_len = config.max_position_embeddings
205
+
206
+ rope_theta = config.rope_theta
207
+ if isinstance(rope_theta, list):
208
+ rope_theta = rope_theta[0 if layer_idx is None else layer_idx]
209
+
210
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
211
+ partial_rotary_factors = getattr(config, "partial_rotary_factors", None)
212
+ if partial_rotary_factors is not None:
213
+ partial_rotary_factor = partial_rotary_factors[
214
+ 0 if layer_idx is None else layer_idx
215
+ ]
216
+
217
+ self.rope_theta = rope_theta
218
+ self.partial_rotary_factor = partial_rotary_factor
219
+
220
+ self.config = copy.copy(config)
221
+ self.config.rope_theta = rope_theta
222
+ self.config.partial_rotary_factor = partial_rotary_factor
223
+
224
+ if config.rope_parameters is not None:
225
+ self.config.rope_parameters = copy.deepcopy(config.rope_parameters)
226
+ self.config.rope_parameters["rope_theta"] = rope_theta
227
+ self.config.rope_parameters["partial_rotary_factor"] = (
228
+ partial_rotary_factor
229
+ )
230
+ self.rope_type = self.config.rope_parameters.get(
231
+ "rope_type", self.config.rope_parameters.get("type")
232
+ )
233
+ else:
234
+ self.rope_type = "default"
235
+
236
+ self.rope_init_fn = self.compute_default_rope_parameters
237
+ if self.rope_type != "default":
238
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
239
+ inv_freq, self.attention_scaling = self.rope_init_fn(
240
+ self.config, device
241
+ )
242
+
243
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
244
+ self.original_inv_freq = self.inv_freq
245
+
246
+ @torch.no_grad()
247
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
248
+ def forward(self, x, position_ids):
249
+ inv_freq_expanded = (
250
+ self.inv_freq[None, :, None]
251
+ .float()
252
+ .expand(position_ids.shape[0], -1, 1)
253
+ .to(x.device)
254
+ )
255
+ position_ids_expanded = position_ids[:, None, :].float().to(x.device)
256
+
257
+ device_type = (
258
+ x.device.type
259
+ if isinstance(x.device.type, str) and x.device.type != "mps"
260
+ else "cpu"
261
+ )
262
+ with torch.autocast(
263
+ device_type=device_type, enabled=False
264
+ ): # Force float32
265
+ freqs = (
266
+ inv_freq_expanded.float() @ position_ids_expanded.float()
267
+ ).transpose(1, 2)
268
+ emb = torch.cat((freqs, freqs), dim=-1)
269
+ cos = emb.cos() * self.attention_scaling
270
+ sin = emb.sin() * self.attention_scaling
271
+
272
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
273
+
274
+ @staticmethod
275
+ def compute_default_rope_parameters(
276
+ config: Step3p7TextConfig | None = None,
277
+ device: Optional["torch.device"] = None,
278
+ ) -> tuple["torch.Tensor", float]:
279
+ """
280
+ Computes the inverse frequencies according to the original RoPE implementation
281
+ Args:
282
+ config ([`~transformers.PreTrainedConfig`]):
283
+ The model configuration.
284
+ device (`torch.device`):
285
+ The device to use for initialization of the inverse frequencies.
286
+ seq_len (`int`, *optional*):
287
+ The current sequence length. Unused for this type of RoPE.
288
+ Returns:
289
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
290
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
291
+ """
292
+ base = config.rope_theta
293
+ partial_rotary_factor = getattr(
294
+ config, "partial_rotary_factor", 1.0
295
+ )
296
+ head_dim = (
297
+ getattr(config, "head_dim", None)
298
+ or config.hidden_size // config.num_attention_heads
299
+ )
300
+ dim = int(head_dim * partial_rotary_factor)
301
+
302
+ attention_factor = 1.0 # Unused in this type of RoPE
303
+
304
+ # Compute the inverse frequencies
305
+ inv_freq = 1.0 / (
306
+ base
307
+ ** (
308
+ torch.arange(0, dim, 2, dtype=torch.int64).to(
309
+ device=device, dtype=torch.float
310
+ )
311
+ / dim
312
+ )
313
+ )
314
+ return inv_freq, attention_factor
315
+
316
+ def rotate_half(x):
317
+ """Rotates half the hidden dims of the input."""
318
+ x1 = x[..., :x.shape[-1] // 2]
319
+ x2 = x[..., x.shape[-1] // 2:]
320
+ return torch.cat((-x2, x1), dim=-1)
321
+
322
+
323
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
324
+ """Applies Rotary Position Embedding to the query and key tensors.
325
+
326
+ Args:
327
+ q (`torch.Tensor`): The query tensor.
328
+ k (`torch.Tensor`): The key tensor.
329
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
330
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
331
+ position_ids (`torch.Tensor`, *optional*):
332
+ Deprecated and unused.
333
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
334
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
335
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
336
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
337
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
338
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
339
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
340
+ Returns:
341
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
342
+ """
343
+ rotary_dim = cos.shape[-1]
344
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
345
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
346
+
347
+ # Apply rotary embeddings on the first half or full tensor
348
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
349
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
350
+
351
+ # Concatenate back to full shape
352
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
353
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
354
+ return q_embed, k_embed
355
+
356
+
357
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
358
+ """
359
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
360
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
361
+ """
362
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
363
+ if n_rep == 1:
364
+ return hidden_states
365
+ hidden_states = hidden_states[:, :, None, :, :].expand(
366
+ batch, num_key_value_heads, n_rep, slen, head_dim
367
+ )
368
+ return hidden_states.reshape(
369
+ batch, num_key_value_heads * n_rep, slen, head_dim
370
+ )
371
+
372
+
373
+ # Adapted from transformers.models.llama.modeling_llama.eager_attention_forward.
374
+ # Llama4 does not cast attention weights to fp32 here.
375
+ def eager_attention_forward(
376
+ module: nn.Module,
377
+ query: torch.Tensor,
378
+ key: torch.Tensor,
379
+ value: torch.Tensor,
380
+ attention_mask: Optional[torch.Tensor],
381
+ scaling: float,
382
+ dropout: float = 0.0,
383
+ **kwargs,
384
+ ):
385
+ key_states = repeat_kv(key, module.num_key_value_groups)
386
+ value_states = repeat_kv(value, module.num_key_value_groups)
387
+ # breakpoint()
388
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
389
+ if attention_mask is not None:
390
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
391
+ attn_weights = attn_weights + causal_mask
392
+
393
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
394
+ attn_weights = nn.functional.dropout(
395
+ attn_weights, p=dropout, training=module.training
396
+ )
397
+ attn_output = torch.matmul(attn_weights, value_states)
398
+ attn_output = attn_output.transpose(1, 2).contiguous()
399
+
400
+ return attn_output, attn_weights
401
+
402
+
403
+ @dataclass
404
+ class Step3p7CausalLMOutputWithPast(ModelOutput):
405
+ r"""
406
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
407
+ Language modeling loss (for next-token prediction).
408
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
409
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
410
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
411
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
412
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
413
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
414
+ `past_key_values` input) to speed up sequential decoding.
415
+ """
416
+
417
+ loss: Optional[torch.FloatTensor] = None
418
+ last_hidden_state: Optional[torch.FloatTensor] = None
419
+ logits: torch.FloatTensor = None
420
+ past_key_values: Optional[list[torch.FloatTensor]] = None
421
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
422
+ attentions: Optional[tuple[torch.FloatTensor]] = None
423
+
424
+
425
+ class Step3p7MLP(nn.Module):
426
+ def __init__(self, config, intermediate_size=None, swiglu_limit=None):
427
+ super().__init__()
428
+ self.config = config
429
+ self.hidden_size = config.hidden_size
430
+ self.intermediate_size = (
431
+ intermediate_size
432
+ if intermediate_size is not None
433
+ else config.intermediate_size
434
+ )
435
+ self.gate_proj = nn.Linear(self.hidden_size,
436
+ self.intermediate_size,
437
+ bias=False)
438
+ self.up_proj = nn.Linear(self.hidden_size,
439
+ self.intermediate_size,
440
+ bias=False)
441
+ self.down_proj = nn.Linear(self.intermediate_size,
442
+ self.hidden_size,
443
+ bias=False)
444
+ self.act_fn = ACT2FN["silu"]
445
+ self.limit = swiglu_limit
446
+
447
+ def forward(self, x):
448
+ up = self.up_proj(x)
449
+ gate = self.act_fn(self.gate_proj(x))
450
+ if self.limit is not None:
451
+ gate = gate.clamp(min=None, max=self.limit)
452
+ up = up.clamp(min=-self.limit, max=self.limit)
453
+
454
+ return self.down_proj(gate * up)
455
+
456
+
457
+ def sigmoid_routing_function(gating_output: torch.Tensor, topk: int,
458
+ renormalize: bool):
459
+ gating_output = gating_output.float()
460
+ gate_prob = torch.sigmoid(gating_output)
461
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
462
+ topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1)
463
+ expert_topk_weight = topk_prob
464
+ if renormalize:
465
+ expert_topk_weight = expert_topk_weight / torch.sum(
466
+ expert_topk_weight, dim=-1, keepdim=True)
467
+ return expert_topk_weight, indices
468
+
469
+
470
+ def softmax_routing_function(gating_output: torch.Tensor, top_k: int,
471
+ renormalize: bool):
472
+ gating_output = gating_output.float()
473
+ gate_prob = torch.softmax(gating_output, dim=-1)
474
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
475
+ topk_prob, indices = torch.topk(gate_prob, k=top_k, dim=1)
476
+ expert_topk_weight = topk_prob
477
+ if renormalize:
478
+ expert_topk_weight = expert_topk_weight / torch.sum(
479
+ expert_topk_weight, dim=-1, keepdim=True)
480
+ return expert_topk_weight, indices.to(torch.int32)
481
+
482
+
483
+ class MoELinear(nn.Module):
484
+
485
+ def __init__(self, num_experts, in_features, out_features):
486
+ super().__init__()
487
+ self.num_experts = num_experts
488
+ self.in_features = in_features
489
+ self.out_features = out_features
490
+ self.weight = nn.Parameter(
491
+ torch.empty(num_experts, out_features, in_features))
492
+
493
+ def forward(self, x, expert_id):
494
+ x = F.linear(x.float(), self.weight[expert_id].float())
495
+ return x
496
+
497
+
498
+ class Step3p7MoEMLP(nn.Module):
499
+
500
+ def __init__(self, config, swiglu_limit=None):
501
+ super().__init__()
502
+ self.num_experts = config.moe_num_experts
503
+ self.top_k = config.moe_top_k
504
+ self.hidden_size = config.hidden_size
505
+ self.moe_intermediate_size = config.moe_intermediate_size
506
+
507
+ self.use_moe_router_bias = config.use_moe_router_bias
508
+ if self.use_moe_router_bias:
509
+ self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts,
510
+ dtype=torch.float32),
511
+ requires_grad=False)
512
+ self.custom_routing_function = self.router_bias_func
513
+ elif config.moe_router_activation == "sigmoid":
514
+ self.custom_routing_function = sigmoid_routing_function
515
+ else:
516
+ self.custom_routing_function = None
517
+ self.need_fp32_gate = config.need_fp32_gate
518
+ self.routed_scaling_factor = getattr(config,
519
+ "moe_router_scaling_factor", 1.0)
520
+
521
+ # gating
522
+ self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
523
+
524
+ self.act_fn = ACT2FN["silu"]
525
+ self.limit = swiglu_limit
526
+
527
+ self.up_proj = MoELinear(self.num_experts, self.hidden_size,
528
+ self.moe_intermediate_size)
529
+ self.gate_proj = MoELinear(self.num_experts, self.hidden_size,
530
+ self.moe_intermediate_size)
531
+ self.down_proj = MoELinear(self.num_experts,
532
+ self.moe_intermediate_size,
533
+ self.hidden_size)
534
+
535
+ def router_bias_func(self, gating_output: torch.Tensor, topk: int,
536
+ renormalize: bool):
537
+ gate_prob = torch.sigmoid(gating_output.float())
538
+ gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0)
539
+ _, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1)
540
+ topk_prob = torch.gather(gate_prob, 1, indices)
541
+ expert_topk_weight = topk_prob
542
+ if renormalize:
543
+ expert_topk_weight = expert_topk_weight / (
544
+ torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20)
545
+ return expert_topk_weight, indices
546
+
547
+ def get_expert_output(self, inputs: torch.Tensor, expert_id):
548
+ #if self.limit is None:
549
+ up = self.up_proj(inputs, expert_id)
550
+ gate = self.act_fn(self.gate_proj(inputs, expert_id))
551
+ if self.limit is not None:
552
+ gate = gate.clamp(min=None, max=self.limit)
553
+ up = up.clamp(min=-self.limit, max=self.limit)
554
+
555
+ return self.down_proj(gate * up, expert_id)
556
+
557
+ def forward(self, hidden_states):
558
+ """ """
559
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
560
+ hidden_states = hidden_states.view(-1, hidden_dim)
561
+ if self.need_fp32_gate:
562
+ router_logits = torch.matmul(
563
+ hidden_states.to(torch.float32),
564
+ self.gate.weight.t().to(torch.float32),
565
+ )
566
+ else:
567
+ # router_logits: (batch * sequence_length, n_experts)
568
+ router_logits = self.gate(hidden_states)
569
+
570
+ if self.custom_routing_function:
571
+ routing_weights, selected_experts = self.custom_routing_function(
572
+ router_logits, self.top_k, renormalize=True)
573
+ else:
574
+ routing_weights = F.softmax(router_logits,
575
+ dim=1,
576
+ dtype=torch.float)
577
+ routing_weights, selected_experts = torch.topk(routing_weights,
578
+ self.top_k,
579
+ dim=-1)
580
+
581
+ routing_weights = routing_weights * self.routed_scaling_factor
582
+
583
+ final_hidden_states = torch.zeros(
584
+ (batch_size * sequence_length, hidden_dim),
585
+ dtype=hidden_states.dtype,
586
+ device=hidden_states.device)
587
+
588
+ # One hot encode the selected experts to create an expert mask
589
+ # this will be used to easily index which expert is going to be sollicitated
590
+ expert_mask = torch.nn.functional.one_hot(
591
+ selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
592
+
593
+ # Loop over all available experts in the model and perform the computation on each expert
594
+ for expert_idx in range(self.num_experts):
595
+ idx, top_x = torch.where(expert_mask[expert_idx])
596
+
597
+ # Index the correct hidden states and compute the expert hidden state for
598
+ # the current expert. We need to make sure to multiply the output hidden
599
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
600
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
601
+ current_hidden_states = (
602
+ self.get_expert_output(current_state, expert_idx) *
603
+ routing_weights[top_x, idx, None])
604
+
605
+ # However `index_add_` only support torch tensors for indexing so we'll use
606
+ # the `top_x` tensor here.
607
+ final_hidden_states.index_add_(
608
+ 0, top_x, current_hidden_states.to(hidden_states.dtype))
609
+ final_hidden_states = final_hidden_states.reshape(
610
+ batch_size, sequence_length, hidden_dim)
611
+ return final_hidden_states
612
+
613
+
614
+ class Step3p7RMSNorm(nn.Module):
615
+
616
+ def __init__(
617
+ self,
618
+ hidden_size: int,
619
+ eps: float = 1e-5,
620
+ ) -> None:
621
+ super().__init__()
622
+ self.weight = nn.Parameter(torch.ones(hidden_size))
623
+ self.variance_epsilon = eps
624
+
625
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
626
+ dtype = x.dtype
627
+ x = x.float()
628
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
629
+ normed = x * torch.rsqrt(variance + self.variance_epsilon)
630
+ normed = normed * (self.weight.float() + 1)
631
+ return normed.to(dtype)
632
+ class Step3p7Attention(nn.Module):
633
+
634
+ def __init__(self, config: Step3p7TextConfig, layer_idx):
635
+ super().__init__()
636
+ self.config = config
637
+ self.layer_idx = layer_idx
638
+ self.num_attention_heads = config.num_attention_heads
639
+ self.num_key_value_heads = config.num_attention_groups
640
+
641
+ layer_types = getattr(config, "layer_types", [])
642
+ if layer_types:
643
+ enable_sliding_window = layer_types[
644
+ self.layer_idx] == "sliding_attention"
645
+ else:
646
+ enable_sliding_window = self.layer_idx % 2 == 0
647
+
648
+ yarn_only_types = getattr(config, "yarn_only_types", None)
649
+ if yarn_only_types and layer_types[
650
+ self.layer_idx] not in yarn_only_types:
651
+ config.rope_parameters = None
652
+ else:
653
+ config.rope_parameters = getattr(config, "rope_scaling", None)
654
+
655
+ self.sliding_window = config.sliding_window
656
+ if enable_sliding_window:
657
+ self.num_attention_heads = config.attention_other_setting[
658
+ "num_attention_heads"]
659
+ self.num_key_value_heads = config.attention_other_setting[
660
+ "num_attention_groups"]
661
+
662
+ if self.sliding_window is not None and enable_sliding_window:
663
+ self.sliding_window = (self.sliding_window)
664
+ else:
665
+ self.sliding_window = None
666
+ self.head_dim = getattr(config, "head_dim",
667
+ config.hidden_size // self.num_attention_heads)
668
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
669
+
670
+ self.rotary_emb = Step3p7RotaryEmbedding(config, layer_idx=layer_idx)
671
+
672
+ self.q_size = self.num_attention_heads * self.head_dim
673
+ self.kv_size = self.num_key_value_heads * self.head_dim
674
+ self.scaling = self.head_dim**-0.5
675
+
676
+ self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=False)
677
+ self.k_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
678
+ self.v_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
679
+ self.o_proj = nn.Linear(self.q_size, config.hidden_size, bias=False)
680
+ self.attention_dropout = getattr(config, "attention_dropout", 0.0)
681
+ self.q_norm = Step3p7RMSNorm(self.head_dim,
682
+ eps=config.rms_norm_eps)
683
+ self.k_norm = Step3p7RMSNorm(self.head_dim,
684
+ eps=config.rms_norm_eps)
685
+
686
+ self.use_head_wise_attn_gate = config.use_head_wise_attn_gate
687
+ if self.use_head_wise_attn_gate:
688
+ self.g_proj = nn.Linear(config.hidden_size,
689
+ self.num_attention_heads,
690
+ bias=False)
691
+
692
+ self.use_rope = True
693
+ use_rope_layers = getattr(config, "use_rope_layers", None)
694
+ if use_rope_layers:
695
+ self.use_rope = use_rope_layers[self.layer_idx]
696
+
697
+ def forward(
698
+ self,
699
+ hidden_states: torch.Tensor,
700
+ attention_mask: Optional[torch.Tensor],
701
+ past_key_value: Optional[Cache] = None,
702
+ cache_position: Optional[torch.LongTensor] = None,
703
+ position_ids: Optional[torch.LongTensor] = None,
704
+ **kwargs: Unpack[FlashAttentionKwargs],
705
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
706
+ Optional[Tuple[torch.Tensor]]]:
707
+ input_shape = hidden_states.shape[:-1]
708
+ hidden_shape = (*input_shape, -1, self.head_dim)
709
+
710
+ query_states = self.q_norm(
711
+ self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
712
+ key_states = self.k_norm(
713
+ self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
714
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(
715
+ 1, 2)
716
+ if self.use_head_wise_attn_gate:
717
+ gate_states = self.g_proj(hidden_states)
718
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
719
+
720
+ # cos, sin = position_embeddings
721
+ query_states, key_states = apply_rotary_pos_emb(
722
+ query_states, key_states, cos, sin)
723
+
724
+ # query_states, key_states = apply_rotary_pos_emb(query_norm_states, key_norm_states, cos, sin)
725
+ if past_key_value is not None:
726
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
727
+ cache_kwargs = {
728
+ "sin": sin,
729
+ "cos": cos,
730
+ "cache_position": cache_position
731
+ }
732
+ key_states, value_states = past_key_value.update(
733
+ key_states, value_states, self.layer_idx, cache_kwargs)
734
+
735
+ attention_interface: Callable = eager_attention_forward
736
+ # TODO: considering FP8;
737
+ # RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype,
738
+ # but got attn_mask.dtype: long int and query.dtype: c10::BFloat16 instead.
739
+ if self.config._attn_implementation != "eager":
740
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
741
+ self.config._attn_implementation]
742
+
743
+ attn_output, attn_weights = attention_interface(
744
+ self,
745
+ query_states,
746
+ key_states,
747
+ value_states,
748
+ attention_mask,
749
+ dropout=0.0 if not self.training else self.attention_dropout,
750
+ scaling=self.scaling,
751
+ sliding_window=self.sliding_window, # main diff with Llama
752
+ **kwargs,
753
+ )
754
+ attn_output = attn_output.reshape(*input_shape, -1)
755
+ if self.use_head_wise_attn_gate:
756
+ output = attn_output.view(
757
+ *attn_output.shape[:-1], self.num_attention_heads,
758
+ self.head_dim) * gate_states.unsqueeze(-1).sigmoid()
759
+ attn_output = output.view(*attn_output.shape)
760
+ attn_output = self.o_proj(attn_output)
761
+
762
+ return attn_output, attn_weights
763
+
764
+
765
+ class Step3p7DecoderLayer(GradientCheckpointingLayer):
766
+
767
+ def __init__(self, config, layer_idx):
768
+ super().__init__()
769
+ self.hidden_size = config.hidden_size
770
+ self.layer_idx = layer_idx
771
+ self.self_attn = Step3p7Attention(config, layer_idx)
772
+ layer_types = getattr(config, "layer_types", None) or []
773
+ if layer_types:
774
+ self.attention_type = layer_types[layer_idx]
775
+ else:
776
+ self.attention_type = (
777
+ "sliding_attention" if layer_idx % 2 == 0 else "full_attention"
778
+ )
779
+
780
+ moe_layers_enum = getattr(config, "moe_layers_enum", None)
781
+ if moe_layers_enum is not None:
782
+ if isinstance(moe_layers_enum, str):
783
+ moe_layers_idx = [
784
+ int(i) for i in moe_layers_enum.split(',') if i.strip()
785
+ ]
786
+ else:
787
+ moe_layers_idx = [int(i) for i in moe_layers_enum]
788
+ else:
789
+ moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
790
+ self.is_moe_layer = layer_idx in moe_layers_idx
791
+ self.use_moe = False
792
+
793
+ if config.swiglu_limits_shared and config.swiglu_limits_shared[
794
+ layer_idx] is not None and config.swiglu_limits_shared[
795
+ layer_idx] != 0:
796
+ swiglu_limit_shared = config.swiglu_limits_shared[layer_idx]
797
+ else:
798
+ swiglu_limit_shared = None
799
+ if config.swiglu_limits and config.swiglu_limits[
800
+ layer_idx] is not None and config.swiglu_limits[layer_idx] != 0:
801
+ swiglu_limit = config.swiglu_limits[layer_idx]
802
+ else:
803
+ swiglu_limit = None
804
+ if self.is_moe_layer:
805
+ self.moe = Step3p7MoEMLP(config, swiglu_limit=swiglu_limit) #
806
+ self.share_expert = Step3p7MLP(
807
+ config,
808
+ intermediate_size=config.share_expert_dim,
809
+ swiglu_limit=swiglu_limit_shared)
810
+ self.use_moe = True
811
+ else:
812
+ self.mlp = Step3p7MLP(config,
813
+ intermediate_size=config.intermediate_size,
814
+ swiglu_limit=swiglu_limit_shared)
815
+
816
+ self.input_layernorm = Step3p7RMSNorm(
817
+ config.hidden_size,
818
+ eps=config.rms_norm_eps)
819
+ self.post_attention_layernorm = Step3p7RMSNorm(
820
+ config.hidden_size,
821
+ eps=config.rms_norm_eps)
822
+
823
+ def forward(
824
+ self,
825
+ hidden_states: torch.Tensor,
826
+ attention_mask: Optional[torch.Tensor] = None,
827
+ position_ids: Optional[torch.LongTensor] = None,
828
+ past_key_value: Optional[tuple[torch.Tensor]] = None,
829
+ cache_position: Optional[torch.LongTensor] = None,
830
+ **kwargs: Unpack[FlashAttentionKwargs],
831
+ ) -> torch.FloatTensor:
832
+ residual = hidden_states
833
+ hidden_states = self.input_layernorm(hidden_states)
834
+ hidden_states, _ = self.self_attn(
835
+ hidden_states=hidden_states,
836
+ attention_mask=attention_mask,
837
+ position_ids=position_ids,
838
+ past_key_value=past_key_value,
839
+ cache_position=cache_position,
840
+ **kwargs,
841
+ )
842
+ hidden_states = residual + hidden_states
843
+
844
+ # Fully Connected
845
+ residual = hidden_states
846
+ hidden_states = self.post_attention_layernorm(hidden_states)
847
+ if self.use_moe:
848
+ share_output = self.share_expert(hidden_states)
849
+ moe_output = self.moe(hidden_states)
850
+ ffn_output = moe_output + share_output
851
+ else:
852
+ ffn_output = self.mlp(hidden_states)
853
+ if isinstance(ffn_output, tuple):
854
+ hidden_states, _ = ffn_output
855
+ else:
856
+ hidden_states = ffn_output
857
+
858
+ hidden_states = residual + hidden_states
859
+ return hidden_states
860
+
861
+
862
+ class Step3p7TextPreTrainedModel(PreTrainedModel):
863
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
864
+ # can load the config instead of failing with a NoneType error.
865
+ config_class = Step3p7TextConfig
866
+ supports_gradient_checkpointing = True
867
+ _skip_keys_device_placement = ["past_key_values"]
868
+ _keys_to_ignore_on_load_unexpected = [
869
+ r"model\.layers\.45\.*",
870
+ r"model\.layers\.46\.*",
871
+ r"model\.layers\.47\.*",
872
+ ]
873
+ _supports_flash_attn = False
874
+ _supports_sdpa = True
875
+ _supports_flex_attn = True
876
+ _supports_static_cache = True
877
+ _supports_attention_backend = True
878
+
879
+
880
+ class Step3p7TextModel(Step3p7TextPreTrainedModel, GenerationMixin):
881
+ _no_split_modules = ["Step3p7DecoderLayer"]
882
+ base_model_prefix = "model"
883
+ _tied_weights_keys = ["lm_head.weight"]
884
+ config: Step3p7TextConfig
885
+
886
+ def __init__(self, config: Step3p7TextConfig):
887
+ super().__init__(config)
888
+ self.padding_idx = config.pad_token_id
889
+ self.vocab_size = config.vocab_size
890
+
891
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
892
+ self.padding_idx)
893
+ self.layers = nn.ModuleList([
894
+ Step3p7DecoderLayer(config, layer_idx)
895
+ for layer_idx in range(config.num_hidden_layers)
896
+ ])
897
+ self.norm = Step3p7RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
898
+ self.gradient_checkpointing = False
899
+ layer_types = self.config.layer_types or []
900
+ self.has_sliding_layers = (not layer_types or
901
+ "sliding_attention" in layer_types)
902
+
903
+ # Initialize weights and apply final processing
904
+ self.post_init()
905
+
906
+
907
+ def get_input_embeddings(self, input_ids):
908
+ return self.embed_tokens(input_ids)
909
+
910
+ @can_return_tuple
911
+ def forward(
912
+ self,
913
+ input_ids: torch.LongTensor = None,
914
+ attention_mask: Optional[torch.Tensor] = None,
915
+ position_ids: Optional[torch.LongTensor] = None,
916
+ past_key_values: Optional[Cache] = None,
917
+ inputs_embeds: Optional[torch.FloatTensor] = None,
918
+ use_cache: Optional[bool] = None,
919
+ output_attentions: Optional[bool] = None,
920
+ output_hidden_states: Optional[bool] = None,
921
+ return_dict: Optional[bool] = None,
922
+ cache_position: Optional[torch.LongTensor] = None,
923
+ **kwargs: Unpack[TransformersKwargs],
924
+ ) -> Union[tuple, BaseModelOutputWithPast]:
925
+ output_attentions = (
926
+ output_attentions
927
+ if output_attentions is not None
928
+ else self.config.output_attentions
929
+ )
930
+ output_hidden_states = (
931
+ output_hidden_states
932
+ if output_hidden_states is not None
933
+ else self.config.output_hidden_states
934
+ )
935
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
936
+ return_dict = (
937
+ return_dict
938
+ if return_dict is not None
939
+ else getattr(self.config, "return_dict", True)
940
+ )
941
+ if (input_ids is None) ^ (inputs_embeds is not None):
942
+ raise ValueError(
943
+ "You must specify exactly one of input_ids or inputs_embeds")
944
+
945
+ if self.gradient_checkpointing and self.training and use_cache:
946
+ logger.warning_once(
947
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
948
+ )
949
+ use_cache = False
950
+
951
+ if inputs_embeds is None:
952
+ inputs_embeds = self.embed_tokens(
953
+ input_ids.to(self.embed_tokens.weight.device))
954
+
955
+ if use_cache and past_key_values is None:
956
+ past_key_values = DynamicCache()
957
+
958
+ if cache_position is None:
959
+ past_seen_tokens = past_key_values.get_seq_length(
960
+ ) if past_key_values is not None else 0
961
+ cache_position = torch.arange(past_seen_tokens,
962
+ past_seen_tokens +
963
+ inputs_embeds.shape[1],
964
+ device=inputs_embeds.device)
965
+
966
+ if position_ids is None:
967
+ position_ids = cache_position.unsqueeze(0)
968
+
969
+ hidden_states = inputs_embeds
970
+
971
+ # It may already have been prepared by e.g. `generate`
972
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
973
+ # Prepare mask arguments
974
+ mask_kwargs = {
975
+ "config": self.config,
976
+ "attention_mask": attention_mask,
977
+ "past_key_values": past_key_values,
978
+ "position_ids": position_ids,
979
+ }
980
+ mask_kwargs[_MASK_INPUT_EMBEDS_ARG] = inputs_embeds
981
+ # Create the masks
982
+ causal_mask_mapping = {
983
+ "full_attention": create_causal_mask(**mask_kwargs),
984
+ }
985
+
986
+ # The sliding window alternating layers are not always activated depending on the config
987
+ if self.has_sliding_layers:
988
+ causal_mask_mapping[
989
+ "sliding_attention"] = create_sliding_window_causal_mask(
990
+ **mask_kwargs)
991
+
992
+ # # create position embeddings to be shared across the decoder layers
993
+ # decoder layers
994
+ all_hidden_states = () if output_hidden_states else None
995
+ all_self_attns = () if output_attentions else None
996
+ for decoder_layer in self.layers[:self.config.num_hidden_layers]:
997
+ if output_hidden_states:
998
+ all_hidden_states += (hidden_states, )
999
+
1000
+ layer_outputs = decoder_layer(
1001
+ hidden_states,
1002
+ attention_mask=causal_mask_mapping[
1003
+ decoder_layer.attention_type],
1004
+ position_ids=position_ids,
1005
+ past_key_value=past_key_values,
1006
+ output_attentions=output_attentions,
1007
+ use_cache=use_cache,
1008
+ cache_position=cache_position,
1009
+ **kwargs,
1010
+ )
1011
+
1012
+ hidden_states = layer_outputs
1013
+
1014
+ hidden_states = self.norm(hidden_states)
1015
+
1016
+ return BaseModelOutputWithPast(
1017
+ last_hidden_state=hidden_states,
1018
+ past_key_values=past_key_values if use_cache else None,
1019
+ hidden_states=all_hidden_states,
1020
+ attentions=all_self_attns,
1021
+ )
1022
+
1023
+
1024
+ class Step3p7Model(Step3p7PreTrainedModel, GenerationMixin):
1025
+ config: Step3p7Config
1026
+ _tied_weights_keys = ["lm_head.weight"]
1027
+ base_model_prefix = ""
1028
+
1029
+ def __init__(self, config: Step3p7Config):
1030
+ super().__init__(config)
1031
+ self.vision_model = StepRoboticsVisionEncoder(config.vision_config)
1032
+ self.language_model = Step3p7TextModel(config.text_config)
1033
+ self.vocab_size = config.text_config.vocab_size
1034
+ self.vit_large_projector = nn.Linear(
1035
+ config.vision_config.width * 4,
1036
+ config.text_config.hidden_size,
1037
+ bias=config.projector_bias)
1038
+ self.image_placeholder_token_id = config.image_token_id
1039
+
1040
+ # Initialize weights and apply final processing
1041
+ self.post_init()
1042
+
1043
+ def get_input_embeddings(
1044
+ self,
1045
+ input_ids: torch.Tensor,
1046
+ multimodal_embeddings = None,
1047
+ ) -> torch.Tensor:
1048
+ # breakpoint()
1049
+ input_ids = input_ids.squeeze(0)
1050
+ if multimodal_embeddings is None:
1051
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
1052
+ else:
1053
+ is_text = input_ids != self.config.image_token_id
1054
+ text_ids = input_ids[is_text]
1055
+ text_embeds = self.language_model.get_input_embeddings(text_ids)
1056
+
1057
+ inputs_embeds = torch.empty(input_ids.shape[0],
1058
+ text_embeds.shape[-1],
1059
+ dtype=text_embeds.dtype,
1060
+ device=text_embeds.device)
1061
+ inputs_embeds[is_text] = text_embeds
1062
+ inputs_embeds = merge_multimodal_embeddings(
1063
+ input_ids, inputs_embeds, multimodal_embeddings,
1064
+ self.config.image_token_id)
1065
+ inputs_embeds = inputs_embeds.unsqueeze(0)
1066
+ return inputs_embeds
1067
+
1068
+
1069
+ def set_input_embeddings(self, value):
1070
+ return self.language_model.set_input_embeddings(value)
1071
+
1072
+ def set_decoder(self, decoder):
1073
+ self.language_model = decoder
1074
+
1075
+ def get_decoder(self):
1076
+ return self.language_model
1077
+
1078
+ def _parse_and_validate_image_input(
1079
+ self, **kwargs: object) -> Optional[StepVLImageInputs]:
1080
+ pixel_values = kwargs.pop("pixel_values", None)
1081
+ patch_pixel_values = kwargs.pop("patch_pixel_values", None)
1082
+ num_patches = kwargs.pop("num_patches", None)
1083
+ image_embeds = kwargs.pop("image_embeds", None)
1084
+
1085
+ if pixel_values is None and image_embeds is None:
1086
+ return None
1087
+
1088
+ if pixel_values is not None:
1089
+ # pixel_values = flatten_bn(pixel_values, concat=True)
1090
+ if pixel_values.dim() >= 3:
1091
+ pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
1092
+ if patch_pixel_values is not None:
1093
+ # patch_pixel_values = flatten_bn(patch_pixel_values,
1094
+ # concat=True)
1095
+ patch_pixel_values = patch_pixel_values.view(
1096
+ -1, *patch_pixel_values.shape[-3:])
1097
+ # Handle empty patch_pixel_values by setting to None
1098
+ if patch_pixel_values.shape[0] == 0:
1099
+ patch_pixel_values = None
1100
+
1101
+ return StepVLImagePixelInputs(
1102
+ type="pixel_values",
1103
+ pixel_values=pixel_values.to(self.dtype).to(self.device),
1104
+ patch_pixel_values=patch_pixel_values.to(self.dtype).to(
1105
+ self.device) if patch_pixel_values is not None else None,
1106
+ num_patches=num_patches,
1107
+ )
1108
+
1109
+ if image_embeds is not None:
1110
+ if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
1111
+ image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
1112
+ else:
1113
+ raise ValueError(
1114
+ f"Unexpected shape for image_embeds: {image_embeds.shape}")
1115
+
1116
+ return StepVLImageEmbeddingInputs(
1117
+ type="image_embeds",
1118
+ image_embeds=image_embeds.to(self.dtype).to(self.device),
1119
+ )
1120
+ return None
1121
+
1122
+ def _process_image_features(self,
1123
+ image_features: torch.Tensor) -> torch.Tensor:
1124
+ B, P = image_features.shape[:2]
1125
+ HW = int(P ** 0.5)
1126
+ image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
1127
+ image_features = self.vision_model.vit_downsampler1(image_features)
1128
+ image_features = self.vision_model.vit_downsampler2(image_features)
1129
+
1130
+ B, C, HW, HW = image_features.shape
1131
+ image_features = image_features.view(B, -1, HW * HW).permute(0, 2, 1)
1132
+ image_features = self.vit_large_projector(image_features)
1133
+ return image_features
1134
+
1135
+ def _get_vision_model_output(self,
1136
+ input_tensor: torch.Tensor) -> torch.Tensor:
1137
+ return self.vision_model(input_tensor)
1138
+
1139
+ def _process_image_input(
1140
+ self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]:
1141
+
1142
+ if image_input["type"] == "image_embeds":
1143
+ image_features = image_input["image_embeds"]
1144
+ else:
1145
+ image_features = self._get_vision_model_output(
1146
+ image_input["pixel_values"])
1147
+ patch_image_features = self._get_vision_model_output(
1148
+ image_input["patch_pixel_values"]
1149
+ ) if image_input["patch_pixel_values"] is not None else None
1150
+ num_patches = image_input["num_patches"]
1151
+
1152
+ image_features = self._process_image_features(image_features)
1153
+ patch_image_features = self._process_image_features(
1154
+ patch_image_features) if patch_image_features is not None else None
1155
+
1156
+ merged_image_features = []
1157
+ cur_patch_idx = 0
1158
+ for i, num_patch in enumerate(num_patches):
1159
+ cur_feature = []
1160
+ if num_patch > 0:
1161
+ patch_slice = patch_image_features[
1162
+ cur_patch_idx:cur_patch_idx + num_patch]
1163
+ cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
1164
+ cur_feature.append(image_features[i].view(
1165
+ -1, image_features.shape[-1]))
1166
+ cur_patch_idx += num_patch
1167
+ merged_image_features.append(
1168
+ torch.cat(cur_feature) if len(cur_feature) >
1169
+ 1 else cur_feature[0])
1170
+
1171
+ return merged_image_features
1172
+
1173
+ def get_multimodal_embeddings(self, **kwargs):
1174
+ # breakpoint()
1175
+ image_input = self._parse_and_validate_image_input(**kwargs)
1176
+ if image_input is None:
1177
+ return None
1178
+ vision_embeddings = self._process_image_input(image_input)
1179
+ return vision_embeddings
1180
+
1181
+ @can_return_tuple
1182
+ def forward(
1183
+ self,
1184
+ input_ids: torch.LongTensor = None,
1185
+ attention_mask: Optional[torch.Tensor] = None,
1186
+ position_ids: Optional[torch.LongTensor] = None,
1187
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
1188
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1189
+ labels: Optional[torch.LongTensor] = None,
1190
+ use_cache: Optional[bool] = None,
1191
+ output_attentions: Optional[bool] = None,
1192
+ output_hidden_states: Optional[bool] = None,
1193
+ return_dict: Optional[bool] = None,
1194
+ cache_position: Optional[torch.LongTensor] = None,
1195
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1196
+ images: Optional[list[Image.Image]] = None,
1197
+ **kwargs: Unpack[TransformersKwargs],
1198
+ ) -> Union[tuple, Step3p7CausalLMOutputWithPast]:
1199
+ r"""
1200
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1201
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1202
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1203
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1204
+ Example:
1205
+ ```python
1206
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
1207
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
1208
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
1209
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1210
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1211
+ >>> # Generate
1212
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1213
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1214
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1215
+ ```"""
1216
+ output_attentions = (
1217
+ output_attentions
1218
+ if output_attentions is not None
1219
+ else self.config.output_attentions
1220
+ )
1221
+ output_hidden_states = (
1222
+ output_hidden_states
1223
+ if output_hidden_states is not None
1224
+ else self.config.output_hidden_states
1225
+ )
1226
+ return_dict = (
1227
+ return_dict if return_dict is not None else self.config.use_return_dict
1228
+ )
1229
+
1230
+ if inputs_embeds is None:
1231
+ input_ids = input_ids
1232
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
1233
+ inputs_embeds = self.get_input_embeddings(input_ids,
1234
+ vision_embeddings)
1235
+ input_ids = None
1236
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1237
+ outputs = self.language_model(
1238
+ input_ids=None,
1239
+ position_ids=position_ids,
1240
+ attention_mask=attention_mask,
1241
+ past_key_values=past_key_values,
1242
+ inputs_embeds=inputs_embeds,
1243
+ use_cache=use_cache,
1244
+ output_attentions=output_attentions,
1245
+ output_hidden_states=output_hidden_states,
1246
+ return_dict=True,
1247
+ cache_position=cache_position,
1248
+ **kwargs,
1249
+ )
1250
+
1251
+ output = Step3p7CausalLMOutputWithPast(
1252
+ last_hidden_state=outputs.last_hidden_state,
1253
+ past_key_values=outputs.past_key_values,
1254
+ attentions=outputs.attentions,
1255
+ )
1256
+ return output if return_dict else output.to_tuple()
1257
+
1258
+
1259
+ class Step3p7ForConditionalGeneration(Step3p7PreTrainedModel, GenerationMixin):
1260
+ _checkpoint_conversion_mapping = {
1261
+ "^vision_model": "model.vision_model",
1262
+ r"^model(?!\.(language_model|vision_model))": "model.language_model",
1263
+ "^vit_large_projector": "model.vit_large_projector",
1264
+ }
1265
+ _tied_weights_keys = ["lm_head.weight"]
1266
+ config: Step3p7Config
1267
+
1268
+ def __init__(self, config: Step3p7Config):
1269
+ super().__init__(config)
1270
+ self.model = Step3p7Model(config)
1271
+ self.lm_head = nn.Linear(config.hidden_size,
1272
+ config.text_config.vocab_size,
1273
+ bias=False)
1274
+
1275
+ self.post_init()
1276
+
1277
+ def get_input_embeddings(self):
1278
+ return self.model.get_input_embeddings()
1279
+
1280
+ def set_input_embeddings(self, value):
1281
+ self.model.set_input_embeddings(value)
1282
+
1283
+ def get_output_embeddings(self):
1284
+ return self.model.get_output_embeddings()
1285
+
1286
+ def set_output_embeddings(self, new_embeddings):
1287
+ self.model.set_output_embeddings(new_embeddings)
1288
+
1289
+ def set_decoder(self, decoder):
1290
+ self.model.set_decoder(decoder)
1291
+
1292
+ def get_decoder(self):
1293
+ return self.model.get_decoder()
1294
+
1295
+ @property
1296
+ def language_model(self):
1297
+ return self.model.language_model
1298
+
1299
+ @property
1300
+ def visual(self):
1301
+ return self.model.vision_model
1302
+
1303
+ def forward(
1304
+ self,
1305
+ input_ids: torch.LongTensor = None,
1306
+ pixel_values: Optional[torch.Tensor] = None,
1307
+ num_patches=None,
1308
+ patch_pixel_values=None,
1309
+ patch_newline_mask=None,
1310
+ image_embeds: Optional[torch.FloatTensor] = None,
1311
+ attention_mask: Optional[torch.Tensor] = None,
1312
+ position_ids: Optional[torch.LongTensor] = None,
1313
+ past_key_values: Optional[Cache] = None,
1314
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1315
+ labels: Optional[torch.LongTensor] = None,
1316
+ use_cache: Optional[bool] = None,
1317
+ output_attentions: Optional[bool] = None,
1318
+ output_hidden_states: Optional[bool] = None,
1319
+ return_dict: Optional[bool] = None,
1320
+ cache_position: Optional[torch.LongTensor] = None,
1321
+ **kwargs: Unpack[TransformersKwargs],
1322
+ ) -> Union[tuple, Step3p7CausalLMOutputWithPast]:
1323
+ output_attentions = (
1324
+ output_attentions
1325
+ if output_attentions is not None
1326
+ else self.config.output_attentions
1327
+ )
1328
+ output_hidden_states = (
1329
+ output_hidden_states
1330
+ if output_hidden_states is not None
1331
+ else self.config.output_hidden_states
1332
+ )
1333
+
1334
+ outputs = self.model(
1335
+ input_ids=input_ids,
1336
+ num_patches=num_patches,
1337
+ patch_pixel_values=patch_pixel_values,
1338
+ patch_newline_mask=patch_newline_mask,
1339
+ position_ids=position_ids,
1340
+ attention_mask=attention_mask,
1341
+ past_key_values=past_key_values,
1342
+ inputs_embeds=inputs_embeds,
1343
+ use_cache=use_cache,
1344
+ output_attentions=output_attentions,
1345
+ output_hidden_states=output_hidden_states,
1346
+ return_dict=return_dict,
1347
+ cache_position=cache_position,
1348
+ **kwargs,
1349
+ )
1350
+
1351
+ hidden_states = outputs.last_hidden_state
1352
+ logits = self.lm_head(hidden_states)
1353
+
1354
+ los = None
1355
+ if labels is not None:
1356
+ loss = self.loss_function(
1357
+ logits=logits, labels=labels, vocab_size=self.config.vocab_size
1358
+ )
1359
+
1360
+ return Step3p7CausalLMOutputWithPast(
1361
+ logits=logits,
1362
+ )
1363
+
1364
+
1365
+ def prepare_inputs_for_generation(
1366
+ self,
1367
+ input_ids,
1368
+ past_key_values=None,
1369
+ inputs_embeds=None,
1370
+ pixel_values=None,
1371
+ patch_pixel_values=None,
1372
+ num_patches=None,
1373
+ image_embeds=None,
1374
+ attention_mask=None,
1375
+ cache_position=None,
1376
+ logits_to_keep=None,
1377
+ **kwargs,
1378
+ ):
1379
+ model_inputs = super().prepare_inputs_for_generation(
1380
+ input_ids,
1381
+ past_key_values=past_key_values,
1382
+ inputs_embeds=inputs_embeds,
1383
+ attention_mask=attention_mask,
1384
+ cache_position=cache_position,
1385
+ logits_to_keep=logits_to_keep,
1386
+ **kwargs,
1387
+ )
1388
+
1389
+ generation_cache_position = model_inputs.get("cache_position", cache_position)
1390
+ is_prefill = past_key_values is None
1391
+ if generation_cache_position is not None and generation_cache_position.numel() > 0:
1392
+ is_prefill = generation_cache_position[0].item() == 0
1393
+
1394
+ if is_prefill:
1395
+ # During cached decoding, input ids no longer contain image tokens,
1396
+ # so pixel values should only be passed at the first step.
1397
+ model_inputs["pixel_values"] = pixel_values
1398
+
1399
+ return model_inputs
1400
+
1401
+ def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
1402
+ if key.startswith("language_model."):
1403
+ return key[len("language_model.") :], True
1404
+
1405
+ return key, False
processing_step3.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BaseImageProcessor, ImageProcessingMixin
2
+ from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
3
+ import math
4
+ from typing import Iterable, Optional, Tuple, List, TypedDict, Literal, Union, overload
5
+
6
+ from PIL import Image
7
+ import torch
8
+ import numpy as np
9
+ import torchvision
10
+ from torch import nn
11
+ from torch.nn import functional as F, LayerNorm
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ from transformers.activations import ACT2FN
14
+ from torchvision import transforms
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from transformers.feature_extraction_utils import BatchFeature, TensorType
17
+ from transformers.image_utils import ImageInput
18
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
19
+ from transformers.tokenization_utils_tokenizers import TokenizersBackend
20
+ from math import ceil
21
+ from itertools import product
22
+
23
+
24
+
25
+ MAX_IMAGE_SIZE: int = 3024
26
+
27
+ class Step3VLImagePixelInputs(TypedDict):
28
+ type: Literal["pixel_values"]
29
+ pixel_values: torch.Tensor
30
+ patch_pixel_values: Optional[torch.Tensor]
31
+ num_patches: list[int]
32
+
33
+
34
+ class Step3VLImageEmbeddingInputs(TypedDict):
35
+ type: Literal["image_embeds"]
36
+ image_embeds: torch.Tensor
37
+
38
+
39
+ ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
40
+
41
+
42
+ class GPUToTensor(torch.nn.Module):
43
+
44
+ def forward(self, raw_image: Union[np.ndarray,
45
+ Image.Image]) -> torch.Tensor:
46
+ if isinstance(raw_image, Image.Image):
47
+ return transforms.ToTensor()(raw_image)
48
+ if raw_image.ndim == 2:
49
+ raw_image = raw_image[:, :, None].repeat(3, -1)
50
+ if torch.cuda.is_available():
51
+ device = torch.device("cuda")
52
+ else:
53
+ device = torch.device("cpu")
54
+ image_tensor = torch.from_numpy(raw_image).to(device)
55
+ image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous()
56
+ if image_tensor.dtype == torch.uint8:
57
+ image_tensor = image_tensor.to(torch.float32).div(255)
58
+ return image_tensor
59
+
60
+ class Step3VisionProcessor(BaseImageProcessor):
61
+
62
+ def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
63
+ mean = [0.48145466, 0.4578275, 0.40821073]
64
+ std = [0.26862954, 0.26130258, 0.27577711]
65
+ patch_size = patch_size if patch_size is not None else size
66
+
67
+ self.transform = transforms.Compose([
68
+ GPUToTensor(),
69
+ transforms.Normalize(mean, std),
70
+ transforms.Resize(
71
+ (size, size),
72
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
73
+ == "bicubic" else InterpolationMode.BILINEAR,
74
+ antialias=True),
75
+ ])
76
+
77
+ self.patch_transform = transforms.Compose([
78
+ GPUToTensor(),
79
+ transforms.Normalize(mean, std),
80
+ transforms.Resize(
81
+ (patch_size, patch_size),
82
+ interpolation=InterpolationMode.BICUBIC if interpolation_mode
83
+ == "bicubic" else InterpolationMode.BILINEAR,
84
+ antialias=True),
85
+ ]) if patch_size is not None else None
86
+
87
+ def __call__(self, image, is_patch=False):
88
+ if is_patch:
89
+ return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
90
+ else:
91
+ return {"pixel_values": self.transform(image).unsqueeze(0)}
92
+
93
+ class ImagePatcher:
94
+ def determine_window_size(self, long: int, short: int) -> int:
95
+ if long <= 728:
96
+ return short if long / short > 1.5 else 0
97
+ return min(short, 504) if long / short > 4 else 504
98
+ def slide_window(
99
+ self,
100
+ width: int,
101
+ height: int,
102
+ sizes: list[tuple[int, int]],
103
+ steps: list[tuple[int, int]],
104
+ img_rate_thr: float = 0.6,
105
+ ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]:
106
+ assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1"
107
+ windows = []
108
+ # Sliding windows.
109
+ for size, step in zip(sizes, steps):
110
+ size_w, size_h = size
111
+ step_w, step_h = step
112
+
113
+ x_num = 1 if width <= size_w else ceil((width - size_w) / step_w +
114
+ 1)
115
+ x_start = [step_w * i for i in range(x_num)]
116
+ if len(x_start) > 1 and x_start[-1] + size_w > width:
117
+ x_start[-1] = width - size_w
118
+
119
+ y_num = 1 if height <= size_h else ceil((height - size_h) /
120
+ step_h + 1)
121
+ y_start = [step_h * i for i in range(y_num)]
122
+ if len(y_start) > 1 and y_start[-1] + size_h > height:
123
+ y_start[-1] = height - size_h
124
+
125
+ start = np.array(list(product(y_start, x_start)), dtype=int)
126
+ start[:, [0, 1]] = start[:, [1, 0]]
127
+ windows.append(np.concatenate([start, start + size], axis=1))
128
+ windows = np.concatenate(windows, axis=0)
129
+
130
+ return [(int(box[0]), int(box[1]), int(box[2] - box[0]),
131
+ int(box[3] - box[1])) for box in windows], (x_num, y_num)
132
+
133
+ def square_pad(self, img: Image.Image) -> Image.Image:
134
+ w, h = img.size
135
+ if w == h:
136
+ return img
137
+ size = max(w, h)
138
+ padded = Image.new(img.mode, (size, size), 0)
139
+ padded.paste(img, (0, 0))
140
+ return padded
141
+
142
+ def get_image_size_for_padding(self, img_width: int,
143
+ img_height: int) -> tuple[int, int]:
144
+ ratio = img_width / img_height
145
+ if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4):
146
+ new_size = max(img_height, img_width)
147
+ return new_size, new_size
148
+ return img_width, img_height
149
+
150
+ def get_image_size_for_preprocess(self, img_width: int,
151
+ img_height: int) -> tuple[int, int]:
152
+
153
+ if max(img_height, img_width) > MAX_IMAGE_SIZE:
154
+ scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width)
155
+ img_width = int(img_width * scale_factor)
156
+ img_height = int(img_height * scale_factor)
157
+ return img_width, img_height
158
+
159
+ def get_image_size_for_crop(self, img_width: int, img_height: int,
160
+ window_size: int):
161
+ w_ratio = img_width / window_size
162
+ h_ratio = img_height / window_size
163
+
164
+ if w_ratio < 1:
165
+ width_new = img_width
166
+ else:
167
+ decimal_w = w_ratio - img_width // window_size
168
+ w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio)
169
+ width_new = window_size * w_ratio
170
+ if h_ratio < 1:
171
+ height_new = img_height
172
+ else:
173
+ decimal_h = h_ratio - img_height // window_size
174
+ h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio)
175
+ height_new = window_size * h_ratio
176
+ return int(width_new), int(height_new)
177
+
178
+ def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int):
179
+ target = img.crop((j, i, j + tw, i + th))
180
+ return target
181
+
182
+ def get_num_patches(self, img_width: int,
183
+ img_height: int) -> tuple[int, int]:
184
+ img_width, img_height = self.get_image_size_for_padding(
185
+ img_width, img_height)
186
+ img_width, img_height = self.get_image_size_for_preprocess(
187
+ img_width, img_height)
188
+ window_size = self.determine_window_size(max(img_height, img_width),
189
+ min(img_height, img_width))
190
+ if window_size == 0:
191
+ return 0, 0
192
+ else:
193
+ img_width, img_height = self.get_image_size_for_crop(
194
+ img_width, img_height, window_size)
195
+ center_list, (x_num, y_num) = self.slide_window(
196
+ img_width, img_height, [(window_size, window_size)],
197
+ [(window_size, window_size)])
198
+ full_rows = (len(center_list) - 1) // x_num + 1
199
+ if len(center_list) > 0 and len(center_list) % x_num == 0:
200
+ full_rows -= 1
201
+ return len(center_list), full_rows
202
+
203
+ def __call__(
204
+ self, img: Image.Image
205
+ ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
206
+ img_width, img_height = img.size
207
+ new_img_width, new_img_height = self.get_image_size_for_padding(
208
+ img_width, img_height)
209
+ if new_img_width != img_width or new_img_height != img_height:
210
+ img = self.square_pad(img)
211
+ img_width, img_height = img.size
212
+
213
+ new_img_width, new_img_height = self.get_image_size_for_preprocess(
214
+ img_width, img_height)
215
+ img = img.resize((new_img_width, new_img_height),
216
+ Image.Resampling.BILINEAR)
217
+ window_size = self.determine_window_size(
218
+ max(new_img_height, new_img_width),
219
+ min(new_img_height, new_img_width))
220
+ # return img, [], None
221
+ if window_size == 0:
222
+ return img, [], None
223
+ else:
224
+ new_img_width, new_img_height = self.get_image_size_for_crop(
225
+ new_img_width, new_img_height, window_size)
226
+ if (new_img_width, new_img_height) != (img_width, img_height):
227
+ img_for_crop = img.resize((new_img_width, new_img_height),
228
+ Image.Resampling.BILINEAR)
229
+ else:
230
+ img_for_crop = img
231
+
232
+ patches = []
233
+ newlines = []
234
+ center_list, (x_num, y_num) = self.slide_window(
235
+ new_img_width, new_img_height, [(window_size, window_size)],
236
+ [(window_size, window_size)])
237
+ for patch_id, center_lf_point in enumerate(center_list):
238
+ x, y, patch_w, patch_h = center_lf_point
239
+ big_patch = self.patch_crop(img_for_crop, y, x, patch_h,
240
+ patch_w)
241
+ patches.append(big_patch)
242
+ if (patch_id + 1) % x_num == 0:
243
+ newlines.append(patch_id)
244
+
245
+ if newlines and newlines[-1] == len(patches) - 1:
246
+ newlines.pop()
247
+
248
+ return img, patches, [i in newlines for i in range(len(patches))] if len(patches) > 0 else None
249
+
250
+
251
+
252
+
253
+ class Step3VLProcessor(ProcessorMixin):
254
+ # Align ProcessorMixin with our custom components.
255
+ # We only have an image processor (not a feature extractor) plus a tokenizer.
256
+ attributes = ["tokenizer"]
257
+ tokenizer_class = "AutoTokenizer"
258
+
259
+ @classmethod
260
+ def _load_tokenizer_from_pretrained(
261
+ cls, sub_processor_type, pretrained_model_name_or_path, subfolder="", **kwargs
262
+ ):
263
+ return TokenizersBackend.from_pretrained(
264
+ pretrained_model_name_or_path,
265
+ subfolder=subfolder,
266
+ **kwargs,
267
+ )
268
+
269
+ def __init__(
270
+ self,
271
+ tokenizer=None,
272
+ chat_template=None,
273
+ **kwargs
274
+ ) -> None:
275
+ self.image_size = 728
276
+ self.patch_size = 504
277
+
278
+ self.image_preprocessor = Step3VisionProcessor(self.image_size,
279
+ "bilinear",
280
+ self.patch_size)
281
+
282
+ self.num_image_feature_size = 169
283
+ self.num_patch_feature_size = 81
284
+ self.image_token = "<im_patch>"
285
+ self.image_feature_placeholder = (self.image_token *
286
+ self.num_image_feature_size)
287
+ self.patch_feature_placeholder = (self.image_token *
288
+ self.num_patch_feature_size)
289
+ super().__init__(tokenizer=tokenizer, chat_template=chat_template, **kwargs)
290
+ self.patcher = ImagePatcher()
291
+
292
+ @property
293
+ def image_token_id(self) -> int:
294
+ return self.tokenizer.get_vocab()[self.image_token]
295
+
296
+ def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
297
+ num_patches, num_newlines = self.patcher.get_num_patches(
298
+ img_width, img_height)
299
+
300
+ return num_patches * (
301
+ self.num_patch_feature_size +
302
+ 2) + self.num_image_feature_size + 2 + num_newlines
303
+
304
+ def _split_images(self,
305
+ images: list[Image.Image]) -> list[ImageWithPatches]:
306
+ result = []
307
+ for img in images:
308
+ result.append(self.patcher(img))
309
+ return result
310
+
311
+ def _convert_images_to_pixel_values(
312
+ self,
313
+ images: list[Image.Image],
314
+ is_patch: bool = False,
315
+ ) -> list[torch.Tensor]:
316
+ return [
317
+ self.image_preprocessor(img, is_patch=is_patch)["pixel_values"]
318
+ for img in images
319
+ ]
320
+
321
+ def _get_patch_repl(
322
+ self,
323
+ num_patches: int,
324
+ patch_newline_mask: list[bool] | None,
325
+ ) -> tuple[str, list[int]]:
326
+ text = ""
327
+ token_ids = []
328
+ for i in range(num_patches):
329
+ assert len(patch_newline_mask) == num_patches
330
+ text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
331
+ token_ids.extend(
332
+ [self.tokenizer.convert_tokens_to_ids("<patch_start>")] +
333
+ [self.image_token_id] * self.num_patch_feature_size +
334
+ [self.tokenizer.convert_tokens_to_ids("<patch_end>")])
335
+ if patch_newline_mask and patch_newline_mask[i]:
336
+ text += "<patch_newline>"
337
+ token_ids.append(
338
+ self.tokenizer.convert_tokens_to_ids("<patch_newline>"))
339
+ return text, token_ids
340
+
341
+ def _get_image_repl(
342
+ self,
343
+ num_images: int,
344
+ ) -> tuple[str, list[int]]:
345
+ text = f"<im_start>{self.image_feature_placeholder}<im_end>"
346
+ token_ids = [
347
+ self.tokenizer.convert_tokens_to_ids("<im_start>")
348
+ ] + [self.image_token_id] * self.num_image_feature_size + [
349
+ self.tokenizer.convert_tokens_to_ids("<im_end>")
350
+ ]
351
+ return text * num_images, token_ids * num_images
352
+
353
+ def _get_image_repl_features(
354
+ self,
355
+ num_images: int,
356
+ num_patches: int,
357
+ patch_new_line_idx: Optional[list[bool]],
358
+ ) -> tuple[str, list[int]]:
359
+ if num_patches > 0:
360
+ patch_repl, patch_repl_ids = self._get_patch_repl(
361
+ num_patches, patch_new_line_idx)
362
+ else:
363
+ patch_repl = ""
364
+ patch_repl_ids = []
365
+ image_repl, image_repl_ids = self._get_image_repl(num_images)
366
+ return patch_repl + image_repl, patch_repl_ids + image_repl_ids
367
+
368
+ def replace_placeholder(self, text: str, placeholder: str,
369
+ repls: list[str]) -> str:
370
+ parts = text.split(placeholder)
371
+
372
+ if len(parts) - 1 != len(repls):
373
+ raise ValueError(
374
+ "The number of placeholders does not match the number of replacements." # noqa: E501
375
+ )
376
+
377
+ result = [parts[0]]
378
+ for i, repl in enumerate(repls):
379
+ result.append(repl)
380
+ result.append(parts[i + 1])
381
+
382
+ return "".join(result)
383
+
384
+ def __call__(
385
+ self,
386
+ text: Optional[Union[str, list[str]]] = None,
387
+ images: ImageInput | None = None,
388
+ return_tensors: Optional[Union[str, TensorType]] = None,
389
+ **kwargs,
390
+ ) -> BatchFeature:
391
+
392
+ if images is not None:
393
+ images = self.image_preprocessor.fetch_images(images)
394
+ if text is None:
395
+ text = []
396
+ if not isinstance(text, list):
397
+ text = [text]
398
+ if images is None:
399
+ images = []
400
+ elif not isinstance(images, list):
401
+ images = [images]
402
+ elif isinstance(images[0], list):
403
+ images = images[0]
404
+
405
+ if len(images) == 0:
406
+ image_inputs = {}
407
+ text_inputs = self.tokenizer(text)
408
+ else:
409
+ splitted_images_data = self._split_images(images)
410
+ pixel_values_lst = []
411
+ patch_pixel_values_lst = []
412
+ patch_newline_mask_lst = []
413
+ image_repl_str_lst = []
414
+ image_repl_ids_lst = []
415
+ num_patches = []
416
+ for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501
417
+ pixel_values_lst.extend(
418
+ self._convert_images_to_pixel_values([raw_img]))
419
+
420
+ if len(img_patches) > 0:
421
+ patch_pixel_values_lst.extend(
422
+ self._convert_images_to_pixel_values(img_patches,
423
+ is_patch=True))
424
+ num_patches.append(len(img_patches))
425
+
426
+ image_repl_str, image_repl_ids = self._get_image_repl_features(
427
+ 1, len(img_patches), patch_newline_mask)
428
+ image_repl_str_lst.append(image_repl_str)
429
+ image_repl_ids_lst.extend(image_repl_ids)
430
+
431
+ if patch_newline_mask is not None:
432
+ patch_newline_mask_lst.extend(patch_newline_mask)
433
+
434
+ image_inputs = {
435
+ "pixel_values": torch.cat(pixel_values_lst),
436
+ "num_patches": num_patches,
437
+ }
438
+ if patch_pixel_values_lst:
439
+ image_inputs["patch_pixel_values"] = torch.cat(
440
+ patch_pixel_values_lst)
441
+ if patch_newline_mask_lst:
442
+ image_inputs["patch_newline_mask"] = torch.tensor(
443
+ patch_newline_mask_lst, dtype=torch.bool)
444
+
445
+ text = [
446
+ self.replace_placeholder(t, self.image_token,
447
+ image_repl_str_lst) for t in text
448
+ ]
449
+ text_inputs = self.tokenizer(text)
450
+
451
+ return BatchFeature(
452
+ {
453
+ **text_inputs,
454
+ **image_inputs,
455
+ },
456
+ tensor_type=return_tensors,
457
+ )
458
+
459
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
460
+ def batch_decode(self, *args, **kwargs):
461
+ """
462
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
463
+ refer to the docstring of this method for more information.
464
+ """
465
+ return self.tokenizer.batch_decode(*args, **kwargs)
466
+
467
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
468
+ def decode(self, *args, **kwargs):
469
+ """
470
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
471
+ the docstring of this method for more information.
472
+ """
473
+ return self.tokenizer.decode(*args, **kwargs)
474
+
475
+ __all__ = ["Step3VLProcessor"]
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|begin▁of▁sentence|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|im_end|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|end▁of▁sentence|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
step3p7_mlx.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Local MLX bridge for Step3p7 bundles.
2
+
3
+ The current MLX runtime has a Step3p5 text model but no Step3p7 VLM wrapper.
4
+ For text coherence proof, this module exposes a custom ``model_file`` that
5
+ loads the nested ``text_config`` through ``mlx_lm.models.step3p5`` and drops
6
+ vision tensors during sanitize. Vision runtime remains a separate follow-up.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+
13
+ from mlx_lm.models import step3p5
14
+
15
+
16
+ @dataclass
17
+ class ModelArgs(step3p5.ModelArgs):
18
+ @classmethod
19
+ def from_dict(cls, params):
20
+ text_config = dict(params.get("text_config") or params)
21
+ text_config["model_type"] = "step3p5"
22
+ return super().from_dict(text_config)
23
+
24
+
25
+ class Model(step3p5.Model):
26
+ def sanitize(self, weights):
27
+ text_weights = {}
28
+ prefix = "model.language_model."
29
+ for key, value in weights.items():
30
+ if key.startswith(prefix):
31
+ text_weights["model." + key[len(prefix):]] = value
32
+ elif key.startswith("model.vision_model.") or key.startswith("model.vit_large_projector."):
33
+ continue
34
+ elif key.startswith("vision_model.") or key.startswith("vit_large_projector."):
35
+ continue
36
+ else:
37
+ text_weights[key] = value
38
+ return super().sanitize(text_weights)
39
+
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
vision_encoder.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.activations import ACT2FN
7
+
8
+
9
+ from .configuration_step3p7 import StepRoboticsVisionEncoderConfig
10
+
11
+
12
+
13
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
14
+ """Rotate last dimension halves (used by RoPE)."""
15
+ x = x.reshape(*x.shape[:-1], -1, 2)
16
+ x1, x2 = x.unbind(dim=-1)
17
+ x = torch.stack((-x2, x1), dim=-1)
18
+ return x.reshape(*x.shape[:-2], -1)
19
+
20
+
21
+ def apply_rotary_emb(freqs: torch.Tensor,
22
+ t: torch.Tensor,
23
+ start_index: int = 0,
24
+ scale: float = 1.0,
25
+ seq_dim: int = -2) -> torch.Tensor:
26
+ """Apply 2D rotary embeddings to queries / keys."""
27
+ dtype = t.dtype
28
+
29
+ if t.ndim == 3:
30
+ seq_len = t.shape[seq_dim]
31
+ freqs = freqs[-seq_len:]
32
+
33
+ rot_dim = freqs.shape[-1]
34
+ end_index = start_index + rot_dim
35
+ assert rot_dim <= t.shape[-1], (
36
+ f"feature dimension {t.shape[-1]} is too small for rot_dim {rot_dim}")
37
+
38
+ t_left, t, t_right = (
39
+ t[..., :start_index],
40
+ t[..., start_index:end_index],
41
+ t[..., end_index:],
42
+ )
43
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
44
+ out = torch.cat((t_left, t, t_right), dim=-1)
45
+ return out.type(dtype)
46
+
47
+
48
+ class EncoderRope2D(nn.Module):
49
+ """Cacheable 2D rotary positional embedding."""
50
+
51
+ def __init__(
52
+ self,
53
+ dim: int,
54
+ max_grid_height: int,
55
+ max_grid_width: int,
56
+ use_cls_token: bool = False,
57
+ theta: Union[int, float] = 10000,
58
+ max_freq: int = 10,
59
+ num_freqs: int = 1,
60
+ theta_rescale_factor: float = 1.0,
61
+ ):
62
+ super().__init__()
63
+ self.dim = dim
64
+ self.max_grid_height = max_grid_height
65
+ self.max_grid_width = max_grid_width
66
+ self.use_cls_token = use_cls_token
67
+ self.theta = theta * theta_rescale_factor**(dim / (dim - 2))
68
+ self.max_freq = max_freq
69
+ self.num_freqs = num_freqs
70
+ cache = self._compute_2d_freqs()
71
+ self.register_buffer("freqs_cache", cache, persistent=False)
72
+
73
+ def _compute_inv_freq(self, base: Union[int, float],
74
+ dim: int) -> torch.Tensor:
75
+
76
+ freqs = 1.0 / (base**(
77
+ torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
78
+ return freqs
79
+
80
+ def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
81
+ freqs = torch.einsum("..., f -> ... f", t.type(inv_freq.dtype),
82
+ inv_freq)
83
+ freqs = freqs.repeat_interleave(2, dim=-1)
84
+ return freqs
85
+
86
+ def _compute_2d_freqs(self) -> torch.Tensor:
87
+ grid_h_range = torch.arange(self.max_grid_height, dtype=torch.float)
88
+ grid_w_range = torch.arange(self.max_grid_width, dtype=torch.float)
89
+ if self.use_cls_token:
90
+ grid_h_range += 1
91
+ grid_w_range += 1
92
+ inv_freq = self._compute_inv_freq(self.theta, self.dim // 2)
93
+ freqs_h = self._compute_freqs(grid_h_range, inv_freq)[:, None].expand(
94
+ self.max_grid_height, self.max_grid_width, -1)
95
+ freqs_w = self._compute_freqs(grid_w_range, inv_freq)[None, :].expand(
96
+ self.max_grid_height, self.max_grid_width, -1)
97
+ freqs = torch.cat([freqs_w, freqs_h], dim=-1).reshape(
98
+ self.max_grid_height * self.max_grid_width, -1)
99
+ if self.use_cls_token:
100
+ freqs = torch.cat([torch.zeros(1, freqs.shape[-1]), freqs], dim=0)
101
+ freqs = freqs[None, None, ...]
102
+ return freqs
103
+
104
+ def forward(self, q: torch.Tensor, k: torch.Tensor,
105
+ grid_hw: tuple[int, int]):
106
+ # If grid matches cached shape we reuse directly to avoid recomputation.
107
+ if grid_hw[0] != self.max_grid_height or grid_hw[1] != self.max_grid_width:
108
+ rows = torch.arange(grid_hw[0], device=q.device).view(-1, 1)
109
+ cols = torch.arange(grid_hw[1], device=q.device).view(1, -1)
110
+ positions = (rows * self.max_grid_width + cols).reshape(-1).to(
111
+ torch.long)
112
+ if self.use_cls_token:
113
+ positions = torch.cat(
114
+ [torch.zeros(1, device=q.device), positions + 1], dim=0)
115
+ freqs = self.freqs_cache.index_select(2, positions)
116
+ else:
117
+ freqs = self.freqs_cache
118
+ q = apply_rotary_emb(freqs, q)
119
+ k = apply_rotary_emb(freqs, k)
120
+ return q, k
121
+
122
+
123
+ class EncoderLayerScale(nn.Module):
124
+ """Per-channel residual scaling used when ls_init_value is set."""
125
+
126
+ def __init__(self, dim: int, init_values: float):
127
+ super().__init__()
128
+ self.gamma = nn.Parameter(torch.full((dim,), init_values))
129
+
130
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # (B, L, D)
131
+ return hidden_states * self.gamma
132
+
133
+
134
+ class EncoderMLP(nn.Module):
135
+ """Feed-forward network used inside each transformer block."""
136
+
137
+ def __init__(self, hidden_size: int, intermediate_size: int,
138
+ hidden_act: str):
139
+ super().__init__()
140
+ self.c_fc = nn.Linear(hidden_size, intermediate_size, bias=True)
141
+ self.act_fn = ACT2FN[hidden_act]
142
+ self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=True)
143
+
144
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
145
+
146
+ hidden_states = self.c_proj(self.act_fn(self.c_fc(hidden_states)))
147
+ return hidden_states
148
+
149
+
150
+ class EncoderVisionAttention(nn.Module):
151
+ """Multi-head self attention with optional 2D RoPE."""
152
+
153
+ def __init__(
154
+ self,
155
+ hidden_size: int,
156
+ num_heads: int,
157
+ max_grid_height: int,
158
+ max_grid_width: int,
159
+ use_cls_token: bool = False,
160
+ use_rope2d: bool = True,
161
+ rope_theta: Union[int, float] = 10000,
162
+ rope_max_freq: int = 10,
163
+ rope_num_freqs: int = 1,
164
+ rope_theta_rescale_factor: float = 1.0,
165
+ rope_freqs_for: Literal["lang", "pixel", "constant"] = "lang",
166
+ ):
167
+ super().__init__()
168
+ if hidden_size % num_heads != 0:
169
+ raise ValueError(
170
+ f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})."
171
+ )
172
+ self.num_heads = num_heads
173
+ self.head_dim = hidden_size // num_heads
174
+ self.scale = self.head_dim**-0.5
175
+ self.in_proj_weight = nn.Parameter(torch.zeros(hidden_size * 3, hidden_size))
176
+ self.in_proj_bias = nn.Parameter(torch.zeros(hidden_size * 3))
177
+ self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
178
+
179
+ self.rope = None
180
+ if use_rope2d:
181
+ self.rope = EncoderRope2D(
182
+ dim=self.head_dim,
183
+ max_grid_height=max_grid_height,
184
+ max_grid_width=max_grid_width,
185
+ use_cls_token=use_cls_token,
186
+ theta=rope_theta,
187
+ max_freq=rope_max_freq,
188
+ num_freqs=rope_num_freqs,
189
+ theta_rescale_factor=rope_theta_rescale_factor,
190
+ )
191
+
192
+ def forward(self, hidden_states: torch.Tensor, grid_hw: tuple[int, int]) -> torch.Tensor:
193
+ bsz, seq_len, _ = hidden_states.shape
194
+ qkv = F.linear(
195
+ hidden_states,
196
+ self.in_proj_weight,
197
+ self.in_proj_bias,
198
+ )
199
+ q, k, v = qkv.chunk(3, dim=-1)
200
+
201
+ q = q.view(bsz, seq_len, self.num_heads,
202
+ self.head_dim).transpose(1, 2)
203
+ k = k.view(bsz, seq_len, self.num_heads,
204
+ self.head_dim).transpose(1, 2)
205
+ if self.rope is not None:
206
+ q, k = self.rope(q, k, grid_hw=grid_hw)
207
+ v = v.view(bsz, seq_len, self.num_heads,
208
+ self.head_dim).transpose(1, 2)
209
+
210
+ attn_output = F.scaled_dot_product_attention(
211
+ q, k, v, is_causal=False, scale=self.scale)
212
+ attn_output = attn_output.transpose(1, 2).reshape(
213
+ bsz, seq_len, self.num_heads * self.head_dim)
214
+ return self.out_proj(attn_output)
215
+
216
+
217
+ class EncoderVisionBlock(nn.Module):
218
+ """A single Vision Transformer block (self-attention + MLP)."""
219
+
220
+ def __init__(
221
+ self,
222
+ hidden_size: int,
223
+ num_heads: int,
224
+ mlp_ratio: float,
225
+ hidden_act: str,
226
+ layer_norm_eps: float,
227
+ ls_init_value: Optional[float] = None,
228
+ max_grid_height: Optional[int] = None,
229
+ max_grid_width: Optional[int] = None,
230
+ use_cls_token: bool = False,
231
+ use_rope2d: bool = True,
232
+ rope_kwargs: Optional[dict] = None,
233
+ ):
234
+ super().__init__()
235
+ rope_kwargs = rope_kwargs or {}
236
+ self.attn = EncoderVisionAttention(
237
+ hidden_size,
238
+ num_heads,
239
+ max_grid_height=max_grid_height,
240
+ max_grid_width=max_grid_width,
241
+ use_cls_token=use_cls_token,
242
+ use_rope2d=use_rope2d,
243
+ **rope_kwargs,
244
+ )
245
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
246
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
247
+
248
+ intermediate = int(hidden_size * mlp_ratio)
249
+ self.mlp = EncoderMLP(hidden_size, intermediate, hidden_act)
250
+
251
+ self.ls_1 = EncoderLayerScale(hidden_size, ls_init_value)
252
+ self.ls_2 = EncoderLayerScale(hidden_size, ls_init_value)
253
+
254
+ def forward(self, hidden_states: torch.Tensor,
255
+ grid_hw: tuple[int, int]) -> torch.Tensor:
256
+ # breakpoint()
257
+ residual = hidden_states
258
+ hidden_states = self.ln_1(hidden_states)
259
+ hidden_states = self.attn(hidden_states, grid_hw=grid_hw)
260
+ hidden_states = residual + self.ls_1(hidden_states)
261
+
262
+ residual = hidden_states
263
+ hidden_states = self.ln_2(hidden_states)
264
+ hidden_states = self.mlp(hidden_states)
265
+ hidden_states = residual + self.ls_2(hidden_states)
266
+ return hidden_states
267
+
268
+
269
+ class EncoderVisionTransformer(nn.Module):
270
+ """Stack of encoder blocks parameterised by Step35VisionEncoderConfig."""
271
+
272
+ def __init__(
273
+ self,
274
+ embed_dim: int,
275
+ depth: int,
276
+ num_heads: int,
277
+ mlp_ratio: float,
278
+ hidden_act: str,
279
+ layer_norm_eps: float,
280
+ ls_init_value: Optional[float] = None,
281
+ max_grid_height: Optional[int] = None,
282
+ max_grid_width: Optional[int] = None,
283
+ use_cls_token: bool = False,
284
+ use_rope2d: bool = True,
285
+ rope_kwargs: Optional[dict] = None,
286
+ ):
287
+ super().__init__()
288
+ self.layers = depth
289
+ rope_kwargs = rope_kwargs or {}
290
+ self.resblocks = nn.ModuleList([
291
+ EncoderVisionBlock(embed_dim, num_heads, mlp_ratio, hidden_act,
292
+ layer_norm_eps,
293
+ max_grid_height=max_grid_height,
294
+ max_grid_width=max_grid_width,
295
+ use_cls_token=use_cls_token,
296
+ use_rope2d=use_rope2d,
297
+ ls_init_value=ls_init_value,
298
+ rope_kwargs=rope_kwargs)
299
+ for _ in range(depth)
300
+ ])
301
+
302
+ def forward(self,
303
+ hidden_states: torch.Tensor,
304
+ grid_hw: tuple[int, int]) -> torch.Tensor:
305
+ for block in self.resblocks:
306
+ hidden_states = block(hidden_states, grid_hw=grid_hw)
307
+ return hidden_states
308
+
309
+
310
+ class StepRoboticsVisionEncoder(nn.Module):
311
+ """
312
+ Vision encoder built from StepRoboticsVisionEncoderConfig.
313
+
314
+ The encoder performs patch embedding followed by a stack of transformer
315
+ blocks. Only the config fields defined in StepRoboticsVisionEncoderConfig (and
316
+ StepRoboticVLConfig.vision_config) are expected.
317
+ """
318
+
319
+ def __init__(self, config: StepRoboticsVisionEncoderConfig):
320
+ super().__init__()
321
+ self.config = config
322
+
323
+ # Align commonly used attributes so downstream code (e.g. StepRoboticVL)
324
+ # can access them without extra renaming.
325
+ self.hidden_size = config.width
326
+ self.num_heads = config.heads
327
+ self.num_hidden_layers = config.layers
328
+ self.patch_size = config.patch_size
329
+ self.image_size = config.image_size
330
+ self.use_cls_token = getattr(config, "use_cls_token", False)
331
+ self.use_rope2d = getattr(config, "use_rope2d", True)
332
+ self.use_abs_posemb = getattr(config, "use_abs_posemb", True)
333
+ self.layer_norm_eps = config.layer_norm_eps
334
+ self.mlp_ratio = getattr(config, "mlp_ratio", 8960 / 1536)
335
+ self.ls_init_value = getattr(config, "ls_init_value", None)
336
+ self.hidden_act = config.hidden_act
337
+ self.use_ln_pre = getattr(config, "use_ln_pre", False)
338
+ self.use_ln_post = getattr(config, "use_ln_post", True)
339
+
340
+ # Patch embedding.
341
+ self.conv1 = nn.Conv2d(in_channels=config.num_channels,
342
+ out_channels=self.hidden_size,
343
+ kernel_size=self.patch_size,
344
+ stride=self.patch_size,
345
+ bias=False)
346
+
347
+ self.ln_pre = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_pre else nn.Identity()
348
+ self.ln_post = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_post else nn.Identity()
349
+
350
+ grid_size = self.image_size // self.patch_size
351
+ self.base_grid = (grid_size, grid_size)
352
+
353
+ if self.use_cls_token:
354
+ self.class_embedding = nn.Parameter(
355
+ torch.randn(self.hidden_size) * (self.hidden_size**-0.5))
356
+ else:
357
+ self.class_embedding = None
358
+
359
+ if self.use_abs_posemb:
360
+ self.posemb_grid_size = self.image_size // self.patch_size
361
+ self.positional_embedding = nn.Parameter(
362
+ (self.hidden_size**-0.5) * torch.randn(
363
+ int(self.use_cls_token) + self.posemb_grid_size**2,
364
+ self.hidden_size,
365
+ ))
366
+
367
+ self.transformer = EncoderVisionTransformer(
368
+ embed_dim=self.hidden_size,
369
+ depth=self.num_hidden_layers,
370
+ num_heads=self.num_heads,
371
+ mlp_ratio=self.mlp_ratio,
372
+ hidden_act=self.hidden_act,
373
+ layer_norm_eps=self.layer_norm_eps,
374
+ ls_init_value=self.ls_init_value,
375
+ max_grid_height=self.base_grid[0],
376
+ max_grid_width=self.base_grid[1],
377
+ use_cls_token=self.use_cls_token,
378
+ use_rope2d=self.use_rope2d,
379
+ rope_kwargs={
380
+ "rope_theta": getattr(config, "rope_theta", 10000),
381
+ "rope_max_freq": getattr(config, "rope_max_freq", 10),
382
+ "rope_num_freqs": getattr(config, "rope_num_freqs", 1),
383
+ "rope_theta_rescale_factor":
384
+ getattr(config, "rope_theta_rescale_factor", 1.0),
385
+ "rope_freqs_for": getattr(config, "rope_freqs_for", "lang"),
386
+ },
387
+ )
388
+ self.vit_downsampler1 = nn.Conv2d(self.hidden_size,
389
+ self.hidden_size * 2,
390
+ kernel_size=3,
391
+ stride=2,
392
+ padding=1)
393
+ self.vit_downsampler2 = nn.Conv2d(self.hidden_size * 2,
394
+ self.hidden_size * 4,
395
+ kernel_size=3,
396
+ stride=2,
397
+ padding=1)
398
+
399
+
400
+ def sample_abs_posemb(self, grid_h: int, grid_w: int):
401
+ if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
402
+ return self.positional_embedding[None, ...]
403
+
404
+ pos_embed = self.positional_embedding
405
+ if self.use_cls_token:
406
+ cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
407
+
408
+ pos_embed = (pos_embed.reshape(1, self.posemb_grid_size,
409
+ self.posemb_grid_size,
410
+ -1).permute(0, 3, 1, 2).contiguous())
411
+ pos_embed = F.interpolate(pos_embed,
412
+ size=(grid_h, grid_w),
413
+ mode="bilinear",
414
+ align_corners=False)
415
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.hidden_size)
416
+
417
+ if self.use_cls_token:
418
+ pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
419
+
420
+ return pos_embed[None, ...]
421
+
422
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
423
+ """
424
+ Args:
425
+ pixel_values: Image tensor of shape (B, C, H, W).
426
+ layer_idx: Negative indices stop after a given block (e.g., -1 uses all blocks).
427
+ strip_cls_token: If True and cls token is used, remove it from output.
428
+ """
429
+ bsz, _, height, width = pixel_values.shape
430
+ grid_h, grid_w = height // self.patch_size, width // self.patch_size
431
+
432
+ hidden_state = self.conv1(pixel_values) # (B, D, Gh, Gw)
433
+ hidden_state = hidden_state.flatten(2).transpose(1, 2) # (B, Gh*Gw, D)
434
+
435
+ if self.use_cls_token:
436
+ cls_token = self.class_embedding.view(1, 1,
437
+ -1).expand(bsz, -1, -1)
438
+ hidden_state = torch.cat([cls_token, hidden_state], dim=1)
439
+
440
+ if self.use_abs_posemb:
441
+ pos_emb = self.sample_abs_posemb(grid_h, grid_w)
442
+ hidden_state = hidden_state + pos_emb
443
+ hidden_state = self.ln_pre(hidden_state)
444
+ hidden_state = self.transformer(hidden_state, grid_hw=(grid_h, grid_w))
445
+
446
+ if self.use_ln_post:
447
+ hidden_state = self.ln_post(hidden_state)
448
+
449
+ if self.use_cls_token:
450
+ hidden_state = hidden_state[:, 1:, :]
451
+
452
+ return hidden_state