File size: 7,258 Bytes
c082aa2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
#!/usr/bin/env python3
"""
Quick evaluation script for JSON-formatted models.
Reads base model from adapter_config.json automatically.
"""
import argparse
import json
import logging
import os
import sys
from pathlib import Path
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from datasets import load_dataset
sys.path.insert(0, str(Path(__file__).parent.parent))
from classes.expression import Expression
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def load_model_auto(model_path: str):
"""Load model with automatic base model detection from adapter_config.json"""
adapter_config_path = os.path.join(model_path, "adapter_config.json")
if not os.path.exists(adapter_config_path):
raise FileNotFoundError(f"No adapter_config.json found in {model_path}")
with open(adapter_config_path) as f:
adapter_config = json.load(f)
base_model_name = adapter_config.get("base_model_name_or_path", "gpt2")
logger.info(f"Loading base model: {base_model_name}")
# Load base model
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token = tokenizer.eos_token
# Load LoRA adapter
logger.info(f"Loading LoRA adapter from {model_path}")
model = PeftModel.from_pretrained(model, model_path)
model = model.merge_and_unload()
model.eval()
return model, tokenizer, base_model_name
def create_json_prompt(vars_list, ops_list, cons="C"):
"""Create JSON format prompt"""
prompt = {
"vars": vars_list,
"ops": ops_list,
"cons": cons,
"expr": ""
}
prompt_str = json.dumps(prompt, ensure_ascii=False)
prompt_str = prompt_str.rsplit('"expr":', 1)[0] + '"expr": "'
return prompt_str
def extract_expression_json(output: str):
"""Extract expression from JSON output"""
import re
# Try to extract from "expr": "..." pattern
match = re.search(r'"expr":\s*"([^"]*)"', output)
if match:
return match.group(1)
# Try without closing quote
match = re.search(r'"expr":\s*"([^"]+)', output)
if match:
expr = match.group(1)
# Clean up common artifacts
expr = expr.split('"')[0].split('}')[0].strip()
return expr
return None
def evaluate_model(model, tokenizer, num_samples=500, dataset_name="augustocsc/sintetico_natural", data_dir="700K"):
"""Evaluate model on dataset"""
device = model.device
logger.info(f"Using device: {device}")
# Load dataset
logger.info(f"Loading dataset {dataset_name}/{data_dir}")
dataset = load_dataset(dataset_name, data_dir, split="train")
# Sample
import random
random.seed(42)
indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))
results = []
valid_count = 0
parseable_count = 0
unique_expressions = set()
logger.info(f"Evaluating on {len(indices)} samples...")
for idx in tqdm(indices, desc="Evaluating"):
sample = dataset[idx]
prompt_text = sample.get("i_prompt_n", "")
# Parse prompt to extract vars and ops
vars_line = [l for l in prompt_text.split('\n') if l.startswith('vars:')]
ops_line = [l for l in prompt_text.split('\n') if l.startswith('oper:')]
if not vars_line or not ops_line:
continue
vars_list = [v.strip() for v in vars_line[0].replace('vars:', '').split(',')]
ops_list = [o.strip() for o in ops_line[0].replace('oper:', '').split(',')]
# Create JSON prompt
prompt = create_json_prompt(vars_list, ops_list)
# Generate
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract expression
expr_str = extract_expression_json(generated)
# Validate
is_valid = False
is_parseable = False
error_msg = None
if expr_str:
try:
expr = Expression.parse_infix(expr_str)
is_parseable = True
is_valid = expr.validate()
if is_valid:
unique_expressions.add(expr_str)
except Exception as e:
error_msg = str(e)[:100]
else:
error_msg = "Failed to extract expression"
if is_valid:
valid_count += 1
if is_parseable:
parseable_count += 1
results.append({
"sample_idx": idx,
"prompt": prompt,
"generated": generated[:500],
"expression": expr_str,
"valid": is_valid,
"parseable": is_parseable,
"error": error_msg
})
total = len(results)
metrics = {
"model_path": str(model),
"num_samples": total,
"valid_rate": valid_count / total if total > 0 else 0,
"parseable_rate": parseable_count / total if total > 0 else 0,
"unique_expressions": len(unique_expressions),
"diversity_rate": len(unique_expressions) / total if total > 0 else 0,
}
return metrics, results
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--num_samples", type=int, default=500)
parser.add_argument("--output_dir", type=str, default="./results_corrected")
args = parser.parse_args()
# Load model
model, tokenizer, base_model_name = load_model_auto(args.model_path)
# Evaluate
metrics, results = evaluate_model(model, tokenizer, args.num_samples)
# Print results
print("\n" + "="*60)
print(f"EVALUATION RESULTS - {os.path.basename(args.model_path)}")
print("="*60)
print(f"Base model: {base_model_name}")
print(f"Valid rate: {metrics['valid_rate']*100:.1f}%")
print(f"Parseable rate: {metrics['parseable_rate']*100:.1f}%")
print(f"Unique expressions: {metrics['unique_expressions']}")
print(f"Diversity rate: {metrics['diversity_rate']*100:.1f}%")
print("="*60)
# Save results
os.makedirs(args.output_dir, exist_ok=True)
model_name = os.path.basename(args.model_path)
metrics_path = os.path.join(args.output_dir, f"{model_name}_metrics.json")
with open(metrics_path, 'w') as f:
json.dump(metrics, f, indent=2)
results_path = os.path.join(args.output_dir, f"{model_name}_results.json")
with open(results_path, 'w') as f:
json.dump(results, f, indent=2)
logger.info(f"Results saved to {args.output_dir}")
if __name__ == "__main__":
main()
|