| |
| """ |
| Test different model sizes on expression generation. |
| Compare GPT-2 (124M), GPT-2-medium (355M), GPT-2-large (774M). |
| """ |
|
|
| import os |
| import sys |
| import json |
| import argparse |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| |
| PROJECT_ROOT = Path(__file__).parent.parent |
| sys.path.insert(0, str(PROJECT_ROOT)) |
| sys.path.insert(0, str(PROJECT_ROOT / "classes")) |
|
|
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from expression import Expression |
|
|
|
|
| def generate_expressions(model_name: str, num_samples: int = 20, device: str = None): |
| """Generate expressions with a given model.""" |
|
|
| if device is None: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| device = torch.device(device) |
|
|
| print(f"Loading {model_name}...") |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| model = AutoModelForCausalLM.from_pretrained(model_name) |
| model = model.to(device) |
| model.eval() |
|
|
| |
| vars_list = ["x_1"] |
| ops_list = ["+", "-", "*", "/", "sin", "cos", "sqrt", "log", "exp", "pow"] |
| prompt = json.dumps({"vars": vars_list, "ops": ops_list, "cons": "C", "expr": ""})[:-2] |
|
|
| expressions = [] |
| valid_count = 0 |
| has_power = 0 |
| has_nested_trig = 0 |
| depths = [] |
|
|
| print(f"Generating {num_samples} expressions...") |
|
|
| for i in range(num_samples): |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=50, |
| temperature=0.7, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| expr_str = "" |
| if '"expr": "' in text: |
| start = text.index('"expr": "') + len('"expr": "') |
| remaining = text[start:] |
| for terminator in ['"}', '"']: |
| if terminator in remaining: |
| expr_str = remaining[:remaining.index(terminator)].strip() |
| break |
|
|
| if not expr_str: |
| continue |
|
|
| |
| test_expr = expr_str.replace('C', '1') |
| is_valid = False |
|
|
| try: |
| expr = Expression(test_expr, is_prefix=False) |
| |
| is_valid = True |
| except: |
| pass |
|
|
| |
| if is_valid: |
| valid_count += 1 |
|
|
| if '**' in expr_str or 'pow(' in expr_str: |
| has_power += 1 |
|
|
| if any(nested in expr_str for nested in ['sin(sin', 'sin(cos', 'cos(sin', 'cos(cos']): |
| has_nested_trig += 1 |
|
|
| depth = max(expr_str.count('('), 1) |
| depths.append(depth) |
|
|
| expressions.append({ |
| "expression": expr_str, |
| "is_valid": is_valid, |
| }) |
|
|
| |
| stats = { |
| "model_name": model_name, |
| "total": len(expressions), |
| "valid": valid_count, |
| "valid_pct": 100 * valid_count / len(expressions) if expressions else 0, |
| "has_power": has_power, |
| "has_power_pct": 100 * has_power / valid_count if valid_count > 0 else 0, |
| "has_nested_trig": has_nested_trig, |
| "has_nested_trig_pct": 100 * has_nested_trig / valid_count if valid_count > 0 else 0, |
| "avg_depth": sum(depths) / len(depths) if depths else 0, |
| "max_depth": max(depths) if depths else 0, |
| } |
|
|
| return expressions, stats |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--models", nargs="+", default=["gpt2", "gpt2-medium"], |
| help="Models to test") |
| parser.add_argument("--num_samples", type=int, default=20, help="Samples per model") |
| parser.add_argument("--output_file", type=str, default="model_size_comparison.json") |
| args = parser.parse_args() |
|
|
| results = {} |
|
|
| for model_name in args.models: |
| print() |
| print("="*80) |
| print(f"Testing {model_name}") |
| print("="*80) |
|
|
| expressions, stats = generate_expressions(model_name, args.num_samples) |
|
|
| results[model_name] = { |
| "stats": stats, |
| "expressions": expressions, |
| } |
|
|
| print() |
| print(f"Results for {model_name}:") |
| print(f" Valid: {stats['valid']}/{stats['total']} ({stats['valid_pct']:.1f}%)") |
| print(f" With power: {stats['has_power']} ({stats['has_power_pct']:.1f}%)") |
| print(f" With nested trig: {stats['has_nested_trig']} ({stats['has_nested_trig_pct']:.1f}%)") |
| print(f" Avg depth: {stats['avg_depth']:.2f}") |
| print(f" Max depth: {stats['max_depth']}") |
|
|
| |
| print() |
| print("Sample expressions:") |
| valid_exprs = [e for e in expressions if e["is_valid"]][:5] |
| for i, e in enumerate(valid_exprs, 1): |
| print(f" {i}. {e['expression'][:70]}") |
|
|
| |
| with open(args.output_file, "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| print() |
| print(f"Saved results to {args.output_file}") |
|
|
| |
| print() |
| print("="*80) |
| print("COMPARISON") |
| print("="*80) |
| print(f"{'Model':<20} {'Valid%':>8} {'Power%':>8} {'NestedTrig%':>12} {'AvgDepth':>10}") |
| print("-"*80) |
| for model_name, data in results.items(): |
| stats = data["stats"] |
| print(f"{model_name:<20} {stats['valid_pct']:>7.1f}% {stats['has_power_pct']:>7.1f}% {stats['has_nested_trig_pct']:>11.1f}% {stats['avg_depth']:>10.2f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|