Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
| """Training implementations for pretrain, SFT, and RL.""" | |
| import time | |
| from typing import Optional, Dict, Tuple | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from taoTrain.core.base import BaseModel, BaseTrainer | |
| from taoTrain.config import TrainingConfig, PretrainConfig, SFTConfig, RLConfig | |
| from taoTrain.data.loaders import get_dataloader | |
| from taoTrain.data.async_loader import AsyncBatchIterator | |
| from taoTrain.data.tokenization_queue import TokenizationQueue | |
| from taoTrain.logging import AimLogger | |
| from taoTrain.optimizers import get_optimizer | |
| from taoTrain.schedulers import get_scheduler | |
| from taoTrain.utils import set_seed, get_dtype | |
| # ============================================================================ | |
| # Metrics | |
| # ============================================================================ | |
| class MetricsTracker: | |
| """Track training and validation metrics.""" | |
| def __init__(self): | |
| """Initialize tracker.""" | |
| self.metrics = {} | |
| def update(self, metrics: Dict[str, float]): | |
| """Update metrics.""" | |
| for key, value in metrics.items(): | |
| if key not in self.metrics: | |
| self.metrics[key] = [] | |
| self.metrics[key].append(value) | |
| def get_average(self) -> Dict[str, float]: | |
| """Get average of all metrics.""" | |
| return { | |
| key: sum(values) / len(values) | |
| for key, values in self.metrics.items() | |
| if values | |
| } | |
| def get_latest(self) -> Dict[str, float]: | |
| """Get latest value of all metrics.""" | |
| return { | |
| key: values[-1] if values else 0.0 | |
| for key, values in self.metrics.items() | |
| } | |
| def reset(self): | |
| """Reset metrics.""" | |
| self.metrics = {} | |
| # ============================================================================ | |
| # Base Trainer Implementation | |
| # ============================================================================ | |
| class BaseTrainerImpl(BaseTrainer): | |
| """Base trainer implementation with common functionality.""" | |
| def __init__( | |
| self, | |
| model: BaseModel, | |
| train_dataset, | |
| val_dataset, | |
| config: TrainingConfig, | |
| device: torch.device, | |
| ): | |
| """Initialize trainer.""" | |
| super().__init__(model, train_dataset, val_dataset, config, device) | |
| # Setup optimizer and scheduler using factories | |
| print("\n✓ Setting up optimizer and scheduler...") | |
| self.optimizer = get_optimizer(self.model, config) | |
| # Compute number of training steps for scheduler | |
| num_training_steps = self._compute_num_steps() | |
| self.scheduler = get_scheduler(self.optimizer, config, num_training_steps) | |
| print(f"✓ Optimizer and scheduler setup complete. Total training steps: {num_training_steps}") | |
| # Setup AimStack logging | |
| self.logger = AimLogger(config) | |
| print("✓ AimStack logger initialized.") | |
| # Data type | |
| self.dtype = get_dtype(config.dtype.value) | |
| self.use_autocast = config.dtype.value != "float32" | |
| # Metrics tracker | |
| self.train_metrics = MetricsTracker() | |
| self.val_metrics = MetricsTracker() | |
| # Setup async loading if using JSONL datasets | |
| print("\n✓ Setting up data loading...") | |
| self._setup_async_loading(train_dataset) | |
| def _compute_num_steps(self) -> int: | |
| """Compute total training steps.""" | |
| if self.config.max_steps: | |
| return self.config.max_steps | |
| num_steps_per_epoch = ( | |
| len(self.train_dataset) // (self.config.batch_size * self.config.gradient_accumulation_steps) | |
| ) | |
| return num_steps_per_epoch * self.config.num_epochs | |
| def _setup_async_loading(self, train_dataset): | |
| """ | |
| Setup async loading for JSONL datasets. | |
| For JSONL datasets, creates TokenizationQueue and AsyncBatchIterator. | |
| For HuggingFace datasets, sets async_loader to None. | |
| Note: All JSONL datasets now operate in async-only mode. | |
| """ | |
| self.async_loader = None | |
| # Check if this is a JSONL-based dataset | |
| from taoTrain.data.jsonl_base import BaseJSONLDataset | |
| print("\n✓ Checking dataset type for async loading...") | |
| if isinstance(train_dataset, BaseJSONLDataset): | |
| # Set up async loading pipeline | |
| print("✓ Detected JSONL dataset, setting up async loading...") | |
| # Create tokenization queue | |
| print("✓ Creating TokenizationQueue...") | |
| tokenization_queue = TokenizationQueue( | |
| chunk_manager=self.train_dataset.chunk_manager, | |
| tokenizer=self.train_dataset.tokenizer, | |
| config=self.config, | |
| max_queue_size=32, # Memory constraint | |
| shuffle_chunks=True, | |
| num_threads=self.config.dataset.tokenizer_threads, | |
| ) | |
| # Create async batch iterator | |
| print("✓ Creating AsyncBatchIterator...") | |
| self.async_loader = AsyncBatchIterator( | |
| tokenization_queue=tokenization_queue, | |
| batch_size=self.config.batch_size, | |
| device=self.device, | |
| drop_last=True, | |
| gradient_accumulation_steps=self.config.gradient_accumulation_steps, | |
| ) | |
| def training_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: | |
| """ | |
| Single training step. | |
| Args: | |
| batch: Training batch | |
| Returns: | |
| Dict with loss and other metrics | |
| """ | |
| self.model.train() | |
| # Move batch to device (may already be on device for async loader) | |
| batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v | |
| for k, v in batch.items()} | |
| # Forward pass with mixed precision | |
| with torch.autocast(device_type="cuda" if self.device.type == "cuda" else "cpu", | |
| dtype=torch.bfloat16 if self.use_autocast else torch.float32, | |
| enabled=self.use_autocast): | |
| outputs = self.model( | |
| input_ids=batch["input_ids"], | |
| attention_mask=batch.get("attention_mask"), | |
| labels=batch.get("labels"), | |
| ) | |
| loss = outputs["loss"] | |
| # Backward pass | |
| if self.config.gradient_accumulation_steps > 1: | |
| loss = loss / self.config.gradient_accumulation_steps | |
| loss.backward() | |
| # Only step optimizer every N accumulation steps | |
| accumulation_counter = (self.global_step + 1) % self.config.gradient_accumulation_steps | |
| if accumulation_counter == 0 or self.config.gradient_accumulation_steps == 1: | |
| # Gradient clipping | |
| if self.config.max_grad_norm > 0: | |
| torch.nn.utils.clip_grad_norm_( | |
| self.model.parameters(), | |
| self.config.max_grad_norm | |
| ) | |
| self.optimizer.step() | |
| self.scheduler.step() | |
| self.optimizer.zero_grad() | |
| # Store unaccumulated loss for logging | |
| raw_loss = loss.item() * (self.config.gradient_accumulation_steps or 1) | |
| return { | |
| "loss": raw_loss, | |
| "lr": self.scheduler.get_last_lr()[0], | |
| } | |
| def validation_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: | |
| """Single validation step (just compute loss).""" | |
| self.model.eval() | |
| # Move batch to device | |
| batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v | |
| for k, v in batch.items()} | |
| with torch.no_grad(): | |
| with torch.autocast(device_type="cuda" if self.device.type == "cuda" else "cpu", | |
| dtype=torch.bfloat16 if self.use_autocast else torch.float32, | |
| enabled=self.use_autocast): | |
| outputs = self.model( | |
| input_ids=batch["input_ids"], | |
| attention_mask=batch.get("attention_mask"), | |
| labels=batch.get("labels"), | |
| ) | |
| loss = outputs["loss"] | |
| return {"val_loss": loss.item()} | |
| def train_epoch(self) -> Dict[str, float]: | |
| """ | |
| Train for one epoch. | |
| Returns: | |
| Dict with epoch metrics | |
| """ | |
| self.current_epoch += 1 | |
| self.train_metrics.reset() | |
| # Use async loader for JSONL datasets, regular DataLoader for HuggingFace | |
| if self.async_loader is not None: | |
| print("\n✓ Using AsyncBatchIterator for training...") | |
| train_iterator = self.async_loader | |
| else: | |
| print("\n✓ Creating DataLoader for training dataset...") | |
| train_loader = get_dataloader( | |
| self.train_dataset, | |
| self.config, | |
| shuffle=True, | |
| drop_last=True, | |
| ) | |
| train_iterator = train_loader | |
| pbar = tqdm(train_iterator, desc=f"Epoch {self.current_epoch}") | |
| for batch_idx, batch in enumerate(pbar): | |
| # Check if we've hit max steps | |
| if self.config.max_steps and self.global_step >= self.config.max_steps: | |
| print(f"\n✓ Reached max steps ({self.global_step}), ending training.") | |
| break | |
| # Training step | |
| metrics = self.training_step(batch) | |
| self.train_metrics.update(metrics) | |
| self.global_step += 1 | |
| pbar.set_postfix(self.train_metrics.get_latest()) | |
| # Logging | |
| if self.global_step % self.config.log_every_steps == 0: | |
| latest_metrics = self.train_metrics.get_latest() | |
| log_dict = {"step": self.global_step, "epoch": self.current_epoch} | |
| log_dict.update(latest_metrics) | |
| self.logger.log_metrics(log_dict) | |
| # Validation | |
| if self.global_step % self.config.eval_every_steps == 0: | |
| val_metrics = self.validate() | |
| self.logger.log_metrics({"step": self.global_step, **val_metrics}) | |
| # Save checkpoint if best | |
| if val_metrics.get("val_loss", float('inf')) < self.best_loss: | |
| self.best_loss = val_metrics["val_loss"] | |
| if self.config.save_best_model: | |
| ckpt_path = Path(self.config.checkpoint_dir) / "best_model.pt" | |
| self.save_checkpoint(ckpt_path) | |
| # Save periodic checkpoint | |
| if self.global_step % self.config.save_every_steps == 0: | |
| ckpt_path = Path(self.config.checkpoint_dir) / f"checkpoint_step_{self.global_step}.pt" | |
| self.save_checkpoint(ckpt_path) | |
| print(f"\n✓ Finished epoch {self.current_epoch}.") | |
| return self.train_metrics.get_average() | |
| def validate(self) -> Dict[str, float]: | |
| """Run validation.""" | |
| if self.val_dataset is None: | |
| return {} | |
| val_loader = get_dataloader( | |
| self.val_dataset, | |
| self.config, | |
| shuffle=False, | |
| drop_last=False, | |
| ) | |
| self.val_metrics.reset() | |
| with torch.no_grad(): | |
| for batch in tqdm(val_loader, desc="Validating", disable=True): | |
| metrics = self.validation_step(batch) | |
| self.val_metrics.update(metrics) | |
| return self.val_metrics.get_average() | |
| # ============================================================================ | |
| # Stage-Specific Trainers | |
| # ============================================================================ | |
| class PretrainTrainer(BaseTrainerImpl): | |
| """Trainer for pretraining.""" | |
| pass # Inherits all from BaseTrainerImpl | |
| class SFTTrainer(BaseTrainerImpl): | |
| """Trainer for supervised fine-tuning.""" | |
| pass # Can add SFT-specific logic here if needed | |
| class RLTrainer(BaseTrainerImpl): | |
| """Trainer for reinforcement learning.""" | |
| # Will implement PPO/DPO logic in separate module | |
| pass | |