File size: 7,631 Bytes
f4c0387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""
Test: output_attentions が正しく Attention Output を返すか検証する。

Gemma4TextDecoderLayer は output_attentions=True のとき、
(hidden_states, attn_output) を返す。attn_output は self_attn の出力
(post_attention_layernorm 適用前の hidden states)。

capture_outputs フックは Gemma4TextAttention の output[1] (attn_weights) を
キャプチャするが、sdpa 実装では attn_weights=None のため空になる。
そこで DecoderLayer レベルで attn_output が正しく取得できるかを検証する。
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_PATH = "/workspace/llm/gemma-4-31B-Text"

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
inputs = tokenizer("hello", return_tensors="pt")

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)
inputs = inputs.to(model.device)

num_layers = model.config.num_hidden_layers
hidden_size = model.config.hidden_size
seq_len = inputs["input_ids"].shape[1]
batch_size = inputs["input_ids"].shape[0]

print(f"Model: num_layers={num_layers}, hidden_size={hidden_size}")
print(f"Input: batch={batch_size}, seq_len={seq_len}")

# =========================================================
# Test 1: model.model (Gemma4TextModel) で output_attentions=True
# =========================================================
print("\n=== Test 1: Gemma4TextModel.forward(output_attentions=True) ===")
with torch.no_grad():
    text_outputs = model.model(
        **inputs,
        output_attentions=True,
        use_cache=False,
    )

attentions = text_outputs.attentions
print(f"attentions is None: {attentions is None}")

if attentions is not None:
    print(f"Number of attention entries: {len(attentions)}")
    if len(attentions) > 0:
        for i, attn in enumerate(attentions):
            if attn is None:
                print(f"  Layer {i}: None")
            else:
                print(f"  Layer {i}: shape={attn.shape}, dtype={attn.dtype}")
                if i == 0:
                    # attn_output は (batch, seq_len, hidden_size) であるべき
                    expected_shape = (batch_size, seq_len, hidden_size)
                    if attn.shape == expected_shape:
                        print(f"    PASS: shape matches expected {expected_shape}")
                    else:
                        print(f"    FAIL: expected {expected_shape}, got {attn.shape}")
    else:
        print("  (empty tuple - capture_outputs hook did not collect anything)")

# =========================================================
# Test 2: DecoderLayer を直接呼んで attn_output を確認
# =========================================================
print("\n=== Test 2: DecoderLayer direct call with output_attentions=True ===")
with torch.no_grad():
    # まずembeddingとposition情報を準備
    input_ids = inputs["input_ids"].to(model.device)
    inputs_embeds = model.model.embed_tokens(input_ids)
    position_ids = torch.arange(seq_len, device=model.device).unsqueeze(0)

    # Rotary embedding
    layer_type = model.config.layer_types[0]
    position_embeddings = model.model.rotary_emb(inputs_embeds, position_ids, layer_type)

    # Causal mask (簡易: None で全アテンション)
    first_layer = model.model.layers[0]

    layer_outputs = first_layer(
        inputs_embeds,
        per_layer_input=None,
        position_embeddings=position_embeddings,
        attention_mask=None,
        position_ids=position_ids,
        past_key_values=None,
        output_attentions=True,
    )

    print(f"DecoderLayer returned {len(layer_outputs)} outputs")
    if len(layer_outputs) >= 2:
        hidden_out = layer_outputs[0]
        attn_out = layer_outputs[1]
        print(f"  hidden_states: shape={hidden_out.shape}, dtype={hidden_out.dtype}")
        print(f"  attn_output:   shape={attn_out.shape}, dtype={attn_out.dtype}")

        expected_shape = (batch_size, seq_len, hidden_size)
        if attn_out.shape == expected_shape:
            print(f"  PASS: attn_output shape is correct {expected_shape}")
        else:
            print(f"  FAIL: expected {expected_shape}, got {attn_out.shape}")

        # attn_output が all-zero でないことを確認
        if attn_out.abs().sum() > 0:
            print(f"  PASS: attn_output is non-zero (norm={attn_out.float().norm().item():.4f})")
        else:
            print(f"  FAIL: attn_output is all zeros")

        # hidden_states と attn_output が異なることを確認
        # (attn_output は layernorm + residual 前なので hidden_states とは異なるはず)
        if not torch.equal(hidden_out, attn_out):
            print(f"  PASS: attn_output differs from hidden_states (as expected)")
        else:
            print(f"  FAIL: attn_output is identical to hidden_states")
    else:
        print(f"  FAIL: expected 2 outputs, got {len(layer_outputs)}")

# =========================================================
# Test 3: output_attentions=False では attn_output が返らないこと
# =========================================================
print("\n=== Test 3: DecoderLayer with output_attentions=False ===")
with torch.no_grad():
    layer_outputs_no_attn = first_layer(
        inputs_embeds,
        per_layer_input=None,
        position_embeddings=position_embeddings,
        attention_mask=None,
        position_ids=position_ids,
        past_key_values=None,
        output_attentions=False,
    )
    print(f"DecoderLayer returned {len(layer_outputs_no_attn)} outputs")
    if len(layer_outputs_no_attn) == 1:
        print("  PASS: only hidden_states returned (no attn_output)")
    else:
        print(f"  FAIL: expected 1 output, got {len(layer_outputs_no_attn)}")

# =========================================================
# Test 4: CausalLM の output_attentions の伝播確認
# =========================================================
print("\n=== Test 4: Gemma4ForCausalLM output_attentions propagation ===")
with torch.no_grad():
    causal_outputs = model(**inputs, output_attentions=True, use_cache=False)

attentions_causal = causal_outputs.attentions
print(f"CausalLM attentions is None: {attentions_causal is None}")
if attentions_causal is not None:
    print(f"CausalLM attentions length: {len(attentions_causal)}")
    if len(attentions_causal) == num_layers:
        print(f"  PASS: got {num_layers} layers of attention output")
    elif len(attentions_causal) == 0:
        print(f"  FAIL: empty tuple (capture_outputs hook could not collect attn_weights from sdpa)")
        print(f"  NOTE: This is a known issue - sdpa does not return attention weights.")
        print(f"        Use attn_implementation='eager' to get attention weights via this path.")
    else:
        print(f"  Got {len(attentions_causal)} (expected {num_layers})")

# =========================================================
# Summary
# =========================================================
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print("- DecoderLayer correctly returns attn_output when output_attentions=True")
print("- DecoderLayer correctly omits attn_output when output_attentions=False")
print("- capture_outputs hook on CausalLM/TextModel collects Gemma4TextAttention output[1]")
print("  which is attn_weights (None with sdpa), so CausalLM.attentions is empty.")
print("- To get attention outputs at model level, either:")
print("  (a) use attn_implementation='eager', or")
print("  (b) access DecoderLayer outputs directly.")