| from capstone import * |
| from capstone.arm import * |
| from capstone.arm64 import * |
| from capstone.x86 import * |
| import cle |
| import struct |
| from math import e as CONST_E, pi as CONST_PI |
| import sympy as sp |
|
|
| from .util import DecodeError |
|
|
| def int2fp32(v): |
| if type(v) == int: |
| v = struct.unpack("<f", v.to_bytes(4, "little")) |
| v = v[0] |
| return v |
| def int2fp64(v): |
| if type(v) == int: |
| v = struct.unpack("<d", v.to_bytes(8, "little")) |
| v = v[0] |
| return v |
|
|
| def align4(v): |
| return v & (0xFFFFFFFC) |
|
|
| class DisassemblerBase: |
| def __init__(self, expr_constants={}, match_constants=False): |
| self.loader = None |
| self.reg_values = {} |
| self.constidx = 0 |
| self.constants = {} |
| self.constaddrs = set() |
| self.expr_constants = expr_constants |
| self.match_constants = match_constants |
|
|
| def get_function_bytes(self, funcname): |
| func = self.loader.find_symbol(funcname) |
| if not func: |
| raise DecodeError(f"Function {funcname} not found in binary") |
| faddr = func.rebased_addr |
| if (not isinstance(self, DisassemblerX64)) and faddr % 2 == 1: |
| |
| faddr = faddr - 1 |
| fbytes = self.loader.memory.load(faddr, func.size) |
| self.funcrange = faddr, faddr + func.size |
| return faddr, fbytes |
|
|
| def find_constant(self, constants, value): |
| for ec in constants: |
| if abs(value - constants[ec]) < 1e-5: |
| return ec, "" |
| elif abs(1/value - constants[ec]) < 1e-5: |
| return ec, "1/" |
| elif abs(-value - constants[ec]) < 1e-5: |
| return ec, "-" |
| elif abs(-1/value - constants[ec]) < 1e-5: |
| return ec, "-1/" |
| return False |
|
|
| def add_constant(self, value, addr=0, size=0): |
| |
| if value == 0: |
| cname = "CONST=0" |
| elif abs(value - CONST_E) < 1e-7: |
| cname = "CONST=E" |
| elif abs(value - CONST_PI) < 1e-7: |
| cname = "CONST=pi" |
| elif self.match_constants and \ |
| (ecmatch := self.find_constant(self.expr_constants, value)): |
| |
| ecname, ecxpr = ecmatch |
| |
| cname = f"{ecxpr}CSYM{ecname[1:]}" |
| self.constants[ecname] = value |
| elif size > 0 and addr in self.constaddrs and \ |
| (smatch := self.find_constant(self.constants, value)): |
| sname, sxpr = smatch |
| cname = f"{sxpr}CSYM{sname}" |
| else: |
| rep = sp.nsimplify(value, [sp.E, sp.pi], tolerance=1e-7) |
| if isinstance(rep, sp.Integer) or \ |
| (isinstance(rep, sp.Rational) and rep.q <= 16): |
| cname = f"CONST={rep}" |
| elif not self.match_constants: |
| cname = f"CSYM{self.constidx}" |
| self.constants[self.constidx] = value |
| self.constidx += 1 |
| else: |
| raise DecodeError(f"Cannot represent unmatched float {value}") |
|
|
| if size > 0: |
| self.constaddrs |= {addr+i for i in range(size)} |
| return cname |
|
|
| def disassemble(self, function): |
| raise NotImplementedError("Call disassemble on child classes, not base") |
|
|
|
|
| class DisassemblerARM32(DisassemblerBase): |
| def __init__(self, binpath, expr_constants={}, match_constants=False): |
| super().__init__(expr_constants=expr_constants, match_constants=match_constants) |
| self.md = Cs(CS_ARCH_ARM, CS_MODE_THUMB) |
| self.md.detail = True |
| self.loader = cle.Loader(binpath) |
|
|
| def check_mov_imm(self, insn): |
| if insn.id not in {ARM_INS_MOV, ARM_INS_MOVW, |
| ARM_INS_MOVT, ARM_INS_ADR}: |
| return False |
| ops = list(insn.operands) |
| if len(ops) != 2: |
| return False |
| if ops[0].type != ARM_OP_REG or ops[1].type != ARM_OP_IMM: |
| return False |
| imm = ops[1].value.imm |
| if imm < 0: |
| imm = 2**32 + imm |
| if insn.id == ARM_INS_ADR: |
| |
| imm += insn.address + 4 |
| return ops[0].value.reg, imm |
|
|
| def check_float_store(self, insn): |
| if insn.id not in {ARM_INS_STR, ARM_INS_STRD}: |
| return False |
| ops = list(insn.operands) |
| if insn.id == ARM_INS_STRD: |
| dest = ops[0].value.reg |
| dest2 = ops[1].value.reg |
| if dest not in self.reg_values or dest2 not in self.reg_values: |
| return False |
| fval = int2fp64((self.reg_values[dest2]<<32) + self.reg_values[dest]) |
| else: |
| dest = ops[0].value.reg |
| if dest not in self.reg_values: |
| return False |
| fval = int2fp32(self.reg_values[dest]) |
| if abs(fval) < 1e-3 or abs(fval) > 100: |
| return False |
| return fval |
|
|
| def check_ldrd(self, insn): |
| if insn.id != ARM_INS_LDRD: |
| return False |
| ops = list(insn.operands) |
| if len(ops) != 3: |
| return False |
| if ops[2].type != ARM_OP_MEM: |
| return False |
| mem = ops[2].value.mem |
| if mem.base == ARM_REG_PC: |
| addr = align4(insn.address + 4) + mem.disp |
| elif mem.base in self.reg_values: |
| addr = align4(self.reg_values[mem.base]) + mem.disp |
| else: |
| return False |
| if addr < self.loader.min_addr or addr + 8 > self.loader.max_addr: |
| |
| return False |
| fhex = self.loader.memory.load(addr, 8) |
| fval = struct.unpack("d", fhex)[0] |
| return fval, addr, 8 |
|
|
| def check_vldr(self, insn): |
| if insn.id != ARM_INS_VLDR: |
| return False |
| ops = list(insn.operands) |
| dest = ops[0] |
| if ops[1].type != ARM_OP_MEM: |
| return False |
| mem = ops[1].value.mem |
| if mem.base == ARM_REG_PC: |
| |
| |
| addr = align4(insn.address + 4) + mem.disp |
| elif mem.base in self.reg_values: |
| addr = align4(self.reg_values[mem.base]) + mem.disp |
| else: |
| return False |
| if addr < self.loader.min_addr or addr + 8 > self.loader.max_addr: |
| |
| return False |
| if dest.value.reg >= ARM_REG_D0 and dest.value.reg <= ARM_REG_D31: |
| size = 8 |
| fhex = self.loader.memory.load(addr, 8) |
| fval = struct.unpack("d", fhex)[0] |
| else: |
| size = 4 |
| fhex = self.loader.memory.load(addr, 4) |
| fval = struct.unpack("f", fhex)[0] |
| return fval, addr, size |
|
|
| def check_vmov(self, insn): |
| |
| if insn.id not in {ARM_INS_FCONSTS, ARM_INS_FCONSTD}: |
| return False |
| ops = list(insn.operands) |
| if len(ops) != 2 or ops[1].type != ARM_OP_FP: |
| return False |
| fval = ops[1].value.fp |
| destname = insn.reg_name(ops[0].value.reg) |
| asm = f"{insn.mnemonic} {destname}, {fval}" |
| return asm, fval |
| |
| def check_branch_symbol(self, insn): |
| if insn.id not in {ARM_INS_B, ARM_INS_BL, ARM_INS_BLX}: |
| return False |
| ops = list(insn.operands) |
| if len(ops) != 1 or ops[0].type != ARM_OP_IMM: |
| return False |
| addr = ops[0].value.imm |
| if addr > self.funcrange[0] and addr < self.funcrange[1]: |
| |
| func = f"SELF+{hex(addr - self.funcrange[0])}" |
| else: |
| func = self.loader.find_plt_stub_name(addr) |
| if func is None: |
| |
| |
| func = self.loader.find_plt_stub_name(addr + 4) |
| if func is None: |
| return False |
| asm = f"{insn.mnemonic} <{func}>" |
| return asm |
|
|
| def get_function_bytes(self, funcname): |
| func = self.loader.find_symbol(funcname) |
| if not func: |
| raise DecodeError(f"Function {funcname} not found in binary") |
| faddr = func.rebased_addr |
| if faddr % 2 == 1: |
| |
| faddr = faddr - 1 |
| fbytes = self.loader.memory.load(faddr, func.size) |
| self.funcrange = faddr, faddr + func.size |
| return faddr, fbytes |
|
|
| def disassemble(self, funcname): |
| funcaddr, funcbytes = self.get_function_bytes(funcname) |
| disassm = [] |
|
|
| for insn in self.md.disasm(funcbytes, funcaddr): |
| if insn.address in self.constaddrs: |
| |
| continue |
|
|
| cname = None |
| asm = None |
|
|
| if vldr := self.check_vldr(insn): |
| fval, faddr, fsize = vldr |
| cname = self.add_constant(fval, faddr, fsize) |
| elif ldrd := self.check_ldrd(insn): |
| fval, faddr, fsize = ldrd |
| cname = self.add_constant(fval, faddr, fsize) |
| elif strfloat := self.check_float_store(insn): |
| fval = strfloat |
| cname = self.add_constant(fval) |
| elif vmovfloat := self.check_vmov(insn): |
| asm, fval = vmovfloat |
| cname = self.add_constant(fval) |
| elif branch := self.check_branch_symbol(insn): |
| asm = branch |
|
|
| |
| |
| if movimm := self.check_mov_imm(insn): |
| reg, imm = movimm |
| if insn.id == ARM_INS_MOVT: |
| if reg not in self.reg_values: |
| self.reg_values[reg] = 0 |
| self.reg_values[reg] += imm << 16 |
| else: |
| self.reg_values[reg] = imm |
| else: |
| reads, writes = insn.regs_access() |
| for r in writes: |
| |
| if r in self.reg_values: |
| del self.reg_values[r] |
|
|
| if not asm: |
| asm = f"{insn.mnemonic} {insn.op_str}" |
| if cname: |
| asm += f", {cname}" |
| disassm.append(asm) |
|
|
| fulldiss = "; ".join(disassm) |
| return fulldiss |
|
|
| class DisassemblerAArch64(DisassemblerBase): |
| def __init__(self, binpath, expr_constants={}, match_constants=False): |
| super().__init__(expr_constants=expr_constants, match_constants=match_constants) |
| self.md = Cs(CS_ARCH_ARM64, CS_MODE_ARM) |
| self.md.detail = True |
| self.loader = cle.Loader(binpath) |
|
|
| def reg_size_type(self, reg): |
| |
| if reg >= ARM64_REG_W0 and reg <= ARM64_REG_W30: |
| return 32, int |
| elif reg >= ARM64_REG_X0 and reg <= ARM64_REG_X30: |
| return 64, int |
| elif reg >= ARM64_REG_S0 and reg <= ARM64_REG_S31: |
| return 32, float |
| elif reg >= ARM64_REG_D0 and reg <= ARM64_REG_D31: |
| return 64, float |
| return 0, None |
|
|
| def check_mov_imm(self, insn): |
| if insn.id not in {ARM64_INS_ADRP, ARM64_INS_ADR, ARM64_INS_MOV, ARM64_INS_MOVK}: |
| return False |
|
|
| ops = insn.operands |
| if len(ops) != 2: |
| return False |
| if ops[0].type != ARM64_OP_REG or ops[1].type != ARM64_OP_IMM: |
| return False |
|
|
| imm = ops[1].value.imm |
| if ops[1].shift.type == 1: |
| imm <<= ops[1].shift.value |
| mask = 0xFFFF << ops[1].shift.value |
|
|
| if insn.id == ARM64_INS_ADRP: |
| |
| |
| |
| pass |
| elif insn.id == ARM64_INS_ADR: |
| imm -= 0x400000 |
| imm += insn.address + 4 |
| elif insn.id == ARM64_INS_MOVK: |
| |
| if ops[0].value.reg in self.reg_values: |
| curr = self.reg_values[ops[0].value.reg] |
| imm = (imm & mask) | (curr & (~mask)) |
| |
| return ops[0].value.reg, imm |
|
|
| def check_fmov(self, insn): |
| if insn.id != ARM64_INS_FMOV: |
| return False |
| ops = insn.operands |
| if len(ops) != 2: |
| return False |
|
|
| destsize, _ = self.reg_size_type(ops[0].value.reg) |
| destname = insn.reg_name(ops[0].value.reg) |
| if ops[1].type == ARM64_OP_FP: |
| fval = ops[1].value.fp |
| asm = f"{insn.mnemonic} {destname}, {fval}" |
| elif ops[1].type == ARM64_OP_REG: |
| reg = ops[1].value.reg |
| if reg not in self.reg_values: |
| return False |
| |
| fhex = self.reg_values[reg] |
| if destsize == 64: |
| if fhex < 0: |
| fhex += 2**64 |
| fval = int2fp64(fhex) |
| elif destsize == 32: |
| if fhex < 0: |
| fhex += 2**32 |
| fval = int2fp32(fhex) |
| else: |
| return False |
|
|
| if abs(fval) < 1e-5 or abs(fval) > 1e5: |
| return False |
| asm = f"{insn.mnemonic} {insn.op_str}" |
| return asm, fval |
|
|
| def check_ldr(self, insn): |
| if insn.id != ARM64_INS_LDR: |
| return False |
| ops = insn.op_str[:-1].split(", ") |
| destsize, desttype = self.reg_size_type(insn.operands[0].value.reg) |
| if len(ops) < 2 or desttype != float: |
| return False |
| reg = ops[1] |
| if reg[0] != "[" or "sp" in reg: |
| return False |
| basereg = ARM64_REG_X0 + int(reg[2:]) |
| if basereg not in self.reg_values: |
| return False |
| base = align4(self.reg_values[basereg]) |
| if len(ops) == 3: |
| offset = ops[2][1:] |
| if offset.startswith("0x"): |
| offset = int(offset[2:], base=16) |
| else: |
| offset = int(offset) |
| else: |
| offset = 0 |
| addr = base + offset |
| if destsize == 64: |
| fhex = self.loader.memory.load(addr, 8) |
| fval = struct.unpack("d", fhex)[0] |
| return fval, addr, 8 |
| elif destsize == 32: |
| fhex = self.loader.memory.load(addr, 4) |
| fval = struct.unpack("f", fhex)[0] |
| return fval, addr, 4 |
| else: |
| return False |
|
|
|
|
| def check_branch_symbol(self, insn): |
| if insn.id not in {ARM64_INS_BL, ARM64_INS_B}: |
| return False |
| ops = insn.operands |
| if len(ops) != 1 or ops[0].type != ARM_OP_IMM: |
| return False |
| addr = ops[0].value.imm |
| if addr > self.funcrange[0] and addr < self.funcrange[1]: |
| |
| func = f"SELF+{hex(addr - self.funcrange[0])}" |
| else: |
| func = self.loader.find_plt_stub_name(addr) |
| if func is None: |
| |
| |
| func = self.loader.find_plt_stub_name(addr + 4) |
| if func is None: |
| return False |
| asm = f"{insn.mnemonic} <{func}>" |
| return asm |
|
|
| def disassemble(self, funcname): |
| funcaddr, funcbytes = self.get_function_bytes(funcname) |
| disassm = [] |
|
|
| for insn in self.md.disasm(funcbytes, funcaddr): |
| if insn.address in self.constaddrs: |
| |
| continue |
|
|
| cname = None |
| asm = None |
| |
| if movimm := self.check_mov_imm(insn): |
| reg, imm = movimm |
| self.reg_values[reg] = imm |
| else: |
| reads, writes = insn.regs_access() |
| for r in writes: |
| |
| if r in self.reg_values: |
| del self.reg_values[r] |
|
|
| if fmov := self.check_fmov(insn): |
| asm, fval = fmov |
| cname = self.add_constant(fval) |
| elif ldr := self.check_ldr(insn): |
| fval, faddr, fsize = ldr |
| cname = self.add_constant(fval, faddr, fsize) |
| elif branch := self.check_branch_symbol(insn): |
| asm = branch |
|
|
| if not asm: |
| asm = f"{insn.mnemonic} {insn.op_str}" |
| if cname: |
| asm += f", {cname}" |
| disassm.append(asm) |
| |
| fulldiss = "; ".join(disassm) |
| return fulldiss |
|
|
| class DisassemblerX64(DisassemblerBase): |
| def __init__(self, binpath, expr_constants={}, match_constants=False): |
| super().__init__(expr_constants=expr_constants, match_constants=match_constants) |
| self.md = Cs(CS_ARCH_X86, CS_MODE_64) |
| self.md.detail = True |
| self.loader = cle.Loader(binpath) |
|
|
| def check_call_symbol(self, insn): |
| if insn.id != X86_INS_CALL: |
| return False |
| ops = insn.operands |
| if len(ops) != 1 or ops[0].type != X86_OP_IMM: |
| return False |
| addr = ops[0].value.imm |
| func = self.loader.find_plt_stub_name(addr) |
| if func is None: |
| return False |
| asm = f"{insn.mnemonic} <{func}>" |
| return asm |
|
|
| def check_fload(self, insn): |
| |
| |
| ops = insn.operands |
| memops = [op for op in ops |
| if (op.type == X86_OP_MEM and |
| op.value.mem.base == X86_REG_RIP)] |
| if len(memops) != 1: |
| return False |
| mem, size = memops[0].value.mem, memops[0].size |
| if size > 8: |
| return False |
| addr = insn.address + insn.size + mem.disp |
| fhex = self.loader.memory.load(addr, size) |
| fval = struct.unpack("f" if size == 4 else "d", fhex)[0] |
| return fval, addr, size |
|
|
| def disassemble(self, funcname): |
| funcaddr, funcbytes = self.get_function_bytes(funcname) |
| disassm = [] |
|
|
| for insn in self.md.disasm(funcbytes, funcaddr): |
| asm = None |
| cname = None |
| if fload := self.check_fload(insn): |
| fval, faddr, fsize = fload |
| cname = self.add_constant(fval, faddr, fsize) |
| elif call := self.check_call_symbol(insn): |
| asm = call |
|
|
| if not asm: |
| asm = f"{insn.mnemonic} {insn.op_str}" |
| if cname: |
| asm += f", {cname}" |
| disassm.append(asm) |
| |
| fulldiss = "; ".join(disassm) |
| return fulldiss |
|
|
|
|
| |
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser("Pre-process assembly to replace constants and dump") |
| parser.add_argument("--bin", required=True) |
| parser.add_argument("--func", required=True) |
| parser.add_argument("--arch", required=True) |
| args = parser.parse_args() |
| |
| if args.arch == "arm32": |
| D = DisassemblerARM32(args.bin) |
| elif args.arch == "aarch64": |
| D = DisassemblerAArch64(args.bin) |
| elif args.arch == "x64": |
| D = DisassemblerX64(args.bin) |
| diss = D.disassemble(args.func) |
| print(diss) |
| print(D.constants) |
|
|