singhalamaan116 commited on
Commit
a06dde5
·
verified ·
1 Parent(s): d5aad4f

Update ecoeval/core.py

Browse files
Files changed (1) hide show
  1. ecoeval/core.py +80 -35
ecoeval/core.py CHANGED
@@ -6,26 +6,32 @@ from typing import Dict, Any, Optional, List
6
  import torch
7
  from datasets import Dataset
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
9
 
10
  from .config import EcoEvalConfig
11
 
12
 
13
- # ---- Prompt template to force code-only output ----
 
14
  PROMPT_TEMPLATE = """
15
  You are an expert Python 3 programmer.
16
 
17
  Write ONLY valid Python 3 code.
 
18
  Requirements:
19
- - Define exactly one function that solves the task.
20
  - Do NOT print anything.
21
  - Do NOT include explanations, comments, or examples.
22
- - Do NOT include '>>>' prompts or any text outside the function.
 
23
 
24
  Task:
25
  {task}
26
  """
27
 
28
 
 
 
29
  def _select_device(cfg: EcoEvalConfig) -> torch.device:
30
  if cfg.device == "cuda" and torch.cuda.is_available():
31
  return torch.device("cuda")
@@ -35,11 +41,23 @@ def _select_device(cfg: EcoEvalConfig) -> torch.device:
35
 
36
 
37
  def load_model_and_tokenizer(cfg: EcoEvalConfig):
 
 
 
 
38
  device = _select_device(cfg)
39
- tokenizer = AutoTokenizer.from_pretrained(cfg.model_id)
40
- model = AutoModelForCausalLM.from_pretrained(cfg.model_id)
41
 
42
- # Some code/text models don't have a pad token -> use EOS as pad
 
 
 
 
 
 
 
 
 
 
43
  if tokenizer.pad_token_id is None:
44
  tokenizer.pad_token_id = tokenizer.eos_token_id
45
 
@@ -48,44 +66,74 @@ def load_model_and_tokenizer(cfg: EcoEvalConfig):
48
  return tokenizer, model, device
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def _extract_code(generated: str) -> str:
52
  """
53
- Try to clean the raw model output into pure Python code:
54
 
55
- - keep from the first 'def ' onward if present
56
- - drop lines starting with '>>>', 'The ', 'Example:', or fenced code marks
 
57
  """
58
  text = generated.strip()
59
 
60
- # If there's a 'def ' in there, keep from that point
61
  idx = text.find("def ")
62
  if idx != -1:
63
  text = text[idx:]
64
 
65
- # Line-level cleanup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  cleaned_lines: List[str] = []
67
  for line in text.splitlines():
68
  stripped = line.strip()
69
  if not stripped:
70
- cleaned_lines.append(line)
71
  continue
72
 
73
- # Drop obvious non-code patterns
74
- if stripped.startswith(">>>"):
75
- continue
76
- if stripped.lower().startswith("example:"):
77
  continue
 
78
  if stripped.startswith("```"):
79
  continue
80
- if stripped.lower().startswith("the above code"):
81
- continue
82
- if stripped.lower().startswith("the following code"):
83
- continue
84
 
85
  cleaned_lines.append(line)
86
 
87
- return "\n".join(cleaned_lines).strip()
 
 
88
 
 
89
 
