Spaces:
Runtime error
Runtime error
| # Copyright © 2023 Apple Inc. | |
| import glob | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Generator | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| import models.phi2 as phi2 | |
| import transformers | |
| from huggingface_hub import snapshot_download | |
| from transformers import AutoTokenizer | |
| # Constants | |
| MODEL_MAPPING = { | |
| "phi": phi2, | |
| } | |
| def _get_classes(config: dict): | |
| """ | |
| Retrieve the model and model args classes based on the configuration. | |
| Args: | |
| config (dict): The model configuration. | |
| Returns: | |
| A tuple containing the Model class and the ModelArgs class. | |
| """ | |
| model_type = config["model_type"] | |
| if model_type not in MODEL_MAPPING: | |
| msg = f"Model type {model_type} not supported." | |
| logging.error(msg) | |
| raise ValueError(msg) | |
| arch = MODEL_MAPPING[model_type] | |
| return arch.Model, arch.ModelArgs | |
| def fetch_from_hub(hf_path: str): | |
| model_path = snapshot_download( | |
| repo_id=hf_path, | |
| allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], | |
| ) | |
| weight_files = glob.glob(f"{model_path}/*.safetensors") | |
| if len(weight_files) == 0: | |
| raise FileNotFoundError("No safetensors found in {}".format(model_path)) | |
| weights = {} | |
| for wf in weight_files: | |
| weights.update(mx.load(wf).items()) | |
| config = transformers.AutoConfig.from_pretrained(hf_path) | |
| tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| hf_path, | |
| ) | |
| return weights, config.to_dict(), tokenizer | |
| def make_shards(weights: dict, max_file_size_gibibyte: int = 15): | |
| max_file_size_bytes = max_file_size_gibibyte << 30 | |
| shards = [] | |
| shard, shard_size = {}, 0 | |
| for k, v in weights.items(): | |
| if shard_size + v.nbytes > max_file_size_bytes: | |
| shards.append(shard) | |
| shard, shard_size = {}, 0 | |
| shard[k] = v | |
| shard_size += v.nbytes | |
| shards.append(shard) | |
| return shards | |
| def save_model(save_dir: str, weights, tokenizer, config): | |
| save_dir = Path(save_dir) | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| shards = make_shards(weights, max_file_size_gibibyte=5) | |
| shards_count = len(shards) | |
| shard_file_format = ( | |
| "model-{:05d}-of-{:05d}.safetensors" | |
| if shards_count > 1 | |
| else "model.safetensors" | |
| ) | |
| for i, shard in enumerate(shards): | |
| shard_name = shard_file_format.format(i + 1, shards_count) | |
| mx.save_safetensors(str(save_dir / shard_name), shard) | |
| tokenizer.save_pretrained(save_dir) | |
| with open(save_dir / "config.json", "w") as fid: | |
| json.dump(config, fid, indent=4) | |
| def load(path): | |
| model_path = Path(path) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| # Load the config | |
| with open(model_path / "config.json", "r") as f: | |
| config = json.load(f) | |
| # Get the appropriate model and ModelArgs classes | |
| model_class, model_args_class = _get_classes(config) | |
| # Create ModelArgs instance | |
| model_args = model_args_class.from_dict(config) | |
| # Create model instance | |
| model = model_class(model_args) | |
| # Load weights from .safetensors files | |
| weight_files = glob.glob(str(model_path / "*.safetensors")) | |
| if not weight_files: | |
| raise FileNotFoundError(f"No .safetensors files found in {model_path}") | |
| weights = {} | |
| for wf in weight_files: | |
| weights.update(mx.load(wf)) | |
| if "quantization" in config: | |
| print("[INFO] Loading quantized model") | |
| group_size = config["quantization"]["group_size"] | |
| bits = config["quantization"]["bits"] | |
| nn.quantize(model, group_size, bits) | |
| model.load_weights(list(weights.items())) | |
| return model, tokenizer, model_args | |
| def generate( | |
| prompt: mx.array, model: nn.Module, temp: float = 0.0 | |
| ) -> Generator[mx.array, None, None]: | |
| """ | |
| Generate text based on the given prompt and model. | |
| Args: | |
| prompt (mx.array): The input prompt. | |
| model (nn.Module): The model to use for generation. | |
| temp (float): The temperature for sampling. If temp is 0, use max sampling. | |
| Yields: | |
| mx.array: The generated text. | |
| """ | |
| def sample(logits: mx.array) -> mx.array: | |
| return ( | |
| mx.argmax(logits, axis=-1) | |
| if temp == 0 | |
| else mx.random.categorical(logits * (1 / temp)) | |
| ) | |
| y = prompt | |
| cache = None | |
| while True: | |
| logits, cache = model(y[None], cache=cache) | |
| logits = logits[:, -1, :] | |
| y = sample(logits) | |
| yield y | |