File size: 1,892 Bytes
e75c925 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
def load_model_with_quantization_fallback(
model_name: str = "deepseek-ai/DeepSeek-R1",
trust_remote_code: bool = True,
device_map: Optional[Union[str, Dict[str, Any]]] = "auto",
**kwargs
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
try:
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
device_map=device_map,
**kwargs
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info("Model loaded successfully with original configuration")
return model, tokenizer
except ValueError as e:
if "Unknown quantization type" in str(e):
logger.warning(
"Quantization type not supported directly. "
"Attempting to load without quantization..."
)
config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=trust_remote_code
)
if hasattr(config, "quantization_config"):
delattr(config, "quantization_config")
try:
model = AutoModel.from_pretrained(
model_name,
config=config,
trust_remote_code=trust_remote_code,
device_map=device_map,
**kwargs
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=trust_remote_code
)
logger.info("Model loaded successfully without quantization")
return model, tokenizer
except Exception as inner_e:
logger.error(f"Failed to load model without quantization: {str(inner_e)}")
raise
else:
logger.error(f"Unexpected error during model loading: {str(e)}")
raise |