ligongh commited on
Commit
be46b64
·
verified ·
1 Parent(s): 5a96d7b

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +111 -3
README.md CHANGED
@@ -1,3 +1,111 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - custom_generate
5
+ ---
6
+
7
+
8
+ ## Description
9
+ Implementation of the KV cache quantization method introduced in the [SQuat paper (COLM 2025)](https://arxiv.org/abs/2503.24358). SQuat (Subspace-orthogonal KV cache quantization) reduces the memory and compute cost of storing the KV cache by carefully quantizing the key tensors. It constructs a task-relevant subspace and ensures that quantization errors remain orthogonal to it, thereby minimizing their impact on attention outputs. SQuat is training-free, calibration-free, and operates on-the-fly, with strong theoretical grounding and state-of-the-art empirical results.
10
+
11
+ This repo provides a partial implementation of SQuat via a custom `SQuatCache` class. It requires passing an additional `query_states` input to `.update()`. To support this, you can monkey patch the `LlamaAttention.forward` method—see the example usage below.
12
+
13
+ For the full implementation, please refer to the [original repository](https://github.com/Red-Hat-AI-Innovation-Team/SQuat).
14
+
15
+
16
+ ## Base model:
17
+ `meta-llama/Llama-3.1-8B-Instruct`
18
+
19
+ ## Model compatibility
20
+ Most models. More specifically, any `transformer` LLM/VLM trained for causal language modeling.
21
+
22
+ ## Additional Arguments
23
+ `backend` (`str`, *optional*): quantization backend, default is `quanto`
24
+ `nbits` (`int`, *optional*): number of bits for quantization, default is `2`
25
+ `quant_group_size` (`int`, *optional*): quantization group size, default is `64`
26
+ `residual_length` (`int`, *optional*): residual length, default is `32`
27
+ `squat_lambda` (`float`, *optional*): squat lambda, default is `0.001`
28
+ `subspace_dim` (`int`, *optional*): subspace dimension, default is `10`
29
+ `shared_svd` (`bool`, *optional*): if use shared svd, default is `True`
30
+
31
+ ## Output Type changes
32
+ (none)
33
+
34
+ ## Example usage
35
+
36
+ ```py
37
+ import torch
38
+ from typing import Callable, Optional, Tuple
39
+ from transformers.cache_utils import Cache
40
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward
41
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
42
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
43
+ from transformers.processing_utils import Unpack
44
+ import transformers
45
+
46
+ from transformers import AutoTokenizer, AutoModelForCausalLM
47
+
48
+ def llama_attn_forward(
49
+ self,
50
+ hidden_states: torch.Tensor,
51
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
52
+ attention_mask: Optional[torch.Tensor],
53
+ past_key_value: Optional[Cache] = None,
54
+ cache_position: Optional[torch.LongTensor] = None,
55
+ **kwargs: Unpack[FlashAttentionKwargs],
56
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
57
+
58
+ input_shape = hidden_states.shape[:-1]
59
+ hidden_shape = (*input_shape, -1, self.head_dim)
60
+
61
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
62
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
63
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
64
+
65
+ cos, sin = position_embeddings
66
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
67
+
68
+ if past_key_value is not None:
69
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
70
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, "query_states": query_states, "attention_mask": attention_mask}
71
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
72
+
73
+ attention_interface: Callable = eager_attention_forward
74
+
75
+ if self.config._attn_implementation != "eager":
76
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
77
+ logger.warning_once(
78
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
79
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
80
+ )
81
+ else:
82
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
83
+
84
+ attn_output, attn_weights = attention_interface(
85
+ self,
86
+ query_states,
87
+ key_states,
88
+ value_states,
89
+ attention_mask,
90
+ dropout=0.0 if not self.training else self.attention_dropout,
91
+ scaling=self.scaling,
92
+ **kwargs,
93
+ )
94
+
95
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
96
+ attn_output = self.o_proj(attn_output)
97
+ return attn_output, attn_weights
98
+
99
+ def replace_llama():
100
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = llama_attn_forward
101
+
102
+ replace_llama()
103
+
104
+ tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.1-8B-Instruct')
105
+ model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-3.1-8B-Instruct', device_map="auto")
106
+
107
+ inputs = tokenizer(["I like rock music because"], return_tensors="pt").to(model.device)
108
+
109
+ gen_out = model.generate(**inputs, custom_generate="ligongh/squat", trust_remote_code=True)
110
+ print(tokenizer.batch_decode(gen_out))
111
+ ```