firstAI / test_enhanced_fallback.py
ndc8
try
db8cd85
raw
history blame
3.36 kB
#!/usr/bin/env python3
"""
Test script to verify enhanced fallback mechanisms for pre-quantized models.
This simulates the production deployment scenario where bitsandbytes package metadata is missing.
"""
import sys
import logging
import os
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_pre_quantized_model_fallback():
"""Test loading a pre-quantized model without bitsandbytes package metadata."""
logger.info("πŸ§ͺ Testing enhanced fallback for pre-quantized models...")
# Set the problematic model as environment variable
os.environ["AI_MODEL"] = "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit"
try:
from backend_service import current_model, get_quantization_config
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
logger.info(f"πŸ“ Testing model: {current_model}")
# Test quantization detection
quant_config = get_quantization_config(current_model)
if quant_config:
logger.info(f"βœ… Quantization config detected: {type(quant_config).__name__}")
else:
logger.info("πŸ“ No quantization config (bitsandbytes not available)")
# Test the enhanced fallback mechanism
logger.info("πŸ”§ Testing enhanced config-based fallback...")
try:
# This simulates what happens in the lifespan function
config = AutoConfig.from_pretrained(current_model, trust_remote_code=True)
logger.info(f"βœ… Successfully loaded config: {type(config).__name__}")
# Check for quantization config in the model config
if hasattr(config, 'quantization_config'):
logger.info(f"πŸ” Found quantization_config in model config: {config.quantization_config}")
# Remove it to prevent bitsandbytes errors
config.quantization_config = None
logger.info("🚫 Removed quantization_config from model config")
else:
logger.info("πŸ“ No quantization_config found in model config")
# Test tokenizer loading
logger.info("πŸ“₯ Testing tokenizer loading...")
tokenizer = AutoTokenizer.from_pretrained(current_model)
logger.info(f"βœ… Tokenizer loaded successfully: {len(tokenizer)} tokens")
# Note: We won't actually load the full model in the test to save time/memory
logger.info("βœ… Enhanced fallback mechanism validated successfully!")
return True
except Exception as e:
logger.error(f"❌ Enhanced fallback test failed: {e}")
return False
except Exception as e:
logger.error(f"❌ Test setup failed: {e}")
return False
if __name__ == "__main__":
logger.info("πŸš€ Starting enhanced fallback mechanism test...")
success = test_pre_quantized_model_fallback()
if success:
logger.info("\nπŸŽ‰ Enhanced fallback test passed!")
logger.info("πŸ’‘ The deployment should now handle pre-quantized models correctly")
else:
logger.error("\n❌ Enhanced fallback test failed")
sys.exit(0 if success else 1)