| | |
| | """ |
| | Evaluation script for expression generation experiments. |
| | |
| | Evaluates trained models on: |
| | 1. Valid Rate: % expressions that can be parsed and evaluated |
| | 2. Stopping Rate: % that stop correctly (contain end marker) |
| | 3. Symbol Accuracy: % that use only symbols from prompt |
| | 4. Garbage Rate: % with non-mathematical tokens |
| | |
| | Usage: |
| | python scripts/evaluate_experiments.py \ |
| | --model_path ./output/exp_a_json \ |
| | --experiment_type json \ |
| | --num_samples 200 \ |
| | --output_file ./results/exp_a_results.json |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import logging |
| | import os |
| | import re |
| | import sys |
| | from pathlib import Path |
| | from typing import Dict, List, Optional, Tuple |
| |
|
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList |
| | from peft import PeftModel |
| |
|
| | |
| | sys.path.insert(0, str(Path(__file__).parent.parent)) |
| | from classes.expression import Expression |
| |
|
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(levelname)s - %(message)s' |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | |
| | GARBAGE_WORDS = [ |
| | "Buyable", "Instore", "Online", "Stockholm", "Muslims", "crash", |
| | "Berman", "expressed", "fluent", "Avenger", "repositories", |
| | "GREEN", "intuition", "records", "xstatics", "xid", "sinmod", |
| | "Pressure", "XP", "Variables", "Operators", "Constants" |
| | ] |
| |
|
| |
|
| | class ExpressionStoppingCriteria(StoppingCriteria): |
| | """Stop generation when end marker is detected.""" |
| |
|
| | def __init__(self, tokenizer, stop_sequences: List[str]): |
| | self.tokenizer = tokenizer |
| | self.stop_ids = [] |
| | for seq in stop_sequences: |
| | ids = tokenizer.encode(seq, add_special_tokens=False) |
| | if ids: |
| | self.stop_ids.append(ids) |
| |
|
| | def __call__(self, input_ids, scores, **kwargs) -> bool: |
| | for stop_ids in self.stop_ids: |
| | if len(input_ids[0]) >= len(stop_ids): |
| | if input_ids[0][-len(stop_ids):].tolist() == stop_ids: |
| | return True |
| | return False |
| |
|
| |
|
| | def load_model(model_path: str, experiment_type: str) -> Tuple: |
| | """Load trained model and tokenizer.""" |
| | logger.info(f"Loading model from {model_path}") |
| |
|
| | |
| | exp_info_path = os.path.join(model_path, "experiment_info.json") |
| | if os.path.exists(exp_info_path): |
| | with open(exp_info_path) as f: |
| | exp_info = json.load(f) |
| | logger.info(f"Experiment info: {exp_info}") |
| | use_native_eos = exp_info.get("use_native_eos", False) |
| | else: |
| | use_native_eos = (experiment_type == "eos") |
| | logger.warning("No experiment_info.json found, inferring from experiment_type") |
| |
|
| | |
| | logger.info("Loading base GPT-2...") |
| | model = AutoModelForCausalLM.from_pretrained( |
| | "gpt2", |
| | torch_dtype=torch.float16, |
| | device_map="auto" |
| | ) |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| |
|
| | |
| | if not use_native_eos: |
| | tokenizer.add_special_tokens({ |
| | "additional_special_tokens": ["<|startofex|>", "<|endofex|>"] |
| | }) |
| | model.resize_token_embeddings(len(tokenizer)) |
| |
|
| | |
| | logger.info("Loading adapter...") |
| | model = PeftModel.from_pretrained(model, model_path) |
| | model = model.merge_and_unload() |
| | model.eval() |
| |
|
| | return model, tokenizer, use_native_eos |
| |
|
| |
|
| | def create_prompt_json(vars_list: List[str], ops_list: List[str], cons: str = "C") -> str: |
| | """Create JSON format prompt for generation.""" |
| | prompt = { |
| | "vars": vars_list, |
| | "ops": ops_list, |
| | "cons": cons, |
| | "expr": "" |
| | } |
| | |
| | prompt_str = json.dumps(prompt, ensure_ascii=False) |
| | |
| | prompt_str = prompt_str.rsplit('"expr":', 1)[0] + '"expr": "' |
| | return prompt_str |
| |
|
| |
|
| | def create_prompt_eos(vars_list: List[str], ops_list: List[str], cons: str = "C") -> str: |
| | """Create EOS format prompt for generation.""" |
| | lines = [ |
| | f"vars: {', '.join(vars_list)}", |
| | f"oper: {', '.join(ops_list)}", |
| | f"cons: {cons}", |
| | "expr: " |
| | ] |
| | return "\n".join(lines) |
| |
|
| |
|
| | def extract_expression_json(output: str) -> Optional[str]: |
| | """Extract expression from JSON format output.""" |
| | try: |
| | |
| | if output.strip().endswith("}"): |
| | obj = json.loads(output) |
| | return obj.get("expr", None) |
| | except: |
| | pass |
| |
|
| | |
| | match = re.search(r'"expr":\s*"([^"]*)"', output) |
| | if match: |
| | return match.group(1) |
| |
|
| | |
| | match = re.search(r'"expr":\s*"([^"]*)', output) |
| | if match: |
| | return match.group(1) |
| |
|
| | return None |
| |
|
| |
|
| | def extract_expression_eos(output: str, end_marker: str) -> Optional[str]: |
| | """Extract expression from EOS format output.""" |
| | if "expr:" not in output: |
| | return None |
| |
|
| | |
| | expr_part = output.split("expr:")[-1].strip() |
| |
|
| | |
| | if end_marker in expr_part: |
| | expr_part = expr_part.split(end_marker)[0].strip() |
| |
|
| | |
| | expr_part = expr_part.split("\n")[0].strip() |
| |
|
| | return expr_part if expr_part else None |
| |
|
| |
|
| | def validate_expression(expr_str: str, allowed_vars: set, allowed_ops: set) -> Dict: |
| | """Validate an expression for correctness.""" |
| | result = { |
| | "raw": expr_str, |
| | "is_valid": False, |
| | "is_parseable": False, |
| | "uses_correct_symbols": False, |
| | "has_garbage": False, |
| | "error": None |
| | } |
| |
|
| | if not expr_str or not expr_str.strip(): |
| | result["error"] = "Empty expression" |
| | return result |
| |
|
| | |
| | for word in GARBAGE_WORDS: |
| | if word.lower() in expr_str.lower(): |
| | result["has_garbage"] = True |
| | result["error"] = f"Contains garbage: {word}" |
| | return result |
| |
|
| | |
| | try: |
| | expr = Expression(expr_str, is_prefix=False) |
| | result["is_parseable"] = True |
| |
|
| | |
| | X_test = [[1.0] * 10] |
| | eval_result = expr.evaluate(X_test) |
| | if len(eval_result) > 0: |
| | val = eval_result[0] |
| | if val == val and val != float('inf') and val != float('-inf'): |
| | result["is_valid"] = True |
| |
|
| | except Exception as e: |
| | result["error"] = str(e)[:100] |
| |
|
| | |
| | expr_clean = expr_str.replace(" ", "") |
| |
|
| | |
| | used_vars = set(re.findall(r'x_\d+', expr_clean)) |
| | used_ops = set() |
| |
|
| | for op in ["sin", "cos", "tan", "exp", "log", "sqrt", "abs", "asin", "acos", "atan"]: |
| | if op in expr_clean: |
| | used_ops.add(op) |
| |
|
| | for op in ["+", "-", "*", "/", "**"]: |
| | if op in expr_clean: |
| | used_ops.add(op) |
| |
|
| | |
| | var_ok = used_vars.issubset(allowed_vars) |
| | op_ok = used_ops.issubset(allowed_ops) |
| | result["uses_correct_symbols"] = var_ok and op_ok |
| |
|
| | if not var_ok: |
| | invalid_vars = used_vars - allowed_vars |
| | result["error"] = f"Invalid vars: {invalid_vars}" |
| |
|
| | return result |
| |
|
| |
|
| | def generate_and_evaluate( |
| | model, |
| | tokenizer, |
| | experiment_type: str, |
| | use_native_eos: bool, |
| | num_samples: int = 100, |
| | test_prompts: Optional[List[Dict]] = None |
| | ) -> Dict: |
| | """Generate expressions and evaluate quality.""" |
| |
|
| | if test_prompts is None: |
| | |
| | test_prompts = [ |
| | {"vars": ["x_1", "x_2"], "ops": ["*", "+", "-", "sin", "cos"], "cons": "C"}, |
| | {"vars": ["x_1", "x_2", "x_3"], "ops": ["*", "+", "/", "exp", "log"], "cons": "C"}, |
| | {"vars": ["x_1"], "ops": ["*", "**", "sin", "sqrt"], "cons": "C"}, |
| | {"vars": ["x_1", "x_2", "x_3", "x_4"], "ops": ["*", "+", "-", "/"], "cons": "C"}, |
| | ] |
| |
|
| | |
| | if use_native_eos: |
| | end_marker = "<|endoftext|>" |
| | stop_sequences = ["<|endoftext|>", "\n\nvars:"] |
| | else: |
| | end_marker = "<|endofex|>" |
| | stop_sequences = ["<|endofex|>", '"}', "\n\nvars:"] |
| |
|
| | stopping_criteria = StoppingCriteriaList([ |
| | ExpressionStoppingCriteria(tokenizer, stop_sequences) |
| | ]) |
| |
|
| | |
| | gen_config = { |
| | "temperature": 0.7, |
| | "top_k": 50, |
| | "top_p": 0.9, |
| | "max_new_tokens": 128, |
| | "do_sample": True, |
| | "pad_token_id": tokenizer.eos_token_id, |
| | } |
| |
|
| | results = { |
| | "total": 0, |
| | "valid": 0, |
| | "parseable": 0, |
| | "correct_symbols": 0, |
| | "garbage": 0, |
| | "stopped_correctly": 0, |
| | "samples": [] |
| | } |
| |
|
| | samples_per_prompt = num_samples // len(test_prompts) |
| |
|
| | logger.info(f"Generating {num_samples} samples ({samples_per_prompt} per prompt)...") |
| |
|
| | for prompt_config in test_prompts: |
| | vars_list = prompt_config["vars"] |
| | ops_list = prompt_config["ops"] |
| | cons = prompt_config.get("cons", "C") |
| |
|
| | allowed_vars = set(vars_list) | {cons} |
| | allowed_ops = set(ops_list) | {"(", ")"} |
| |
|
| | |
| | if experiment_type == "json": |
| | prompt = create_prompt_json(vars_list, ops_list, cons) |
| | else: |
| | prompt = create_prompt_eos(vars_list, ops_list, cons) |
| |
|
| | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| |
|
| | for i in range(samples_per_prompt): |
| | results["total"] += 1 |
| |
|
| | |
| | output = model.generate( |
| | **inputs, |
| | **gen_config, |
| | stopping_criteria=stopping_criteria |
| | ) |
| | output_text = tokenizer.decode(output[0], skip_special_tokens=False) |
| |
|
| | |
| | if experiment_type == "json": |
| | expr_str = extract_expression_json(output_text) |
| | else: |
| | expr_str = extract_expression_eos(output_text, end_marker) |
| |
|
| | |
| | stopped_correctly = end_marker in output_text |
| | if stopped_correctly: |
| | results["stopped_correctly"] += 1 |
| |
|
| | |
| | if expr_str: |
| | validation = validate_expression(expr_str, allowed_vars, allowed_ops) |
| |
|
| | if validation["is_valid"]: |
| | results["valid"] += 1 |
| | if validation["is_parseable"]: |
| | results["parseable"] += 1 |
| | if validation["uses_correct_symbols"]: |
| | results["correct_symbols"] += 1 |
| | if validation["has_garbage"]: |
| | results["garbage"] += 1 |
| |
|
| | |
| | sample = { |
| | "prompt_vars": vars_list, |
| | "prompt_ops": ops_list, |
| | "expression": expr_str, |
| | "stopped_correctly": stopped_correctly, |
| | **validation |
| | } |
| | results["samples"].append(sample) |
| | else: |
| | results["garbage"] += 1 |
| | results["samples"].append({ |
| | "prompt_vars": vars_list, |
| | "prompt_ops": ops_list, |
| | "expression": None, |
| | "stopped_correctly": stopped_correctly, |
| | "is_valid": False, |
| | "error": "Could not extract expression" |
| | }) |
| |
|
| | |
| | if results["total"] % 20 == 0: |
| | logger.info(f"Progress: {results['total']}/{num_samples}") |
| |
|
| | return results |
| |
|
| |
|
| | def print_report(results: Dict, experiment_name: str): |
| | """Print evaluation report.""" |
| | total = results["total"] |
| |
|
| | print("\n" + "=" * 60) |
| | print(f"EVALUATION REPORT: {experiment_name}") |
| | print("=" * 60) |
| |
|
| | print(f"\nTotal samples: {total}") |
| |
|
| | metrics = [ |
| | ("Valid Rate", results["valid"] / total * 100), |
| | ("Parseable Rate", results["parseable"] / total * 100), |
| | ("Correct Symbols", results["correct_symbols"] / total * 100), |
| | ("Stopping Rate", results["stopped_correctly"] / total * 100), |
| | ("Garbage Rate", results["garbage"] / total * 100), |
| | ] |
| |
|
| | print("\nMetrics:") |
| | print("-" * 40) |
| | for name, value in metrics: |
| | status = "PASS" if (name != "Garbage Rate" and value >= 80) or (name == "Garbage Rate" and value < 5) else "FAIL" |
| | print(f" {name:<20s}: {value:6.1f}% [{status}]") |
| |
|
| | |
| | print("\n" + "-" * 40) |
| | print("Sample Outputs:") |
| | print("-" * 40) |
| |
|
| | valid_samples = [s for s in results["samples"] if s.get("is_valid")] |
| | invalid_samples = [s for s in results["samples"] if not s.get("is_valid")] |
| |
|
| | print("\nValid examples:") |
| | for sample in valid_samples[:5]: |
| | expr = sample.get("expression", "N/A") |
| | vars_str = ", ".join(sample.get("prompt_vars", [])) |
| | print(f" [{vars_str}] -> {expr}") |
| |
|
| | print("\nInvalid examples:") |
| | for sample in invalid_samples[:5]: |
| | expr = sample.get("expression", "N/A") |
| | error = sample.get("error", "Unknown") |
| | print(f" {expr[:50]}... | Error: {error}") |
| |
|
| | print("\n" + "=" * 60) |
| |
|
| | |
| | valid_rate = results["valid"] / total * 100 |
| | stopping_rate = results["stopped_correctly"] / total * 100 |
| | garbage_rate = results["garbage"] / total * 100 |
| |
|
| | success = valid_rate >= 80 and stopping_rate >= 90 and garbage_rate < 5 |
| |
|
| | print(f"\nOVERALL: {'SUCCESS' if success else 'NEEDS IMPROVEMENT'}") |
| | print("=" * 60) |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="Evaluate expression generation experiments" |
| | ) |
| | parser.add_argument("--model_path", type=str, required=True, |
| | help="Path to trained model") |
| | parser.add_argument("--experiment_type", type=str, required=True, |
| | choices=["json", "eos"], |
| | help="Experiment type (json or eos)") |
| | parser.add_argument("--num_samples", type=int, default=200, |
| | help="Number of samples to generate") |
| | parser.add_argument("--output_file", type=str, default=None, |
| | help="Path to save results JSON") |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | model, tokenizer, use_native_eos = load_model( |
| | args.model_path, |
| | args.experiment_type |
| | ) |
| |
|
| | |
| | results = generate_and_evaluate( |
| | model=model, |
| | tokenizer=tokenizer, |
| | experiment_type=args.experiment_type, |
| | use_native_eos=use_native_eos, |
| | num_samples=args.num_samples |
| | ) |
| |
|
| | |
| | experiment_name = f"EXP-{'A' if args.experiment_type == 'json' else 'B'} ({args.experiment_type.upper()})" |
| | print_report(results, experiment_name) |
| |
|
| | |
| | if args.output_file: |
| | os.makedirs(os.path.dirname(args.output_file), exist_ok=True) |
| |
|
| | |
| | save_results = {k: v for k, v in results.items() if k != "samples"} |
| | save_results["sample_count"] = len(results["samples"]) |
| | save_results["valid_samples"] = [s for s in results["samples"] if s.get("is_valid")][:20] |
| | save_results["invalid_samples"] = [s for s in results["samples"] if not s.get("is_valid")][:20] |
| |
|
| | with open(args.output_file, "w") as f: |
| | json.dump(save_results, f, indent=2) |
| |
|
| | logger.info(f"Results saved to: {args.output_file}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|