| """HF Inference Endpoint custom handler for NX-AI/xLSTM-7b. |
| |
| Deploys matrix-memory recurrent architecture (Beck et al. 2024) via the HF |
| Endpoints custom-handler interface. xLSTM introduces mLSTM (matrix-memory |
| long short-term memory) and sLSTM (exponential-gating scalar LSTM) blocks, |
| representing a non-SSM non-attention recurrent family. |
| |
| Input schema (Bench 1.6-A concatenated completion format): |
| { |
| "inputs": "<flat text prompt with system + user turns concatenated>", |
| "parameters": { |
| "max_new_tokens": 512, |
| "temperature": 0.7, |
| "top_p": 0.95, |
| "do_sample": true, |
| } |
| } |
| |
| Output schema: |
| { |
| "generated_text": "<model completion>", |
| "input_tokens": <int>, |
| "output_tokens": <int>, |
| "model": "NX-AI/xLSTM-7b" |
| } |
| |
| Preregistered per docs/BENCH-1.6A-PREREG-V1.1-AMENDMENT.md as Cell A3. |
| Base-model asymmetry (v1.0 §5.5) applies: xLSTM-7b is a base model with no |
| instruction tuning, receives completion-format prompts. |
| """ |
| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| MODEL_ID = "NX-AI/xLSTM-7b" |
|
|
|
|
| class EndpointHandler: |
| """HF Endpoints custom handler entry point.""" |
|
|
| def __init__(self, path: str = "") -> None: |
| self.model_id = MODEL_ID |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained( |
| self.model_id, |
| trust_remote_code=True, |
| ) |
| |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| self.model_id, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
| self.model.eval() |
|
|
| def __call__(self, data: dict[str, Any]) -> dict[str, Any]: |
| prompt: str = data.get("inputs", "") |
| params: dict[str, Any] = data.get("parameters", {}) or {} |
|
|
| max_new_tokens: int = int(params.get("max_new_tokens", 512)) |
| temperature: float = float(params.get("temperature", 0.7)) |
| top_p: float = float(params.get("top_p", 0.95)) |
| do_sample: bool = bool(params.get("do_sample", True)) |
|
|
| if not prompt: |
| return { |
| "generated_text": "", |
| "input_tokens": 0, |
| "output_tokens": 0, |
| "model": self.model_id, |
| "error": "empty_input", |
| } |
|
|
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
| input_tokens = int(inputs["input_ids"].shape[-1]) |
|
|
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature if do_sample else 1.0, |
| top_p=top_p, |
| do_sample=do_sample, |
| pad_token_id=self.tokenizer.eos_token_id |
| if self.tokenizer.pad_token_id is None |
| else self.tokenizer.pad_token_id, |
| ) |
|
|
| full_text = self.tokenizer.decode( |
| outputs[0], |
| skip_special_tokens=True, |
| ) |
| generated_only = full_text[len(prompt):] if full_text.startswith(prompt) else full_text |
| output_tokens = int(outputs.shape[-1]) - input_tokens |
|
|
| return { |
| "generated_text": generated_only, |
| "input_tokens": input_tokens, |
| "output_tokens": output_tokens, |
| "model": self.model_id, |
| } |
|
|