TestForge / model_suggest.py
perceptron01's picture
Upload 7 files
4c69128 verified
Raw
History Blame Contribute Delete
4.03 kB
"""Optional small-model assist: propose extra characterization inputs.
This never weakens the "capture-then-assert" guarantee in generator.py -- any
input tuple suggested here is executed by `generator.capture` exactly like the
deterministic inputs, so the recorded expectation always comes from running the
real code, never from the model. A model that proposes a redundant or useless
input just yields a redundant (still-correct) test; it can never make a test
green by assertion alone.
"""
from __future__ import annotations
import ast
import re
from analyzer import FunctionInfo
try:
import spaces
except ImportError: # local/dev environments without the `spaces` package
class _SpacesShim:
@staticmethod
def GPU(*args, **kwargs):
if args and callable(args[0]):
return args[0]
def decorator(fn):
return fn
return decorator
spaces = _SpacesShim()
MODEL_ID = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
MAX_NEW_TOKENS = 96
_model = None
_tokenizer = None
def _load_model():
global _model, _tokenizer
if _model is None:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
return _model, _tokenizer
@spaces.GPU(duration=120)
def suggest_extra_inputs(fn: FunctionInfo, existing: list[tuple], max_new: int = 2) -> list[tuple]:
"""Ask a small coder model for extra argument tuples to try for `fn`.
Returns [] on any failure (no GPU, model unavailable, bad output, ...) so the
deterministic pipeline never depends on this succeeding.
"""
try:
import torch
model, tokenizer = _load_model()
model.to("cuda" if torch.cuda.is_available() else "cpu")
except Exception:
return []
prompt = _build_prompt(fn, existing, max_new)
try:
messages = [{"role": "user", "content": prompt}]
input_ids = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
).to(model.device)
output_ids = model.generate(
input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(output_ids[0, input_ids.shape[1] :], skip_special_tokens=True)
except Exception:
return []
return _parse_tuples(text, arity=len(fn.parameters), limit=max_new)
def _build_prompt(fn: FunctionInfo, existing: list[tuple], max_new: int) -> str:
signature = ", ".join(fn.parameters) or "(no arguments)"
existing_repr = ", ".join(repr(case) for case in existing[:4])
return (
f"Function under test: {fn.qualname}({signature})\n"
f"Docstring: {fn.docstring or '(none)'}\n"
f"Argument tuples already tried: [{existing_repr}]\n\n"
f"Suggest {max_new} new argument tuples that probe edge cases "
"(empty, zero, negative, boundary, or unusual values) and differ from "
"the ones already tried. Respond with ONLY a Python list of tuples, "
"e.g. [(0, 'x'), (-1, '')]. No explanation, no code block."
)
def _parse_tuples(text: str, arity: int, limit: int) -> list[tuple]:
"""Pull up to `limit` arity-matched argument tuples out of free-form model text."""
match = re.search(r"\[.*\]", text, re.DOTALL)
if not match:
return []
try:
parsed = ast.literal_eval(match.group(0))
except (ValueError, SyntaxError):
return []
if not isinstance(parsed, list):
return []
cases: list[tuple] = []
for item in parsed:
if isinstance(item, tuple) and len(item) == arity:
cases.append(item)
elif arity == 1 and isinstance(item, (int, float, str, bool)):
cases.append((item,))
if len(cases) >= limit:
break
return cases