""" Interactive test script for the multiplication LoRA model. Loads the base model with or without the LoRA adapter and allows testing multiplication queries. Usage: python test_multiply.py # With LoRA adapter (default) python test_multiply.py --no-lora # Without LoRA adapter (base model only) python test_multiply.py --base # Same as --no-lora """ import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent)) import argparse import torch import random import re from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import config def load_model(use_lora: bool = True): """ Load the base model with or without the LoRA adapter. Args: use_lora: If True, apply the LoRA adapter. If False, use base model only. Returns: Tuple of (model, tokenizer) """ base_model_name = config.BASE_MODEL print(f"Loading base model: {base_model_name}") if use_lora: lora_path = config.LORA_PATH print(f"Loading LoRA adapter from: {lora_path}") else: print("Running WITHOUT LoRA adapter (base model only)") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(base_model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load base model model = AutoModelForCausalLM.from_pretrained( base_model_name, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto", trust_remote_code=True, ) # Load LoRA adapter if requested if use_lora: model = PeftModel.from_pretrained(model, str(config.LORA_PATH)) model.eval() print("Model loaded successfully!") return model, tokenizer def generate_answer(model, tokenizer, query: str) -> str: """Generate an answer for the given multiplication query.""" # Format as chat message messages = [ {"role": "system", "content": config.SYSTEM_PROMPT}, {"role": "user", "content": query}, ] # Apply chat template prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=32, do_sample=False, # Greedy decoding for deterministic results pad_token_id=tokenizer.pad_token_id, ) # Decode only the generated part generated_ids = outputs[0][inputs["input_ids"].shape[1] :] answer = tokenizer.decode(generated_ids, skip_special_tokens=True) return answer.strip() def evaluate_accuracy(model, tokenizer, num_tests: int = 500, seed: int = 42): """ Evaluate the model's accuracy on random multiplication problems. Args: model: The loaded model with LoRA adapter tokenizer: The tokenizer num_tests: Number of random test cases to evaluate seed: Random seed for reproducibility Returns: Dictionary with accuracy metrics """ random.seed(seed) correct = 0 exact_match = 0 off_by_one_digit = 0 results = [] print(f"\nEvaluating accuracy on {num_tests} random multiplication problems...") print("-" * 50) for i in range(num_tests): # Generate random 6-digit number a = random.randint(100000, 999999) b = 7 expected = a * b # Generate prediction query = f"{a} * {b}" predicted_str = generate_answer(model, tokenizer, query) # Extract numeric answer from prediction predicted_numbers = re.findall(r"\d+", predicted_str) predicted = int(predicted_numbers[0]) if predicted_numbers else None # Check correctness is_correct = predicted == expected print(predicted, expected, is_correct) if is_correct: correct += 1 exact_match += 1 elif predicted is not None: # Check if off by one digit (common error pattern) expected_str = str(expected) predicted_str_clean = str(predicted) if len(expected_str) == len(predicted_str_clean): diff_count = sum( 1 for x, y in zip(expected_str, predicted_str_clean) if x != y ) if diff_count == 1: off_by_one_digit += 1 results.append( { "query": query, "expected": expected, "predicted": predicted, "correct": is_correct, } ) # Progress indicator if (i + 1) % 100 == 0: print(f" Progress: {i + 1}/{num_tests} ({correct}/{i + 1} correct so far)") accuracy = correct / num_tests off_by_one_rate = off_by_one_digit / num_tests print("-" * 50) print(f"\nResults:") print(f" Exact match accuracy: {accuracy * 100:.2f}% ({correct}/{num_tests})") print( f" Off by one digit: {off_by_one_rate * 100:.2f}% ({off_by_one_digit}/{num_tests})" ) print(f" Total near-correct: {(accuracy + off_by_one_rate) * 100:.2f}%") # Show some examples of errors errors = [r for r in results if not r["correct"]] if errors: print(f"\nSample errors (showing up to 5):") for err in errors[:5]: print( f" {err['query']} = {err['expected']} (predicted: {err['predicted']})" ) return { "accuracy": accuracy, "correct": correct, "total": num_tests, "off_by_one_digit": off_by_one_digit, "results": results, } def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser( description="Test the multiplication model with or without LoRA adapter" ) parser.add_argument( "--no-lora", "--base", action="store_true", dest="no_lora", help="Run without LoRA adapter (base model only)", ) return parser.parse_args() def main(): args = parse_args() use_lora = not args.no_lora print("=" * 60) if use_lora: print("Multiplication LoRA Model Tester") else: print("Multiplication Base Model Tester (NO LoRA)") print("=" * 60) # Check CUDA print(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") # Load model print("\nLoading model...") model, tokenizer = load_model(use_lora=use_lora) print("\n" + "=" * 60) print("Commands:") print(" - Enter multiplication queries like '134345 * 7' or '923256 * 7'") print(" - Type 'eval' or 'eval 1000' to run accuracy evaluation") print(" - Type 'quit' or 'exit' to stop") print("=" * 60 + "\n") while True: try: query = input("Query: ").strip() if not query: continue if query.lower() in ("quit", "exit", "q"): print("Goodbye!") break # Handle eval command if query.lower().startswith("eval"): parts = query.split() num_tests = int(parts[1]) if len(parts) > 1 else 500 evaluate_accuracy(model, tokenizer, num_tests=num_tests) print() continue # Generate answer answer = generate_answer(model, tokenizer, query) print(f"Answer: {answer}") # Try to verify the result if it looks like a multiplication try: # Parse the query to extract numbers numbers = re.findall(r"\d+", query) if len(numbers) >= 2: a, b = int(numbers[0]), int(numbers[1]) expected = a * b # Try to extract the numeric answer answer_numbers = re.findall(r"\d+", answer) if answer_numbers: predicted = int(answer_numbers[-1]) if predicted == expected: print(f" ✓ Correct! ({a} × {b} = {expected})") else: print(f" ✗ Expected: {expected}") except Exception: pass # Silently ignore verification errors print() except KeyboardInterrupt: print("\nGoodbye!") break except Exception as e: print(f"Error: {e}") if __name__ == "__main__": main()