File size: 6,363 Bytes
0667dc9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | #!/usr/bin/env python3
"""
Minimal inference script for the Doc-to-LoRA Perceiver.
Requirements: pip install transformers>=4.51.0 huggingface_hub torch
"""
import re, torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import hf_hub_download, login
REPO_ID = "farpluto/doc-to-lora-niah" # filled in automatically at packaging time
HF_TOKEN = None # set your token here if the base model is gated
if HF_TOKEN:
login(token=HF_TOKEN)
ckpt = torch.load(hf_hub_download(REPO_ID, "hypernet.pt", token=HF_TOKEN),
map_location="cuda", weights_only=False)
hcfg = ckpt["hypernet_cfg"]
BASE = ckpt["base_model"]
TGT = ckpt["target_module"]
ALPHA = ckpt["lora_alpha"]
EARLY = ckpt["early_exit"]
tokenizer = AutoTokenizer.from_pretrained(BASE, token=HF_TOKEN, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
llm = AutoModelForCausalLM.from_pretrained(
BASE, token=HF_TOKEN, torch_dtype=torch.bfloat16,
device_map="cuda", attn_implementation="sdpa", trust_remote_code=True)
llm.eval()
for p in llm.parameters():
p.requires_grad_(False)
class CrossAttentionBlock(nn.Module):
def __init__(self, latent_dim, ctx_dim, n_heads=8):
super().__init__()
self.n_heads, self.head_dim = n_heads, latent_dim // n_heads
self.norm_q=nn.LayerNorm(latent_dim); self.norm_ctx=nn.LayerNorm(ctx_dim)
self.q_proj=nn.Linear(latent_dim,latent_dim,bias=False)
self.k_proj=nn.Linear(ctx_dim,latent_dim,bias=False)
self.v_proj=nn.Linear(ctx_dim,latent_dim,bias=False)
self.o_proj=nn.Linear(latent_dim,latent_dim,bias=False)
self.norm_ff=nn.LayerNorm(latent_dim)
self.ff=nn.Sequential(nn.Linear(latent_dim,latent_dim*4,bias=False),
nn.GELU(),nn.Linear(latent_dim*4,latent_dim,bias=False))
def forward(self,latents,ctx,ctx_mask=None):
B,L,D=latents.shape; _,S,_=ctx.shape; H,Dh=self.n_heads,self.head_dim
q=self.q_proj(self.norm_q(latents)).view(B,L,H,Dh).transpose(1,2)
k=self.k_proj(self.norm_ctx(ctx)).view(B,S,H,Dh).transpose(1,2)
v=self.v_proj(ctx).view(B,S,H,Dh).transpose(1,2)
bias=None
if ctx_mask is not None:
bias=(1.0-ctx_mask.float()).unsqueeze(1).unsqueeze(2)*-1e4
bias=bias.to(q.dtype)
out=F.scaled_dot_product_attention(q,k,v,attn_mask=bias)
latents=latents+self.o_proj(out.transpose(1,2).contiguous().view(B,L,D))
latents=latents+self.ff(self.norm_ff(latents))
return latents
class PerceiverHypernet(nn.Module):
def __init__(self,ctx_dim,n_lora_layers,lora_r,target_in,target_out,
latent_dim=512,n_blocks=8):
super().__init__()
self.n_lora_layers=n_lora_layers; self.din=target_in; self.dout=target_out
d=target_in+target_out
self.ctx_proj=nn.Linear(ctx_dim,latent_dim,bias=False)
self.ctx_norm=nn.LayerNorm(latent_dim)
self.latent_q=nn.Parameter(torch.randn(lora_r,latent_dim)*latent_dim**-0.5)
self.blocks=nn.ModuleList([CrossAttentionBlock(latent_dim,latent_dim) for _ in range(n_blocks)])
self.head_w=nn.Parameter(torch.randn(n_lora_layers,latent_dim,d)*0.01)
self.head_b=nn.Parameter(torch.zeros(n_lora_layers,d))
def forward(self,ctx_acts,ctx_mask=None):
B=ctx_acts.shape[0]
ctx=self.ctx_norm(self.ctx_proj(ctx_acts))
lat=self.latent_q.unsqueeze(0).expand(B,-1,-1)
for blk in self.blocks: lat=blk(lat,ctx,ctx_mask)
flat=torch.einsum("brd,nde->bnre",lat,self.head_w)
flat=flat+self.head_b.unsqueeze(0).unsqueeze(2)
return flat[...,:self.din], flat[...,self.din:].transpose(-1,-2)
hypernet=PerceiverHypernet(**hcfg).to("cuda",dtype=torch.bfloat16)
hypernet.load_state_dict(ckpt["state_dict"])
hypernet.eval()
print("Perceiver loaded OK")
_tok_open = tokenizer.convert_tokens_to_ids("<think>")
_tok_close = tokenizer.convert_tokens_to_ids("</think>")
THINK_TOKENS = {t for t in [_tok_open,_tok_close]
if t not in (tokenizer.unk_token_id, None)}
def _strip_think(ids):
toks=ids.tolist()
if not THINK_TOKENS or not any(t in THINK_TOKENS for t in toks):
return tokenizer.decode(toks,skip_special_tokens=True).strip()
clean,inside=[],False
op,cl=min(THINK_TOKENS),max(THINK_TOKENS)
for t in toks:
if t==op: inside=True
elif t==cl: inside=False
elif not inside: clean.append(t)
return tokenizer.decode(clean,skip_special_tokens=True).strip()
def _sorted_mods(model,mod_name):
mods=[(n,m) for n,m in model.named_modules()
if mod_name in n and isinstance(m,nn.Linear)]
def _idx(name):
nums=re.findall(r"\d+",name)
return int(nums[0]) if nums else -1
return sorted(mods,key=lambda x:_idx(x[0]))
target_mods=_sorted_mods(llm,TGT)
scale=ALPHA/hcfg["lora_r"]
@torch.no_grad()
def internalize_and_query(document,query,max_new_tokens=12):
ctx_ids=torch.tensor(
[tokenizer.encode(document,add_special_tokens=False)],device="cuda")
ctx_mask=torch.ones_like(ctx_ids)
qry_ids=torch.tensor(
[tokenizer.encode(query+" /no_think",add_special_tokens=True)],device="cuda")
acts=llm(input_ids=ctx_ids,attention_mask=ctx_mask,
output_hidden_states=True,use_cache=False).hidden_states[EARLY]
A,B=hypernet(acts,ctx_mask)
A,B=A.squeeze(0),B.squeeze(0)
hooks=[]
def _mkhook(Ai,Bi):
def h(mod,inp,out): return out+scale*(inp[0]@Ai.t())@Bi.t()
return h
for i,(_,mod) in enumerate(target_mods):
hooks.append(mod.register_forward_hook(_mkhook(A[i],B[i])))
ids=qry_ids.clone()
for _ in range(max_new_tokens):
out=llm(input_ids=ids,attention_mask=torch.ones_like(ids),use_cache=False)
nxt=out.logits[:,-1,:].argmax(-1,keepdim=True)
ids=torch.cat([ids,nxt],dim=1)
if nxt.item()==tokenizer.eos_token_id: break
for h in hooks: h.remove()
return _strip_think(ids[0,qry_ids.shape[1]:])
if __name__=="__main__":
doc = "The special magic number is 7341. The sky is blue today."
ans = internalize_and_query(doc, "What is the special magic number?")
print(f"Answer: {ans}")
|