| from __future__ import annotations |
|
|
| import re |
| import struct |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Tuple |
|
|
| import numpy as np |
|
|
| MAGIC = b"ADD2LLM1" |
| HEADER_STRUCT = "<8sIIII" |
| HEADER_SIZE = struct.calcsize(HEADER_STRUCT) |
|
|
|
|
| class Add2ModelError(Exception): |
| """Raised when input cannot be handled by the tiny addition model.""" |
|
|
|
|
| @dataclass |
| class TinyAdd2LLM: |
| model_path: Path |
| max_number: int |
| table: np.ndarray |
|
|
| @classmethod |
| def load(cls, model_path: str | Path = "model/add2_model.bin") -> "TinyAdd2LLM": |
| path = Path(model_path) |
| if not path.exists(): |
| raise FileNotFoundError(f"Model file not found: {path}") |
|
|
| with path.open("rb") as f: |
| header = f.read(HEADER_SIZE) |
| magic, max_number, rows, cols, table_bytes = struct.unpack(HEADER_STRUCT, header) |
| if magic != MAGIC: |
| raise Add2ModelError("Invalid model file: bad magic header") |
| raw = f.read(table_bytes) |
|
|
| table = np.frombuffer(raw, dtype=np.uint16).reshape((rows, cols)) |
| return cls(model_path=path, max_number=max_number, table=table) |
|
|
| def add(self, a: int, b: int) -> int: |
| if not (0 <= a <= self.max_number and 0 <= b <= self.max_number): |
| raise Add2ModelError( |
| f"This 10MB model supports only numbers 0 to {self.max_number}. " |
| f"Got: {a}, {b}" |
| ) |
| return int(self.table[a, b]) |
|
|
| def answer(self, text: str) -> str: |
| a, b = parse_two_number_addition(text) |
| return str(self.add(a, b)) |
|
|
|
|
| def parse_two_number_addition(text: str) -> Tuple[int, int]: |
| text = text.strip().lower() |
| patterns = [ |
| r"^\s*(\d+)\s*\+\s*(\d+)\s*\??\s*$", |
| r"^\s*(\d+)\s+plus\s+(\d+)\s*\??\s*$", |
| r"^\s*add\s+(\d+)\s+and\s+(\d+)\s*\??\s*$", |
| r"^\s*what\s+is\s+(\d+)\s*\+\s*(\d+)\s*\??\s*$", |
| r"^\s*what\s+is\s+(\d+)\s+plus\s+(\d+)\s*\??\s*$", |
| ] |
| for pattern in patterns: |
| match = re.match(pattern, text) |
| if match: |
| return int(match.group(1)), int(match.group(2)) |
| raise Add2ModelError( |
| "I only understand two-number addition, like: 12 + 30, 12 plus 30, or add 12 and 30." |
| ) |
|
|