Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
| """Interactive/sample generation with the RepoBridge-style SSM inference fix. | |
| This intentionally overrides the checkpoint config at inference time: | |
| - ssm_finite_tail_correction = True | |
| - ssm_kernel_mode = recurrent | |
| Those settings match the temporary chat-quality fix used in RepoBridge Model Chat. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| import time | |
| from pathlib import Path | |
| ROOT = Path(__file__).resolve().parent | |
| TAOTRAIN_SRC = ROOT / "code" / "TaoTrain" / "src" | |
| SSM_SRC = ROOT / "code" / "Taotern_SSM" | |
| for path in (TAOTRAIN_SRC, SSM_SRC): | |
| if str(path) not in sys.path: | |
| sys.path.insert(0, str(path)) | |
| import torch | |
| from taoTrain.checkpointing.checkpoint import CheckpointManager | |
| from taoTrain.config import ModelConfig | |
| from taoTrain.inference.inferencer import Inferencer | |
| from taoTrain.models import get_model | |
| def apply_ssm_overrides(model: torch.nn.Module, *, kernel_mode: str, finite_tail: bool) -> int: | |
| count = 0 | |
| for module in model.modules(): | |
| changed = False | |
| if hasattr(module, "kernel_mode"): | |
| module.kernel_mode = kernel_mode | |
| changed = True | |
| if hasattr(module, "finite_tail_correction"): | |
| module.finite_tail_correction = finite_tail | |
| changed = True | |
| clear = getattr(module, "clear_kernel_cache", None) | |
| if callable(clear): | |
| clear() | |
| if changed: | |
| count += 1 | |
| return count | |
| def load_fixed(checkpoint_path: Path, tokenizer_path: Path, device: torch.device, dtype: torch.dtype): | |
| checkpoint = CheckpointManager(checkpoint_path.parent).load(checkpoint_path, device=device) | |
| config_dict = checkpoint.get("config", {}) | |
| model_config_dict = dict(config_dict.get("model", {})) | |
| model_config_dict["ssm_finite_tail_correction"] = True | |
| model_config_dict["ssm_kernel_mode"] = "recurrent" | |
| model_config = ModelConfig(**model_config_dict) | |
| tokenizer = Inferencer._load_tokenizer(tokenizer_path) | |
| model = get_model(model_config, device=device) | |
| model.load_state_dict(checkpoint["model_state"], strict=False) | |
| model.to(device=device) | |
| model.eval() | |
| override_count = apply_ssm_overrides(model, kernel_mode="recurrent", finite_tail=True) | |
| return model, tokenizer, override_count | |
| def generate( | |
| model, | |
| tokenizer, | |
| prompt: str, | |
| *, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| repetition_penalty: float, | |
| greedy: bool, | |
| ) -> str: | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| prompt_len = input_ids.shape[1] | |
| generated_ids: list[int] = [] | |
| eos_token_id = getattr(tokenizer, "eos_token_id", None) | |
| device_type = "cuda" if device.type == "cuda" else "cpu" | |
| autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} | |
| with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled): | |
| for _ in range(max_new_tokens): | |
| apply_ssm_overrides(model, kernel_mode="recurrent", finite_tail=True) | |
| outputs = model(input_ids=input_ids, attention_mask=torch.ones_like(input_ids), labels=None) | |
| logits = outputs["logits"][:, -1, :] | |
| if not greedy: | |
| logits = logits / max(temperature, 1e-6) | |
| if repetition_penalty != 1.0: | |
| for token_id in torch.unique(input_ids[0, prompt_len:]): | |
| logits[0, token_id] /= repetition_penalty | |
| if greedy: | |
| next_token = torch.argmax(logits, dim=-1, keepdim=True) | |
| else: | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| sorted_probs = torch.softmax(sorted_logits, dim=-1) | |
| cumulative = torch.cumsum(sorted_probs, dim=-1) | |
| remove = cumulative > top_p | |
| remove[..., 1:] = remove[..., :-1].clone() | |
| remove[..., 0] = False | |
| indices_to_remove = sorted_indices[remove] | |
| logits[0, indices_to_remove] = float("-inf") | |
| probs = torch.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| token_id = int(next_token.item()) | |
| if eos_token_id is not None and token_id == eos_token_id: | |
| break | |
| generated_ids.append(token_id) | |
| input_ids = torch.cat([input_ids, next_token], dim=-1) | |
| apply_ssm_overrides(model, kernel_mode="recurrent", finite_tail=True) | |
| return tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--checkpoint", default=str(ROOT / "model" / "pretrain_final_model.pt")) | |
| parser.add_argument("--tokenizer", default=str(ROOT / "tokenizer" / "tokenizer.model")) | |
| parser.add_argument("--device", default="cuda") | |
| parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16") | |
| parser.add_argument("--max-new-tokens", type=int, default=64) | |
| parser.add_argument("--temperature", type=float, default=0.7) | |
| parser.add_argument("--top-p", type=float, default=0.85) | |
| parser.add_argument("--repetition-penalty", type=float, default=1.2) | |
| parser.add_argument("--decode", choices=["greedy", "sample"], default="greedy") | |
| parser.add_argument("--prompt", action="append", default=[]) | |
| parser.add_argument("--output", default=str(ROOT / "artifacts" / "local_test_samples_ssm_fixed.json")) | |
| parser.add_argument("--interactive", action="store_true") | |
| args = parser.parse_args() | |
| checkpoint_path = Path(args.checkpoint) | |
| if not checkpoint_path.exists() and checkpoint_path.name == "pretrain_final_model.pt": | |
| checkpoint_path = ROOT / "model" / "final_model.pt" | |
| tokenizer_path = Path(args.tokenizer) | |
| device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu") | |
| dtype = { | |
| "float32": torch.float32, | |
| "bfloat16": torch.bfloat16, | |
| "float16": torch.float16, | |
| }[args.dtype] | |
| print(f"Loading checkpoint: {checkpoint_path}") | |
| print("SSM fix: ssm_finite_tail_correction=true, ssm_kernel_mode=recurrent") | |
| model, tokenizer, override_count = load_fixed(checkpoint_path, tokenizer_path, device, dtype) | |
| print(f"device={device}") | |
| print(f"ssm_overrides={override_count}") | |
| if args.interactive: | |
| print("Type 'quit' or 'exit' to stop.") | |
| while True: | |
| prompt = input("\nYou: ").strip() | |
| if prompt.lower() in {"quit", "exit"}: | |
| break | |
| if not prompt: | |
| continue | |
| start = time.time() | |
| completion = generate( | |
| model, | |
| tokenizer, | |
| prompt, | |
| device=device, | |
| dtype=dtype, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| repetition_penalty=args.repetition_penalty, | |
| greedy=args.decode == "greedy", | |
| ) | |
| elapsed = time.time() - start | |
| print(f"\nAssistant: {completion}") | |
| print(f"\n[{elapsed:.1f}s]") | |
| return | |
| prompts = args.prompt or [ | |
| "Fruit is now expensive so we should", | |
| "<user>Hello, who are you?<assistant>", | |
| "<user>Explain what artificial intelligence is in simple words.<assistant>", | |
| ] | |
| samples = [] | |
| for prompt in prompts: | |
| completion = generate( | |
| model, | |
| tokenizer, | |
| prompt, | |
| device=device, | |
| dtype=dtype, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| repetition_penalty=args.repetition_penalty, | |
| greedy=args.decode == "greedy", | |
| ) | |
| samples.append({"prompt": prompt, "completion": completion}) | |
| result = { | |
| "checkpoint": str(checkpoint_path), | |
| "tokenizer": str(tokenizer_path), | |
| "device": str(device), | |
| "dtype": str(dtype), | |
| "ssm_finite_tail_correction": True, | |
| "ssm_kernel_mode": "recurrent", | |
| "ssm_overrides": override_count, | |
| "decode": args.decode, | |
| "temperature": args.temperature, | |
| "top_p": args.top_p, | |
| "repetition_penalty": args.repetition_penalty, | |
| "max_new_tokens": args.max_new_tokens, | |
| "samples": samples, | |
| } | |
| output = Path(args.output) | |
| output.parent.mkdir(parents=True, exist_ok=True) | |
| output.write_text(json.dumps(result, indent=2, ensure_ascii=False), encoding="utf-8") | |
| print(json.dumps(result, indent=2, ensure_ascii=False)) | |
| if __name__ == "__main__": | |
| main() | |