|
|
""" |
|
|
Baseline evaluation: Vanilla SmolLM2-360M on arithmetic |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import random |
|
|
import re |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
DEVICE = "cuda" |
|
|
MODEL_ID = "HuggingFaceTB/SmolLM2-360M-Instruct" |
|
|
|
|
|
SYSTEM_PROMPT = """You are a calculator. Output only the numeric answer. No words, no explanation, just digits. Examples: |
|
|
User: 5 + 3 |
|
|
Assistant: 8 |
|
|
User: 12 * 7 |
|
|
Assistant: 84 |
|
|
User: 100 > 50 |
|
|
Assistant: 1 |
|
|
User: 25 < 10 |
|
|
Assistant: 0""" |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
print(f"Loading {MODEL_ID}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
tokenizer.padding_side = "left" |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.float16, |
|
|
device_map=DEVICE |
|
|
) |
|
|
model.eval() |
|
|
print(f" Loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def format_prompt(tokenizer, op_str): |
|
|
messages = [ |
|
|
{"role": "system", "content": SYSTEM_PROMPT}, |
|
|
{"role": "user", "content": op_str} |
|
|
] |
|
|
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
|
|
|
def generate_batch(model, tokenizer, prompts, max_new_tokens=16): |
|
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(DEVICE) |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=False, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
responses = [] |
|
|
for i, output in enumerate(outputs): |
|
|
response = tokenizer.decode(output[inputs.input_ids.shape[1]:], skip_special_tokens=True) |
|
|
responses.append(response.strip()) |
|
|
return responses |
|
|
|
|
|
|
|
|
def extract_answer(text): |
|
|
"""Generous extraction - find any number in output""" |
|
|
text = text.strip().lower() |
|
|
if not text: |
|
|
return None |
|
|
|
|
|
|
|
|
if text in ['yes', 'true', '1']: |
|
|
return 1 |
|
|
if text in ['no', 'false', '0']: |
|
|
return 0 |
|
|
if text.startswith('yes'): |
|
|
return 1 |
|
|
if text.startswith('no'): |
|
|
return 0 |
|
|
|
|
|
|
|
|
numbers = re.findall(r'-?\d+', text) |
|
|
if numbers: |
|
|
return int(numbers[-1]) |
|
|
return None |
|
|
|
|
|
|
|
|
def ground_truth(a, b, op): |
|
|
"""Compute expected result (8-bit where applicable)""" |
|
|
if op == 'add': |
|
|
return (a + b) & 0xFF |
|
|
elif op == 'sub': |
|
|
return (a - b) & 0xFF |
|
|
elif op == 'mul': |
|
|
return (a * b) & 0xFF |
|
|
elif op == 'div': |
|
|
return a // b if b != 0 else 0 |
|
|
elif op == 'and': |
|
|
return a & b |
|
|
elif op == 'or': |
|
|
return a | b |
|
|
elif op == 'xor': |
|
|
return a ^ b |
|
|
elif op == 'gt': |
|
|
return 1 if a > b else 0 |
|
|
elif op == 'lt': |
|
|
return 1 if a < b else 0 |
|
|
elif op == 'eq': |
|
|
return 1 if a == b else 0 |
|
|
elif op == 'ge': |
|
|
return 1 if a >= b else 0 |
|
|
elif op == 'le': |
|
|
return 1 if a <= b else 0 |
|
|
else: |
|
|
raise ValueError(f"Unknown op: {op}") |
|
|
|
|
|
|
|
|
def op_to_str(a, b, op): |
|
|
"""Convert operation to natural string""" |
|
|
symbols = { |
|
|
'add': '+', 'sub': '-', 'mul': '*', 'div': '/', |
|
|
'and': '&', 'or': '|', 'xor': '^', |
|
|
'gt': '>', 'lt': '<', 'eq': '==', 'ge': '>=', 'le': '<=' |
|
|
} |
|
|
return f"{a} {symbols[op]} {b}" |
|
|
|
|
|
|
|
|
def evaluate(model, tokenizer, n_samples=1000, batch_size=32, ops=None): |
|
|
if ops is None: |
|
|
ops = ['add', 'sub', 'mul', 'gt', 'lt', 'eq'] |
|
|
|
|
|
results = {op: {'correct': 0, 'total': 0} for op in ops} |
|
|
all_correct = 0 |
|
|
all_total = 0 |
|
|
|
|
|
samples = [] |
|
|
for _ in range(n_samples): |
|
|
a = random.randint(0, 255) |
|
|
b = random.randint(0, 255) |
|
|
if 'div' in ops and random.random() < 0.1: |
|
|
op = 'div' |
|
|
b = random.randint(1, 255) |
|
|
else: |
|
|
op = random.choice([o for o in ops if o != 'div']) |
|
|
samples.append((a, b, op)) |
|
|
|
|
|
print(f"\nEvaluating {n_samples} samples (batch_size={batch_size})...") |
|
|
|
|
|
for batch_start in range(0, n_samples, batch_size): |
|
|
batch = samples[batch_start:batch_start + batch_size] |
|
|
prompts = [format_prompt(tokenizer, op_to_str(a, b, op)) for a, b, op in batch] |
|
|
responses = generate_batch(model, tokenizer, prompts) |
|
|
|
|
|
for (a, b, op), response in zip(batch, responses): |
|
|
expected = ground_truth(a, b, op) |
|
|
extracted = extract_answer(response) |
|
|
|
|
|
correct = (extracted == expected) |
|
|
results[op]['total'] += 1 |
|
|
all_total += 1 |
|
|
if correct: |
|
|
results[op]['correct'] += 1 |
|
|
all_correct += 1 |
|
|
|
|
|
if (batch_start + batch_size) % 200 == 0 or batch_start + batch_size >= n_samples: |
|
|
pct = 100 * all_correct / all_total |
|
|
print(f" Progress: {min(batch_start + batch_size, n_samples)}/{n_samples} | Accuracy: {pct:.2f}%") |
|
|
|
|
|
return results, all_correct, all_total |
|
|
|
|
|
|
|
|
def main(): |
|
|
random.seed(42) |
|
|
torch.manual_seed(42) |
|
|
|
|
|
model, tokenizer = load_model() |
|
|
|
|
|
|
|
|
print("\nSanity check (5 examples):") |
|
|
test_cases = [ |
|
|
("5 + 3", 8), |
|
|
("100 - 37", 63), |
|
|
("12 * 11", 132), |
|
|
("50 > 30", 1), |
|
|
("25 < 10", 0), |
|
|
] |
|
|
prompts = [format_prompt(tokenizer, q) for q, _ in test_cases] |
|
|
responses = generate_batch(model, tokenizer, prompts) |
|
|
for (q, expected), response in zip(test_cases, responses): |
|
|
extracted = extract_answer(response) |
|
|
status = "OK" if extracted == expected else "FAIL" |
|
|
print(f" {q} = {expected} | Model: '{response}' -> {extracted} [{status}]") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print(" BASELINE EVALUATION") |
|
|
print("=" * 60) |
|
|
|
|
|
ops = ['add', 'sub', 'mul', 'gt', 'lt', 'eq'] |
|
|
results, correct, total = evaluate(model, tokenizer, n_samples=2000, batch_size=64, ops=ops) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print(" RESULTS BY OPERATION") |
|
|
print("=" * 60) |
|
|
for op in ops: |
|
|
r = results[op] |
|
|
pct = 100 * r['correct'] / r['total'] if r['total'] > 0 else 0 |
|
|
print(f" {op:6}: {r['correct']:4}/{r['total']:4} ({pct:6.2f}%)") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print(" OVERALL") |
|
|
print("=" * 60) |
|
|
fitness = correct / total |
|
|
print(f" Correct: {correct}/{total}") |
|
|
print(f" Fitness: {fitness:.4f} ({100*fitness:.2f}%)") |
|
|
print("=" * 60) |
|
|
|
|
|
return fitness |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|