""" MBA Expression Simplifier — Standalone Inference ================================================= Usage: python inference.py "(x & y) + (x | y)" from inference import simplify """ import json, re, sys, time from dataclasses import dataclass from pathlib import Path from typing import Union, Optional _HERE = Path(__file__).resolve().parent # ── AST ─────────────────────────────────────────────────────────────────────── @dataclass(frozen=True) class Var: name: str def __repr__(self): return self.name @dataclass(frozen=True) class Const: val: int def __repr__(self): return str(self.val) @dataclass(frozen=True) class BinOp: op: str; left: 'Expr'; right: 'Expr' def __repr__(self): return f"({self.op} {self.left} {self.right})" @dataclass(frozen=True) class UnOp: op: str; arg: 'Expr' def __repr__(self): return f"({self.op} {self.arg})" Expr = Union[Var, Const, BinOp, UnOp] def size(e) -> int: match e: case Var(_) | Const(_): return 1 case BinOp(_, l, r): return 1 + size(l) + size(r) case UnOp(_, a): return 1 + size(a) # ── EVALUATOR ───────────────────────────────────────────────────────────────── def make_ev(bits: int): MASK = (1 << bits) - 1 def ev(e, env: dict) -> int: match e: case Var(n): return env.get(n, 0) & MASK case Const(v): return v & MASK case BinOp('+',l,r): return (ev(l,env) + ev(r,env)) & MASK case BinOp('-',l,r): return (ev(l,env) - ev(r,env)) & MASK case BinOp('*',l,r): return (ev(l,env) * ev(r,env)) & MASK case BinOp('&',l,r): return ev(l,env) & ev(r,env) case BinOp('|',l,r): return ev(l,env) | ev(r,env) case BinOp('^',l,r): return ev(l,env) ^ ev(r,env) case UnOp('~',a): return (~ev(a,env)) & MASK case UnOp('-',a): return (-ev(a,env)) & MASK return ev _ev8 = make_ev(8) # ── BFS FINGERPRINT (must match bfs_table_6.json exactly) ──────────────────── # Fixed 16-point evaluation set from poly_synth_final.py. # Changing these invalidates the pre-computed table. _EVAL_POINTS = [ (0,0),(1,0),(0,1),(1,1), (2,3),(5,7),(11,13),(17,19), (64,128),(128,64),(255,0), (170,85),(204,51),(85,170), (7,15),(31,63), ] def _fp8(expr, v0: str, v1: str) -> tuple: return tuple(_ev8(expr, {v0: px, v1: py}) for px, py in _EVAL_POINTS) # ── PREFIX PARSER (for loading bfs_table_6.json values) ────────────────────── def _parse_prefix(s: str) -> Expr: pos = [0] def skip(): while pos[0] < len(s) and s[pos[0]] == ' ': pos[0] += 1 def parse(): skip() if pos[0] >= len(s): raise ValueError(f"Unexpected end: {s!r}") c = s[pos[0]] if c == '(': pos[0] += 1; skip() op_s = pos[0] while pos[0] < len(s) and s[pos[0]] not in (' ', ')'): pos[0] += 1 op = s[op_s:pos[0]]; skip() left = parse(); skip() if pos[0] < len(s) and s[pos[0]] == ')': pos[0] += 1; return UnOp(op, left) right = parse(); skip() if pos[0] < len(s) and s[pos[0]] == ')': pos[0] += 1 return BinOp(op, left, right) elif c.isdigit(): start = pos[0] while pos[0] < len(s) and s[pos[0]].isdigit(): pos[0] += 1 return Const(int(s[start:pos[0]])) elif c.isalpha() or c == '_': start = pos[0] while pos[0] < len(s) and (s[pos[0]].isalnum() or s[pos[0]] == '_'): pos[0] += 1 return Var(s[start:pos[0]]) else: raise ValueError(f"Unexpected {c!r} in {s!r}") return parse() # ── BFS TABLE LOAD ──────────────────────────────────────────────────────────── _bfs_table: Optional[dict] = None def _load_bfs() -> dict: global _bfs_table if _bfs_table is not None: return _bfs_table path = _HERE / 'bfs_table_6.json' if not path.exists(): print(f"Warning: {path} not found — BFS disabled.", file=sys.stderr) _bfs_table = {}; return _bfs_table t0 = time.perf_counter() with open(path) as f: raw = json.load(f) table = {} for fp_str, expr_str in raw.items(): table[tuple(json.loads(fp_str))] = _parse_prefix(expr_str) _bfs_table = table print(f"BFS: {len(table):,} entries ({(time.perf_counter()-t0)*1000:.0f}ms)", flush=True) return _bfs_table def _rename_vars(expr, var_map: dict) -> Expr: match expr: case Var(name): return Var(var_map.get(name, name)) case Const(_): return expr case BinOp(op,l,r): return BinOp(op, _rename_vars(l,var_map), _rename_vars(r,var_map)) case UnOp(op,a): return UnOp(op, _rename_vars(a,var_map)) # ── ORACLE VERIFY ───────────────────────────────────────────────────────────── def _verify(e1, e2, bits: int, vars_tuple: tuple, n: int = 512) -> bool: import random MASK = (1 << bits) - 1 ev = make_ev(bits) rng = random.Random(7) for _ in range(n): env = {v: rng.randint(0, MASK) for v in vars_tuple} if ev(e1, env) != ev(e2, env): return False return True # ── SYMBOLIC ENGINE (inlined from pipeline_v5.py) ───────────────────────────── def make_symbolic(bits: int): MASK = (1 << bits) - 1 def rw(e): match e: case BinOp('+',BinOp('&',a,b),BinOp('|',c,d)) if {a,b}=={c,d}: return BinOp('+',min(a,b,key=repr),max(a,b,key=repr)) case BinOp('+',BinOp('|',a,b),BinOp('&',c,d)) if {a,b}=={c,d}: return BinOp('+',min(a,b,key=repr),max(a,b,key=repr)) case BinOp('-',BinOp('|',a,b),BinOp('&',c,d)) if {a,b}=={c,d}: return BinOp('^',min(a,b,key=repr),max(a,b,key=repr)) case BinOp('-',BinOp('+',a,b),BinOp('*',Const(2),BinOp('&',c,d))) if {a,b}=={c,d}: return BinOp('^',min(a,b,key=repr),max(a,b,key=repr)) case BinOp('-',BinOp('+',a,b),BinOp('*',BinOp('&',c,d),Const(2))) if {a,b}=={c,d}: return BinOp('^',min(a,b,key=repr),max(a,b,key=repr)) case BinOp('-',UnOp('-',x),Const(1)): return UnOp('~',x) case BinOp('+',UnOp('-',x),Const(v)) if (v&MASK)==MASK: return UnOp('~',x) case BinOp('+',UnOp('~',x),Const(1)): return UnOp('-',x) case BinOp('+',Const(1),UnOp('~',x)): return UnOp('-',x) case BinOp('^',UnOp('~',a),UnOp('~',b)): return BinOp('^',min(a,b,key=repr),max(a,b,key=repr)) case UnOp('~',BinOp('&',UnOp('~',a),UnOp('~',b))): return BinOp('|',min(a,b,key=repr),max(a,b,key=repr)) case UnOp('~',BinOp('|',UnOp('~',a),UnOp('~',b))): return BinOp('&',min(a,b,key=repr),max(a,b,key=repr)) case BinOp('|',a,BinOp('&',b,c)) if a==b or a==c: return a case BinOp('|',BinOp('&',a,b),c) if c==a or c==b: return c case BinOp('&',a,BinOp('|',b,c)) if a==b or a==c: return a case BinOp('&',BinOp('|',a,b),c) if c==a or c==b: return c case BinOp('^',a,BinOp('&',b,c)) if a==b: return BinOp('&',a,UnOp('~',c)) case BinOp('^',a,BinOp('&',b,c)) if a==c: return BinOp('&',a,UnOp('~',b)) case BinOp('^',BinOp('&',a,b),c) if c==a: return BinOp('&',c,UnOp('~',b)) case BinOp('^',BinOp('&',a,b),c) if c==b: return BinOp('&',c,UnOp('~',a)) case UnOp('-',BinOp('-',a,b)): return BinOp('-',b,a) case BinOp('+',a,UnOp('-',b)) if a==b: return Const(0) case BinOp('+',UnOp('-',a),b) if a==b: return Const(0) case BinOp('-',a,UnOp('-',b)): return BinOp('+',a,b) case BinOp('&',a,b) if a==b: return a case BinOp('|',a,b) if a==b: return a case BinOp('^',a,b) if a==b: return Const(0) case BinOp('-',a,b) if a==b: return Const(0) case BinOp('&',_,Const(0))|BinOp('&',Const(0),_): return Const(0) case BinOp('|',x,Const(0))|BinOp('|',Const(0),x): return x case BinOp('^',x,Const(0))|BinOp('^',Const(0),x): return x case BinOp('^',x,Const(v)) if (v&MASK)==MASK: return UnOp('~',x) case BinOp('^',Const(v),x) if (v&MASK)==MASK: return UnOp('~',x) case BinOp('&',x,Const(v)) if (v&MASK)==MASK: return x case BinOp('&',Const(v),x) if (v&MASK)==MASK: return x case BinOp('|',_,Const(v))|BinOp('|',Const(v),_) if (v&MASK)==MASK: return Const(MASK) case BinOp('&',a,UnOp('~',b)) if a==b: return Const(0) case BinOp('&',UnOp('~',a),b) if a==b: return Const(0) case BinOp('|',a,UnOp('~',b)) if a==b: return Const(MASK) case BinOp('|',UnOp('~',a),b) if a==b: return Const(MASK) case BinOp('^',a,UnOp('~',b)) if a==b: return Const(MASK) case BinOp('+',x,Const(0))|BinOp('+',Const(0),x): return x case BinOp('-',x,Const(0)): return x case BinOp('*',x,Const(1))|BinOp('*',Const(1),x): return x case BinOp('*',_,Const(0))|BinOp('*',Const(0),_): return Const(0) case UnOp('-',UnOp('-',x)): return x case UnOp('~',UnOp('~',x)): return x case UnOp('-',Const(0)): return Const(0) case UnOp('-',UnOp('~',x)): return BinOp('+',x,Const(1&MASK)) case UnOp('~',UnOp('-',x)): return BinOp('-',x,Const(1&MASK)) case BinOp('-',BinOp('+',a,b),c) if b==c: return a case BinOp('-',BinOp('+',a,b),c) if a==c: return b case BinOp('+',BinOp('-',a,b),c) if b==c: return a case BinOp('-',a,BinOp('+',b,c)) if a==b: return UnOp('-',c) case BinOp('-',a,BinOp('+',b,c)) if a==c: return UnOp('-',b) case BinOp('-',BinOp('+',a,b),BinOp('+',c,d)) if a==c: return BinOp('-',b,d) case BinOp('-',BinOp('+',a,b),BinOp('+',c,d)) if b==d: return BinOp('-',a,c) case BinOp('-',BinOp('+',a,b),BinOp('+',c,d)) if a==d: return BinOp('-',b,c) case BinOp('-',BinOp('+',a,b),BinOp('+',c,d)) if b==c: return BinOp('-',a,d) case BinOp(op,Const(a),Const(b)): match op: case '+': return Const((a+b)&MASK) case '-': return Const((a-b)&MASK) case '*': return Const((a*b)&MASK) case '&': return Const(a&b&MASK) case '|': return Const((a|b)&MASK) case '^': return Const((a^b)&MASK) case UnOp('-',Const(a)): return Const((-a)&MASK) case UnOp('~',Const(a)): return Const((~a)&MASK) case BinOp(op,a,b) if op in('+','*','&','|','^') and repr(a)>repr(b): return BinOp(op,b,a) return None def sym_pass(e): match e: case Var(_)|Const(_): return e case BinOp(op,l,r): nl=sym_pass(l); nr=sym_pass(r); ne=BinOp(op,nl,nr) r2=rw(ne) if r2 is None: return ne return Const(r2.val&MASK) if isinstance(r2,Const) else r2 case UnOp(op,a): na=sym_pass(a); ne=UnOp(op,na) r2=rw(ne) if r2 is None: return ne return Const(r2.val&MASK) if isinstance(r2,Const) else r2 def sym(e, passes=30): for _ in range(passes): ne=sym_pass(e) if ne==e: break e=ne return e return sym, rw, sym_pass # ── NEURAL LAYER (optional) ─────────────────────────────────────────────────── class NeuralLayer: def __init__(self, model_path: str, rw_fn, bits: int): self.rw_fn=rw_fn; self.bits=bits; self.loaded=False try: import torch, torch.nn as nn self.torch=torch; self.nn=nn # Load config/vocab from JSON sidecars if present, else extract from .pt cfg_path = _HERE / "config.json" rv_path = _HERE / "rule_vocab.json" cv_path = _HERE / "char_vocab.json" sf_path = _HERE / "mba_classifier.safetensors" if cfg_path.exists() and rv_path.exists() and cv_path.exists(): cfg = json.load(open(cfg_path)) self.rules = json.load(open(rv_path)) self.char2id = json.load(open(cv_path)) else: ckpt = torch.load(model_path, map_location='cpu', weights_only=False) cfg = ckpt['config'] self.rules = ckpt['rule_vocab'] self.char2id = ckpt['char_vocab'] self.n_rules = len(self.rules) self.max_len = cfg['max_len'] self.model = self._build(cfg) # Load weights: prefer .safetensors, fall back to .pt if sf_path.exists(): from safetensors.torch import load_file self.model.load_state_dict(load_file(str(sf_path))) else: ckpt = torch.load(model_path, map_location='cpu', weights_only=False) self.model.load_state_dict(ckpt['model_state']) self.model.eval() self.reduction_rules=[i for i,r in enumerate(self.rules) if not r.startswith('comm_') and not r.startswith('fold_')] self.loaded=True n_params=sum(p.numel() for p in self.model.parameters()) print(f"Neural: {self.n_rules} rules, {n_params/1e6:.1f}M params", flush=True) except Exception as ex: print(f"Neural disabled: {ex}", file=sys.stderr) def _build(self, cfg): import torch.nn as nn; torch=self.torch class RC(nn.Module): def __init__(s,vocab_size,d_model,n_heads,n_layers,d_ff,max_len,n_rules,dropout=0.1): super().__init__() s.embed = nn.Embedding(vocab_size,d_model,padding_idx=0) s.pos_embed = nn.Embedding(max_len,d_model) el = nn.TransformerEncoderLayer(d_model,n_heads,d_ff,dropout,batch_first=True,norm_first=True) s.encoder = nn.TransformerEncoder(el,n_layers) s.norm = nn.LayerNorm(d_model) s.head = nn.Linear(d_model,n_rules) def forward(s,x): pm=(x==0); pos=torch.arange(x.size(1),device=x.device).unsqueeze(0) e=s.embed(x)+s.pos_embed(pos); enc=s.encoder(e,src_key_padding_mask=pm) L=(~pm).float().sum(1,keepdim=True).clamp(min=1) return s.head(s.norm((enc*(~pm).unsqueeze(-1).float()).sum(1)/L)) RC.torch=self.torch return RC(**cfg) def _tok(self, expr_str): ids=[self.char2id.get(c,1) for c in expr_str[:self.max_len]] ids+=[0]*(self.max_len-len(ids)) return self.torch.tensor([ids],dtype=self.torch.long) def search(self, expr, max_steps=50, top_k=5): if not self.loaded: return expr best=expr; bsz=size(expr); cur=expr; vis={repr(cur)} for _ in range(max_steps): try: rules=self._top_rules(repr(cur),top_k) except Exception: break improved=False for rid in rules: for cand in self._apply_everywhere(cur,self.rules[rid]): k=repr(cand) if k in vis: continue if size(cand) list: tokens=[]; i=0 while i Expr: MASK=(1< str: def emit(node,mp=0,ir=False): match node: case Var(n): return n case Const(v): return str(v) case UnOp(op,a): return f"{op}{emit(a,6)}" case BinOp(op,l,r): p=_PREC[op]; s=f"{emit(l,p,False)} {op} {emit(r,p,True)}" return f"({s})" if p tuple: seen=[] def w(n): match n: case Var(name): if name not in seen: seen.append(name) case BinOp(_,l,r): w(l); w(r) case UnOp(_,a): w(a) w(expr); return tuple(sorted(seen)) # ── GLOBAL STATE (lazy) ─────────────────────────────────────────────────────── _neural: Optional[NeuralLayer] = None _sym_cache: dict = {} def _get_sym(bits): if bits not in _sym_cache: _sym_cache[bits]=make_symbolic(bits) return _sym_cache[bits] def _get_neural(model_path: Optional[str]) -> Optional[NeuralLayer]: global _neural if _neural is not None: return _neural if model_path is None: sf = _HERE / 'mba_classifier.safetensors' pt = _HERE / 'mba_classifier_v2.pt' if sf.exists(): model_path=str(sf) elif pt.exists(): model_path=str(pt) else: return None _,rw,_=_get_sym(8) _neural=NeuralLayer(str(model_path), rw, bits=8) return _neural # ── PUBLIC API ──────────────────────────────────────────────────────────────── def simplify( expr: str, bits: int = 8, model_path: Optional[str] = None, ) -> dict: """ Simplify an MBA expression string. Args: expr: Infix expression. Operators: + - * & | ^ ~ (unary: - ~) Variables: identifiers [a-zA-Z_]+. Example: "(x & y) + (x | y)" bits: Bit width — 8 or 32. Default 8. model_path: Path to mba_classifier_v2.pt. Auto-detected if omitted. Returns: dict: simplified (str) — simplified infix expression original (str) — input reduction (float) — size reduction ratio 0.0–1.0 strategy (str) — bfs / symbolic / neural / none latency_ms (float) — wall-clock time Raises: ValueError: on parse error or unsupported bit width. """ if not isinstance(expr, str) or not expr.strip(): raise ValueError("expr must be a non-empty string") if bits not in (8, 32): raise ValueError("bits must be 8 or 32") t0=time.perf_counter() tokens=_tokenize(expr.strip()) ast=_parse_infix(tokens, bits) vars_tuple=_extract_vars(ast) if not vars_tuple: vars_tuple=('x',) orig_size=size(ast); strategy="none"; result=ast # 1. BFS (8-bit, ≤2 vars) if bits==8 and len(vars_tuple)<=2: table=_load_bfs() if table: v0=vars_tuple[0] v1=vars_tuple[1] if len(vars_tuple)>1 else '__' fp=_fp8(ast, v0, v1) if fp in table: bfs_e=table[fp] vm={} if len(vars_tuple)>=1: vm['x']=vars_tuple[0] if len(vars_tuple)>=2: vm['y']=vars_tuple[1] bfs_e=_rename_vars(bfs_e, vm) if size(bfs_e)<=orig_size and _verify(ast, bfs_e, bits, vars_tuple): result=bfs_e; strategy="bfs" # 2. Symbolic if strategy=="none": sym,_,_=_get_sym(bits) sr=sym(ast) if sr!=ast: result=sr; strategy="symbolic" # 3. Neural (optional) if strategy=="none": neural=_get_neural(model_path) if neural and neural.loaded: nr=neural.search(ast) if nr is not ast and size(nr)<=orig_size and _verify(ast, nr, bits, vars_tuple): result=nr; strategy="neural" out_size=size(result) reduction=1.0-out_size/orig_size if orig_size>0 else 0.0 ms=round((time.perf_counter()-t0)*1e3, 1) return { "simplified": _to_infix(result), "original": expr.strip(), "reduction": round(reduction,3), "strategy": strategy, "latency_ms": ms, } # ── CLI ─────────────────────────────────────────────────────────────────────── if __name__=="__main__": import argparse p=argparse.ArgumentParser(description="MBA Expression Simplifier") p.add_argument("expr", help='e.g. "(x & y) + (x | y)"') p.add_argument("--bits",type=int,default=8,choices=[8,32]) p.add_argument("--model",type=str,default=None) args=p.parse_args() try: r=simplify(args.expr,bits=args.bits,model_path=args.model) print(f"Input: {r['original']}") print(f"Simplified: {r['simplified']}") if r['original']!=r['simplified']: print(f"Reduction: {r['reduction']*100:.0f}% ({r['strategy']})") else: print(f"Result: already minimal ({r['strategy']})") print(f"Latency: {r['latency_ms']} ms") except ValueError as e: print(f"Error: {e}",file=sys.stderr); sys.exit(1)