| | |
| | |
| |
|
| | import argparse |
| | import json |
| | import os |
| | import sys |
| | import re |
| | from collections import Counter |
| | from datetime import datetime |
| |
|
| | import numpy as np |
| | import torch |
| | from datasets import load_dataset |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from peft import PeftModel |
| | from tqdm import tqdm |
| |
|
| | |
| | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| | from classes.expression import Expression |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Evaluate a trained model on expression generation") |
| | parser.add_argument("--model_path", type=str, required=True, |
| | help="Path to model (local or HuggingFace Hub)") |
| | parser.add_argument("--base_model", type=str, default=None, |
| | help="Base model for PEFT (if model_path is adapter)") |
| | parser.add_argument("--dataset_repo_id", type=str, default="augustocsc/sintetico_natural", |
| | help="HuggingFace dataset repository") |
| | parser.add_argument("--data_dir", type=str, default="700K", |
| | help="Data directory within dataset") |
| | parser.add_argument("--data_column", type=str, default="i_prompt_n", |
| | help="Column name for prompts (i_prompt_n for infix, p_prompt_n for prefix)") |
| | parser.add_argument("--num_samples", type=int, default=500, |
| | help="Number of samples to evaluate") |
| | parser.add_argument("--num_generations", type=int, default=1, |
| | help="Number of generations per prompt") |
| | parser.add_argument("--max_new_tokens", type=int, default=128, |
| | help="Maximum new tokens to generate") |
| | parser.add_argument("--temperature", type=float, default=0.7, |
| | help="Sampling temperature") |
| | parser.add_argument("--top_p", type=float, default=0.9, |
| | help="Top-p sampling parameter") |
| | parser.add_argument("--output_dir", type=str, default="./evaluation_results", |
| | help="Directory to save evaluation results") |
| | parser.add_argument("--seed", type=int, default=42, |
| | help="Random seed") |
| | parser.add_argument("--device", type=str, default="auto", |
| | help="Device to use (auto, cuda, cpu)") |
| | return parser.parse_args() |
| |
|
| |
|
| | def extract_expression_from_output(output: str, is_prefix: bool = False) -> str: |
| | """Extract the expression from model output.""" |
| | |
| | start_marker = "<|startofex|>" |
| | end_marker = "<|endofex|>" |
| |
|
| | if start_marker in output and end_marker in output: |
| | start_idx = output.find(start_marker) + len(start_marker) |
| | end_idx = output.find(end_marker) |
| | if start_idx < end_idx: |
| | return output[start_idx:end_idx].strip() |
| |
|
| | |
| | if start_marker in output: |
| | start_idx = output.find(start_marker) + len(start_marker) |
| | remaining = output[start_idx:].strip() |
| |
|
| | |
| | for boundary in ["\nvars:", "\nVariables:", "\nOperators:", "\n\n", "<|endoftext|>"]: |
| | if boundary in remaining: |
| | remaining = remaining.split(boundary)[0].strip() |
| | break |
| |
|
| | |
| | remaining = remaining.split("\n")[0].strip() |
| |
|
| | |
| | if len(remaining) > 150: |
| | remaining = remaining[:150] |
| |
|
| | return remaining |
| |
|
| | |
| | match = re.search(r'(?:expr|Expression):\s*(.+?)(?:\n|$)', output, re.IGNORECASE) |
| | if match: |
| | return match.group(1).strip() |
| |
|
| | |
| | first_line = output.strip().split("\n")[0] |
| | return first_line[:100] if len(first_line) > 100 else first_line |
| |
|
| |
|
| | def validate_expression(expr_str: str, is_prefix: bool = False) -> dict: |
| | """Validate if expression is syntactically correct.""" |
| | result = { |
| | "valid": False, |
| | "parseable": False, |
| | "error": None, |
| | "expression_obj": None |
| | } |
| |
|
| | if not expr_str or expr_str.strip() == "": |
| | result["error"] = "Empty expression" |
| | return result |
| |
|
| | try: |
| | expr_obj = Expression(expr_str, is_prefix=is_prefix) |
| | result["parseable"] = True |
| | result["valid"] = True |
| | result["expression_obj"] = expr_obj |
| | except Exception as e: |
| | result["error"] = str(e) |
| |
|
| | return result |
| |
|
| |
|
| | def check_prompt_adherence(expr_str: str, prompt: str, is_prefix: bool = False) -> dict: |
| | """Check if expression adheres to prompt constraints.""" |
| | result = { |
| | "uses_allowed_vars": False, |
| | "uses_allowed_ops": False, |
| | "all_constraints_met": False |
| | } |
| |
|
| | |
| | |
| |
|
| | |
| | var_match = re.search(r"Variables?:\s*([^\n]+)", prompt, re.IGNORECASE) |
| | allowed_vars = set() |
| | if var_match: |
| | var_str = var_match.group(1) |
| | |
| | allowed_vars = set(re.findall(r"x_\d+", var_str)) |
| |
|
| | |
| | op_match = re.search(r"Operators?:\s*([^\n]+)", prompt, re.IGNORECASE) |
| | allowed_ops = set() |
| | if op_match: |
| | op_str = op_match.group(1) |
| | |
| | ops = ['+', '-', '*', '/', '**', 'sin', 'cos', 'tan', 'log', 'sqrt', 'exp'] |
| | for op in ops: |
| | if op in op_str: |
| | allowed_ops.add(op) |
| |
|
| | |
| | expr_vars = set(re.findall(r"x_\d+", expr_str)) |
| | if allowed_vars: |
| | result["uses_allowed_vars"] = expr_vars.issubset(allowed_vars) |
| | else: |
| | result["uses_allowed_vars"] = True |
| |
|
| | |
| | result["uses_allowed_ops"] = True |
| | if allowed_ops: |
| | |
| | for op in ['sin', 'cos', 'tan', 'log', 'sqrt', 'exp']: |
| | if op in expr_str and op not in allowed_ops: |
| | result["uses_allowed_ops"] = False |
| | break |
| |
|
| | result["all_constraints_met"] = result["uses_allowed_vars"] and result["uses_allowed_ops"] |
| |
|
| | return result |
| |
|
| |
|
| | def load_model_and_tokenizer(model_path: str, base_model: str = None, device: str = "auto"): |
| | """Load model and tokenizer.""" |
| | print(f"Loading model from: {model_path}") |
| |
|
| | |
| | if device == "auto": |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | |
| | is_peft = os.path.exists(os.path.join(model_path, "adapter_config.json")) if os.path.isdir(model_path) else False |
| |
|
| | if is_peft or base_model: |
| | |
| | base = base_model or "gpt2" |
| | print(f"Loading base model: {base}") |
| | model = AutoModelForCausalLM.from_pretrained(base) |
| | model.resize_token_embeddings(len(tokenizer)) |
| |
|
| | |
| | print("Loading PEFT adapter...") |
| | model = PeftModel.from_pretrained(model, model_path) |
| | model = model.merge_and_unload() |
| | else: |
| | |
| | model = AutoModelForCausalLM.from_pretrained(model_path) |
| | model.resize_token_embeddings(len(tokenizer)) |
| |
|
| | model = model.to(device) |
| | model.eval() |
| |
|
| | return model, tokenizer, device |
| |
|
| |
|
| | def generate_expression(model, tokenizer, prompt: str, device: str, |
| | max_new_tokens: int = 128, temperature: float = 0.7, |
| | top_p: float = 0.9, num_return_sequences: int = 1): |
| | """Generate expression(s) from prompt.""" |
| | inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
|
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | top_p=top_p, |
| | do_sample=True, |
| | num_return_sequences=num_return_sequences, |
| | pad_token_id=tokenizer.pad_token_id, |
| | eos_token_id=tokenizer.eos_token_id, |
| | ) |
| |
|
| | generated = tokenizer.batch_decode(outputs, skip_special_tokens=False) |
| | return generated |
| |
|
| |
|
| | def evaluate_model(args): |
| | """Main evaluation function.""" |
| | |
| | torch.manual_seed(args.seed) |
| | np.random.seed(args.seed) |
| |
|
| | |
| | model, tokenizer, device = load_model_and_tokenizer( |
| | args.model_path, args.base_model, args.device |
| | ) |
| |
|
| | |
| | print(f"Loading dataset: {args.dataset_repo_id}/{args.data_dir}") |
| | try: |
| | dataset = load_dataset( |
| | args.dataset_repo_id, |
| | data_files={ |
| | "test": f"{args.data_dir}/test_{args.data_dir}.csv" |
| | } |
| | )["test"] |
| | except Exception as e: |
| | print(f"Error loading test set, trying validation: {e}") |
| | dataset = load_dataset( |
| | args.dataset_repo_id, |
| | data_files={ |
| | "validation": f"{args.data_dir}/val_{args.data_dir}.csv" |
| | } |
| | )["validation"] |
| |
|
| | |
| | if len(dataset) > args.num_samples: |
| | indices = np.random.choice(len(dataset), args.num_samples, replace=False) |
| | dataset = dataset.select(indices) |
| |
|
| | print(f"Evaluating on {len(dataset)} samples...") |
| |
|
| | |
| | is_prefix = args.data_column.startswith("p_") |
| |
|
| | |
| | metrics = { |
| | "total_samples": 0, |
| | "total_generations": 0, |
| | "valid_expressions": 0, |
| | "parseable_expressions": 0, |
| | "uses_allowed_vars": 0, |
| | "uses_allowed_ops": 0, |
| | "all_constraints_met": 0, |
| | "unique_expressions": set(), |
| | "expression_lengths": [], |
| | "errors": Counter(), |
| | } |
| |
|
| | results = [] |
| |
|
| | |
| | for idx, sample in enumerate(tqdm(dataset, desc="Evaluating")): |
| | prompt = sample[args.data_column] |
| |
|
| | |
| | |
| | if "<|startofex|>" in prompt: |
| | prompt_only = prompt.split("<|startofex|>")[0] + "<|startofex|>" |
| | else: |
| | prompt_only = prompt |
| |
|
| | generations = generate_expression( |
| | model, tokenizer, prompt_only, device, |
| | max_new_tokens=args.max_new_tokens, |
| | temperature=args.temperature, |
| | top_p=args.top_p, |
| | num_return_sequences=args.num_generations |
| | ) |
| |
|
| | metrics["total_samples"] += 1 |
| |
|
| | for gen_output in generations: |
| | metrics["total_generations"] += 1 |
| |
|
| | |
| | expr_str = extract_expression_from_output(gen_output, is_prefix) |
| |
|
| | |
| | validation = validate_expression(expr_str, is_prefix) |
| |
|
| | |
| | adherence = check_prompt_adherence(expr_str, prompt_only, is_prefix) |
| |
|
| | |
| | if validation["valid"]: |
| | metrics["valid_expressions"] += 1 |
| | if validation["parseable"]: |
| | metrics["parseable_expressions"] += 1 |
| | metrics["unique_expressions"].add(expr_str) |
| | metrics["expression_lengths"].append(len(expr_str)) |
| | if validation["error"]: |
| | metrics["errors"][validation["error"][:50]] += 1 |
| |
|
| | if adherence["uses_allowed_vars"]: |
| | metrics["uses_allowed_vars"] += 1 |
| | if adherence["uses_allowed_ops"]: |
| | metrics["uses_allowed_ops"] += 1 |
| | if adherence["all_constraints_met"]: |
| | metrics["all_constraints_met"] += 1 |
| |
|
| | results.append({ |
| | "sample_idx": idx, |
| | "prompt": prompt_only[:200], |
| | "generated_output": gen_output[:500], |
| | "extracted_expression": expr_str, |
| | "valid": validation["valid"], |
| | "parseable": validation["parseable"], |
| | "error": validation["error"], |
| | "uses_allowed_vars": adherence["uses_allowed_vars"], |
| | "uses_allowed_ops": adherence["uses_allowed_ops"], |
| | }) |
| |
|
| | |
| | total_gen = metrics["total_generations"] |
| | final_metrics = { |
| | "model_path": args.model_path, |
| | "dataset": f"{args.dataset_repo_id}/{args.data_dir}", |
| | "data_column": args.data_column, |
| | "is_prefix": is_prefix, |
| | "num_samples": metrics["total_samples"], |
| | "num_generations": total_gen, |
| | "temperature": args.temperature, |
| | "top_p": args.top_p, |
| |
|
| | |
| | "valid_rate": metrics["valid_expressions"] / total_gen if total_gen > 0 else 0, |
| | "parseable_rate": metrics["parseable_expressions"] / total_gen if total_gen > 0 else 0, |
| |
|
| | |
| | "uses_allowed_vars_rate": metrics["uses_allowed_vars"] / total_gen if total_gen > 0 else 0, |
| | "uses_allowed_ops_rate": metrics["uses_allowed_ops"] / total_gen if total_gen > 0 else 0, |
| | "constraints_met_rate": metrics["all_constraints_met"] / total_gen if total_gen > 0 else 0, |
| |
|
| | |
| | "unique_expressions": len(metrics["unique_expressions"]), |
| | "diversity_rate": len(metrics["unique_expressions"]) / total_gen if total_gen > 0 else 0, |
| | "avg_expression_length": np.mean(metrics["expression_lengths"]) if metrics["expression_lengths"] else 0, |
| |
|
| | |
| | "top_errors": dict(metrics["errors"].most_common(10)), |
| |
|
| | "timestamp": datetime.now().isoformat(), |
| | } |
| |
|
| | |
| | print("\n" + "="*60) |
| | print("EVALUATION RESULTS") |
| | print("="*60) |
| | print(f"Model: {args.model_path}") |
| | print(f"Dataset: {args.dataset_repo_id}/{args.data_dir}") |
| | print(f"Format: {'Prefix' if is_prefix else 'Infix'}") |
| | print("-"*60) |
| | print(f"Total samples: {metrics['total_samples']}") |
| | print(f"Total generations: {total_gen}") |
| | print("-"*60) |
| | print("VALIDITY METRICS:") |
| | print(f" Valid rate: {final_metrics['valid_rate']:.2%}") |
| | print(f" Parseable rate: {final_metrics['parseable_rate']:.2%}") |
| | print("-"*60) |
| | print("ADHERENCE METRICS:") |
| | print(f" Uses allowed vars: {final_metrics['uses_allowed_vars_rate']:.2%}") |
| | print(f" Uses allowed ops: {final_metrics['uses_allowed_ops_rate']:.2%}") |
| | print(f" All constraints met: {final_metrics['constraints_met_rate']:.2%}") |
| | print("-"*60) |
| | print("DIVERSITY METRICS:") |
| | print(f" Unique expressions: {final_metrics['unique_expressions']}") |
| | print(f" Diversity rate: {final_metrics['diversity_rate']:.2%}") |
| | print(f" Avg expression length: {final_metrics['avg_expression_length']:.1f}") |
| | print("="*60) |
| |
|
| | |
| | os.makedirs(args.output_dir, exist_ok=True) |
| |
|
| | |
| | model_name = args.model_path.replace("/", "_").replace("\\", "_") |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| |
|
| | |
| | metrics_file = os.path.join(args.output_dir, f"metrics_{model_name}_{timestamp}.json") |
| | with open(metrics_file, "w") as f: |
| | json.dump(final_metrics, f, indent=2) |
| | print(f"\nMetrics saved to: {metrics_file}") |
| |
|
| | |
| | results_file = os.path.join(args.output_dir, f"results_{model_name}_{timestamp}.json") |
| | with open(results_file, "w") as f: |
| | json.dump(results, f, indent=2) |
| | print(f"Detailed results saved to: {results_file}") |
| |
|
| | return final_metrics |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | evaluate_model(args) |
| |
|