| | |
| | """ |
| | Quick evaluation script for JSON-formatted models. |
| | Reads base model from adapter_config.json automatically. |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import logging |
| | import os |
| | import sys |
| | from pathlib import Path |
| | from tqdm import tqdm |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from peft import PeftModel |
| | from datasets import load_dataset |
| |
|
| | sys.path.insert(0, str(Path(__file__).parent.parent)) |
| | from classes.expression import Expression |
| |
|
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def load_model_auto(model_path: str): |
| | """Load model with automatic base model detection from adapter_config.json""" |
| | adapter_config_path = os.path.join(model_path, "adapter_config.json") |
| |
|
| | if not os.path.exists(adapter_config_path): |
| | raise FileNotFoundError(f"No adapter_config.json found in {model_path}") |
| |
|
| | with open(adapter_config_path) as f: |
| | adapter_config = json.load(f) |
| |
|
| | base_model_name = adapter_config.get("base_model_name_or_path", "gpt2") |
| | logger.info(f"Loading base model: {base_model_name}") |
| |
|
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | base_model_name, |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| | device_map="auto" if torch.cuda.is_available() else None |
| | ) |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | |
| | logger.info(f"Loading LoRA adapter from {model_path}") |
| | model = PeftModel.from_pretrained(model, model_path) |
| | model = model.merge_and_unload() |
| | model.eval() |
| |
|
| | return model, tokenizer, base_model_name |
| |
|
| |
|
| | def create_json_prompt(vars_list, ops_list, cons="C"): |
| | """Create JSON format prompt""" |
| | prompt = { |
| | "vars": vars_list, |
| | "ops": ops_list, |
| | "cons": cons, |
| | "expr": "" |
| | } |
| | prompt_str = json.dumps(prompt, ensure_ascii=False) |
| | prompt_str = prompt_str.rsplit('"expr":', 1)[0] + '"expr": "' |
| | return prompt_str |
| |
|
| |
|
| | def extract_expression_json(output: str): |
| | """Extract expression from JSON output""" |
| | import re |
| |
|
| | |
| | match = re.search(r'"expr":\s*"([^"]*)"', output) |
| | if match: |
| | return match.group(1) |
| |
|
| | |
| | match = re.search(r'"expr":\s*"([^"]+)', output) |
| | if match: |
| | expr = match.group(1) |
| | |
| | expr = expr.split('"')[0].split('}')[0].strip() |
| | return expr |
| |
|
| | return None |
| |
|
| |
|
| | def evaluate_model(model, tokenizer, num_samples=500, dataset_name="augustocsc/sintetico_natural", data_dir="700K"): |
| | """Evaluate model on dataset""" |
| |
|
| | device = model.device |
| | logger.info(f"Using device: {device}") |
| |
|
| | |
| | logger.info(f"Loading dataset {dataset_name}/{data_dir}") |
| | dataset = load_dataset(dataset_name, data_dir, split="train") |
| |
|
| | |
| | import random |
| | random.seed(42) |
| | indices = random.sample(range(len(dataset)), min(num_samples, len(dataset))) |
| |
|
| | results = [] |
| | valid_count = 0 |
| | parseable_count = 0 |
| | unique_expressions = set() |
| |
|
| | logger.info(f"Evaluating on {len(indices)} samples...") |
| |
|
| | for idx in tqdm(indices, desc="Evaluating"): |
| | sample = dataset[idx] |
| | prompt_text = sample.get("i_prompt_n", "") |
| |
|
| | |
| | vars_line = [l for l in prompt_text.split('\n') if l.startswith('vars:')] |
| | ops_line = [l for l in prompt_text.split('\n') if l.startswith('oper:')] |
| |
|
| | if not vars_line or not ops_line: |
| | continue |
| |
|
| | vars_list = [v.strip() for v in vars_line[0].replace('vars:', '').split(',')] |
| | ops_list = [o.strip() for o in ops_line[0].replace('oper:', '').split(',')] |
| |
|
| | |
| | prompt = create_json_prompt(vars_list, ops_list) |
| |
|
| | |
| | inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| |
|
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=100, |
| | temperature=0.7, |
| | top_p=0.9, |
| | do_sample=True, |
| | pad_token_id=tokenizer.eos_token_id |
| | ) |
| |
|
| | generated = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | |
| | expr_str = extract_expression_json(generated) |
| |
|
| | |
| | is_valid = False |
| | is_parseable = False |
| | error_msg = None |
| |
|
| | if expr_str: |
| | try: |
| | expr = Expression.parse_infix(expr_str) |
| | is_parseable = True |
| | is_valid = expr.validate() |
| | if is_valid: |
| | unique_expressions.add(expr_str) |
| | except Exception as e: |
| | error_msg = str(e)[:100] |
| | else: |
| | error_msg = "Failed to extract expression" |
| |
|
| | if is_valid: |
| | valid_count += 1 |
| | if is_parseable: |
| | parseable_count += 1 |
| |
|
| | results.append({ |
| | "sample_idx": idx, |
| | "prompt": prompt, |
| | "generated": generated[:500], |
| | "expression": expr_str, |
| | "valid": is_valid, |
| | "parseable": is_parseable, |
| | "error": error_msg |
| | }) |
| |
|
| | total = len(results) |
| | metrics = { |
| | "model_path": str(model), |
| | "num_samples": total, |
| | "valid_rate": valid_count / total if total > 0 else 0, |
| | "parseable_rate": parseable_count / total if total > 0 else 0, |
| | "unique_expressions": len(unique_expressions), |
| | "diversity_rate": len(unique_expressions) / total if total > 0 else 0, |
| | } |
| |
|
| | return metrics, results |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--model_path", type=str, required=True) |
| | parser.add_argument("--num_samples", type=int, default=500) |
| | parser.add_argument("--output_dir", type=str, default="./results_corrected") |
| | args = parser.parse_args() |
| |
|
| | |
| | model, tokenizer, base_model_name = load_model_auto(args.model_path) |
| |
|
| | |
| | metrics, results = evaluate_model(model, tokenizer, args.num_samples) |
| |
|
| | |
| | print("\n" + "="*60) |
| | print(f"EVALUATION RESULTS - {os.path.basename(args.model_path)}") |
| | print("="*60) |
| | print(f"Base model: {base_model_name}") |
| | print(f"Valid rate: {metrics['valid_rate']*100:.1f}%") |
| | print(f"Parseable rate: {metrics['parseable_rate']*100:.1f}%") |
| | print(f"Unique expressions: {metrics['unique_expressions']}") |
| | print(f"Diversity rate: {metrics['diversity_rate']*100:.1f}%") |
| | print("="*60) |
| |
|
| | |
| | os.makedirs(args.output_dir, exist_ok=True) |
| | model_name = os.path.basename(args.model_path) |
| |
|
| | metrics_path = os.path.join(args.output_dir, f"{model_name}_metrics.json") |
| | with open(metrics_path, 'w') as f: |
| | json.dump(metrics, f, indent=2) |
| |
|
| | results_path = os.path.join(args.output_dir, f"{model_name}_results.json") |
| | with open(results_path, 'w') as f: |
| | json.dump(results, f, indent=2) |
| |
|
| | logger.info(f"Results saved to {args.output_dir}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|