Spaces:
Build error
Build error
| from __future__ import annotations | |
| import re | |
| import struct | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Tuple | |
| import numpy as np | |
| MAGIC = b"ADD2LLM1" # 8 bytes | |
| HEADER_STRUCT = "<8sIIII" # magic, max_number, rows, cols, table_bytes | |
| HEADER_SIZE = struct.calcsize(HEADER_STRUCT) | |
| class Add2ModelError(Exception): | |
| """Raised when input cannot be handled by the tiny addition model.""" | |
| class TinyAdd2LLM: | |
| model_path: Path | |
| max_number: int | |
| table: np.ndarray | |
| def load(cls, model_path: str | Path = "model/add2_model.bin") -> "TinyAdd2LLM": | |
| path = Path(model_path) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Model file not found: {path}") | |
| with path.open("rb") as f: | |
| header = f.read(HEADER_SIZE) | |
| magic, max_number, rows, cols, table_bytes = struct.unpack(HEADER_STRUCT, header) | |
| if magic != MAGIC: | |
| raise Add2ModelError("Invalid model file: bad magic header") | |
| raw = f.read(table_bytes) | |
| table = np.frombuffer(raw, dtype=np.uint16).reshape((rows, cols)) | |
| return cls(model_path=path, max_number=max_number, table=table) | |
| def add(self, a: int, b: int) -> int: | |
| if not (0 <= a <= self.max_number and 0 <= b <= self.max_number): | |
| raise Add2ModelError( | |
| f"This 10MB model supports only numbers 0 to {self.max_number}. " | |
| f"Got: {a}, {b}" | |
| ) | |
| return int(self.table[a, b]) | |
| def answer(self, text: str) -> str: | |
| a, b = parse_two_number_addition(text) | |
| return str(self.add(a, b)) | |
| def parse_two_number_addition(text: str) -> Tuple[int, int]: | |
| text = text.strip().lower() | |
| patterns = [ | |
| r"^\s*(\d+)\s*\+\s*(\d+)\s*\??\s*$", | |
| r"^\s*(\d+)\s+plus\s+(\d+)\s*\??\s*$", | |
| r"^\s*add\s+(\d+)\s+and\s+(\d+)\s*\??\s*$", | |
| r"^\s*what\s+is\s+(\d+)\s*\+\s*(\d+)\s*\??\s*$", | |
| r"^\s*what\s+is\s+(\d+)\s+plus\s+(\d+)\s*\??\s*$", | |
| ] | |
| for pattern in patterns: | |
| match = re.match(pattern, text) | |
| if match: | |
| return int(match.group(1)), int(match.group(2)) | |
| raise Add2ModelError( | |
| "I only understand two-number addition, like: 12 + 30, 12 plus 30, or add 12 and 30." | |
| ) | |