| |
| """ |
| Minimal inference script for the Doc-to-LoRA Perceiver. |
| Requirements: pip install transformers>=4.51.0 huggingface_hub torch |
| """ |
| import re, torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from huggingface_hub import hf_hub_download, login |
|
|
| REPO_ID = "farpluto/doc-to-lora-niah" |
| HF_TOKEN = None |
|
|
| if HF_TOKEN: |
| login(token=HF_TOKEN) |
|
|
| ckpt = torch.load(hf_hub_download(REPO_ID, "hypernet.pt", token=HF_TOKEN), |
| map_location="cuda", weights_only=False) |
| hcfg = ckpt["hypernet_cfg"] |
| BASE = ckpt["base_model"] |
| TGT = ckpt["target_module"] |
| ALPHA = ckpt["lora_alpha"] |
| EARLY = ckpt["early_exit"] |
|
|
| tokenizer = AutoTokenizer.from_pretrained(BASE, token=HF_TOKEN, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| llm = AutoModelForCausalLM.from_pretrained( |
| BASE, token=HF_TOKEN, torch_dtype=torch.bfloat16, |
| device_map="cuda", attn_implementation="sdpa", trust_remote_code=True) |
| llm.eval() |
| for p in llm.parameters(): |
| p.requires_grad_(False) |
|
|
|
|
| class CrossAttentionBlock(nn.Module): |
| def __init__(self, latent_dim, ctx_dim, n_heads=8): |
| super().__init__() |
| self.n_heads, self.head_dim = n_heads, latent_dim // n_heads |
| self.norm_q=nn.LayerNorm(latent_dim); self.norm_ctx=nn.LayerNorm(ctx_dim) |
| self.q_proj=nn.Linear(latent_dim,latent_dim,bias=False) |
| self.k_proj=nn.Linear(ctx_dim,latent_dim,bias=False) |
| self.v_proj=nn.Linear(ctx_dim,latent_dim,bias=False) |
| self.o_proj=nn.Linear(latent_dim,latent_dim,bias=False) |
| self.norm_ff=nn.LayerNorm(latent_dim) |
| self.ff=nn.Sequential(nn.Linear(latent_dim,latent_dim*4,bias=False), |
| nn.GELU(),nn.Linear(latent_dim*4,latent_dim,bias=False)) |
| def forward(self,latents,ctx,ctx_mask=None): |
| B,L,D=latents.shape; _,S,_=ctx.shape; H,Dh=self.n_heads,self.head_dim |
| q=self.q_proj(self.norm_q(latents)).view(B,L,H,Dh).transpose(1,2) |
| k=self.k_proj(self.norm_ctx(ctx)).view(B,S,H,Dh).transpose(1,2) |
| v=self.v_proj(ctx).view(B,S,H,Dh).transpose(1,2) |
| bias=None |
| if ctx_mask is not None: |
| bias=(1.0-ctx_mask.float()).unsqueeze(1).unsqueeze(2)*-1e4 |
| bias=bias.to(q.dtype) |
| out=F.scaled_dot_product_attention(q,k,v,attn_mask=bias) |
| latents=latents+self.o_proj(out.transpose(1,2).contiguous().view(B,L,D)) |
| latents=latents+self.ff(self.norm_ff(latents)) |
| return latents |
|
|
|
|
| class PerceiverHypernet(nn.Module): |
| def __init__(self,ctx_dim,n_lora_layers,lora_r,target_in,target_out, |
| latent_dim=512,n_blocks=8): |
| super().__init__() |
| self.n_lora_layers=n_lora_layers; self.din=target_in; self.dout=target_out |
| d=target_in+target_out |
| self.ctx_proj=nn.Linear(ctx_dim,latent_dim,bias=False) |
| self.ctx_norm=nn.LayerNorm(latent_dim) |
| self.latent_q=nn.Parameter(torch.randn(lora_r,latent_dim)*latent_dim**-0.5) |
| self.blocks=nn.ModuleList([CrossAttentionBlock(latent_dim,latent_dim) for _ in range(n_blocks)]) |
| self.head_w=nn.Parameter(torch.randn(n_lora_layers,latent_dim,d)*0.01) |
| self.head_b=nn.Parameter(torch.zeros(n_lora_layers,d)) |
| def forward(self,ctx_acts,ctx_mask=None): |
| B=ctx_acts.shape[0] |
| ctx=self.ctx_norm(self.ctx_proj(ctx_acts)) |
| lat=self.latent_q.unsqueeze(0).expand(B,-1,-1) |
| for blk in self.blocks: lat=blk(lat,ctx,ctx_mask) |
| flat=torch.einsum("brd,nde->bnre",lat,self.head_w) |
| flat=flat+self.head_b.unsqueeze(0).unsqueeze(2) |
| return flat[...,:self.din], flat[...,self.din:].transpose(-1,-2) |
|
|
|
|
| hypernet=PerceiverHypernet(**hcfg).to("cuda",dtype=torch.bfloat16) |
| hypernet.load_state_dict(ckpt["state_dict"]) |
| hypernet.eval() |
| print("Perceiver loaded OK") |
|
|
| _tok_open = tokenizer.convert_tokens_to_ids("<think>") |
| _tok_close = tokenizer.convert_tokens_to_ids("</think>") |
| THINK_TOKENS = {t for t in [_tok_open,_tok_close] |
| if t not in (tokenizer.unk_token_id, None)} |
|
|
| def _strip_think(ids): |
| toks=ids.tolist() |
| if not THINK_TOKENS or not any(t in THINK_TOKENS for t in toks): |
| return tokenizer.decode(toks,skip_special_tokens=True).strip() |
| clean,inside=[],False |
| op,cl=min(THINK_TOKENS),max(THINK_TOKENS) |
| for t in toks: |
| if t==op: inside=True |
| elif t==cl: inside=False |
| elif not inside: clean.append(t) |
| return tokenizer.decode(clean,skip_special_tokens=True).strip() |
|
|
| def _sorted_mods(model,mod_name): |
| mods=[(n,m) for n,m in model.named_modules() |
| if mod_name in n and isinstance(m,nn.Linear)] |
| def _idx(name): |
| nums=re.findall(r"\d+",name) |
| return int(nums[0]) if nums else -1 |
| return sorted(mods,key=lambda x:_idx(x[0])) |
|
|
| target_mods=_sorted_mods(llm,TGT) |
| scale=ALPHA/hcfg["lora_r"] |
|
|
| @torch.no_grad() |
| def internalize_and_query(document,query,max_new_tokens=12): |
| ctx_ids=torch.tensor( |
| [tokenizer.encode(document,add_special_tokens=False)],device="cuda") |
| ctx_mask=torch.ones_like(ctx_ids) |
| qry_ids=torch.tensor( |
| [tokenizer.encode(query+" /no_think",add_special_tokens=True)],device="cuda") |
| acts=llm(input_ids=ctx_ids,attention_mask=ctx_mask, |
| output_hidden_states=True,use_cache=False).hidden_states[EARLY] |
| A,B=hypernet(acts,ctx_mask) |
| A,B=A.squeeze(0),B.squeeze(0) |
| hooks=[] |
| def _mkhook(Ai,Bi): |
| def h(mod,inp,out): return out+scale*(inp[0]@Ai.t())@Bi.t() |
| return h |
| for i,(_,mod) in enumerate(target_mods): |
| hooks.append(mod.register_forward_hook(_mkhook(A[i],B[i]))) |
| ids=qry_ids.clone() |
| for _ in range(max_new_tokens): |
| out=llm(input_ids=ids,attention_mask=torch.ones_like(ids),use_cache=False) |
| nxt=out.logits[:,-1,:].argmax(-1,keepdim=True) |
| ids=torch.cat([ids,nxt],dim=1) |
| if nxt.item()==tokenizer.eos_token_id: break |
| for h in hooks: h.remove() |
| return _strip_think(ids[0,qry_ids.shape[1]:]) |
|
|
| if __name__=="__main__": |
| doc = "The special magic number is 7341. The sky is blue today." |
| ans = internalize_and_query(doc, "What is the special magic number?") |
| print(f"Answer: {ans}") |
|
|