from ufastrsa import srandom try: from _crypto import NUMBER as tomsfastmath pow3_ = tomsfastmath.exptmod invmod_ = tomsfastmath.invmod generate_prime_ = tomsfastmath.generate_prime def genprime(num=1024, test=25, safe=False): return generate_prime_(num, test, safe) except ImportError: pow3_ = pow def invmod_(a, b): c, d, e, f, g = 1, 0, 0, 1, b while b: q = a // b a, c, d, b, e, f = b, e, f, a - q * b, c - q * e, d - q * f assert a >= 0 and c % g >= 0 return a == 1 and c % g or 0 def miller_rabin_pass(a, n): n_minus_one = n - 1 s, d = get_lowest_set_bit(n_minus_one) a_to_power = pow3(a, d, n) if a_to_power == 1: return True for i in range(s): if a_to_power == n_minus_one: return True a_to_power = pow3(a_to_power, 2, n) if a_to_power == n_minus_one: return True return False class MillerRabinTest: def __init__(self, randint, repeat): self.randint = randint self.repeat = repeat def __call__(self, n): randint = self.randint n_minus_one = n - 1 for repeat in range(self.repeat): a = randint(1, n_minus_one) if not miller_rabin_pass(a, n): return False return True class GenPrime: def __init__(self, getrandbits, testfn): self.getrandbits = getrandbits self.testfn = testfn def __call__(self, bits): getrandbits = self.getrandbits testfn = self.testfn while True: p = (1 << (bits - 1)) | getrandbits(bits - 1) | 1 if p % 3 != 0 and p % 5 != 0 and p % 7 != 0 and testfn(p): break return p miller_rabin_test = MillerRabinTest(srandom.randint, 25) genprime = GenPrime(srandom.getrandbits, miller_rabin_test) def pow3(x, y, z): return pow3_(x, y, z) def invmod(a, b): return invmod_(a, b) def get_lowest_set_bit(n): i = 0 while n: if n & 1: return i, n n >>= 1 i += 1 raise "Error" def gcd(a, b): while b: a, b = b, a % b return a def get_bit_length(n): return srandom.get_bit_length(n) class GenRSA: def __init__(self, genprime): self.genprime = genprime def __call__(self, bits, e=None, with_crt=False): pbits = (bits + 1) >> 1 qbits = bits - pbits if e is None: e = 65537 elif e < 0: e = self.genprime(-e) while True: p = self.genprime(pbits) if gcd(e, p - 1) == 1: break while True: while True: q = self.genprime(qbits) if gcd(e, q - 1) == 1 and p != q: break n = p * q if get_bit_length(n) == bits: break p = max(p, q) p_minus_1 = p - 1 q_minus_1 = q - 1 phi = p_minus_1 * q_minus_1 d = invmod(e, phi) if with_crt: dp = d % p_minus_1 dq = d % q_minus_1 qinv = invmod(q, p) assert qinv < p return bits, n, e, d, p, q, dp, dq, qinv else: return bits, n, e, d genrsa = GenRSA(genprime)