|
|
|
|
|
""" |
|
|
Test different model sizes on expression generation. |
|
|
Compare GPT-2 (124M), GPT-2-medium (355M), GPT-2-large (774M). |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
sys.path.insert(0, str(PROJECT_ROOT / "classes")) |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from expression import Expression |
|
|
|
|
|
|
|
|
def generate_expressions(model_name: str, num_samples: int = 20, device: str = None): |
|
|
"""Generate expressions with a given model.""" |
|
|
|
|
|
if device is None: |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
else: |
|
|
device = torch.device(device) |
|
|
|
|
|
print(f"Loading {model_name}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
vars_list = ["x_1"] |
|
|
ops_list = ["+", "-", "*", "/", "sin", "cos", "sqrt", "log", "exp", "pow"] |
|
|
prompt = json.dumps({"vars": vars_list, "ops": ops_list, "cons": "C", "expr": ""})[:-2] |
|
|
|
|
|
expressions = [] |
|
|
valid_count = 0 |
|
|
has_power = 0 |
|
|
has_nested_trig = 0 |
|
|
depths = [] |
|
|
|
|
|
print(f"Generating {num_samples} expressions...") |
|
|
|
|
|
for i in range(num_samples): |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=50, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
expr_str = "" |
|
|
if '"expr": "' in text: |
|
|
start = text.index('"expr": "') + len('"expr": "') |
|
|
remaining = text[start:] |
|
|
for terminator in ['"}', '"']: |
|
|
if terminator in remaining: |
|
|
expr_str = remaining[:remaining.index(terminator)].strip() |
|
|
break |
|
|
|
|
|
if not expr_str: |
|
|
continue |
|
|
|
|
|
|
|
|
test_expr = expr_str.replace('C', '1') |
|
|
is_valid = False |
|
|
|
|
|
try: |
|
|
expr = Expression(test_expr, is_prefix=False) |
|
|
|
|
|
is_valid = True |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
if is_valid: |
|
|
valid_count += 1 |
|
|
|
|
|
if '**' in expr_str or 'pow(' in expr_str: |
|
|
has_power += 1 |
|
|
|
|
|
if any(nested in expr_str for nested in ['sin(sin', 'sin(cos', 'cos(sin', 'cos(cos']): |
|
|
has_nested_trig += 1 |
|
|
|
|
|
depth = max(expr_str.count('('), 1) |
|
|
depths.append(depth) |
|
|
|
|
|
expressions.append({ |
|
|
"expression": expr_str, |
|
|
"is_valid": is_valid, |
|
|
}) |
|
|
|
|
|
|
|
|
stats = { |
|
|
"model_name": model_name, |
|
|
"total": len(expressions), |
|
|
"valid": valid_count, |
|
|
"valid_pct": 100 * valid_count / len(expressions) if expressions else 0, |
|
|
"has_power": has_power, |
|
|
"has_power_pct": 100 * has_power / valid_count if valid_count > 0 else 0, |
|
|
"has_nested_trig": has_nested_trig, |
|
|
"has_nested_trig_pct": 100 * has_nested_trig / valid_count if valid_count > 0 else 0, |
|
|
"avg_depth": sum(depths) / len(depths) if depths else 0, |
|
|
"max_depth": max(depths) if depths else 0, |
|
|
} |
|
|
|
|
|
return expressions, stats |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--models", nargs="+", default=["gpt2", "gpt2-medium"], |
|
|
help="Models to test") |
|
|
parser.add_argument("--num_samples", type=int, default=20, help="Samples per model") |
|
|
parser.add_argument("--output_file", type=str, default="model_size_comparison.json") |
|
|
args = parser.parse_args() |
|
|
|
|
|
results = {} |
|
|
|
|
|
for model_name in args.models: |
|
|
print() |
|
|
print("="*80) |
|
|
print(f"Testing {model_name}") |
|
|
print("="*80) |
|
|
|
|
|
expressions, stats = generate_expressions(model_name, args.num_samples) |
|
|
|
|
|
results[model_name] = { |
|
|
"stats": stats, |
|
|
"expressions": expressions, |
|
|
} |
|
|
|
|
|
print() |
|
|
print(f"Results for {model_name}:") |
|
|
print(f" Valid: {stats['valid']}/{stats['total']} ({stats['valid_pct']:.1f}%)") |
|
|
print(f" With power: {stats['has_power']} ({stats['has_power_pct']:.1f}%)") |
|
|
print(f" With nested trig: {stats['has_nested_trig']} ({stats['has_nested_trig_pct']:.1f}%)") |
|
|
print(f" Avg depth: {stats['avg_depth']:.2f}") |
|
|
print(f" Max depth: {stats['max_depth']}") |
|
|
|
|
|
|
|
|
print() |
|
|
print("Sample expressions:") |
|
|
valid_exprs = [e for e in expressions if e["is_valid"]][:5] |
|
|
for i, e in enumerate(valid_exprs, 1): |
|
|
print(f" {i}. {e['expression'][:70]}") |
|
|
|
|
|
|
|
|
with open(args.output_file, "w") as f: |
|
|
json.dump(results, f, indent=2) |
|
|
|
|
|
print() |
|
|
print(f"Saved results to {args.output_file}") |
|
|
|
|
|
|
|
|
print() |
|
|
print("="*80) |
|
|
print("COMPARISON") |
|
|
print("="*80) |
|
|
print(f"{'Model':<20} {'Valid%':>8} {'Power%':>8} {'NestedTrig%':>12} {'AvgDepth':>10}") |
|
|
print("-"*80) |
|
|
for model_name, data in results.items(): |
|
|
stats = data["stats"] |
|
|
print(f"{model_name:<20} {stats['valid_pct']:>7.1f}% {stats['has_power_pct']:>7.1f}% {stats['has_nested_trig_pct']:>11.1f}% {stats['avg_depth']:>10.2f}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|