Text Generation
PyTorch
Transformers
English
language-model
graph-neural-network
sparse-attention
adaptive-depth
temporal-decay
mesh-attention
efficient-transformer
novel-architecture
causal-lm
research
preprint
mesh-transformer
dynamic-graph
early-exit
per-token-routing
Eval Results (legacy)
Instructions to use vigneshwar234/TemporalMesh-Transformer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use vigneshwar234/TemporalMesh-Transformer with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="vigneshwar234/TemporalMesh-Transformer")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("vigneshwar234/TemporalMesh-Transformer", dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use vigneshwar234/TemporalMesh-Transformer with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "vigneshwar234/TemporalMesh-Transformer" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "vigneshwar234/TemporalMesh-Transformer", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/vigneshwar234/TemporalMesh-Transformer
- SGLang
How to use vigneshwar234/TemporalMesh-Transformer with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "vigneshwar234/TemporalMesh-Transformer" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "vigneshwar234/TemporalMesh-Transformer", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "vigneshwar234/TemporalMesh-Transformer" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "vigneshwar234/TemporalMesh-Transformer", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use vigneshwar234/TemporalMesh-Transformer with Docker Model Runner:
docker model run hf.co/vigneshwar234/TemporalMesh-Transformer
| """ | |
| trainer.py — TMT training loop with wandb logging. | |
| Trains on wikitext-2 (or tinystories) using AdamW + cosine warmup schedule. | |
| Logs: train loss, val perplexity, exit rate per layer, and memory anchor norms. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from torch.optim import AdamW | |
| from torch.utils.data import DataLoader | |
| from ..model.config import TMTConfig | |
| from ..model.model import TMTModel | |
| from .loss import compute_loss | |
| from .scheduler import cosine_warmup_scheduler | |
| class TrainConfig: | |
| # Data | |
| dataset: str = "wikitext-2" # or "tinystories" | |
| batch_size: int = 16 | |
| seq_len: int = 256 # shorter than max for memory efficiency | |
| # Optimiser | |
| lr: float = 3e-4 | |
| weight_decay: float = 0.1 | |
| grad_clip: float = 1.0 | |
| warmup_steps: int = 500 | |
| total_steps: int = 10_000 | |
| # Saving | |
| save_dir: str = "checkpoints" | |
| save_every: int = 500 | |
| eval_every: int = 100 | |
| # Logging | |
| use_wandb: bool = False # set True when wandb is configured | |
| project: str = "temporal-mesh-transformer" | |
| # Device | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Loss | |
| exit_gate_coeff: float = 0.1 | |
| class TMTTrainer: | |
| def __init__( | |
| self, | |
| model_cfg: TMTConfig, | |
| train_cfg: TrainConfig, | |
| train_loader: DataLoader, | |
| val_loader: Optional[DataLoader] = None, | |
| ) -> None: | |
| self.cfg = train_cfg | |
| self.device = torch.device(train_cfg.device) | |
| self.model = TMTModel(model_cfg).to(self.device) | |
| self.optimizer = AdamW( | |
| self.model.parameters(), | |
| lr=train_cfg.lr, | |
| weight_decay=train_cfg.weight_decay, | |
| ) | |
| self.scheduler = cosine_warmup_scheduler( | |
| self.optimizer, | |
| warmup_steps=train_cfg.warmup_steps, | |
| total_steps=train_cfg.total_steps, | |
| ) | |
| self.train_loader = train_loader | |
| self.val_loader = val_loader | |
| self.step = 0 | |
| if train_cfg.use_wandb: | |
| try: | |
| import wandb | |
| wandb.init(project=train_cfg.project, config={ | |
| "model": vars(model_cfg), | |
| "train": vars(train_cfg), | |
| }) | |
| self.wandb = wandb | |
| except ImportError: | |
| print("wandb not installed — skipping wandb logging") | |
| self.wandb = None | |
| else: | |
| self.wandb = None | |
| os.makedirs(train_cfg.save_dir, exist_ok=True) | |
| print(self.model) | |
| def train(self) -> None: | |
| self.model.train() | |
| data_iter = iter(self.train_loader) | |
| while self.step < self.cfg.total_steps: | |
| try: | |
| batch = next(data_iter) | |
| except StopIteration: | |
| data_iter = iter(self.train_loader) | |
| batch = next(data_iter) | |
| input_ids = batch["input_ids"].to(self.device) | |
| # Next-token prediction: targets are shifted by 1 | |
| x = input_ids[:, :-1] | |
| targets = input_ids[:, 1:] | |
| # Forward | |
| output = self.model(x) | |
| total_loss, ce_loss, gate_loss = compute_loss( | |
| output.logits, | |
| targets, | |
| output.confidences, | |
| self.cfg.exit_gate_coeff, | |
| ) | |
| # Backward | |
| self.optimizer.zero_grad() | |
| total_loss.backward() | |
| nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip) | |
| self.optimizer.step() | |
| self.scheduler.step() | |
| self.step += 1 | |
| # Logging | |
| if self.step % 10 == 0: | |
| lr = self.optimizer.param_groups[0]["lr"] | |
| avg_exit_rate = self._compute_exit_rate(output) | |
| print( | |
| f"step={self.step:5d} | loss={total_loss.item():.4f} | " | |
| f"ce={ce_loss.item():.4f} | gate={gate_loss.item():.4f} | " | |
| f"exit={avg_exit_rate:.3f} | lr={lr:.2e}" | |
| ) | |
| if self.wandb: | |
| self.wandb.log({ | |
| "loss/total": total_loss.item(), | |
| "loss/ce": ce_loss.item(), | |
| "loss/gate": gate_loss.item(), | |
| "train/exit_rate": avg_exit_rate, | |
| "train/lr": lr, | |
| "step": self.step, | |
| }) | |
| if self.val_loader and self.step % self.cfg.eval_every == 0: | |
| val_ppl = self.evaluate() | |
| print(f" val_perplexity={val_ppl:.2f}") | |
| if self.wandb: | |
| self.wandb.log({"val/perplexity": val_ppl, "step": self.step}) | |
| self.model.train() | |
| if self.step % self.cfg.save_every == 0: | |
| self._save(f"{self.cfg.save_dir}/ckpt_step{self.step}.pt") | |
| def evaluate(self) -> float: | |
| self.model.eval() | |
| total_loss, n_batches = 0.0, 0 | |
| for batch in self.val_loader: | |
| input_ids = batch["input_ids"].to(self.device) | |
| x, targets = input_ids[:, :-1], input_ids[:, 1:] | |
| out = self.model(x) | |
| loss, *_ = compute_loss(out.logits, targets, out.confidences) | |
| total_loss += loss.item() | |
| n_batches += 1 | |
| avg_loss = total_loss / max(n_batches, 1) | |
| import math | |
| return math.exp(avg_loss) | |
| def _compute_exit_rate(output) -> float: | |
| if not output.exit_masks: | |
| return 0.0 | |
| final_mask = output.exit_masks[-1] | |
| return final_mask.float().mean().item() | |
| def _save(self, path: str) -> None: | |
| torch.save({ | |
| "step": self.step, | |
| "model_state": self.model.state_dict(), | |
| "optimizer_state": self.optimizer.state_dict(), | |
| }, path) | |
| print(f" saved checkpoint → {path}") | |