"""Model loading utilities.""" from __future__ import annotations from dataclasses import dataclass from typing import Any from transformers import AutoModelForCausalLM, AutoTokenizer from src.config import settings @dataclass class ModelBundle: """Holds tokenizer and model together.""" tokenizer: Any model: Any active_model_name: str is_mock: bool = False load_error: str = "" def load_model_bundle() -> ModelBundle: """Load Qwen2.5-Coder first, then fallback if needed.""" if settings.force_mock_mode: return ModelBundle( tokenizer=None, model=None, active_model_name="mock-rule-based", is_mock=True, load_error="FORCE_MOCK_MODE=true", ) candidate_models = [ settings.model_name, settings.fallback_model_name, settings.final_fallback_model_name, ] last_error = None for model_name in candidate_models: try: tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, device_map="auto", ) return ModelBundle( tokenizer=tokenizer, model=model, active_model_name=model_name, ) except Exception as exc: # pragma: no cover last_error = exc return ModelBundle( tokenizer=None, model=None, active_model_name="mock-rule-based", is_mock=True, load_error=str(last_error), )