MobileLLM-R1-950M-MLX / custom_mlx_lm /inference_mlx_lm.py
robbiemu's picture
add mlx and mlx-lm support
e39ff3a
# custom_mlx_lm/inference_mlx_lm.py
import argparse
import os
import sys
import time
import mlx.core as mx
from transformers import AutoTokenizer
from pathlib import Path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if project_root not in sys.path:
sys.path.insert(0, project_root)
# Use the robust universal loader
from custom_mlx_lm.custom_loader import load_model
def generate_text(
prompt: str,
model_path: str,
max_tokens: int = 100,
temperature: float = 0.1,
top_p: float = 0.9,
# Add other parameters from your original inference.py if needed
):
"""
Generates text using the loaded MLX model with the robust custom sampler.
This logic is adapted from your proven inference.py script.
"""
print("Loading model and tokenizer using custom loader...")
model = load_model(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Align prompt handling with inference.py: prefer chat template, else prepend BOS
chat_template_path = Path(model_path) / "chat_template.jinja"
use_chat_format = chat_template_path.exists()
if use_chat_format:
messages = [{"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
bos = tokenizer.bos_token or ""
formatted_prompt = f"{bos}{prompt}"
print("Starting generation...")
prompt_tokens = tokenizer.encode(formatted_prompt, add_special_tokens=False)
prompt_tokens = mx.array([prompt_tokens])
start_time = time.time()
generated_tokens = []
for i in range(max_tokens):
logits = model(prompt_tokens)
next_token_logits = logits[0, -1, :]
if temperature == 0:
next_token = int(mx.argmax(next_token_logits).item())
else:
scaled_logits = next_token_logits / temperature
if 0.0 < top_p < 1.0:
probs = mx.softmax(scaled_logits, axis=-1)
sorted_probs = mx.sort(probs)[::-1]
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
cutoff_index = mx.sum(cumulative_probs < top_p)
cutoff_prob = sorted_probs[cutoff_index.item()]
mask = probs >= cutoff_prob
scaled_logits = mx.where(mask, scaled_logits, float("-inf"))
next_token = mx.random.categorical(scaled_logits, num_samples=1).item()
eos_ids = tokenizer.eos_token_id
stop_ids = (
{int(i) for i in eos_ids} if isinstance(eos_ids, list) else {int(eos_ids)}
)
if next_token in stop_ids:
break
generated_tokens.append(next_token)
prompt_tokens = mx.concatenate(
[prompt_tokens, mx.array([[next_token]])], axis=1
)
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
print("\n--- Response ---")
print(response)
print("------------------")
generation_speed = (
len(generated_tokens) / (time.time() - start_time) if generated_tokens else 0
)
print(
f"Generated {len(generated_tokens)} tokens at {generation_speed:.2f} tokens/sec"
)
def main():
parser = argparse.ArgumentParser(
description="Run inference on converted MLX models."
)
parser.add_argument(
"--model-path",
type=str,
required=True,
help="Path to the converted MLX model directory.",
)
parser.add_argument(
"--prompt",
type=str,
default="What is the capital of France?",
help="The prompt.",
)
parser.add_argument(
"--max-tokens", type=int, default=100, help="Max tokens to generate."
)
parser.add_argument(
"--temperature", type=float, default=0.1, help="Sampling temperature."
)
parser.add_argument("--top-p", type=float, default=0.9, help="Top-p sampling.")
args = parser.parse_args()
generate_text(
args.prompt,
args.model_path,
args.max_tokens,
args.temperature,
args.top_p,
)
if __name__ == "__main__":
main()