File size: 4,033 Bytes
4c69128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Optional small-model assist: propose extra characterization inputs.

This never weakens the "capture-then-assert" guarantee in generator.py -- any
input tuple suggested here is executed by `generator.capture` exactly like the
deterministic inputs, so the recorded expectation always comes from running the
real code, never from the model. A model that proposes a redundant or useless
input just yields a redundant (still-correct) test; it can never make a test
green by assertion alone.
"""

from __future__ import annotations

import ast
import re

from analyzer import FunctionInfo

try:
    import spaces
except ImportError:  # local/dev environments without the `spaces` package
    class _SpacesShim:
        @staticmethod
        def GPU(*args, **kwargs):
            if args and callable(args[0]):
                return args[0]

            def decorator(fn):
                return fn

            return decorator

    spaces = _SpacesShim()


MODEL_ID = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
MAX_NEW_TOKENS = 96

_model = None
_tokenizer = None


def _load_model():
    global _model, _tokenizer
    if _model is None:
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer

        _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        _model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
    return _model, _tokenizer


@spaces.GPU(duration=120)
def suggest_extra_inputs(fn: FunctionInfo, existing: list[tuple], max_new: int = 2) -> list[tuple]:
    """Ask a small coder model for extra argument tuples to try for `fn`.

    Returns [] on any failure (no GPU, model unavailable, bad output, ...) so the
    deterministic pipeline never depends on this succeeding.
    """
    try:
        import torch

        model, tokenizer = _load_model()
        model.to("cuda" if torch.cuda.is_available() else "cpu")
    except Exception:
        return []

    prompt = _build_prompt(fn, existing, max_new)
    try:
        messages = [{"role": "user", "content": prompt}]
        input_ids = tokenizer.apply_chat_template(
            messages, add_generation_prompt=True, return_tensors="pt"
        ).to(model.device)
        output_ids = model.generate(
            input_ids,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )
        text = tokenizer.decode(output_ids[0, input_ids.shape[1] :], skip_special_tokens=True)
    except Exception:
        return []

    return _parse_tuples(text, arity=len(fn.parameters), limit=max_new)


def _build_prompt(fn: FunctionInfo, existing: list[tuple], max_new: int) -> str:
    signature = ", ".join(fn.parameters) or "(no arguments)"
    existing_repr = ", ".join(repr(case) for case in existing[:4])
    return (
        f"Function under test: {fn.qualname}({signature})\n"
        f"Docstring: {fn.docstring or '(none)'}\n"
        f"Argument tuples already tried: [{existing_repr}]\n\n"
        f"Suggest {max_new} new argument tuples that probe edge cases "
        "(empty, zero, negative, boundary, or unusual values) and differ from "
        "the ones already tried. Respond with ONLY a Python list of tuples, "
        "e.g. [(0, 'x'), (-1, '')]. No explanation, no code block."
    )


def _parse_tuples(text: str, arity: int, limit: int) -> list[tuple]:
    """Pull up to `limit` arity-matched argument tuples out of free-form model text."""
    match = re.search(r"\[.*\]", text, re.DOTALL)
    if not match:
        return []
    try:
        parsed = ast.literal_eval(match.group(0))
    except (ValueError, SyntaxError):
        return []
    if not isinstance(parsed, list):
        return []

    cases: list[tuple] = []
    for item in parsed:
        if isinstance(item, tuple) and len(item) == arity:
            cases.append(item)
        elif arity == 1 and isinstance(item, (int, float, str, bool)):
            cases.append((item,))
        if len(cases) >= limit:
            break
    return cases