Spaces:
Sleeping
Sleeping
| """ | |
| Interactive test script for the multiplication LoRA model. | |
| Loads the base model with or without the LoRA adapter and allows testing multiplication queries. | |
| Usage: | |
| python test_multiply.py # With LoRA adapter (default) | |
| python test_multiply.py --no-lora # Without LoRA adapter (base model only) | |
| python test_multiply.py --base # Same as --no-lora | |
| """ | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| import argparse | |
| import torch | |
| import random | |
| import re | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| import config | |
| def load_model(use_lora: bool = True): | |
| """ | |
| Load the base model with or without the LoRA adapter. | |
| Args: | |
| use_lora: If True, apply the LoRA adapter. If False, use base model only. | |
| Returns: | |
| Tuple of (model, tokenizer) | |
| """ | |
| base_model_name = config.BASE_MODEL | |
| print(f"Loading base model: {base_model_name}") | |
| if use_lora: | |
| lora_path = config.LORA_PATH | |
| print(f"Loading LoRA adapter from: {lora_path}") | |
| else: | |
| print("Running WITHOUT LoRA adapter (base model only)") | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load base model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| # Load LoRA adapter if requested | |
| if use_lora: | |
| model = PeftModel.from_pretrained(model, str(config.LORA_PATH)) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| return model, tokenizer | |
| def generate_answer(model, tokenizer, query: str) -> str: | |
| """Generate an answer for the given multiplication query.""" | |
| # Format as chat message | |
| messages = [ | |
| {"role": "system", "content": config.SYSTEM_PROMPT}, | |
| {"role": "user", "content": query}, | |
| ] | |
| # Apply chat template | |
| prompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| # Tokenize | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=32, | |
| do_sample=False, # Greedy decoding for deterministic results | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| # Decode only the generated part | |
| generated_ids = outputs[0][inputs["input_ids"].shape[1] :] | |
| answer = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| return answer.strip() | |
| def evaluate_accuracy(model, tokenizer, num_tests: int = 500, seed: int = 42): | |
| """ | |
| Evaluate the model's accuracy on random multiplication problems. | |
| Args: | |
| model: The loaded model with LoRA adapter | |
| tokenizer: The tokenizer | |
| num_tests: Number of random test cases to evaluate | |
| seed: Random seed for reproducibility | |
| Returns: | |
| Dictionary with accuracy metrics | |
| """ | |
| random.seed(seed) | |
| correct = 0 | |
| exact_match = 0 | |
| off_by_one_digit = 0 | |
| results = [] | |
| print(f"\nEvaluating accuracy on {num_tests} random multiplication problems...") | |
| print("-" * 50) | |
| for i in range(num_tests): | |
| # Generate random 6-digit number | |
| a = random.randint(100000, 999999) | |
| b = 7 | |
| expected = a * b | |
| # Generate prediction | |
| query = f"{a} * {b}" | |
| predicted_str = generate_answer(model, tokenizer, query) | |
| # Extract numeric answer from prediction | |
| predicted_numbers = re.findall(r"\d+", predicted_str) | |
| predicted = int(predicted_numbers[0]) if predicted_numbers else None | |
| # Check correctness | |
| is_correct = predicted == expected | |
| print(predicted, expected, is_correct) | |
| if is_correct: | |
| correct += 1 | |
| exact_match += 1 | |
| elif predicted is not None: | |
| # Check if off by one digit (common error pattern) | |
| expected_str = str(expected) | |
| predicted_str_clean = str(predicted) | |
| if len(expected_str) == len(predicted_str_clean): | |
| diff_count = sum( | |
| 1 for x, y in zip(expected_str, predicted_str_clean) if x != y | |
| ) | |
| if diff_count == 1: | |
| off_by_one_digit += 1 | |
| results.append( | |
| { | |
| "query": query, | |
| "expected": expected, | |
| "predicted": predicted, | |
| "correct": is_correct, | |
| } | |
| ) | |
| # Progress indicator | |
| if (i + 1) % 100 == 0: | |
| print(f" Progress: {i + 1}/{num_tests} ({correct}/{i + 1} correct so far)") | |
| accuracy = correct / num_tests | |
| off_by_one_rate = off_by_one_digit / num_tests | |
| print("-" * 50) | |
| print(f"\nResults:") | |
| print(f" Exact match accuracy: {accuracy * 100:.2f}% ({correct}/{num_tests})") | |
| print( | |
| f" Off by one digit: {off_by_one_rate * 100:.2f}% ({off_by_one_digit}/{num_tests})" | |
| ) | |
| print(f" Total near-correct: {(accuracy + off_by_one_rate) * 100:.2f}%") | |
| # Show some examples of errors | |
| errors = [r for r in results if not r["correct"]] | |
| if errors: | |
| print(f"\nSample errors (showing up to 5):") | |
| for err in errors[:5]: | |
| print( | |
| f" {err['query']} = {err['expected']} (predicted: {err['predicted']})" | |
| ) | |
| return { | |
| "accuracy": accuracy, | |
| "correct": correct, | |
| "total": num_tests, | |
| "off_by_one_digit": off_by_one_digit, | |
| "results": results, | |
| } | |
| def parse_args(): | |
| """Parse command line arguments.""" | |
| parser = argparse.ArgumentParser( | |
| description="Test the multiplication model with or without LoRA adapter" | |
| ) | |
| parser.add_argument( | |
| "--no-lora", | |
| "--base", | |
| action="store_true", | |
| dest="no_lora", | |
| help="Run without LoRA adapter (base model only)", | |
| ) | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| use_lora = not args.no_lora | |
| print("=" * 60) | |
| if use_lora: | |
| print("Multiplication LoRA Model Tester") | |
| else: | |
| print("Multiplication Base Model Tester (NO LoRA)") | |
| print("=" * 60) | |
| # Check CUDA | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| # Load model | |
| print("\nLoading model...") | |
| model, tokenizer = load_model(use_lora=use_lora) | |
| print("\n" + "=" * 60) | |
| print("Commands:") | |
| print(" - Enter multiplication queries like '134345 * 7' or '923256 * 7'") | |
| print(" - Type 'eval' or 'eval 1000' to run accuracy evaluation") | |
| print(" - Type 'quit' or 'exit' to stop") | |
| print("=" * 60 + "\n") | |
| while True: | |
| try: | |
| query = input("Query: ").strip() | |
| if not query: | |
| continue | |
| if query.lower() in ("quit", "exit", "q"): | |
| print("Goodbye!") | |
| break | |
| # Handle eval command | |
| if query.lower().startswith("eval"): | |
| parts = query.split() | |
| num_tests = int(parts[1]) if len(parts) > 1 else 500 | |
| evaluate_accuracy(model, tokenizer, num_tests=num_tests) | |
| print() | |
| continue | |
| # Generate answer | |
| answer = generate_answer(model, tokenizer, query) | |
| print(f"Answer: {answer}") | |
| # Try to verify the result if it looks like a multiplication | |
| try: | |
| # Parse the query to extract numbers | |
| numbers = re.findall(r"\d+", query) | |
| if len(numbers) >= 2: | |
| a, b = int(numbers[0]), int(numbers[1]) | |
| expected = a * b | |
| # Try to extract the numeric answer | |
| answer_numbers = re.findall(r"\d+", answer) | |
| if answer_numbers: | |
| predicted = int(answer_numbers[-1]) | |
| if predicted == expected: | |
| print(f" ✓ Correct! ({a} × {b} = {expected})") | |
| else: | |
| print(f" ✗ Expected: {expected}") | |
| except Exception: | |
| pass # Silently ignore verification errors | |
| print() | |
| except KeyboardInterrupt: | |
| print("\nGoodbye!") | |
| break | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| if __name__ == "__main__": | |
| main() | |