Spaces:
Sleeping
Sleeping
| """ | |
| utils/model_loader.py | |
| βββββββββββββββββββββ | |
| Load a local (or HF Hub) model via HuggingFace Transformers. | |
| Returns a model_bundle dict that every benchmark consumes. | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| from typing import Any | |
| DTYPE_MAP = { | |
| "float32": torch.float32, | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| } | |
| def load_model( | |
| model_path: str, | |
| device: str = "auto", | |
| dtype: str = "bfloat16", | |
| model_type: str | None = "auto", | |
| ) -> dict[str, Any]: | |
| """ | |
| Load model + tokenizer from a local path or HF Hub ID. | |
| Returns | |
| ------- | |
| model_bundle : dict with keys | |
| model β the loaded AutoModelForCausalLM | |
| tokenizer β the matching AutoTokenizer | |
| device β resolved torch device string | |
| dtype β resolved torch dtype | |
| param_count β float (billions) | |
| model_path β original path string | |
| generate_fn β convenience callable (prompt β str) | |
| """ | |
| if model_type and model_type not in ("auto", "hf"): | |
| raise ValueError( | |
| f"Unsupported model_type {model_type!r}; use 'auto' or 'hf'." | |
| ) | |
| # ββ lazy imports so the module is importable without torch installed ββββββ | |
| try: | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| ) | |
| except ImportError as e: | |
| raise ImportError( | |
| "transformers is required: pip install transformers accelerate" | |
| ) from e | |
| model_path = str(model_path) | |
| # ββ Quantization config βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| quant_cfg = None | |
| torch_dtype = DTYPE_MAP.get(dtype, torch.bfloat16) | |
| if dtype == "int4": | |
| quant_cfg = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| torch_dtype = None | |
| elif dtype == "int8": | |
| quant_cfg = BitsAndBytesConfig(load_in_8bit=True) | |
| torch_dtype = None | |
| # ββ Load tokenizer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| padding_side="left", | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # ββ Load model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map=device, | |
| torch_dtype=torch_dtype, | |
| quantization_config=quant_cfg, | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| # ββ Parameter count βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| param_count = sum(p.numel() for p in model.parameters()) / 1e9 | |
| # ββ Resolve actual device βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| resolved_device = next(model.parameters()).device | |
| # ββ Convenience generate function βββββββββββββββββββββββββββββββββββββββββ | |
| def generate_fn( | |
| prompt: str, | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.0, | |
| stop_strings: list[str] | None = None, | |
| ) -> str: | |
| """Run inference and return the decoded completion (without prompt).""" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(resolved_device) | |
| gen_kwargs: dict[str, Any] = dict( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| if temperature > 0: | |
| gen_kwargs.update(do_sample=True, temperature=temperature, top_p=0.95) | |
| else: | |
| gen_kwargs["do_sample"] = False | |
| with torch.no_grad(): | |
| output_ids = model.generate(**gen_kwargs) | |
| # Strip the input tokens from output | |
| new_ids = output_ids[0][inputs["input_ids"].shape[1]:] | |
| return tokenizer.decode(new_ids, skip_special_tokens=True).strip() | |
| return { | |
| "model": model, | |
| "tokenizer": tokenizer, | |
| "device": str(resolved_device), | |
| "dtype": dtype, | |
| "param_count": param_count, | |
| "model_path": model_path, | |
| "model_type": "hf", | |
| "generate_fn": generate_fn, | |
| } | |