File size: 2,306 Bytes
beb6a05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from __future__ import annotations

import re
import struct
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple

import numpy as np

MAGIC = b"ADD2LLM1"  # 8 bytes
HEADER_STRUCT = "<8sIIII"  # magic, max_number, rows, cols, table_bytes
HEADER_SIZE = struct.calcsize(HEADER_STRUCT)


class Add2ModelError(Exception):
    """Raised when input cannot be handled by the tiny addition model."""


@dataclass
class TinyAdd2LLM:
    model_path: Path
    max_number: int
    table: np.ndarray

    @classmethod
    def load(cls, model_path: str | Path = "model/add2_model.bin") -> "TinyAdd2LLM":
        path = Path(model_path)
        if not path.exists():
            raise FileNotFoundError(f"Model file not found: {path}")

        with path.open("rb") as f:
            header = f.read(HEADER_SIZE)
            magic, max_number, rows, cols, table_bytes = struct.unpack(HEADER_STRUCT, header)
            if magic != MAGIC:
                raise Add2ModelError("Invalid model file: bad magic header")
            raw = f.read(table_bytes)

        table = np.frombuffer(raw, dtype=np.uint16).reshape((rows, cols))
        return cls(model_path=path, max_number=max_number, table=table)

    def add(self, a: int, b: int) -> int:
        if not (0 <= a <= self.max_number and 0 <= b <= self.max_number):
            raise Add2ModelError(
                f"This 10MB model supports only numbers 0 to {self.max_number}. "
                f"Got: {a}, {b}"
            )
        return int(self.table[a, b])

    def answer(self, text: str) -> str:
        a, b = parse_two_number_addition(text)
        return str(self.add(a, b))


def parse_two_number_addition(text: str) -> Tuple[int, int]:
    text = text.strip().lower()
    patterns = [
        r"^\s*(\d+)\s*\+\s*(\d+)\s*\??\s*$",
        r"^\s*(\d+)\s+plus\s+(\d+)\s*\??\s*$",
        r"^\s*add\s+(\d+)\s+and\s+(\d+)\s*\??\s*$",
        r"^\s*what\s+is\s+(\d+)\s*\+\s*(\d+)\s*\??\s*$",
        r"^\s*what\s+is\s+(\d+)\s+plus\s+(\d+)\s*\??\s*$",
    ]
    for pattern in patterns:
        match = re.match(pattern, text)
        if match:
            return int(match.group(1)), int(match.group(2))
    raise Add2ModelError(
        "I only understand two-number addition, like: 12 + 30, 12 plus 30, or add 12 and 30."
    )