| |
| import argparse |
| import os |
| import sys |
| import time |
|
|
| import mlx.core as mx |
| from transformers import AutoTokenizer |
| from pathlib import Path |
|
|
| project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| if project_root not in sys.path: |
| sys.path.insert(0, project_root) |
|
|
| |
| from custom_mlx_lm.custom_loader import load_model |
|
|
|
|
| def generate_text( |
| prompt: str, |
| model_path: str, |
| max_tokens: int = 100, |
| temperature: float = 0.1, |
| top_p: float = 0.9, |
| |
| ): |
| """ |
| Generates text using the loaded MLX model with the robust custom sampler. |
| This logic is adapted from your proven inference.py script. |
| """ |
| print("Loading model and tokenizer using custom loader...") |
| model = load_model(model_path) |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
| |
| chat_template_path = Path(model_path) / "chat_template.jinja" |
| use_chat_format = chat_template_path.exists() |
|
|
| if use_chat_format: |
| messages = [{"role": "user", "content": prompt}] |
| formatted_prompt = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| else: |
| bos = tokenizer.bos_token or "" |
| formatted_prompt = f"{bos}{prompt}" |
|
|
| print("Starting generation...") |
| prompt_tokens = tokenizer.encode(formatted_prompt, add_special_tokens=False) |
| prompt_tokens = mx.array([prompt_tokens]) |
|
|
| start_time = time.time() |
| generated_tokens = [] |
| for i in range(max_tokens): |
| logits = model(prompt_tokens) |
| next_token_logits = logits[0, -1, :] |
|
|
| if temperature == 0: |
| next_token = int(mx.argmax(next_token_logits).item()) |
| else: |
| scaled_logits = next_token_logits / temperature |
| if 0.0 < top_p < 1.0: |
| probs = mx.softmax(scaled_logits, axis=-1) |
| sorted_probs = mx.sort(probs)[::-1] |
| cumulative_probs = mx.cumsum(sorted_probs, axis=-1) |
| cutoff_index = mx.sum(cumulative_probs < top_p) |
| cutoff_prob = sorted_probs[cutoff_index.item()] |
| mask = probs >= cutoff_prob |
| scaled_logits = mx.where(mask, scaled_logits, float("-inf")) |
| next_token = mx.random.categorical(scaled_logits, num_samples=1).item() |
|
|
| eos_ids = tokenizer.eos_token_id |
| stop_ids = ( |
| {int(i) for i in eos_ids} if isinstance(eos_ids, list) else {int(eos_ids)} |
| ) |
| if next_token in stop_ids: |
| break |
|
|
| generated_tokens.append(next_token) |
| prompt_tokens = mx.concatenate( |
| [prompt_tokens, mx.array([[next_token]])], axis=1 |
| ) |
|
|
| response = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| print("\n--- Response ---") |
| print(response) |
| print("------------------") |
| generation_speed = ( |
| len(generated_tokens) / (time.time() - start_time) if generated_tokens else 0 |
| ) |
| print( |
| f"Generated {len(generated_tokens)} tokens at {generation_speed:.2f} tokens/sec" |
| ) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Run inference on converted MLX models." |
| ) |
| parser.add_argument( |
| "--model-path", |
| type=str, |
| required=True, |
| help="Path to the converted MLX model directory.", |
| ) |
| parser.add_argument( |
| "--prompt", |
| type=str, |
| default="What is the capital of France?", |
| help="The prompt.", |
| ) |
| parser.add_argument( |
| "--max-tokens", type=int, default=100, help="Max tokens to generate." |
| ) |
| parser.add_argument( |
| "--temperature", type=float, default=0.1, help="Sampling temperature." |
| ) |
| parser.add_argument("--top-p", type=float, default=0.9, help="Top-p sampling.") |
| args = parser.parse_args() |
|
|
| generate_text( |
| args.prompt, |
| args.model_path, |
| args.max_tokens, |
| args.temperature, |
| args.top_p, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|