import torch from transformers import AutoModelForCausalLM, AutoTokenizer from gravity_attention_qwen import patch_qwen_with_gravity REPO = "." # or "squ11z1/Gravity-2" tok = AutoTokenizer.from_pretrained(REPO) model = AutoModelForCausalLM.from_pretrained(REPO, dtype=torch.bfloat16, device_map="cuda", attn_implementation="eager") patch_qwen_with_gravity(model) # re-enable gravity attention masses = torch.load(f"{REPO}/gravity_mass_log.pt", map_location="cuda") for i, layer in enumerate(model.model.layers): layer.self_attn.gravity_mass_log.data.copy_(masses[f"model.layers.{i}.self_attn.gravity_mass_log"].cuda()) model.eval() ids = tok.apply_chat_template([{"role":"user","content":"What is 24*17?"}], add_generation_prompt=True, return_tensors="pt", return_dict=True)["input_ids"].cuda() print(tok.decode(model.generate(ids, max_new_tokens=200)[0, ids.shape[1]:], skip_special_tokens=True))