| |
| |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import re |
| from collections import Counter |
| from datetime import datetime |
|
|
| import numpy as np |
| import torch |
| from datasets import load_dataset |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel |
| from tqdm import tqdm |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from classes.expression import Expression |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Evaluate a trained model on expression generation") |
| parser.add_argument("--model_path", type=str, required=True, |
| help="Path to model (local or HuggingFace Hub)") |
| parser.add_argument("--base_model", type=str, default=None, |
| help="Base model for PEFT (if model_path is adapter)") |
| parser.add_argument("--dataset_repo_id", type=str, default="augustocsc/sintetico_natural", |
| help="HuggingFace dataset repository") |
| parser.add_argument("--data_dir", type=str, default="700K", |
| help="Data directory within dataset") |
| parser.add_argument("--data_column", type=str, default="i_prompt_n", |
| help="Column name for prompts (i_prompt_n for infix, p_prompt_n for prefix)") |
| parser.add_argument("--num_samples", type=int, default=500, |
| help="Number of samples to evaluate") |
| parser.add_argument("--num_generations", type=int, default=1, |
| help="Number of generations per prompt") |
| parser.add_argument("--max_new_tokens", type=int, default=128, |
| help="Maximum new tokens to generate") |
| parser.add_argument("--temperature", type=float, default=0.7, |
| help="Sampling temperature") |
| parser.add_argument("--top_p", type=float, default=0.9, |
| help="Top-p sampling parameter") |
| parser.add_argument("--output_dir", type=str, default="./evaluation_results", |
| help="Directory to save evaluation results") |
| parser.add_argument("--seed", type=int, default=42, |
| help="Random seed") |
| parser.add_argument("--device", type=str, default="auto", |
| help="Device to use (auto, cuda, cpu)") |
| return parser.parse_args() |
|
|
|
|
| def extract_expression_from_output(output: str, is_prefix: bool = False) -> str: |
| """Extract the expression from model output.""" |
| |
| start_marker = "<|startofex|>" |
| end_marker = "<|endofex|>" |
|
|
| if start_marker in output and end_marker in output: |
| start_idx = output.find(start_marker) + len(start_marker) |
| end_idx = output.find(end_marker) |
| if start_idx < end_idx: |
| return output[start_idx:end_idx].strip() |
|
|
| |
| if start_marker in output: |
| start_idx = output.find(start_marker) + len(start_marker) |
| remaining = output[start_idx:].strip() |
|
|
| |
| for boundary in ["\nvars:", "\nVariables:", "\nOperators:", "\n\n", "<|endoftext|>"]: |
| if boundary in remaining: |
| remaining = remaining.split(boundary)[0].strip() |
| break |
|
|
| |
| remaining = remaining.split("\n")[0].strip() |
|
|
| |
| if len(remaining) > 150: |
| remaining = remaining[:150] |
|
|
| return remaining |
|
|
| |
| match = re.search(r'(?:expr|Expression):\s*(.+?)(?:\n|$)', output, re.IGNORECASE) |
| if match: |
| return match.group(1).strip() |
|
|
| |
| first_line = output.strip().split("\n")[0] |
| return first_line[:100] if len(first_line) > 100 else first_line |
|
|
|
|
| def validate_expression(expr_str: str, is_prefix: bool = False) -> dict: |
| """Validate if expression is syntactically correct.""" |
| result = { |
| "valid": False, |
| "parseable": False, |
| "error": None, |
| "expression_obj": None |
| } |
|
|
| if not expr_str or expr_str.strip() == "": |
| result["error"] = "Empty expression" |
| return result |
|
|
| try: |
| expr_obj = Expression(expr_str, is_prefix=is_prefix) |
| result["parseable"] = True |
| result["valid"] = True |
| result["expression_obj"] = expr_obj |
| except Exception as e: |
| result["error"] = str(e) |
|
|
| return result |
|
|
|
|
| def check_prompt_adherence(expr_str: str, prompt: str, is_prefix: bool = False) -> dict: |
| """Check if expression adheres to prompt constraints.""" |
| result = { |
| "uses_allowed_vars": False, |
| "uses_allowed_ops": False, |
| "all_constraints_met": False |
| } |
|
|
| |
| |
|
|
| |
| var_match = re.search(r"Variables?:\s*([^\n]+)", prompt, re.IGNORECASE) |
| allowed_vars = set() |
| if var_match: |
| var_str = var_match.group(1) |
| |
| allowed_vars = set(re.findall(r"x_\d+", var_str)) |
|
|
| |
| op_match = re.search(r"Operators?:\s*([^\n]+)", prompt, re.IGNORECASE) |
| allowed_ops = set() |
| if op_match: |
| op_str = op_match.group(1) |
| |
| ops = ['+', '-', '*', '/', '**', 'sin', 'cos', 'tan', 'log', 'sqrt', 'exp'] |
| for op in ops: |
| if op in op_str: |
| allowed_ops.add(op) |
|
|
| |
| expr_vars = set(re.findall(r"x_\d+", expr_str)) |
| if allowed_vars: |
| result["uses_allowed_vars"] = expr_vars.issubset(allowed_vars) |
| else: |
| result["uses_allowed_vars"] = True |
|
|
| |
| result["uses_allowed_ops"] = True |
| if allowed_ops: |
| |
| for op in ['sin', 'cos', 'tan', 'log', 'sqrt', 'exp']: |
| if op in expr_str and op not in allowed_ops: |
| result["uses_allowed_ops"] = False |
| break |
|
|
| result["all_constraints_met"] = result["uses_allowed_vars"] and result["uses_allowed_ops"] |
|
|
| return result |
|
|
|
|
| def load_model_and_tokenizer(model_path: str, base_model: str = None, device: str = "auto"): |
| """Load model and tokenizer.""" |
| print(f"Loading model from: {model_path}") |
|
|
| |
| if device == "auto": |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| is_peft = os.path.exists(os.path.join(model_path, "adapter_config.json")) if os.path.isdir(model_path) else False |
|
|
| if is_peft or base_model: |
| |
| base = base_model or "gpt2" |
| print(f"Loading base model: {base}") |
| model = AutoModelForCausalLM.from_pretrained(base) |
| model.resize_token_embeddings(len(tokenizer)) |
|
|
| |
| print("Loading PEFT adapter...") |
| model = PeftModel.from_pretrained(model, model_path) |
| model = model.merge_and_unload() |
| else: |
| |
| model = AutoModelForCausalLM.from_pretrained(model_path) |
| model.resize_token_embeddings(len(tokenizer)) |
|
|
| model = model.to(device) |
| model.eval() |
|
|
| return model, tokenizer, device |
|
|
|
|
| def generate_expression(model, tokenizer, prompt: str, device: str, |
| max_new_tokens: int = 128, temperature: float = 0.7, |
| top_p: float = 0.9, num_return_sequences: int = 1): |
| """Generate expression(s) from prompt.""" |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| do_sample=True, |
| num_return_sequences=num_return_sequences, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| generated = tokenizer.batch_decode(outputs, skip_special_tokens=False) |
| return generated |
|
|
|
|
| def evaluate_model(args): |
| """Main evaluation function.""" |
| |
| torch.manual_seed(args.seed) |
| np.random.seed(args.seed) |
|
|
| |
| model, tokenizer, device = load_model_and_tokenizer( |
| args.model_path, args.base_model, args.device |
| ) |
|
|
| |
| print(f"Loading dataset: {args.dataset_repo_id}/{args.data_dir}") |
| try: |
| dataset = load_dataset( |
| args.dataset_repo_id, |
| data_files={ |
| "test": f"{args.data_dir}/test_{args.data_dir}.csv" |
| } |
| )["test"] |
| except Exception as e: |
| print(f"Error loading test set, trying validation: {e}") |
| dataset = load_dataset( |
| args.dataset_repo_id, |
| data_files={ |
| "validation": f"{args.data_dir}/val_{args.data_dir}.csv" |
| } |
| )["validation"] |
|
|
| |
| if len(dataset) > args.num_samples: |
| indices = np.random.choice(len(dataset), args.num_samples, replace=False) |
| dataset = dataset.select(indices) |
|
|
| print(f"Evaluating on {len(dataset)} samples...") |
|
|
| |
| is_prefix = args.data_column.startswith("p_") |
|
|
| |
| metrics = { |
| "total_samples": 0, |
| "total_generations": 0, |
| "valid_expressions": 0, |
| "parseable_expressions": 0, |
| "uses_allowed_vars": 0, |
| "uses_allowed_ops": 0, |
| "all_constraints_met": 0, |
| "unique_expressions": set(), |
| "expression_lengths": [], |
| "errors": Counter(), |
| } |
|
|
| results = [] |
|
|
| |
| for idx, sample in enumerate(tqdm(dataset, desc="Evaluating")): |
| prompt = sample[args.data_column] |
|
|
| |
| |
| if "<|startofex|>" in prompt: |
| prompt_only = prompt.split("<|startofex|>")[0] + "<|startofex|>" |
| else: |
| prompt_only = prompt |
|
|
| generations = generate_expression( |
| model, tokenizer, prompt_only, device, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| num_return_sequences=args.num_generations |
| ) |
|
|
| metrics["total_samples"] += 1 |
|
|
| for gen_output in generations: |
| metrics["total_generations"] += 1 |
|
|
| |
| expr_str = extract_expression_from_output(gen_output, is_prefix) |
|
|
| |
| validation = validate_expression(expr_str, is_prefix) |
|
|
| |
| adherence = check_prompt_adherence(expr_str, prompt_only, is_prefix) |
|
|
| |
| if validation["valid"]: |
| metrics["valid_expressions"] += 1 |
| if validation["parseable"]: |
| metrics["parseable_expressions"] += 1 |
| metrics["unique_expressions"].add(expr_str) |
| metrics["expression_lengths"].append(len(expr_str)) |
| if validation["error"]: |
| metrics["errors"][validation["error"][:50]] += 1 |
|
|
| if adherence["uses_allowed_vars"]: |
| metrics["uses_allowed_vars"] += 1 |
| if adherence["uses_allowed_ops"]: |
| metrics["uses_allowed_ops"] += 1 |
| if adherence["all_constraints_met"]: |
| metrics["all_constraints_met"] += 1 |
|
|
| results.append({ |
| "sample_idx": idx, |
| "prompt": prompt_only[:200], |
| "generated_output": gen_output[:500], |
| "extracted_expression": expr_str, |
| "valid": validation["valid"], |
| "parseable": validation["parseable"], |
| "error": validation["error"], |
| "uses_allowed_vars": adherence["uses_allowed_vars"], |
| "uses_allowed_ops": adherence["uses_allowed_ops"], |
| }) |
|
|
| |
| total_gen = metrics["total_generations"] |
| final_metrics = { |
| "model_path": args.model_path, |
| "dataset": f"{args.dataset_repo_id}/{args.data_dir}", |
| "data_column": args.data_column, |
| "is_prefix": is_prefix, |
| "num_samples": metrics["total_samples"], |
| "num_generations": total_gen, |
| "temperature": args.temperature, |
| "top_p": args.top_p, |
|
|
| |
| "valid_rate": metrics["valid_expressions"] / total_gen if total_gen > 0 else 0, |
| "parseable_rate": metrics["parseable_expressions"] / total_gen if total_gen > 0 else 0, |
|
|
| |
| "uses_allowed_vars_rate": metrics["uses_allowed_vars"] / total_gen if total_gen > 0 else 0, |
| "uses_allowed_ops_rate": metrics["uses_allowed_ops"] / total_gen if total_gen > 0 else 0, |
| "constraints_met_rate": metrics["all_constraints_met"] / total_gen if total_gen > 0 else 0, |
|
|
| |
| "unique_expressions": len(metrics["unique_expressions"]), |
| "diversity_rate": len(metrics["unique_expressions"]) / total_gen if total_gen > 0 else 0, |
| "avg_expression_length": np.mean(metrics["expression_lengths"]) if metrics["expression_lengths"] else 0, |
|
|
| |
| "top_errors": dict(metrics["errors"].most_common(10)), |
|
|
| "timestamp": datetime.now().isoformat(), |
| } |
|
|
| |
| print("\n" + "="*60) |
| print("EVALUATION RESULTS") |
| print("="*60) |
| print(f"Model: {args.model_path}") |
| print(f"Dataset: {args.dataset_repo_id}/{args.data_dir}") |
| print(f"Format: {'Prefix' if is_prefix else 'Infix'}") |
| print("-"*60) |
| print(f"Total samples: {metrics['total_samples']}") |
| print(f"Total generations: {total_gen}") |
| print("-"*60) |
| print("VALIDITY METRICS:") |
| print(f" Valid rate: {final_metrics['valid_rate']:.2%}") |
| print(f" Parseable rate: {final_metrics['parseable_rate']:.2%}") |
| print("-"*60) |
| print("ADHERENCE METRICS:") |
| print(f" Uses allowed vars: {final_metrics['uses_allowed_vars_rate']:.2%}") |
| print(f" Uses allowed ops: {final_metrics['uses_allowed_ops_rate']:.2%}") |
| print(f" All constraints met: {final_metrics['constraints_met_rate']:.2%}") |
| print("-"*60) |
| print("DIVERSITY METRICS:") |
| print(f" Unique expressions: {final_metrics['unique_expressions']}") |
| print(f" Diversity rate: {final_metrics['diversity_rate']:.2%}") |
| print(f" Avg expression length: {final_metrics['avg_expression_length']:.1f}") |
| print("="*60) |
|
|
| |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| |
| model_name = args.model_path.replace("/", "_").replace("\\", "_") |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
| |
| metrics_file = os.path.join(args.output_dir, f"metrics_{model_name}_{timestamp}.json") |
| with open(metrics_file, "w") as f: |
| json.dump(final_metrics, f, indent=2) |
| print(f"\nMetrics saved to: {metrics_file}") |
|
|
| |
| results_file = os.path.join(args.output_dir, f"results_{model_name}_{timestamp}.json") |
| with open(results_file, "w") as f: |
| json.dump(results, f, indent=2) |
| print(f"Detailed results saved to: {results_file}") |
|
|
| return final_metrics |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| evaluate_model(args) |
|
|