Spaces:
Sleeping
Sleeping
| """ | |
| Utilities for extracting and manipulating attention weights from transformer models, | |
| starting from pre-computed hidden states. | |
| This module provides functions to compute attention weights from various transformer | |
| models (like Llama, Phi, Qwen, Gemma) and use them for attribution. We compute only | |
| the relevant attention weights (as specified by `attribution_start` and | |
| `attribution_end`) in order to be able to efficiently compute and store them. If we | |
| were to use `output_attentions=True` in the forward pass, we would (1) only be able | |
| to use the `eager` attention implementation, and (2) would need to store the entire | |
| attention matrix which grows quadratically with the sequence length. Most of the | |
| logic here is replicated from the `transformers` library. | |
| If you'd like to perform attribution on a model that is not currently supported, | |
| you can add it yourself by modifying `infer_model_type` and | |
| `get_layer_attention_weights`. Please see `tests/attribution/test_attention.py` | |
| to ensure that your implementation matches the expected attention weights when | |
| using the `output_attentions=True`. | |
| """ | |
| import math | |
| from typing import Any, Optional | |
| import torch as ch | |
| import transformers.models | |
| def infer_model_type(model): | |
| model_type_to_keyword = { | |
| "llama": "llama", | |
| "phi3": "phi", | |
| "qwen2": "qwen", | |
| "gemma3": "gemma", | |
| } | |
| for model_type, keyword in model_type_to_keyword.items(): | |
| if keyword in model.name_or_path.lower(): | |
| return model_type | |
| else: | |
| raise ValueError(f"Unknown model: {model.name_or_path}. Specify `model_type`.") | |
| def get_helpers(model_type): | |
| #for model_name in dir(transformers.models): | |
| # if not model_name.startswith('__') and ("gemma" in model_name or "chatglm" in model_name): | |
| # print(model_name) | |
| if not hasattr(transformers.models, model_type): | |
| raise ValueError(f"Unknown model: {model_type}") | |
| model_module = getattr(transformers.models, model_type) | |
| modeling_module = getattr(model_module, f"modeling_{model_type}") | |
| return modeling_module.apply_rotary_pos_emb, modeling_module.repeat_kv | |
| def get_position_ids_and_attention_mask(model, hidden_states): | |
| input_embeds = hidden_states[0] | |
| _, seq_len, _ = input_embeds.shape | |
| position_ids = ch.arange(0, seq_len, device=model.device).unsqueeze(0) | |
| attention_mask = ch.ones( | |
| seq_len, seq_len + 1, device=model.device, dtype=model.dtype | |
| ) | |
| attention_mask = ch.triu(attention_mask, diagonal=1) | |
| attention_mask *= ch.finfo(model.dtype).min | |
| attention_mask = attention_mask[None, None] | |
| return position_ids, attention_mask | |
| def get_attentions_shape(model): | |
| num_layers = len(model.model.layers) | |
| num_heads = model.model.config.num_attention_heads | |
| return num_layers, num_heads | |
| def get_layer_attention_weights( | |
| model, | |
| hidden_states, | |
| layer_index, | |
| position_ids, | |
| attention_mask, | |
| attribution_start=None, | |
| attribution_end=None, | |
| model_type=None, | |
| ): | |
| model_type = model_type or infer_model_type(model) | |
| assert layer_index >= 0 and layer_index < len(model.model.layers) | |
| layer = model.model.layers[layer_index] | |
| self_attn = layer.self_attn | |
| hidden_states = hidden_states[layer_index] | |
| #print("hidden_states_shape: ", hidden_states.shape) | |
| hidden_states = layer.input_layernorm(hidden_states) | |
| bsz, q_len, _ = hidden_states.size() | |
| num_attention_heads = model.model.config.num_attention_heads | |
| num_key_value_heads = model.model.config.num_key_value_heads | |
| head_dim = self_attn.head_dim | |
| if model_type in ("llama", "qwen2", "qwen1.5","gemma3","glm"): | |
| query_states = self_attn.q_proj(hidden_states) | |
| key_states = self_attn.k_proj(hidden_states) | |
| elif model_type in ("phi3",): | |
| qkv = self_attn.qkv_proj(hidden_states) | |
| query_pos = num_attention_heads * head_dim | |
| query_states = qkv[..., :query_pos] | |
| key_states = qkv[..., query_pos : query_pos + num_key_value_heads * head_dim] | |
| else: | |
| raise ValueError(f"Unknown model: {model.name_or_path}") | |
| query_states = query_states.view(bsz, q_len, num_attention_heads, head_dim) | |
| query_states = query_states.transpose(1, 2) | |
| key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim) | |
| key_states = key_states.transpose(1, 2) | |
| if model_type in ["gemma3"]: | |
| query_states = self_attn.q_norm(query_states) | |
| key_states = self_attn.k_norm(key_states) | |
| if self_attn.is_sliding: | |
| position_embeddings = model.model.rotary_emb_local( | |
| hidden_states, position_ids | |
| ) | |
| else: | |
| position_embeddings = model.model.rotary_emb(hidden_states, position_ids) | |
| else: | |
| position_embeddings = model.model.rotary_emb(hidden_states, position_ids) | |
| cos, sin = position_embeddings | |
| apply_rotary_pos_emb, repeat_kv = get_helpers(model_type) | |
| #query_states = query_states.to("cuda:0") | |
| #key_states = key_states.to("cuda:0") | |
| #cos = cos.to("cuda:0") | |
| #sin = sin.to("cuda:0") | |
| #print("D1", query_states.device) | |
| #print("D2", key_states.device) | |
| # print("D3", cos.device) | |
| #print("D4", sin.device) | |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
| key_states = repeat_kv(key_states, self_attn.num_key_value_groups) | |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] | |
| attribution_start = attribution_start if attribution_start is not None else 1 | |
| attribution_end = attribution_end if attribution_end is not None else q_len + 1 | |
| causal_mask = causal_mask[:, :, attribution_start - 1 : attribution_end - 1] | |
| query_states = query_states[:, :, attribution_start - 1 : attribution_end - 1] | |
| attn_weights = ch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( | |
| head_dim | |
| ) | |
| attn_weights = attn_weights + causal_mask | |
| dtype = attn_weights.dtype | |
| attn_weights = ch.softmax(attn_weights, dim=-1, dtype=ch.float32).to(dtype) | |
| return attn_weights | |
| def get_attention_weights( | |
| model: Any, | |
| hidden_states: Any, | |
| attribution_start: Optional[int] = None, | |
| attribution_end: Optional[int] = None, | |
| model_type: Optional[str] = None, | |
| ) -> Any: | |
| """ | |
| Compute the attention weights for the given model and hidden states. | |
| Args: | |
| model: The model to compute the attention weights for. | |
| hidden_states: The pre-computed hidden states. | |
| attribution_start: The start index of the tokens we would like to attribute. | |
| attribution_end: The end index of the tokens we would like to attribute. | |
| model_type: The type of model to compute the attention weights for (each model | |
| in the `transformers` library has its own specific attention implementation). | |
| """ | |
| with ch.no_grad(): | |
| position_ids, attention_mask = get_position_ids_and_attention_mask( | |
| model, hidden_states | |
| ) | |
| num_layers, num_heads = get_attentions_shape(model) | |
| num_tokens = hidden_states[0].shape[1] + 1 | |
| attribution_start = attribution_start if attribution_start is not None else 1 | |
| attribution_end = attribution_end if attribution_end is not None else num_tokens | |
| num_target_tokens = attribution_end - attribution_start | |
| weights = ch.zeros( | |
| num_layers, | |
| num_heads, | |
| num_target_tokens, | |
| num_tokens - 1, | |
| device=model.device, | |
| dtype=model.dtype, | |
| ) | |
| for i in range(len(model.model.layers)): | |
| cur_weights = get_layer_attention_weights( | |
| model, | |
| hidden_states, | |
| i, | |
| position_ids, | |
| attention_mask, | |
| attribution_start=attribution_start, | |
| attribution_end=attribution_end, | |
| model_type=model_type, | |
| ) | |
| weights[i, :, :, :] = cur_weights[0] | |
| return weights | |
| def get_attention_weights_one_layer( | |
| model: Any, | |
| hidden_states: Any, | |
| layer_index: int, | |
| attribution_start: Optional[int] = None, | |
| attribution_end: Optional[int] = None, | |
| model_type: Optional[str] = None, | |
| ) -> Any: | |
| """ | |
| Compute the attention weights for the given model and hidden states. | |
| Args: | |
| model: The model to compute the attention weights for. | |
| hidden_states: The pre-computed hidden states. | |
| attribution_start: The start index of the tokens we would like to attribute. | |
| attribution_end: The end index of the tokens we would like to attribute. | |
| model_type: The type of model to compute the attention weights for (each model | |
| in the `transformers` library has its own specific attention implementation). | |
| """ | |
| with ch.no_grad(): | |
| position_ids, attention_mask = get_position_ids_and_attention_mask( | |
| model, hidden_states | |
| ) | |
| num_layers, num_heads = get_attentions_shape(model) | |
| num_tokens = hidden_states[0].shape[1] + 1 | |
| attribution_start = attribution_start if attribution_start is not None else 1 | |
| attribution_end = attribution_end if attribution_end is not None else num_tokens | |
| num_target_tokens = attribution_end - attribution_start | |
| weights = ch.zeros( | |
| num_layers, | |
| num_heads, | |
| num_target_tokens, | |
| num_tokens - 1, | |
| device=model.device, | |
| dtype=model.dtype, | |
| ) | |
| weights = get_layer_attention_weights( | |
| model, | |
| hidden_states, | |
| layer_index, | |
| position_ids, | |
| attention_mask, | |
| attribution_start=attribution_start, | |
| attribution_end=attribution_end, | |
| model_type=model_type, | |
| ) | |
| return weights | |
| def get_hidden_states_one_layer( | |
| model: Any, | |
| hidden_states: Any, | |
| layer_index: int, | |
| attribution_start: Optional[int] = None, | |
| attribution_end: Optional[int] = None, | |
| model_type: Optional[str] = None, | |
| ) -> Any: | |
| def get_hidden_states( | |
| model, | |
| hidden_states, | |
| layer_index, | |
| position_ids, | |
| attention_mask, | |
| attribution_start=None, | |
| attribution_end=None, | |
| model_type=None, | |
| ): | |
| model_type = model_type or infer_model_type(model) | |
| assert layer_index >= 0 and layer_index < len(model.model.layers) | |
| layer = model.model.layers[layer_index] | |
| self_attn = layer.self_attn | |
| hidden_states = hidden_states[layer_index] | |
| #print("hidden_states_shape: ", hidden_states.shape) | |
| hidden_states = layer.input_layernorm(hidden_states) | |
| bsz, q_len, _ = hidden_states.size() | |
| num_attention_heads = model.model.config.num_attention_heads | |
| num_key_value_heads = model.model.config.num_key_value_heads | |
| head_dim = self_attn.head_dim | |
| if model_type in ("llama", "qwen2", "qwen1.5","gemma3","glm"): | |
| query_states = self_attn.q_proj(hidden_states) | |
| key_states = self_attn.k_proj(hidden_states) | |
| elif model_type in ("phi3",): | |
| qkv = self_attn.qkv_proj(hidden_states) | |
| query_pos = num_attention_heads * head_dim | |
| query_states = qkv[..., :query_pos] | |
| key_states = qkv[..., query_pos : query_pos + num_key_value_heads * head_dim] | |
| else: | |
| raise ValueError(f"Unknown model: {model.name_or_path}") | |
| query_states = query_states.view(bsz, q_len, num_attention_heads, head_dim) | |
| query_states = query_states.transpose(1, 2) | |
| key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).mean(dim=(0, 2)) | |
| return key_states | |
| """ | |
| Compute the attention weights for the given model and hidden states. | |
| Args: | |
| model: The model to compute the attention weights for. | |
| hidden_states: The pre-computed hidden states. | |
| attribution_start: The start index of the tokens we would like to attribute. | |
| attribution_end: The end index of the tokens we would like to attribute. | |
| model_type: The type of model to compute the attention weights for (each model | |
| in the `transformers` library has its own specific attention implementation). | |
| """ | |
| with ch.no_grad(): | |
| position_ids, attention_mask = get_position_ids_and_attention_mask( | |
| model, hidden_states | |
| ) | |
| num_layers, num_heads = get_attentions_shape(model) | |
| num_tokens = hidden_states[0].shape[1] + 1 | |
| attribution_start = attribution_start if attribution_start is not None else 1 | |
| attribution_end = attribution_end if attribution_end is not None else num_tokens | |
| num_target_tokens = attribution_end - attribution_start | |
| weights = ch.zeros( | |
| num_layers, | |
| num_heads, | |
| num_target_tokens, | |
| num_tokens - 1, | |
| device=model.device, | |
| dtype=model.dtype, | |
| ) | |
| hidden_states = get_hidden_states( | |
| model, | |
| hidden_states, | |
| layer_index, | |
| position_ids, | |
| attention_mask, | |
| attribution_start=attribution_start, | |
| attribution_end=attribution_end, | |
| model_type=model_type, | |
| ) | |
| return hidden_states |