from __future__ import annotations
'''PackedLLM, By: Chance Brownfield-|-HiMindAi@proton.me'''
import ast
import argparse
import base64
import contextlib
import dataclasses
import hashlib
import importlib
import importlib.util
import io
import json
import lzma
import math
import os
import platform
import queue
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
import threading
import time
import traceback
import types
import uuid
import zipfile
from collections import OrderedDict, defaultdict, deque
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import (
Any, Callable, Dict, Iterable, List, Mapping,
MutableMapping, Optional, Sequence, Tuple, Union
)
from urllib.parse import urlparse, parse_qs, quote, unquote
from multiprocessing import Process, Queue, get_context
import concurrent.futures
import numpy as np
import psutil
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
try:
from transformers import (
MarianMTModel,
MarianTokenizer,
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
_HAS_TRANSFORMERS = True
except ImportError:
_HAS_TRANSFORMERS = False
MarianMTModel = MarianTokenizer = None
AutoModelForSeq2SeqLM = AutoTokenizer = None
try:
from sentence_transformers import SentenceTransformer, util
_HAS_SENTENCE_TRANSFORMERS = True
except ImportError:
_HAS_SENTENCE_TRANSFORMERS = False
try:
import fitz # PyMuPDF
_HAS_PYMUPDF = True
except ImportError:
_HAS_PYMUPDF = False
try:
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity
_HAS_SKLEARN = True
except ImportError:
_HAS_SKLEARN = False
KMeans = AgglomerativeClustering = PCA = cosine_similarity = None
try:
import spacy
_HAS_SPACY = True
except ImportError as e:
raise RuntimeError(f"spaCy is required: {e}")
try:
from bs4 import BeautifulSoup
except ImportError:
BeautifulSoup = None
try:
import trafilatura
except ImportError:
trafilatura = None
try:
from readability import Document
except ImportError:
Document = None
try:
from newspaper import Article
except ImportError:
Article = None
try:
from goose3 import Goose
except ImportError:
Goose = None
try:
from boilerpy3 import extractors
except ImportError:
extractors = None
try:
from inscriptis import get_text as inscriptis_text
except ImportError:
inscriptis_text = None
try:
from lxml import html as lxml_html
except ImportError:
lxml_html = None
try:
from youtube_transcript_api import YouTubeTranscriptApi
except ImportError:
YouTubeTranscriptApi = None
try:
from llama_cpp import Llama
_HAS_LLAMA_CPP = True
except ImportError:
_HAS_LLAMA_CPP = False
Llama = None
try:
import wgpu
_HAS_WGPU = True
except ImportError:
_HAS_WGPU = False
try:
from huggingface_hub import snapshot_download
_HAS_HF_HUB = True
except ImportError:
_HAS_HF_HUB = False
try:
from safetensors.torch import load_file as safetensors_load
except ImportError:
safetensors_load = None
MODEL_DIR = Path("models")
EMBEDDING_PATH = MODEL_DIR / "all-MiniLM-L6-v2"
SUMMARIZER_PATH = MODEL_DIR / "distilbart-cnn-12-6"
SPACY_MODEL_PATH = MODEL_DIR / "spacy" / "en_core_web_sm"
SPACY_MODEL_NAME = "en_core_web_sm"
DEFAULT_CHECKPOINT_PATH = "LM.pt"
DEFAULT_BUNDLE_PATH = "PackedLM.pt"
DEFAULT_IMAGE_TEST_SOURCE = "sample_img.png"
DEFAULT_ZH_EN_DIR = MODEL_DIR / "opus-mt-zh-en"
_CHUNK_BYTES = 32 * 1024 * 1024
_CODE_FENCE_RE = re.compile(r"```(?:python)?\s*\n(.*?)```", re.DOTALL | re.IGNORECASE)
CHINESE_RE = re.compile(r"[\u4e00-\u9fff]")
CHINESE_SPAN_RE = re.compile(
r"[\u4e00-\u9fff]+(?:[\u3000-\u303F\uFF00-\uFFEF\u2000-\u206F"
r"\u2E00-\u2E7F\u3000-\u303F\uFF00-\uFFEF\s,.;:!?\-—()\[\]{},。!?、;:]+"
r"[\u4e00-\u9fff]+)*"
)
GGUF_EMBED_FILENAME = "jina-embeddings-v3-Q8_0.gguf"
for p in [MODEL_DIR, EMBEDDING_PATH, SUMMARIZER_PATH, SPACY_MODEL_PATH.parent]:
p.mkdir(parents=True, exist_ok=True)
def _extract_python_code(text: str) -> str:
if not isinstance(text, str):
return ""
text = text.strip()
fences = _CODE_FENCE_RE.findall(text)
candidates = [f.strip() for f in fences] if fences else [text]
for candidate in candidates:
cleaned = _strip_to_valid_python(candidate)
if cleaned:
return cleaned
return text
def _strip_to_valid_python(code: str) -> str:
try:
ast.parse(code)
return code
except SyntaxError:
pass
lines = code.splitlines()
for start in range(1, min(len(lines), 10) + 1):
candidate = "\n".join(lines[start:])
try:
ast.parse(candidate)
return candidate
except SyntaxError:
continue
for end in range(len(lines) - 1, max(len(lines) - 10, 0), -1):
candidate = "\n".join(lines[:end])
try:
ast.parse(candidate)
return candidate
except SyntaxError:
continue
return ""
def _json_dumps(obj: Any) -> str:
return json.dumps(obj, ensure_ascii=False, default=str)
def _parse_json_safe(text: Any) -> Any:
if not isinstance(text, str):
return None
cleaned = text.strip()
if cleaned.startswith("```"):
lines = cleaned.splitlines()
if lines and lines[-1].strip() == "```":
cleaned = "\n".join(lines[1:-1])
else:
cleaned = "\n".join(lines[1:])
try:
return json.loads(cleaned)
except Exception:
for start_char, end_char in [("{", "}"), ("[", "]")]:
si = cleaned.find(start_char)
ei = cleaned.rfind(end_char)
if si != -1 and ei != -1 and ei > si:
try:
return json.loads(cleaned[si:ei + 1])
except Exception:
pass
return None
def _safe_import_class(name: str) -> Optional[type]:
try:
frame = inspect.stack()[2].frame
cls = frame.f_globals.get(name) or builtins.__dict__.get(name)
return cls if isinstance(cls, type) else None
except Exception:
return None
def _safe_call(obj: Any, name: str, *args: Any, default: Any = None, **kwargs: Any) -> Any:
fn = getattr(obj, name, None)
if not callable(fn):
return default
try:
return fn(*args, **kwargs)
except Exception:
return default
def _bytes_to_chunks(data: bytes, chunk_size: int = _CHUNK_BYTES) -> List[bytes]:
return [data[i: i + chunk_size] for i in range(0, max(len(data), 1), chunk_size)]
def _chunks_to_bytes(chunks: List[bytes]) -> bytes:
return b"".join(chunks)
def _read_file_chunked(path: Optional[Union[str, Path]]) -> Optional[List[bytes]]:
if not path:
return None
p = Path(path)
if not p.exists():
return None
chunks: List[bytes] = []
try:
with open(p, "rb") as fh:
while True:
chunk = fh.read(_CHUNK_BYTES)
if not chunk:
break
chunks.append(chunk)
except Exception:
return None
return chunks if chunks else [b""]
def _write_chunks_to_temp(chunks: Optional[List[bytes]], suffix: str, prefix: str = "packedllm_") -> Optional[str]:
if not chunks:
return None
fd, path = tempfile.mkstemp(prefix=prefix, suffix=suffix)
os.close(fd)
try:
with open(path, "wb") as fh:
for chunk in chunks:
fh.write(chunk)
except Exception:
try:
os.unlink(path)
except Exception:
pass
return None
return path
def _normalise_expert_name(name: str) -> str:
s = re.sub(r"(?<=[a-z0-9])([A-Z])", r"_\1", name)
return s.lower()
def _expert_names_canonical(names: List[str]) -> List[str]:
seen: set = set()
out: List[str] = []
for n in names:
key = _normalise_expert_name(n)
if key not in seen:
seen.add(key)
out.append(key)
return out
def capture_telemetry() -> Dict[str, Any]:
process = psutil.Process(os.getpid())
cpu_total_pct = psutil.cpu_percent(interval=None)
cpu_process_pct = process.cpu_percent(interval=None)
ram_info = psutil.virtual_memory()
metrics: Dict[str, Any] = {
"timestamp_ns": time.perf_counter_ns(),
"cpu": {
"system_total_percent": cpu_total_pct,
"process_percent": cpu_process_pct,
},
"ram": {
"system_total_gb": ram_info.total / (1024 ** 3),
"system_available_gb": ram_info.available / (1024 ** 3),
"system_used_gb": ram_info.used / (1024 ** 3),
"process_rss_gb": process.memory_info().rss / (1024 ** 3),
},
"gpu_hardware_metrics": {
"driver_detected": False,
"device_name": "None",
"total_vram_gb": 0.0,
"used_vram_gb": 0.0,
"free_vram_gb": 0.0,
"gpu_utilization_percent": 0.0,
},
}
try:
cmd = (
"nvidia-smi --query-gpu=name,memory.total,memory.free,memory.used,"
"utilization.gpu --format=csv,noheader,nounits"
)
output = subprocess.check_output(cmd.split(), stderr=subprocess.DEVNULL).decode("ascii").strip()
if output:
parts = [p.strip() for p in output.split(",")]
metrics["gpu_hardware_metrics"] = {
"driver_detected": True,
"device_name": parts[0],
"total_vram_gb": float(parts[1]) / 1024.0,
"free_vram_gb": float(parts[2]) / 1024.0,
"used_vram_gb": float(parts[3]) / 1024.0,
"gpu_utilization_percent": float(parts[4]),
}
except Exception:
pass
return metrics
def calculate_delta(start: Dict[str, Any], end: Dict[str, Any]) -> Dict[str, Any]:
s_gpu = start["gpu_hardware_metrics"]
e_gpu = end["gpu_hardware_metrics"]
vram_delta = (e_gpu["used_vram_gb"] - s_gpu["used_vram_gb"]) if e_gpu["driver_detected"] else 0.0
gpu_util = e_gpu["gpu_utilization_percent"] if e_gpu["driver_detected"] else 0.0
return {
"ram_process_delta_gb": end["ram"]["process_rss_gb"] - start["ram"]["process_rss_gb"],
"vram_allocated_delta_gb": vram_delta,
"gpu_instantaneous_utilization_pct": gpu_util,
"cpu_system_delta_pct": end["cpu"]["system_total_percent"] - start["cpu"]["system_total_percent"],
}
def normalize_unicode(s: str) -> str:
s = unicodedata.normalize("NFKC", s)
s = re.sub(r"[\u200B-\u200F\uFEFF\u00AD]", "", s)
s = re.sub(r"[\x00-\x1F\x7F]", "", s)
return s
def canonicalize_numbers(s: str) -> str:
return re.sub(r"\d+\.\d+|\d+", "N", s)
def strip_latex_wrappers(s: str) -> str:
s = s.replace("\\[", "").replace("\\]", "")
s = s.replace("\\(", "").replace("\\)", "")
s = re.sub(r"\$+", "", s)
return s
def semantic_key(line: str) -> str:
line = line.strip().lower()
line = normalize_unicode(line)
line = strip_latex_wrappers(line)
line = re.sub(r"\\frac\{([^}]*)}\{([^}]*)}", r"frac(\1,\2)", line)
line = re.sub(r"\^{\s*([^}]*)\s*}", r"^(\1)", line)
line = canonicalize_numbers(line)
line = re.sub(r"\s+", " ", line)
return line.strip()
def collapse_repeated_blocks_with_report(text: str, block_size: int = 2) -> Tuple[str, List[str]]:
lines = [l for l in text.splitlines() if l.strip()]
out: List[str] = []
seen = set()
removed: List[str] = []
i = 0
while i < len(lines):
block_lines = lines[i:i+block_size]
block_keys = tuple(semantic_key(l) for l in block_lines)
if len(block_keys) < block_size:
out.extend(lines[i:])
break
if block_keys in seen:
removed.extend(block_lines)
i += 1
continue
seen.add(block_keys)
out.extend(block_lines)
i += block_size
return "\n".join(out), removed
def collapse_repeated_semantic_lines_with_report(text: str, max_repeat: int = 1) -> Tuple[str, List[str]]:
out: List[str] = []
removed: List[str] = []
prev_key = None
count = 0
for line in text.splitlines():
if not line.strip():
out.append(line)
prev_key = None
count = 0
continue
k = semantic_key(line)
if k == prev_key:
count += 1
if count > max_repeat:
removed.append(line)
continue
else:
prev_key = k
count = 0
out.append(line)
return "\n".join(out), removed
def collapse_repeated_lines(text: str, block_size: int = 2, max_repeat: int = 1, passes: int = 2, verbose: bool = True) -> str:
out = text or ""
total_removed: List[str] = []
for _ in range(max(1, int(passes))):
out, removed_blocks = collapse_repeated_blocks_with_report(out, block_size=block_size)
total_removed.extend(removed_blocks)
out, removed_lines = collapse_repeated_semantic_lines_with_report(out, max_repeat=max_repeat)
total_removed.extend(removed_lines)
seen = set()
unique_removed = []
for r in total_removed:
if r not in seen:
seen.add(r)
unique_removed.append(r)
if verbose:
if unique_removed:
print(f"[collapse_repeated_lines] Removed {len(unique_removed)} unique repeated line(s)/block(s). Examples:")
for ex in unique_removed[:50]:
print(f"- {ex}")
else:
print("[collapse_repeated_lines] No repeated blocks or semantic-line repeats detected.")
return out.strip()
def _now_iso() -> str:
return datetime.utcnow().isoformat()
def _norm_ws(text: str) -> str:
return re.sub(r"\s+", " ", (text or "")).strip()
def _safe_json_dumps(obj: Any) -> str:
return json.dumps(obj, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
def _safe_json_loads(text: str, default: Any = None) -> Any:
try:
return json.loads(text)
except Exception:
return default
def _maybe_list(x: Any) -> List[Any]:
if x is None:
return []
if isinstance(x, list):
return x
if isinstance(x, tuple):
return list(x)
return [x]
def _normalize_whitespace(text: str) -> str:
return re.sub(r"\s+", " ", text or "").strip()
def split_sentences(text: str) -> List[str]:
text = _normalize_whitespace(text)
if not text:
return []
parts = re.split(r"(?<=[.!?])\s+", text)
return [p.strip() for p in parts if p.strip()]
def _safe_hash(text: str) -> int:
return hash(_normalize_whitespace(text))
def _ensure_min_text(text: str, fallback: str = "") -> str:
text = _normalize_whitespace(text)
if text:
return text
return _normalize_whitespace(fallback)
class FileManager:
def __init__(self, base_dir: str):
self.base_dir = os.path.abspath(base_dir)
self.global_dir = os.path.join(self.base_dir, "global_data")
os.makedirs(self.global_dir, exist_ok=True)
def abs_path(self, path: str) -> str:
return os.path.abspath(path)
def copy_to_global(self, src: str, dest_name: Optional[str] = None) -> str:
dest_name = dest_name or os.path.basename(src)
dest = os.path.join(self.global_dir, dest_name)
if os.path.abspath(src) == os.path.abspath(dest):
return dest
if not os.path.exists(dest):
shutil.copy2(src, dest)
return dest
def write_bytes_to_global(self, data: bytes, dest_name: str) -> str:
dest = os.path.join(self.global_dir, dest_name)
with open(dest, "wb") as f:
f.write(data)
return dest
def hardlink_or_copy(self, src: str, dst: str):
try:
os.link(src, dst)
except Exception:
shutil.copy2(src, dst)
def atomic_replace(self, src_tmp: str, dest: str):
os.replace(src_tmp, dest)
def compute_sha256(self, path: str) -> str:
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
h.update(chunk)
return h.hexdigest()
class CodeBoxError(Exception):
pass
class AssetNotFoundError(CodeBoxError):
pass
class RunnerCacheError(CodeBoxError):
pass
class LRUCache:
def __init__(self, capacity: int = 2):
self.capacity = capacity
self.cache = OrderedDict()
def get(self, key):
if key not in self.cache:
return None
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key, value):
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = value
return
self.cache[key] = value
if len(self.cache) > self.capacity:
old_key, old_val = self.cache.popitem(last=False)
try:
if hasattr(old_val, "close"):
old_val.close()
if hasattr(old_val, "cleanup"):
old_val.cleanup()
except Exception:
pass
try:
import torch
del old_val
torch.cuda.empty_cache()
except Exception:
pass
def keys(self):
return list(self.cache.keys())
def clear(self):
self.cache.clear()
try:
import torch
torch.cuda.empty_cache()
except Exception:
pass
# The CodeBox
class CodeBox:
ASSETS_FILENAME = "assets.pt"
ASSET_SCHEMA_VERSION = 1
def __init__(self, base_dir: str = "./codebox_storage", runner_cache_capacity: int = 2):
self.base_dir = os.path.abspath(base_dir)
self.envs_dir = os.path.join(self.base_dir, "envs")
self.file_manager = FileManager(self.base_dir)
os.makedirs(self.envs_dir, exist_ok=True)
self.env_dic: Dict[str, Dict[str, Any]] = {}
self.code_bank: Dict[str, Dict[str, Any]] = {}
self.asset_registry: Dict[str, Dict[str, Any]] = {}
self._load_registry()
self._runner_cache = LRUCache(capacity=runner_cache_capacity)
self._ensure_loader_template()
def _registry_path(self) -> str:
return os.path.join(self.file_manager.global_dir, self.ASSETS_FILENAME)
def _persist_registry(self):
tmp_fd, tmp_path = tempfile.mkstemp(dir=self.file_manager.global_dir)
os.close(tmp_fd)
payload = {
"schema_version": self.ASSET_SCHEMA_VERSION,
"assets": self.asset_registry
}
torch.save(payload, tmp_path)
self.file_manager.atomic_replace(tmp_path, self._registry_path())
def _load_registry(self):
path = self._registry_path()
if os.path.exists(path):
try:
payload = torch.load(path)
if isinstance(payload, dict) and "assets" in payload:
self.asset_registry = payload["assets"]
else:
self.asset_registry = payload
except Exception:
corrupted = path + f".corrupt.{int(time.time())}"
shutil.move(path, corrupted)
self.asset_registry = {}
else:
self.asset_registry = {}
def register_asset(self, alias: str, file_path: Optional[str] = None,
asset_type: str = "bin", metadata: Dict = None,
embed_bytes: Optional[bytes] = None, force: bool = False) -> Dict[str, Any]:
metadata = metadata or {}
if embed_bytes is not None:
bytes_name = f"{alias}.embedded"
dest = self.file_manager.write_bytes_to_global(embed_bytes, bytes_name)
sha = self.file_manager.compute_sha256(dest)
entry = {
"source_path": dest,
"type": asset_type,
"metadata": metadata,
"embedded": True,
"bytes_name": bytes_name,
"sha256": sha,
"registered_at": time.time()
}
self.asset_registry[alias] = entry
self._persist_registry()
return entry
if not file_path:
raise ValueError("Provide either file_path or embed_bytes")
src_abs = os.path.abspath(file_path)
dest = self.file_manager.copy_to_global(src_abs, dest_name=os.path.basename(src_abs))
sha = self.file_manager.compute_sha256(dest)
entry = {
"source_path": dest,
"type": asset_type,
"metadata": metadata,
"embedded": False,
"sha256": sha,
"registered_at": time.time()
}
if alias in self.asset_registry and not force:
if self.asset_registry[alias].get("sha256") == sha:
return self.asset_registry[alias]
self.asset_registry[alias] = entry
self._persist_registry()
return entry
def unregister_asset(self, alias: str):
if alias in self.asset_registry:
del self.asset_registry[alias]
self._persist_registry()
def _ensure_loader_template(self):
loader_path = os.path.join(self.file_manager.global_dir, "_codebox_loader.py")
if os.path.exists(loader_path):
return
loader_code = r'''
import os
import json
# Try optional heavy deps; if missing, fall back to JSON registry and path-only behavior.
try:
import torch as _torch
except Exception:
_torch = None
try:
from safetensors.torch import load_file as _safetensors_load
except Exception:
_safetensors_load = None
def _registry_pt_path(assets_dir):
return os.path.join(assets_dir, "assets.pt")
def _registry_json_path(assets_dir):
return os.path.join(assets_dir, "assets.json")
def _load_registry(assets_dir):
# Prefer torch payload if torch is available and file exists
pt_path = _registry_pt_path(assets_dir)
json_path = _registry_json_path(assets_dir)
if _torch is not None and os.path.exists(pt_path):
try:
payload = _torch.load(pt_path)
if isinstance(payload, dict) and "assets" in payload:
return payload["assets"]
return payload
except Exception:
# fall through to json
pass
if os.path.exists(json_path):
with open(json_path, "r", encoding="utf-8") as f:
return json.load(f)
# last resort: try to load pt even without torch (will raise)
if os.path.exists(pt_path):
raise RuntimeError("Torch not available to read assets.pt; install torch or ensure assets.json exists.")
return {}
def load_asset(alias):
assets_dir = os.environ.get("CODEBOX_ASSETS_DIR")
if not assets_dir:
raise RuntimeError("CODEBOX_ASSETS_DIR not set")
registry = _load_registry(assets_dir)
entry = registry.get(alias)
if not entry:
raise KeyError(f"Asset '{alias}' not found in registry")
path = entry["source_path"]
typ = entry.get("type", "bin")
if typ == "safetensors":
if _safetensors_load is None:
return path
return _safetensors_load(path)
if typ == "pt":
if _torch is None:
raise RuntimeError("Torch not available in this environment to load .pt assets.")
return _torch.load(path, map_location="cpu")
if typ == "json":
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
return path
'''
loader_code = textwrap.dedent(loader_code)
with open(loader_path, "w", encoding="utf-8") as f:
f.write(loader_code)
def _sync_required_assets(self, working_dir: str, required_assets: Optional[List[str]] = None):
registry_dst_pt = os.path.join(working_dir, self.ASSETS_FILENAME)
registry_dst_json = os.path.join(working_dir, "assets.json")
payload = {"schema_version": self.ASSET_SCHEMA_VERSION, "assets": self.asset_registry}
tmp_fd, tmp_path = tempfile.mkstemp(dir=working_dir)
os.close(tmp_fd)
torch.save(payload, tmp_path)
os.replace(tmp_path, registry_dst_pt)
tmp_fd, tmp_path = tempfile.mkstemp(dir=working_dir)
os.close(tmp_fd)
with open(tmp_path, "w", encoding="utf-8") as f:
json.dump(self.asset_registry, f)
os.replace(tmp_path, registry_dst_json)
assets_to_mount = required_assets if required_assets else list(self.asset_registry.keys())
for alias in assets_to_mount:
entry = self.asset_registry.get(alias)
if not entry:
continue
src = entry["source_path"]
target = os.path.join(working_dir, os.path.basename(src))
if not os.path.exists(target):
try:
self.file_manager.hardlink_or_copy(src, target)
except Exception:
shutil.copy2(src, target)
def _inject_loader_into_env(self, env_src_dir: str):
src = os.path.join(self.file_manager.global_dir, "_codebox_loader.py")
dst = os.path.join(env_src_dir, "_codebox_loader.py")
if not os.path.exists(dst):
shutil.copy2(src, dst)
def _get_python_bin(self, venv_path: str) -> str:
if os.name == 'nt':
return os.path.join(venv_path, "Scripts", "python.exe")
return os.path.join(venv_path, "bin", "python")
def create_venv(self, venv_id: str, requirements: List[str] = None) -> str:
venv_path = os.path.join(self.envs_dir, venv_id)
if venv_id not in self.env_dic:
subprocess.run([sys.executable, "-m", "venv", venv_path], check=True)
os.makedirs(os.path.join(venv_path, "src"), exist_ok=True)
self.env_dic[venv_id] = {"path": venv_path, "packages": []}
# inject loader template into src
self._inject_loader_into_env(os.path.join(venv_path, "src"))
if requirements:
self.install_packages(venv_id, requirements)
return venv_path
def install_packages(self, venv_id: str, packages: List[str]):
if venv_id not in self.env_dic:
self.create_venv(venv_id)
python_bin = self._get_python_bin(self.env_dic[venv_id]["path"])
cmd = [python_bin, "-m", "pip", "install", "--quiet"] + packages
subprocess.run(cmd, check=True)
existing = set(self.env_dic[venv_id]["packages"])
for pkg in packages:
if pkg not in existing:
self.env_dic[venv_id]["packages"].append(pkg)
def export_venv(self, venv_id: str) -> str:
if venv_id not in self.env_dic:
raise ValueError(f"Environment '{venv_id}' does not exist.")
python_bin = self._get_python_bin(self.env_dic[venv_id]["path"])
res = subprocess.run([python_bin, "-m", "pip", "freeze"], capture_output=True, text=True)
manifest = {
"venv_id": venv_id,
"pip_freeze": res.stdout.splitlines(),
"metadata": {k: v for k, v in self.env_dic[venv_id].items() if k != "path"}
}
return json.dumps(manifest, indent=2)
def import_venv(self, manifest_json: str):
manifest = json.loads(manifest_json)
venv_id = manifest["venv_id"]
self.create_venv(venv_id, requirements=manifest["pip_freeze"])
def _execute_supervised(self, python_bin: str, script_path: str, working_dir: str,
timeout: Optional[int] = 30, max_ram_mb: int = 4096,
required_assets: Optional[List[str]] = None) -> Dict[str, Any]:
self._sync_required_assets(working_dir, required_assets)
env_vars = os.environ.copy()
env_vars["PYTHONPATH"] = working_dir
env_vars["CODEBOX_ASSETS_DIR"] = working_dir
proc = subprocess.Popen(
[python_bin, script_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=working_dir,
env=env_vars,
text=True
)
alarms = []
MAX_CPU = 95.0
def monitor():
try:
p = psutil.Process(proc.pid)
while proc.poll() is None:
mem_mb = p.memory_info().rss / (1024 * 1024)
cpu = p.cpu_percent(interval=0.2)
if mem_mb > max_ram_mb:
alarms.append(f"RESOURCE KILL: Memory usage exceeded ({mem_mb:.1f}MB > {max_ram_mb}MB)")
proc.kill()
break
if cpu > MAX_CPU:
alarms.append(f"RESOURCE WARNING: Sustained CPU spike ({cpu}%)")
except psutil.NoSuchProcess:
pass
mon_thread = threading.Thread(target=monitor, daemon=True)
mon_thread.start()
try:
stdout, stderr = proc.communicate(timeout=timeout)
except subprocess.TimeoutExpired:
proc.kill()
stdout, stderr = proc.communicate()
alarms.append(f"RESOURCE KILL: Code execution timed out ({timeout}s limit).")
return {
"stdout": stdout.strip(),
"stderr": stderr.strip(),
"success": proc.returncode == 0 and not any("KILL" in a for a in alarms),
"exit_code": proc.returncode,
"technical_alarms": alarms
}
def run_code(self, code_block: str, venv_id: str = "default", requirements: List[str] = None,
timeout: Optional[int] = 30, max_ram_mb: int = 4096,
required_assets: Optional[List[str]] = None) -> Dict[str, Any]:
is_temp = False
if requirements:
venv_id = f"temp_{int(time.time())}"
self.create_venv(venv_id, requirements)
is_temp = True
elif venv_id not in self.env_dic:
self.create_venv(venv_id)
env_meta = self.env_dic[venv_id]
src_dir = os.path.join(env_meta["path"], "src")
python_bin = self._get_python_bin(env_meta["path"])
temp_script = os.path.join(src_dir, f"_run_{int(time.time())}.py")
with open(temp_script, 'w', encoding='utf-8') as f:
f.write(code_block)
try:
result = self._execute_supervised(python_bin, temp_script, src_dir, timeout, max_ram_mb, required_assets)
finally:
if os.path.exists(temp_script):
os.remove(temp_script)
if is_temp:
shutil.rmtree(env_meta["path"], ignore_errors=True)
del self.env_dic[venv_id]
return result
def save_script(self, code_id: str, venv_id: str, source_code: str):
if venv_id not in self.env_dic:
self.create_venv(venv_id)
src_dir = os.path.join(self.env_dic[venv_id]["path"], "src")
filepath = os.path.join(src_dir, f"{code_id}.py")
with open(filepath, 'w', encoding='utf-8') as f:
f.write(source_code)
self.code_bank[code_id] = {"venv_id": venv_id, "filepath": filepath}
self._inject_loader_into_env(src_dir)
def call_function(self, code_id: str = None, function_call: str = None, function_map: Dict[str, Any] = None,
timeout: Optional[int] = 30, max_ram_mb: int = 4096,
required_assets: Optional[List[str]] = None) -> Dict[str, Any]:
if function_map:
first_step = list(function_map.values())[0]
venv_id = self.code_bank[first_step["code_id"]]["venv_id"]
env_meta = self.env_dic[venv_id]
src_dir = os.path.join(env_meta["path"], "src")
python_bin = self._get_python_bin(env_meta["path"])
lines = ["import json\nimport sys\ncontext = {}\n"]
for step, data in function_map.items():
c_id, f_name = data["code_id"], data["function"]
args = data.get("args", {})
out_var = data.get("output_var", f"out_{step}")
lines.append(f"import {c_id}")
arg_strs = []
for k, v in args.items():
if isinstance(v, str) and v.startswith("$"):
arg_strs.append(f"{k}=context['{v[1:]}']")
else:
arg_strs.append(f"{k}={repr(v)}")
lines.append(f"try:\n context['{out_var}'] = {c_id}.{f_name}({', '.join(arg_strs)})")
lines.append(f"except Exception as e:\n print(f'Pipeline failed at {step}: {{e}}', file=sys.stderr)\n sys.exit(1)\n")
lines.append("print(json.dumps(context))")
wrapper_code = "\n".join(lines)
wrapper_path = os.path.join(src_dir, "_dag_runner.py")
with open(wrapper_path, 'w', encoding='utf-8') as f:
f.write(wrapper_code)
try:
result = self._execute_supervised(python_bin, wrapper_path, src_dir, timeout, max_ram_mb, required_assets)
finally:
if os.path.exists(wrapper_path):
os.remove(wrapper_path)
return result
elif code_id and function_call:
import ast
try:
tree = ast.parse(function_call)
expr = tree.body[0].value
if not isinstance(expr, ast.Call):
raise ValueError("Target signature is not a valid call statement.")
func_name = expr.func.id
extracted_args = {}
for keyword in expr.keywords:
extracted_args[keyword.arg] = ast.literal_eval(keyword.value)
except Exception as e:
raise ValueError(f"AST Function Call Parser failed on expression matching: {e}")
macro_map = {
"step_1": {
"code_id": code_id,
"function": func_name,
"args": extracted_args,
"output_var": "result"
}
}
return self.call_function(function_map=macro_map, timeout=timeout, max_ram_mb=max_ram_mb, required_assets=required_assets)
else:
raise ValueError("Provide either code_id + function_call, or function_map")
def get_runner(self, key: str):
return self._runner_cache.get(key)
def put_runner(self, key: str, runner_obj):
self._runner_cache.put(key, runner_obj)
def prune_cache(self, days_unused: int = 7, logger=None):
cutoff = time.time() - days_unused * 86400
assets_to_keep = {self.ASSETS_FILENAME, "_codebox_loader.py"}
try:
with os.scandir(self.file_manager.global_dir) as it:
for entry in it:
try:
if entry.is_dir():
continue
name = entry.name
if name in assets_to_keep:
continue
mtime = entry.stat().st_mtime
if mtime < cutoff:
try:
os.remove(entry.path)
except Exception as e:
if logger:
logger.warning("Failed to remove %s: %s", entry.path, e)
except FileNotFoundError:
continue
except PermissionError:
if logger:
logger.warning("Permission denied pruning %s", entry.path)
continue
except Exception as e:
if logger:
logger.error("Prune cache failed for %s: %s", self.file_manager.global_dir, e)
def resolve_asset_path(self, alias: str) -> str:
entry = self.asset_registry.get(alias)
if not entry:
raise AssetNotFoundError(alias)
return entry["source_path"]
class Box:
"""
Persistent wrapper around CodeBox.
Resolution order:
1. box_location argument
2. AppData config location
3. Create new CodeBox.pt
"""
APP_NAME = "CodeBox"
CONFIG_FILE = "box_config.json"
DEFAULT_MODEL_FILE = "CodeBox.pt"
def __init__(
self,
box_location: Optional[str] = None,
base_dir: Optional[str] = None,
runner_cache_capacity: int = 2,
):
self._model_path = self._resolve_model_path(box_location)
if os.path.exists(self._model_path):
self.model = torch.load(self._model_path)
else:
self.model = CodeBox(
base_dir=base_dir or self._default_storage_dir(),
runner_cache_capacity=runner_cache_capacity,
)
self.save()
@classmethod
def _appdata_dir(cls):
if os.name == "nt":
root = os.getenv("APPDATA")
else:
root = os.path.expanduser("~/.config")
path = os.path.join(root, cls.APP_NAME)
os.makedirs(path, exist_ok=True)
return path
@classmethod
def _config_path(cls):
return os.path.join(cls._appdata_dir(), cls.CONFIG_FILE)
@classmethod
def _default_storage_dir(cls):
path = os.path.join(cls._appdata_dir(), "storage")
os.makedirs(path, exist_ok=True)
return path
@classmethod
def _default_model_path(cls):
return os.path.join(cls._appdata_dir(), cls.DEFAULT_MODEL_FILE)
def _resolve_model_path(self, box_location):
if box_location:
path = os.path.abspath(box_location)
self._write_config(path)
return path
config = self._read_config()
if config:
saved_path = config.get("box_location")
if saved_path and os.path.exists(saved_path):
return saved_path
path = self._default_model_path()
self._write_config(path)
return path
def _read_config(self):
cfg = self._config_path()
if not os.path.exists(cfg):
return {}
try:
with open(cfg, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {}
def _write_config(self, model_path):
cfg = self._config_path()
with open(cfg, "w", encoding="utf-8") as f:
json.dump(
{
"box_location": os.path.abspath(model_path)
},
f,
indent=2,
)
def save(self):
torch.save(self.model, self._model_path)
def save_as(self, path):
path = os.path.abspath(path)
torch.save(self.model, path)
self._model_path = path
self._write_config(path)
@classmethod
def load(cls, path):
return cls(box_location=path)
def register_asset(self, *args, **kwargs):
result = self.model.register_asset(*args, **kwargs)
self.save()
return result
def unregister_asset(self, *args, **kwargs):
result = self.model.unregister_asset(*args, **kwargs)
self.save()
return result
def create_venv(self, *args, **kwargs):
result = self.model.create_venv(*args, **kwargs)
self.save()
return result
def install_packages(self, *args, **kwargs):
result = self.model.install_packages(*args, **kwargs)
self.save()
return result
def save_script(self, *args, **kwargs):
result = self.model.save_script(*args, **kwargs)
self.save()
return result
def put_runner(self, *args, **kwargs):
result = self.model.put_runner(*args, **kwargs)
self.save()
return result
def __getattr__(self, name):
return getattr(self.model, name)
def __contains__(self, alias):
return alias in self.model.asset_registry
def __len__(self):
return len(self.model.asset_registry)
def __repr__(self):
return (
f"Box("
f"assets={len(self.model.asset_registry)}, "
f"envs={len(self.model.env_dic)}, "
f"path='{self._model_path}')"
)
def close(self):
try:
self.model._runner_cache.clear()
except Exception:
pass
self.save()
def __del__(self):
try:
self.close()
except Exception:
pass
class TextModelBundle:
def __init__(self, model_dir: str = MODEL_DIR):
self.model_dir = model_dir
self.embedding_path = os.path.join(model_dir, "all-MiniLM-L6-v2")
self.summarizer_path = os.path.join(model_dir, "distilbart-cnn-12-6")
self.spacy_model_path = os.path.join(model_dir, "spacy", "en_core_web_sm")
self.spacy_model_name = SPACY_MODEL_NAME
self.embedding_model = self._load_embeddings()
self.summarizer_model, self.tokenizer = self._load_summarizer()
self.nlp = self._load_spacy()
def _load_embeddings(self):
if os.path.exists(self.embedding_path) and os.path.isdir(self.embedding_path):
return SentenceTransformer(self.embedding_path)
model = SentenceTransformer("all-MiniLM-L6-v2")
model.save(self.embedding_path)
return model
def _load_summarizer(self):
if os.path.exists(self.summarizer_path) and os.path.isdir(self.summarizer_path):
model = AutoModelForSeq2SeqLM.from_pretrained(self.summarizer_path)
tokenizer = AutoTokenizer.from_pretrained(self.summarizer_path)
return model, tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6")
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")
model.save_pretrained(self.summarizer_path)
tokenizer.save_pretrained(self.summarizer_path)
return model, tokenizer
def _load_spacy(self):
try:
if os.path.exists(self.spacy_model_path):
return spacy.load(self.spacy_model_path)
return spacy.load(self.spacy_model_name)
except Exception:
nlp = spacy.blank("en")
if "sentencizer" not in nlp.pipe_names:
nlp.add_pipe("sentencizer")
return nlp
def embed(self, texts, convert_to_tensor=True):
return self.embedding_model.encode(texts, convert_to_tensor=convert_to_tensor)
def generate_summary(self, text: str, max_length: int = 300, min_length: int = 50) -> str:
text = _normalize_whitespace(text)
if not text:
return ""
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=1024,
)
device = next(self.summarizer_model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
summary_ids = self.summarizer_model.generate(
**inputs,
max_length=max_length,
min_length=min_length,
num_beams=4,
early_stopping=True,
)
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
def chunk_text_by_context(text: str, bundle: TextModelBundle, num_chunks: int = 20) -> List[str]:
sentences = split_sentences(text)
if not sentences:
return []
if len(sentences) == 1:
return [sentences[0]]
embeddings = bundle.embed(sentences)
if embeddings.ndim != 2 or embeddings.shape[0] <= 1:
return [" ".join(sentences)]
n_samples, n_features = embeddings.shape
n_components = max(1, min(num_chunks, n_samples, n_features))
reduced = PCA(n_components=n_components).fit_transform(embeddings)
k = min(num_chunks, n_samples)
clustering = AgglomerativeClustering(n_clusters=k)
labels = clustering.fit_predict(reduced)
chunks: Dict[int, List[str]] = {}
for sent, lbl in zip(sentences, labels):
chunks.setdefault(int(lbl), []).append(sent)
return [" ".join(chunks[i]) for i in sorted(chunks.keys())]
def safe_summarize_iterative(bundle: TextModelBundle, text: str, max_length: int = 500, min_length: int = 300, overlap: int = 100) -> str:
text = _normalize_whitespace(text)
if not text:
return ""
word_threshold = max_length
if len(text.split()) <= word_threshold:
return text
token_ids = bundle.tokenizer.encode(text, add_special_tokens=False)
if len(token_ids) <= 1024:
try:
return bundle.generate_summary(text, max_length=max_length, min_length=min_length)
except Exception:
return text
sentences = split_sentences(text)
if len(sentences) > 1:
mid = len(sentences) // 2
left = " ".join(sentences[:mid])
right = " ".join(sentences[mid:])
a = safe_summarize_iterative(bundle, left, max_length, min_length, overlap)
b = safe_summarize_iterative(bundle, right, max_length, min_length, overlap)
return _normalize_whitespace(f"{a} {b}")
chunks = []
start = 0
while start < len(token_ids):
end = min(start + 1024, len(token_ids))
chunk_tokens = token_ids[start:end]
chunk_text = bundle.tokenizer.decode(
chunk_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
chunks.append(chunk_text)
if end == len(token_ids):
break
start = end - overlap
summaries = []
for chunk in chunks:
chunk = _normalize_whitespace(chunk)
if not chunk:
continue
if len(chunk.split()) <= word_threshold:
summaries.append(chunk)
continue
try:
summaries.append(bundle.generate_summary(chunk, max_length=max_length, min_length=min_length))
except Exception:
summaries.append(chunk)
combined = _normalize_whitespace(" ".join(summaries))
if not combined:
return text
if len(combined.split()) <= word_threshold:
return combined
return safe_summarize_iterative(bundle, combined, max_length, min_length, overlap)
def safe_summarize(bundle: TextModelBundle, text: str, max_length: int = 750, min_length: int = 500, overlap: int = 250, depth: int = 0, max_depth: int = 15) -> str:
text = _normalize_whitespace(text)
if not text:
return ""
if depth > max_depth:
return safe_summarize_iterative(bundle, text, max_length=max_length, min_length=min_length, overlap=overlap)
if len(text.split()) <= max_length:
return text
token_ids = bundle.tokenizer.encode(text, add_special_tokens=False)
if len(token_ids) <= 1024:
try:
return bundle.generate_summary(text, max_length=max_length, min_length=min_length)
except Exception:
return text
sentences = split_sentences(text)
if len(sentences) > 1:
mid = len(sentences) // 2
left = " ".join(sentences[:mid])
right = " ".join(sentences[mid:])
try:
s1 = safe_summarize(bundle, left, max_length, min_length, overlap, depth + 1, max_depth)
s2 = safe_summarize(bundle, right, max_length, min_length, overlap, depth + 1, max_depth)
return safe_summarize(bundle, f"{s1} {s2}", max_length, min_length, overlap, depth + 1, max_depth)
except RecursionError:
return safe_summarize_iterative(bundle, text, max_length=max_length, min_length=min_length, overlap=overlap)
pieces = []
start = 0
while start < len(token_ids):
end = min(start + 1024, len(token_ids))
chunk_tokens = token_ids[start:end]
chunk_text = bundle.tokenizer.decode(
chunk_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
pieces.append(chunk_text)
if end == len(token_ids):
break
start = end - overlap
chunk_summaries = []
for piece in pieces:
piece = _normalize_whitespace(piece)
if not piece:
continue
if len(piece.split()) <= max_length:
chunk_summaries.append(piece)
else:
try:
chunk_summaries.append(bundle.generate_summary(piece, max_length=max_length, min_length=min_length))
except Exception:
chunk_summaries.append(piece)
combined = _normalize_whitespace(" ".join(chunk_summaries))
if not combined:
return text
if len(combined.split()) <= max_length:
return combined
return safe_summarize(bundle, combined, max_length, min_length, overlap, depth + 1, max_depth)
def summarize_relevant_clusters(bundle: TextModelBundle, input_query: str, texts: List[str], similarity_threshold: Optional[float] = None, num_clusters: int = 12) -> List[str]:
texts = [_normalize_whitespace(t) for t in texts if _normalize_whitespace(t)]
if not texts:
return []
if len(texts) == 1:
return [safe_summarize(bundle, texts[0])]
embeddings = bundle.embed(texts, convert_to_tensor=True)
sim_matrix = util.pytorch_cos_sim(embeddings, embeddings)
if similarity_threshold is None:
if sim_matrix.size(0) > 1:
idx = torch.triu_indices(sim_matrix.size(0), sim_matrix.size(1), offset=1)
similarities = sim_matrix[idx[0], idx[1]]
similarity_threshold = similarities.mean().item() if similarities.numel() > 0 else 0.85
else:
similarity_threshold = 0.85
keep_indices = []
for i in range(len(texts)):
if not any(float(sim_matrix[i][j].item()) > similarity_threshold for j in keep_indices):
keep_indices.append(i)
dedup_texts = [texts[i] for i in keep_indices]
if not dedup_texts:
dedup_texts = texts[:]
dedup_embeddings = embeddings[keep_indices] if keep_indices else embeddings
n_clusters = min(num_clusters, len(dedup_texts))
n_clusters = max(1, n_clusters)
if len(dedup_texts) == 1:
return [safe_summarize(bundle, dedup_texts[0])]
kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init="auto")
labels = kmeans.fit_predict(dedup_embeddings.cpu().numpy())
clusters: Dict[int, List[str]] = {}
for idx, lbl in enumerate(labels):
clusters.setdefault(int(lbl), []).append(dedup_texts[idx])
try:
doc = bundle.nlp(input_query)
keywords = {token.lemma_.lower() for token in doc if getattr(token, "pos_", "") in ("NOUN", "VERB", "ADJ", "PROPN")}
except Exception:
keywords = set()
def is_cluster_relevant(cluster_texts: List[str]) -> bool:
if not keywords:
return True
joined = " ".join(cluster_texts).lower()
return any(k in joined for k in keywords)
relevant_clusters = [c for c in clusters.values() if is_cluster_relevant(c)]
if not relevant_clusters:
relevant_clusters = list(clusters.values())
cluster_summaries = []
for cluster in relevant_clusters:
combined = _normalize_whitespace(" ".join(cluster))
if combined:
cluster_summaries.append(safe_summarize(bundle, combined))
if not cluster_summaries:
return [safe_summarize(bundle, " ".join(dedup_texts))]
final_text = _normalize_whitespace(" ".join(cluster_summaries))
return [safe_summarize(bundle, final_text)]
@dataclass
class PageCandidate:
title: str
url: str
snippet: str = ""
rank: float = 0.0
@dataclass
class CrawlResult:
query: str
answer: str = ""
partial_texts: List[str] = field(default_factory=list)
used_candidates: List[str] = field(default_factory=list)
failed_candidates: List[str] = field(default_factory=list)
fallback_used: bool = False
elapsed_seconds: float = 0.0
error: Optional[str] = None
class CrawlWorker:
def __init__(self):
self.bundle = TextModelBundle()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
self.session = requests.Session()
self.session.headers.update({
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
),
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "en-US,en;q=0.5",
})
def close(self):
try:
self.executor.shutdown(wait=False, cancel_futures=True)
except Exception:
pass
try:
self.session.close()
except Exception:
pass
def run_query(self, query: str, stats: Optional[dict] = None) -> CrawlResult:
t0 = time.perf_counter()
result = CrawlResult(query=query)
try:
candidates = self.build_candidates(query)
if not candidates:
result.answer = f"No results found for: {query}"
result.fallback_used = True
return result
target_successes = 5
successful_texts: List[str] = []
fallback_snippets: List[str] = []
for idx, cand in enumerate(candidates):
if len(successful_texts) >= target_successes:
break
text = self.fetch_and_extract(cand.url, page_timeout=20)
if text and len(text.strip()) >= 80:
successful_texts.append(_normalize_whitespace(text))
result.used_candidates.append(cand.url)
else:
result.failed_candidates.append(cand.url)
if cand.snippet:
fallback_snippets.append(f"{cand.title}. {cand.snippet}".strip())
if len(successful_texts) < target_successes:
for cand in candidates[len(successful_texts) + len(result.failed_candidates):]:
if len(successful_texts) >= target_successes:
break
if cand.url in result.used_candidates or cand.url in result.failed_candidates:
continue
text = self.fetch_and_extract(cand.url, page_timeout=20)
if text and len(text.strip()) >= 80:
successful_texts.append(_normalize_whitespace(text))
result.used_candidates.append(cand.url)
else:
result.failed_candidates.append(cand.url)
if cand.snippet:
fallback_snippets.append(f"{cand.title}. {cand.snippet}".strip())
if successful_texts:
merged = " ".join(successful_texts)
chunks = chunk_text_by_context(merged, self.bundle, num_chunks=min(8, max(2, len(successful_texts))))
summary_list = summarize_relevant_clusters(self.bundle, query, chunks, similarity_threshold=None, num_clusters=min(8, len(chunks)))
answer = _normalize_whitespace(" ".join(summary_list))
result.answer = answer if answer else _normalize_whitespace(merged)
result.partial_texts = successful_texts
return result
snippet_text = _normalize_whitespace(" ".join(fallback_snippets))
if snippet_text:
result.answer = snippet_text
result.fallback_used = True
return result
ranked_text = _normalize_whitespace(" ".join(f"{c.title}. {c.snippet}".strip() for c in candidates[:5]))
result.answer = ranked_text if ranked_text else f"No usable text found for: {query}"
result.fallback_used = True
return result
except Exception as e:
result.error = f"{e}\n{traceback.format_exc()}"
result.answer = result.answer or f"[ERROR] {e}"
return result
finally:
result.elapsed_seconds = time.perf_counter() - t0
if stats is not None:
stats.setdefault("runs", []).append({
"query": query,
"seconds": round(result.elapsed_seconds, 3),
"used_pages": len(result.used_candidates),
"failed_pages": len(result.failed_candidates),
"fallback_used": result.fallback_used,
"error": result.error,
})
def build_candidates(self, query: str, num_results: int = 15) -> List[PageCandidate]:
raw = []
raw.extend(self.duckduckgo_search(query, num_results=num_results))
raw.extend(self.resulthunter_search(query, num_results=num_results))
raw.extend(self.google_search(query, num_results=num_results))
deduped: List[PageCandidate] = []
seen = set()
scored = []
for title, url, snippet in raw:
norm = self.normalize_url(url)
if not norm or norm in seen:
continue
seen.add(norm)
scored.append(PageCandidate(title=title or norm, url=url, snippet=snippet or ""))
if not scored:
return []
titles_and_snippets = [f"{c.title} {c.snippet}".strip() for c in scored]
query_emb = self.bundle.embed([query], convert_to_tensor=True)[0]
page_embs = self.bundle.embed(titles_and_snippets, convert_to_tensor=True)
sim_scores = util.cos_sim(query_emb, page_embs)[0]
order = sim_scores.argsort(descending=True).tolist()
ordered = [scored[i] for i in order]
return ordered[:15]
def fetch_and_extract(self, url: str, page_timeout: int = 20) -> str:
url = self.normalize_url(url)
if not url:
return ""
future = self.executor.submit(self._fetch_and_extract_sync, url)
try:
return future.result(timeout=page_timeout) or ""
except concurrent.futures.TimeoutError:
return ""
except Exception:
return ""
def _fetch_and_extract_sync(self, url: str) -> str:
try:
r = self.session.get(url, timeout=(8, 15), allow_redirects=True)
r.raise_for_status()
except Exception:
return ""
content_type = r.headers.get("Content-Type", "").lower()
if "application/pdf" in content_type or url.lower().endswith(".pdf"):
text = self._extract_pdf_bytes(r.content)
return _normalize_whitespace(text)
html = r.text or ""
text = self.universal_page_parser(url, html, response=r, use_browser=False)
return _normalize_whitespace(text)
def _extract_pdf_bytes(self, pdf_bytes: bytes) -> str:
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
tmp.write(pdf_bytes)
pdf_path = tmp.name
doc = fitz.open(pdf_path)
text = " ".join(page.get_text() for page in doc)
doc.close()
try:
os.remove(pdf_path)
except Exception:
pass
return text
except Exception:
return ""
def universal_page_parser(self, url: str, html: str, response=None, use_browser: bool = False) -> str:
text = ""
if YouTubeTranscriptApi is not None and ("youtu" in url.lower()):
video_id = self._extract_youtube_id(url)
if video_id:
try:
transcript = YouTubeTranscriptApi.get_transcript(video_id)
text = " ".join(entry["text"] for entry in transcript).strip()
if text:
return text
except Exception:
pass
if trafilatura is not None:
try:
t = trafilatura.extract(html, include_comments=False, include_tables=True)
if t and len(t.strip()) > 80:
return t.strip()
except Exception:
pass
if extractors is not None:
try:
boilerpy_text = extractors.ArticleExtractor().get_content(html)
if boilerpy_text and len(boilerpy_text.strip()) > 80:
return boilerpy_text.strip()
except Exception:
pass
if Document is not None:
try:
doc = Document(html)
soup = BeautifulSoup(doc.summary(), "html.parser")
text_readability = soup.get_text(" ", strip=True)
if text_readability and len(text_readability.strip()) > 80:
return text_readability.strip()
except Exception:
pass
if Article is not None:
try:
article = Article(url)
article.set_html(html)
article.parse()
text_newspaper = article.text or ""
if text_newspaper and len(text_newspaper.strip()) > 80:
return text_newspaper.strip()
except Exception:
pass
if Goose is not None:
try:
goose_text = Goose().extract(raw_html=html).cleaned_text
if goose_text and len(goose_text.strip()) > 80:
return goose_text.strip()
except Exception:
pass
if inscriptis_text is not None:
try:
inscriptis_parsed = inscriptis_text(html)
if inscriptis_parsed and len(inscriptis_parsed.strip()) > 80:
return inscriptis_parsed.strip()
except Exception:
pass
if lxml_html is not None:
try:
lxml_tree = lxml_html.fromstring(html)
lxml_text = " ".join(lxml_tree.xpath("//p//text()"))
if lxml_text and len(lxml_text.strip()) > 80:
return lxml_text.strip()
except Exception:
pass
try:
soup = BeautifulSoup(html, "html.parser")
for tag in soup(["script", "style", "noscript", "header", "footer", "nav", "aside", "form", "input", "button", "svg", "canvas", "iframe", "object", "embed", "img", "video", "audio"]):
tag.decompose()
tags = ["p", "li", "span", "div", "h1", "h2", "h3", "h4", "h5", "h6"]
bs_text = " ".join(t.get_text(" ", strip=True) for tag in tags for t in soup.find_all(tag))
if bs_text and len(bs_text.strip()) > 80:
return bs_text.strip()
except Exception:
pass
try:
soup = BeautifulSoup(html, "html.parser")
for tag in soup(["script", "style", "noscript", "header", "footer", "nav", "aside", "form", "input", "button", "svg", "canvas", "iframe", "object", "embed", "img", "video", "audio"]):
tag.decompose()
visible_text = soup.get_text(" ", strip=True)
if visible_text and len(visible_text.strip()) > 50:
return visible_text.strip()
except Exception:
pass
return ""
def _extract_youtube_id(self, url: str) -> Optional[str]:
patterns = [
r"(?:v=)([A-Za-z0-9_-]{11})",
r"youtu\.be/([A-Za-z0-9_-]{11})",
r"youtube\.com/shorts/([A-Za-z0-9_-]{11})",
r"youtube\.com/embed/([A-Za-z0-9_-]{11})",
]
for pattern in patterns:
m = re.search(pattern, url)
if m:
return m.group(1)
return None
def normalize_url(self, url: str) -> str:
if not url:
return ""
if "resulthunter.com" in url:
qs = parse_qs(urlparse(url).query)
if "url" in qs:
return unquote(qs["url"][0])
if url.startswith("/videos/watch/"):
parsed = urlparse(url)
path_parts = parsed.path.split("/")
if len(path_parts) >= 4:
video_id = path_parts[3]
if len(video_id) >= 11:
return f"https://www.youtube.com/watch?v={video_id}"
if "duckduckgo.com/l/" in url:
qs = parse_qs(urlparse(url).query)
if "uddg" in qs and qs["uddg"]:
return unquote(qs["uddg"][0])
return url if url.startswith("http") else f"https://{url.lstrip('/')}"
def is_ad_link(self, url: str) -> bool:
ad_keywords = ["advert", "ads", "doubleclick", "sponsor", "promo"]
return any(term in (url or "").lower() for term in ad_keywords)
# ---------------- search engines ----------------
def duckduckgo_search(self, query: str, num_results: int = 10):
url = "https://html.duckduckgo.com/html/"
data = {"q": query}
try:
r = self.session.post(url, data=data, timeout=(8, 20))
r.raise_for_status()
except requests.RequestException:
return []
soup = BeautifulSoup(r.text, "html.parser")
results = []
for result in soup.select("div.result"):
title_a = result.select_one("a.result__url") or result.select_one("a.result__a")
snippet_a = result.select_one("a.result__snippet") or result.select_one("div.result__snippet")
if not title_a:
continue
title = title_a.get_text(strip=True)
href = title_a.get("href", "")
if href.startswith("//duckduckgo.com/l/?uddg="):
href = unquote(href.split("uddg=")[1].split("&")[0])
snippet = snippet_a.get_text(strip=True) if snippet_a else ""
results.append((title, href, snippet))
if len(results) >= num_results:
break
return results
def resulthunter_search(self, query: str, num_results: int = 10):
encoded_query = quote(query)
url = f"https://www.resulthunter.com/search?q={encoded_query}"
try:
r = self.session.get(url, timeout=(8, 15))
r.raise_for_status()
except requests.RequestException:
return []
soup = BeautifulSoup(r.text, "html.parser")
results = []
result_divs = soup.find_all("div", class_="web-result")
if not result_divs:
result_divs = soup.find_all("div", class_=lambda c: c and "result" in c.lower())
for result in result_divs:
link_tag = result.find("a", href=True)
if not link_tag:
continue
title = link_tag.get_text(strip=True)
link = link_tag["href"]
if not link.startswith("http"):
continue
snippet_tag = result.find("p", class_="web-result-desc") or result.find("p")
snippet = snippet_tag.get_text(strip=True) if snippet_tag else ""
results.append((title, link, snippet))
if len(results) >= num_results:
break
return results
def google_search(self, query: str, num_results: int = 10):
encoded_query = quote(query)
url = f"https://www.google.com/search?q={encoded_query}&num={num_results + 5}"
try:
r = self.session.get(url, timeout=(8, 15))
r.raise_for_status()
except requests.RequestException:
return []
soup = BeautifulSoup(r.text, "html.parser")
results = []
for g in soup.select("div.g"):
title_el = g.select_one("h3")
link_el = g.select_one("a[href]")
if not (title_el and link_el):
continue
title = title_el.get_text(strip=True)
href = link_el["href"]
snippet_el = g.select_one("div.VwiC3b")
if snippet_el:
snippet = snippet_el.get_text(strip=True)
else:
snippet = g.get_text(separator=" ", strip=True).replace(title, "", 1).strip()
results.append((title, href, snippet))
if len(results) >= num_results:
break
return results
def _worker_main(command_queue: Queue, response_queue: Queue):
worker = CrawlWorker()
try:
while True:
msg = command_queue.get()
if not msg:
continue
mtype = msg.get("type")
if mtype == "shutdown":
response_queue.put({"type": "shutdown_ack"})
break
if mtype != "query":
continue
req_id = msg["request_id"]
query = msg["query"]
stats = msg.get("stats")
result = worker.run_query(query, stats=stats)
response_queue.put({
"type": "result",
"request_id": req_id,
"result": result,
})
finally:
worker.close()
class WebSearchModule(nn.Module):
def __init__(self, model_dir: str = MODEL_DIR):
super().__init__()
self.model_dir = model_dir
self.bundle = TextModelBundle(model_dir=model_dir)
self._worker = None
self._command_queue = None
self._response_queue = None
self._worker_ctx = None
self._worker_started = False
def __getstate__(self):
state = self.__dict__.copy()
state["_worker"] = None
state["_command_queue"] = None
state["_response_queue"] = None
state["_worker_ctx"] = None
state["_worker_started"] = False
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._worker = None
self._command_queue = None
self._response_queue = None
self._worker_ctx = None
self._worker_started = False
def _ensure_worker(self):
if self._worker is not None and self._worker.is_alive():
return
self._worker_ctx = get_context("spawn")
self._command_queue = self._worker_ctx.Queue()
self._response_queue = self._worker_ctx.Queue()
self._worker = self._worker_ctx.Process(
target=_worker_main,
args=(self._command_queue, self._response_queue),
daemon=True,
)
self._worker.start()
self._worker_started = True
def close(self):
if getattr(self, "_worker", None) is None:
return
try:
if self._worker.is_alive() and self._command_queue is not None:
self._command_queue.put({"type": "shutdown"})
try:
if self._response_queue is not None:
self._response_queue.get(timeout=10)
except Exception:
pass
self._worker.join(timeout=10)
if self._worker.is_alive():
self._worker.terminate()
self._worker.join(timeout=5)
finally:
try:
if self._command_queue is not None:
self._command_queue.close()
except Exception:
pass
try:
if self._response_queue is not None:
self._response_queue.close()
except Exception:
pass
self._worker = None
self._command_queue = None
self._response_queue = None
self._worker_ctx = None
self._worker_started = False
def __del__(self):
try:
self.close()
except Exception:
pass
def embed(self, texts, convert_to_tensor=True):
return self.bundle.embed(texts, convert_to_tensor=convert_to_tensor)
def summarize(self, text: str, max_length: int = 300, min_length: int = 50):
return self.bundle.generate_summary(text, max_length=max_length, min_length=min_length)
def forward(self, query: str, stats: Optional[dict] = None, timeout: Optional[float] = None) -> str:
self._ensure_worker()
req_id = str(uuid.uuid4())
if stats is not None:
stats.setdefault("spiders_created", 0)
stats.setdefault("spiders_completed", 0)
stats.setdefault("spiders_killed", 0)
stats.setdefault("total_seconds", 0.0)
stats.setdefault("runs", [])
stats["spiders_created"] += 1
self._command_queue.put({
"type": "query",
"request_id": req_id,
"query": query,
"stats": stats,
})
started = time.perf_counter()
effective_timeout = timeout if timeout is not None else None
while True:
if effective_timeout is not None and (time.perf_counter() - started) > effective_timeout:
return f"Query is still running in the worker process for: {query}"
try:
msg = self._response_queue.get(timeout=0.25)
except queue.Empty:
continue
except Exception:
continue
if msg.get("type") != "result" or msg.get("request_id") != req_id:
continue
result: CrawlResult = msg["result"]
if stats is not None:
stats["spiders_completed"] += 1
stats["total_seconds"] += result.elapsed_seconds
stats["runs"].append({
"query": result.query,
"seconds": round(result.elapsed_seconds, 3),
"used_pages": len(result.used_candidates),
"failed_pages": len(result.failed_candidates),
"fallback_used": result.fallback_used,
"error": result.error,
})
return result.answer or f"No content extracted for query: {query}"
def __call__(self, query: str, **kwargs):
return self.forward(query, **kwargs)
class Web:
"""
High-level wrapper around WebSearchModule.
Features:
- Auto-locates WebSearch.pt in AppData
- Creates one automatically if missing
- Exposes all WebSearchModule methods
- Supports save/load/reload
- Supports direct querying through __call__
"""
DEFAULT_FOLDER = os.path.join(
os.getenv("APPDATA", os.path.expanduser("~")),
"PackedLLM"
)
DEFAULT_WEB_PATH = os.path.join(
DEFAULT_FOLDER,
"WebSearch.pt"
)
def __init__(
self,
web_location: Optional[str] = None,
model_dir: str = "models",
auto_create: bool = True,
):
self.model_dir = model_dir
if web_location:
self.web_path = os.path.abspath(web_location)
else:
self.web_path = self.DEFAULT_WEB_PATH
os.makedirs(os.path.dirname(self.web_path), exist_ok=True)
if os.path.exists(self.web_path):
self.web = self._load(self.web_path)
elif auto_create:
self.web = WebSearchModule(model_dir=model_dir)
self.save(self.web_path)
else:
raise FileNotFoundError(
f"WebSearch checkpoint not found: {self.web_path}"
)
def _load(self, path: str) -> WebSearchModule:
obj = torch.load(path, map_location="cpu", weights_only=False)
if not isinstance(obj, WebSearchModule):
raise TypeError(
f"{path} does not contain a WebSearchModule."
)
return obj
def save(self, path: Optional[str] = None):
target = path or self.web_path
os.makedirs(
os.path.dirname(os.path.abspath(target)),
exist_ok=True,
)
torch.save(self.web, target)
self.web_path = target
def reload(self):
self.close()
self.web = self._load(self.web_path)
def search(self, query: str, **kwargs):
return self.web.forward(query, **kwargs)
def embed(self, texts, convert_to_tensor=True):
return self.web.embed(
texts,
convert_to_tensor=convert_to_tensor,
)
def summarize(
self,
text: str,
max_length: int = 300,
min_length: int = 50,
):
return self.web.summarize(
text,
max_length=max_length,
min_length=min_length,
)
def close(self):
try:
self.web.close()
except Exception:
pass
def __call__(self, query: str, **kwargs):
return self.web(query, **kwargs)
def __getattr__(self, item):
return getattr(self.web, item)
@property
def location(self):
return self.web_path
@property
def exists(self):
return os.path.exists(self.web_path)
def info(self):
return {
"web_path": self.web_path,
"exists": self.exists,
"worker_running": (
self.web._worker is not None
and self.web._worker.is_alive()
),
"model_dir": self.model_dir,
}
def __repr__(self):
return (
f"Web("
f"path='{self.web_path}', "
f"exists={self.exists}"
f")"
)
PRIMARY_WEIGHT_FILES = ("pytorch_model.bin")
SKIP_BLOATED_FILES = {
"model.onnx",
"onnx_model.onnx",
"openvino_model.bin",
}
@staticmethod
def _bytes_to_uint8_tensor(data: bytes) -> torch.Tensor:
arr = np.frombuffer(data, dtype=np.uint8)
return torch.from_numpy(arr.copy())
@staticmethod
def _uint8_tensor_to_bytes(t: Union[torch.Tensor, bytes]) -> bytes:
if isinstance(t, bytes):
return t
return bytes(t.detach().cpu().contiguous().numpy().tobytes())
@staticmethod
def _pick_primary_weight_file(model_dir: str) -> Optional[str]:
for name in PRIMARY_WEIGHT_FILES:
if os.path.exists(os.path.join(model_dir, name)):
return name
return None
@dataclass
class PackedRecord:
id: str
text: str
meta: Dict[str, Any]
embedding: Optional[np.ndarray] = None
@dataclass
class PackedTreeSnapshot:
version: int
docs_blob: bytes
metas_blob: bytes
ids_blob: bytes
embs_blob: bytes
extra_blob: bytes = b""
class PackedTree:
def __init__(self, name: str, embed_fn: Callable[..., np.ndarray], cluster_k: int = 4):
self.name = name
self.embed_fn = embed_fn
self.cluster_k = int(cluster_k)
self.docs: List[str] = []
self.metas: List[Dict[str, Any]] = []
self.ids: List[str] = []
self.embs: np.ndarray = np.empty((0, 0), dtype=np.float32)
self.norm_embs: np.ndarray = np.empty((0, 0), dtype=np.float32)
self.id_to_idx: Dict[str, int] = {}
self.hash_to_id: Dict[str, str] = {}
self.query_cache: OrderedDict = OrderedDict()
self.cluster_cache: Dict[str, Any] = {}
self._lock = threading.RLock()
self._clusters_dirty = True
@staticmethod
def norm_text(text: str) -> str:
return _norm_ws(text).lower()
@staticmethod
def text_hash(text: str) -> str:
h = hashlib.sha256()
h.update(PackedTree.norm_text(text).encode("utf-8"))
return h.hexdigest()
def _cache_get(self, key):
v = self.query_cache.get(key)
if v is not None:
self.query_cache.move_to_end(key)
return v
def _cache_put(self, key, value, max_size: int = 512):
self.query_cache[key] = value
self.query_cache.move_to_end(key)
while len(self.query_cache) > max_size:
self.query_cache.popitem(last=False)
def add(self, text: str, meta: Optional[Dict[str, Any]] = None, item_id: Optional[str] = None) -> str:
text = _norm_ws(text)
if not text:
return ""
with self._lock:
meta = dict(meta or {})
item_id = item_id or meta.get("id") or str(uuid.uuid4())
doc_hash = meta.get("hash") or self.text_hash(text)
if doc_hash in self.hash_to_id:
return self.hash_to_id[doc_hash]
emb = np.asarray(self.embed_fn(text), dtype=np.float32)
if emb.ndim != 1:
emb = emb.reshape(-1)
self.id_to_idx[item_id] = len(self.docs)
self.hash_to_id[doc_hash] = item_id
meta.setdefault("id", item_id)
meta.setdefault("hash", doc_hash)
meta.setdefault("timestamp", _now_iso())
self.docs.append(text)
self.metas.append(meta)
self.ids.append(item_id)
if self.embs.size == 0:
self.embs = emb.reshape(1, -1).astype(np.float32)
else:
if self.embs.shape[1] != emb.shape[0]:
raise ValueError(f"Embedding dimension mismatch in tree '{self.name}': {emb.shape[0]} != {self.embs.shape[1]}")
self.embs = np.vstack([self.embs, emb.reshape(1, -1)])
self.norm_embs = self._normalize_embeddings(self.embs)
self._clusters_dirty = True
return item_id
def bulk_add(self, items: Sequence[Tuple[str, Dict[str, Any], Optional[str]]]) -> List[str]:
ids = []
for text, meta, item_id in items:
ids.append(self.add(text, meta=meta, item_id=item_id))
return ids
def update_meta(self, item_id: str, patch: Mapping[str, Any]):
with self._lock:
idx = self.id_to_idx.get(item_id)
if idx is None:
return
self.metas[idx].update(dict(patch))
self._clusters_dirty = True
def record_usage(self, item_id: str):
with self._lock:
idx = self.id_to_idx.get(item_id)
if idx is None:
return
md = self.metas[idx]
md["usage_count"] = int(md.get("usage_count", 0)) + 1
md["last_used"] = _now_iso()
@staticmethod
def _compress(obj: Any) -> bytes:
payload = json.dumps(obj, ensure_ascii=False, default=_json_default).encode("utf-8")
return lzma.compress(payload, preset=9)
@staticmethod
def _decompress(blob: bytes, default: Any = None) -> Any:
if not blob:
return default
try:
raw = lzma.decompress(blob)
return json.loads(raw.decode("utf-8"))
except Exception:
return default
def snapshot(self) -> PackedTreeSnapshot:
with self._lock:
docs_blob = self._compress(self.docs)
metas_blob = self._compress(self.metas)
ids_blob = self._compress(self.ids)
embs_blob = lzma.compress(self.embs.astype(np.float16).tobytes(), preset=9) if self.embs.size else b""
extra = {
"shape": list(self.embs.shape),
"dtype": "float16",
"cluster_k": self.cluster_k,
"hash_to_id": self.hash_to_id,
"id_to_idx": self.id_to_idx,
}
extra_blob = self._compress(extra)
return PackedTreeSnapshot(
version=1,
docs_blob=docs_blob,
metas_blob=metas_blob,
ids_blob=ids_blob,
embs_blob=embs_blob,
extra_blob=extra_blob,
)
def restore(self, snap: PackedTreeSnapshot):
with self._lock:
self.docs = self._decompress(snap.docs_blob, default=[])
self.metas = self._decompress(snap.metas_blob, default=[])
self.ids = self._decompress(snap.ids_blob, default=[])
extra = self._decompress(snap.extra_blob, default={}) or {}
shape = tuple(extra.get("shape") or [0, 0])
self.cluster_k = int(extra.get("cluster_k", self.cluster_k))
self.hash_to_id = dict(extra.get("hash_to_id", {}))
self.id_to_idx = {k: int(v) for k, v in dict(extra.get("id_to_idx", {})).items()}
if snap.embs_blob and shape and shape[0] > 0 and shape[1] > 0:
raw = lzma.decompress(snap.embs_blob)
arr = np.frombuffer(raw, dtype=np.float16).reshape(shape).astype(np.float32)
self.embs = arr
self.norm_embs = self._normalize_embeddings(arr)
else:
self.embs = np.empty((0, 0), dtype=np.float32)
self.norm_embs = np.empty((0, 0), dtype=np.float32)
self.query_cache = OrderedDict()
self.cluster_cache = {}
self._clusters_dirty = True
@staticmethod
def _normalize_embeddings(embs: np.ndarray) -> np.ndarray:
if embs.size == 0:
return embs
norms = np.linalg.norm(embs, axis=1, keepdims=True)
norms[norms == 0] = 1.0
return embs / norms
@staticmethod
def _cosine_scores(query_emb: np.ndarray, matrix: np.ndarray) -> np.ndarray:
if matrix.size == 0:
return np.array([], dtype=np.float32)
q = query_emb.astype(np.float32).reshape(1, -1)
qn = q / np.maximum(np.linalg.norm(q, axis=1, keepdims=True), 1e-8)
mn = PackedTree._normalize_embeddings(matrix.astype(np.float32))
return (qn @ mn.T)[0]
def _build_clusters(self):
if KMeans is None or self.embs.shape[0] < 2:
self.cluster_cache = {"centers": None, "clusters": {0: {"idxs": list(range(len(self.docs)))}}}
self._clusters_dirty = False
return
k = min(self.cluster_k, self.embs.shape[0])
if k <= 1:
self.cluster_cache = {"centers": self.norm_embs[:1], "clusters": {0: {"idxs": list(range(len(self.docs)))}}}
self._clusters_dirty = False
return
km = KMeans(n_clusters=k, random_state=0, n_init="auto")
labels = km.fit_predict(self.norm_embs)
centers = km.cluster_centers_.astype(np.float32)
clusters: Dict[int, Dict[str, Any]] = {i: {"idxs": []} for i in range(k)}
for idx, lab in enumerate(labels):
clusters[int(lab)]["idxs"].append(idx)
self.cluster_cache = {"centers": centers, "clusters": clusters}
self._clusters_dirty = False
def search(self, query: str, top_k: int = 5, min_score: float = 0.0, hybrid: bool = True, use_clusters: bool = False) -> List[Dict[str, Any]]:
query = _norm_ws(query)
if not query:
return []
qkey = (query, top_k, min_score, hybrid, use_clusters)
cached = self._cache_get(qkey)
if cached is not None:
return cached
if self.embs.size == 0:
return []
q_emb = np.asarray(self.embed_fn(query), dtype=np.float32).reshape(-1)
scores = self._cosine_scores(q_emb, self.embs)
if scores.size == 0:
return []
order = np.argsort(scores)[::-1]
results: List[Dict[str, Any]] = []
seen = set()
q_tokens = set(self.norm_text(query).split())
for idx in order:
score = float(scores[idx])
if score < min_score:
continue
doc = self.docs[int(idx)]
doc_norm = self.norm_text(doc)
if doc_norm in seen:
continue
seen.add(doc_norm)
md = self.metas[int(idx)]
kw = 0.0
if hybrid and q_tokens:
d_tokens = set(doc_norm.split())
kw = len(q_tokens.intersection(d_tokens)) / max(1.0, (len(q_tokens) + len(d_tokens)) / 2.0)
final = 0.75 * score + 0.2 * kw + 0.05 * float(md.get("importance", 0.5))
results.append({
"id": self.ids[int(idx)],
"passage": doc,
"raw_similarity": score,
"score": max(0.0, min(1.0, final)),
"metadata": md,
})
if len(results) >= top_k:
break
if use_clusters and self._clusters_dirty:
self._build_clusters()
self._cache_put(qkey, results)
return results
def retrieve_by_semantics(self, query: str, num_clusters: int = 2, top_k_per_cluster: int = 3, min_score: float = 0.0) -> List[Dict[str, Any]]:
query = _norm_ws(query)
if not query:
return []
if self.embs.size == 0:
return []
if self._clusters_dirty:
self._build_clusters()
centers = self.cluster_cache.get("centers")
clusters = self.cluster_cache.get("clusters") or {}
if centers is None:
return self.search(query, top_k=num_clusters * top_k_per_cluster, min_score=min_score, hybrid=True)
q_emb = np.asarray(self.embed_fn(query), dtype=np.float32).reshape(1, -1)
center_sims = self._cosine_scores(q_emb.reshape(-1), centers)
top_cluster_ids = np.argsort(center_sims)[::-1][:min(num_clusters, len(center_sims))]
results: List[Dict[str, Any]] = []
seen = set()
for cid in top_cluster_ids:
idxs = clusters.get(int(cid), {}).get("idxs", [])
if not idxs:
continue
local_embs = self.norm_embs[idxs]
sims = self._cosine_scores(q_emb.reshape(-1), local_embs)
top_local = np.argsort(sims)[::-1][:top_k_per_cluster]
for local_idx in top_local:
global_idx = idxs[int(local_idx)]
raw = float(sims[int(local_idx)])
if raw < min_score:
continue
doc = self.docs[global_idx]
doc_norm = self.norm_text(doc)
if doc_norm in seen:
continue
seen.add(doc_norm)
md = self.metas[global_idx]
q_tokens = set(self.norm_text(query).split())
d_tokens = set(doc_norm.split())
kw = len(q_tokens.intersection(d_tokens)) / max(1.0, (len(q_tokens) + len(d_tokens)) / 2.0)
final = 0.75 * raw + 0.2 * kw + 0.05 * float(md.get("importance", 0.5))
results.append({
"id": self.ids[global_idx],
"passage": doc,
"raw_similarity": raw,
"score": max(0.0, min(1.0, final)),
"metadata": md,
})
if len(results) >= top_k_per_cluster * num_clusters:
break
if len(results) < top_k_per_cluster:
extra = self.search(query, top_k=top_k_per_cluster * num_clusters, min_score=min_score, hybrid=True)
for item in extra:
if self.norm_text(item["passage"]) not in seen:
seen.add(self.norm_text(item["passage"]))
results.append(item)
return results
def _json_default(obj: Any):
if isinstance(obj, (np.integer, np.floating)):
return obj.item()
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, (set, tuple)):
return list(obj)
if dataclasses.is_dataclass(obj):
return asdict(obj)
if isinstance(obj, bytes):
return base64.b64encode(obj).decode("ascii")
if isinstance(obj, Path):
return str(obj)
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
def _extract_response_text_from_result(result: Any) -> str:
def coerce_to_str(v):
if isinstance(v, str):
return v
if isinstance(v, (list, tuple)):
pieces = []
for x in v:
if isinstance(x, str) and x.strip():
pieces.append(x)
elif isinstance(x, dict):
for k in ("content", "response", "assistant", "final"):
if k in x and isinstance(x[k], str) and x[k].strip():
pieces.append(x[k])
break
else:
for val in x.values():
if isinstance(val, str) and val.strip():
pieces.append(val)
else:
try:
pieces.append(str(x))
except Exception:
pass
return "\n".join(pieces)
if isinstance(v, dict):
for key in ("response", "assistant", "final", "content"):
val = v.get(key)
if isinstance(val, str) and val.strip():
return val
if isinstance(val, (list, tuple, dict)):
s = coerce_to_str(val)
if s:
return s
vals = [str(x) for x in v.values() if isinstance(x, str) and x.strip()]
if vals:
return "\n".join(vals)
try:
return json.dumps(v)
except Exception:
return str(v)
try:
return str(v)
except Exception:
return ""
if isinstance(result, dict):
for key in ("blocks", "response", "assistant", "final", "content"):
if key in result and result[key]:
return coerce_to_str(result[key])
vals = [v for v in result.values() if isinstance(v, (str, list, dict)) and v]
if vals:
return coerce_to_str(vals[0])
return json.dumps(result)
return coerce_to_str(result)
class DesktopControl:
def __init__(self):
self._available = None
def _lazy_import(self):
if self._available is not None:
return self._available
mods = {}
for name in ("pyautogui", "keyboard", "mouse", "psutil", "win32gui", "pygetwindow", "ctypes"):
try:
mods[name] = importlib.import_module(name)
except Exception:
mods[name] = None
self._available = mods
return mods
def get_location_string(self) -> str:
try:
import geocoder # type: ignore
g = geocoder.ip("me")
city = g.city or "UnknownCity"
state = g.state or "UnknownState"
country = g.country or "UnknownCountry"
return f"{city}/{state}/{country}"
except Exception:
return "UnknownCity/UnknownState/UnknownCountry"
def get_time_string(self) -> str:
now = datetime.now()
date_str = now.strftime("%d/%m/%Y")
time_str = now.strftime("%I:%M:%S/%p").lower()
return f"{date_str}\n{time_str}"
def is_desktop_active(self) -> bool:
mods = self._lazy_import()
win32gui = mods.get("win32gui")
if win32gui is None:
return False
desktop_hwnd = win32gui.GetDesktopWindow()
active_hwnd = win32gui.GetForegroundWindow()
return active_hwnd == desktop_hwnd
def is_program_active(self, program_name: str) -> bool:
mods = self._lazy_import()
gw = mods.get("pygetwindow")
psutil_mod = mods.get("psutil")
if gw is None or psutil_mod is None:
return False
active_window = gw.getActiveWindow()
if active_window:
active_title = active_window.title or ""
for process in psutil_mod.process_iter(["pid", "name"]):
name = (process.info.get("name") or "").lower()
if name == program_name.lower():
return program_name.lower() in active_title.lower()
return False
def fast_move(self, x, y):
mods = self._lazy_import()
if mods.get("ctypes") is None:
return
mods["ctypes"].windll.user32.SetCursorPos(x, y)
def scroll_mouse(self, delta):
mods = self._lazy_import()
if mods.get("ctypes") is None:
return
mods["ctypes"].windll.user32.mouse_event(0x0800, 0, 0, int(delta * 120), 0)
def press_special_key(self, key):
mods = self._lazy_import()
if mods.get("ctypes") is None:
return
special_keys = {
"volume up": 0xAF,
"volume down": 0xAE,
"volume mute": 0xAD,
"play/pause media": 0xB3,
"next track": 0xB0,
"prev track": 0xB1,
}
vk_code = special_keys.get(key)
if vk_code is None:
return
mods["ctypes"].windll.user32.keybd_event(vk_code, 0, 0, 0)
mods["ctypes"].windll.user32.keybd_event(vk_code, 0, 2, 0)
time.sleep(0.1)
def minimize_all_windows(self):
mods = self._lazy_import()
if mods.get("ctypes") is None:
return
mods["ctypes"].windll.user32.keybd_event(0x5B, 0, 0, 0)
mods["ctypes"].windll.user32.keybd_event(0x4D, 0, 0, 0)
mods["ctypes"].windll.user32.keybd_event(0x4D, 0, 2, 0)
mods["ctypes"].windll.user32.keybd_event(0x5B, 0, 2, 0)
time.sleep(1)
def run_pyautogui_command(self, command_name: str):
mods = self._lazy_import()
pyautogui = mods.get("pyautogui")
if pyautogui is None:
raise RuntimeError("pyautogui is not installed")
getattr(pyautogui, command_name)()
@dataclass
class ActionDecision:
type: str
command_name: Optional[str] = None
command_description: Optional[str] = None
command_text: Optional[str] = None
memory_access: List[str] = field(default_factory=list)
source: Optional[str] = None
profile: Optional[str] = None
query: Optional[str] = None
sufficient_to_answer: bool = False
parameters: Dict[str, Any] = field(default_factory=dict)
class CommandRegistry:
def __init__(self, owner: "GATOR"):
self.owner = owner
self.commands: Dict[str, Dict[str, Any]] = {}
self.custom_commands: Dict[str, Dict[str, Any]] = {}
self.shortcuts_dir = "./Mods/COMMANDS"
self._lock = threading.RLock()
def register_command(self, name: str, description: str, action: Optional[Callable] = None, command_type: str = "basic"):
with self._lock:
self.commands[name] = {
"name": name,
"description": description,
"action": action,
"type": command_type,
}
def update_command(self, name: str, description: Optional[str] = None, action: Optional[Callable] = None, command_type: Optional[str] = None):
with self._lock:
if name not in self.commands:
self.commands[name] = {"name": name, "description": "", "action": None, "type": "basic"}
if description is not None:
self.commands[name]["description"] = description
if action is not None:
self.commands[name]["action"] = action
if command_type is not None:
self.commands[name]["type"] = command_type
def register_custom_command(self, command_name: str, phrase: str, description: str, actions: Optional[List[Dict[str, Any]]] = None):
with self._lock:
self.custom_commands[command_name] = {
"phrase": phrase,
"description": description,
"actions": actions or [],
}
self.owner.command_tree.add_command_branch(
command_name=phrase,
command_action=command_name,
command_type="custom",
description=description,
)
def check_custom_commands(self, text: str, matched_commands: Optional[List[str]] = None):
matched_commands = matched_commands or []
text_lower = (text or "").lower()
for command_name, data in self.custom_commands.items():
phrase = (data.get("phrase") or "").lower()
if phrase and phrase in text_lower:
matched_commands.append(command_name)
return matched_commands if matched_commands else False
def check_basic_commands(self, text: str, matched_commands: Optional[List[Tuple[Callable, List[Any]]]] = None):
matched_commands = matched_commands or []
text_lower = (text or "").lower()
for command_dict in self.commands.values():
commands = command_dict.get("commands") or [command_dict.get("name")]
action = command_dict.get("action")
if action is None:
continue
for cmd in commands:
if cmd and cmd.lower() in text_lower:
matched_commands.append((action, []))
return bool(matched_commands)
def check_shortcuts(self, text: str, shortcut_dir: Optional[str] = None):
shortcut_dir = shortcut_dir or self.shortcuts_dir
os.makedirs(shortcut_dir, exist_ok=True)
actions = []
text_lower = (text or "").lower()
for entry in os.listdir(shortcut_dir):
entry_path = os.path.join(shortcut_dir, entry)
if os.path.isdir(entry_path):
actions.extend(self._check_folder(entry_path, text_lower))
else:
name_no_ext = os.path.splitext(entry)[0].lower()
if os.path.isfile(entry_path) and name_no_ext in text_lower:
actions.append((os.startfile, [entry_path]))
return actions
def _check_folder(self, base_dir, text_lower):
matched_paths = []
for entry in os.listdir(base_dir):
entry_path = os.path.join(base_dir, entry)
name_no_ext = os.path.splitext(entry)[0].lower()
if os.path.isfile(entry_path) and name_no_ext in text_lower:
matched_paths.append(entry_path)
elif os.path.isdir(entry_path):
matched_paths.extend(self._check_folder(entry_path, text_lower))
return [(os.startfile, [p]) for p in matched_paths]
def execute_command(self, commands, command_executed_tags, argument_dictionary=None):
for action, _ in commands:
map_args = []
if argument_dictionary:
map_args = []
for arg in []:
if isinstance(arg, str) and arg.startswith("{") and arg.endswith("}"):
key = arg.strip("{}")
map_args.append(argument_dictionary.get(key, ""))
else:
map_args.append(arg)
action(*map_args)
command_executed_tags.append(f"Executed action: {getattr(action, '__name__', str(action))}")
def execute_shortcut(self, actions, command_executed_tags):
for action_fn, args in actions:
file_path = os.path.abspath(args[0])
if not os.path.exists(file_path):
command_executed_tags.append(f"Path not found: {file_path}")
continue
if os.path.isdir(file_path):
command_executed_tags.append(f"Skipped directory: {file_path}")
continue
try:
os.startfile(file_path)
command_executed_tags.append(f"Opened: {file_path}")
except Exception:
try:
subprocess.Popen(["cmd", "/c", "start", "", file_path], shell=True)
command_executed_tags.append(f"Opened via cmd start: {file_path}")
except Exception:
command_executed_tags.append(f"Failed to open: {file_path}")
def execute_custom_command(self, command_name):
payload = self.custom_commands.get(command_name)
if not payload:
return
actions = payload.get("actions") or []
if not actions:
return
for action in actions:
kind = action.get("event")
if kind == "key_down":
self.owner.desktop._lazy_import().get("pyautogui")
import pyautogui # type: ignore
pyautogui.keyDown(action["key"])
elif kind == "key_up":
import pyautogui # type: ignore
pyautogui.keyUp(action["key"])
elif kind == "mouse_down":
import pyautogui # type: ignore
pyautogui.mouseDown(button=action["button"])
elif kind == "mouse_up":
import pyautogui # type: ignore
pyautogui.mouseUp(button=action["button"])
elif kind == "mouse_move":
self.owner.desktop.fast_move(action["x"], action["y"])
elif kind == "mouse_scroll":
self.owner.desktop.scroll_mouse(action["delta"])
def process_commands(self, command, command_type="shortcut", argument_dictionary=None):
command_executed_tags = []
executed_actions = []
if command_type == "basic":
basic_matches = []
if self.check_basic_commands(command, basic_matches):
self.execute_command(basic_matches, command_executed_tags, argument_dictionary)
executed_actions.extend(basic_matches)
elif command_type == "shortcut":
shortcut_actions = self.check_shortcuts(command, self.shortcuts_dir)
if shortcut_actions:
self.execute_shortcut(shortcut_actions, command_executed_tags)
executed_actions.extend(shortcut_actions)
elif command_type == "custom":
found = self.check_custom_commands(command)
if found:
for cmd_name in found:
self.execute_custom_command(cmd_name)
executed_actions.append(cmd_name)
return executed_actions
# ---------------------------------------------------------------------------
# GATOR module
# ---------------------------------------------------------------------------
class GATOR(nn.Module):
STATE_VERSION = 1
def __init__(
self,
lm_checkpoint_path: str = "LM.pt",
embedder_name: str = "second-state/jina-embeddings-v3-GGUF",
embedder_local_dir: str = os.path.join("models", "jinaai"),
embedder_filename: str = GGUF_EMBED_FILENAME,
device: str = "cpu",
warm_on_start: bool = True,
compression: str = "lzma",
store_dtype: str = "float16",
cluster_k: int = 4,
auto_load_lm: bool = True,
strict_lm: bool = True,
embedder_pack: Optional[Dict[str, Any]] = None,
):
super().__init__()
self.config = {
"lm_checkpoint_path": lm_checkpoint_path,
"embedder_name": embedder_name,
"embedder_local_dir": embedder_local_dir,
"embedder_filename": embedder_filename,
"device": device,
"warm_on_start": warm_on_start,
"compression": compression,
"store_dtype": store_dtype,
"cluster_k": cluster_k,
"auto_load_lm": auto_load_lm,
"strict_lm": strict_lm,
}
self.device_name = device
self.compression = compression
self.store_dtype = store_dtype
self.cluster_k = int(cluster_k)
self.strict_lm = strict_lm
self._lock = threading.RLock()
self._snapshot_cache: Dict[str, bytes] = {}
self._runtime_cache: OrderedDict = OrderedDict()
self._runtime_cache_max = 16
self._last_route: Dict[str, Any] = {}
self._last_response: str = ""
self._last_plan: Dict[str, Any] = {}
self.desktop = DesktopControl()
self.command_registry = CommandRegistry(self)
self.lm = self._load_lm(lm_checkpoint_path, auto_load=auto_load_lm)
self.embedder_name = embedder_name
self.embedder_local_dir = embedder_local_dir
self.embedder_filename = embedder_filename
if embedder_pack is not None and embedder_pack.get("gguf_bytes") is not None:
self.embedder_pack = embedder_pack
self.embedder = self._restore_embedder_from_pack(self.embedder_pack)
self.embedder_path = self.embedder_pack.get("gguf_source_path", "")
else:
self.embedder_path = self._resolve_local_embedder_gguf(embedder_local_dir, embedder_filename)
self.embedder_pack = self._load_embedder_pack(self.embedder_path)
self.embedder = self._restore_embedder_from_pack(self.embedder_pack)
self.embedder_tokenizer = None
probe = self._embed_raw(["__gator_probe__"], task="retrieval.passage")
self.embed_dim = int(probe.shape[-1]) if probe.ndim == 2 and probe.shape[-1] > 0 else 1024
self._store: Dict[str, PackedTree] = {
"knowledge": PackedTree("knowledge", self.embed, cluster_k=self.cluster_k),
"conversation": PackedTree("conversation", self.embed, cluster_k=self.cluster_k),
"profile_user": PackedTree("profile_user", self.embed, cluster_k=self.cluster_k),
"profile_bot": PackedTree("profile_bot", self.embed, cluster_k=self.cluster_k),
"commands": PackedTree("commands", self.embed, cluster_k=self.cluster_k),
"assets": PackedTree("assets", self.embed, cluster_k=self.cluster_k),
"telemetry": PackedTree("telemetry", self.embed, cluster_k=self.cluster_k),
}
self._command_phrases: Dict[str, Dict[str, Any]] = {}
self._warm_on_start = warm_on_start
if warm_on_start:
self.warmup()
@staticmethod
def _bytes_to_uint8_tensor(data: bytes) -> torch.Tensor:
arr = np.frombuffer(data, dtype=np.uint8)
return torch.from_numpy(arr.copy())
@staticmethod
def _uint8_tensor_to_bytes(t: torch.Tensor) -> bytes:
return bytes(t.detach().cpu().contiguous().numpy().tobytes())
def __getstate__(self):
state = self.__dict__.copy()
state["embedder"] = None
state["embedder_tokenizer"] = None
state["embedder_pack"] = self._snapshot_embedder_pack()
state["lm"] = self._snapshot_lm_handle()
state["_lock"] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._lock = threading.RLock()
self.desktop = self.desktop if isinstance(self.desktop, DesktopControl) else DesktopControl()
self.command_registry = self.command_registry if isinstance(self.command_registry, CommandRegistry) else CommandRegistry(self)
if self.embedder is None:
self.embedder = self._restore_embedder_from_pack(self.embedder_pack)
self.embedder_tokenizer = None
self.lm = self._restore_lm_handle(self.lm)
if not hasattr(self, "_store"):
self._store = {
"knowledge": PackedTree("knowledge", self.embed, cluster_k=self.cluster_k),
"conversation": PackedTree("conversation", self.embed, cluster_k=self.cluster_k),
"profile_user": PackedTree("profile_user", self.embed, cluster_k=self.cluster_k),
"profile_bot": PackedTree("profile_bot", self.embed, cluster_k=self.cluster_k),
"commands": PackedTree("commands", self.embed, cluster_k=self.cluster_k),
"assets": PackedTree("assets", self.embed, cluster_k=self.cluster_k),
"telemetry": PackedTree("telemetry", self.embed, cluster_k=self.cluster_k),
}
def _load_lm(self, lm_checkpoint_path: str, auto_load: bool = True):
if not auto_load:
return None
if load_packedlm is None:
raise RuntimeError("PackedLM.load_packedlm is unavailable. Import PackedLM before GATOR.")
if not os.path.exists(lm_checkpoint_path):
raise FileNotFoundError(f"LM checkpoint not found: {lm_checkpoint_path}")
return load_packedlm(lm_checkpoint_path)
def _snapshot_lm_handle(self):
return self.lm
def _restore_lm_handle(self, packed):
return packed
def _resolve_local_embedder_gguf(self, local_dir: str, embedder_filename: str) -> str:
candidates = [
Path(local_dir).resolve() if local_dir else None,
Path("models").resolve(),
(Path("models") / "jinaai").resolve(),
(Path("models") / "jinaai" / "jina-embeddings-v3").resolve(),
]
candidates = [p for p in candidates if p is not None]
for root in candidates:
if not root.exists():
continue
direct = root / embedder_filename
if direct.is_file():
return str(direct)
for p in root.rglob(embedder_filename):
if p.is_file():
return str(p)
raise FileNotFoundError(
f"Could not find {embedder_filename} under: {[str(p) for p in candidates]}"
)
def _load_embedder_pack(self, gguf_path: str) -> Dict[str, Any]:
pack_path = str(Path(gguf_path).with_suffix(Path(gguf_path).suffix + ".pt"))
if os.path.exists(pack_path):
try:
pack = torch.load(pack_path, map_location="cpu", weights_only=False)
if isinstance(pack, dict) and pack.get("gguf_bytes") is not None:
return pack
except Exception:
try:
os.remove(pack_path)
except Exception:
pass
raw = Path(gguf_path).read_bytes()
pack = {
"gguf_filename": Path(gguf_path).name,
"gguf_bytes": self._bytes_to_uint8_tensor(raw),
"gguf_source_path": str(Path(gguf_path).resolve()),
}
try:
torch.save(pack, pack_path, pickle_protocol=5)
except Exception:
pass
return pack
def _restore_embedder_from_pack(self, pack: Dict[str, Any]):
if not pack or pack.get("gguf_bytes") is None:
raise RuntimeError("Embedder GGUF pack is missing")
tmp_dir = tempfile.mkdtemp(prefix="gator_embedder_")
try:
gguf_path = Path(tmp_dir) / pack["gguf_filename"]
gguf_path.write_bytes(self._uint8_tensor_to_bytes(pack["gguf_bytes"]))
llm = Llama(
model_path=str(gguf_path),
embedding=True,
verbose=False,
n_ctx=8192,
use_mmap=False,
use_mlock=False,
)
return llm
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
def _load_embedder(self, model_name: str, local_dir: str):
gguf_path = self._resolve_local_embedder_gguf(local_dir, self.embedder_filename)
pack = self._load_embedder_pack(gguf_path)
embedder = self._restore_embedder_from_pack(pack)
return embedder, None, pack
def _snapshot_embedder_pack(self) -> Dict[str, Any]:
if isinstance(self.embedder_pack, dict) and self.embedder_pack.get("gguf_bytes") is not None:
return {
"gguf_filename": self.embedder_pack["gguf_filename"],
"gguf_bytes": self.embedder_pack["gguf_bytes"],
"gguf_source_path": self.embedder_pack.get("gguf_source_path", ""),
}
raise RuntimeError("Embedder pack is missing")
def _embed_raw(self, texts: Union[str, Sequence[str]], task: str = "retrieval.passage") -> np.ndarray:
if isinstance(texts, str):
texts = [texts]
if not texts:
return np.empty((0, self.embed_dim), dtype=np.float32)
if hasattr(self.embedder, "create_embedding"):
resp = self.embedder.create_embedding(list(texts))
if isinstance(resp, dict) and "data" in resp:
embs = [row["embedding"] for row in resp["data"]]
else:
embs = resp
elif hasattr(self.embedder, "embed"):
resp = self.embedder.embed(list(texts))
if isinstance(resp, dict) and "data" in resp:
embs = [row["embedding"] for row in resp["data"]]
else:
embs = resp
else:
raise RuntimeError("Loaded GGUF embedder does not expose create_embedding() or embed()")
embs = np.asarray(embs, dtype=np.float32)
if embs.ndim == 1:
embs = embs.reshape(1, -1)
return embs
def embed(self, texts: Union[str, Sequence[str]], task: str = "retrieval.passage") -> np.ndarray:
return self._embed_raw(texts, task=task)
def embed_query(self, texts: Union[str, Sequence[str]]) -> np.ndarray:
return self.embed(texts, task="retrieval.query")
def embed_passage(self, texts: Union[str, Sequence[str]]) -> np.ndarray:
return self.embed(texts, task="retrieval.passage")
def embed_classification(self, texts: Union[str, Sequence[str]]) -> np.ndarray:
return self.embed(texts, task="classification")
def embed_matching(self, texts: Union[str, Sequence[str]]) -> np.ndarray:
return self.embed(texts, task="text-matching")
def _snapshot_store(self) -> Dict[str, Any]:
return {name: tree.snapshot() for name, tree in self._store.items()}
def _restore_store(self, packed: Mapping[str, Any]):
for name, tree in self._store.items():
snap = packed.get(name)
if isinstance(snap, PackedTreeSnapshot):
tree.restore(snap)
elif isinstance(snap, dict):
tree.restore(PackedTreeSnapshot(**snap))
@staticmethod
def normalize_for_hash(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: GATOR.normalize_for_hash(v) for k, v in obj.items()}
if isinstance(obj, list):
return [GATOR.normalize_for_hash(v) for v in obj]
if hasattr(obj, "item"):
try:
return obj.item()
except Exception:
return obj
return obj
def get_location_string(self):
return self.desktop.get_location_string()
def get_time_string(self):
return self.desktop.get_time_string()
def _lm_head(self, prompt: str, mode: str = "decision") -> Dict[str, Any]:
if self.lm is None:
raise RuntimeError("LM.pt is not loaded")
system_prompt = (
"You are GATOR HeadExpert. Decide whether the user request needs retrieval, a tool, or a direct answer. "
"Return strict JSON only."
)
tool_prompt = {
"input_text": prompt,
"available_commands": list(self.command_registry.commands.values()),
"system_goal": system_prompt,
"mode": mode,
}
if hasattr(self.lm, "head_expert"):
raw = self.lm.head_expert(_safe_json_dumps(tool_prompt))
else:
raw = self.lm.head(prompt)
return self._parse_json_object(raw, default={"actions": []})
def _lm_tool(self, query: str, tools: List[Dict[str, Any]]) -> Dict[str, Any]:
if self.lm is None:
raise RuntimeError("LM.pt is not loaded")
if hasattr(self.lm, "tool_expert"):
raw = self.lm.tool_expert(query, tools=tools)
else:
raw = self.lm.tool(query, tools=tools)
return self._parse_json_object(raw, default={"tool_calls": []})
@staticmethod
def _parse_json_object(text: str, default: Any = None) -> Any:
if not text:
return default
text = text.strip()
try:
return json.loads(text)
except Exception:
m = re.search(r"(\{.*})", text, re.DOTALL)
if m:
try:
return json.loads(m.group(1))
except Exception:
return default
return default
def _decision_to_actions(self, decision: Mapping[str, Any]) -> List[ActionDecision]:
actions = decision.get("actions") if isinstance(decision, Mapping) else None
if not isinstance(actions, list):
actions = [decision] if isinstance(decision, Mapping) else []
parsed: List[ActionDecision] = []
for item in actions:
if not isinstance(item, Mapping):
continue
parsed.append(ActionDecision(
type=str(item.get("type", "None")),
command_name=item.get("command_name"),
command_description=item.get("command_description"),
command_text=item.get("command_text"),
memory_access=list(item.get("memory_access") or []),
source=item.get("source"),
profile=item.get("profile"),
query=item.get("query"),
sufficient_to_answer=bool(item.get("sufficient_to_answer", False)),
parameters=dict(item.get("parameters") or {}),
))
return parsed
def store_knowledge(self, documents: Sequence[str], tags: Optional[Sequence[str]] = None, source: str = "user", importance: float = 0.5):
tags = list(tags) if tags is not None else ["knowledge"] * len(documents)
if len(tags) != len(documents):
raise ValueError("Length of tags must match length of documents")
for doc, tag in zip(documents, tags):
meta = {"tag": tag, "source": source, "importance": float(importance), "usage_count": 0, "last_used": None}
self._store["knowledge"].add(doc, meta)
def search_knowledge(self, query: str, top_k: int = 5, hybrid: bool = True, min_score: float = 0.0):
return self._store["knowledge"].search(query, top_k=top_k, hybrid=hybrid, min_score=min_score, use_clusters=True)
def process_knowledge(self, query, history="", location="", time_date=""):
if not query:
return []
try:
results = self._store["knowledge"].retrieve_by_semantics(query=query, num_clusters=3, top_k_per_cluster=3, min_score=0.0)
except Exception:
results = self._store["knowledge"].search(query=query, top_k=9, hybrid=True)
return results
def store_conversation_leaf(self, text: str, conv_id: str, leaf_type: str = "input"):
meta = {"id": conv_id, "type": "branch", "leaf_type": leaf_type}
self._store["conversation"].add(text, meta)
def process_conversation(self, query, id, type="input"):
if type == "input":
relevant_context = self.conversation_tree.add_input_leaf(query, id)
return relevant_context
else:
self.conversation_tree.add_output_leaf(query, id)
def store_profile_leaf(self, profile_id: str, text: str, importance: float = 0.5, profile_type: str = "user"):
target = "profile_user" if profile_type == "user" else "profile_bot"
meta = {"profile_id": profile_id, "importance": float(importance), "source": profile_type}
self._store[target].add(text, meta)
def search_profile_leaves(self, profile_id: str, query: str, profile_type: str = "user", top_k: int = 3, min_score: float = 0.0):
target = "profile_user" if profile_type == "user" else "profile_bot"
tree = self._store[target]
results = tree.search(query, top_k=max(top_k * 3, top_k), min_score=min_score, hybrid=True, use_clusters=True)
filtered = [r for r in results if r.get("metadata", {}).get("profile_id") == profile_id]
return filtered[:top_k]
def store_command(
self,
command_name: str,
phrase: str,
description: str,
command_type: str = "custom",
actions: Optional[List[Dict[str, Any]]] = None
):
meta = {
"command_name": command_name,
"phrase": phrase,
"description": description,
"type": command_type,
"actions": actions or [],
}
self.command_registry.custom_commands[command_name] = meta
self.command_registry.commands[command_name] = meta
self._store["commands"].add(
f"Command: {phrase}",
{
"command_name": command_name,
"phrase": phrase,
"description": description,
"type": command_type,
},
item_id=f"command::{command_type}::{command_name}",
)
def search_commands(self, query: str, top_k: int = 3):
return self._store["commands"].search(query, top_k=top_k, hybrid=True)
def execute_retrieval_action(self, action, user_id, bot_id, history="", location="", time_date=""):
source = action.get("source")
profile = action.get("profile")
query = action.get("query")
if not query:
return None
if source == "KnowledgeTree":
return self.process_knowledge(query=query, history=history, location=location, time_date=time_date)
if source == "ProfileTree":
if profile == "user_profile":
return self.search_profile_leaves(user_id, query, profile_type="user")
if profile == "bot_profile":
return self.search_profile_leaves(bot_id, query, profile_type="bot")
return None
def _merge_retrieval_actions(self, actions: List[ActionDecision]) -> Dict[str, List[Dict[str, Any]]]:
retrieval_actions = [a for a in actions if a.type == "retrieval"]
if not retrieval_actions:
return {"KnowledgeTree": [], "ProfileTree": []}
out = {"KnowledgeTree": [], "ProfileTree": []}
for a in retrieval_actions:
if a.source == "KnowledgeTree" or (a.memory_access and "KnowledgeTree" in a.memory_access):
out["KnowledgeTree"].append(asdict(a))
if a.source == "ProfileTree" or (a.memory_access and "ProfileTree" in a.memory_access):
out["ProfileTree"].append(asdict(a))
return out
def process_actions(self, query, user_id, bot_id, history="", location="", time_date=""):
relevant_commands = self.search_commands(query, top_k=3)
available_actions = [
{
"name": c["metadata"].get("command_name", c["id"]),
"description": c["metadata"].get("description", ""),
"type": c["metadata"].get("type", "command"),
}
for c in relevant_commands
]
decision_prompt = {
"query": query,
"available_commands": available_actions,
"history": history,
"location": location,
"time_date": time_date,
"user_id": user_id,
"bot_id": bot_id,
"route_goal": "Use HeadExpert to decide whether retrieval is needed. If a tool/command is needed, choose it with ToolExpert.",
}
head_decision = self._lm_head(_safe_json_dumps(decision_prompt), mode="decision")
actions = self._decision_to_actions(head_decision)
self._last_plan = head_decision if isinstance(head_decision, dict) else {}
command_needed = any(a.type == "command" for a in actions)
retrieval_needed = any(a.type == "retrieval" for a in actions)
if not actions:
actions = [ActionDecision(type="None", sufficient_to_answer=True)]
matched_commands: List[Any] = []
retrieved_data: Dict[str, List[Any]] = {}
if retrieval_needed:
for action in actions:
if action.type != "retrieval":
continue
if action.sufficient_to_answer:
continue
result = self.execute_retrieval_action(asdict(action), user_id, bot_id, history=history, location=location, time_date=time_date)
if result is None:
continue
source = action.source or "KnowledgeTree"
retrieved_data.setdefault(source, [])
if isinstance(result, list):
retrieved_data[source].extend(result)
else:
retrieved_data[source].append(result)
if command_needed:
tool_calls = self._lm_tool(
query,
tools=available_actions,
)
for call in tool_calls.get("tool_calls", []):
if not isinstance(call, dict):
continue
command_name = call.get("name") or call.get("command_name")
parameters = call.get("arguments") or call.get("parameters") or {}
cmd_meta = next((cmd for cmd in relevant_commands if cmd["metadata"].get("command_name") == command_name), None)
if cmd_meta:
command_type = cmd_meta["metadata"].get("type", "custom")
matched = self.command_registry.process_commands(command_name, command_type, parameters)
matched_commands.extend(matched)
shortcut_matches = self.command_registry.process_commands(query, "shortcut")
matched_commands.extend(shortcut_matches)
final_commands = " ".join([str(m) for m in matched_commands]) if matched_commands else None
for key, values in retrieved_data.items():
seen = set()
deduped = []
for v in values:
try:
normalized = self.normalize_for_hash(v)
h = _safe_json_dumps(normalized)
except Exception:
h = str(v)
if h not in seen:
seen.add(h)
deduped.append(v)
retrieved_data[key] = deduped
self._last_route = {"commands": final_commands, "retrieved_data": retrieved_data, "head_plan": head_decision}
return {"commands": final_commands, "retrieved_data": retrieved_data}
def warmup(self):
try:
_ = self.embed(["__gator_warmup__"], task="retrieval.passage")
except Exception:
pass
def export_snapshot_bytes(self) -> bytes:
payload = {
"version": self.STATE_VERSION,
"config": self.config,
"embedder_pack": self._snapshot_embedder_pack(),
"stores": self._snapshot_store(),
"command_phrases": self._command_phrases,
"custom_commands": self.command_registry.custom_commands,
"commands": self.command_registry.commands,
"last_route": self._last_route,
"last_response": self._last_response,
"last_plan": self._last_plan,
}
buf = io.BytesIO()
torch.save(payload, buf)
return lzma.compress(buf.getvalue(), preset=9)
def import_snapshot_bytes(self, blob: bytes):
payload = torch.load(io.BytesIO(lzma.decompress(blob)), map_location="cpu", weights_only=False)
self._restore_store(payload.get("stores", {}))
self.command_registry.custom_commands = payload.get("custom_commands", {}) or {}
self.command_registry.commands = payload.get("commands", {}) or {}
self._command_phrases = payload.get("command_phrases", {}) or {}
self._last_route = payload.get("last_route", {}) or {}
self._last_response = payload.get("last_response", "") or ""
self._last_plan = payload.get("last_plan", {}) or {}
if payload.get("embedder_pack"):
self.embedder_pack = payload["embedder_pack"]
self.embedder, self.embedder_tokenizer, self.embedder_pack = self._restore_embedder_from_pack(self.embedder_pack)
self.embedder_local_dir = self.embedder_pack.get("local_dir", self.embedder_local_dir)
def save_checkpoint(self, path: str = "GATOR.pt") -> None:
self.warmup()
payload = {
"version": self.STATE_VERSION,
"config": self.config,
"embedder_pack": self._snapshot_embedder_pack(),
"stores": self._snapshot_store(),
"custom_commands": self.command_registry.custom_commands,
"commands": self.command_registry.commands,
"command_phrases": self._command_phrases,
"last_route": self._last_route,
"last_response": self._last_response,
"last_plan": self._last_plan,
}
torch.save(payload, path, pickle_protocol=5)
@classmethod
def load_checkpoint(cls, path: str = "GATOR.pt") -> "GATOR":
if not os.path.exists(path):
raise FileNotFoundError(path)
payload = torch.load(path, map_location="cpu", weights_only=False)
cfg = payload.get("config", {})
obj = cls(
lm_checkpoint_path=cfg.get("lm_checkpoint_path", "LM.pt"),
embedder_name=cfg.get("embedder_name", "second-state/jina-embeddings-v3-GGUF"),
embedder_local_dir=cfg.get("embedder_local_dir", os.path.join("models", "jinaai")),
embedder_filename=cfg.get("embedder_filename", GGUF_EMBED_FILENAME),
device=cfg.get("device", "cpu"),
warm_on_start=False,
compression=cfg.get("compression", "lzma"),
store_dtype=cfg.get("store_dtype", "float16"),
cluster_k=int(cfg.get("cluster_k", 4)),
auto_load_lm=bool(cfg.get("auto_load_lm", True)),
strict_lm=bool(cfg.get("strict_lm", True)),
embedder_pack=payload.get("embedder_pack"),
)
obj._restore_store(payload.get("stores", {}))
obj.command_registry.custom_commands = payload.get("custom_commands", {}) or {}
obj.command_registry.commands = payload.get("commands", {}) or {}
obj._command_phrases = payload.get("command_phrases", {}) or {}
obj._last_route = payload.get("last_route", {}) or {}
obj._last_response = payload.get("last_response", "") or ""
obj._last_plan = payload.get("last_plan", {}) or {}
if payload.get("embedder_pack"):
obj.embedder_pack = payload["embedder_pack"]
obj.embedder = obj._restore_embedder_from_pack(obj.embedder_pack)
return obj
def infer_actions(self, prompt: str) -> Dict[str, Any]:
return self._lm_head(prompt, mode="decision")
def infer_command(self, prompt_payload: Dict[str, Any]) -> Dict[str, Any]:
tools = prompt_payload.get("available_commands", [])
query = prompt_payload.get("input_text", "")
return self._lm_tool(query, tools=tools)
def respond(self, query: str, user_id: str, bot_id: str, history: str = "", location: str = "", time_date: str = ""):
routed = self.process_actions(query, user_id=user_id, bot_id=bot_id, history=history, location=location, time_date=time_date)
return routed
def process_command(self, *args, **kwargs):
return self.process_actions(*args, **kwargs)
def summary(self) -> Dict[str, Any]:
return {
"config": dict(self.config),
"stores": {name: {"count": len(tree.docs), "dim": int(tree.embs.shape[1]) if tree.embs.size else self.embed_dim} for name, tree in self._store.items()},
"commands": len(self.command_registry.commands),
"custom_commands": len(self.command_registry.custom_commands),
"last_route": self._last_route,
"last_plan": self._last_plan,
}
def run_self_test(self) -> Dict[str, Any]:
report = {}
try:
report["embed"] = tuple(self.embed(["hello world"], task="retrieval.passage").shape)
except Exception as e:
report["embed_error"] = repr(e)
try:
self.store_knowledge(["The capital of France is Paris."], tags=["fact"], source="test", importance=1.0)
report["knowledge"] = self.search_knowledge("What is the capital of France?", top_k=1)
except Exception as e:
report["knowledge_error"] = repr(e)
try:
self.store_profile_leaf("user-1", "Likes Japanese food.", importance=0.8, profile_type="user")
report["profile"] = self.search_profile_leaves("user-1", "food preference", profile_type="user", top_k=1)
except Exception as e:
report["profile_error"] = repr(e)
try:
self.store_command("test_cmd", "test command", "A small test command", command_type="custom", actions=[])
report["commands"] = self.search_commands("test command", top_k=1)
except Exception as e:
report["commands_error"] = repr(e)
return {"report": report, "summary": self.summary()}
@property
def profile_tree(self):
return types.SimpleNamespace(search_leaves=lambda profile_id, query, top_k=3, min_score=0.0, use_clusters=True: self.search_profile_leaves(profile_id, query, profile_type="user", top_k=top_k, min_score=min_score))
@property
def conversation_tree(self):
return types.SimpleNamespace(
add_input_leaf=lambda text, id: self._store["conversation"].search(text, top_k=3, hybrid=True, use_clusters=True),
add_output_leaf=lambda text, id: self.store_conversation_leaf(text, id, leaf_type="output"),
)
@property
def knowledge_tree(self):
return types.SimpleNamespace(
retrieve_by_semantics=lambda query, num_clusters=2, top_k_per_cluster=3, min_score=0.0: self._store["knowledge"].retrieve_by_semantics(query, num_clusters=num_clusters, top_k_per_cluster=top_k_per_cluster, min_score=min_score),
search=lambda query, top_k=5, hybrid=True: self._store["knowledge"].search(query, top_k=top_k, hybrid=hybrid),
_embed_text=lambda text: self.embed(text, task="retrieval.passage")[0],
_warm_collection=lambda: self.warmup(),
)
@property
def command_tree(self):
return types.SimpleNamespace(
search_relevant_commands=lambda query, top_k=3: self.search_commands(query, top_k=top_k),
add_command_branch=lambda command_name, command_action, command_type, description: self.store_command(command_action, command_name, description, command_type=command_type),
update_command_description=lambda command_name, new_description: self.command_registry.update_command(command_name, description=new_description),
)
def save(self, path: str = "GATOR.pt"):
return self.save_checkpoint(path)
def __repr__(self) -> str:
return f"GATOR(embedder={self.embedder_name!r}, local_dir={self.embedder_local_dir!r}, stores={list(self._store.keys())}, commands={len(self.command_registry.commands)})"
class MemoryBank:
"""
Thin owner/wrapper around GATOR.
- Loads an existing GATOR.pt if present.
- Builds a new one if missing.
- Uses AppData/LocalAppData as the default location when no path is passed.
- Delegates all unknown attributes/methods directly to the underlying GATOR instance.
"""
DEFAULT_APP_FOLDER = "PackedLLM"
DEFAULT_BUNDLE_NAME = "GATOR.pt"
def __init__(
self,
gator_location: Optional[Union[str, os.PathLike]] = None,
*,
build_if_missing: bool = True,
**gator_kwargs: Any,
):
self.root_dir = self._resolve_root_dir(gator_location)
self.root_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint_path = self._resolve_checkpoint_path(self.root_dir)
if self.checkpoint_path.exists():
self.gator = GATOR.load_checkpoint(str(self.checkpoint_path))
else:
if not build_if_missing:
raise FileNotFoundError(f"No GATOR checkpoint found at: {self.checkpoint_path}")
self.gator = GATOR(**gator_kwargs)
self.gator.save_checkpoint(str(self.checkpoint_path))
@classmethod
def _default_appdata_dir(cls) -> Path:
local_appdata = os.getenv("LOCALAPPDATA")
appdata = os.getenv("APPDATA")
if local_appdata:
base = Path(local_appdata)
elif appdata:
base = Path(appdata)
else:
base = Path.home() / "AppData" / "Local"
return base / cls.DEFAULT_APP_FOLDER
@classmethod
def _resolve_root_dir(cls, gator_location: Optional[Union[str, os.PathLike]]) -> Path:
if gator_location is None:
return cls._default_appdata_dir()
path = Path(gator_location).expanduser().resolve()
if path.suffix.lower() == ".pt":
return path.parent
return path
@classmethod
def _resolve_checkpoint_path(cls, root_dir: Path) -> Path:
return root_dir / cls.DEFAULT_BUNDLE_NAME
def save(self) -> str:
self.root_dir.mkdir(parents=True, exist_ok=True)
self.gator.save_checkpoint(str(self.checkpoint_path))
return str(self.checkpoint_path)
def reload(self) -> None:
if not self.checkpoint_path.exists():
raise FileNotFoundError(str(self.checkpoint_path))
self.gator = GATOR.load_checkpoint(str(self.checkpoint_path))
def rebuild(self, **gator_kwargs: Any) -> None:
self.gator = GATOR(**gator_kwargs)
self.gator.save_checkpoint(str(self.checkpoint_path))
def __getattr__(self, name: str) -> Any:
return getattr(self.gator, name)
def __dir__(self):
base = set(super().__dir__())
try:
base.update(dir(self.gator))
except Exception:
pass
return sorted(base)
def __repr__(self) -> str:
return f"MemoryBank(root_dir={str(self.root_dir)!r}, checkpoint_path={str(self.checkpoint_path)!r})"
class HardwareProbe:
@staticmethod
def cpu() -> Dict[str, Any]:
physical = psutil.cpu_count(logical=False) or psutil.cpu_count(logical=True) or 1
logical = psutil.cpu_count(logical=True) or physical
vm = psutil.virtual_memory()
return {
"physical_cores": physical,
"logical_cores": logical,
"available_ram_gb": vm.available / (1024 ** 3),
"total_ram_gb": vm.total / (1024 ** 3),
}
@staticmethod
def nvidia_gpu() -> Optional[Dict[str, Any]]:
try:
cmd = (
"nvidia-smi --query-gpu=name,memory.total,memory.free,memory.used,"
"utilization.gpu --format=csv,noheader,nounits"
)
out = subprocess.check_output(cmd.split(), stderr=subprocess.DEVNULL).decode("ascii").strip()
if not out:
return None
parts = [p.strip() for p in out.split(",")]
return {
"backend": "cuda",
"device_name": parts[0],
"total_vram_gb": float(parts[1]) / 1024.0,
"free_vram_gb": float(parts[2]) / 1024.0,
"used_vram_gb": float(parts[3]) / 1024.0,
"utilization_pct": float(parts[4]),
}
except Exception:
return None
@staticmethod
def torch_gpu() -> Optional[Dict[str, Any]]:
try:
if torch.cuda.is_available():
idx = torch.cuda.current_device()
free_b, total_b = torch.cuda.mem_get_info(idx)
return {
"backend": "cuda",
"device_name": torch.cuda.get_device_name(idx),
"total_vram_gb": total_b / (1024 ** 3),
"free_vram_gb": free_b / (1024 ** 3),
"used_vram_gb": (total_b - free_b) / (1024 ** 3),
"utilization_pct": None,
}
except Exception:
pass
try:
if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
vm = psutil.virtual_memory()
return {
"backend": "metal",
"device_name": "Apple Silicon (MPS)",
"total_vram_gb": vm.total / (1024 ** 3),
"free_vram_gb": vm.available / (1024 ** 3),
"used_vram_gb": (vm.total - vm.available) / (1024 ** 3),
"utilization_pct": None,
}
except Exception:
pass
return None
@staticmethod
def webgpu() -> Optional[Dict[str, Any]]:
if not _WGPU_AVAILABLE:
return None
try:
request = getattr(wgpu.gpu, "request_adapter_sync", None) or getattr(wgpu.gpu, "request_adapter")
adapter = request(power_preference="high-performance")
limits = getattr(adapter, "limits", {}) or {}
max_buffer_bytes = limits.get("max-buffer-size") or limits.get("maxBufferSize") or 0
est_gb = (max_buffer_bytes / (1024 ** 3)) * 0.5 if max_buffer_bytes else 1.0
return {
"backend": "webgpu",
"device_name": getattr(adapter, "summary", "WebGPU adapter"),
"total_vram_gb": est_gb,
"free_vram_gb": est_gb,
"used_vram_gb": 0.0,
"utilization_pct": None,
}
except Exception:
return None
@classmethod
def snapshot(cls) -> Dict[str, Any]:
gpu = cls.nvidia_gpu() or cls.torch_gpu()
webgpu = None if gpu is not None else cls.webgpu()
return {"cpu": cls.cpu(), "gpu": gpu, "webgpu": webgpu}
@dataclass
class RunMetrics:
duration_sec: float
tokens: int
tokens_per_sec: float
template_used: Optional[str]
offload_plan: Dict[str, Any]
telemetry_deltas: Dict[str, Any]
timestamp: float = field(default_factory=time.time)
class ExpertHandle(nn.Module):
def __init__(self, name: str, spec: Dict[str, Any]):
super().__init__()
self.name = name
self.spec: Dict[str, Any] = dict(spec)
self._llama = None
self.last_prompt: str = ""
self.last_response: str = ""
self.last_template_used: Optional[str] = None
self.last_offload_plan: Dict[str, Any] = {}
self.last_metrics: Dict[str, Any] = {}
self.metrics_history: List[Dict[str, Any]] = []
self.call_count: int = 0
self.total_inference_sec: float = 0.0
self.avg_inference_sec: float = 0.0
self.total_tokens_generated: int = 0
self.avg_tokens_per_sec: float = 0.0
self.load_time_sec: Optional[float] = None
def __getstate__(self):
state = self.__dict__.copy()
if self.name == "TranslationExpert" and isinstance(state.get("_llama"), dict):
m_data = state["_llama"]
model = m_data.get("model")
state["_llama"] = {
"state_dict": model.state_dict(),
"config": model.config,
"tokenizer": m_data.get("tokenizer"),
"local_dir": m_data.get("local_dir")
}
elif self.name != "TranslationExpert":
state["_llama"] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
if self.name == "TranslationExpert" and isinstance(self._llama, dict) and "state_dict" in self._llama:
from transformers import MarianMTModel
data = self._llama
model = MarianMTModel(data["config"])
model.load_state_dict(data["state_dict"])
self._llama = {
"tokenizer": data["tokenizer"],
"model": model,
"local_dir": data["local_dir"]
}
def is_loaded(self) -> bool:
return self._llama is not None
def record_run(
self,
prompt_repr: str,
response: str,
duration_sec: float,
tokens: int,
telemetry_deltas: Dict[str, Any],
template_used: Optional[str] = None,
) -> None:
self.last_prompt = prompt_repr
self.last_response = response
self.last_template_used = template_used
self.call_count += 1
self.total_inference_sec += duration_sec
self.avg_inference_sec = self.total_inference_sec / self.call_count
self.total_tokens_generated += tokens
tps = (tokens / duration_sec) if duration_sec > 0 else 0.0
self.avg_tokens_per_sec = ((self.avg_tokens_per_sec * (self.call_count - 1)) + tps) / self.call_count
entry = asdict(
RunMetrics(
duration_sec=duration_sec,
tokens=tokens,
tokens_per_sec=tps,
template_used=template_used,
offload_plan=self.last_offload_plan,
telemetry_deltas=telemetry_deltas,
)
)
self.last_metrics = entry
self.metrics_history.append(entry)
if len(self.metrics_history) > 25:
self.metrics_history.pop(0)
def forward(self, *args, **kwargs): # pragma: no cover
raise RuntimeError(f"ExpertHandle('{self.name}') is not directly callable.")
# ================================================================
# PackedLM
# ================================================================
class PackedLM(nn.ModuleDict):
DEFAULT_LAYER_GUESS = 32
VRAM_SAFETY_MARGIN = 1.15
WEBGPU_SAFETY_MARGIN = 1.40
_R1_USER_TOKEN = "<|User|>"
_R1_ASSISTANT_TOKEN = "<|Assistant|>"
_XLAM_TASK_INSTRUCTION_DEFAULT = (
"Based on the user's query, decide whether a function call is needed and, if so, "
"produce the correct call(s) using only the tools provided."
)
_XLAM_FORMAT_INSTRUCTION_DEFAULT = (
'Generate a JSON object of the form {"tool_calls": [{"name": "func_name", '
'"arguments": {"arg1": "value1"}}, ...]}. If no function call is needed, '
'return {"tool_calls": []}. Output JSON only, nothing else.'
)
def __init__(self, bundle_path: Optional[str] = DEFAULT_BUNDLE_PATH, auto_load_bundle: bool = True):
super().__init__()
self.bundle_path = bundle_path
self.bundle: Optional[Dict[str, Any]] = None
self.last_expert: str = ""
self._embed_tempfiles: Dict[str, str] = {}
self._hf_translation_dir = str(DEFAULT_ZH_EN_DIR)
if auto_load_bundle and bundle_path and os.path.exists(bundle_path):
self.load_bundle(bundle_path)
def __getstate__(self):
state = self.__dict__.copy()
for name, expert in self.items():
if name == "TranslationExpert":
continue
if hasattr(expert, "_llama"):
expert._llama = None
return state
def load_bundle(self, bundle_path: str) -> "PackedLM":
self.bundle_path = bundle_path
self.bundle = torch.load(bundle_path, map_location="cpu", weights_only=False)
models = self.bundle.get("models", {})
for name, spec in models.items():
self[name] = ExpertHandle(name, spec)
# Optional translation module for zh->en support.
if "TranslationExpert" not in self and "zh_en_translator" not in self:
self["TranslationExpert"] = ExpertHandle(
"TranslationExpert",
{
"kind": "hf_seq2seq",
"repo_id": ZH_EN_REPO_ID,
"local_dir": self._hf_translation_dir,
"source_lang": "zh",
"target_lang": "en",
},
)
return self
def reload_expert(self, expert_name: str) -> "PackedLM":
if expert_name not in self:
raise KeyError(f"Unknown expert '{expert_name}'. Loaded experts: {list(self.keys())}")
expert = self[expert_name]
if expert.spec.get("kind") == "hf_seq2seq":
self._load_translation_backend(force_reload=True)
return self
expert._llama = None
self._get_llama(expert, force_reload=True)
return self
def unload_expert(self, expert_name: str) -> None:
if expert_name in self and hasattr(self[expert_name], "_llama"):
self[expert_name]._llama = None
def unload_all(self) -> None:
for name in list(self.keys()):
self.unload_expert(name)
def summary(self) -> Dict[str, Any]:
out: Dict[str, Any] = {}
for name, expert in self.items():
out[name] = {
"loaded": expert.is_loaded(),
"call_count": expert.call_count,
"avg_inference_sec": round(expert.avg_inference_sec, 4),
"avg_tokens_per_sec": round(expert.avg_tokens_per_sec, 2),
"last_offload_plan": expert.last_offload_plan,
"last_response_preview": (expert.last_response[:160] + "...") if len(expert.last_response) > 160 else expert.last_response,
}
out["_last_expert"] = self.last_expert
return out
@staticmethod
def _tensor_to_bytes(t: Any) -> bytes:
if isinstance(t, bytes):
return t
if torch.is_tensor(t):
return bytes(t.detach().cpu().contiguous().numpy().tobytes())
raise TypeError(f"Unsupported embedded asset type: {type(t)}")
def _resolve_model_path(self, expert: ExpertHandle) -> str:
spec = expert.spec
path = spec.get("path", "")
if path and os.path.exists(path):
return path
assets = (self.bundle or {}).get("assets", {}).get("gguf", {})
embedded = assets.get(expert.name)
if embedded is None:
raise FileNotFoundError(f"Model '{expert.name}' has no external path and no embedded GGUF bytes in the bundle.")
cache_key = expert.name
if cache_key in self._embed_tempfiles and os.path.exists(self._embed_tempfiles[cache_key]):
return self._embed_tempfiles[cache_key]
raw = self._tensor_to_bytes(embedded)
tmp = tempfile.NamedTemporaryFile(prefix=f"{expert.name}_", suffix=".gguf", delete=False)
tmp.write(raw)
tmp.flush()
tmp.close()
self._embed_tempfiles[cache_key] = tmp.name
return tmp.name
def _resolve_projector_path(self, expert: ExpertHandle) -> Optional[str]:
spec = expert.spec
path = spec.get("mmproj_path") or spec.get("clip_model_path") or spec.get("clip_path")
if not path:
return None
if os.path.exists(path):
return path
assets = (self.bundle or {}).get("assets", {}).get("gguf", {})
embedded = assets.get(path)
if embedded is None:
return None
cache_key = f"{expert.name}_mmproj"
if cache_key in self._embed_tempfiles and os.path.exists(self._embed_tempfiles[cache_key]):
return self._embed_tempfiles[cache_key]
raw = self._tensor_to_bytes(embedded)
tmp = tempfile.NamedTemporaryFile(prefix=f"{expert.name}_mmproj_", suffix=".gguf", delete=False)
tmp.write(raw)
tmp.flush()
tmp.close()
self._embed_tempfiles[cache_key] = tmp.name
return tmp.name
def _plan_offload(self, expert: ExpertHandle) -> Dict[str, Any]:
hw = HardwareProbe.snapshot()
spec = expert.spec
try:
model_path = self._resolve_model_path(expert)
file_size_gb = os.path.getsize(model_path) / (1024 ** 3)
except Exception:
file_size_gb = float(spec.get("approx_size_gb", 2.0))
n_layers = int(spec.get("n_layers", self.DEFAULT_LAYER_GUESS))
reasoning: List[str] = []
if hw["gpu"] is not None:
gpu = hw["gpu"]
needed = file_size_gb * self.VRAM_SAFETY_MARGIN
if gpu["free_vram_gb"] >= needed:
n_gpu_layers = -1
reasoning.append(f"{gpu['backend']} GPU '{gpu['device_name']}' has {gpu['free_vram_gb']:.2f}GB free >= {needed:.2f}GB needed -> full offload")
else:
frac = max(0.0, gpu["free_vram_gb"] / needed) if needed > 0 else 0.0
n_gpu_layers = max(0, int(frac * n_layers))
reasoning.append(f"{gpu['backend']} GPU has only {gpu['free_vram_gb']:.2f}GB free of {needed:.2f}GB needed -> partial offload of {n_gpu_layers}/{n_layers} layers")
backend = gpu["backend"]
elif hw["webgpu"] is not None:
webgpu = hw["webgpu"]
needed = file_size_gb * self.WEBGPU_SAFETY_MARGIN
if webgpu["free_vram_gb"] >= needed:
n_gpu_layers = -1
reasoning.append("WebGPU adapter's estimated budget covers the full model -> full offload")
else:
frac = max(0.0, webgpu["free_vram_gb"] / needed) if needed > 0 else 0.0
n_gpu_layers = max(0, int(frac * n_layers))
reasoning.append(f"WebGPU adapter budget covers ~{frac * 100:.0f}% of the model -> partial offload of {n_gpu_layers}/{n_layers} layers")
backend = "webgpu"
else:
n_gpu_layers = 0
backend = "cpu"
reasoning.append("No CUDA/Metal GPU or WebGPU adapter detected -> CPU-only")
n_threads = max(1, hw["cpu"]["physical_cores"] - 1)
return {
"backend": backend,
"n_gpu_layers": n_gpu_layers,
"n_threads": n_threads,
"model_size_gb": round(file_size_gb, 3),
"hardware_snapshot": hw,
"rationale": " | ".join(reasoning),
}
def _get_llama(self, expert: ExpertHandle, force_reload: bool = False):
if expert.is_loaded() and not force_reload:
return expert._llama
if not _LLAMA_CPP_AVAILABLE:
raise RuntimeError("llama-cpp-python is not installed.")
model_path = self._resolve_model_path(expert)
plan = self._plan_offload(expert)
chat_handler = None
if expert.spec.get("vision", False):
projector_path = self._resolve_projector_path(expert)
if projector_path:
from llama_cpp.llama_chat_format import Qwen25VLChatHandler
chat_handler = Qwen25VLChatHandler(clip_model_path=projector_path)
t0 = time.perf_counter()
llama = Llama(
model_path=model_path,
n_ctx=expert.spec.get("ctx", 8192),
n_threads=plan["n_threads"],
n_gpu_layers=plan["n_gpu_layers"],
chat_handler=chat_handler,
verbose=False,
)
expert.load_time_sec = time.perf_counter() - t0
expert.last_offload_plan = plan
expert._llama = llama
return llama
def _load_translation_backend(self):
if "TranslationExpert" not in self:
raise RuntimeError("TranslationExpert not found.")
expert = self["TranslationExpert"]
if isinstance(expert._llama, dict) and "model" in expert._llama:
return expert._llama
if self.bundle and "assets" in self.bundle and "translation" in self.bundle["assets"]:
print("Hydrating TranslationExpert from embedded bundle assets...")
from transformers import MarianMTModel, MarianTokenizer, MarianConfig
import tempfile
import shutil
asset_data = self.bundle["assets"]["translation"]
tok_tmp = Path(tempfile.mkdtemp(prefix="zh_en_tok_"))
try:
for filename, filebytes in asset_data["tokenizer_files"].items():
target_path = tok_tmp / filename
target_path.parent.mkdir(parents=True, exist_ok=True)
target_path.write_bytes(filebytes)
tokenizer = MarianTokenizer.from_pretrained(str(tok_tmp))
finally:
shutil.rmtree(tok_tmp, ignore_errors=True)
config = MarianConfig.from_dict(asset_data["config"])
model = MarianMTModel(config)
model.load_state_dict(asset_data["state_dict"])
expert._llama = {
"tokenizer": tokenizer,
"model": model,
"local_dir": self._hf_translation_dir
}
return expert._llama
raise RuntimeError("Translation model data not found or corrupted in bundle assets.")
def _translate_with_internal_model(self, text: str) -> str:
backend = self._load_translation_backend()
tokenizer = backend["tokenizer"]
model = backend["model"]
inputs = tokenizer(text, return_tensors="pt", truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
generated = model.generate(**inputs, max_new_tokens=128, renormalize_logits=True, repetition_penalty=1.1,)
return tokenizer.batch_decode(generated, skip_special_tokens=True)[0].strip()
def translate_zh_en(self, text: str, template: Optional[Union[str, Callable]] = None) -> str:
del template # reserved for future parity with the other experts
if not text:
return ""
if not CHINESE_RE.search(text):
return text
return self._translate_with_internal_model(text)
def _translate_chinese_spans(self, text: str) -> str:
if not CHINESE_RE.search(text):
return text
def repl(match: re.Match) -> str:
segment = match.group(0).strip()
if not segment:
return segment
try:
translated = self.translate_zh_en(segment)
return translated if translated else segment
except Exception:
return segment
return CHINESE_SPAN_RE.sub(repl, text)
@staticmethod
def _is_url(value: str) -> bool:
try:
return urlparse(value).scheme in ("http", "https")
except Exception:
return False
@staticmethod
def _image_to_data_uri(image_path: str, max_pixels: int = 1_000_000) -> str:
if not _PIL_AVAILABLE:
raise RuntimeError("Pillow is required for local image encoding.")
with Image.open(image_path) as img:
if img.mode not in ("RGB", "RGBA"):
img = img.convert("RGB")
width, height = img.size
if width * height > max_pixels:
scale = (max_pixels / (width * height)) ** 0.5
img = img.resize((int(width * scale), int(height * scale)), Image.Resampling.LANCZOS)
buffer = io.BytesIO()
img.save(buffer, format="PNG")
b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return f"data:image/png;base64,{b64}"
@staticmethod
def _build_messages(
template: Optional[Union[str, list, Callable]],
prompt: str,
default_builder: Callable[..., List[Dict[str, Any]]],
**fields: Any,
) -> List[Dict[str, Any]]:
if template is None:
return default_builder(prompt, **fields)
if callable(template) and not isinstance(template, (str, list)):
return template(prompt, **fields)
if isinstance(template, list):
return template
if isinstance(template, str):
base = default_builder(prompt, **fields)
content = template.format(prompt=prompt, **fields)
if base and base[0].get("role") == "system":
return [base[0], {"role": "user", "content": content}]
return [{"role": "user", "content": content}]
raise TypeError("template must be None, a list[dict], a callable, or a str")
def _exec_chat(self, expert_name: str, messages: List[Dict[str, Any]], max_tokens: int, temperature: float, **gen_kwargs):
expert = self[expert_name]
llama = self._get_llama(expert)
t_pre = capture_telemetry()
t0 = time.perf_counter()
out = llama.create_chat_completion(messages=messages, max_tokens=max_tokens, temperature=temperature, **gen_kwargs)
dt = time.perf_counter() - t0
t_post = capture_telemetry()
text = out["choices"][0]["message"]["content"]
tokens = out.get("usage", {}).get("completion_tokens", 0)
deltas = calculate_delta(t_pre, t_post)
return text, tokens, dt, deltas
def _exec_completion(self, expert_name: str, raw_prompt: str, max_tokens: int, temperature: float, stop: Optional[List[str]] = None, **gen_kwargs):
expert = self[expert_name]
llama = self._get_llama(expert)
t_pre = capture_telemetry()
t0 = time.perf_counter()
out = llama.create_completion(prompt=raw_prompt, max_tokens=max_tokens, temperature=temperature, stop=stop, **gen_kwargs)
dt = time.perf_counter() - t0
t_post = capture_telemetry()
text = out["choices"][0]["text"]
tokens = out.get("usage", {}).get("completion_tokens", 0)
deltas = calculate_delta(t_pre, t_post)
return text, tokens, dt, deltas
def _finalize(self, expert_name: str, prompt_repr: str, final_text: str, duration_sec: float, tokens: int, telemetry_deltas: Dict[str, Any], template_used: Optional[str]) -> str:
expert = self[expert_name]
expert.record_run(prompt_repr=prompt_repr, response=final_text, duration_sec=duration_sec, tokens=tokens, telemetry_deltas=telemetry_deltas, template_used=template_used)
self.last_expert = expert_name
return final_text
def _default_creative_messages(self, prompt: str, tone: Optional[str] = None, length: Optional[str] = None, pov: Optional[str] = None, style: Optional[str] = None) -> List[Dict[str, Any]]:
system = "You are a creative writing assistant."
if tone:
system += f" Match this tone: {tone}."
constraints = [c for c in (f"length: {length}" if length else None, f"POV: {pov}" if pov else None, f"style: {style}" if style else None) if c]
user = prompt + ("\n\nConstraints: " + ", ".join(constraints) if constraints else "")
return [{"role": "system", "content": system}, {"role": "user", "content": user}]
def creative_expert(self, prompt: str, tone: Optional[str] = None, length: Optional[str] = None, pov: Optional[str] = None, style: Optional[str] = None, template: Optional[Union[str, list, Callable]] = None, max_tokens: int = 600, temperature: float = 0.9, **gen_kwargs) -> str:
messages = self._build_messages(template, prompt, self._default_creative_messages, tone=tone, length=length, pov=pov, style=style)
text, tokens, dt, deltas = self._exec_chat("CreativeExpert", messages, max_tokens, temperature, **gen_kwargs)
return self._finalize("CreativeExpert", prompt, text, dt, tokens, deltas, "custom" if template is not None else "default")
def _default_code_messages(self, prompt: str, language: Optional[str] = None, context: Optional[str] = None, constraints: Optional[str] = None) -> List[Dict[str, Any]]:
parts = [f"Task: {prompt}"]
if language:
parts.append(f"Language: {language}")
if context:
parts.append(f"Context:\n{context}")
if constraints:
parts.append(f"Constraints:\n{constraints}")
return [{"role": "user", "content": "\n".join(parts)}]
def code_expert(self, prompt: str, language: Optional[str] = None, context: Optional[str] = None, constraints: Optional[str] = None, template: Optional[Union[str, list, Callable]] = None, max_tokens: int = 1200, temperature: float = 0.2, **gen_kwargs) -> str:
messages = self._build_messages(template, prompt, self._default_code_messages, language=language, context=context, constraints=constraints)
text, tokens, dt, deltas = self._exec_chat("CodeExpert", messages, max_tokens, temperature, **gen_kwargs)
return self._finalize("CodeExpert", prompt, text, dt, tokens, deltas, "custom" if template is not None else "default")
def _logic_default_raw(self, prompt: str, mode: str):
instruction = prompt.strip()
if mode == "deep_then_answer":
instruction += "\nPlease reason step by step, then provide the final answer succinctly."
raw = f"{self._R1_USER_TOKEN}{instruction}{self._R1_ASSISTANT_TOKEN}"
return raw, None, lambda t: t.strip()
if mode == "think_only":
instruction += "\nPlease reason step by step, and do not provide a final answer."
raw = f"{self._R1_USER_TOKEN}{instruction}{self._R1_ASSISTANT_TOKEN}\n"
return raw, [""], lambda t: "\n" + t.strip() + "\n"
if mode == "skip_reasoning":
raw = f"{self._R1_USER_TOKEN}{prompt.strip()}{self._R1_ASSISTANT_TOKEN}\n\n\n\n"
return raw, None, lambda t: t.strip()
raise ValueError(f"Unknown LogicExpert mode: {mode!r}")
def logic_expert(self, prompt: str, mode: str = "deep_then_answer", template: Optional[Union[str, Callable]] = None, max_tokens: int = 1024, temperature: float = 0.6, **gen_kwargs) -> str:
if mode not in ("deep_then_answer", "skip_reasoning", "think_only"):
raise ValueError("mode must be one of: deep_then_answer, skip_reasoning, think_only")
if template is not None:
stop = gen_kwargs.pop("stop", None)
raw_prompt = template(prompt, mode) if callable(template) else str(template).format(prompt=prompt, mode=mode)
wrap = lambda t: t
else:
raw_prompt, stop, wrap = self._logic_default_raw(prompt, mode)
raw_text, tokens, dt, deltas = self._exec_completion("LogicExpert", raw_prompt, max_tokens, temperature, stop=stop, **gen_kwargs)
final_text = wrap(raw_text)
return self._finalize("LogicExpert", prompt, final_text, dt, tokens, deltas, "custom" if template is not None else mode)
def _default_role_messages(self, prompt: str, character_card: Optional[str] = None) -> List[Dict[str, Any]]:
system = character_card or "You are roleplaying a character."
system += " Respond using Classic Internet RP formatting: *action* speech *narration*."
return [{"role": "system", "content": system}, {"role": "user", "content": prompt}]
def role_expert(self, prompt: str, character_card: Optional[str] = None, template: Optional[Union[str, list, Callable]] = None, max_tokens: int = 400, temperature: float = 0.9, **gen_kwargs) -> str:
messages = self._build_messages(template, prompt, self._default_role_messages, character_card=character_card)
text, tokens, dt, deltas = self._exec_chat("RoleExpert", messages, max_tokens, temperature, **gen_kwargs)
return self._finalize("RoleExpert", prompt, text, dt, tokens, deltas, "custom" if template is not None else "default")
def _default_affect_messages(self, text: str) -> List[Dict[str, Any]]:
system = "You are a compact classifier. Output only valid JSON."
user = f'Classify the emotional tone of this text:\n{text}\nReturn: {{"emotion": "...", "confidence": 0-1, "evidence": "..."}}'
return [{"role": "system", "content": system}, {"role": "user", "content": user}]
def affect_expert(self, text: str, template: Optional[Union[str, list, Callable]] = None, max_tokens: int = 300, temperature: float = 0.3, **gen_kwargs) -> str:
messages = self._build_messages(template, text, self._default_affect_messages)
out, tokens, dt, deltas = self._exec_chat("AffectExpert", messages, max_tokens, temperature, **gen_kwargs)
return self._finalize("AffectExpert", text, out, dt, tokens, deltas, "custom" if template is not None else "default")
def _default_vision_messages(self, prompt: str, image: Optional[str] = None) -> List[Dict[str, Any]]:
if not image:
return [{"role": "user", "content": prompt}]
data_uri = image if self._is_url(image) else self._image_to_data_uri(image)
content = [
{"type": "image_url", "image_url": {"url": data_uri}},
{"type": "text", "text": prompt},
]
return [{"role": "user", "content": content}]
def _run_multimodal(self, expert_name: str, prompt: str, image: Optional[str], template: Optional[Union[str, list, Callable]], max_tokens: int, temperature: float, **gen_kwargs) -> str:
messages = self._build_messages(template, prompt, self._default_vision_messages, image=image)
text, tokens, dt, deltas = self._exec_chat(expert_name, messages, max_tokens, temperature, **gen_kwargs)
return self._finalize(expert_name, prompt, text, dt, tokens, deltas, "custom" if template is not None else "default")
def vision_expert(self, prompt: str, image: Optional[str] = None, template: Optional[Union[str, list, Callable]] = None, max_tokens: int = 512, temperature: float = 0.4, **gen_kwargs) -> str:
return self._run_multimodal("VisionExpert", prompt, image, template, max_tokens, temperature, **gen_kwargs)
def head_expert(self, prompt: str, image: Optional[str] = None, template: Optional[Union[str, list, Callable]] = None, max_tokens: int = 512, temperature: float = 0.4, **gen_kwargs) -> str:
return self._run_multimodal("HeadExpert", prompt, image, template, max_tokens, temperature, **gen_kwargs)
def _default_math_messages(self, prompt: str) -> List[Dict[str, Any]]:
system = "You are a precise math and reasoning assistant."
user = f"Solve the following. Show formulas, compute carefully, and state the final answer clearly.\n\n{prompt}"
return [{"role": "system", "content": system}, {"role": "user", "content": user}]
def math_expert(self, prompt: str, template: Optional[Union[str, list, Callable]] = None, max_tokens: int = 500, temperature: float = 0.2, **gen_kwargs) -> str:
messages = self._build_messages(template, prompt, self._default_math_messages)
text, tokens, dt, deltas = self._exec_chat("MathExpert", messages, max_tokens, temperature, repeat_penalty=1.15, **gen_kwargs)
if CHINESE_RE.search(text):
text = self._translate_chinese_spans(text)
text = collapse_repeated_lines(text, max_repeat=1)
return self._finalize("MathExpert", prompt, text, dt, tokens, deltas, "custom" if template is not None else "default")
def _default_tool_prompt(self, task_instruction: str, tools_json: str, format_instruction: str, query: str) -> str:
return (
"You are an AI assistant for function calling. For politically sensitive questions, "
"security and privacy issues, and other non-computer science questions, you will refuse "
"to answer\n"
"### Instruction:\n"
f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n"
f"[BEGIN OF AVAILABLE TOOLS]\n{tools_json}\n[END OF AVAILABLE TOOLS]\n\n"
f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n"
f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n"
"### Response:\n"
)
@staticmethod
def _safe_parse_tool_json(text: str) -> Optional[Dict[str, Any]]:
text = text.strip()
try:
return json.loads(text)
except json.JSONDecodeError:
start, end = text.find("{"), text.rfind("}")
if start != -1 and end != -1 and end > start:
try:
return json.loads(text[start:end + 1])
except json.JSONDecodeError:
return None
return None
def tool_expert(self, query: str, tools: Optional[List[Dict[str, Any]]] = None, task_instruction: Optional[str] = None, format_instruction: Optional[str] = None, template: Optional[Union[str, Callable]] = None, max_tokens: int = 512, temperature: float = 0.2, **gen_kwargs) -> str:
tools = tools or []
task_instruction = task_instruction or self._XLAM_TASK_INSTRUCTION_DEFAULT
format_instruction = format_instruction or self._XLAM_FORMAT_INSTRUCTION_DEFAULT
tools_json = json.dumps(tools, indent=2)
if template is not None:
stop = gen_kwargs.pop("stop", ["### Instruction:"])
raw_prompt = template(query, tools, task_instruction, format_instruction) if callable(template) else str(template).format(query=query, tools=tools_json, task_instruction=task_instruction, format_instruction=format_instruction)
else:
stop = ["### Instruction:"]
raw_prompt = self._default_tool_prompt(task_instruction, tools_json, format_instruction, query)
raw_text, tokens, dt, deltas = self._exec_completion("ToolExpert", raw_prompt, max_tokens, temperature, stop=stop, **gen_kwargs)
parsed = self._safe_parse_tool_json(raw_text)
final_text = json.dumps(parsed, indent=2) if parsed is not None else raw_text.strip()
return self._finalize("ToolExpert", query, final_text, dt, tokens, deltas, "custom" if template is not None else "default")
def translation_expert(self, text: str) -> str:
return self.translate_zh_en(text)
def save_checkpoint(self, path: str = DEFAULT_CHECKPOINT_PATH) -> None:
self.unload_all()
torch.save(self, path)
@classmethod
def load_checkpoint(cls, path: str = DEFAULT_CHECKPOINT_PATH) -> "PackedLM":
return torch.load(path, map_location="cpu", weights_only=False)
def run_self_test(self, image_test_source: Optional[str] = None) -> Dict[str, Any]:
prompts = {
"CreativeExpert": "Write a short fantasy story (150 words max) about a dragon that discovers a computer hidden beneath a mountain.",
"CodeExpert": "Write a Python implementation of quicksort. Complete runnable function, include comments, briefly explain time complexity.",
"LogicExpert": "All robots can compute. Some computers are robots. No calculators are robots. What conclusions can be logically inferred?",
"RoleExpert": "Explain how transformer attention works while staying fully in character.",
"HeadExpert": "Explain how Mixture-of-Experts (MoE) routing works. Focus on: expert selection, gating, token routing, and efficiency benefits.",
"MathExpert": "A train travels 120 miles in 2 hours and then 180 miles in 3 hours. What was its average speed for the entire trip?",
"ToolExpert": "What is the weather in St. Louis right now?",
"AffectExpert": "Your core navigation router is completely dropping telemetry packets! Fix this or we pull our implementation down tonight!",
"VisionExpert": "Describe the visual elements, layout, and any text in this image.",
"TranslationExpert": "你好,世界。这个模型应该把中文翻译成英文。",
}
report: Dict[str, Any] = {}
for name in self.keys():
if name not in prompts:
continue
try:
if name == "LogicExpert":
report[name] = {mode: self.logic_expert(prompts[name], mode=mode) for mode in ("deep_then_answer", "skip_reasoning", "think_only")}
elif name in ("HeadExpert", "VisionExpert"):
img = image_test_source if (image_test_source and os.path.exists(image_test_source)) else None
method = self.head_expert if name == "HeadExpert" else self.vision_expert
report[name] = method(prompts[name], image=img)
elif name == "ToolExpert":
report[name] = self.tool_expert(prompts[name], tools=[
{"name": "get_weather", "description": "Get current weather for a city", "parameters": {"city": "string"}},
{"name": "send_discord_message", "description": "Send a message to a Discord webhook", "parameters": {"webhook_url": "string", "content": "string"}},
])
elif name == "TranslationExpert":
report[name] = self.translation_expert(prompts[name])
else:
dispatch = {
"CreativeExpert": self.creative_expert,
"CodeExpert": self.code_expert,
"RoleExpert": self.role_expert,
"AffectExpert": self.affect_expert,
"MathExpert": self.math_expert,
}
report[name] = dispatch[name](prompts[name])
except Exception as e:
report[name] = f"[ERROR] {type(e).__name__}: {e}"
return {"per_expert_responses": report, "summary": self.summary()}
__all__ = ["PackedLM", "ExpertHandle", "HardwareProbe", "capture_telemetry", "calculate_delta"]
class PackedLMCheckpointRuntime:
"""Load and operate a serialized PackedLM checkpoint.
Parameters
----------
checkpoint_path:
Path to the saved `LM.pt` checkpoint.
packedlm_module:
Optional module name to import before loading. Use this when the
`PackedLM` class lives in a separate Python module.
Examples: "packedlm", "my_project.packedlm", or None if the class
is already imported in the current process.
map_location:
Passed to `torch.load`. Usually "cpu".
weights_only:
Must be False for a full object checkpoint saved with `torch.save(obj, ...)`.
strict_type_check:
If True, verifies that the loaded object looks like a PackedLM instance.
"""
def __init__(
self,
checkpoint_path: Union[str, Path] = "LM.pt",
packedlm_module: Optional[str] = None,
map_location: str = "cpu",
weights_only: bool = False,
strict_type_check: bool = True,
):
self.checkpoint_path = Path(checkpoint_path)
self.packedlm_module = packedlm_module
self.map_location = map_location
self.weights_only = weights_only
self.strict_type_check = strict_type_check
self.model: Any = None
self.load()
def _import_checkpoint_module(self) -> None:
if not self.packedlm_module:
return
importlib.import_module(self.packedlm_module)
def _validate_model(self) -> None:
if self.model is None:
raise RuntimeError("PackedLM checkpoint is not loaded.")
required_attrs = [
"creative_expert",
"code_expert",
"logic_expert",
"role_expert",
"affect_expert",
"head_expert",
"vision_expert",
"math_expert",
"tool_expert",
"translation_expert",
"summary",
"run_self_test",
]
missing = [name for name in required_attrs if not hasattr(self.model, name)]
if missing and self.strict_type_check:
raise TypeError(
"Loaded object does not look like a PackedLM instance. Missing: "
+ ", ".join(missing)
)
def load(self) -> Any:
if not self.checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {self.checkpoint_path}")
self._import_checkpoint_module()
# The checkpoint is a full serialized object, so weights_only must be False.
self.model = torch.load(
self.checkpoint_path,
map_location=self.map_location,
weights_only=self.weights_only,
)
self._validate_model()
return self.model
def reload(self) -> Any:
self.model = None
return self.load()
def __getattr__(self, name: str) -> Any:
# Delegate unknown attributes to the loaded PackedLM object.
if name in {"model", "checkpoint_path", "packedlm_module", "map_location", "weights_only", "strict_type_check"}:
return super().__getattribute__(name)
if self.model is not None and hasattr(self.model, name):
return getattr(self.model, name)
raise AttributeError(name)
def __getitem__(self, key: str) -> Any:
return self.model[key]
def __contains__(self, key: str) -> bool:
return key in self.model
@property
def last_expert(self) -> str:
return getattr(self.model, "last_expert", "")
def experts(self) -> List[str]:
return list(self.model.keys())
def summary(self) -> Dict[str, Any]:
return self.model.summary()
def creative(
self,
prompt: str,
**kwargs: Any,
) -> str:
return self.model.creative_expert(prompt, **kwargs)
def code(
self,
prompt: str,
**kwargs: Any,
) -> str:
return self.model.code_expert(prompt, **kwargs)
def logic(
self,
prompt: str,
mode: str = "deep_then_answer",
**kwargs: Any,
) -> str:
return self.model.logic_expert(prompt, mode=mode, **kwargs)
def role(
self,
prompt: str,
**kwargs: Any,
) -> str:
return self.model.role_expert(prompt, **kwargs)
def affect(
self,
text: str,
**kwargs: Any,
) -> str:
return self.model.affect_expert(text, **kwargs)
def head(
self,
prompt: str,
image: Optional[str] = None,
**kwargs: Any,
) -> str:
return self.model.head_expert(prompt, image=image, **kwargs)
def vision(
self,
prompt: str,
image: Optional[str] = None,
**kwargs: Any,
) -> str:
return self.model.vision_expert(prompt, image=image, **kwargs)
def math(
self,
prompt: str,
**kwargs: Any,
) -> str:
text=self.model.math_expert(prompt, **kwargs)
text=collapse_repeated_lines(text)
return text
def tool(
self,
query: str,
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
) -> str:
return self.model.tool_expert(query, tools=tools, **kwargs)
def translation(self, text: str) -> str:
return self.model.translation_expert(text)
def reload_expert(self, expert_name: str) -> Any:
return self.model.reload_expert(expert_name)
def unload_expert(self, expert_name: str) -> None:
return self.model.unload_expert(expert_name)
def unload_all(self) -> None:
return self.model.unload_all()
def save_checkpoint(self, path: Union[str, Path] = "LM.pt") -> None:
return self.model.save_checkpoint(str(path))
def run_self_test(self, image_test_source: Optional[str] = None) -> Dict[str, Any]:
return self.model.run_self_test(image_test_source=image_test_source)
def __repr__(self) -> str:
model_name = type(self.model).__name__ if self.model is not None else ""
return f"PackedLMCheckpointRuntime(checkpoint_path={self.checkpoint_path!s}, model={model_name})"
def load_packedlm(
checkpoint_path: Union[str, Path] = "LM.pt",
packedlm_module: Optional[str] = None,
map_location: str = "cpu",
weights_only: bool = False,
strict_type_check: bool = True,
) -> PackedLMCheckpointRuntime:
"""Load a PackedLM checkpoint into a runtime wrapper."""
return PackedLMCheckpointRuntime(
checkpoint_path=checkpoint_path,
packedlm_module=packedlm_module,
map_location=map_location,
weights_only=weights_only,
strict_type_check=strict_type_check,
)
@dataclass
class RouteStep:
expert: str
sub_prompt: str
goal: str
kwargs: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ExecutionContext:
prompt: str
image: Optional[str] = None
tools: Optional[List[Dict[str, Any]]] = None
deep_think: bool = False
fast_think: bool = False
think_blocks: Dict[str, str] = field(default_factory=dict)
response_goal: Dict[str, Any] = field(default_factory=dict)
route: List["RouteStep"] = field(default_factory=list)
step_results: List[Dict[str, Any]] = field(default_factory=list)
base_response: str = ""
affective_state: Dict[str, Any] = field(default_factory=dict)
final_response: str = ""
final_review: Dict[str, Any] = field(default_factory=dict)
command_context: Dict[str, Any] = field(default_factory=dict)
class PackedLLM(nn.ModuleDict):
MODEL_EXPERTS: List[str] = [
"head_expert", "affect_expert", "role_expert", "creative_expert",
"code_expert", "logic_expert", "math_expert", "vision_expert",
"tool_expert", "translation_expert",
]
PIPELINE_EXPERTS: List[str] = ["action_expert", "web_expert"]
REQUIRED_EXPERTS: List[str] = MODEL_EXPERTS + PIPELINE_EXPERTS
MAX_STEP_RETRIES: int = 3
MAX_ACTION_ATTEMPTS: int = 3
MAX_WEB_ROUNDS: int = 3
_CHECKPOINT_FORMAT_VERSION: int = 3
_STAGE_SETTINGS: Dict[str, Dict[str, Any]] = {
"head_plan_response_goal": {"temperature": 0.2},
"head_build_route": {"temperature": 0.5},
"head_retry_or_reroute": {"temperature": 0.8, "top_p": 0.9},
"head_plan_detour": {"temperature": 1.0, "top_p": 0.95},
"head_synthesize_base": {"temperature": 1.0, "top_p": 0.95},
"head_review_final_response": {"temperature": 0.0},
"head_action_review": {"temperature": 1.0, "top_p": 0.95},
"head_web_queries": {"temperature": 0.5},
"head_web_answer_subquery": {"temperature": 0.5},
"head_web_review": {"temperature": 1.0, "top_p": 0.95},
"head_web_synthesis": {"temperature": 0.8, "top_p": 0.9},
"head_validate_response": {"temperature": 0.8, "top_p": 0.9},
"affect_evaluate_step": {"temperature": 0.2},
"affect_build_affective_state": {"temperature": 0.5},
"role_apply_persona": {"temperature": 1.0, "top_p": 0.95},
"logic_action_planning": {"temperature": 0.0},
"code_action_generation": {"temperature": 0.0},
"logic_action_repair": {"temperature": 0.0},
"deep_think": {"temperature": 0.0},
}
def __init__(
self,
bot_id: Optional[str] = None,
user_id: Optional[str] = None,
model_dir: str = "models",
memory_dir: Optional[str] = None,
web: bool = False,
hardware_probe: bool = True,
expert_modules: Optional[Dict[str, nn.Module]] = None,
packedlm_checkpoint: Optional[str] = "LM.pt",
packedlm_module: Optional[str] = "PackedLM",
):
super().__init__()
self.bot_id = bot_id
self.user_id = user_id
self.model_dir = model_dir
self.memory_dir = memory_dir
self._hardware_probe_enabled = hardware_probe
self.packedlm_checkpoint = self._resolve_local_path(packedlm_checkpoint)
self.packedlm_module = packedlm_module
self._memory_bank: Any = None
self._bot_profile: Dict[str, Any] = {}
self._user_profile: Dict[str, Any] = {}
self._hardware_state: Dict[str, Any] = {}
self._web: Any = None
self._codebox: Any = None
self._packedlm_runtime: Optional[Any] = None
self._runtime_expert_names: List[str] = []
self._init_packedlm_runtime()
if self._packedlm_runtime is None:
self._build_experts(expert_modules)
self._init_memory()
self._load_profiles()
if hardware_probe:
self._probe_hardware()
if web:
self._attach_web()
# ------------------------------------------------------------------
# Pickle support – live handles must never be serialised
# ------------------------------------------------------------------
def __getstate__(self):
state = self.__dict__.copy()
state["_packedlm_runtime"] = None
state["_web"] = None
state["_codebox"] = None
state["_memory_bank"] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._runtime_expert_names = self.__dict__.get("_runtime_expert_names", [])
self._packedlm_runtime = None
self._web = None
self._codebox = None
self._memory_bank = None
self._init_packedlm_runtime()
# ------------------------------------------------------------------
# Path helpers
# ------------------------------------------------------------------
def _resolve_local_path(self, path: Optional[str]) -> Optional[str]:
if not path:
return None
p = Path(path)
if p.is_absolute():
return str(p)
script_dir = Path(__file__).resolve().parent
candidate = script_dir / p
if candidate.exists():
return str(candidate)
return str(p.resolve())
def _project_root(self) -> Path:
try:
return Path(__file__).resolve().parent
except Exception:
return Path.cwd()
# ------------------------------------------------------------------
# Initialisation helpers
# ------------------------------------------------------------------
def _init_packedlm_runtime(self) -> None:
if not self.packedlm_checkpoint or PackedLMCheckpointRuntime is None:
return
if not os.path.exists(self.packedlm_checkpoint):
return
try:
self._packedlm_runtime = PackedLMCheckpointRuntime(
checkpoint_path=self.packedlm_checkpoint,
packedlm_module=self.packedlm_module,
map_location="cpu",
weights_only=False,
strict_type_check=True,
)
raw_names = list(self._packedlm_runtime.experts())
self._runtime_expert_names = _expert_names_canonical(raw_names)
except Exception as exc:
print(f"[PackedLLM] Warning: could not load PackedLM runtime: {exc}")
self._packedlm_runtime = None
self._runtime_expert_names = []
def _build_experts(self, expert_modules: Optional[Dict[str, nn.Module]]) -> None:
if expert_modules:
for key, module in expert_modules.items():
self[key] = module
return
class_map = {
"head_expert": ("HeadExpert", {"model_dir": self.model_dir}),
"affect_expert": ("AffectExpert", {"model_dir": self.model_dir}),
"role_expert": ("RoleExpert", {"model_dir": self.model_dir}),
"creative_expert": ("CreativeExpert", {"model_dir": self.model_dir}),
"code_expert": ("CodeExpert", {"model_dir": self.model_dir}),
"logic_expert": ("LogicExpert", {"model_dir": self.model_dir}),
"math_expert": ("MathExpert", {"model_dir": self.model_dir}),
"vision_expert": ("VisionExpert", {"model_dir": self.model_dir}),
"tool_expert": ("ToolExpert", {"model_dir": self.model_dir}),
"translation_expert": ("TranslationExpert", {"model_dir": self.model_dir}),
"action_expert": ("ActionExpert", {"model_dir": self.model_dir}),
"web_expert": ("WebExpert", {"model_dir": self.model_dir}),
}
frame_globals: Dict[str, Any] = {}
try:
frame = inspect.stack()[2].frame
frame_globals = frame.f_globals
except Exception:
pass
for key, (cls_name, kwargs) in class_map.items():
cls = frame_globals.get(cls_name) or builtins.__dict__.get(cls_name)
if cls is not None:
try:
self[key] = cls(**kwargs)
except Exception as exc:
print(f"[PackedLLM] Warning: could not instantiate {cls_name}: {exc}")
def _init_memory(self) -> None:
if MemoryBank is None:
return
try:
root = Path(self.memory_dir).expanduser().resolve() if self.memory_dir else self._project_root()
self._memory_bank = MemoryBank(
gator_location=root,
build_if_missing=False,
)
if hasattr(self._memory_bank, "gator") and hasattr(self._memory_bank.gator, "set_lazy"):
self._memory_bank.gator.set_lazy(True)
except Exception as exc:
print(f"[PackedLLM] Warning: MemoryBank degraded mode: {exc}")
self._memory_bank = None
def _load_profiles(self) -> None:
if self._memory_bank is None:
return
if self.bot_id:
try:
self._bot_profile = _safe_call(self._memory_bank, "get_profile", "bot", self.bot_id, default={}) or {}
except Exception:
self._bot_profile = {}
if self.user_id:
try:
self._user_profile = _safe_call(self._memory_bank, "get_profile", "user", self.user_id, default={}) or {}
except Exception:
self._user_profile = {}
def _probe_hardware(self) -> None:
state: Dict[str, Any] = {
"platform": platform.system(),
"python_version": platform.python_version(),
"cpu_count": os.cpu_count(),
}
if psutil is not None:
try:
vm = psutil.virtual_memory()
state["ram_total_gb"] = round(vm.total / 1e9, 1)
state["ram_available_gb"] = round(vm.available / 1e9, 1)
state["ram_percent_used"] = vm.percent
except Exception:
pass
if torch.cuda.is_available():
try:
state["gpu_name"] = torch.cuda.get_device_name(0)
props = torch.cuda.get_device_properties(0)
state["gpu_vram_total_gb"] = round(props.total_memory / 1e9, 1)
state["gpu_vram_used_gb"] = round(torch.cuda.memory_allocated(0) / 1e9, 2)
except Exception:
pass
else:
state["gpu"] = "none"
self._hardware_state = state
def _attach_web(self) -> None:
try:
if Web is None:
raise RuntimeError("CompileWeb.Web unavailable")
self._web = Web()
except Exception as exc:
print(f"[PackedLLM] Warning: Web module unavailable: {exc}")
self._web = None
def _get_codebox(self) -> Any:
if self._codebox is not None:
return self._codebox
if Box is not None:
try:
self._codebox = Box()
return self._codebox
except Exception as exc:
print(f"[PackedLLM] Warning: Box unavailable: {exc}")
if CodeBox is not None:
try:
self._codebox = CodeBox()
return self._codebox
except Exception as exc:
print(f"[PackedLLM] Warning: CodeBox unavailable: {exc}")
return None
def _stage_kwargs(self, stage: str) -> Dict[str, Any]:
return dict(self._STAGE_SETTINGS.get(stage, {}))
def _inject_think_blocks(
self,
ctx: ExecutionContext,
*,
stage: str,
target_expert: str,
task_prompt: str,
output_contract: str,
constraints: Optional[List[str]] = None,
) -> str:
if not ctx.deep_think or target_expert == "translation_expert":
return task_prompt
cache_key = f"{stage}::{target_expert}::{hash(task_prompt) & 0xFFFFFFFF}"
if cache_key in ctx.think_blocks:
think = ctx.think_blocks[cache_key]
else:
constraints_text = "\n".join(f"- {c}" for c in (constraints or [])) or "- none"
think_prompt = (
"You are LogicExpert. Produce PRIVATE planning blocks only.\n"
"Return ONLY ... blocks.\n"
"Do not answer the task. Do not produce JSON. Do not produce prose.\n"
"Make the blocks task-specific, brief, and directly useful to the target expert.\n\n"
f"Stage: {stage}\n"
f"Target expert: {target_expert}\n"
f"Task prompt:\n{task_prompt}\n\n"
f"Output contract:\n{output_contract}\n\n"
f"Constraints:\n{constraints_text}\n\n"
"Generate 2 to 4 blocks covering: objective, risks, format, and the safest path."
)
raw = self._call_expert("logic_expert", think_prompt, **self._stage_kwargs("deep_think"))
think = self._extract_think_blocks(raw)
ctx.think_blocks[cache_key] = think
if think.strip():
return (
"\n"
f"{think.strip()}\n"
"\n\n"
f"{task_prompt}"
)
return task_prompt
def _extract_think_blocks(self, text: str) -> str:
if not text:
return ""
blocks = re.findall(
r".*?",
text,
flags=re.IGNORECASE | re.DOTALL,
)
if blocks:
return "\n".join(
b.strip()
for b in blocks
if b.strip()
)
text = text.strip()
if not text:
return ""
return f"{text}"
def _route_allowed(self, ctx: ExecutionContext, expert: str) -> bool:
if ctx.fast_think:
return expert in {"head_expert", "role_expert", "translation_expert"}
return True
def _fast_path_response(self, ctx: ExecutionContext) -> str:
prompt = (
"Answer the user's request directly and as fast as possible.\n"
"Do not call tools. Do not browse. Do not write JSON.\n"
"Do not add meta commentary.\n\n"
f"User prompt: {ctx.prompt}"
)
return self._call_expert("head_expert", prompt, **self._stage_kwargs("head_synthesize_base")).strip()
def _fast_apply_persona(self, ctx: ExecutionContext, base_response: str) -> str:
if not self._bot_profile:
return base_response
character_card = self._bot_profile.get("character_card", "")
user_name = self._user_profile.get("name", "the user")
prompt = (
f"\nRewrite the base response in the character voice while preserving facts.\n"
f"Stay concise and accurate.\n\n\n"
f"{character_card}\n"
f"{user_name}\n"
f"{ctx.prompt}\n"
f"{base_response}"
)
return self._call_expert("role_expert", prompt,
**self._stage_kwargs("role_apply_persona")).strip() or base_response
# ------------------------------------------------------------------
# Forward pass
# ------------------------------------------------------------------
def forward(
self,
prompt: str,
image: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
stream: bool = False,
deep_think: bool = False,
fast_think: bool = False,
) -> Union[str, Generator[str, None, None]]:
ctx = ExecutionContext(
prompt=prompt,
image=image,
tools=tools,
deep_think=deep_think,
fast_think=fast_think,
)
# Fast path: minimal work, no memory, no action, no web.
if fast_think:
base = self._fast_path_response(ctx)
final = self._fast_apply_persona(ctx, base) if self._bot_profile else base
if stream:
return self._stream_response(final)
return final
self._plan_response_goal(ctx)
self._consult_commands(ctx)
self._build_route(ctx)
self._execute_route(ctx)
self._synthesize_base(ctx)
self._build_affective_state(ctx)
self._apply_persona(ctx)
self._review_final_response(ctx)
self._finalize(ctx)
if stream:
return self._stream_response(ctx.final_response)
return ctx.final_response
# ------------------------------------------------------------------
# Expert dispatch
# ------------------------------------------------------------------
def _call_expert(self, key: str, *args: Any, **kwargs: Any) -> str:
# Normalise key to snake_case so callers don't need to worry.
key = _normalise_expert_name(key)
if self._packedlm_runtime is not None:
dispatch = {
"creative_expert": self._packedlm_runtime.creative,
"code_expert": self._packedlm_runtime.code,
"logic_expert": self._packedlm_runtime.logic,
"role_expert": self._packedlm_runtime.role,
"affect_expert": self._packedlm_runtime.affect,
"head_expert": self._packedlm_runtime.head,
"vision_expert": self._packedlm_runtime.vision,
"math_expert": self._packedlm_runtime.math,
"tool_expert": self._packedlm_runtime.tool,
"translation_expert": self._packedlm_runtime.translation,
"action_expert": getattr(self._packedlm_runtime, "action", None),
"web_expert": getattr(self._packedlm_runtime, "web", None),
}
fn = dispatch.get(key)
if fn is None:
return ""
try:
result = fn(*args, **kwargs)
return str(result) if result is not None else ""
except Exception as exc:
print(f"[PackedLLM] _call_expert(runtime:{key}) error: {exc}")
return ""
expert = self._get_expert(key)
if expert is None:
return ""
try:
result = expert(*args, **kwargs)
print(result)
return str(result) if result is not None else ""
except Exception as exc:
print(f"[PackedLLM] _call_expert({key}) error: {exc}")
return ""
# ------------------------------------------------------------------
# Pipeline stages
# ------------------------------------------------------------------
def _plan_response_goal(self, ctx: ExecutionContext) -> None:
prompt = (
"You are a response planner.\n"
"Return ONLY a JSON object with exactly these keys:\n"
"intent (string), tone (string), success (string), constraints (array), "
"needs_vision (boolean), needs_web (boolean), needs_action (boolean).\n"
"Do not add commentary, markdown, or extra keys.\n"
"Always respond in English unless the task is explicitly a translation task.\n"
"TranslationExpert is Chinese→English only; only flag translation-related routing when "
"the source text is actually Chinese.\n\n"
f"User prompt: {ctx.prompt}"
)
prompt = self._inject_think_blocks(
ctx,
stage="head_plan_response_goal",
target_expert="head_expert",
task_prompt=prompt,
output_contract="JSON object with intent/tone/success/constraints/needs_vision/needs_web/needs_action.",
)
raw = self._call_expert("head_expert", prompt, image=ctx.image if ctx.image else None,
**self._stage_kwargs("head_plan_response_goal"))
ctx.response_goal = _parse_json_safe(raw) or {
"intent": ctx.prompt,
"tone": "helpful",
"success": "Answer the user helpfully.",
"constraints": [],
"needs_vision": False,
"needs_web": False,
"needs_action": False,
}
def _consult_commands(self, ctx: ExecutionContext) -> None:
if ctx.fast_think or self._memory_bank is None:
ctx.command_context = {"executed": [], "results": [], "coverage": "none"}
return
command_context = {
"executed": [],
"results": [],
"coverage": "none",
}
gator = getattr(self._memory_bank, "gator", None)
if gator is None:
ctx.command_context = command_context
return
maybe_execute = getattr(gator, "maybe_execute_commands", None)
if callable(maybe_execute):
try:
result = maybe_execute(ctx.prompt, ctx.response_goal)
if isinstance(result, dict):
command_context.update(result)
except Exception as exc:
command_context["results"].append({"error": str(exc)})
else:
try:
if hasattr(gator, "process_actions"):
routed = gator.process_actions(
ctx.prompt,
user_id=self.user_id or "",
bot_id=self.bot_id or "",
history="",
location="",
time_date="",
)
if isinstance(routed, dict):
command_context["executed"] = routed.get("commands", [])
command_context["results"] = routed.get("retrieved_data", {})
if routed.get("commands") or routed.get("retrieved_data"):
command_context["coverage"] = "partial"
except Exception as exc:
command_context["results"].append({"error": str(exc)})
if command_context["coverage"] == "none":
executed = command_context.get("executed") or []
results = command_context.get("results") or []
if executed or results:
command_context["coverage"] = "partial"
ctx.command_context = command_context
def _build_route(self, ctx: ExecutionContext) -> None:
if ctx.fast_think:
ctx.route = [RouteStep(expert="head_expert", sub_prompt=ctx.prompt,
goal="Produce the fastest useful answer possible.")]
if self._bot_profile:
ctx.route.append(RouteStep(expert="role_expert", sub_prompt=ctx.prompt, goal="Apply persona only."))
return
prompt = (
"You are a response router.\n"
"Return ONLY a JSON array of step objects.\n"
"Each step object must include: expert (string), sub_prompt (string), goal (string), "
"and optional kwargs (object).\n"
"Use the fewest steps needed. Prefer web_expert for live information, action_expert for "
"executable tasks, and head_expert for planning/validation.\n"
"Always write sub_prompt and goal in English.\n"
"Do not include duplicate or redundant steps.\n\n"
f"Response goal: {json.dumps(ctx.response_goal, ensure_ascii=False, default=str)}\n"
f"Command context: {json.dumps(ctx.command_context, ensure_ascii=False, default=str)}\n"
f"Original prompt: {ctx.prompt}\n"
f"Available experts: {', '.join(self.REQUIRED_EXPERTS)}"
)
prompt = self._inject_think_blocks(
ctx,
stage="head_build_route",
target_expert="head_expert",
task_prompt=prompt,
output_contract="JSON array of routing steps.",
)
raw = self._call_expert("head_expert", prompt, **self._stage_kwargs("head_build_route"))
steps_raw = _parse_json_safe(raw)
route: List[RouteStep] = []
if isinstance(steps_raw, list) and steps_raw:
for s in steps_raw:
if not isinstance(s, dict):
continue
expert = _normalise_expert_name(s.get("expert", "head_expert"))
if not self._route_allowed(ctx, expert):
continue
route.append(
RouteStep(
expert=expert,
sub_prompt=str(s.get("sub_prompt", ctx.prompt)),
goal=str(s.get("goal", "Complete the sub-task.")),
kwargs=s.get("kwargs", {}) if isinstance(s.get("kwargs", {}), dict) else {},
)
)
if not route:
goal = ctx.response_goal if isinstance(ctx.response_goal, dict) else {}
if goal.get("needs_web") and self._route_allowed(ctx, "web_expert"):
route.append(RouteStep(expert="web_expert", sub_prompt=ctx.prompt,
goal="Gather and synthesize fresh external information."))
if goal.get("needs_action") and self._route_allowed(ctx, "action_expert"):
route.append(RouteStep(expert="action_expert", sub_prompt=ctx.prompt,
goal="Execute the requested action or code workflow."))
if goal.get("needs_vision") and ctx.image and self._route_allowed(ctx, "vision_expert"):
route.append(
RouteStep(expert="vision_expert", sub_prompt=ctx.prompt, goal="Interpret the provided image.",
kwargs={"image": ctx.image}))
route.append(
RouteStep(expert="head_expert", sub_prompt=ctx.prompt, goal="Produce a complete, helpful response."))
# Remove redundant consecutive steps.
compressed: List[RouteStep] = []
for step in route:
if compressed and compressed[-1].expert == step.expert and step.expert != "translation_expert":
continue
compressed.append(step)
ctx.route = compressed
def _execute_route(self, ctx: ExecutionContext) -> None:
i = 0
while i < len(ctx.route):
step = ctx.route[i]
if step.expert == "action_expert":
result = self._run_action_pipeline(ctx, step)
ctx.step_results.append({
"expert": step.expert,
"sub_prompt": step.sub_prompt,
"result": result,
"passed": bool(result.get("passed", False)),
"action": result.get("action", "action_pipeline"),
})
i += 1
continue
if step.expert == "web_expert":
result = self._run_web_pipeline(ctx, step)
ctx.step_results.append({
"expert": step.expert,
"sub_prompt": step.sub_prompt,
"result": result,
"passed": bool(result.get("passed", False)),
"action": result.get("action", "web_pipeline"),
})
i += 1
continue
retries = 0
while True:
result = self._execute_step(step, ctx)
passed = self._evaluate_step(result, step, ctx)
if passed:
ctx.step_results.append({
"expert": step.expert,
"sub_prompt": step.sub_prompt,
"result": result,
"passed": True,
})
break
retries += 1
action = self._retry_or_reroute(step, retries, result, ctx)
if action == "retry" and retries < self.MAX_STEP_RETRIES:
continue
if action == "detour":
new_steps = self._plan_detour(step, result, ctx)
ctx.route = ctx.route[:i + 1] + new_steps + ctx.route[i + 1:]
ctx.step_results.append({
"expert": step.expert,
"sub_prompt": step.sub_prompt,
"result": result,
"passed": False,
"action": "detour",
})
break
ctx.step_results.append({
"expert": step.expert,
"sub_prompt": step.sub_prompt,
"result": result,
"passed": False,
"action": "skip",
})
break
i += 1
def _execute_step(self, step: RouteStep, ctx: ExecutionContext) -> str:
kwargs = dict(step.kwargs)
if step.expert in ("head_expert", "vision_expert") and ctx.image:
kwargs.setdefault("image", ctx.image)
if step.expert == "tool_expert" and ctx.tools:
kwargs.setdefault("tools", ctx.tools)
if step.expert == "translation_expert":
# Translation stays direct; no deep-think wrapper.
kwargs = {}
if not CHINESE_RE.search(step.sub_prompt or "") and not CHINESE_RE.search(ctx.prompt or ""):
return step.sub_prompt
prompt_text = step.sub_prompt
if ctx.deep_think and step.expert != "translation_expert":
prompt_text = self._inject_think_blocks(
ctx,
stage=f"{step.expert}:{step.goal}",
target_expert=step.expert,
task_prompt=step.sub_prompt,
output_contract="Task-specific expert response.",
constraints=[step.goal],
)
return self._call_expert(step.expert, prompt_text, **kwargs)
def _evaluate_step(self, result: str, step: RouteStep, ctx: ExecutionContext) -> bool:
meta = (
"Evaluate whether the following result meets the stated goal. "
"Return ONLY JSON: {\"pass\": true/false, \"reason\": \"...\"}.\n\n"
f"Goal: {step.goal}\n"
f"Result: {result[:500]}"
)
raw = self._call_expert("affect_expert", meta)
parsed = _parse_json_safe(raw)
if isinstance(parsed, dict):
return bool(parsed.get("pass", True))
return True
def _retry_or_reroute(self, step: RouteStep, retry_count: int, result: str, ctx: ExecutionContext) -> str:
options = ["detour", "skip"]
if retry_count < self.MAX_STEP_RETRIES:
options = ["retry"] + options
meta = (
"A pipeline step has failed its quality check. Decide what to do.\n"
f"Failed expert: {step.expert}\n"
f"Sub-prompt: {step.sub_prompt}\n"
f"Step goal: {step.goal}\n"
f"Result so far: {result[:300]}\n"
f"Retry count: {retry_count}\n"
f"Available options: {options}\n"
"Return ONLY JSON: {\"action\": \"