File size: 7,631 Bytes
f4c0387 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | """
Test: output_attentions が正しく Attention Output を返すか検証する。
Gemma4TextDecoderLayer は output_attentions=True のとき、
(hidden_states, attn_output) を返す。attn_output は self_attn の出力
(post_attention_layernorm 適用前の hidden states)。
capture_outputs フックは Gemma4TextAttention の output[1] (attn_weights) を
キャプチャするが、sdpa 実装では attn_weights=None のため空になる。
そこで DecoderLayer レベルで attn_output が正しく取得できるかを検証する。
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_PATH = "/workspace/llm/gemma-4-31B-Text"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
inputs = tokenizer("hello", return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
inputs = inputs.to(model.device)
num_layers = model.config.num_hidden_layers
hidden_size = model.config.hidden_size
seq_len = inputs["input_ids"].shape[1]
batch_size = inputs["input_ids"].shape[0]
print(f"Model: num_layers={num_layers}, hidden_size={hidden_size}")
print(f"Input: batch={batch_size}, seq_len={seq_len}")
# =========================================================
# Test 1: model.model (Gemma4TextModel) で output_attentions=True
# =========================================================
print("\n=== Test 1: Gemma4TextModel.forward(output_attentions=True) ===")
with torch.no_grad():
text_outputs = model.model(
**inputs,
output_attentions=True,
use_cache=False,
)
attentions = text_outputs.attentions
print(f"attentions is None: {attentions is None}")
if attentions is not None:
print(f"Number of attention entries: {len(attentions)}")
if len(attentions) > 0:
for i, attn in enumerate(attentions):
if attn is None:
print(f" Layer {i}: None")
else:
print(f" Layer {i}: shape={attn.shape}, dtype={attn.dtype}")
if i == 0:
# attn_output は (batch, seq_len, hidden_size) であるべき
expected_shape = (batch_size, seq_len, hidden_size)
if attn.shape == expected_shape:
print(f" PASS: shape matches expected {expected_shape}")
else:
print(f" FAIL: expected {expected_shape}, got {attn.shape}")
else:
print(" (empty tuple - capture_outputs hook did not collect anything)")
# =========================================================
# Test 2: DecoderLayer を直接呼んで attn_output を確認
# =========================================================
print("\n=== Test 2: DecoderLayer direct call with output_attentions=True ===")
with torch.no_grad():
# まずembeddingとposition情報を準備
input_ids = inputs["input_ids"].to(model.device)
inputs_embeds = model.model.embed_tokens(input_ids)
position_ids = torch.arange(seq_len, device=model.device).unsqueeze(0)
# Rotary embedding
layer_type = model.config.layer_types[0]
position_embeddings = model.model.rotary_emb(inputs_embeds, position_ids, layer_type)
# Causal mask (簡易: None で全アテンション)
first_layer = model.model.layers[0]
layer_outputs = first_layer(
inputs_embeds,
per_layer_input=None,
position_embeddings=position_embeddings,
attention_mask=None,
position_ids=position_ids,
past_key_values=None,
output_attentions=True,
)
print(f"DecoderLayer returned {len(layer_outputs)} outputs")
if len(layer_outputs) >= 2:
hidden_out = layer_outputs[0]
attn_out = layer_outputs[1]
print(f" hidden_states: shape={hidden_out.shape}, dtype={hidden_out.dtype}")
print(f" attn_output: shape={attn_out.shape}, dtype={attn_out.dtype}")
expected_shape = (batch_size, seq_len, hidden_size)
if attn_out.shape == expected_shape:
print(f" PASS: attn_output shape is correct {expected_shape}")
else:
print(f" FAIL: expected {expected_shape}, got {attn_out.shape}")
# attn_output が all-zero でないことを確認
if attn_out.abs().sum() > 0:
print(f" PASS: attn_output is non-zero (norm={attn_out.float().norm().item():.4f})")
else:
print(f" FAIL: attn_output is all zeros")
# hidden_states と attn_output が異なることを確認
# (attn_output は layernorm + residual 前なので hidden_states とは異なるはず)
if not torch.equal(hidden_out, attn_out):
print(f" PASS: attn_output differs from hidden_states (as expected)")
else:
print(f" FAIL: attn_output is identical to hidden_states")
else:
print(f" FAIL: expected 2 outputs, got {len(layer_outputs)}")
# =========================================================
# Test 3: output_attentions=False では attn_output が返らないこと
# =========================================================
print("\n=== Test 3: DecoderLayer with output_attentions=False ===")
with torch.no_grad():
layer_outputs_no_attn = first_layer(
inputs_embeds,
per_layer_input=None,
position_embeddings=position_embeddings,
attention_mask=None,
position_ids=position_ids,
past_key_values=None,
output_attentions=False,
)
print(f"DecoderLayer returned {len(layer_outputs_no_attn)} outputs")
if len(layer_outputs_no_attn) == 1:
print(" PASS: only hidden_states returned (no attn_output)")
else:
print(f" FAIL: expected 1 output, got {len(layer_outputs_no_attn)}")
# =========================================================
# Test 4: CausalLM の output_attentions の伝播確認
# =========================================================
print("\n=== Test 4: Gemma4ForCausalLM output_attentions propagation ===")
with torch.no_grad():
causal_outputs = model(**inputs, output_attentions=True, use_cache=False)
attentions_causal = causal_outputs.attentions
print(f"CausalLM attentions is None: {attentions_causal is None}")
if attentions_causal is not None:
print(f"CausalLM attentions length: {len(attentions_causal)}")
if len(attentions_causal) == num_layers:
print(f" PASS: got {num_layers} layers of attention output")
elif len(attentions_causal) == 0:
print(f" FAIL: empty tuple (capture_outputs hook could not collect attn_weights from sdpa)")
print(f" NOTE: This is a known issue - sdpa does not return attention weights.")
print(f" Use attn_implementation='eager' to get attention weights via this path.")
else:
print(f" Got {len(attentions_causal)} (expected {num_layers})")
# =========================================================
# Summary
# =========================================================
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print("- DecoderLayer correctly returns attn_output when output_attentions=True")
print("- DecoderLayer correctly omits attn_output when output_attentions=False")
print("- capture_outputs hook on CausalLM/TextModel collects Gemma4TextAttention output[1]")
print(" which is attn_weights (None with sdpa), so CausalLM.attentions is empty.")
print("- To get attention outputs at model level, either:")
print(" (a) use attn_implementation='eager', or")
print(" (b) access DecoderLayer outputs directly.")
|