| from __future__ import absolute_import
|
|
|
| import os
|
| import re
|
| import unittest
|
| import shlex
|
| import sys
|
| import tempfile
|
| import textwrap
|
| from io import open
|
| from functools import partial
|
|
|
| from .Compiler import Errors
|
| from .CodeWriter import CodeWriter
|
| from .Compiler.TreeFragment import TreeFragment, strip_common_indent
|
| from .Compiler.Visitor import TreeVisitor, VisitorTransform
|
| from .Compiler import TreePath
|
|
|
|
|
| class NodeTypeWriter(TreeVisitor):
|
| def __init__(self):
|
| super(NodeTypeWriter, self).__init__()
|
| self._indents = 0
|
| self.result = []
|
|
|
| def visit_Node(self, node):
|
| if not self.access_path:
|
| name = u"(root)"
|
| else:
|
| tip = self.access_path[-1]
|
| if tip[2] is not None:
|
| name = u"%s[%d]" % tip[1:3]
|
| else:
|
| name = tip[1]
|
|
|
| self.result.append(u" " * self._indents +
|
| u"%s: %s" % (name, node.__class__.__name__))
|
| self._indents += 1
|
| self.visitchildren(node)
|
| self._indents -= 1
|
|
|
|
|
| def treetypes(root):
|
| """Returns a string representing the tree by class names.
|
| There's a leading and trailing whitespace so that it can be
|
| compared by simple string comparison while still making test
|
| cases look ok."""
|
| w = NodeTypeWriter()
|
| w.visit(root)
|
| return u"\n".join([u""] + w.result + [u""])
|
|
|
|
|
| class CythonTest(unittest.TestCase):
|
|
|
| def setUp(self):
|
| Errors.init_thread()
|
|
|
| def tearDown(self):
|
| Errors.init_thread()
|
|
|
| def assertLines(self, expected, result):
|
| "Checks that the given strings or lists of strings are equal line by line"
|
| if not isinstance(expected, list):
|
| expected = expected.split(u"\n")
|
| if not isinstance(result, list):
|
| result = result.split(u"\n")
|
| for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
|
| self.assertEqual(expected_line, result_line,
|
| "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
|
| self.assertEqual(len(expected), len(result),
|
| "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
|
|
|
| def codeToLines(self, tree):
|
| writer = CodeWriter()
|
| writer.write(tree)
|
| return writer.result.lines
|
|
|
| def codeToString(self, tree):
|
| return "\n".join(self.codeToLines(tree))
|
|
|
| def assertCode(self, expected, result_tree):
|
| result_lines = self.codeToLines(result_tree)
|
|
|
| expected_lines = strip_common_indent(expected.split("\n"))
|
|
|
| for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)):
|
| self.assertEqual(expected_line, line,
|
| "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line))
|
| self.assertEqual(len(result_lines), len(expected_lines),
|
| "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
|
|
|
| def assertNodeExists(self, path, result_tree):
|
| self.assertNotEqual(TreePath.find_first(result_tree, path), None,
|
| "Path '%s' not found in result tree" % path)
|
|
|
| def fragment(self, code, pxds=None, pipeline=None):
|
| "Simply create a tree fragment using the name of the test-case in parse errors."
|
| if pxds is None:
|
| pxds = {}
|
| if pipeline is None:
|
| pipeline = []
|
| name = self.id()
|
| if name.startswith("__main__."):
|
| name = name[len("__main__."):]
|
| name = name.replace(".", "_")
|
| return TreeFragment(code, name, pxds, pipeline=pipeline)
|
|
|
| def treetypes(self, root):
|
| return treetypes(root)
|
|
|
| def should_fail(self, func, exc_type=Exception):
|
| """Calls "func" and fails if it doesn't raise the right exception
|
| (any exception by default). Also returns the exception in question.
|
| """
|
| try:
|
| func()
|
| self.fail("Expected an exception of type %r" % exc_type)
|
| except exc_type as e:
|
| self.assertTrue(isinstance(e, exc_type))
|
| return e
|
|
|
| def should_not_fail(self, func):
|
| """Calls func and succeeds if and only if no exception is raised
|
| (i.e. converts exception raising into a failed testcase). Returns
|
| the return value of func."""
|
| try:
|
| return func()
|
| except Exception as exc:
|
| self.fail(str(exc))
|
|
|
|
|
| class TransformTest(CythonTest):
|
| """
|
| Utility base class for transform unit tests. It is based around constructing
|
| test trees (either explicitly or by parsing a Cython code string); running
|
| the transform, serialize it using a customized Cython serializer (with
|
| special markup for nodes that cannot be represented in Cython),
|
| and do a string-comparison line-by-line of the result.
|
|
|
| To create a test case:
|
| - Call run_pipeline. The pipeline should at least contain the transform you
|
| are testing; pyx should be either a string (passed to the parser to
|
| create a post-parse tree) or a node representing input to pipeline.
|
| The result will be a transformed result.
|
|
|
| - Check that the tree is correct. If wanted, assertCode can be used, which
|
| takes a code string as expected, and a ModuleNode in result_tree
|
| (it serializes the ModuleNode to a string and compares line-by-line).
|
|
|
| All code strings are first stripped for whitespace lines and then common
|
| indentation.
|
|
|
| Plans: One could have a pxd dictionary parameter to run_pipeline.
|
| """
|
|
|
| def run_pipeline(self, pipeline, pyx, pxds=None):
|
| if pxds is None:
|
| pxds = {}
|
| tree = self.fragment(pyx, pxds).root
|
|
|
| for T in pipeline:
|
| tree = T(tree)
|
| return tree
|
|
|
|
|
|
|
|
|
|
|
|
|
| _strip_c_comments = partial(re.compile(
|
| re.sub(r'\s+', '', r'''
|
| /[*] (
|
| (?: [^*\n] | [*][^/] )*
|
| [\n]
|
| (?: [^*] | [*][^/] )*
|
| ) [*]/
|
| ''')
|
| ).sub, '')
|
|
|
| _strip_cython_code_from_html = partial(re.compile(
|
| re.sub(r'\s\s+', '', r'''
|
| (?:
|
| <pre class=["'][^"']*cython\s+line[^"']*["']\s*>
|
| (?:[^<]|<(?!/pre))+
|
| </pre>
|
| )|(?:
|
| <style[^>]*>
|
| (?:[^<]|<(?!/style))+
|
| </style>
|
| )
|
| ''')
|
| ).sub, '')
|
|
|
|
|
| def _parse_pattern(pattern):
|
| start = end = None
|
| if pattern.startswith('/'):
|
| start, pattern = re.split(r"(?<!\\)/", pattern[1:], maxsplit=1)
|
| pattern = pattern.strip()
|
| if pattern.startswith(':'):
|
| pattern = pattern[1:].strip()
|
| if pattern.startswith("/"):
|
| end, pattern = re.split(r"(?<!\\)/", pattern[1:], maxsplit=1)
|
| pattern = pattern.strip()
|
| return start, end, pattern
|
|
|
|
|
| class TreeAssertVisitor(VisitorTransform):
|
|
|
|
|
|
|
| def __init__(self):
|
| super(TreeAssertVisitor, self).__init__()
|
| self._module_pos = None
|
| self._c_patterns = []
|
| self._c_antipatterns = []
|
|
|
| def create_c_file_validator(self):
|
| patterns, antipatterns = self._c_patterns, self._c_antipatterns
|
|
|
| def fail(pos, pattern, found, file_path):
|
| Errors.error(pos, "Pattern '%s' %s found in %s" %(
|
| pattern,
|
| 'was' if found else 'was not',
|
| file_path,
|
| ))
|
|
|
| def extract_section(file_path, content, start, end):
|
| if start:
|
| split = re.search(start, content)
|
| if split:
|
| content = content[split.end():]
|
| else:
|
| fail(self._module_pos, start, found=False, file_path=file_path)
|
| if end:
|
| split = re.search(end, content)
|
| if split:
|
| content = content[:split.start()]
|
| else:
|
| fail(self._module_pos, end, found=False, file_path=file_path)
|
| return content
|
|
|
| def validate_file_content(file_path, content):
|
| for pattern in patterns:
|
|
|
| start, end, pattern = _parse_pattern(pattern)
|
| section = extract_section(file_path, content, start, end)
|
| if not re.search(pattern, section):
|
| fail(self._module_pos, pattern, found=False, file_path=file_path)
|
|
|
| for antipattern in antipatterns:
|
|
|
| start, end, antipattern = _parse_pattern(antipattern)
|
| section = extract_section(file_path, content, start, end)
|
| if re.search(antipattern, section):
|
| fail(self._module_pos, antipattern, found=True, file_path=file_path)
|
|
|
| def validate_c_file(result):
|
| c_file = result.c_file
|
| if not (patterns or antipatterns):
|
|
|
| return result
|
|
|
| with open(c_file, encoding='utf8') as f:
|
| content = f.read()
|
| content = _strip_c_comments(content)
|
| validate_file_content(c_file, content)
|
|
|
| html_file = os.path.splitext(c_file)[0] + ".html"
|
| if os.path.exists(html_file) and os.path.getmtime(c_file) <= os.path.getmtime(html_file):
|
| with open(html_file, encoding='utf8') as f:
|
| content = f.read()
|
| content = _strip_cython_code_from_html(content)
|
| validate_file_content(html_file, content)
|
|
|
| return validate_c_file
|
|
|
| def _check_directives(self, node):
|
| directives = node.directives
|
| if 'test_assert_path_exists' in directives:
|
| for path in directives['test_assert_path_exists']:
|
| if TreePath.find_first(node, path) is None:
|
| Errors.error(
|
| node.pos,
|
| "Expected path '%s' not found in result tree" % path)
|
| if 'test_fail_if_path_exists' in directives:
|
| for path in directives['test_fail_if_path_exists']:
|
| first_node = TreePath.find_first(node, path)
|
| if first_node is not None:
|
| Errors.error(
|
| first_node.pos,
|
| "Unexpected path '%s' found in result tree" % path)
|
| if 'test_assert_c_code_has' in directives:
|
| self._c_patterns.extend(directives['test_assert_c_code_has'])
|
| if 'test_fail_if_c_code_has' in directives:
|
| self._c_antipatterns.extend(directives['test_fail_if_c_code_has'])
|
|
|
| def visit_ModuleNode(self, node):
|
| self._module_pos = node.pos
|
| self._check_directives(node)
|
| self.visitchildren(node)
|
| return node
|
|
|
| def visit_CompilerDirectivesNode(self, node):
|
| self._check_directives(node)
|
| self.visitchildren(node)
|
| return node
|
|
|
| visit_Node = VisitorTransform.recurse_to_children
|
|
|
|
|
| def unpack_source_tree(tree_file, workdir, cython_root):
|
| programs = {
|
| 'PYTHON': [sys.executable],
|
| 'CYTHON': [sys.executable, os.path.join(cython_root, 'cython.py')],
|
| 'CYTHONIZE': [sys.executable, os.path.join(cython_root, 'cythonize.py')]
|
| }
|
|
|
| if workdir is None:
|
| workdir = tempfile.mkdtemp()
|
| header, cur_file = [], None
|
| with open(tree_file, 'rb') as f:
|
| try:
|
| for line in f:
|
| if line[:5] == b'#####':
|
| filename = line.strip().strip(b'#').strip().decode('utf8').replace('/', os.path.sep)
|
| path = os.path.join(workdir, filename)
|
| if not os.path.exists(os.path.dirname(path)):
|
| os.makedirs(os.path.dirname(path))
|
| if cur_file is not None:
|
| to_close, cur_file = cur_file, None
|
| to_close.close()
|
| cur_file = open(path, 'wb')
|
| elif cur_file is not None:
|
| cur_file.write(line)
|
| elif line.strip() and not line.lstrip().startswith(b'#'):
|
| if line.strip() not in (b'"""', b"'''"):
|
| command = shlex.split(line.decode('utf8'))
|
| if not command: continue
|
|
|
| prog, args = command[0], command[1:]
|
| try:
|
| header.append(programs[prog]+args)
|
| except KeyError:
|
| header.append(command)
|
| finally:
|
| if cur_file is not None:
|
| cur_file.close()
|
| return workdir, header
|
|
|
|
|
| def write_file(file_path, content, dedent=False, encoding=None):
|
| r"""Write some content (text or bytes) to the file
|
| at `file_path` without translating `'\n'` into `os.linesep`.
|
|
|
| The default encoding is `'utf-8'`.
|
| """
|
| if isinstance(content, bytes):
|
| mode = "wb"
|
|
|
|
|
| newline = None
|
| default_encoding = None
|
| else:
|
| mode = "w"
|
|
|
|
|
|
|
| newline = "\n"
|
| default_encoding = "utf-8"
|
|
|
| if encoding is None:
|
| encoding = default_encoding
|
|
|
| if dedent:
|
| content = textwrap.dedent(content)
|
|
|
| with open(file_path, mode=mode, encoding=encoding, newline=newline) as f:
|
| f.write(content)
|
|
|
|
|
| def write_newer_file(file_path, newer_than, content, dedent=False, encoding=None):
|
| r"""
|
| Write `content` to the file `file_path` without translating `'\n'`
|
| into `os.linesep` and make sure it is newer than the file `newer_than`.
|
|
|
| The default encoding is `'utf-8'` (same as for `write_file`).
|
| """
|
| write_file(file_path, content, dedent=dedent, encoding=encoding)
|
|
|
| try:
|
| other_time = os.path.getmtime(newer_than)
|
| except OSError:
|
|
|
| other_time = None
|
|
|
| while other_time is None or other_time >= os.path.getmtime(file_path):
|
| write_file(file_path, content, dedent=dedent, encoding=encoding)
|
|
|