Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
| """Background tokenization queue for streaming large JSONL datasets.""" | |
| import queue | |
| import threading | |
| import time | |
| from typing import Dict, List, Optional, Any, Callable | |
| import torch | |
| from taoTrain.data.chunk_manager import ChunkManager | |
| class TokenizationQueue: | |
| """ | |
| Background threads that continuously tokenize chunks and stores them in a queue. | |
| This allows tokenization to happen in parallel with training, avoiding the bottleneck | |
| of tokenizing all data upfront before training starts. | |
| Supports multiple worker threads for faster throughput. Each thread greedily | |
| grabs the next available chunk using an atomic counter. | |
| Attributes: | |
| total_items: Total number of samples across all chunks | |
| queue_size: Maximum number of chunks to buffer in memory | |
| num_threads: Number of worker threads for tokenization | |
| """ | |
| def __init__( | |
| self, | |
| chunk_manager: ChunkManager, | |
| tokenizer: Any, | |
| config: "TrainingConfig", # type: ignore | |
| max_queue_size: int = 2, | |
| shuffle_chunks: bool = True, | |
| num_threads: int = 1, | |
| ): | |
| """ | |
| Initialize tokenization queue with multithreading support. | |
| Args: | |
| chunk_manager: ChunkManager instance loaded with chunks | |
| tokenizer: Tokenizer instance (HuggingFace or SentencePiece wrapper) | |
| config: Training configuration with model and dataset settings | |
| max_queue_size: Maximum chunks to buffer in queue (memory constraint) | |
| shuffle_chunks: Whether to shuffle chunk order at initialization | |
| num_threads: Number of worker threads for tokenization (default: 1) | |
| Raises: | |
| ValueError: If chunk_manager has no chunks or num_threads < 1 | |
| """ | |
| if chunk_manager.num_chunks == 0: | |
| raise ValueError("ChunkManager must have at least one chunk") | |
| if num_threads < 1: | |
| raise ValueError(f"num_threads must be >= 1, got {num_threads}") | |
| self.chunk_manager = chunk_manager | |
| self.tokenizer = tokenizer | |
| self.config = config | |
| self.max_queue_size = max_queue_size | |
| self.shuffle_chunks = shuffle_chunks | |
| self.num_threads = num_threads | |
| # Detect SFT mode: check for response_loss_only flag | |
| self.is_sft_mode = hasattr(config, 'response_loss_only') and config.response_loss_only | |
| # Calculate total items across all chunks | |
| self.total_items = chunk_manager.effective_lines | |
| # Thread-safe queue for tokenized chunks | |
| self._queue: queue.Queue[Dict[str, List]] = queue.Queue(maxsize=max_queue_size) | |
| # Control signals | |
| self._stop_event = threading.Event() | |
| self._error_event = threading.Event() | |
| self._error_messages: List[str] = [] | |
| self._threads: List[threading.Thread] = [] | |
| # Thread-safe chunk distribution | |
| self._next_chunk_idx = 0 | |
| self._chunk_idx_lock = threading.Lock() | |
| self._active_threads = 0 | |
| self._active_threads_lock = threading.Lock() | |
| # Chunk ordering | |
| self._chunk_order = list(range(chunk_manager.num_chunks)) | |
| print(f"TokenizationQueue initialized with {chunk_manager.num_chunks} chunks, total {chunk_manager.effective_lines} samples") | |
| print(f"Using {num_threads} tokenization worker thread{'s' if num_threads != 1 else ''}") | |
| print(f"Max queue size: {max_queue_size} chunks (memory constraint)") | |
| if self.shuffle_chunks: | |
| import random | |
| random.shuffle(self._chunk_order) | |
| def _get_next_chunk_idx(self) -> Optional[int]: | |
| """ | |
| Atomically get the next chunk index for processing. | |
| Returns: | |
| Chunk index to process, or None if all chunks have been assigned | |
| """ | |
| with self._chunk_idx_lock: | |
| if self._next_chunk_idx < len(self._chunk_order): | |
| chunk_idx = self._chunk_order[self._next_chunk_idx] | |
| self._next_chunk_idx += 1 | |
| return chunk_idx | |
| return None | |
| def start(self): | |
| """Start the tokenization background worker threads.""" | |
| if self._threads: | |
| raise RuntimeError(f"Tokenization threads already started ({len(self._threads)} active)") | |
| # Create and start N worker threads | |
| for thread_id in range(self.num_threads): | |
| thread = threading.Thread(target=self._worker, args=(thread_id,), daemon=False) | |
| self._threads.append(thread) | |
| thread.start() | |
| def _worker(self, thread_id: int): | |
| """ | |
| Worker thread target: greedy chunk processing with thread-safe distribution. | |
| Args: | |
| thread_id: Identifier for this worker thread | |
| """ | |
| with self._active_threads_lock: | |
| self._active_threads += 1 | |
| try: | |
| while True: | |
| # Check for stop signal | |
| if self._stop_event.is_set(): | |
| break | |
| # Get next chunk to process (atomic operation) | |
| chunk_num = self._get_next_chunk_idx() | |
| if chunk_num is None: | |
| # All chunks assigned | |
| break | |
| # Load chunk | |
| chunk_examples = self.chunk_manager.read_chunk(chunk_num) | |
| # Tokenize chunk based on mode | |
| if self.is_sft_mode: | |
| tokenized_chunk = self._tokenize_batch_sft(chunk_examples) | |
| else: | |
| # Extract texts for pretrain | |
| text_field = self.config.dataset.text_field | |
| texts = [obj.get(text_field, "") for obj in chunk_examples] | |
| tokenized_chunk = self._tokenize_batch(texts) | |
| # Put in queue (blocks if queue is full) | |
| self._queue.put(tokenized_chunk) | |
| print(f"[Worker-{thread_id}] Processed chunk {chunk_num}, put {len(tokenized_chunk['input_ids'])} samples in queue") | |
| except Exception as e: | |
| error_msg = f"[Worker-{thread_id}] {str(e)}" | |
| print(f"Worker-{thread_id} encountered an error: {error_msg}") | |
| # Thread-safe append to error list | |
| self._error_messages.append(error_msg) | |
| self._error_event.set() | |
| finally: | |
| with self._active_threads_lock: | |
| self._active_threads -= 1 | |
| remaining = self._active_threads | |
| print(f"[Worker-{thread_id}] Finished processing. Active threads remaining: {remaining}") | |
| def _tokenize_batch(self, texts: List[str]) -> Dict[str, List]: | |
| """ | |
| Tokenize a batch of texts, join with EOS, and split into fixed-size sequences. | |
| This packs multiple documents into longer sequences separated by EOS tokens, | |
| then splits the concatenated tokens into N fixed-size chunks of max_seq_length. | |
| Args: | |
| texts: List of text strings | |
| Returns: | |
| Dict with 'input_ids' and 'attention_mask' lists, where each element | |
| is a fixed-size sequence of length max_seq_length | |
| """ | |
| max_seq_length = self.config.model.max_seq_length | |
| # Get EOS token ID | |
| eos_token_id = self.tokenizer.eos_token_id | |
| unk_token_id = self.tokenizer.unk_token_id | |
| if eos_token_id is None: | |
| raise ValueError("Tokenizer does not have an EOS token defined") | |
| if unk_token_id is None: | |
| raise ValueError("Tokenizer does not have an UNK token defined") | |
| # Tokenize all texts without truncation | |
| all_token_ids = [] | |
| for i, text in enumerate(texts): | |
| tokenized = self.tokenizer( | |
| text, | |
| truncation=False, | |
| return_attention_mask=False, | |
| ) | |
| # Remove UNK tokens from tokenized output (if any) | |
| tokenized["input_ids"] = [tid for tid in tokenized["input_ids"] if tid != unk_token_id] | |
| all_token_ids.extend(tokenized["input_ids"]) | |
| # Add EOS token between documents (except after the last one) | |
| if i < len(texts) - 1: | |
| all_token_ids.append(eos_token_id) | |
| # Split into N fixed-size sequences | |
| sequences_input_ids = [] | |
| sequences_attention_masks = [] | |
| for i in range(0, len(all_token_ids), max_seq_length): | |
| seq = all_token_ids[i : i + max_seq_length] | |
| # Pad sequence if it's shorter than max_seq_length | |
| if len(seq) < max_seq_length: | |
| # Create attention mask before padding | |
| attention_mask = [1] * len(seq) + [0] * (max_seq_length - len(seq)) | |
| # Pad with 0 (assuming 0 is the pad token, or use tokenizer.pad_token_id) | |
| pad_token_id = self.tokenizer.pad_token_id or 0 | |
| seq = seq + [pad_token_id] * (max_seq_length - len(seq)) | |
| else: | |
| attention_mask = [1] * max_seq_length | |
| sequences_input_ids.append(seq) | |
| sequences_attention_masks.append(attention_mask) | |
| return { | |
| "input_ids": sequences_input_ids, | |
| "attention_mask": sequences_attention_masks, | |
| } | |
| def _tokenize_batch_sft(self, records: List[Dict[str, Any]]) -> Dict[str, List]: | |
| """ | |
| Tokenize a batch of SFT records with role tokens and response masking. | |
| Processes each record (single-turn or multi-turn) and generates sequences | |
| with role markers and masking (0=ignore user, 1=train on assistant). | |
| Args: | |
| records: List of JSONL record dicts with various SFT formats | |
| Returns: | |
| Dict with 'input_ids', 'attention_mask', and 'mask' lists, where each | |
| element is a fixed-size sequence of length max_seq_length with masking info | |
| """ | |
| # Import here to avoid circular imports | |
| from taoTrain.data.sft_utils import parse_sft_record, build_sft_sequence_tokens | |
| max_seq_length = self.config.model.max_seq_length | |
| user_token = getattr(self.config, 'user_token', '<user>') | |
| assistant_token = getattr(self.config, 'assistant_token', '<assistant>') | |
| sequences_input_ids = [] | |
| sequences_attention_masks = [] | |
| sequences_masks = [] | |
| for record in records: | |
| try: | |
| # Parse SFT record (supports multiple formats) | |
| turns, is_multi_turn = parse_sft_record(record, self.config) | |
| if not turns: | |
| # Skip records that couldn't be parsed | |
| continue | |
| # Build token sequence with role tokens and response masking | |
| input_ids, attention_mask, mask = build_sft_sequence_tokens( | |
| turns=turns, | |
| tokenizer=self.tokenizer, | |
| user_token=user_token, | |
| assistant_token=assistant_token, | |
| max_seq_length=max_seq_length, | |
| ) | |
| sequences_input_ids.append(input_ids) | |
| sequences_attention_masks.append(attention_mask) | |
| sequences_masks.append(mask) | |
| except Exception as e: | |
| # Log error but continue processing | |
| print(f"Warning: Failed to tokenize SFT record: {e}") | |
| continue | |
| return { | |
| "input_ids": sequences_input_ids, | |
| "attention_mask": sequences_attention_masks, | |
| "mask": sequences_masks, | |
| } | |
| def get_next_chunk(self, timeout: Optional[float] = None) -> Optional[Dict[str, List]]: | |
| """ | |
| Get the next tokenized chunk from the queue. | |
| This is a blocking call that waits for the next chunk to be tokenized. | |
| Returns None if queue is closed or all chunks have been processed. | |
| CRITICAL: Always attempts to drain the queue first before returning None. | |
| This prevents abandoning buffered chunks when threads finish. | |
| Args: | |
| timeout: Timeout in seconds (None = wait indefinitely) | |
| Returns: | |
| Dict with tokenized chunk, or None if queue is exhausted | |
| Raises: | |
| RuntimeError: If an error occurred in any worker thread | |
| """ | |
| if self._error_event.is_set(): | |
| error_summary = "; ".join(self._error_messages) if self._error_messages else "Unknown error" | |
| raise RuntimeError(f"Tokenization thread error: {error_summary}") | |
| # PRIORITY: Try to get from queue first (may have buffered items) | |
| try: | |
| chunk = self._queue.get(timeout=timeout) | |
| return chunk | |
| except queue.Empty: | |
| # Queue is empty - check if threads are still working | |
| with self._active_threads_lock: | |
| if self._active_threads == 0 and self._next_chunk_idx >= len(self._chunk_order): | |
| # All chunks assigned AND no active threads = true exhaustion | |
| return None | |
| # Queue temporarily empty but threads still working - signal to wait | |
| return None | |
| def is_exhausted(self) -> bool: | |
| """Return True only when all chunks are assigned and all workers are idle.""" | |
| with self._active_threads_lock: | |
| return self._active_threads == 0 and self._next_chunk_idx >= len(self._chunk_order) | |
| def shutdown(self, wait: bool = True): | |
| """ | |
| Shutdown the tokenization worker threads gracefully. | |
| Args: | |
| wait: If True, wait for all threads to finish; otherwise return immediately | |
| """ | |
| if not self._threads: | |
| return | |
| # Signal threads to stop | |
| self._stop_event.set() | |
| # Drain queue to unblock threads if they're waiting to put | |
| try: | |
| while True: | |
| self._queue.get_nowait() | |
| except queue.Empty: | |
| pass | |
| # Wait for all threads to finish | |
| if wait: | |
| for thread in self._threads: | |
| thread.join(timeout=5.0) | |
| if thread.is_alive(): | |
| print(f"⚠ Tokenization thread {thread.name} did not terminate cleanly") | |
| # Clear thread list to allow fresh start in next epoch | |
| self._threads.clear() | |
| print("✓ TokenizationQueue shutdown complete, thread list cleared") | |
| def reset_for_next_epoch(self): | |
| """ | |
| Reset queue state for the next epoch. | |
| This allows the same TokenizationQueue to be reused across multiple epochs. | |
| Resets the chunk index counter, reshuffles chunks (if enabled), and clears | |
| any buffered items and error state. | |
| Called by AsyncBatchIterator at the start of epoch 2+. | |
| """ | |
| # Reset iteration counter | |
| self._next_chunk_idx = 0 | |
| # Reshuffle chunk order if enabled | |
| if self.shuffle_chunks: | |
| import random | |
| random.shuffle(self._chunk_order) | |
| print(f"✓ Reshuffled chunk order for next epoch: {self._chunk_order}") | |
| # Drain any remaining items from queue | |
| items_drained = 0 | |
| try: | |
| while True: | |
| self._queue.get_nowait() | |
| items_drained += 1 | |
| except queue.Empty: | |
| pass | |
| if items_drained > 0: | |
| print(f"⚠ Drained {items_drained} items from queue before epoch reset") | |
| # Clear error state | |
| self._error_event.clear() | |
| self._error_messages.clear() | |
| # Clear threads list so new threads will be started in next epoch | |
| self._threads.clear() | |
| print(f"✓ TokenizationQueue reset for next epoch. Ready to process {len(self._chunk_order)} chunks") | |
| def __len__(self) -> int: | |
| """Return total number of samples.""" | |
| return self.total_items | |
| def __del__(self): | |
| """Cleanup on deletion.""" | |
| self.shutdown(wait=False) | |