Spaces:
Runtime error
Runtime error
| from __future__ import absolute_import | |
| import hashlib | |
| import inspect | |
| import os | |
| import re | |
| import sys | |
| from distutils.core import Distribution, Extension | |
| from distutils.command.build_ext import build_ext | |
| import Cython | |
| from ..Compiler.Main import Context, default_options | |
| from ..Compiler.Visitor import CythonTransform, EnvTransform | |
| from ..Compiler.ParseTreeTransforms import SkipDeclarations | |
| from ..Compiler.TreeFragment import parse_from_strings | |
| from ..Compiler.StringEncoding import _unicode | |
| from .Dependencies import strip_string_literals, cythonize, cached_function | |
| from ..Compiler import Pipeline | |
| from ..Utils import get_cython_cache_dir | |
| import cython as cython_module | |
| IS_PY3 = sys.version_info >= (3,) | |
| # A utility function to convert user-supplied ASCII strings to unicode. | |
| if not IS_PY3: | |
| def to_unicode(s): | |
| if isinstance(s, bytes): | |
| return s.decode('ascii') | |
| else: | |
| return s | |
| else: | |
| to_unicode = lambda x: x | |
| if sys.version_info < (3, 5): | |
| import imp | |
| def load_dynamic(name, module_path): | |
| return imp.load_dynamic(name, module_path) | |
| else: | |
| import importlib.util as _importlib_util | |
| def load_dynamic(name, module_path): | |
| spec = _importlib_util.spec_from_file_location(name, module_path) | |
| module = _importlib_util.module_from_spec(spec) | |
| # sys.modules[name] = module | |
| spec.loader.exec_module(module) | |
| return module | |
| class UnboundSymbols(EnvTransform, SkipDeclarations): | |
| def __init__(self): | |
| CythonTransform.__init__(self, None) | |
| self.unbound = set() | |
| def visit_NameNode(self, node): | |
| if not self.current_env().lookup(node.name): | |
| self.unbound.add(node.name) | |
| return node | |
| def __call__(self, node): | |
| super(UnboundSymbols, self).__call__(node) | |
| return self.unbound | |
| def unbound_symbols(code, context=None): | |
| code = to_unicode(code) | |
| if context is None: | |
| context = Context([], default_options) | |
| from ..Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform | |
| tree = parse_from_strings('(tree fragment)', code) | |
| for phase in Pipeline.create_pipeline(context, 'pyx'): | |
| if phase is None: | |
| continue | |
| tree = phase(tree) | |
| if isinstance(phase, AnalyseDeclarationsTransform): | |
| break | |
| try: | |
| import builtins | |
| except ImportError: | |
| import __builtin__ as builtins | |
| return tuple(UnboundSymbols()(tree) - set(dir(builtins))) | |
| def unsafe_type(arg, context=None): | |
| py_type = type(arg) | |
| if py_type is int: | |
| return 'long' | |
| else: | |
| return safe_type(arg, context) | |
| def safe_type(arg, context=None): | |
| py_type = type(arg) | |
| if py_type in (list, tuple, dict, str): | |
| return py_type.__name__ | |
| elif py_type is complex: | |
| return 'double complex' | |
| elif py_type is float: | |
| return 'double' | |
| elif py_type is bool: | |
| return 'bint' | |
| elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray): | |
| return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim) | |
| else: | |
| for base_type in py_type.__mro__: | |
| if base_type.__module__ in ('__builtin__', 'builtins'): | |
| return 'object' | |
| module = context.find_module(base_type.__module__, need_pxd=False) | |
| if module: | |
| entry = module.lookup(base_type.__name__) | |
| if entry.is_type: | |
| return '%s.%s' % (base_type.__module__, base_type.__name__) | |
| return 'object' | |
| def _get_build_extension(): | |
| dist = Distribution() | |
| # Ensure the build respects distutils configuration by parsing | |
| # the configuration files | |
| config_files = dist.find_config_files() | |
| dist.parse_config_files(config_files) | |
| build_extension = build_ext(dist) | |
| build_extension.finalize_options() | |
| return build_extension | |
| def _create_context(cython_include_dirs): | |
| return Context(list(cython_include_dirs), default_options) | |
| _cython_inline_cache = {} | |
| _cython_inline_default_context = _create_context(('.',)) | |
| def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None): | |
| for symbol in unbound_symbols: | |
| if symbol not in kwds: | |
| if locals is None or globals is None: | |
| calling_frame = inspect.currentframe().f_back.f_back.f_back | |
| if locals is None: | |
| locals = calling_frame.f_locals | |
| if globals is None: | |
| globals = calling_frame.f_globals | |
| if symbol in locals: | |
| kwds[symbol] = locals[symbol] | |
| elif symbol in globals: | |
| kwds[symbol] = globals[symbol] | |
| else: | |
| print("Couldn't find %r" % symbol) | |
| def _inline_key(orig_code, arg_sigs, language_level): | |
| key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__ | |
| return hashlib.sha1(_unicode(key).encode('utf-8')).hexdigest() | |
| def cython_inline(code, get_type=unsafe_type, | |
| lib_dir=os.path.join(get_cython_cache_dir(), 'inline'), | |
| cython_include_dirs=None, cython_compiler_directives=None, | |
| force=False, quiet=False, locals=None, globals=None, language_level=None, **kwds): | |
| if get_type is None: | |
| get_type = lambda x: 'object' | |
| ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context | |
| cython_compiler_directives = dict(cython_compiler_directives) if cython_compiler_directives else {} | |
| if language_level is None and 'language_level' not in cython_compiler_directives: | |
| language_level = '3str' | |
| if language_level is not None: | |
| cython_compiler_directives['language_level'] = language_level | |
| # Fast path if this has been called in this session. | |
| _unbound_symbols = _cython_inline_cache.get(code) | |
| if _unbound_symbols is not None: | |
| _populate_unbound(kwds, _unbound_symbols, locals, globals) | |
| args = sorted(kwds.items()) | |
| arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args]) | |
| key_hash = _inline_key(code, arg_sigs, language_level) | |
| invoke = _cython_inline_cache.get((code, arg_sigs, key_hash)) | |
| if invoke is not None: | |
| arg_list = [arg[1] for arg in args] | |
| return invoke(*arg_list) | |
| orig_code = code | |
| code = to_unicode(code) | |
| code, literals = strip_string_literals(code) | |
| code = strip_common_indent(code) | |
| if locals is None: | |
| locals = inspect.currentframe().f_back.f_back.f_locals | |
| if globals is None: | |
| globals = inspect.currentframe().f_back.f_back.f_globals | |
| try: | |
| _cython_inline_cache[orig_code] = _unbound_symbols = unbound_symbols(code) | |
| _populate_unbound(kwds, _unbound_symbols, locals, globals) | |
| except AssertionError: | |
| if not quiet: | |
| # Parsing from strings not fully supported (e.g. cimports). | |
| print("Could not parse code as a string (to extract unbound symbols).") | |
| cimports = [] | |
| for name, arg in list(kwds.items()): | |
| if arg is cython_module: | |
| cimports.append('\ncimport cython as %s' % name) | |
| del kwds[name] | |
| arg_names = sorted(kwds) | |
| arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names]) | |
| key_hash = _inline_key(orig_code, arg_sigs, language_level) | |
| module_name = "_cython_inline_" + key_hash | |
| if module_name in sys.modules: | |
| module = sys.modules[module_name] | |
| else: | |
| build_extension = None | |
| if cython_inline.so_ext is None: | |
| # Figure out and cache current extension suffix | |
| build_extension = _get_build_extension() | |
| cython_inline.so_ext = build_extension.get_ext_filename('') | |
| module_path = os.path.join(lib_dir, module_name + cython_inline.so_ext) | |
| if not os.path.exists(lib_dir): | |
| os.makedirs(lib_dir) | |
| if force or not os.path.isfile(module_path): | |
| cflags = [] | |
| c_include_dirs = [] | |
| qualified = re.compile(r'([.\w]+)[.]') | |
| for type, _ in arg_sigs: | |
| m = qualified.match(type) | |
| if m: | |
| cimports.append('\ncimport %s' % m.groups()[0]) | |
| # one special case | |
| if m.groups()[0] == 'numpy': | |
| import numpy | |
| c_include_dirs.append(numpy.get_include()) | |
| # cflags.append('-Wno-unused') | |
| module_body, func_body = extract_func_code(code) | |
| params = ', '.join(['%s %s' % a for a in arg_sigs]) | |
| module_code = """ | |
| %(module_body)s | |
| %(cimports)s | |
| def __invoke(%(params)s): | |
| %(func_body)s | |
| return locals() | |
| """ % {'cimports': '\n'.join(cimports), | |
| 'module_body': module_body, | |
| 'params': params, | |
| 'func_body': func_body } | |
| for key, value in literals.items(): | |
| module_code = module_code.replace(key, value) | |
| pyx_file = os.path.join(lib_dir, module_name + '.pyx') | |
| fh = open(pyx_file, 'w') | |
| try: | |
| fh.write(module_code) | |
| finally: | |
| fh.close() | |
| extension = Extension( | |
| name = module_name, | |
| sources = [pyx_file], | |
| include_dirs = c_include_dirs, | |
| extra_compile_args = cflags) | |
| if build_extension is None: | |
| build_extension = _get_build_extension() | |
| build_extension.extensions = cythonize( | |
| [extension], | |
| include_path=cython_include_dirs or ['.'], | |
| compiler_directives=cython_compiler_directives, | |
| quiet=quiet) | |
| build_extension.build_temp = os.path.dirname(pyx_file) | |
| build_extension.build_lib = lib_dir | |
| build_extension.run() | |
| module = load_dynamic(module_name, module_path) | |
| _cython_inline_cache[orig_code, arg_sigs, key_hash] = module.__invoke | |
| arg_list = [kwds[arg] for arg in arg_names] | |
| return module.__invoke(*arg_list) | |
| # Cached suffix used by cython_inline above. None should get | |
| # overridden with actual value upon the first cython_inline invocation | |
| cython_inline.so_ext = None | |
| _find_non_space = re.compile('[^ ]').search | |
| def strip_common_indent(code): | |
| min_indent = None | |
| lines = code.splitlines() | |
| for line in lines: | |
| match = _find_non_space(line) | |
| if not match: | |
| continue # blank | |
| indent = match.start() | |
| if line[indent] == '#': | |
| continue # comment | |
| if min_indent is None or min_indent > indent: | |
| min_indent = indent | |
| for ix, line in enumerate(lines): | |
| match = _find_non_space(line) | |
| if not match or not line or line[indent:indent+1] == '#': | |
| continue | |
| lines[ix] = line[min_indent:] | |
| return '\n'.join(lines) | |
| module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))') | |
| def extract_func_code(code): | |
| module = [] | |
| function = [] | |
| current = function | |
| code = code.replace('\t', ' ') | |
| lines = code.split('\n') | |
| for line in lines: | |
| if not line.startswith(' '): | |
| if module_statement.match(line): | |
| current = module | |
| else: | |
| current = function | |
| current.append(line) | |
| return '\n'.join(module), ' ' + '\n '.join(function) | |
| try: | |
| from inspect import getcallargs | |
| except ImportError: | |
| def getcallargs(func, *arg_values, **kwd_values): | |
| all = {} | |
| args, varargs, kwds, defaults = inspect.getargspec(func) | |
| if varargs is not None: | |
| all[varargs] = arg_values[len(args):] | |
| for name, value in zip(args, arg_values): | |
| all[name] = value | |
| for name, value in list(kwd_values.items()): | |
| if name in args: | |
| if name in all: | |
| raise TypeError("Duplicate argument %s" % name) | |
| all[name] = kwd_values.pop(name) | |
| if kwds is not None: | |
| all[kwds] = kwd_values | |
| elif kwd_values: | |
| raise TypeError("Unexpected keyword arguments: %s" % list(kwd_values)) | |
| if defaults is None: | |
| defaults = () | |
| first_default = len(args) - len(defaults) | |
| for ix, name in enumerate(args): | |
| if name not in all: | |
| if ix >= first_default: | |
| all[name] = defaults[ix - first_default] | |
| else: | |
| raise TypeError("Missing argument: %s" % name) | |
| return all | |
| def get_body(source): | |
| ix = source.index(':') | |
| if source[:5] == 'lambda': | |
| return "return %s" % source[ix+1:] | |
| else: | |
| return source[ix+1:] | |
| # Lots to be done here... It would be especially cool if compiled functions | |
| # could invoke each other quickly. | |
| class RuntimeCompiledFunction(object): | |
| def __init__(self, f): | |
| self._f = f | |
| self._body = get_body(inspect.getsource(f)) | |
| def __call__(self, *args, **kwds): | |
| all = getcallargs(self._f, *args, **kwds) | |
| if IS_PY3: | |
| return cython_inline(self._body, locals=self._f.__globals__, globals=self._f.__globals__, **all) | |
| else: | |
| return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all) | |