90
  def generate_code(
91
  prompt: str,
@@ -95,12 +143,9 @@ def generate_code(
95
  device: torch.device,
96
  ) -> str:
97
  """
98
- Generate code completion for a given full prompt (already templated).
99
  """
100
- encoded = tokenizer(
101
- prompt,
102
- return_tensors="pt",
103
- ).to(device)
104
 
105
  with torch.no_grad():
106
  outputs = model.generate(
@@ -114,7 +159,7 @@ def generate_code(
114
 
115
  full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
116
 
117
- # Heuristic: take the part after the prompt
118
  if full_text.startswith(prompt):
119
  raw = full_text[len(prompt):].strip()
120
  else:
@@ -125,10 +170,9 @@ def generate_code(
125
 
126
  def run_python_tests(pred_code: str, test_code: str) -> bool:
127
  """
128
- Extremely simple sandbox: execs pred_code + test_code in the same restricted namespace.
129
 
130
- NOTE: This is *not* secure against malicious code. For research/demo only.
131
- In a serious setting, you should use a proper sandbox (separate process, limits, etc.).
132
  """
133
  namespace: Dict[str, Any] = {}
134
  try:
@@ -140,17 +184,19 @@ def run_python_tests(pred_code: str, test_code: str) -> bool:
140
  return False
141
 
142
 
 
 
143
  def run_benchmark(
144
  dataset: Dataset,
145
  cfg: EcoEvalConfig,
146
  limit: Optional[int] = None,
147
  ) -> Dict[str, Any]:
148
  """
149
- Run a full benchmark over a dataset of code tasks.
150
 
151
  Dataset must have columns:
152
- - 'prompt' (natural-language task description)
153
- - 'test_code' (Python unit tests)
154
  """
155
  tokenizer, model, device = load_model_and_tokenizer(cfg)
156
 
@@ -160,7 +206,6 @@ def run_benchmark(
160
 
161
  passed = 0
162
  total = 0
163
-
164
  per_task: List[Dict[str, Any]] = []
165
 
166
  start = time.time()
@@ -170,7 +215,7 @@ def run_benchmark(
170
  task_text = row["prompt"]
171
  test_code = row["test_code"]
172
 
173
- # Build a strong instruction-style prompt
174
  full_prompt = PROMPT_TEMPLATE.format(task=task_text)
175
 
176
  t0 = time.time()
 
6
  import torch
7
  from datasets import Dataset
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from huggingface_hub.errors import RepositoryNotFoundError
10
 
11
  from .config import EcoEvalConfig
12
 
13
 
14
+ # ---------- Prompt template to force clean Python output ----------
15
+
16
  PROMPT_TEMPLATE = """
17
  You are an expert Python 3 programmer.
18
 
19
  Write ONLY valid Python 3 code.
20
+
21
  Requirements:
22
+ - Define exactly ONE function that solves the task.
23
  - Do NOT print anything.
24
  - Do NOT include explanations, comments, or examples.
25
+ - Do NOT include '>>>' prompts or any natural language text.
26
+ - Only return the function definition and any necessary helper code.
27
 
28
  Task:
29
  {task}
30
  """
31
 
32
 
33
+ # ---------- Device + model loading ----------
34
+
35
  def _select_device(cfg: EcoEvalConfig) -> torch.device:
36
  if cfg.device == "cuda" and torch.cuda.is_available():
37
  return torch.device("cuda")
 
41
 
42
 
43
  def load_model_and_tokenizer(cfg: EcoEvalConfig):
44
+ """
45
+ Load tokenizer and model from Hugging Face Hub.
46
+ Raises a clean RuntimeError if the model id is invalid.
47
+ """
48
  device = _select_device(cfg)
 
 
49
 
50
+ try:
51
+ tokenizer = AutoTokenizer.from_pretrained(cfg.model_id)
52
+ model = AutoModelForCausalLM.from_pretrained(cfg.model_id)
53
+ except (OSError, RepositoryNotFoundError) as e:
54
+ raise RuntimeError(
55
+ f"Could not load model '{cfg.model_id}'. "
56
+ "Make sure it is a valid public model on Hugging Face "
57
+ "(e.g. 'gpt2', 'Salesforce/codegen-350M-mono', "
58
+ "'bigcode/tiny_starcoder_py')."
59
+ ) from e
60
+
61
  if tokenizer.pad_token_id is None:
62
  tokenizer.pad_token_id = tokenizer.eos_token_id
63
 
 
66
  return tokenizer, model, device
67
 
68
 
69
+ # ---------- Output cleaning / extraction ----------
70
+
71
+ def _strip_leading_docstring(text: str) -> str:
72
+ """
73
+ Remove a leading triple-quoted docstring if present.
74
+ """
75
+ for quote in ('"""', "'''"):
76
+ if text.startswith(quote):
77
+ parts = text.split(quote)
78
+ if len(parts) >= 3:
79
+ # parts: ["", docstring, rest...]
80
+ return quote.join(parts[2:]).lstrip()
81
+ return text
82
+
83
+
84
  def _extract_code(generated: str) -> str:
85
  """
86
+ Clean raw model output into executable Python:
87
 
88
+ - Keep from the first 'def ' onwards when possible.
89
+ - Strip leading docstrings.
90
+ - Drop lines that are clearly meta-text (Input:, Output:, >>>, etc.).
91
  """
92
  text = generated.strip()
93
 
94
+ # If there's a function definition, keep from there.
95
  idx = text.find("def ")
96
  if idx != -1:
97
  text = text[idx:]
98
 
99
+ # Remove a leading docstring if present.
100
+ text = _strip_leading_docstring(text)
101
+
102
+ bad_prefixes = (
103
+ ">>>",
104
+ "Example:",
105
+ "Examples:",
106
+ "Input:",
107
+ "Input Format:",
108
+ "Output:",
109
+ "Output Format:",
110
+ "Python 3:",
111
+ "The function ",
112
+ "The above code",
113
+ "The following code",
114
+ "- ", # bullet lists like "- Write a function ..."
115
+ )
116
+
117
  cleaned_lines: List[str] = []
118
  for line in text.splitlines():
119
  stripped = line.strip()
120
  if not stripped:
121
+ cleaned_lines.append("") # keep blank lines for indentation blocks
122
  continue
123
 
124
+ if any(stripped.startswith(bp) for bp in bad_prefixes):
 
 
 
125
  continue
126
+
127
  if stripped.startswith("```"):
128
  continue
 
 
 
 
129
 
130
  cleaned_lines.append(line)
131
 
132
+ cleaned = "\n".join(cleaned_lines).strip()
133
+ return cleaned
134
+
135
 
136
+ # ---------- Generation + execution ----------
137
 
138
  def generate_code(
139
  prompt: str,
 
143
  device: torch.device,
144
  ) -> str:
145
  """
146
+ Generate Python code given a full prompt (already templated).
147
  """
148
+ encoded = tokenizer(prompt, return_tensors="pt").to(device)
 
 
 
149
 
150
  with torch.no_grad():
151
  outputs = model.generate(
 
159
 
160
  full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
161
 
162
+ # Take the part after the prompt to avoid echoing it.
163
  if full_text.startswith(prompt):
164
  raw = full_text[len(prompt):].strip()
165
  else:
 
170
 
171
  def run_python_tests(pred_code: str, test_code: str) -> bool:
172
  """
173
+ Very simple sandbox: execs pred_code + test_code in the same namespace.
174
 
175
+ NOTE: This is not safe against malicious code. For research/demo only.
 
176
  """
177
  namespace: Dict[str, Any] = {}
178
  try:
 
184
  return False
185
 
186
 
187
+ # ---------- Main benchmark loop ----------
188
+
189
  def run_benchmark(
190
  dataset: Dataset,
191
  cfg: EcoEvalConfig,
192
  limit: Optional[int] = None,
193
  ) -> Dict[str, Any]:
194
  """
195
+ Run the EcoEval benchmark over a dataset.
196
 
197
  Dataset must have columns:
198
+ - 'prompt' : natural language description of the task
199
+ - 'test_code' : Python unit tests to validate the solution
200
  """
201
  tokenizer, model, device = load_model_and_tokenizer(cfg)
202
 
 
206
 
207
  passed = 0
208
  total = 0
 
209
  per_task: List[Dict[str, Any]] = []
210
 
211
  start = time.time()
 
215
  task_text = row["prompt"]
216
  test_code = row["test_code"]
217
 
218
+ # 🔑 ALWAYS wrap the task in our strict code-only template
219
  full_prompt = PROMPT_TEMPLATE.format(task=task_text)
220
 
221
  t0 = time.time()