File size: 2,306 Bytes
beb6a05 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | from __future__ import annotations
import re
import struct
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple
import numpy as np
MAGIC = b"ADD2LLM1" # 8 bytes
HEADER_STRUCT = "<8sIIII" # magic, max_number, rows, cols, table_bytes
HEADER_SIZE = struct.calcsize(HEADER_STRUCT)
class Add2ModelError(Exception):
"""Raised when input cannot be handled by the tiny addition model."""
@dataclass
class TinyAdd2LLM:
model_path: Path
max_number: int
table: np.ndarray
@classmethod
def load(cls, model_path: str | Path = "model/add2_model.bin") -> "TinyAdd2LLM":
path = Path(model_path)
if not path.exists():
raise FileNotFoundError(f"Model file not found: {path}")
with path.open("rb") as f:
header = f.read(HEADER_SIZE)
magic, max_number, rows, cols, table_bytes = struct.unpack(HEADER_STRUCT, header)
if magic != MAGIC:
raise Add2ModelError("Invalid model file: bad magic header")
raw = f.read(table_bytes)
table = np.frombuffer(raw, dtype=np.uint16).reshape((rows, cols))
return cls(model_path=path, max_number=max_number, table=table)
def add(self, a: int, b: int) -> int:
if not (0 <= a <= self.max_number and 0 <= b <= self.max_number):
raise Add2ModelError(
f"This 10MB model supports only numbers 0 to {self.max_number}. "
f"Got: {a}, {b}"
)
return int(self.table[a, b])
def answer(self, text: str) -> str:
a, b = parse_two_number_addition(text)
return str(self.add(a, b))
def parse_two_number_addition(text: str) -> Tuple[int, int]:
text = text.strip().lower()
patterns = [
r"^\s*(\d+)\s*\+\s*(\d+)\s*\??\s*$",
r"^\s*(\d+)\s+plus\s+(\d+)\s*\??\s*$",
r"^\s*add\s+(\d+)\s+and\s+(\d+)\s*\??\s*$",
r"^\s*what\s+is\s+(\d+)\s*\+\s*(\d+)\s*\??\s*$",
r"^\s*what\s+is\s+(\d+)\s+plus\s+(\d+)\s*\??\s*$",
]
for pattern in patterns:
match = re.match(pattern, text)
if match:
return int(match.group(1)), int(match.group(2))
raise Add2ModelError(
"I only understand two-number addition, like: 12 + 30, 12 plus 30, or add 12 and 30."
)
|