gpt2_base_prefix_682k / scripts /test_model_sizes.py
augustocsc's picture
GPT-2 Base trained on prefix dataset (682K)
c082aa2 verified
#!/usr/bin/env python3
"""
Test different model sizes on expression generation.
Compare GPT-2 (124M), GPT-2-medium (355M), GPT-2-large (774M).
"""
import os
import sys
import json
import argparse
from pathlib import Path
import numpy as np
import torch
# Add project root to path
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "classes"))
from transformers import AutoTokenizer, AutoModelForCausalLM
from expression import Expression
def generate_expressions(model_name: str, num_samples: int = 20, device: str = None):
"""Generate expressions with a given model."""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(device)
print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)
model = model.to(device)
model.eval()
# Build prompt (JSON format)
vars_list = ["x_1"]
ops_list = ["+", "-", "*", "/", "sin", "cos", "sqrt", "log", "exp", "pow"]
prompt = json.dumps({"vars": vars_list, "ops": ops_list, "cons": "C", "expr": ""})[:-2]
expressions = []
valid_count = 0
has_power = 0
has_nested_trig = 0
depths = []
print(f"Generating {num_samples} expressions...")
for i in range(num_samples):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract expression
expr_str = ""
if '"expr": "' in text:
start = text.index('"expr": "') + len('"expr": "')
remaining = text[start:]
for terminator in ['"}', '"']:
if terminator in remaining:
expr_str = remaining[:remaining.index(terminator)].strip()
break
if not expr_str:
continue
# Validate
test_expr = expr_str.replace('C', '1')
is_valid = False
try:
expr = Expression(test_expr, is_prefix=False)
# Simple validation - just check if it parses
is_valid = True
except:
pass
# Count features
if is_valid:
valid_count += 1
if '**' in expr_str or 'pow(' in expr_str:
has_power += 1
if any(nested in expr_str for nested in ['sin(sin', 'sin(cos', 'cos(sin', 'cos(cos']):
has_nested_trig += 1
depth = max(expr_str.count('('), 1)
depths.append(depth)
expressions.append({
"expression": expr_str,
"is_valid": is_valid,
})
# Stats
stats = {
"model_name": model_name,
"total": len(expressions),
"valid": valid_count,
"valid_pct": 100 * valid_count / len(expressions) if expressions else 0,
"has_power": has_power,
"has_power_pct": 100 * has_power / valid_count if valid_count > 0 else 0,
"has_nested_trig": has_nested_trig,
"has_nested_trig_pct": 100 * has_nested_trig / valid_count if valid_count > 0 else 0,
"avg_depth": sum(depths) / len(depths) if depths else 0,
"max_depth": max(depths) if depths else 0,
}
return expressions, stats
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--models", nargs="+", default=["gpt2", "gpt2-medium"],
help="Models to test")
parser.add_argument("--num_samples", type=int, default=20, help="Samples per model")
parser.add_argument("--output_file", type=str, default="model_size_comparison.json")
args = parser.parse_args()
results = {}
for model_name in args.models:
print()
print("="*80)
print(f"Testing {model_name}")
print("="*80)
expressions, stats = generate_expressions(model_name, args.num_samples)
results[model_name] = {
"stats": stats,
"expressions": expressions,
}
print()
print(f"Results for {model_name}:")
print(f" Valid: {stats['valid']}/{stats['total']} ({stats['valid_pct']:.1f}%)")
print(f" With power: {stats['has_power']} ({stats['has_power_pct']:.1f}%)")
print(f" With nested trig: {stats['has_nested_trig']} ({stats['has_nested_trig_pct']:.1f}%)")
print(f" Avg depth: {stats['avg_depth']:.2f}")
print(f" Max depth: {stats['max_depth']}")
# Show examples
print()
print("Sample expressions:")
valid_exprs = [e for e in expressions if e["is_valid"]][:5]
for i, e in enumerate(valid_exprs, 1):
print(f" {i}. {e['expression'][:70]}")
# Save
with open(args.output_file, "w") as f:
json.dump(results, f, indent=2)
print()
print(f"Saved results to {args.output_file}")
# Comparison table
print()
print("="*80)
print("COMPARISON")
print("="*80)
print(f"{'Model':<20} {'Valid%':>8} {'Power%':>8} {'NestedTrig%':>12} {'AvgDepth':>10}")
print("-"*80)
for model_name, data in results.items():
stats = data["stats"]
print(f"{model_name:<20} {stats['valid_pct']:>7.1f}% {stats['has_power_pct']:>7.1f}% {stats['has_nested_trig_pct']:>11.1f}% {stats['avg_depth']:>10.2f}")
if __name__ == "__main__":
main()