wenjiao's picture
fix init Evaluation queue count + add quant description + update overhead_factor params
4802f1a
"""Unit tests for auto_eval (add_new_eval) and auto_quant (add_new_quant).
Test models:
- nytopop/Qwen3-30B-A3B.w4a16 (quantized W4A16 β†’ auto_eval)
- Qwen/Qwen3-30B-A3B (FP bfloat16 β†’ auto_quant)
"""
import json
import logging
import re
import sys
import os
from types import SimpleNamespace
from pathlib import Path
# Ensure project root is on the path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logging.basicConfig(level=logging.DEBUG, format="%(name)s %(levelname)s: %(message)s")
logger = logging.getLogger("test_submit")
# ── Imports from the project ─────────────────────────────────────────────────
from transformers import AutoConfig
from huggingface_hub import HfApi
from src.submission.check_validity import (
get_model_size,
get_quantized_model_parameters_memory,
validate_quantization_scheme,
estimate_weight_memory_gb,
estimate_quantization_memory_gb,
get_num_layers,
select_gpu,
SUPPORTED_QUANT_SCHEMES,
SUPPORTED_INPUT_DTYPES,
PRECISION_TO_BITS,
BYTES,
KNOWN_SIZE_FACTOR,
get_gpu_display_name,
is_model_on_hub,
)
import src.submission.check_validity as check_validity
from src.submission.submit import (
_normalize_file_tag_component,
add_new_eval,
add_new_quant,
)
import src.submission.submit as submit_module
API = HfApi()
def _is_error(result: str) -> bool:
"""Check if result is a styled_error (red) response."""
return "color: red" in result
def _is_success(result: str) -> bool:
"""Check if result is a styled_message (green) response."""
return "color: green" in result
def _is_warning(result: str) -> bool:
"""Check if result is a styled_warning (orange) response."""
return "color: orange" in result
def _consume_generator(gen):
"""Consume a generator (or plain value), return the last yielded value."""
if hasattr(gen, '__next__'):
result = None
for value in gen:
result = value
return result
return gen
# ═══════════════════════════════════════════════════════════════════════════════
# Helper: inspect model config
# ═══════════════════════════════════════════════════════════════════════════════
def inspect_model(model_name: str, revision: str = "main"):
"""Fetch and print model config details for debugging."""
print(f"\n{'='*70}")
print(f" Inspecting: {model_name}")
print(f"{'='*70}")
# 1. AutoConfig
try:
config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)
print(f"\n[Config] architectures: {getattr(config, 'architectures', None)}")
print(f"[Config] torch_dtype: {getattr(config, 'torch_dtype', None)}")
print(f"[Config] num_hidden_layers: {getattr(config, 'num_hidden_layers', None)}")
print(f"[Config] num_attention_heads: {getattr(config, 'num_attention_heads', None)}")
print(f"[Config] hidden_size: {getattr(config, 'hidden_size', None)}")
# MoE-specific
for moe_attr in ("num_experts", "num_local_experts", "num_experts_per_tok"):
val = getattr(config, moe_attr, None)
if val is not None:
print(f"[Config] {moe_attr}: {val}")
# quantization_config
qc = getattr(config, "quantization_config", None)
if qc is not None:
if hasattr(qc, "to_dict"):
qc_dict = qc.to_dict()
elif isinstance(qc, dict):
qc_dict = qc
else:
qc_dict = {"raw": str(qc)}
print(f"[Config] quantization_config: {json.dumps(qc_dict, indent=2)}")
else:
print(f"[Config] quantization_config: None")
except Exception as e:
print(f"[Config] ERROR: {e}")
config = None
# 2. Model info from HF API
try:
info = API.model_info(repo_id=model_name, revision=revision)
print(f"\n[ModelInfo] id: {info.id}")
print(f"[ModelInfo] likes: {info.likes}")
print(f"[ModelInfo] siblings count: {len(info.siblings) if info.siblings else 0}")
# List sibling files (first 20)
if info.siblings:
fnames = [s.rfilename for s in info.siblings]
print(f"[ModelInfo] files (first 20):")
for f in fnames[:20]:
print(f" {f}")
if len(fnames) > 20:
print(f" ... and {len(fnames) - 20} more")
except Exception as e:
print(f"[ModelInfo] ERROR: {e}")
info = None
return config, info
def test_file_tag_component_prefers_parenthesized_value():
assert _normalize_file_tag_component("INT4 (W4A16)") == "W4A16"
assert _normalize_file_tag_component("INT8 ( W8A16 )") == "W8A16"
assert _normalize_file_tag_component("MXFP4") == "MXFP4"
def test_gpu_display_name_uses_full_label():
assert get_gpu_display_name("4090") == "NVIDIA GeForce RTX 4090"
assert get_gpu_display_name("A100") == "NVIDIA A100-SXM4-80GB"
assert get_gpu_display_name("H200") == "H200"
def test_get_num_layers_supports_nested_raw_config_dict():
config = {
"model_type": "qwen3_5",
"text_config": {
"num_hidden_layers": 36,
"torch_dtype": "bfloat16",
},
}
assert get_num_layers(config) == 36
def test_get_model_size_uses_config_param_count_fallback(monkeypatch):
def _raise(*_args, **_kwargs):
raise RuntimeError("metadata unavailable")
monkeypatch.setattr(check_validity, "get_safetensors_metadata", _raise)
model_info = SimpleNamespace(id="org/custom-model")
params_b, size_gb = get_model_size(
model_info,
precision="16bit",
model_config={"num_parameters": "0.8B"},
)
assert params_b == 0.8
assert size_gb == 1.6
def test_is_model_on_hub_returns_authorization_guidance_for_gated_repo(monkeypatch):
def _raise(*_args, **_kwargs):
raise RuntimeError("You are trying to access a gated repo.")
monkeypatch.setattr(check_validity.AutoConfig, "from_pretrained", _raise)
ok, message, config = is_model_on_hub("org/gated-model", revision="main")
assert ok is False
assert config is None
assert "https://huggingface.co/org/gated-model" in message
assert "request or accept access first" in message
def test_add_new_quant_surfaces_gated_repo_authorization_message(monkeypatch):
monkeypatch.setattr("src.submission.submit._load_quant_cache", lambda: None)
monkeypatch.setattr("src.submission.submit._common_pre_checks", lambda *_args, **_kwargs: None)
monkeypatch.setattr(
"src.submission.submit.is_model_on_hub",
lambda **_kwargs: (
False,
"is gated on the Hugging Face Hub. Please open https://huggingface.co/org/gated-model and request or accept access first. After access is granted, resubmit the model.",
None,
),
)
result = _consume_generator(add_new_quant(
model="org/gated-model",
revision="main",
private=False,
))
assert _is_warning(result)
assert "https://huggingface.co/org/gated-model" in result
assert "request or accept access first" in result
def test_add_new_quant_allows_whitelisted_resubmit_for_failed_entry(monkeypatch, tmp_path):
status_root = tmp_path / "status"
pending_root = tmp_path / "pending"
status_dir = status_root / "quant"
pending_dir = pending_root / "quant"
status_dir.mkdir(parents=True)
pending_dir.mkdir(parents=True)
scheme = SUPPORTED_QUANT_SCHEMES["INT4 (W4A16)"]
model_name = "org/model"
dedup_key = (
f"{model_name}_main_{scheme.name}_{scheme.precision}_{scheme.weight_dtype}_{scheme.name}"
)
failed_entry = {
"model": model_name,
"revision": "main",
"quant_scheme": scheme.name,
"quant_precision": scheme.precision,
"quant_weight_dtype": scheme.weight_dtype,
"status": "Quant Failed",
}
(status_dir / "failed.json").write_text(json.dumps(failed_entry), encoding="utf-8")
pending_entry = dict(failed_entry, status="Pending")
(pending_dir / "stale_request_copy.json").write_text(json.dumps(pending_entry), encoding="utf-8")
monkeypatch.setattr(submit_module, "GIT_STATUS_PATH", str(status_root))
monkeypatch.setattr(submit_module, "GIT_REQUESTS_PATH", str(pending_root))
monkeypatch.setattr(submit_module, "SIZE_WHITELIST", {"alice"})
monkeypatch.setattr(submit_module, "_QUANT_REQUESTED", {dedup_key})
monkeypatch.setattr(submit_module, "_SUBMITTER_DATES", {})
monkeypatch.setattr(submit_module, "_load_quant_cache", lambda: None)
monkeypatch.setattr(submit_module, "_common_pre_checks", lambda *_args, **_kwargs: None)
monkeypatch.setattr(
submit_module,
"is_model_on_hub",
lambda **_kwargs: (
True,
"",
{"architectures": ["TestArch"], "torch_dtype": "float16", "num_hidden_layers": 24},
),
)
monkeypatch.setattr(
submit_module.API,
"model_info",
lambda **_kwargs: SimpleNamespace(cardData={"license": "apache-2.0"}, likes=0),
)
monkeypatch.setattr(
submit_module,
"check_model_card",
lambda *_args, **_kwargs: (True, "", SimpleNamespace(text="x" * 300, data=SimpleNamespace(tags=[]))),
)
monkeypatch.setattr(submit_module, "get_model_tags", lambda *_args, **_kwargs: [])
monkeypatch.setattr(submit_module, "is_license_approved", lambda *_args, **_kwargs: True)
monkeypatch.setattr(submit_module, "get_model_size", lambda *_args, **_kwargs: (7.0, 14.0))
monkeypatch.setattr(submit_module, "get_num_layers", lambda *_args, **_kwargs: 24)
monkeypatch.setattr(submit_module, "estimate_quantization_memory_gb", lambda *_args, **_kwargs: 12.0)
monkeypatch.setattr(submit_module, "estimate_weight_memory_gb", lambda *_args, **_kwargs: 5.0)
monkeypatch.setattr(submit_module, "select_gpu_with_override", lambda *_args, **_kwargs: ("A100", 1))
monkeypatch.setattr(submit_module, "get_gpu_display_name", lambda value: value)
monkeypatch.setattr(submit_module, "compute_single_eta", lambda *_args, **_kwargs: 1)
monkeypatch.setattr(submit_module, "format_eta", lambda *_args, **_kwargs: "1h")
uploaded = {"called": False, "file_tag": None}
def _fake_upload(entry, user_name, model_path, file_tag, model, task_label="eval"):
uploaded["called"] = True
uploaded["file_tag"] = file_tag
monkeypatch.setattr(submit_module, "_upload_to_hub", _fake_upload)
result = _consume_generator(add_new_quant(
model=model_name,
revision="main",
private=False,
quant_scheme="INT4 (W4A16)",
submitted_by="alice",
))
assert uploaded["called"] is True
assert _is_success(result)
# Re-submission must not overwrite the previous failed status file: the
# filename gets a timestamp suffix appended to keep both records.
assert uploaded["file_tag"] is not None
assert re.search(r"_\d{4}", uploaded["file_tag"]), uploaded["file_tag"]
def test_add_new_eval_allows_whitelisted_resubmit_for_failed_entry(monkeypatch, tmp_path):
status_root = tmp_path / "status"
pending_root = tmp_path / "pending"
status_dir = status_root / "eval"
pending_dir = pending_root / "eval"
status_dir.mkdir(parents=True)
pending_dir.mkdir(parents=True)
model_name = "org/quant-model"
dedup_key = f"{model_name}_main_AutoRound_4bit_int4_INT4 (W4A16)"
failed_entry = {
"model": model_name,
"revision": "main",
"quant_type": "AutoRound",
"precision": "4bit",
"weight_dtype": "int4",
"compute_dtype": "INT4 (W4A16)",
"status": "Eval Failed",
}
(status_dir / "failed.json").write_text(json.dumps(failed_entry), encoding="utf-8")
pending_entry = dict(failed_entry, status="Pending")
(pending_dir / "stale_request_copy.json").write_text(json.dumps(pending_entry), encoding="utf-8")
monkeypatch.setattr(submit_module, "GIT_STATUS_PATH", str(status_root))
monkeypatch.setattr(submit_module, "GIT_REQUESTS_PATH", str(pending_root))
monkeypatch.setattr(submit_module, "SIZE_WHITELIST", {"alice"})
monkeypatch.setattr(submit_module, "_EVAL_REQUESTED", {dedup_key})
monkeypatch.setattr(submit_module, "_SUBMITTER_DATES", {})
monkeypatch.setattr(submit_module, "_load_eval_cache", lambda: None)
monkeypatch.setattr(submit_module, "_common_pre_checks", lambda *_args, **_kwargs: None)
monkeypatch.setattr(
submit_module,
"is_model_on_hub",
lambda **_kwargs: (
True,
"",
{"architectures": ["TestArch"], "quantization_config": {"quant_method": "AutoRound"}},
),
)
monkeypatch.setattr(
submit_module,
"validate_quantization_scheme",
lambda *_args, **_kwargs: (True, SimpleNamespace(name="INT4 (W4A16)", precision="4bit", weight_dtype="int4", bits=4, hardware="A100", script="auto_eval"), "AutoRound"),
)
monkeypatch.setattr(
submit_module.API,
"model_info",
lambda **_kwargs: SimpleNamespace(cardData={"license": "apache-2.0"}, likes=0),
)
monkeypatch.setattr(
submit_module,
"check_model_card",
lambda *_args, **_kwargs: (True, "", SimpleNamespace(text="x" * 300, data=SimpleNamespace(tags=[]))),
)
monkeypatch.setattr(submit_module, "get_model_tags", lambda *_args, **_kwargs: [])
monkeypatch.setattr(submit_module, "is_license_approved", lambda *_args, **_kwargs: True)
monkeypatch.setattr(submit_module, "get_quantized_model_parameters_memory", lambda *_args, **_kwargs: (7.0, 3.5))
monkeypatch.setattr(submit_module, "estimate_weight_memory_gb", lambda *_args, **_kwargs: 5.0)
monkeypatch.setattr(submit_module, "select_gpu_with_override", lambda *_args, **_kwargs: ("A100", 1))
monkeypatch.setattr(submit_module, "get_gpu_display_name", lambda value: value)
monkeypatch.setattr(submit_module, "compute_single_eta", lambda *_args, **_kwargs: 1)
monkeypatch.setattr(submit_module, "format_eta", lambda *_args, **_kwargs: "1h")
uploaded = {"called": False, "file_tag": None}
def _fake_upload(entry, user_name, model_path, file_tag, model, task_label="eval"):
uploaded["called"] = True
uploaded["file_tag"] = file_tag
monkeypatch.setattr(submit_module, "_upload_to_hub", _fake_upload)
result = _consume_generator(add_new_eval(
model=model_name,
revision="main",
private=False,
compute_dtype="INT4 (W4A16)",
submitted_by="alice",
))
assert uploaded["called"] is True
assert _is_success(result)
# Re-submission must not overwrite the previous failed status file: the
# filename gets a timestamp suffix appended to keep both records.
assert uploaded["file_tag"] is not None
assert re.search(r"_\d{4}", uploaded["file_tag"]), uploaded["file_tag"]
# ═══════════════════════════════════════════════════════════════════════════════
# Test 1: auto_eval with nytopop/Qwen3-30B-A3B.w4a16
# ═══════════════════════════════════════════════════════════════════════════════
def test_auto_eval():
model_name = "nytopop/Qwen3-30B-A3B.w4a16"
compute_dtype = "INT4 (W4A16)" # what the UI passes
print(f"\n{'#'*70}")
print(f" TEST: auto_eval model={model_name}")
print(f"{'#'*70}")
config, info = inspect_model(model_name)
# ── Step-by-step validation ──────────────────────────────────────────
print(f"\n--- Step 1: Quantization scheme validation ---")
qc = getattr(config, "quantization_config", None) if config else None
is_valid, scheme, detected_method = validate_quantization_scheme(qc, compute_dtype)
print(f" is_valid: {is_valid}")
print(f" scheme: {scheme}")
print(f" detected_method: {detected_method}")
if scheme:
print(f" scheme.name: {scheme.name}")
print(f" scheme.precision: {scheme.precision}")
print(f" scheme.weight_dtype: {scheme.weight_dtype}")
print(f" scheme.bits: {scheme.bits}")
print(f" scheme.hardware: {scheme.hardware}")
print(f" scheme.script: {scheme.script}")
# ── Step 2: Model size ───────────────────────────────────────────────
print(f"\n--- Step 2: Model size (get_quantized_model_parameters_memory) ---")
if info:
quant_method = detected_method.lower() if detected_method else ""
precision = scheme.precision if scheme else "4bit"
print(f" quant_method arg: '{quant_method}'")
print(f" bits arg: '{precision}'")
print(f" KNOWN_SIZE_FACTOR has '{quant_method}': {quant_method in KNOWN_SIZE_FACTOR}")
params_b, size_gb = get_quantized_model_parameters_memory(
info, quant_method=quant_method, bits=precision
)
print(f" params_b: {params_b}")
print(f" size_gb: {size_gb}")
# Also test get_model_size for comparison
print(f"\n--- Step 2b: get_model_size (FP-style) ---")
params_b2, size_gb2 = get_model_size(info, precision=precision)
print(f" params_b: {params_b2}")
print(f" size_gb: {size_gb2}")
else:
params_b = None
# ── Step 3: VRAM estimation ──────────────────────────────────────────
print(f"\n--- Step 3: VRAM estimation ---")
if params_b:
bits = PRECISION_TO_BITS.get(precision, 4)
est_mem = estimate_weight_memory_gb(params_b, bits=bits, overhead_factor=4.4)
print(f" bits: {bits}")
print(f" estimated_vram: {est_mem} GB")
else:
print(f" SKIPPED (no params)")
est_mem = None
# ── Step 4: GPU selection ────────────────────────────────────────────
print(f"\n--- Step 4: GPU selection ---")
if est_mem:
gpu_type, gpu_nums = select_gpu(est_mem)
print(f" gpu_type: {gpu_type}")
print(f" gpu_nums: {gpu_nums}")
# ── Step 5: Call add_new_eval end-to-end ─────────────────────────────
print(f"\n--- Step 5: add_new_eval (end-to-end) ---")
result = _consume_generator(add_new_eval(
model=model_name,
revision="main",
private=False,
compute_dtype=compute_dtype,
))
print(f" Result: {result}")
# ── Validate expected values ─────────────────────────────────────────
print(f"\n--- Validation checks ---")
errors = []
if not is_valid:
errors.append(f"FAIL: Model should be detected as W4A16 quantized but is_valid={is_valid}")
if params_b is not None:
if not (25 <= params_b <= 35):
errors.append(f"WARN: Expected params ~30B, got {params_b}B")
else:
errors.append("FAIL: params_b is None")
if _is_error(result):
errors.append(f"FAIL: add_new_eval returned error: {result[:200]}")
if errors:
for e in errors:
print(f" ❌ {e}")
else:
print(f" βœ… All checks passed")
return len(errors) == 0
# ═══════════════════════════════════════════════════════════════════════════════
# Test 2: auto_quant with Qwen/Qwen3-30B-A3B
# ═══════════════════════════════════════════════════════════════════════════════
def test_auto_quant():
model_name = "Qwen/Qwen3-30B-A3B"
quant_scheme = "INT4 (W4A16)"
print(f"\n{'#'*70}")
print(f" TEST: auto_quant model={model_name}")
print(f"{'#'*70}")
config, info = inspect_model(model_name)
# ── Step 1: Confirm NOT quantized ────────────────────────────────────
print(f"\n--- Step 1: Confirm model is FP (not quantized) ---")
qc = getattr(config, "quantization_config", None) if config else None
print(f" quantization_config: {qc}")
if qc:
print(f" ❌ Model appears quantized β€” auto_quant should reject it")
torch_dtype = getattr(config, "torch_dtype", None)
input_dtype = str(torch_dtype) if torch_dtype else "float16"
input_bits = SUPPORTED_INPUT_DTYPES.get(input_dtype)
print(f" torch_dtype: {torch_dtype}")
print(f" input_dtype: {input_dtype}")
print(f" input_bits: {input_bits}")
# ── Step 2: Model size (FP) ──────────────────────────────────────────
print(f"\n--- Step 2: Model size (FP) ---")
params_b = None
size_gb = None
if info:
fp_label = "16bit" if input_bits == 16 else "32bit"
params_b, size_gb = get_model_size(info, precision=fp_label)
print(f" precision arg: '{fp_label}'")
print(f" params_b: {params_b}")
print(f" size_gb: {size_gb}")
# ── Step 3: Layer count ──────────────────────────────────────────────
print(f"\n--- Step 3: Layer count ---")
num_layers = get_num_layers(config) if config else None
print(f" num_layers: {num_layers}")
# ── Step 4: Quantization VRAM ────────────────────────────────────────
print(f"\n--- Step 4: Quantization VRAM ---")
quant_mem = None
if size_gb and num_layers:
quant_mem = estimate_quantization_memory_gb(size_gb, num_layers, overhead_factor=1.5)
print(f" model_weight_gb: {size_gb}")
print(f" num_layers: {num_layers}")
print(f" quant_vram: {quant_mem} GB")
else:
print(f" SKIPPED (size_gb={size_gb}, num_layers={num_layers})")
# ── Step 5: Eval VRAM (post-quantization) ────────────────────────────
print(f"\n--- Step 5: Eval VRAM (post-quant W4A16) ---")
scheme = SUPPORTED_QUANT_SCHEMES.get(quant_scheme)
eval_mem = None
if params_b and scheme:
eval_mem = estimate_weight_memory_gb(params_b, bits=scheme.bits, overhead_factor=4.4)
print(f" params_b: {params_b}")
print(f" output_bits: {scheme.bits}")
print(f" eval_vram: {eval_mem} GB")
quant_model_size_gb = round(params_b * (scheme.bits / 8.0), 2)
print(f" quant_model_size_gb: {quant_model_size_gb}")
# ── Step 6: GPU selection ────────────────────────────────────────────
print(f"\n--- Step 6: GPU selection ---")
if quant_mem:
qgpu, qn = select_gpu(quant_mem)
print(f" Quantization: {qgpu} Γ— {qn}")
if eval_mem:
egpu, en = select_gpu(eval_mem)
print(f" Evaluation: {egpu} Γ— {en}")
# ── Step 7: Call add_new_quant end-to-end ────────────────────────────
print(f"\n--- Step 7: add_new_quant (end-to-end) ---")
result = _consume_generator(add_new_quant(
model=model_name,
revision="main",
private=False,
quant_scheme=quant_scheme,
))
print(f" Result: {result}")
# ── Validate expected values ─────────────────────────────────────────
print(f"\n--- Validation checks ---")
errors = []
if qc:
errors.append("FAIL: FP model has quantization_config β€” auto_quant should reject")
if params_b is not None:
if not (25 <= params_b <= 35):
errors.append(f"WARN: Expected params ~30B, got {params_b}B")
else:
errors.append("FAIL: params_b is None")
if input_bits is None:
errors.append(f"FAIL: input_dtype '{input_dtype}' not in SUPPORTED_INPUT_DTYPES")
if num_layers is None or num_layers <= 0:
errors.append(f"FAIL: Could not determine num_layers: {num_layers}")
if _is_error(result):
errors.append(f"FAIL: add_new_quant returned error: {result[:200]}")
if errors:
for e in errors:
print(f" ❌ {e}")
else:
print(f" βœ… All checks passed")
return len(errors) == 0
# ═══════════════════════════════════════════════════════════════════════════════
# Test 3: Cross-check β€” call auto_eval on FP model (should fail)
# ═══════════════════════════════════════════════════════════════════════════════
def test_auto_eval_rejects_fp_model():
"""auto_eval should reject an FP (non-quantized) model."""
model_name = "Qwen/Qwen3-30B-A3B"
print(f"\n{'#'*70}")
print(f" TEST: auto_eval should REJECT FP model: {model_name}")
print(f"{'#'*70}")
result = _consume_generator(add_new_eval(
model=model_name,
revision="main",
private=False,
compute_dtype="INT4 (W4A16)",
))
print(f" Result: {result}")
if "color: red" in result or "color:red" in result:
print(f" βœ… Correctly rejected FP model")
return True
else:
print(f" ❌ FAIL: Should have rejected FP model but got: {result[:200]}")
return False
# ═══════════════════════════════════════════════════════════════════════════════
# Test 4: Cross-check β€” call auto_quant on quantized model (should fail)
# ═══════════════════════════════════════════════════════════════════════════════
def test_auto_quant_rejects_quantized_model():
"""auto_quant should reject an already-quantized model."""
model_name = "nytopop/Qwen3-30B-A3B.w4a16"
print(f"\n{'#'*70}")
print(f" TEST: auto_quant should REJECT quantized model: {model_name}")
print(f"{'#'*70}")
result = _consume_generator(add_new_quant(
model=model_name,
revision="main",
private=False,
quant_scheme="INT4 (W4A16)",
))
print(f" Result: {result}")
if _is_error(result) or _is_warning(result):
print(f" βœ… Correctly rejected quantized model")
return True
else:
print(f" ❌ FAIL: Should have rejected quantized model but got: {result[:200]}")
return False
# ═══════════════════════════════════════════════════════════════════════════════
# Main
# ═══════════════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
print("=" * 70)
print(" submit.py Unit Tests")
print("=" * 70)
results = {}
results["test_auto_eval"] = test_auto_eval()
results["test_auto_quant"] = test_auto_quant()
results["test_auto_eval_rejects_fp"] = test_auto_eval_rejects_fp_model()
results["test_auto_quant_rejects_quantized"] = test_auto_quant_rejects_quantized_model()
print(f"\n{'='*70}")
print(" SUMMARY")
print(f"{'='*70}")
for name, passed in results.items():
status = "βœ… PASS" if passed else "❌ FAIL"
print(f" {status} {name}")
total = len(results)
passed = sum(1 for v in results.values() if v)
print(f"\n {passed}/{total} tests passed")
print(f"{'='*70}")