Spaces:
Sleeping
Sleeping
| # ecoeval/core.py | |
| import time | |
| import traceback | |
| from typing import Dict, Any, Optional, List | |
| import torch | |
| from datasets import Dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from huggingface_hub.errors import RepositoryNotFoundError | |
| from .config import EcoEvalConfig | |
| # ---------- Prompt template to force clean Python output ---------- | |
| PROMPT_TEMPLATE = """ | |
| You are an expert Python 3 programmer. | |
| Write ONLY valid Python 3 code. | |
| Requirements: | |
| - Define exactly ONE function that solves the task. | |
| - Do NOT print anything. | |
| - Do NOT include explanations, comments, or examples. | |
| - Do NOT include '>>>' prompts or any natural language text. | |
| - Only return the function definition and any necessary helper code. | |
| Task: | |
| {task} | |
| """ | |
| # ---------- Device + model loading ---------- | |
| def _select_device(cfg: EcoEvalConfig) -> torch.device: | |
| if cfg.device == "cuda" and torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| if cfg.device == "auto" and torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| return torch.device("cpu") | |
| def load_model_and_tokenizer(cfg: EcoEvalConfig): | |
| """ | |
| Load tokenizer and model from Hugging Face Hub. | |
| Raises a clean RuntimeError if the model id is invalid. | |
| """ | |
| device = _select_device(cfg) | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(cfg.model_id) | |
| model = AutoModelForCausalLM.from_pretrained(cfg.model_id) | |
| except (OSError, RepositoryNotFoundError) as e: | |
| raise RuntimeError( | |
| f"Could not load model '{cfg.model_id}'. " | |
| "Make sure it is a valid public model on Hugging Face " | |
| "(e.g. 'gpt2', 'Salesforce/codegen-350M-mono', " | |
| "'bigcode/tiny_starcoder_py')." | |
| ) from e | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| model.to(device) | |
| model.eval() | |
| return tokenizer, model, device | |
| # ---------- Output cleaning / extraction ---------- | |
| def _strip_leading_docstring(text: str) -> str: | |
| """ | |
| Remove a leading triple-quoted docstring if present. | |
| """ | |
| for quote in ('"""', "'''"): | |
| if text.startswith(quote): | |
| parts = text.split(quote) | |
| if len(parts) >= 3: | |
| # parts: ["", docstring, rest...] | |
| return quote.join(parts[2:]).lstrip() | |
| return text | |
| def _extract_code(generated: str) -> str: | |
| """ | |
| Clean raw model output into executable Python: | |
| - Keep from the first 'def ' onwards when possible. | |
| - Remove triple-quoted docstrings. | |
| - Drop obvious natural-language lines. | |
| - Stop at top-level 'if __name__ == "__main__"' or other | |
| top-level control-flow scaffolding that often causes | |
| indentation errors. | |
| """ | |
| text = generated.strip() | |
| # If there's a function definition, keep from there. | |
| idx = text.find("def ") | |
| if idx != -1: | |
| text = text[idx:] | |
| # Remove a leading docstring if present. | |
| text = _strip_leading_docstring(text) | |
| bad_prefixes = ( | |
| ">>>", | |
| "Example:", | |
| "Examples:", | |
| "Input:", | |
| "Input Format:", | |
| "Output:", | |
| "Output Format:", | |
| "Python 3:", | |
| "The function ", | |
| "The first line ", | |
| "The above code", | |
| "The following code", | |
| "- ", # bullet lists like "- Write a function ..." | |
| ) | |
| lines = text.splitlines() | |
| cleaned: List[str] = [] | |
| in_docstring = False | |
| for line in lines: | |
| stripped = line.strip() | |
| # Track and drop any triple-quoted docstring blocks anywhere | |
| if '"""' in stripped or "'''" in stripped: | |
| # toggle docstring state and skip this line | |
| in_docstring = not in_docstring | |
| continue | |
| if in_docstring: | |
| continue | |
| if not stripped: | |
| # keep blank lines (can be inside function) | |
| cleaned.append("") | |
| continue | |
| # Drop obvious NL/meta text | |
| if any(stripped.startswith(bp) for bp in bad_prefixes): | |
| continue | |
| if stripped.startswith("```"): | |
| continue | |
| # Detect top-level (unindented) scaffolding and stop there | |
| is_top_level = (line == stripped) # no leading spaces/tabs | |
| if is_top_level and stripped.startswith("if __name__"): | |
| # stop before main-guard | |
| break | |
| if is_top_level and stripped.startswith(("for ", "while ", "if ", "elif ", "else:", "try:", "except", "with ")): | |
| # likely problem-causing scaffold; stop here | |
| break | |
| cleaned.append(line) | |
| code = "\n".join(cleaned).rstrip() | |
| return code | |
| # ---------- Generation + execution ---------- | |
| def generate_code( | |
| prompt: str, | |
| tokenizer, | |
| model, | |
| cfg: EcoEvalConfig, | |
| device: torch.device, | |
| ) -> str: | |
| """ | |
| Generate Python code given a full prompt (already templated). | |
| """ | |
| encoded = tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **encoded, | |
| max_new_tokens=cfg.max_new_tokens, | |
| temperature=cfg.temperature, | |
| top_p=cfg.top_p, | |
| do_sample=cfg.temperature > 0, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Take the part after the prompt to avoid echoing it. | |
| if full_text.startswith(prompt): | |
| raw = full_text[len(prompt):].strip() | |
| else: | |
| raw = full_text.strip() | |
| return _extract_code(raw) | |
| def run_python_tests(pred_code: str, test_code: str) -> bool: | |
| """ | |
| Very simple sandbox: execs pred_code + test_code in the same namespace. | |
| NOTE: This is not safe against malicious code. For research/demo only. | |
| """ | |
| namespace: Dict[str, Any] = {} | |
| try: | |
| exec(pred_code, namespace, namespace) | |
| exec(test_code, namespace, namespace) | |
| return True | |
| except Exception: | |
| traceback.print_exc() | |
| return False | |
| # ---------- Main benchmark loop ---------- | |
| def run_benchmark( | |
| dataset: Dataset, | |
| cfg: EcoEvalConfig, | |
| limit: Optional[int] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Run the EcoEval benchmark over a dataset. | |
| Dataset must have columns: | |
| - 'prompt' : natural language description of the task | |
| - 'test_code' : Python unit tests to validate the solution | |
| """ | |
| tokenizer, model, device = load_model_and_tokenizer(cfg) | |
| n = len(dataset) | |
| if limit is not None: | |
| n = min(n, limit) | |
| passed = 0 | |
| total = 0 | |
| per_task: List[Dict[str, Any]] = [] | |
| start = time.time() | |
| for idx in range(n): | |
| row = dataset[idx] | |
| task_text = row["prompt"] | |
| test_code = row["test_code"] | |
| # 🔑 ALWAYS wrap the task in our strict code-only template | |
| full_prompt = PROMPT_TEMPLATE.format(task=task_text) | |
| t0 = time.time() | |
| pred_code = generate_code(full_prompt, tokenizer, model, cfg, device) | |
| ok = run_python_tests(pred_code, test_code) | |
| t1 = time.time() | |
| total += 1 | |
| passed += int(ok) | |
| per_task.append( | |
| { | |
| "task_id": idx, | |
| "prompt_preview": (task_text[:80] + "…") if len(task_text) > 80 else task_text, | |
| "passed": bool(ok), | |
| "runtime_s": round(t1 - t0, 3), | |
| } | |
| ) | |
| end = time.time() | |
| elapsed = end - start | |
| accuracy = passed / total if total > 0 else 0.0 | |
| return { | |
| "tasks": total, | |
| "passed": passed, | |
| "accuracy": accuracy, | |
| "runtime_seconds": elapsed, | |
| "per_task": per_task, | |
| } | |