study-partner / test_load_overrides.py
nz-nz's picture
Deploy Recall study-partner app (stub-mode demo)
7563305 verified
Raw
History Blame Contribute Delete
2.2 kB
"""
NAH-49 — local real-model dtype/device overrides.
The Space-correct defaults (bf16 + device_map="auto") produce garbage on
Apple-Silicon MPS, so _load() honors RECALL_DTYPE / RECALL_DEVICE for local dev.
These check the resolution logic only — no torch, no GPU, no model download.
python3 test_load_overrides.py
"""
import importlib
import os
def _reload_with(dtype=None, device=None):
for key, val in (("RECALL_DTYPE", dtype), ("RECALL_DEVICE", device)):
if val is None:
os.environ.pop(key, None)
else:
os.environ[key] = val
os.environ["RECALL_STUB"] = "1" # never actually load a model here
import llm
return importlib.reload(llm)
def test_defaults_match_the_space():
llm = _reload_with()
assert llm._resolve_dtype_name() == "bfloat16"
assert llm._resolve_device_map() == "auto"
print("ok defaults -> bfloat16 + auto (unchanged for the Space)")
def test_apple_silicon_local_combo():
llm = _reload_with(dtype="float32", device="cpu")
assert llm._resolve_dtype_name() == "float32"
assert llm._resolve_device_map() == "cpu"
print("ok RECALL_DTYPE=float32 RECALL_DEVICE=cpu honored (Mac smoke test)")
def test_dtype_aliases_normalize():
for raw, want in [("fp16", "float16"), ("half", "float16"),
("fp32", "float32"), ("BF16", "bfloat16")]:
llm = _reload_with(dtype=raw)
assert llm._resolve_dtype_name() == want, f"{raw} -> {want}"
print("ok dtype aliases (fp16/half/fp32/BF16) normalize")
def test_unknown_dtype_falls_back_not_crashes():
llm = _reload_with(dtype="not-a-dtype")
assert llm._resolve_dtype_name() == "bfloat16", "unknown -> safe default"
print("ok unknown RECALL_DTYPE falls back to bfloat16 (no crash at load)")
if __name__ == "__main__":
test_defaults_match_the_space()
test_apple_silicon_local_combo()
test_dtype_aliases_normalize()
test_unknown_dtype_falls_back_not_crashes()
# Reset env so a follow-on import in the same shell sees clean defaults.
for k in ("RECALL_DTYPE", "RECALL_DEVICE"):
os.environ.pop(k, None)
print("\nAll NAH-49 load-override tests passed.")