|
|
|
|
|
""" |
|
|
Evaluation script for expression generation experiments. |
|
|
|
|
|
Evaluates trained models on: |
|
|
1. Valid Rate: % expressions that can be parsed and evaluated |
|
|
2. Stopping Rate: % that stop correctly (contain end marker) |
|
|
3. Symbol Accuracy: % that use only symbols from prompt |
|
|
4. Garbage Rate: % with non-mathematical tokens |
|
|
|
|
|
Usage: |
|
|
python scripts/evaluate_experiments.py \ |
|
|
--model_path ./output/exp_a_json \ |
|
|
--experiment_type json \ |
|
|
--num_samples 200 \ |
|
|
--output_file ./results/exp_a_results.json |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import re |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
from classes.expression import Expression |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
GARBAGE_WORDS = [ |
|
|
"Buyable", "Instore", "Online", "Stockholm", "Muslims", "crash", |
|
|
"Berman", "expressed", "fluent", "Avenger", "repositories", |
|
|
"GREEN", "intuition", "records", "xstatics", "xid", "sinmod", |
|
|
"Pressure", "XP", "Variables", "Operators", "Constants" |
|
|
] |
|
|
|
|
|
|
|
|
class ExpressionStoppingCriteria(StoppingCriteria): |
|
|
"""Stop generation when end marker is detected.""" |
|
|
|
|
|
def __init__(self, tokenizer, stop_sequences: List[str]): |
|
|
self.tokenizer = tokenizer |
|
|
self.stop_ids = [] |
|
|
for seq in stop_sequences: |
|
|
ids = tokenizer.encode(seq, add_special_tokens=False) |
|
|
if ids: |
|
|
self.stop_ids.append(ids) |
|
|
|
|
|
def __call__(self, input_ids, scores, **kwargs) -> bool: |
|
|
for stop_ids in self.stop_ids: |
|
|
if len(input_ids[0]) >= len(stop_ids): |
|
|
if input_ids[0][-len(stop_ids):].tolist() == stop_ids: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def load_model(model_path: str, experiment_type: str) -> Tuple: |
|
|
"""Load trained model and tokenizer.""" |
|
|
logger.info(f"Loading model from {model_path}") |
|
|
|
|
|
|
|
|
exp_info_path = os.path.join(model_path, "experiment_info.json") |
|
|
if os.path.exists(exp_info_path): |
|
|
with open(exp_info_path) as f: |
|
|
exp_info = json.load(f) |
|
|
logger.info(f"Experiment info: {exp_info}") |
|
|
use_native_eos = exp_info.get("use_native_eos", False) |
|
|
else: |
|
|
use_native_eos = (experiment_type == "eos") |
|
|
logger.warning("No experiment_info.json found, inferring from experiment_type") |
|
|
|
|
|
|
|
|
logger.info("Loading base GPT-2...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"gpt2", |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
|
|
|
|
|
|
if not use_native_eos: |
|
|
tokenizer.add_special_tokens({ |
|
|
"additional_special_tokens": ["<|startofex|>", "<|endofex|>"] |
|
|
}) |
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
|
|
|
logger.info("Loading adapter...") |
|
|
model = PeftModel.from_pretrained(model, model_path) |
|
|
model = model.merge_and_unload() |
|
|
model.eval() |
|
|
|
|
|
return model, tokenizer, use_native_eos |
|
|
|
|
|
|
|
|
def create_prompt_json(vars_list: List[str], ops_list: List[str], cons: str = "C") -> str: |
|
|
"""Create JSON format prompt for generation.""" |
|
|
prompt = { |
|
|
"vars": vars_list, |
|
|
"ops": ops_list, |
|
|
"cons": cons, |
|
|
"expr": "" |
|
|
} |
|
|
|
|
|
prompt_str = json.dumps(prompt, ensure_ascii=False) |
|
|
|
|
|
prompt_str = prompt_str.rsplit('"expr":', 1)[0] + '"expr": "' |
|
|
return prompt_str |
|
|
|
|
|
|
|
|
def create_prompt_eos(vars_list: List[str], ops_list: List[str], cons: str = "C") -> str: |
|
|
"""Create EOS format prompt for generation.""" |
|
|
lines = [ |
|
|
f"vars: {', '.join(vars_list)}", |
|
|
f"oper: {', '.join(ops_list)}", |
|
|
f"cons: {cons}", |
|
|
"expr: " |
|
|
] |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
def extract_expression_json(output: str) -> Optional[str]: |
|
|
"""Extract expression from JSON format output.""" |
|
|
try: |
|
|
|
|
|
if output.strip().endswith("}"): |
|
|
obj = json.loads(output) |
|
|
return obj.get("expr", None) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
match = re.search(r'"expr":\s*"([^"]*)"', output) |
|
|
if match: |
|
|
return match.group(1) |
|
|
|
|
|
|
|
|
match = re.search(r'"expr":\s*"([^"]*)', output) |
|
|
if match: |
|
|
return match.group(1) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def extract_expression_eos(output: str, end_marker: str) -> Optional[str]: |
|
|
"""Extract expression from EOS format output.""" |
|
|
if "expr:" not in output: |
|
|
return None |
|
|
|
|
|
|
|
|
expr_part = output.split("expr:")[-1].strip() |
|
|
|
|
|
|
|
|
if end_marker in expr_part: |
|
|
expr_part = expr_part.split(end_marker)[0].strip() |
|
|
|
|
|
|
|
|
expr_part = expr_part.split("\n")[0].strip() |
|
|
|
|
|
return expr_part if expr_part else None |
|
|
|
|
|
|
|
|
def validate_expression(expr_str: str, allowed_vars: set, allowed_ops: set) -> Dict: |
|
|
"""Validate an expression for correctness.""" |
|
|
result = { |
|
|
"raw": expr_str, |
|
|
"is_valid": False, |
|
|
"is_parseable": False, |
|
|
"uses_correct_symbols": False, |
|
|
"has_garbage": False, |
|
|
"error": None |
|
|
} |
|
|
|
|
|
if not expr_str or not expr_str.strip(): |
|
|
result["error"] = "Empty expression" |
|
|
return result |
|
|
|
|
|
|
|
|
for word in GARBAGE_WORDS: |
|
|
if word.lower() in expr_str.lower(): |
|
|
result["has_garbage"] = True |
|
|
result["error"] = f"Contains garbage: {word}" |
|
|
return result |
|
|
|
|
|
|
|
|
try: |
|
|
expr = Expression(expr_str, is_prefix=False) |
|
|
result["is_parseable"] = True |
|
|
|
|
|
|
|
|
X_test = [[1.0] * 10] |
|
|
eval_result = expr.evaluate(X_test) |
|
|
if len(eval_result) > 0: |
|
|
val = eval_result[0] |
|
|
if val == val and val != float('inf') and val != float('-inf'): |
|
|
result["is_valid"] = True |
|
|
|
|
|
except Exception as e: |
|
|
result["error"] = str(e)[:100] |
|
|
|
|
|
|
|
|
expr_clean = expr_str.replace(" ", "") |
|
|
|
|
|
|
|
|
used_vars = set(re.findall(r'x_\d+', expr_clean)) |
|
|
used_ops = set() |
|
|
|
|
|
for op in ["sin", "cos", "tan", "exp", "log", "sqrt", "abs", "asin", "acos", "atan"]: |
|
|
if op in expr_clean: |
|
|
used_ops.add(op) |
|
|
|
|
|
for op in ["+", "-", "*", "/", "**"]: |
|
|
if op in expr_clean: |
|
|
used_ops.add(op) |
|
|
|
|
|
|
|
|
var_ok = used_vars.issubset(allowed_vars) |
|
|
op_ok = used_ops.issubset(allowed_ops) |
|
|
result["uses_correct_symbols"] = var_ok and op_ok |
|
|
|
|
|
if not var_ok: |
|
|
invalid_vars = used_vars - allowed_vars |
|
|
result["error"] = f"Invalid vars: {invalid_vars}" |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def generate_and_evaluate( |
|
|
model, |
|
|
tokenizer, |
|
|
experiment_type: str, |
|
|
use_native_eos: bool, |
|
|
num_samples: int = 100, |
|
|
test_prompts: Optional[List[Dict]] = None |
|
|
) -> Dict: |
|
|
"""Generate expressions and evaluate quality.""" |
|
|
|
|
|
if test_prompts is None: |
|
|
|
|
|
test_prompts = [ |
|
|
{"vars": ["x_1", "x_2"], "ops": ["*", "+", "-", "sin", "cos"], "cons": "C"}, |
|
|
{"vars": ["x_1", "x_2", "x_3"], "ops": ["*", "+", "/", "exp", "log"], "cons": "C"}, |
|
|
{"vars": ["x_1"], "ops": ["*", "**", "sin", "sqrt"], "cons": "C"}, |
|
|
{"vars": ["x_1", "x_2", "x_3", "x_4"], "ops": ["*", "+", "-", "/"], "cons": "C"}, |
|
|
] |
|
|
|
|
|
|
|
|
if use_native_eos: |
|
|
end_marker = "<|endoftext|>" |
|
|
stop_sequences = ["<|endoftext|>", "\n\nvars:"] |
|
|
else: |
|
|
end_marker = "<|endofex|>" |
|
|
stop_sequences = ["<|endofex|>", '"}', "\n\nvars:"] |
|
|
|
|
|
stopping_criteria = StoppingCriteriaList([ |
|
|
ExpressionStoppingCriteria(tokenizer, stop_sequences) |
|
|
]) |
|
|
|
|
|
|
|
|
gen_config = { |
|
|
"temperature": 0.7, |
|
|
"top_k": 50, |
|
|
"top_p": 0.9, |
|
|
"max_new_tokens": 128, |
|
|
"do_sample": True, |
|
|
"pad_token_id": tokenizer.eos_token_id, |
|
|
} |
|
|
|
|
|
results = { |
|
|
"total": 0, |
|
|
"valid": 0, |
|
|
"parseable": 0, |
|
|
"correct_symbols": 0, |
|
|
"garbage": 0, |
|
|
"stopped_correctly": 0, |
|
|
"samples": [] |
|
|
} |
|
|
|
|
|
samples_per_prompt = num_samples // len(test_prompts) |
|
|
|
|
|
logger.info(f"Generating {num_samples} samples ({samples_per_prompt} per prompt)...") |
|
|
|
|
|
for prompt_config in test_prompts: |
|
|
vars_list = prompt_config["vars"] |
|
|
ops_list = prompt_config["ops"] |
|
|
cons = prompt_config.get("cons", "C") |
|
|
|
|
|
allowed_vars = set(vars_list) | {cons} |
|
|
allowed_ops = set(ops_list) | {"(", ")"} |
|
|
|
|
|
|
|
|
if experiment_type == "json": |
|
|
prompt = create_prompt_json(vars_list, ops_list, cons) |
|
|
else: |
|
|
prompt = create_prompt_eos(vars_list, ops_list, cons) |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
for i in range(samples_per_prompt): |
|
|
results["total"] += 1 |
|
|
|
|
|
|
|
|
output = model.generate( |
|
|
**inputs, |
|
|
**gen_config, |
|
|
stopping_criteria=stopping_criteria |
|
|
) |
|
|
output_text = tokenizer.decode(output[0], skip_special_tokens=False) |
|
|
|
|
|
|
|
|
if experiment_type == "json": |
|
|
expr_str = extract_expression_json(output_text) |
|
|
else: |
|
|
expr_str = extract_expression_eos(output_text, end_marker) |
|
|
|
|
|
|
|
|
stopped_correctly = end_marker in output_text |
|
|
if stopped_correctly: |
|
|
results["stopped_correctly"] += 1 |
|
|
|
|
|
|
|
|
if expr_str: |
|
|
validation = validate_expression(expr_str, allowed_vars, allowed_ops) |
|
|
|
|
|
if validation["is_valid"]: |
|
|
results["valid"] += 1 |
|
|
if validation["is_parseable"]: |
|
|
results["parseable"] += 1 |
|
|
if validation["uses_correct_symbols"]: |
|
|
results["correct_symbols"] += 1 |
|
|
if validation["has_garbage"]: |
|
|
results["garbage"] += 1 |
|
|
|
|
|
|
|
|
sample = { |
|
|
"prompt_vars": vars_list, |
|
|
"prompt_ops": ops_list, |
|
|
"expression": expr_str, |
|
|
"stopped_correctly": stopped_correctly, |
|
|
**validation |
|
|
} |
|
|
results["samples"].append(sample) |
|
|
else: |
|
|
results["garbage"] += 1 |
|
|
results["samples"].append({ |
|
|
"prompt_vars": vars_list, |
|
|
"prompt_ops": ops_list, |
|
|
"expression": None, |
|
|
"stopped_correctly": stopped_correctly, |
|
|
"is_valid": False, |
|
|
"error": "Could not extract expression" |
|
|
}) |
|
|
|
|
|
|
|
|
if results["total"] % 20 == 0: |
|
|
logger.info(f"Progress: {results['total']}/{num_samples}") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def print_report(results: Dict, experiment_name: str): |
|
|
"""Print evaluation report.""" |
|
|
total = results["total"] |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print(f"EVALUATION REPORT: {experiment_name}") |
|
|
print("=" * 60) |
|
|
|
|
|
print(f"\nTotal samples: {total}") |
|
|
|
|
|
metrics = [ |
|
|
("Valid Rate", results["valid"] / total * 100), |
|
|
("Parseable Rate", results["parseable"] / total * 100), |
|
|
("Correct Symbols", results["correct_symbols"] / total * 100), |
|
|
("Stopping Rate", results["stopped_correctly"] / total * 100), |
|
|
("Garbage Rate", results["garbage"] / total * 100), |
|
|
] |
|
|
|
|
|
print("\nMetrics:") |
|
|
print("-" * 40) |
|
|
for name, value in metrics: |
|
|
status = "PASS" if (name != "Garbage Rate" and value >= 80) or (name == "Garbage Rate" and value < 5) else "FAIL" |
|
|
print(f" {name:<20s}: {value:6.1f}% [{status}]") |
|
|
|
|
|
|
|
|
print("\n" + "-" * 40) |
|
|
print("Sample Outputs:") |
|
|
print("-" * 40) |
|
|
|
|
|
valid_samples = [s for s in results["samples"] if s.get("is_valid")] |
|
|
invalid_samples = [s for s in results["samples"] if not s.get("is_valid")] |
|
|
|
|
|
print("\nValid examples:") |
|
|
for sample in valid_samples[:5]: |
|
|
expr = sample.get("expression", "N/A") |
|
|
vars_str = ", ".join(sample.get("prompt_vars", [])) |
|
|
print(f" [{vars_str}] -> {expr}") |
|
|
|
|
|
print("\nInvalid examples:") |
|
|
for sample in invalid_samples[:5]: |
|
|
expr = sample.get("expression", "N/A") |
|
|
error = sample.get("error", "Unknown") |
|
|
print(f" {expr[:50]}... | Error: {error}") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
|
|
|
|
|
|
valid_rate = results["valid"] / total * 100 |
|
|
stopping_rate = results["stopped_correctly"] / total * 100 |
|
|
garbage_rate = results["garbage"] / total * 100 |
|
|
|
|
|
success = valid_rate >= 80 and stopping_rate >= 90 and garbage_rate < 5 |
|
|
|
|
|
print(f"\nOVERALL: {'SUCCESS' if success else 'NEEDS IMPROVEMENT'}") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Evaluate expression generation experiments" |
|
|
) |
|
|
parser.add_argument("--model_path", type=str, required=True, |
|
|
help="Path to trained model") |
|
|
parser.add_argument("--experiment_type", type=str, required=True, |
|
|
choices=["json", "eos"], |
|
|
help="Experiment type (json or eos)") |
|
|
parser.add_argument("--num_samples", type=int, default=200, |
|
|
help="Number of samples to generate") |
|
|
parser.add_argument("--output_file", type=str, default=None, |
|
|
help="Path to save results JSON") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
model, tokenizer, use_native_eos = load_model( |
|
|
args.model_path, |
|
|
args.experiment_type |
|
|
) |
|
|
|
|
|
|
|
|
results = generate_and_evaluate( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
experiment_type=args.experiment_type, |
|
|
use_native_eos=use_native_eos, |
|
|
num_samples=args.num_samples |
|
|
) |
|
|
|
|
|
|
|
|
experiment_name = f"EXP-{'A' if args.experiment_type == 'json' else 'B'} ({args.experiment_type.upper()})" |
|
|
print_report(results, experiment_name) |
|
|
|
|
|
|
|
|
if args.output_file: |
|
|
os.makedirs(os.path.dirname(args.output_file), exist_ok=True) |
|
|
|
|
|
|
|
|
save_results = {k: v for k, v in results.items() if k != "samples"} |
|
|
save_results["sample_count"] = len(results["samples"]) |
|
|
save_results["valid_samples"] = [s for s in results["samples"] if s.get("is_valid")][:20] |
|
|
save_results["invalid_samples"] = [s for s in results["samples"] if not s.get("is_valid")][:20] |
|
|
|
|
|
with open(args.output_file, "w") as f: |
|
|
json.dump(save_results, f, indent=2) |
|
|
|
|
|
logger.info(f"Results saved to: {args.output_file}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|