#!/usr/bin/env python3 """ Minimal inference script for the Doc-to-LoRA Perceiver. Requirements: pip install transformers>=4.51.0 huggingface_hub torch """ import re, torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import hf_hub_download, login REPO_ID = "farpluto/doc-to-lora-niah" # filled in automatically at packaging time HF_TOKEN = None # set your token here if the base model is gated if HF_TOKEN: login(token=HF_TOKEN) ckpt = torch.load(hf_hub_download(REPO_ID, "hypernet.pt", token=HF_TOKEN), map_location="cuda", weights_only=False) hcfg = ckpt["hypernet_cfg"] BASE = ckpt["base_model"] TGT = ckpt["target_module"] ALPHA = ckpt["lora_alpha"] EARLY = ckpt["early_exit"] tokenizer = AutoTokenizer.from_pretrained(BASE, token=HF_TOKEN, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token llm = AutoModelForCausalLM.from_pretrained( BASE, token=HF_TOKEN, torch_dtype=torch.bfloat16, device_map="cuda", attn_implementation="sdpa", trust_remote_code=True) llm.eval() for p in llm.parameters(): p.requires_grad_(False) class CrossAttentionBlock(nn.Module): def __init__(self, latent_dim, ctx_dim, n_heads=8): super().__init__() self.n_heads, self.head_dim = n_heads, latent_dim // n_heads self.norm_q=nn.LayerNorm(latent_dim); self.norm_ctx=nn.LayerNorm(ctx_dim) self.q_proj=nn.Linear(latent_dim,latent_dim,bias=False) self.k_proj=nn.Linear(ctx_dim,latent_dim,bias=False) self.v_proj=nn.Linear(ctx_dim,latent_dim,bias=False) self.o_proj=nn.Linear(latent_dim,latent_dim,bias=False) self.norm_ff=nn.LayerNorm(latent_dim) self.ff=nn.Sequential(nn.Linear(latent_dim,latent_dim*4,bias=False), nn.GELU(),nn.Linear(latent_dim*4,latent_dim,bias=False)) def forward(self,latents,ctx,ctx_mask=None): B,L,D=latents.shape; _,S,_=ctx.shape; H,Dh=self.n_heads,self.head_dim q=self.q_proj(self.norm_q(latents)).view(B,L,H,Dh).transpose(1,2) k=self.k_proj(self.norm_ctx(ctx)).view(B,S,H,Dh).transpose(1,2) v=self.v_proj(ctx).view(B,S,H,Dh).transpose(1,2) bias=None if ctx_mask is not None: bias=(1.0-ctx_mask.float()).unsqueeze(1).unsqueeze(2)*-1e4 bias=bias.to(q.dtype) out=F.scaled_dot_product_attention(q,k,v,attn_mask=bias) latents=latents+self.o_proj(out.transpose(1,2).contiguous().view(B,L,D)) latents=latents+self.ff(self.norm_ff(latents)) return latents class PerceiverHypernet(nn.Module): def __init__(self,ctx_dim,n_lora_layers,lora_r,target_in,target_out, latent_dim=512,n_blocks=8): super().__init__() self.n_lora_layers=n_lora_layers; self.din=target_in; self.dout=target_out d=target_in+target_out self.ctx_proj=nn.Linear(ctx_dim,latent_dim,bias=False) self.ctx_norm=nn.LayerNorm(latent_dim) self.latent_q=nn.Parameter(torch.randn(lora_r,latent_dim)*latent_dim**-0.5) self.blocks=nn.ModuleList([CrossAttentionBlock(latent_dim,latent_dim) for _ in range(n_blocks)]) self.head_w=nn.Parameter(torch.randn(n_lora_layers,latent_dim,d)*0.01) self.head_b=nn.Parameter(torch.zeros(n_lora_layers,d)) def forward(self,ctx_acts,ctx_mask=None): B=ctx_acts.shape[0] ctx=self.ctx_norm(self.ctx_proj(ctx_acts)) lat=self.latent_q.unsqueeze(0).expand(B,-1,-1) for blk in self.blocks: lat=blk(lat,ctx,ctx_mask) flat=torch.einsum("brd,nde->bnre",lat,self.head_w) flat=flat+self.head_b.unsqueeze(0).unsqueeze(2) return flat[...,:self.din], flat[...,self.din:].transpose(-1,-2) hypernet=PerceiverHypernet(**hcfg).to("cuda",dtype=torch.bfloat16) hypernet.load_state_dict(ckpt["state_dict"]) hypernet.eval() print("Perceiver loaded OK") _tok_open = tokenizer.convert_tokens_to_ids("") _tok_close = tokenizer.convert_tokens_to_ids("") THINK_TOKENS = {t for t in [_tok_open,_tok_close] if t not in (tokenizer.unk_token_id, None)} def _strip_think(ids): toks=ids.tolist() if not THINK_TOKENS or not any(t in THINK_TOKENS for t in toks): return tokenizer.decode(toks,skip_special_tokens=True).strip() clean,inside=[],False op,cl=min(THINK_TOKENS),max(THINK_TOKENS) for t in toks: if t==op: inside=True elif t==cl: inside=False elif not inside: clean.append(t) return tokenizer.decode(clean,skip_special_tokens=True).strip() def _sorted_mods(model,mod_name): mods=[(n,m) for n,m in model.named_modules() if mod_name in n and isinstance(m,nn.Linear)] def _idx(name): nums=re.findall(r"\d+",name) return int(nums[0]) if nums else -1 return sorted(mods,key=lambda x:_idx(x[0])) target_mods=_sorted_mods(llm,TGT) scale=ALPHA/hcfg["lora_r"] @torch.no_grad() def internalize_and_query(document,query,max_new_tokens=12): ctx_ids=torch.tensor( [tokenizer.encode(document,add_special_tokens=False)],device="cuda") ctx_mask=torch.ones_like(ctx_ids) qry_ids=torch.tensor( [tokenizer.encode(query+" /no_think",add_special_tokens=True)],device="cuda") acts=llm(input_ids=ctx_ids,attention_mask=ctx_mask, output_hidden_states=True,use_cache=False).hidden_states[EARLY] A,B=hypernet(acts,ctx_mask) A,B=A.squeeze(0),B.squeeze(0) hooks=[] def _mkhook(Ai,Bi): def h(mod,inp,out): return out+scale*(inp[0]@Ai.t())@Bi.t() return h for i,(_,mod) in enumerate(target_mods): hooks.append(mod.register_forward_hook(_mkhook(A[i],B[i]))) ids=qry_ids.clone() for _ in range(max_new_tokens): out=llm(input_ids=ids,attention_mask=torch.ones_like(ids),use_cache=False) nxt=out.logits[:,-1,:].argmax(-1,keepdim=True) ids=torch.cat([ids,nxt],dim=1) if nxt.item()==tokenizer.eos_token_id: break for h in hooks: h.remove() return _strip_think(ids[0,qry_ids.shape[1]:]) if __name__=="__main__": doc = "The special magic number is 7341. The sky is blue today." ans = internalize_and_query(doc, "What is the special magic number?") print(f"Answer: {ans}")