singhalamaan116's picture
Update ecoeval/core.py
cb2d7b5 verified
# ecoeval/core.py
import time
import traceback
from typing import Dict, Any, Optional, List
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub.errors import RepositoryNotFoundError
from .config import EcoEvalConfig
# ---------- Prompt template to force clean Python output ----------
PROMPT_TEMPLATE = """
You are an expert Python 3 programmer.
Write ONLY valid Python 3 code.
Requirements:
- Define exactly ONE function that solves the task.
- Do NOT print anything.
- Do NOT include explanations, comments, or examples.
- Do NOT include '>>>' prompts or any natural language text.
- Only return the function definition and any necessary helper code.
Task:
{task}
"""
# ---------- Device + model loading ----------
def _select_device(cfg: EcoEvalConfig) -> torch.device:
if cfg.device == "cuda" and torch.cuda.is_available():
return torch.device("cuda")
if cfg.device == "auto" and torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def load_model_and_tokenizer(cfg: EcoEvalConfig):
"""
Load tokenizer and model from Hugging Face Hub.
Raises a clean RuntimeError if the model id is invalid.
"""
device = _select_device(cfg)
try:
tokenizer = AutoTokenizer.from_pretrained(cfg.model_id)
model = AutoModelForCausalLM.from_pretrained(cfg.model_id)
except (OSError, RepositoryNotFoundError) as e:
raise RuntimeError(
f"Could not load model '{cfg.model_id}'. "
"Make sure it is a valid public model on Hugging Face "
"(e.g. 'gpt2', 'Salesforce/codegen-350M-mono', "
"'bigcode/tiny_starcoder_py')."
) from e
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model.to(device)
model.eval()
return tokenizer, model, device
# ---------- Output cleaning / extraction ----------
def _strip_leading_docstring(text: str) -> str:
"""
Remove a leading triple-quoted docstring if present.
"""
for quote in ('"""', "'''"):
if text.startswith(quote):
parts = text.split(quote)
if len(parts) >= 3:
# parts: ["", docstring, rest...]
return quote.join(parts[2:]).lstrip()
return text
def _extract_code(generated: str) -> str:
"""
Clean raw model output into executable Python:
- Keep from the first 'def ' onwards when possible.
- Remove triple-quoted docstrings.
- Drop obvious natural-language lines.
- Stop at top-level 'if __name__ == "__main__"' or other
top-level control-flow scaffolding that often causes
indentation errors.
"""
text = generated.strip()
# If there's a function definition, keep from there.
idx = text.find("def ")
if idx != -1:
text = text[idx:]
# Remove a leading docstring if present.
text = _strip_leading_docstring(text)
bad_prefixes = (
">>>",
"Example:",
"Examples:",
"Input:",
"Input Format:",
"Output:",
"Output Format:",
"Python 3:",
"The function ",
"The first line ",
"The above code",
"The following code",
"- ", # bullet lists like "- Write a function ..."
)
lines = text.splitlines()
cleaned: List[str] = []
in_docstring = False
for line in lines:
stripped = line.strip()
# Track and drop any triple-quoted docstring blocks anywhere
if '"""' in stripped or "'''" in stripped:
# toggle docstring state and skip this line
in_docstring = not in_docstring
continue
if in_docstring:
continue
if not stripped:
# keep blank lines (can be inside function)
cleaned.append("")
continue
# Drop obvious NL/meta text
if any(stripped.startswith(bp) for bp in bad_prefixes):
continue
if stripped.startswith("```"):
continue
# Detect top-level (unindented) scaffolding and stop there
is_top_level = (line == stripped) # no leading spaces/tabs
if is_top_level and stripped.startswith("if __name__"):
# stop before main-guard
break
if is_top_level and stripped.startswith(("for ", "while ", "if ", "elif ", "else:", "try:", "except", "with ")):
# likely problem-causing scaffold; stop here
break
cleaned.append(line)
code = "\n".join(cleaned).rstrip()
return code
# ---------- Generation + execution ----------
def generate_code(
prompt: str,
tokenizer,
model,
cfg: EcoEvalConfig,
device: torch.device,
) -> str:
"""
Generate Python code given a full prompt (already templated).
"""
encoded = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**encoded,
max_new_tokens=cfg.max_new_tokens,
temperature=cfg.temperature,
top_p=cfg.top_p,
do_sample=cfg.temperature > 0,
pad_token_id=tokenizer.pad_token_id,
)
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Take the part after the prompt to avoid echoing it.
if full_text.startswith(prompt):
raw = full_text[len(prompt):].strip()
else:
raw = full_text.strip()
return _extract_code(raw)
def run_python_tests(pred_code: str, test_code: str) -> bool:
"""
Very simple sandbox: execs pred_code + test_code in the same namespace.
NOTE: This is not safe against malicious code. For research/demo only.
"""
namespace: Dict[str, Any] = {}
try:
exec(pred_code, namespace, namespace)
exec(test_code, namespace, namespace)
return True
except Exception:
traceback.print_exc()
return False
# ---------- Main benchmark loop ----------
def run_benchmark(
dataset: Dataset,
cfg: EcoEvalConfig,
limit: Optional[int] = None,
) -> Dict[str, Any]:
"""
Run the EcoEval benchmark over a dataset.
Dataset must have columns:
- 'prompt' : natural language description of the task
- 'test_code' : Python unit tests to validate the solution
"""
tokenizer, model, device = load_model_and_tokenizer(cfg)
n = len(dataset)
if limit is not None:
n = min(n, limit)
passed = 0
total = 0
per_task: List[Dict[str, Any]] = []
start = time.time()
for idx in range(n):
row = dataset[idx]
task_text = row["prompt"]
test_code = row["test_code"]
# 🔑 ALWAYS wrap the task in our strict code-only template
full_prompt = PROMPT_TEMPLATE.format(task=task_text)
t0 = time.time()
pred_code = generate_code(full_prompt, tokenizer, model, cfg, device)
ok = run_python_tests(pred_code, test_code)
t1 = time.time()
total += 1
passed += int(ok)
per_task.append(
{
"task_id": idx,
"prompt_preview": (task_text[:80] + "…") if len(task_text) > 80 else task_text,
"passed": bool(ok),
"runtime_s": round(t1 - t0, 3),
}
)
end = time.time()
elapsed = end - start
accuracy = passed / total if total > 0 else 0.0
return {
"tasks": total,
"passed": passed,
"accuracy": accuracy,
"runtime_seconds": elapsed,
"per_task": per_task,
}