| """ |
| MBA Expression Simplifier β Standalone Inference |
| ================================================= |
| Usage: |
| python inference.py "(x & y) + (x | y)" |
| from inference import simplify |
| """ |
|
|
| import json, re, sys, time |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Union, Optional |
|
|
| _HERE = Path(__file__).resolve().parent |
|
|
| |
|
|
| @dataclass(frozen=True) |
| class Var: |
| name: str |
| def __repr__(self): return self.name |
|
|
| @dataclass(frozen=True) |
| class Const: |
| val: int |
| def __repr__(self): return str(self.val) |
|
|
| @dataclass(frozen=True) |
| class BinOp: |
| op: str; left: 'Expr'; right: 'Expr' |
| def __repr__(self): return f"({self.op} {self.left} {self.right})" |
|
|
| @dataclass(frozen=True) |
| class UnOp: |
| op: str; arg: 'Expr' |
| def __repr__(self): return f"({self.op} {self.arg})" |
|
|
| Expr = Union[Var, Const, BinOp, UnOp] |
|
|
| def size(e) -> int: |
| match e: |
| case Var(_) | Const(_): return 1 |
| case BinOp(_, l, r): return 1 + size(l) + size(r) |
| case UnOp(_, a): return 1 + size(a) |
|
|
| |
|
|
| def make_ev(bits: int): |
| MASK = (1 << bits) - 1 |
| def ev(e, env: dict) -> int: |
| match e: |
| case Var(n): return env.get(n, 0) & MASK |
| case Const(v): return v & MASK |
| case BinOp('+',l,r): return (ev(l,env) + ev(r,env)) & MASK |
| case BinOp('-',l,r): return (ev(l,env) - ev(r,env)) & MASK |
| case BinOp('*',l,r): return (ev(l,env) * ev(r,env)) & MASK |
| case BinOp('&',l,r): return ev(l,env) & ev(r,env) |
| case BinOp('|',l,r): return ev(l,env) | ev(r,env) |
| case BinOp('^',l,r): return ev(l,env) ^ ev(r,env) |
| case UnOp('~',a): return (~ev(a,env)) & MASK |
| case UnOp('-',a): return (-ev(a,env)) & MASK |
| return ev |
|
|
| _ev8 = make_ev(8) |
|
|
| |
| |
| |
|
|
| _EVAL_POINTS = [ |
| (0,0),(1,0),(0,1),(1,1), |
| (2,3),(5,7),(11,13),(17,19), |
| (64,128),(128,64),(255,0), |
| (170,85),(204,51),(85,170), |
| (7,15),(31,63), |
| ] |
|
|
| def _fp8(expr, v0: str, v1: str) -> tuple: |
| return tuple(_ev8(expr, {v0: px, v1: py}) for px, py in _EVAL_POINTS) |
|
|
| |
|
|
| def _parse_prefix(s: str) -> Expr: |
| pos = [0] |
|
|
| def skip(): |
| while pos[0] < len(s) and s[pos[0]] == ' ': |
| pos[0] += 1 |
|
|
| def parse(): |
| skip() |
| if pos[0] >= len(s): raise ValueError(f"Unexpected end: {s!r}") |
| c = s[pos[0]] |
| if c == '(': |
| pos[0] += 1; skip() |
| op_s = pos[0] |
| while pos[0] < len(s) and s[pos[0]] not in (' ', ')'): |
| pos[0] += 1 |
| op = s[op_s:pos[0]]; skip() |
| left = parse(); skip() |
| if pos[0] < len(s) and s[pos[0]] == ')': |
| pos[0] += 1; return UnOp(op, left) |
| right = parse(); skip() |
| if pos[0] < len(s) and s[pos[0]] == ')': pos[0] += 1 |
| return BinOp(op, left, right) |
| elif c.isdigit(): |
| start = pos[0] |
| while pos[0] < len(s) and s[pos[0]].isdigit(): pos[0] += 1 |
| return Const(int(s[start:pos[0]])) |
| elif c.isalpha() or c == '_': |
| start = pos[0] |
| while pos[0] < len(s) and (s[pos[0]].isalnum() or s[pos[0]] == '_'): |
| pos[0] += 1 |
| return Var(s[start:pos[0]]) |
| else: |
| raise ValueError(f"Unexpected {c!r} in {s!r}") |
|
|
| return parse() |
|
|
| |
|
|
| _bfs_table: Optional[dict] = None |
|
|
| def _load_bfs() -> dict: |
| global _bfs_table |
| if _bfs_table is not None: return _bfs_table |
| path = _HERE / 'bfs_table_6.json' |
| if not path.exists(): |
| print(f"Warning: {path} not found β BFS disabled.", file=sys.stderr) |
| _bfs_table = {}; return _bfs_table |
| t0 = time.perf_counter() |
| with open(path) as f: raw = json.load(f) |
| table = {} |
| for fp_str, expr_str in raw.items(): |
| table[tuple(json.loads(fp_str))] = _parse_prefix(expr_str) |
| _bfs_table = table |
| print(f"BFS: {len(table):,} entries ({(time.perf_counter()-t0)*1000:.0f}ms)", flush=True) |
| return _bfs_table |
|
|
| def _rename_vars(expr, var_map: dict) -> Expr: |
| match expr: |
| case Var(name): return Var(var_map.get(name, name)) |
| case Const(_): return expr |
| case BinOp(op,l,r): return BinOp(op, _rename_vars(l,var_map), _rename_vars(r,var_map)) |
| case UnOp(op,a): return UnOp(op, _rename_vars(a,var_map)) |
|
|
| |
|
|
| def _verify(e1, e2, bits: int, vars_tuple: tuple, n: int = 512) -> bool: |
| import random |
| MASK = (1 << bits) - 1 |
| ev = make_ev(bits) |
| rng = random.Random(7) |
| for _ in range(n): |
| env = {v: rng.randint(0, MASK) for v in vars_tuple} |
| if ev(e1, env) != ev(e2, env): return False |
| return True |
|
|
| |
|
|
| def make_symbolic(bits: int): |
| MASK = (1 << bits) - 1 |
|
|
| def rw(e): |
| match e: |
| case BinOp('+',BinOp('&',a,b),BinOp('|',c,d)) if {a,b}=={c,d}: |
| return BinOp('+',min(a,b,key=repr),max(a,b,key=repr)) |
| case BinOp('+',BinOp('|',a,b),BinOp('&',c,d)) if {a,b}=={c,d}: |
| return BinOp('+',min(a,b,key=repr),max(a,b,key=repr)) |
| case BinOp('-',BinOp('|',a,b),BinOp('&',c,d)) if {a,b}=={c,d}: |
| return BinOp('^',min(a,b,key=repr),max(a,b,key=repr)) |
| case BinOp('-',BinOp('+',a,b),BinOp('*',Const(2),BinOp('&',c,d))) if {a,b}=={c,d}: |
| return BinOp('^',min(a,b,key=repr),max(a,b,key=repr)) |
| case BinOp('-',BinOp('+',a,b),BinOp('*',BinOp('&',c,d),Const(2))) if {a,b}=={c,d}: |
| return BinOp('^',min(a,b,key=repr),max(a,b,key=repr)) |
| case BinOp('-',UnOp('-',x),Const(1)): return UnOp('~',x) |
| case BinOp('+',UnOp('-',x),Const(v)) if (v&MASK)==MASK: return UnOp('~',x) |
| case BinOp('+',UnOp('~',x),Const(1)): return UnOp('-',x) |
| case BinOp('+',Const(1),UnOp('~',x)): return UnOp('-',x) |
| case BinOp('^',UnOp('~',a),UnOp('~',b)): |
| return BinOp('^',min(a,b,key=repr),max(a,b,key=repr)) |
| case UnOp('~',BinOp('&',UnOp('~',a),UnOp('~',b))): |
| return BinOp('|',min(a,b,key=repr),max(a,b,key=repr)) |
| case UnOp('~',BinOp('|',UnOp('~',a),UnOp('~',b))): |
| return BinOp('&',min(a,b,key=repr),max(a,b,key=repr)) |
| case BinOp('|',a,BinOp('&',b,c)) if a==b or a==c: return a |
| case BinOp('|',BinOp('&',a,b),c) if c==a or c==b: return c |
| case BinOp('&',a,BinOp('|',b,c)) if a==b or a==c: return a |
| case BinOp('&',BinOp('|',a,b),c) if c==a or c==b: return c |
| case BinOp('^',a,BinOp('&',b,c)) if a==b: return BinOp('&',a,UnOp('~',c)) |
| case BinOp('^',a,BinOp('&',b,c)) if a==c: return BinOp('&',a,UnOp('~',b)) |
| case BinOp('^',BinOp('&',a,b),c) if c==a: return BinOp('&',c,UnOp('~',b)) |
| case BinOp('^',BinOp('&',a,b),c) if c==b: return BinOp('&',c,UnOp('~',a)) |
| case UnOp('-',BinOp('-',a,b)): return BinOp('-',b,a) |
| case BinOp('+',a,UnOp('-',b)) if a==b: return Const(0) |
| case BinOp('+',UnOp('-',a),b) if a==b: return Const(0) |
| case BinOp('-',a,UnOp('-',b)): return BinOp('+',a,b) |
| case BinOp('&',a,b) if a==b: return a |
| case BinOp('|',a,b) if a==b: return a |
| case BinOp('^',a,b) if a==b: return Const(0) |
| case BinOp('-',a,b) if a==b: return Const(0) |
| case BinOp('&',_,Const(0))|BinOp('&',Const(0),_): return Const(0) |
| case BinOp('|',x,Const(0))|BinOp('|',Const(0),x): return x |
| case BinOp('^',x,Const(0))|BinOp('^',Const(0),x): return x |
| case BinOp('^',x,Const(v)) if (v&MASK)==MASK: return UnOp('~',x) |
| case BinOp('^',Const(v),x) if (v&MASK)==MASK: return UnOp('~',x) |
| case BinOp('&',x,Const(v)) if (v&MASK)==MASK: return x |
| case BinOp('&',Const(v),x) if (v&MASK)==MASK: return x |
| case BinOp('|',_,Const(v))|BinOp('|',Const(v),_) if (v&MASK)==MASK: return Const(MASK) |
| case BinOp('&',a,UnOp('~',b)) if a==b: return Const(0) |
| case BinOp('&',UnOp('~',a),b) if a==b: return Const(0) |
| case BinOp('|',a,UnOp('~',b)) if a==b: return Const(MASK) |
| case BinOp('|',UnOp('~',a),b) if a==b: return Const(MASK) |
| case BinOp('^',a,UnOp('~',b)) if a==b: return Const(MASK) |
| case BinOp('+',x,Const(0))|BinOp('+',Const(0),x): return x |
| case BinOp('-',x,Const(0)): return x |
| case BinOp('*',x,Const(1))|BinOp('*',Const(1),x): return x |
| case BinOp('*',_,Const(0))|BinOp('*',Const(0),_): return Const(0) |
| case UnOp('-',UnOp('-',x)): return x |
| case UnOp('~',UnOp('~',x)): return x |
| case UnOp('-',Const(0)): return Const(0) |
| case UnOp('-',UnOp('~',x)): return BinOp('+',x,Const(1&MASK)) |
| case UnOp('~',UnOp('-',x)): return BinOp('-',x,Const(1&MASK)) |
| case BinOp('-',BinOp('+',a,b),c) if b==c: return a |
| case BinOp('-',BinOp('+',a,b),c) if a==c: return b |
| case BinOp('+',BinOp('-',a,b),c) if b==c: return a |
| case BinOp('-',a,BinOp('+',b,c)) if a==b: return UnOp('-',c) |
| case BinOp('-',a,BinOp('+',b,c)) if a==c: return UnOp('-',b) |
| case BinOp('-',BinOp('+',a,b),BinOp('+',c,d)) if a==c: return BinOp('-',b,d) |
| case BinOp('-',BinOp('+',a,b),BinOp('+',c,d)) if b==d: return BinOp('-',a,c) |
| case BinOp('-',BinOp('+',a,b),BinOp('+',c,d)) if a==d: return BinOp('-',b,c) |
| case BinOp('-',BinOp('+',a,b),BinOp('+',c,d)) if b==c: return BinOp('-',a,d) |
| case BinOp(op,Const(a),Const(b)): |
| match op: |
| case '+': return Const((a+b)&MASK) |
| case '-': return Const((a-b)&MASK) |
| case '*': return Const((a*b)&MASK) |
| case '&': return Const(a&b&MASK) |
| case '|': return Const((a|b)&MASK) |
| case '^': return Const((a^b)&MASK) |
| case UnOp('-',Const(a)): return Const((-a)&MASK) |
| case UnOp('~',Const(a)): return Const((~a)&MASK) |
| case BinOp(op,a,b) if op in('+','*','&','|','^') and repr(a)>repr(b): |
| return BinOp(op,b,a) |
| return None |
|
|
| def sym_pass(e): |
| match e: |
| case Var(_)|Const(_): return e |
| case BinOp(op,l,r): |
| nl=sym_pass(l); nr=sym_pass(r); ne=BinOp(op,nl,nr) |
| r2=rw(ne) |
| if r2 is None: return ne |
| return Const(r2.val&MASK) if isinstance(r2,Const) else r2 |
| case UnOp(op,a): |
| na=sym_pass(a); ne=UnOp(op,na) |
| r2=rw(ne) |
| if r2 is None: return ne |
| return Const(r2.val&MASK) if isinstance(r2,Const) else r2 |
|
|
| def sym(e, passes=30): |
| for _ in range(passes): |
| ne=sym_pass(e) |
| if ne==e: break |
| e=ne |
| return e |
|
|
| return sym, rw, sym_pass |
|
|
| |
|
|
| class NeuralLayer: |
| def __init__(self, model_path: str, rw_fn, bits: int): |
| self.rw_fn=rw_fn; self.bits=bits; self.loaded=False |
| try: |
| import torch, torch.nn as nn |
| self.torch=torch; self.nn=nn |
| |
| cfg_path = _HERE / "config.json" |
| rv_path = _HERE / "rule_vocab.json" |
| cv_path = _HERE / "char_vocab.json" |
| sf_path = _HERE / "mba_classifier.safetensors" |
| if cfg_path.exists() and rv_path.exists() and cv_path.exists(): |
| cfg = json.load(open(cfg_path)) |
| self.rules = json.load(open(rv_path)) |
| self.char2id = json.load(open(cv_path)) |
| else: |
| ckpt = torch.load(model_path, map_location='cpu', weights_only=False) |
| cfg = ckpt['config'] |
| self.rules = ckpt['rule_vocab'] |
| self.char2id = ckpt['char_vocab'] |
| self.n_rules = len(self.rules) |
| self.max_len = cfg['max_len'] |
| self.model = self._build(cfg) |
| |
| if sf_path.exists(): |
| from safetensors.torch import load_file |
| self.model.load_state_dict(load_file(str(sf_path))) |
| else: |
| ckpt = torch.load(model_path, map_location='cpu', weights_only=False) |
| self.model.load_state_dict(ckpt['model_state']) |
| self.model.eval() |
| self.reduction_rules=[i for i,r in enumerate(self.rules) |
| if not r.startswith('comm_') and not r.startswith('fold_')] |
| self.loaded=True |
| n_params=sum(p.numel() for p in self.model.parameters()) |
| print(f"Neural: {self.n_rules} rules, {n_params/1e6:.1f}M params", flush=True) |
| except Exception as ex: |
| print(f"Neural disabled: {ex}", file=sys.stderr) |
|
|
| def _build(self, cfg): |
| import torch.nn as nn; torch=self.torch |
| class RC(nn.Module): |
| def __init__(s,vocab_size,d_model,n_heads,n_layers,d_ff,max_len,n_rules,dropout=0.1): |
| super().__init__() |
| s.embed = nn.Embedding(vocab_size,d_model,padding_idx=0) |
| s.pos_embed = nn.Embedding(max_len,d_model) |
| el = nn.TransformerEncoderLayer(d_model,n_heads,d_ff,dropout,batch_first=True,norm_first=True) |
| s.encoder = nn.TransformerEncoder(el,n_layers) |
| s.norm = nn.LayerNorm(d_model) |
| s.head = nn.Linear(d_model,n_rules) |
| def forward(s,x): |
| pm=(x==0); pos=torch.arange(x.size(1),device=x.device).unsqueeze(0) |
| e=s.embed(x)+s.pos_embed(pos); enc=s.encoder(e,src_key_padding_mask=pm) |
| L=(~pm).float().sum(1,keepdim=True).clamp(min=1) |
| return s.head(s.norm((enc*(~pm).unsqueeze(-1).float()).sum(1)/L)) |
| RC.torch=self.torch |
| return RC(**cfg) |
|
|
| def _tok(self, expr_str): |
| ids=[self.char2id.get(c,1) for c in expr_str[:self.max_len]] |
| ids+=[0]*(self.max_len-len(ids)) |
| return self.torch.tensor([ids],dtype=self.torch.long) |
|
|
| def search(self, expr, max_steps=50, top_k=5): |
| if not self.loaded: return expr |
| best=expr; bsz=size(expr); cur=expr; vis={repr(cur)} |
| for _ in range(max_steps): |
| try: rules=self._top_rules(repr(cur),top_k) |
| except Exception: break |
| improved=False |
| for rid in rules: |
| for cand in self._apply_everywhere(cur,self.rules[rid]): |
| k=repr(cand) |
| if k in vis: continue |
| if size(cand)<size(cur): |
| cur=cand; vis.add(k) |
| if size(cur)<bsz: best=cur; bsz=size(cur) |
| improved=True; break |
| if improved: break |
| if not improved: break |
| return best |
|
|
| def _top_rules(self, expr_str, top_k): |
| with self.torch.no_grad(): |
| lgt=self.model(self._tok(expr_str))[0] |
| m=self.torch.full_like(lgt,float('-inf')) |
| for i in self.reduction_rules: m[i]=lgt[i] |
| return m.topk(min(top_k,len(self.reduction_rules))).indices.tolist() |
|
|
| def _apply_everywhere(self, expr, rule_name): |
| rw=self.rw_fn; results=[] |
| def try_at(node,rebuild): |
| r=rw(node) |
| if r is not None: results.append(rebuild(r)) |
| match node: |
| case BinOp(op,l,r2): |
| try_at(l,lambda nl,op=op,r2=r2,rb=rebuild: rb(BinOp(op,nl,r2))) |
| try_at(r2,lambda nr,op=op,l=l,rb=rebuild: rb(BinOp(op,l,nr))) |
| case UnOp(op,a): |
| try_at(a,lambda na,op=op,rb=rebuild: rb(UnOp(op,na))) |
| try_at(expr, lambda x: x); results.sort(key=size); return results |
|
|
| |
|
|
| def _tokenize(s: str) -> list: |
| tokens=[]; i=0 |
| while i<len(s): |
| c=s[i] |
| if c.isspace(): i+=1 |
| elif s[i:i+2].lower().startswith('0x'): |
| j=i+2 |
| while j<len(s) and s[j] in '0123456789abcdefABCDEF': j+=1 |
| tokens.append(s[i:j]); i=j |
| elif c.isdigit(): |
| j=i+1 |
| while j<len(s) and s[j].isdigit(): j+=1 |
| tokens.append(s[i:j]); i=j |
| elif c.isalpha() or c=='_': |
| j=i+1 |
| while j<len(s) and (s[j].isalnum() or s[j]=='_'): j+=1 |
| tokens.append(s[i:j]); i=j |
| elif c in '+-*&|^~()': tokens.append(c); i+=1 |
| else: raise ValueError(f"Unknown char {c!r} in: {s!r}") |
| return tokens |
|
|
| def _parse_infix(tokens: list, bits: int) -> Expr: |
| MASK=(1<<bits)-1; pos=[0] |
| def peek(): return tokens[pos[0]] if pos[0]<len(tokens) else None |
| def consume(e=None): |
| t=tokens[pos[0]] |
| if e is not None and t!=e: raise ValueError(f"Expected {e!r}, got {t!r}") |
| pos[0]+=1; return t |
| def parse_or(): |
| L=parse_xor() |
| while peek()=='|': consume(); L=BinOp('|',L,parse_xor()) |
| return L |
| def parse_xor(): |
| L=parse_and() |
| while peek()=='^': consume(); L=BinOp('^',L,parse_and()) |
| return L |
| def parse_and(): |
| L=parse_add() |
| while peek()=='&': consume(); L=BinOp('&',L,parse_add()) |
| return L |
| def parse_add(): |
| L=parse_mul() |
| while peek() in('+','-'): |
| op=consume(); L=BinOp(op,L,parse_mul()) |
| return L |
| def parse_mul(): |
| L=parse_un() |
| while peek()=='*': consume(); L=BinOp('*',L,parse_un()) |
| return L |
| def parse_un(): |
| if peek()=='~': consume(); return UnOp('~',parse_un()) |
| if peek()=='-': |
| consume() |
| if peek() is not None and re.match(r'^\d+$',peek()): |
| return Const((-int(consume()))&MASK) |
| return UnOp('-',parse_un()) |
| return parse_atom() |
| def parse_atom(): |
| tok=peek() |
| if tok=='(': consume('('); e=parse_or(); consume(')'); return e |
| if tok and re.match(r'^0[xX][0-9a-fA-F]+$',tok): consume(); return Const(int(tok,16)&MASK) |
| if tok and re.match(r'^\d+$',tok): consume(); return Const(int(tok)&MASK) |
| if tok and re.match(r'^[a-zA-Z_]\w*$',tok): consume(); return Var(tok) |
| raise ValueError(f"Unexpected token {tok!r}") |
| ast=parse_or() |
| if pos[0]!=len(tokens): raise ValueError(f"Trailing: {tokens[pos[0]:]}") |
| return ast |
|
|
| _PREC={'|':1,'^':2,'&':3,'+':4,'-':4,'*':5} |
|
|
| def _to_infix(expr) -> str: |
| def emit(node,mp=0,ir=False): |
| match node: |
| case Var(n): return n |
| case Const(v): return str(v) |
| case UnOp(op,a): return f"{op}{emit(a,6)}" |
| case BinOp(op,l,r): |
| p=_PREC[op]; s=f"{emit(l,p,False)} {op} {emit(r,p,True)}" |
| return f"({s})" if p<mp or (p==mp and ir) else s |
| return emit(expr) |
|
|
| def _extract_vars(expr) -> tuple: |
| seen=[] |
| def w(n): |
| match n: |
| case Var(name): |
| if name not in seen: seen.append(name) |
| case BinOp(_,l,r): w(l); w(r) |
| case UnOp(_,a): w(a) |
| w(expr); return tuple(sorted(seen)) |
|
|
| |
|
|
| _neural: Optional[NeuralLayer] = None |
| _sym_cache: dict = {} |
|
|
| def _get_sym(bits): |
| if bits not in _sym_cache: _sym_cache[bits]=make_symbolic(bits) |
| return _sym_cache[bits] |
|
|
| def _get_neural(model_path: Optional[str]) -> Optional[NeuralLayer]: |
| global _neural |
| if _neural is not None: return _neural |
| if model_path is None: |
| sf = _HERE / 'mba_classifier.safetensors' |
| pt = _HERE / 'mba_classifier_v2.pt' |
| if sf.exists(): model_path=str(sf) |
| elif pt.exists(): model_path=str(pt) |
| else: return None |
| _,rw,_=_get_sym(8) |
| _neural=NeuralLayer(str(model_path), rw, bits=8) |
| return _neural |
|
|
| |
|
|
| def simplify( |
| expr: str, |
| bits: int = 8, |
| model_path: Optional[str] = None, |
| ) -> dict: |
| """ |
| Simplify an MBA expression string. |
| |
| Args: |
| expr: Infix expression. Operators: + - * & | ^ ~ (unary: - ~) |
| Variables: identifiers [a-zA-Z_]+. |
| Example: "(x & y) + (x | y)" |
| bits: Bit width β 8 or 32. Default 8. |
| model_path: Path to mba_classifier_v2.pt. Auto-detected if omitted. |
| |
| Returns: |
| dict: |
| simplified (str) β simplified infix expression |
| original (str) β input |
| reduction (float) β size reduction ratio 0.0β1.0 |
| strategy (str) β bfs / symbolic / neural / none |
| latency_ms (float) β wall-clock time |
| |
| Raises: |
| ValueError: on parse error or unsupported bit width. |
| """ |
| if not isinstance(expr, str) or not expr.strip(): |
| raise ValueError("expr must be a non-empty string") |
| if bits not in (8, 32): |
| raise ValueError("bits must be 8 or 32") |
|
|
| t0=time.perf_counter() |
| tokens=_tokenize(expr.strip()) |
| ast=_parse_infix(tokens, bits) |
| vars_tuple=_extract_vars(ast) |
| if not vars_tuple: vars_tuple=('x',) |
| orig_size=size(ast); strategy="none"; result=ast |
|
|
| |
| if bits==8 and len(vars_tuple)<=2: |
| table=_load_bfs() |
| if table: |
| v0=vars_tuple[0] |
| v1=vars_tuple[1] if len(vars_tuple)>1 else '__' |
| fp=_fp8(ast, v0, v1) |
| if fp in table: |
| bfs_e=table[fp] |
| vm={} |
| if len(vars_tuple)>=1: vm['x']=vars_tuple[0] |
| if len(vars_tuple)>=2: vm['y']=vars_tuple[1] |
| bfs_e=_rename_vars(bfs_e, vm) |
| if size(bfs_e)<=orig_size and _verify(ast, bfs_e, bits, vars_tuple): |
| result=bfs_e; strategy="bfs" |
|
|
| |
| if strategy=="none": |
| sym,_,_=_get_sym(bits) |
| sr=sym(ast) |
| if sr!=ast: result=sr; strategy="symbolic" |
|
|
| |
| if strategy=="none": |
| neural=_get_neural(model_path) |
| if neural and neural.loaded: |
| nr=neural.search(ast) |
| if nr is not ast and size(nr)<=orig_size and _verify(ast, nr, bits, vars_tuple): |
| result=nr; strategy="neural" |
|
|
| out_size=size(result) |
| reduction=1.0-out_size/orig_size if orig_size>0 else 0.0 |
| ms=round((time.perf_counter()-t0)*1e3, 1) |
| return { |
| "simplified": _to_infix(result), |
| "original": expr.strip(), |
| "reduction": round(reduction,3), |
| "strategy": strategy, |
| "latency_ms": ms, |
| } |
|
|
| |
| if __name__=="__main__": |
| import argparse |
| p=argparse.ArgumentParser(description="MBA Expression Simplifier") |
| p.add_argument("expr", help='e.g. "(x & y) + (x | y)"') |
| p.add_argument("--bits",type=int,default=8,choices=[8,32]) |
| p.add_argument("--model",type=str,default=None) |
| args=p.parse_args() |
| try: |
| r=simplify(args.expr,bits=args.bits,model_path=args.model) |
| print(f"Input: {r['original']}") |
| print(f"Simplified: {r['simplified']}") |
| if r['original']!=r['simplified']: |
| print(f"Reduction: {r['reduction']*100:.0f}% ({r['strategy']})") |
| else: |
| print(f"Result: already minimal ({r['strategy']})") |
| print(f"Latency: {r['latency_ms']} ms") |
| except ValueError as e: |
| print(f"Error: {e}",file=sys.stderr); sys.exit(1) |
|
|