nlac
first commit
2253b0d
"""
Interactive test script for the multiplication LoRA model.
Loads the base model with or without the LoRA adapter and allows testing multiplication queries.
Usage:
python test_multiply.py # With LoRA adapter (default)
python test_multiply.py --no-lora # Without LoRA adapter (base model only)
python test_multiply.py --base # Same as --no-lora
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
sys.path.insert(0, str(Path(__file__).parent))
import argparse
import torch
import random
import re
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import config
def load_model(use_lora: bool = True):
"""
Load the base model with or without the LoRA adapter.
Args:
use_lora: If True, apply the LoRA adapter. If False, use base model only.
Returns:
Tuple of (model, tokenizer)
"""
base_model_name = config.BASE_MODEL
print(f"Loading base model: {base_model_name}")
if use_lora:
lora_path = config.LORA_PATH
print(f"Loading LoRA adapter from: {lora_path}")
else:
print("Running WITHOUT LoRA adapter (base model only)")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load base model
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True,
)
# Load LoRA adapter if requested
if use_lora:
model = PeftModel.from_pretrained(model, str(config.LORA_PATH))
model.eval()
print("Model loaded successfully!")
return model, tokenizer
def generate_answer(model, tokenizer, query: str) -> str:
"""Generate an answer for the given multiplication query."""
# Format as chat message
messages = [
{"role": "system", "content": config.SYSTEM_PROMPT},
{"role": "user", "content": query},
]
# Apply chat template
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=32,
do_sample=False, # Greedy decoding for deterministic results
pad_token_id=tokenizer.pad_token_id,
)
# Decode only the generated part
generated_ids = outputs[0][inputs["input_ids"].shape[1] :]
answer = tokenizer.decode(generated_ids, skip_special_tokens=True)
return answer.strip()
def evaluate_accuracy(model, tokenizer, num_tests: int = 500, seed: int = 42):
"""
Evaluate the model's accuracy on random multiplication problems.
Args:
model: The loaded model with LoRA adapter
tokenizer: The tokenizer
num_tests: Number of random test cases to evaluate
seed: Random seed for reproducibility
Returns:
Dictionary with accuracy metrics
"""
random.seed(seed)
correct = 0
exact_match = 0
off_by_one_digit = 0
results = []
print(f"\nEvaluating accuracy on {num_tests} random multiplication problems...")
print("-" * 50)
for i in range(num_tests):
# Generate random 6-digit number
a = random.randint(100000, 999999)
b = 7
expected = a * b
# Generate prediction
query = f"{a} * {b}"
predicted_str = generate_answer(model, tokenizer, query)
# Extract numeric answer from prediction
predicted_numbers = re.findall(r"\d+", predicted_str)
predicted = int(predicted_numbers[0]) if predicted_numbers else None
# Check correctness
is_correct = predicted == expected
print(predicted, expected, is_correct)
if is_correct:
correct += 1
exact_match += 1
elif predicted is not None:
# Check if off by one digit (common error pattern)
expected_str = str(expected)
predicted_str_clean = str(predicted)
if len(expected_str) == len(predicted_str_clean):
diff_count = sum(
1 for x, y in zip(expected_str, predicted_str_clean) if x != y
)
if diff_count == 1:
off_by_one_digit += 1
results.append(
{
"query": query,
"expected": expected,
"predicted": predicted,
"correct": is_correct,
}
)
# Progress indicator
if (i + 1) % 100 == 0:
print(f" Progress: {i + 1}/{num_tests} ({correct}/{i + 1} correct so far)")
accuracy = correct / num_tests
off_by_one_rate = off_by_one_digit / num_tests
print("-" * 50)
print(f"\nResults:")
print(f" Exact match accuracy: {accuracy * 100:.2f}% ({correct}/{num_tests})")
print(
f" Off by one digit: {off_by_one_rate * 100:.2f}% ({off_by_one_digit}/{num_tests})"
)
print(f" Total near-correct: {(accuracy + off_by_one_rate) * 100:.2f}%")
# Show some examples of errors
errors = [r for r in results if not r["correct"]]
if errors:
print(f"\nSample errors (showing up to 5):")
for err in errors[:5]:
print(
f" {err['query']} = {err['expected']} (predicted: {err['predicted']})"
)
return {
"accuracy": accuracy,
"correct": correct,
"total": num_tests,
"off_by_one_digit": off_by_one_digit,
"results": results,
}
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Test the multiplication model with or without LoRA adapter"
)
parser.add_argument(
"--no-lora",
"--base",
action="store_true",
dest="no_lora",
help="Run without LoRA adapter (base model only)",
)
return parser.parse_args()
def main():
args = parse_args()
use_lora = not args.no_lora
print("=" * 60)
if use_lora:
print("Multiplication LoRA Model Tester")
else:
print("Multiplication Base Model Tester (NO LoRA)")
print("=" * 60)
# Check CUDA
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
# Load model
print("\nLoading model...")
model, tokenizer = load_model(use_lora=use_lora)
print("\n" + "=" * 60)
print("Commands:")
print(" - Enter multiplication queries like '134345 * 7' or '923256 * 7'")
print(" - Type 'eval' or 'eval 1000' to run accuracy evaluation")
print(" - Type 'quit' or 'exit' to stop")
print("=" * 60 + "\n")
while True:
try:
query = input("Query: ").strip()
if not query:
continue
if query.lower() in ("quit", "exit", "q"):
print("Goodbye!")
break
# Handle eval command
if query.lower().startswith("eval"):
parts = query.split()
num_tests = int(parts[1]) if len(parts) > 1 else 500
evaluate_accuracy(model, tokenizer, num_tests=num_tests)
print()
continue
# Generate answer
answer = generate_answer(model, tokenizer, query)
print(f"Answer: {answer}")
# Try to verify the result if it looks like a multiplication
try:
# Parse the query to extract numbers
numbers = re.findall(r"\d+", query)
if len(numbers) >= 2:
a, b = int(numbers[0]), int(numbers[1])
expected = a * b
# Try to extract the numeric answer
answer_numbers = re.findall(r"\d+", answer)
if answer_numbers:
predicted = int(answer_numbers[-1])
if predicted == expected:
print(f" ✓ Correct! ({a} × {b} = {expected})")
else:
print(f" ✗ Expected: {expected}")
except Exception:
pass # Silently ignore verification errors
print()
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()