|
|
|
|
|
""" |
|
|
Simple comparison of V1 vs V2 model generation quality |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import torch |
|
|
from pathlib import Path |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList |
|
|
from peft import PeftModel |
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
from classes.expression import Expression |
|
|
|
|
|
|
|
|
class ExpressionStoppingCriteria(StoppingCriteria): |
|
|
def __init__(self, tokenizer, stop_sequences): |
|
|
self.tokenizer = tokenizer |
|
|
self.stop_ids = [tokenizer.encode(seq, add_special_tokens=False) |
|
|
for seq in stop_sequences] |
|
|
|
|
|
def __call__(self, input_ids, scores, **kwargs): |
|
|
for stop_ids in self.stop_ids: |
|
|
if len(stop_ids) > 0 and len(input_ids[0]) >= len(stop_ids): |
|
|
if input_ids[0][-len(stop_ids):].tolist() == stop_ids: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def load_model(model_name, model_label): |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Loading {model_label}: {model_name}") |
|
|
print('='*60) |
|
|
|
|
|
|
|
|
print("Loading base GPT-2...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"gpt2", |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
tokenizer.add_special_tokens({ |
|
|
"additional_special_tokens": ["<|startofex|>", "<|endofex|>"] |
|
|
}) |
|
|
|
|
|
|
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
|
|
|
print(f"Loading adapter from {model_name}...") |
|
|
model = PeftModel.from_pretrained(model, model_name) |
|
|
print("Merging adapter...") |
|
|
model = model.merge_and_unload() |
|
|
model.eval() |
|
|
|
|
|
print(f"β {model_label} loaded successfully") |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def test_model(model, tokenizer, model_label, n_samples=20): |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Testing {model_label} - {n_samples} generations") |
|
|
print('='*60) |
|
|
|
|
|
|
|
|
prompt = """vars: x_1, x_2 |
|
|
oper: *, +, -, sin, cos |
|
|
cons: C |
|
|
expr:""" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
stopping_criteria = StoppingCriteriaList([ |
|
|
ExpressionStoppingCriteria(tokenizer, ["<|endofex|>", "\n\nvars:"]) |
|
|
]) |
|
|
|
|
|
|
|
|
if model_label == "V1": |
|
|
|
|
|
gen_config = { |
|
|
"temperature": 0.5, |
|
|
"top_k": 40, |
|
|
"top_p": 0.9, |
|
|
"repetition_penalty": 1.15, |
|
|
"max_new_tokens": 100, |
|
|
"do_sample": True, |
|
|
"pad_token_id": tokenizer.eos_token_id, |
|
|
} |
|
|
print("Using V1 optimal config: temp=0.5, top_k=40, rep_penalty=1.15") |
|
|
else: |
|
|
|
|
|
gen_config = { |
|
|
"temperature": 0.7, |
|
|
"top_k": 0, |
|
|
"top_p": 0.8, |
|
|
"repetition_penalty": 1.0, |
|
|
"max_new_tokens": 128, |
|
|
"do_sample": True, |
|
|
"pad_token_id": tokenizer.eos_token_id, |
|
|
} |
|
|
print("Using V2 optimal config: temp=0.7, top_p=0.8 (nucleus sampling)") |
|
|
|
|
|
results = { |
|
|
"valid_count": 0, |
|
|
"correct_symbols_count": 0, |
|
|
"expressions": [] |
|
|
} |
|
|
|
|
|
allowed_vars = {"x_1", "x_2", "C"} |
|
|
allowed_ops = {"*", "+", "-", "sin", "cos", "(", ")"} |
|
|
|
|
|
print(f"\nGenerating {n_samples} expressions...\n") |
|
|
|
|
|
for i in range(n_samples): |
|
|
output = model.generate( |
|
|
**inputs, |
|
|
**gen_config, |
|
|
stopping_criteria=stopping_criteria |
|
|
) |
|
|
text = tokenizer.decode(output[0], skip_special_tokens=False) |
|
|
|
|
|
|
|
|
if "expr:" in text: |
|
|
expr_str = text.split("expr:")[-1].strip() |
|
|
expr_str = expr_str.split("<|endofex|>")[0].strip() |
|
|
else: |
|
|
expr_str = text |
|
|
|
|
|
|
|
|
is_valid = False |
|
|
try: |
|
|
expr = Expression(expr_str, is_prefix=False) |
|
|
X_test = [[1.0, 2.0]] |
|
|
result = expr.evaluate(X_test) |
|
|
if len(result) > 0 and all(x != float('inf') and x != float('-inf') and x == x for x in result): |
|
|
is_valid = True |
|
|
results["valid_count"] += 1 |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
has_correct_symbols = True |
|
|
|
|
|
expr_clean = expr_str.replace(" ", "") |
|
|
|
|
|
for char in expr_clean: |
|
|
if char.isalpha() and char not in "xCsinco_": |
|
|
has_correct_symbols = False |
|
|
break |
|
|
|
|
|
|
|
|
garbage_words = ["Buyable", "Instore", "Online", "Muslims", "crash", "Berman", |
|
|
"vars:", "oper:", "expressed", "fluent", "Avenger", "repositories"] |
|
|
for word in garbage_words: |
|
|
if word in expr_str: |
|
|
has_correct_symbols = False |
|
|
break |
|
|
|
|
|
if has_correct_symbols: |
|
|
results["correct_symbols_count"] += 1 |
|
|
|
|
|
results["expressions"].append({ |
|
|
"index": i + 1, |
|
|
"expression": expr_str[:80], |
|
|
"valid": is_valid, |
|
|
"correct_symbols": has_correct_symbols |
|
|
}) |
|
|
|
|
|
|
|
|
if i < 5: |
|
|
status = "β Valid" if is_valid else "β Invalid" |
|
|
symbols = "β Clean" if has_correct_symbols else "β Garbage" |
|
|
print(f" [{i+1:2d}] {status:10s} {symbols:10s} | {expr_str[:60]}") |
|
|
|
|
|
print(f"\n{'-'*60}") |
|
|
print(f"RESULTS FOR {model_label}:") |
|
|
print(f" Valid expressions: {results['valid_count']:2d}/{n_samples} ({results['valid_count']/n_samples*100:.1f}%)") |
|
|
print(f" Correct symbols only: {results['correct_symbols_count']:2d}/{n_samples} ({results['correct_symbols_count']/n_samples*100:.1f}%)") |
|
|
print(f"{'-'*60}") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def main(): |
|
|
print("\n" + "="*60) |
|
|
print("V1 vs V2 MODEL COMPARISON") |
|
|
print("="*60) |
|
|
print("Testing same prompt on both models") |
|
|
print("Measuring: valid expressions + symbol correctness\n") |
|
|
|
|
|
|
|
|
v1_model, v1_tokenizer = load_model("augustocsc/Se124M_700K_infix", "V1") |
|
|
v1_results = test_model(v1_model, v1_tokenizer, "V1", n_samples=20) |
|
|
|
|
|
|
|
|
del v1_model |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
v2_model, v2_tokenizer = load_model("augustocsc/Se124M_700K_infix_v2", "V2") |
|
|
v2_results = test_model(v2_model, v2_tokenizer, "V2", n_samples=20) |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("FINAL COMPARISON") |
|
|
print("="*60) |
|
|
print(f"\n{'Metric':<30s} {'V1':>10s} {'V2':>10s} {'Winner':>10s}") |
|
|
print("-"*60) |
|
|
|
|
|
v1_valid = v1_results["valid_count"] |
|
|
v2_valid = v2_results["valid_count"] |
|
|
valid_winner = "V1" if v1_valid > v2_valid else ("V2" if v2_valid > v1_valid else "TIE") |
|
|
print(f"{'Valid Expressions':<30s} {v1_valid:>10d} {v2_valid:>10d} {valid_winner:>10s}") |
|
|
|
|
|
v1_clean = v1_results["correct_symbols_count"] |
|
|
v2_clean = v2_results["correct_symbols_count"] |
|
|
clean_winner = "V1" if v1_clean > v2_clean else ("V2" if v2_clean > v1_clean else "TIE") |
|
|
print(f"{'Correct Symbols Only':<30s} {v1_clean:>10d} {v2_clean:>10d} {clean_winner:>10s}") |
|
|
|
|
|
print("-"*60) |
|
|
print(f"{'Valid Rate':<30s} {v1_valid/20*100:>9.1f}% {v2_valid/20*100:>9.1f}%") |
|
|
print(f"{'Clean Symbol Rate':<30s} {v1_clean/20*100:>9.1f}% {v2_clean/20*100:>9.1f}%") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
print("\nConclusion:") |
|
|
if v1_valid > v2_valid and v1_clean > v2_clean: |
|
|
print(" β V1 is better on both metrics") |
|
|
elif v2_valid > v1_valid and v2_clean > v1_clean: |
|
|
print(" β V2 is better on both metrics") |
|
|
else: |
|
|
print(" β Mixed results - models have different strengths") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|