Gravity-2 / load_gravity2.py
squ11z1's picture
Gravity-2 stage-1: VibeThinker-3B with gravity attention (LoRA merged + trained masses)
d3b580b verified
Raw
History Blame Contribute Delete
1 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from gravity_attention_qwen import patch_qwen_with_gravity
REPO = "." # or "squ11z1/Gravity-2"
tok = AutoTokenizer.from_pretrained(REPO)
model = AutoModelForCausalLM.from_pretrained(REPO, dtype=torch.bfloat16,
device_map="cuda", attn_implementation="eager")
patch_qwen_with_gravity(model) # re-enable gravity attention
masses = torch.load(f"{REPO}/gravity_mass_log.pt", map_location="cuda")
for i, layer in enumerate(model.model.layers):
layer.self_attn.gravity_mass_log.data.copy_(masses[f"model.layers.{i}.self_attn.gravity_mass_log"].cuda())
model.eval()
ids = tok.apply_chat_template([{"role":"user","content":"What is 24*17?"}],
add_generation_prompt=True, return_tensors="pt", return_dict=True)["input_ids"].cuda()
print(tok.decode(model.generate(ids, max_new_tokens=200)[0, ids.shape[1]:], skip_special_tokens=True))