Spaces:
Running
Running
File size: 1,725 Bytes
07a91a1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | """Model loading utilities."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.config import settings
@dataclass
class ModelBundle:
"""Holds tokenizer and model together."""
tokenizer: Any
model: Any
active_model_name: str
is_mock: bool = False
load_error: str = ""
def load_model_bundle() -> ModelBundle:
"""Load Qwen2.5-Coder first, then fallback if needed."""
if settings.force_mock_mode:
return ModelBundle(
tokenizer=None,
model=None,
active_model_name="mock-rule-based",
is_mock=True,
load_error="FORCE_MOCK_MODE=true",
)
candidate_models = [
settings.model_name,
settings.fallback_model_name,
settings.final_fallback_model_name,
]
last_error = None
for model_name in candidate_models:
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
device_map="auto",
)
return ModelBundle(
tokenizer=tokenizer,
model=model,
active_model_name=model_name,
)
except Exception as exc: # pragma: no cover
last_error = exc
return ModelBundle(
tokenizer=None,
model=None,
active_model_name="mock-rule-based",
is_mock=True,
load_error=str(last_error),
)
|