mba-simplifier / inference.py
qox's picture
Add inference.py
f2c692e verified
Raw
History Blame Contribute Delete
25.6 kB
"""
MBA Expression Simplifier β€” Standalone Inference
=================================================
Usage:
python inference.py "(x & y) + (x | y)"
from inference import simplify
"""
import json, re, sys, time
from dataclasses import dataclass
from pathlib import Path
from typing import Union, Optional
_HERE = Path(__file__).resolve().parent
# ── AST ───────────────────────────────────────────────────────────────────────
@dataclass(frozen=True)
class Var:
name: str
def __repr__(self): return self.name
@dataclass(frozen=True)
class Const:
val: int
def __repr__(self): return str(self.val)
@dataclass(frozen=True)
class BinOp:
op: str; left: 'Expr'; right: 'Expr'
def __repr__(self): return f"({self.op} {self.left} {self.right})"
@dataclass(frozen=True)
class UnOp:
op: str; arg: 'Expr'
def __repr__(self): return f"({self.op} {self.arg})"
Expr = Union[Var, Const, BinOp, UnOp]
def size(e) -> int:
match e:
case Var(_) | Const(_): return 1
case BinOp(_, l, r): return 1 + size(l) + size(r)
case UnOp(_, a): return 1 + size(a)
# ── EVALUATOR ─────────────────────────────────────────────────────────────────
def make_ev(bits: int):
MASK = (1 << bits) - 1
def ev(e, env: dict) -> int:
match e:
case Var(n): return env.get(n, 0) & MASK
case Const(v): return v & MASK
case BinOp('+',l,r): return (ev(l,env) + ev(r,env)) & MASK
case BinOp('-',l,r): return (ev(l,env) - ev(r,env)) & MASK
case BinOp('*',l,r): return (ev(l,env) * ev(r,env)) & MASK
case BinOp('&',l,r): return ev(l,env) & ev(r,env)
case BinOp('|',l,r): return ev(l,env) | ev(r,env)
case BinOp('^',l,r): return ev(l,env) ^ ev(r,env)
case UnOp('~',a): return (~ev(a,env)) & MASK
case UnOp('-',a): return (-ev(a,env)) & MASK
return ev
_ev8 = make_ev(8)
# ── BFS FINGERPRINT (must match bfs_table_6.json exactly) ────────────────────
# Fixed 16-point evaluation set from poly_synth_final.py.
# Changing these invalidates the pre-computed table.
_EVAL_POINTS = [
(0,0),(1,0),(0,1),(1,1),
(2,3),(5,7),(11,13),(17,19),
(64,128),(128,64),(255,0),
(170,85),(204,51),(85,170),
(7,15),(31,63),
]
def _fp8(expr, v0: str, v1: str) -> tuple:
return tuple(_ev8(expr, {v0: px, v1: py}) for px, py in _EVAL_POINTS)
# ── PREFIX PARSER (for loading bfs_table_6.json values) ──────────────────────
def _parse_prefix(s: str) -> Expr:
pos = [0]
def skip():
while pos[0] < len(s) and s[pos[0]] == ' ':
pos[0] += 1
def parse():
skip()
if pos[0] >= len(s): raise ValueError(f"Unexpected end: {s!r}")
c = s[pos[0]]
if c == '(':
pos[0] += 1; skip()
op_s = pos[0]
while pos[0] < len(s) and s[pos[0]] not in (' ', ')'):
pos[0] += 1
op = s[op_s:pos[0]]; skip()
left = parse(); skip()
if pos[0] < len(s) and s[pos[0]] == ')':
pos[0] += 1; return UnOp(op, left)
right = parse(); skip()
if pos[0] < len(s) and s[pos[0]] == ')': pos[0] += 1
return BinOp(op, left, right)
elif c.isdigit():
start = pos[0]
while pos[0] < len(s) and s[pos[0]].isdigit(): pos[0] += 1
return Const(int(s[start:pos[0]]))
elif c.isalpha() or c == '_':
start = pos[0]
while pos[0] < len(s) and (s[pos[0]].isalnum() or s[pos[0]] == '_'):
pos[0] += 1
return Var(s[start:pos[0]])
else:
raise ValueError(f"Unexpected {c!r} in {s!r}")
return parse()
# ── BFS TABLE LOAD ────────────────────────────────────────────────────────────
_bfs_table: Optional[dict] = None
def _load_bfs() -> dict:
global _bfs_table
if _bfs_table is not None: return _bfs_table
path = _HERE / 'bfs_table_6.json'
if not path.exists():
print(f"Warning: {path} not found β€” BFS disabled.", file=sys.stderr)
_bfs_table = {}; return _bfs_table
t0 = time.perf_counter()
with open(path) as f: raw = json.load(f)
table = {}
for fp_str, expr_str in raw.items():
table[tuple(json.loads(fp_str))] = _parse_prefix(expr_str)
_bfs_table = table
print(f"BFS: {len(table):,} entries ({(time.perf_counter()-t0)*1000:.0f}ms)", flush=True)
return _bfs_table
def _rename_vars(expr, var_map: dict) -> Expr:
match expr:
case Var(name): return Var(var_map.get(name, name))
case Const(_): return expr
case BinOp(op,l,r): return BinOp(op, _rename_vars(l,var_map), _rename_vars(r,var_map))
case UnOp(op,a): return UnOp(op, _rename_vars(a,var_map))
# ── ORACLE VERIFY ─────────────────────────────────────────────────────────────
def _verify(e1, e2, bits: int, vars_tuple: tuple, n: int = 512) -> bool:
import random
MASK = (1 << bits) - 1
ev = make_ev(bits)
rng = random.Random(7)
for _ in range(n):
env = {v: rng.randint(0, MASK) for v in vars_tuple}
if ev(e1, env) != ev(e2, env): return False
return True
# ── SYMBOLIC ENGINE (inlined from pipeline_v5.py) ─────────────────────────────
def make_symbolic(bits: int):
MASK = (1 << bits) - 1
def rw(e):
match e:
case BinOp('+',BinOp('&',a,b),BinOp('|',c,d)) if {a,b}=={c,d}:
return BinOp('+',min(a,b,key=repr),max(a,b,key=repr))
case BinOp('+',BinOp('|',a,b),BinOp('&',c,d)) if {a,b}=={c,d}:
return BinOp('+',min(a,b,key=repr),max(a,b,key=repr))
case BinOp('-',BinOp('|',a,b),BinOp('&',c,d)) if {a,b}=={c,d}:
return BinOp('^',min(a,b,key=repr),max(a,b,key=repr))
case BinOp('-',BinOp('+',a,b),BinOp('*',Const(2),BinOp('&',c,d))) if {a,b}=={c,d}:
return BinOp('^',min(a,b,key=repr),max(a,b,key=repr))
case BinOp('-',BinOp('+',a,b),BinOp('*',BinOp('&',c,d),Const(2))) if {a,b}=={c,d}:
return BinOp('^',min(a,b,key=repr),max(a,b,key=repr))
case BinOp('-',UnOp('-',x),Const(1)): return UnOp('~',x)
case BinOp('+',UnOp('-',x),Const(v)) if (v&MASK)==MASK: return UnOp('~',x)
case BinOp('+',UnOp('~',x),Const(1)): return UnOp('-',x)
case BinOp('+',Const(1),UnOp('~',x)): return UnOp('-',x)
case BinOp('^',UnOp('~',a),UnOp('~',b)):
return BinOp('^',min(a,b,key=repr),max(a,b,key=repr))
case UnOp('~',BinOp('&',UnOp('~',a),UnOp('~',b))):
return BinOp('|',min(a,b,key=repr),max(a,b,key=repr))
case UnOp('~',BinOp('|',UnOp('~',a),UnOp('~',b))):
return BinOp('&',min(a,b,key=repr),max(a,b,key=repr))
case BinOp('|',a,BinOp('&',b,c)) if a==b or a==c: return a
case BinOp('|',BinOp('&',a,b),c) if c==a or c==b: return c
case BinOp('&',a,BinOp('|',b,c)) if a==b or a==c: return a
case BinOp('&',BinOp('|',a,b),c) if c==a or c==b: return c
case BinOp('^',a,BinOp('&',b,c)) if a==b: return BinOp('&',a,UnOp('~',c))
case BinOp('^',a,BinOp('&',b,c)) if a==c: return BinOp('&',a,UnOp('~',b))
case BinOp('^',BinOp('&',a,b),c) if c==a: return BinOp('&',c,UnOp('~',b))
case BinOp('^',BinOp('&',a,b),c) if c==b: return BinOp('&',c,UnOp('~',a))
case UnOp('-',BinOp('-',a,b)): return BinOp('-',b,a)
case BinOp('+',a,UnOp('-',b)) if a==b: return Const(0)
case BinOp('+',UnOp('-',a),b) if a==b: return Const(0)
case BinOp('-',a,UnOp('-',b)): return BinOp('+',a,b)
case BinOp('&',a,b) if a==b: return a
case BinOp('|',a,b) if a==b: return a
case BinOp('^',a,b) if a==b: return Const(0)
case BinOp('-',a,b) if a==b: return Const(0)
case BinOp('&',_,Const(0))|BinOp('&',Const(0),_): return Const(0)
case BinOp('|',x,Const(0))|BinOp('|',Const(0),x): return x
case BinOp('^',x,Const(0))|BinOp('^',Const(0),x): return x
case BinOp('^',x,Const(v)) if (v&MASK)==MASK: return UnOp('~',x)
case BinOp('^',Const(v),x) if (v&MASK)==MASK: return UnOp('~',x)
case BinOp('&',x,Const(v)) if (v&MASK)==MASK: return x
case BinOp('&',Const(v),x) if (v&MASK)==MASK: return x
case BinOp('|',_,Const(v))|BinOp('|',Const(v),_) if (v&MASK)==MASK: return Const(MASK)
case BinOp('&',a,UnOp('~',b)) if a==b: return Const(0)
case BinOp('&',UnOp('~',a),b) if a==b: return Const(0)
case BinOp('|',a,UnOp('~',b)) if a==b: return Const(MASK)
case BinOp('|',UnOp('~',a),b) if a==b: return Const(MASK)
case BinOp('^',a,UnOp('~',b)) if a==b: return Const(MASK)
case BinOp('+',x,Const(0))|BinOp('+',Const(0),x): return x
case BinOp('-',x,Const(0)): return x
case BinOp('*',x,Const(1))|BinOp('*',Const(1),x): return x
case BinOp('*',_,Const(0))|BinOp('*',Const(0),_): return Const(0)
case UnOp('-',UnOp('-',x)): return x
case UnOp('~',UnOp('~',x)): return x
case UnOp('-',Const(0)): return Const(0)
case UnOp('-',UnOp('~',x)): return BinOp('+',x,Const(1&MASK))
case UnOp('~',UnOp('-',x)): return BinOp('-',x,Const(1&MASK))
case BinOp('-',BinOp('+',a,b),c) if b==c: return a
case BinOp('-',BinOp('+',a,b),c) if a==c: return b
case BinOp('+',BinOp('-',a,b),c) if b==c: return a
case BinOp('-',a,BinOp('+',b,c)) if a==b: return UnOp('-',c)
case BinOp('-',a,BinOp('+',b,c)) if a==c: return UnOp('-',b)
case BinOp('-',BinOp('+',a,b),BinOp('+',c,d)) if a==c: return BinOp('-',b,d)
case BinOp('-',BinOp('+',a,b),BinOp('+',c,d)) if b==d: return BinOp('-',a,c)
case BinOp('-',BinOp('+',a,b),BinOp('+',c,d)) if a==d: return BinOp('-',b,c)
case BinOp('-',BinOp('+',a,b),BinOp('+',c,d)) if b==c: return BinOp('-',a,d)
case BinOp(op,Const(a),Const(b)):
match op:
case '+': return Const((a+b)&MASK)
case '-': return Const((a-b)&MASK)
case '*': return Const((a*b)&MASK)
case '&': return Const(a&b&MASK)
case '|': return Const((a|b)&MASK)
case '^': return Const((a^b)&MASK)
case UnOp('-',Const(a)): return Const((-a)&MASK)
case UnOp('~',Const(a)): return Const((~a)&MASK)
case BinOp(op,a,b) if op in('+','*','&','|','^') and repr(a)>repr(b):
return BinOp(op,b,a)
return None
def sym_pass(e):
match e:
case Var(_)|Const(_): return e
case BinOp(op,l,r):
nl=sym_pass(l); nr=sym_pass(r); ne=BinOp(op,nl,nr)
r2=rw(ne)
if r2 is None: return ne
return Const(r2.val&MASK) if isinstance(r2,Const) else r2
case UnOp(op,a):
na=sym_pass(a); ne=UnOp(op,na)
r2=rw(ne)
if r2 is None: return ne
return Const(r2.val&MASK) if isinstance(r2,Const) else r2
def sym(e, passes=30):
for _ in range(passes):
ne=sym_pass(e)
if ne==e: break
e=ne
return e
return sym, rw, sym_pass
# ── NEURAL LAYER (optional) ───────────────────────────────────────────────────
class NeuralLayer:
def __init__(self, model_path: str, rw_fn, bits: int):
self.rw_fn=rw_fn; self.bits=bits; self.loaded=False
try:
import torch, torch.nn as nn
self.torch=torch; self.nn=nn
# Load config/vocab from JSON sidecars if present, else extract from .pt
cfg_path = _HERE / "config.json"
rv_path = _HERE / "rule_vocab.json"
cv_path = _HERE / "char_vocab.json"
sf_path = _HERE / "mba_classifier.safetensors"
if cfg_path.exists() and rv_path.exists() and cv_path.exists():
cfg = json.load(open(cfg_path))
self.rules = json.load(open(rv_path))
self.char2id = json.load(open(cv_path))
else:
ckpt = torch.load(model_path, map_location='cpu', weights_only=False)
cfg = ckpt['config']
self.rules = ckpt['rule_vocab']
self.char2id = ckpt['char_vocab']
self.n_rules = len(self.rules)
self.max_len = cfg['max_len']
self.model = self._build(cfg)
# Load weights: prefer .safetensors, fall back to .pt
if sf_path.exists():
from safetensors.torch import load_file
self.model.load_state_dict(load_file(str(sf_path)))
else:
ckpt = torch.load(model_path, map_location='cpu', weights_only=False)
self.model.load_state_dict(ckpt['model_state'])
self.model.eval()
self.reduction_rules=[i for i,r in enumerate(self.rules)
if not r.startswith('comm_') and not r.startswith('fold_')]
self.loaded=True
n_params=sum(p.numel() for p in self.model.parameters())
print(f"Neural: {self.n_rules} rules, {n_params/1e6:.1f}M params", flush=True)
except Exception as ex:
print(f"Neural disabled: {ex}", file=sys.stderr)
def _build(self, cfg):
import torch.nn as nn; torch=self.torch
class RC(nn.Module):
def __init__(s,vocab_size,d_model,n_heads,n_layers,d_ff,max_len,n_rules,dropout=0.1):
super().__init__()
s.embed = nn.Embedding(vocab_size,d_model,padding_idx=0)
s.pos_embed = nn.Embedding(max_len,d_model)
el = nn.TransformerEncoderLayer(d_model,n_heads,d_ff,dropout,batch_first=True,norm_first=True)
s.encoder = nn.TransformerEncoder(el,n_layers)
s.norm = nn.LayerNorm(d_model)
s.head = nn.Linear(d_model,n_rules)
def forward(s,x):
pm=(x==0); pos=torch.arange(x.size(1),device=x.device).unsqueeze(0)
e=s.embed(x)+s.pos_embed(pos); enc=s.encoder(e,src_key_padding_mask=pm)
L=(~pm).float().sum(1,keepdim=True).clamp(min=1)
return s.head(s.norm((enc*(~pm).unsqueeze(-1).float()).sum(1)/L))
RC.torch=self.torch
return RC(**cfg)
def _tok(self, expr_str):
ids=[self.char2id.get(c,1) for c in expr_str[:self.max_len]]
ids+=[0]*(self.max_len-len(ids))
return self.torch.tensor([ids],dtype=self.torch.long)
def search(self, expr, max_steps=50, top_k=5):
if not self.loaded: return expr
best=expr; bsz=size(expr); cur=expr; vis={repr(cur)}
for _ in range(max_steps):
try: rules=self._top_rules(repr(cur),top_k)
except Exception: break
improved=False
for rid in rules:
for cand in self._apply_everywhere(cur,self.rules[rid]):
k=repr(cand)
if k in vis: continue
if size(cand)<size(cur):
cur=cand; vis.add(k)
if size(cur)<bsz: best=cur; bsz=size(cur)
improved=True; break
if improved: break
if not improved: break
return best
def _top_rules(self, expr_str, top_k):
with self.torch.no_grad():
lgt=self.model(self._tok(expr_str))[0]
m=self.torch.full_like(lgt,float('-inf'))
for i in self.reduction_rules: m[i]=lgt[i]
return m.topk(min(top_k,len(self.reduction_rules))).indices.tolist()
def _apply_everywhere(self, expr, rule_name):
rw=self.rw_fn; results=[]
def try_at(node,rebuild):
r=rw(node)
if r is not None: results.append(rebuild(r))
match node:
case BinOp(op,l,r2):
try_at(l,lambda nl,op=op,r2=r2,rb=rebuild: rb(BinOp(op,nl,r2)))
try_at(r2,lambda nr,op=op,l=l,rb=rebuild: rb(BinOp(op,l,nr)))
case UnOp(op,a):
try_at(a,lambda na,op=op,rb=rebuild: rb(UnOp(op,na)))
try_at(expr, lambda x: x); results.sort(key=size); return results
# ── INFIX PARSER + PRINTER ────────────────────────────────────────────────────
def _tokenize(s: str) -> list:
tokens=[]; i=0
while i<len(s):
c=s[i]
if c.isspace(): i+=1
elif s[i:i+2].lower().startswith('0x'):
j=i+2
while j<len(s) and s[j] in '0123456789abcdefABCDEF': j+=1
tokens.append(s[i:j]); i=j
elif c.isdigit():
j=i+1
while j<len(s) and s[j].isdigit(): j+=1
tokens.append(s[i:j]); i=j
elif c.isalpha() or c=='_':
j=i+1
while j<len(s) and (s[j].isalnum() or s[j]=='_'): j+=1
tokens.append(s[i:j]); i=j
elif c in '+-*&|^~()': tokens.append(c); i+=1
else: raise ValueError(f"Unknown char {c!r} in: {s!r}")
return tokens
def _parse_infix(tokens: list, bits: int) -> Expr:
MASK=(1<<bits)-1; pos=[0]
def peek(): return tokens[pos[0]] if pos[0]<len(tokens) else None
def consume(e=None):
t=tokens[pos[0]]
if e is not None and t!=e: raise ValueError(f"Expected {e!r}, got {t!r}")
pos[0]+=1; return t
def parse_or():
L=parse_xor()
while peek()=='|': consume(); L=BinOp('|',L,parse_xor())
return L
def parse_xor():
L=parse_and()
while peek()=='^': consume(); L=BinOp('^',L,parse_and())
return L
def parse_and():
L=parse_add()
while peek()=='&': consume(); L=BinOp('&',L,parse_add())
return L
def parse_add():
L=parse_mul()
while peek() in('+','-'):
op=consume(); L=BinOp(op,L,parse_mul())
return L
def parse_mul():
L=parse_un()
while peek()=='*': consume(); L=BinOp('*',L,parse_un())
return L
def parse_un():
if peek()=='~': consume(); return UnOp('~',parse_un())
if peek()=='-':
consume()
if peek() is not None and re.match(r'^\d+$',peek()):
return Const((-int(consume()))&MASK)
return UnOp('-',parse_un())
return parse_atom()
def parse_atom():
tok=peek()
if tok=='(': consume('('); e=parse_or(); consume(')'); return e
if tok and re.match(r'^0[xX][0-9a-fA-F]+$',tok): consume(); return Const(int(tok,16)&MASK)
if tok and re.match(r'^\d+$',tok): consume(); return Const(int(tok)&MASK)
if tok and re.match(r'^[a-zA-Z_]\w*$',tok): consume(); return Var(tok)
raise ValueError(f"Unexpected token {tok!r}")
ast=parse_or()
if pos[0]!=len(tokens): raise ValueError(f"Trailing: {tokens[pos[0]:]}")
return ast
_PREC={'|':1,'^':2,'&':3,'+':4,'-':4,'*':5}
def _to_infix(expr) -> str:
def emit(node,mp=0,ir=False):
match node:
case Var(n): return n
case Const(v): return str(v)
case UnOp(op,a): return f"{op}{emit(a,6)}"
case BinOp(op,l,r):
p=_PREC[op]; s=f"{emit(l,p,False)} {op} {emit(r,p,True)}"
return f"({s})" if p<mp or (p==mp and ir) else s
return emit(expr)
def _extract_vars(expr) -> tuple:
seen=[]
def w(n):
match n:
case Var(name):
if name not in seen: seen.append(name)
case BinOp(_,l,r): w(l); w(r)
case UnOp(_,a): w(a)
w(expr); return tuple(sorted(seen))
# ── GLOBAL STATE (lazy) ───────────────────────────────────────────────────────
_neural: Optional[NeuralLayer] = None
_sym_cache: dict = {}
def _get_sym(bits):
if bits not in _sym_cache: _sym_cache[bits]=make_symbolic(bits)
return _sym_cache[bits]
def _get_neural(model_path: Optional[str]) -> Optional[NeuralLayer]:
global _neural
if _neural is not None: return _neural
if model_path is None:
sf = _HERE / 'mba_classifier.safetensors'
pt = _HERE / 'mba_classifier_v2.pt'
if sf.exists(): model_path=str(sf)
elif pt.exists(): model_path=str(pt)
else: return None
_,rw,_=_get_sym(8)
_neural=NeuralLayer(str(model_path), rw, bits=8)
return _neural
# ── PUBLIC API ────────────────────────────────────────────────────────────────
def simplify(
expr: str,
bits: int = 8,
model_path: Optional[str] = None,
) -> dict:
"""
Simplify an MBA expression string.
Args:
expr: Infix expression. Operators: + - * & | ^ ~ (unary: - ~)
Variables: identifiers [a-zA-Z_]+.
Example: "(x & y) + (x | y)"
bits: Bit width β€” 8 or 32. Default 8.
model_path: Path to mba_classifier_v2.pt. Auto-detected if omitted.
Returns:
dict:
simplified (str) β€” simplified infix expression
original (str) β€” input
reduction (float) β€” size reduction ratio 0.0–1.0
strategy (str) β€” bfs / symbolic / neural / none
latency_ms (float) β€” wall-clock time
Raises:
ValueError: on parse error or unsupported bit width.
"""
if not isinstance(expr, str) or not expr.strip():
raise ValueError("expr must be a non-empty string")
if bits not in (8, 32):
raise ValueError("bits must be 8 or 32")
t0=time.perf_counter()
tokens=_tokenize(expr.strip())
ast=_parse_infix(tokens, bits)
vars_tuple=_extract_vars(ast)
if not vars_tuple: vars_tuple=('x',)
orig_size=size(ast); strategy="none"; result=ast
# 1. BFS (8-bit, ≀2 vars)
if bits==8 and len(vars_tuple)<=2:
table=_load_bfs()
if table:
v0=vars_tuple[0]
v1=vars_tuple[1] if len(vars_tuple)>1 else '__'
fp=_fp8(ast, v0, v1)
if fp in table:
bfs_e=table[fp]
vm={}
if len(vars_tuple)>=1: vm['x']=vars_tuple[0]
if len(vars_tuple)>=2: vm['y']=vars_tuple[1]
bfs_e=_rename_vars(bfs_e, vm)
if size(bfs_e)<=orig_size and _verify(ast, bfs_e, bits, vars_tuple):
result=bfs_e; strategy="bfs"
# 2. Symbolic
if strategy=="none":
sym,_,_=_get_sym(bits)
sr=sym(ast)
if sr!=ast: result=sr; strategy="symbolic"
# 3. Neural (optional)
if strategy=="none":
neural=_get_neural(model_path)
if neural and neural.loaded:
nr=neural.search(ast)
if nr is not ast and size(nr)<=orig_size and _verify(ast, nr, bits, vars_tuple):
result=nr; strategy="neural"
out_size=size(result)
reduction=1.0-out_size/orig_size if orig_size>0 else 0.0
ms=round((time.perf_counter()-t0)*1e3, 1)
return {
"simplified": _to_infix(result),
"original": expr.strip(),
"reduction": round(reduction,3),
"strategy": strategy,
"latency_ms": ms,
}
# ── CLI ───────────────────────────────────────────────────────────────────────
if __name__=="__main__":
import argparse
p=argparse.ArgumentParser(description="MBA Expression Simplifier")
p.add_argument("expr", help='e.g. "(x & y) + (x | y)"')
p.add_argument("--bits",type=int,default=8,choices=[8,32])
p.add_argument("--model",type=str,default=None)
args=p.parse_args()
try:
r=simplify(args.expr,bits=args.bits,model_path=args.model)
print(f"Input: {r['original']}")
print(f"Simplified: {r['simplified']}")
if r['original']!=r['simplified']:
print(f"Reduction: {r['reduction']*100:.0f}% ({r['strategy']})")
else:
print(f"Result: already minimal ({r['strategy']})")
print(f"Latency: {r['latency_ms']} ms")
except ValueError as e:
print(f"Error: {e}",file=sys.stderr); sys.exit(1)