File size: 1,001 Bytes
d3b580b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from gravity_attention_qwen import patch_qwen_with_gravity

REPO = "."  # or "squ11z1/Gravity-2"
tok = AutoTokenizer.from_pretrained(REPO)
model = AutoModelForCausalLM.from_pretrained(REPO, dtype=torch.bfloat16,
                                             device_map="cuda", attn_implementation="eager")
patch_qwen_with_gravity(model)                       # re-enable gravity attention
masses = torch.load(f"{REPO}/gravity_mass_log.pt", map_location="cuda")
for i, layer in enumerate(model.model.layers):
    layer.self_attn.gravity_mass_log.data.copy_(masses[f"model.layers.{i}.self_attn.gravity_mass_log"].cuda())
model.eval()
ids = tok.apply_chat_template([{"role":"user","content":"What is 24*17?"}],
                              add_generation_prompt=True, return_tensors="pt", return_dict=True)["input_ids"].cuda()
print(tok.decode(model.generate(ids, max_new_tokens=200)[0, ids.shape[1]:], skip_special_tokens=True))