File size: 8,776 Bytes
2253b0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
"""
Interactive test script for the multiplication LoRA model.
Loads the base model with or without the LoRA adapter and allows testing multiplication queries.

Usage:
    python test_multiply.py              # With LoRA adapter (default)
    python test_multiply.py --no-lora    # Without LoRA adapter (base model only)
    python test_multiply.py --base       # Same as --no-lora
"""

import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
sys.path.insert(0, str(Path(__file__).parent))

import argparse
import torch
import random
import re
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

import config


def load_model(use_lora: bool = True):
    """
    Load the base model with or without the LoRA adapter.

    Args:
        use_lora: If True, apply the LoRA adapter. If False, use base model only.

    Returns:
        Tuple of (model, tokenizer)
    """
    base_model_name = config.BASE_MODEL

    print(f"Loading base model: {base_model_name}")
    if use_lora:
        lora_path = config.LORA_PATH
        print(f"Loading LoRA adapter from: {lora_path}")
    else:
        print("Running WITHOUT LoRA adapter (base model only)")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="auto",
        trust_remote_code=True,
    )

    # Load LoRA adapter if requested
    if use_lora:
        model = PeftModel.from_pretrained(model, str(config.LORA_PATH))

    model.eval()

    print("Model loaded successfully!")
    return model, tokenizer


def generate_answer(model, tokenizer, query: str) -> str:
    """Generate an answer for the given multiplication query."""
    # Format as chat message
    messages = [
        {"role": "system", "content": config.SYSTEM_PROMPT},
        {"role": "user", "content": query},
    ]

    # Apply chat template
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=32,
            do_sample=False,  # Greedy decoding for deterministic results
            pad_token_id=tokenizer.pad_token_id,
        )

    # Decode only the generated part
    generated_ids = outputs[0][inputs["input_ids"].shape[1] :]
    answer = tokenizer.decode(generated_ids, skip_special_tokens=True)

    return answer.strip()


def evaluate_accuracy(model, tokenizer, num_tests: int = 500, seed: int = 42):
    """
    Evaluate the model's accuracy on random multiplication problems.

    Args:
        model: The loaded model with LoRA adapter
        tokenizer: The tokenizer
        num_tests: Number of random test cases to evaluate
        seed: Random seed for reproducibility

    Returns:
        Dictionary with accuracy metrics
    """
    random.seed(seed)

    correct = 0
    exact_match = 0
    off_by_one_digit = 0
    results = []

    print(f"\nEvaluating accuracy on {num_tests} random multiplication problems...")
    print("-" * 50)

    for i in range(num_tests):
        # Generate random 6-digit number
        a = random.randint(100000, 999999)
        b = 7
        expected = a * b

        # Generate prediction
        query = f"{a} * {b}"
        predicted_str = generate_answer(model, tokenizer, query)

        # Extract numeric answer from prediction
        predicted_numbers = re.findall(r"\d+", predicted_str)
        predicted = int(predicted_numbers[0]) if predicted_numbers else None

        # Check correctness
        is_correct = predicted == expected
        print(predicted, expected, is_correct)
        if is_correct:
            correct += 1
            exact_match += 1
        elif predicted is not None:
            # Check if off by one digit (common error pattern)
            expected_str = str(expected)
            predicted_str_clean = str(predicted)
            if len(expected_str) == len(predicted_str_clean):
                diff_count = sum(
                    1 for x, y in zip(expected_str, predicted_str_clean) if x != y
                )
                if diff_count == 1:
                    off_by_one_digit += 1

        results.append(
            {
                "query": query,
                "expected": expected,
                "predicted": predicted,
                "correct": is_correct,
            }
        )

        # Progress indicator
        if (i + 1) % 100 == 0:
            print(f"  Progress: {i + 1}/{num_tests} ({correct}/{i + 1} correct so far)")

    accuracy = correct / num_tests
    off_by_one_rate = off_by_one_digit / num_tests

    print("-" * 50)
    print(f"\nResults:")
    print(f"  Exact match accuracy: {accuracy * 100:.2f}% ({correct}/{num_tests})")
    print(
        f"  Off by one digit:     {off_by_one_rate * 100:.2f}% ({off_by_one_digit}/{num_tests})"
    )
    print(f"  Total near-correct:   {(accuracy + off_by_one_rate) * 100:.2f}%")

    # Show some examples of errors
    errors = [r for r in results if not r["correct"]]
    if errors:
        print(f"\nSample errors (showing up to 5):")
        for err in errors[:5]:
            print(
                f"  {err['query']} = {err['expected']} (predicted: {err['predicted']})"
            )

    return {
        "accuracy": accuracy,
        "correct": correct,
        "total": num_tests,
        "off_by_one_digit": off_by_one_digit,
        "results": results,
    }


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="Test the multiplication model with or without LoRA adapter"
    )
    parser.add_argument(
        "--no-lora",
        "--base",
        action="store_true",
        dest="no_lora",
        help="Run without LoRA adapter (base model only)",
    )
    return parser.parse_args()


def main():
    args = parse_args()
    use_lora = not args.no_lora

    print("=" * 60)
    if use_lora:
        print("Multiplication LoRA Model Tester")
    else:
        print("Multiplication Base Model Tester (NO LoRA)")
    print("=" * 60)

    # Check CUDA
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")

    # Load model
    print("\nLoading model...")
    model, tokenizer = load_model(use_lora=use_lora)

    print("\n" + "=" * 60)
    print("Commands:")
    print("  - Enter multiplication queries like '134345 * 7' or '923256 * 7'")
    print("  - Type 'eval' or 'eval 1000' to run accuracy evaluation")
    print("  - Type 'quit' or 'exit' to stop")
    print("=" * 60 + "\n")

    while True:
        try:
            query = input("Query: ").strip()

            if not query:
                continue

            if query.lower() in ("quit", "exit", "q"):
                print("Goodbye!")
                break

            # Handle eval command
            if query.lower().startswith("eval"):
                parts = query.split()
                num_tests = int(parts[1]) if len(parts) > 1 else 500
                evaluate_accuracy(model, tokenizer, num_tests=num_tests)
                print()
                continue

            # Generate answer
            answer = generate_answer(model, tokenizer, query)
            print(f"Answer: {answer}")

            # Try to verify the result if it looks like a multiplication
            try:
                # Parse the query to extract numbers
                numbers = re.findall(r"\d+", query)
                if len(numbers) >= 2:
                    a, b = int(numbers[0]), int(numbers[1])
                    expected = a * b
                    # Try to extract the numeric answer
                    answer_numbers = re.findall(r"\d+", answer)
                    if answer_numbers:
                        predicted = int(answer_numbers[-1])
                        if predicted == expected:
                            print(f"  ✓ Correct! ({a} × {b} = {expected})")
                        else:
                            print(f"  ✗ Expected: {expected}")
            except Exception:
                pass  # Silently ignore verification errors

            print()

        except KeyboardInterrupt:
            print("\nGoodbye!")
            break
        except Exception as e:
            print(f"Error: {e}")


if __name__ == "__main__":
    main()