|
|
""" |
|
|
Test different inference configurations to find optimal generation parameters. |
|
|
|
|
|
This script tests various combinations of: |
|
|
- Temperature (sampling randomness) |
|
|
- Top-k and top-p (nucleus sampling) |
|
|
- Repetition penalty |
|
|
- Max length |
|
|
- Stopping criteria |
|
|
|
|
|
Usage: |
|
|
python scripts/test_inference_configs.py \ |
|
|
--model_path ./output/Se124M_700K_infix_v3 \ |
|
|
--num_samples 20 \ |
|
|
--output_dir ./inference_tests/v3 |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Any |
|
|
import time |
|
|
|
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
StoppingCriteria, |
|
|
StoppingCriteriaList, |
|
|
) |
|
|
from peft import PeftModel |
|
|
import pandas as pd |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ExpressionStoppingCriteria(StoppingCriteria): |
|
|
"""Stop generation at <|endofex|> token.""" |
|
|
|
|
|
def __init__(self, tokenizer, prompt_length: int): |
|
|
self.tokenizer = tokenizer |
|
|
self.prompt_length = prompt_length |
|
|
self.end_token_id = tokenizer.encode("<|endofex|>", add_special_tokens=False) |
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
|
|
|
|
if input_ids.shape[1] <= self.prompt_length: |
|
|
return False |
|
|
|
|
|
|
|
|
recent_tokens = input_ids[0, -len(self.end_token_id):].tolist() |
|
|
return recent_tokens == self.end_token_id |
|
|
|
|
|
|
|
|
|
|
|
INFERENCE_CONFIGS = { |
|
|
"default": { |
|
|
"temperature": 1.0, |
|
|
"top_k": 50, |
|
|
"top_p": 1.0, |
|
|
"repetition_penalty": 1.0, |
|
|
"max_new_tokens": 128, |
|
|
"do_sample": True, |
|
|
"description": "Default transformers settings" |
|
|
}, |
|
|
"greedy": { |
|
|
"temperature": 1.0, |
|
|
"top_k": 1, |
|
|
"top_p": 1.0, |
|
|
"repetition_penalty": 1.0, |
|
|
"max_new_tokens": 128, |
|
|
"do_sample": False, |
|
|
"description": "Greedy decoding (no sampling)" |
|
|
}, |
|
|
"low_temp": { |
|
|
"temperature": 0.3, |
|
|
"top_k": 50, |
|
|
"top_p": 0.9, |
|
|
"repetition_penalty": 1.0, |
|
|
"max_new_tokens": 128, |
|
|
"do_sample": True, |
|
|
"description": "Low temperature for more focused output" |
|
|
}, |
|
|
"high_temp": { |
|
|
"temperature": 1.5, |
|
|
"top_k": 50, |
|
|
"top_p": 0.95, |
|
|
"repetition_penalty": 1.0, |
|
|
"max_new_tokens": 128, |
|
|
"do_sample": True, |
|
|
"description": "Higher temperature for more diversity" |
|
|
}, |
|
|
"nucleus_strict": { |
|
|
"temperature": 0.7, |
|
|
"top_k": 0, |
|
|
"top_p": 0.8, |
|
|
"repetition_penalty": 1.0, |
|
|
"max_new_tokens": 128, |
|
|
"do_sample": True, |
|
|
"description": "Strict nucleus sampling (top-p=0.8)" |
|
|
}, |
|
|
"nucleus_relaxed": { |
|
|
"temperature": 0.7, |
|
|
"top_k": 0, |
|
|
"top_p": 0.95, |
|
|
"repetition_penalty": 1.0, |
|
|
"max_new_tokens": 128, |
|
|
"do_sample": True, |
|
|
"description": "Relaxed nucleus sampling (top-p=0.95)" |
|
|
}, |
|
|
"with_repetition_penalty": { |
|
|
"temperature": 0.7, |
|
|
"top_k": 50, |
|
|
"top_p": 0.9, |
|
|
"repetition_penalty": 1.2, |
|
|
"max_new_tokens": 128, |
|
|
"do_sample": True, |
|
|
"description": "With repetition penalty to avoid loops" |
|
|
}, |
|
|
"strong_repetition_penalty": { |
|
|
"temperature": 0.7, |
|
|
"top_k": 50, |
|
|
"top_p": 0.9, |
|
|
"repetition_penalty": 1.5, |
|
|
"max_new_tokens": 128, |
|
|
"do_sample": True, |
|
|
"description": "Strong repetition penalty" |
|
|
}, |
|
|
"short_generation": { |
|
|
"temperature": 0.7, |
|
|
"top_k": 50, |
|
|
"top_p": 0.9, |
|
|
"repetition_penalty": 1.1, |
|
|
"max_new_tokens": 64, |
|
|
"do_sample": True, |
|
|
"description": "Shorter max length (64 tokens)" |
|
|
}, |
|
|
"optimized": { |
|
|
"temperature": 0.5, |
|
|
"top_k": 40, |
|
|
"top_p": 0.9, |
|
|
"repetition_penalty": 1.15, |
|
|
"max_new_tokens": 100, |
|
|
"do_sample": True, |
|
|
"description": "Optimized settings (balanced)" |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def load_model_and_tokenizer(model_path: str, base_model: str = "gpt2"): |
|
|
"""Load model and tokenizer, handling both base and LoRA models.""" |
|
|
logger.info(f"Loading model from {model_path}...") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model) |
|
|
|
|
|
|
|
|
special_tokens = { |
|
|
"additional_special_tokens": ["<|startofex|>", "<|endofex|>"] |
|
|
} |
|
|
tokenizer.add_special_tokens(special_tokens) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
try: |
|
|
base = AutoModelForCausalLM.from_pretrained( |
|
|
base_model, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
) |
|
|
base.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
model = PeftModel.from_pretrained(base, model_path) |
|
|
model = model.merge_and_unload() |
|
|
logger.info("Loaded as LoRA model and merged") |
|
|
except Exception as e: |
|
|
|
|
|
logger.info(f"Loading as regular model (not LoRA): {e}") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
logger.info(f"Model loaded on: {model.device}") |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def generate_with_config( |
|
|
model, |
|
|
tokenizer, |
|
|
prompt: str, |
|
|
config: Dict[str, Any], |
|
|
use_stopping_criteria: bool = True |
|
|
) -> tuple[str, Dict[str, Any]]: |
|
|
"""Generate text with specific configuration.""" |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
if torch.cuda.is_available(): |
|
|
inputs = {k: v.cuda() for k, v in inputs.items()} |
|
|
|
|
|
prompt_length = inputs["input_ids"].shape[1] |
|
|
|
|
|
|
|
|
stopping_criteria = None |
|
|
if use_stopping_criteria: |
|
|
stopping_criteria = StoppingCriteriaList([ |
|
|
ExpressionStoppingCriteria(tokenizer, prompt_length) |
|
|
]) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
**{k: v for k, v in config.items() if k != "description"}, |
|
|
stopping_criteria=stopping_criteria, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
generation_time = time.time() - start_time |
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False) |
|
|
|
|
|
|
|
|
generated_only = tokenizer.decode( |
|
|
outputs[0][prompt_length:], |
|
|
skip_special_tokens=False |
|
|
) |
|
|
|
|
|
|
|
|
stats = { |
|
|
"total_tokens": outputs.shape[1], |
|
|
"generated_tokens": outputs.shape[1] - prompt_length, |
|
|
"generation_time": generation_time, |
|
|
"tokens_per_second": (outputs.shape[1] - prompt_length) / generation_time, |
|
|
} |
|
|
|
|
|
return generated_only, stats |
|
|
|
|
|
|
|
|
def extract_expression(generated_text: str) -> tuple[str, str]: |
|
|
"""Extract expression from generated text.""" |
|
|
|
|
|
|
|
|
if "<|endofex|>" in generated_text: |
|
|
expr = generated_text.split("<|endofex|>")[0].strip() |
|
|
|
|
|
if "expr:" in expr: |
|
|
expr = expr.split("expr:")[-1].strip() |
|
|
return expr, "marker" |
|
|
|
|
|
|
|
|
if "expr:" in generated_text: |
|
|
parts = generated_text.split("expr:") |
|
|
if len(parts) > 1: |
|
|
|
|
|
expr = parts[1].split("\n")[0].strip() |
|
|
expr = expr.split("vars:")[0].strip() |
|
|
return expr, "prefix" |
|
|
|
|
|
|
|
|
first_line = generated_text.split("\n")[0].strip() |
|
|
if first_line: |
|
|
return first_line, "first_line" |
|
|
|
|
|
return generated_text.strip(), "raw" |
|
|
|
|
|
|
|
|
def validate_expression(expr: str) -> Dict[str, Any]: |
|
|
"""Simple validation of expression quality.""" |
|
|
issues = [] |
|
|
|
|
|
|
|
|
if len(expr) > 10: |
|
|
for i in range(len(expr) - 5): |
|
|
substr = expr[i:i+3] |
|
|
if expr.count(substr) > 3: |
|
|
issues.append(f"repetition: '{substr}'") |
|
|
break |
|
|
|
|
|
|
|
|
if "<|endofex|>" in expr: |
|
|
issues.append("marker_in_expression") |
|
|
|
|
|
|
|
|
garbage_tokens = [ |
|
|
"Buyable", "Instore", "AndOnline", "Store", "Online", |
|
|
"Product", "Available", "Shopping" |
|
|
] |
|
|
for token in garbage_tokens: |
|
|
if token in expr: |
|
|
issues.append(f"garbage: {token}") |
|
|
|
|
|
|
|
|
valid_operators = ["sin", "cos", "tan", "log", "exp", "sqrt", "abs", "+", "-", "*", "/", "**"] |
|
|
has_operator = any(op in expr for op in valid_operators) |
|
|
|
|
|
|
|
|
has_variable = any(f"x_{i}" in expr or f"C" in expr for i in range(1, 20)) |
|
|
|
|
|
return { |
|
|
"is_valid": len(issues) == 0 and has_operator and has_variable, |
|
|
"has_operator": has_operator, |
|
|
"has_variable": has_variable, |
|
|
"issues": issues, |
|
|
"length": len(expr), |
|
|
} |
|
|
|
|
|
|
|
|
def test_configurations( |
|
|
model, |
|
|
tokenizer, |
|
|
test_prompts: List[str], |
|
|
output_dir: Path, |
|
|
configs_to_test: List[str] = None |
|
|
): |
|
|
"""Test all configurations on test prompts.""" |
|
|
|
|
|
if configs_to_test is None: |
|
|
configs_to_test = list(INFERENCE_CONFIGS.keys()) |
|
|
|
|
|
results = [] |
|
|
|
|
|
logger.info(f"\nTesting {len(configs_to_test)} configurations on {len(test_prompts)} prompts...") |
|
|
|
|
|
for config_name in configs_to_test: |
|
|
config = INFERENCE_CONFIGS[config_name] |
|
|
logger.info(f"\n{'='*60}") |
|
|
logger.info(f"Testing config: {config_name}") |
|
|
logger.info(f"Description: {config['description']}") |
|
|
logger.info(f"{'='*60}") |
|
|
|
|
|
config_results = [] |
|
|
|
|
|
for i, prompt in enumerate(test_prompts): |
|
|
logger.info(f"\nPrompt {i+1}/{len(test_prompts)}: {prompt[:50]}...") |
|
|
|
|
|
|
|
|
try: |
|
|
generated, stats = generate_with_config( |
|
|
model, tokenizer, prompt, config, use_stopping_criteria=True |
|
|
) |
|
|
|
|
|
|
|
|
expr, extraction_method = extract_expression(generated) |
|
|
|
|
|
|
|
|
validation = validate_expression(expr) |
|
|
|
|
|
result = { |
|
|
"config_name": config_name, |
|
|
"config_description": config["description"], |
|
|
"prompt": prompt, |
|
|
"generated_raw": generated[:200], |
|
|
"expression": expr[:200], |
|
|
"extraction_method": extraction_method, |
|
|
"is_valid": validation["is_valid"], |
|
|
"has_operator": validation["has_operator"], |
|
|
"has_variable": validation["has_variable"], |
|
|
"issues": ", ".join(validation["issues"]) if validation["issues"] else "", |
|
|
"expr_length": validation["length"], |
|
|
"total_tokens": stats["total_tokens"], |
|
|
"generated_tokens": stats["generated_tokens"], |
|
|
"generation_time": stats["generation_time"], |
|
|
"tokens_per_second": stats["tokens_per_second"], |
|
|
} |
|
|
|
|
|
config_results.append(result) |
|
|
results.append(result) |
|
|
|
|
|
|
|
|
status = "✅ VALID" if validation["is_valid"] else "❌ INVALID" |
|
|
logger.info(f" {status}: {expr[:80]}") |
|
|
if validation["issues"]: |
|
|
logger.info(f" Issues: {', '.join(validation['issues'])}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating with {config_name}: {e}") |
|
|
results.append({ |
|
|
"config_name": config_name, |
|
|
"config_description": config["description"], |
|
|
"prompt": prompt, |
|
|
"error": str(e), |
|
|
"is_valid": False, |
|
|
}) |
|
|
|
|
|
|
|
|
valid_count = sum(1 for r in config_results if r.get("is_valid", False)) |
|
|
valid_rate = valid_count / len(config_results) * 100 if config_results else 0 |
|
|
|
|
|
avg_tokens = sum(r.get("generated_tokens", 0) for r in config_results) / len(config_results) if config_results else 0 |
|
|
avg_time = sum(r.get("generation_time", 0) for r in config_results) / len(config_results) if config_results else 0 |
|
|
|
|
|
logger.info(f"\n{'='*60}") |
|
|
logger.info(f"Config {config_name} Summary:") |
|
|
logger.info(f" Valid: {valid_count}/{len(config_results)} ({valid_rate:.1f}%)") |
|
|
logger.info(f" Avg tokens: {avg_tokens:.1f}") |
|
|
logger.info(f" Avg time: {avg_time:.3f}s") |
|
|
logger.info(f"{'='*60}\n") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Test different inference configurations" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model_path", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Path to model (local or HuggingFace Hub)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--base_model", |
|
|
type=str, |
|
|
default="gpt2", |
|
|
help="Base model for LoRA" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_samples", |
|
|
type=int, |
|
|
default=20, |
|
|
help="Number of test prompts to generate" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Directory to save results" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--configs", |
|
|
type=str, |
|
|
nargs="+", |
|
|
default=None, |
|
|
help="Specific configs to test (default: all)" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
output_dir = Path(args.output_dir) |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer(args.model_path, args.base_model) |
|
|
|
|
|
|
|
|
test_prompts = [ |
|
|
"vars: x_1, x_2, x_3\noper: *, +, -, sin, cos, log\ncons: C\nexpr:", |
|
|
"vars: x_1, x_2\noper: *, **, exp, log\ncons: C\nexpr:", |
|
|
"vars: x_1, x_2, x_3, x_4\noper: *, +, /, sqrt, abs\ncons: C\nexpr:", |
|
|
"vars: x_1\noper: sin, cos, exp\ncons: C\nexpr:", |
|
|
"vars: x_1, x_2, x_3\noper: *, +, -, tan\ncons: C\nexpr:", |
|
|
] * (args.num_samples // 5 + 1) |
|
|
test_prompts = test_prompts[:args.num_samples] |
|
|
|
|
|
|
|
|
results = test_configurations( |
|
|
model, |
|
|
tokenizer, |
|
|
test_prompts, |
|
|
output_dir, |
|
|
args.configs |
|
|
) |
|
|
|
|
|
|
|
|
df = pd.DataFrame(results) |
|
|
results_file = output_dir / "inference_config_results.csv" |
|
|
df.to_csv(results_file, index=False) |
|
|
logger.info(f"\nDetailed results saved to: {results_file}") |
|
|
|
|
|
|
|
|
summary = {} |
|
|
for config_name in df["config_name"].unique(): |
|
|
config_df = df[df["config_name"] == config_name] |
|
|
summary[config_name] = { |
|
|
"description": config_df["config_description"].iloc[0] if len(config_df) > 0 else "", |
|
|
"valid_rate": (config_df["is_valid"].sum() / len(config_df) * 100) if len(config_df) > 0 else 0, |
|
|
"total_samples": len(config_df), |
|
|
"valid_count": int(config_df["is_valid"].sum()), |
|
|
"avg_tokens": float(config_df["generated_tokens"].mean()) if "generated_tokens" in config_df else 0, |
|
|
"avg_time": float(config_df["generation_time"].mean()) if "generation_time" in config_df else 0, |
|
|
"common_issues": config_df["issues"].value_counts().head(3).to_dict() if "issues" in config_df else {}, |
|
|
} |
|
|
|
|
|
|
|
|
summary = dict(sorted(summary.items(), key=lambda x: x[1]["valid_rate"], reverse=True)) |
|
|
|
|
|
summary_file = output_dir / "inference_config_summary.json" |
|
|
with open(summary_file, "w") as f: |
|
|
json.dump(summary, f, indent=2) |
|
|
logger.info(f"Summary saved to: {summary_file}") |
|
|
|
|
|
|
|
|
logger.info("\n" + "="*60) |
|
|
logger.info("FINAL SUMMARY") |
|
|
logger.info("="*60) |
|
|
for config_name, stats in summary.items(): |
|
|
logger.info(f"\n{config_name}:") |
|
|
logger.info(f" Description: {stats['description']}") |
|
|
logger.info(f" Valid rate: {stats['valid_rate']:.1f}% ({stats['valid_count']}/{stats['total_samples']})") |
|
|
logger.info(f" Avg tokens: {stats['avg_tokens']:.1f}") |
|
|
logger.info(f" Avg time: {stats['avg_time']:.3f}s") |
|
|
|
|
|
logger.info("\n" + "="*60) |
|
|
logger.info("Testing complete!") |
|
|
logger.info("="*60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|