study-partner / test_model_config.py
nz-nz's picture
Sync from GitHub via hub-sync
6736c51 verified
Raw
History Blame Contribute Delete
2.18 kB
"""
NAH-9 — model fallback is a one-env-var config flip.
Re-imports llm under different RECALL_MODEL values to confirm alias resolution
and verbatim passthrough of full HF ids. No GPU/model download (we never call
_load). Run: python3 test_model_config.py
"""
import importlib
import os
def _reload_with(model_env):
if model_env is None:
os.environ.pop("RECALL_MODEL", None)
else:
os.environ["RECALL_MODEL"] = model_env
os.environ["RECALL_STUB"] = "1" # keep stub on; we only check config, not GPU
import llm
return importlib.reload(llm)
def test_default_is_v46():
llm = _reload_with(None)
assert llm.MODEL_ID == "openbmb/MiniCPM-V-4.6"
assert llm.VISION is True, "default is the multimodal model"
print("ok default -> MiniCPM-V-4.6 (multimodal)")
def test_1b_alias():
llm = _reload_with("1b")
assert llm.MODEL_ID == "openbmb/MiniCPM5-1B"
print("ok RECALL_MODEL=1b -> MiniCPM5-1B (fast fallback)")
def test_4b_alias():
llm = _reload_with("4b")
assert llm.MODEL_ID == "openbmb/MiniCPM3-4B"
print("ok RECALL_MODEL=4b -> MiniCPM3-4B (Tiny Titan)")
def test_full_id_passthrough():
llm = _reload_with("some-org/Custom-Model-7B")
assert llm.MODEL_ID == "some-org/Custom-Model-7B"
assert llm.VISION is False, "a non MiniCPM-V id is not a vision model"
print("ok unknown value passed through as a literal HF id")
def test_vision_detection():
llm = _reload_with("v46")
assert llm.MODEL_ID == "openbmb/MiniCPM-V-4.6" and llm.VISION is True
llm = _reload_with("8b")
assert llm.MODEL_ID == "openbmb/MiniCPM4.1-8B" and llm.VISION is False
print("ok vision detection: -V ids -> VISION, text ids -> not")
def test_active_model_reports_stub():
llm = _reload_with("1b")
assert llm.active_model() == "stub", "STUB on -> active_model() is 'stub'"
print("ok active_model() reports 'stub' while stubbed")
if __name__ == "__main__":
test_default_is_v46()
test_1b_alias()
test_4b_alias()
test_full_id_passthrough()
test_vision_detection()
test_active_model_reports_stub()
print("\nAll NAH-9 model-config tests passed.")