Image-Text-to-Text
Transformers
English
vision-language-model
vlm
surveillance
iot
gemma
vl-jepa
multimodal
object-detection
video-analytics
Instructions to use hardiksa/arcisvlm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use hardiksa/arcisvlm with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-text-to-text", model="hardiksa/arcisvlm")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("hardiksa/arcisvlm", dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use hardiksa/arcisvlm with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "hardiksa/arcisvlm" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "hardiksa/arcisvlm", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/hardiksa/arcisvlm
- SGLang
How to use hardiksa/arcisvlm with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "hardiksa/arcisvlm" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "hardiksa/arcisvlm", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "hardiksa/arcisvlm" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "hardiksa/arcisvlm", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use hardiksa/arcisvlm with Docker Model Runner:
docker model run hf.co/hardiksa/arcisvlm
| #!/usr/bin/env python3 | |
| """ | |
| Stage 3: Domain Fine-Tuning — DDP Training on 8x A100. | |
| Loads Stage 2 checkpoint and fine-tunes on domain-specific data | |
| (COCO detection, VisDrone, MOT, UCF-Crime, ActivityNet, synthetic RTSP) | |
| with very low learning rate. | |
| Usage: | |
| torchrun --nproc_per_node=8 scripts/train_stage3_ddp.py \\ | |
| --config configs/scale_1.3b.yaml \\ | |
| --stage2_ckpt checkpoints/stage2_final.pt | |
| torchrun --nproc_per_node=8 scripts/train_stage3_ddp.py \\ | |
| --config configs/scale_1.3b.yaml \\ | |
| --stage2_ckpt checkpoints/stage2_final.pt \\ | |
| --resume checkpoints/stage3_epoch3.pt | |
| """ | |
| import argparse | |
| import math | |
| import os | |
| import sys | |
| import time | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.utils.data import DataLoader, Dataset | |
| from torch.utils.data.distributed import DistributedSampler | |
| import yaml | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from model.vlm import VLJEPAModel | |
| from model.tokenizer import BPETokenizer | |
| # --------------------------------------------------------------------------- | |
| # Dataset helpers | |
| # --------------------------------------------------------------------------- | |
| def build_stage3_dataset(config: dict, tokenizer) -> Dataset: | |
| """ | |
| Build Stage 3 domain fine-tuning dataset (detection, surveillance, temporal). | |
| Raises RuntimeError if no real data is found. | |
| """ | |
| img_size = config["vision"]["img_size"] | |
| vocab_size = config["decoder"]["vocab_size"] | |
| # Try loading real data from local JSONL files (downloaded by download_datasets.py) | |
| jsonl_dir = "data/downloads/stage3" | |
| if os.path.exists(jsonl_dir): | |
| try: | |
| import json | |
| class LocalJSONLDataset(Dataset): | |
| """Load domain fine-tuning data from local JSONL files.""" | |
| def __init__(self, jsonl_dir, tokenizer, img_size=448, max_q=64, max_a=128): | |
| self.samples = [] | |
| self.tokenizer = tokenizer | |
| self.max_q = max_q | |
| self.max_a = max_a | |
| self.img_size = img_size | |
| self.vocab_size = tokenizer.vocab_size if hasattr(tokenizer, 'vocab_size') else 32768 | |
| # Load all JSONL files — ONLY keep samples with valid images | |
| skipped = 0 | |
| for fname in sorted(os.listdir(jsonl_dir)): | |
| if fname.endswith('.jsonl'): | |
| fpath = os.path.join(jsonl_dir, fname) | |
| with open(fpath) as f: | |
| for line in f: | |
| try: | |
| item = json.loads(line.strip()) | |
| # Filter: must have valid image_path | |
| img_path = item.get("image_path") | |
| if img_path and os.path.exists(img_path): | |
| self.samples.append(item) | |
| else: | |
| skipped += 1 | |
| except json.JSONDecodeError: | |
| continue | |
| if skipped > 0: | |
| print(f" [Stage 3 data] Skipped {skipped} samples without images") | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| item = self.samples[idx] | |
| # Extract question and answer from various formats | |
| question = "" | |
| answer = "" | |
| # LLaVA-Instruct format: conversations list | |
| if "conversations" in item: | |
| convos = item["conversations"] | |
| if isinstance(convos, list) and len(convos) >= 2: | |
| question = convos[0].get("value", "") if isinstance(convos[0], dict) else str(convos[0]) | |
| answer = convos[1].get("value", "") if isinstance(convos[1], dict) else str(convos[1]) | |
| # VQAv2/GQA format | |
| if not question: | |
| question = item.get("question", item.get("text", "What do you see?")) | |
| if not answer: | |
| answer = item.get("answer", item.get("multiple_choice_answer", "")) | |
| if not answer and "answers" in item: | |
| answers = item["answers"] | |
| if isinstance(answers, list) and answers: | |
| if isinstance(answers[0], dict): | |
| answer = answers[0].get("answer", "") | |
| else: | |
| answer = str(answers[0]) | |
| if not answer: | |
| answer = "unknown" | |
| # Load real image — crash if missing or corrupted | |
| image_path = item.get("image_path") | |
| if image_path and os.path.exists(image_path): | |
| try: | |
| from PIL import Image as PILImage | |
| from torchvision import transforms | |
| pil_img = PILImage.open(image_path).convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.Resize((self.img_size, self.img_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ), | |
| ]) | |
| image = transform(pil_img) | |
| except Exception as e: | |
| raise FileNotFoundError(f"Image not found or corrupted: {image_path}. Original error: {e}") | |
| else: | |
| raise FileNotFoundError( | |
| f"Image not found: {image_path}. Data may be corrupted or incomplete." | |
| ) | |
| # Tokenize | |
| q_ids = self.tokenizer.encode(str(question)) | |
| a_ids = self.tokenizer.encode(str(answer)) | |
| # Pad/truncate | |
| q_ids = self._pad(q_ids, self.max_q) | |
| a_ids = self._pad(a_ids, self.max_a) | |
| q_tensor = torch.tensor(q_ids, dtype=torch.long) | |
| a_tensor = torch.tensor(a_ids, dtype=torch.long) | |
| return { | |
| "image": image, | |
| "question_ids": q_tensor, | |
| "question_mask": (q_tensor != self.tokenizer.pad_id).long(), | |
| "answer_ids": a_tensor, | |
| "answer_mask": (a_tensor != self.tokenizer.pad_id).long(), | |
| } | |
| def _pad(self, ids, max_len): | |
| if len(ids) > max_len: | |
| return ids[:max_len] | |
| return ids + [self.tokenizer.pad_id] * (max_len - len(ids)) | |
| dataset = LocalJSONLDataset(jsonl_dir, tokenizer, img_size) | |
| if len(dataset) > 100: | |
| print(f" [REAL DATA] Local JSONL: {len(dataset)} samples from {jsonl_dir}") | |
| return dataset | |
| except Exception as e: | |
| print(f" [WARN] Local JSONL loading failed: {e}") | |
| # Try loading from data/multi_dataset.py — wrap with tokenization | |
| try: | |
| from data.multi_dataset import build_stage3_dataset as _build | |
| raw_dataset = _build(config, tokenizer) | |
| if len(raw_dataset) > 0: | |
| # Wrap raw dataset to tokenize on-the-fly if needed | |
| class TokenizedWrapper(Dataset): | |
| """Wraps a dataset returning raw strings and adds tokenization.""" | |
| def __init__(self, ds, tokenizer, img_size=448, max_q=64, max_a=128): | |
| self.ds = ds | |
| self.tokenizer = tokenizer | |
| self.img_size = img_size | |
| self.max_q = max_q | |
| self.max_a = max_a | |
| self.vocab_size = tokenizer.vocab_size if hasattr(tokenizer, 'vocab_size') else 32768 | |
| def __len__(self): | |
| return len(self.ds) | |
| def __getitem__(self, idx): | |
| item = self.ds[idx] | |
| # If already tokenized, return as-is | |
| if "question_ids" in item: | |
| return item | |
| # Otherwise tokenize raw text | |
| image = item.get("image") | |
| if image is None: | |
| # Skip text-only samples — return a valid dummy that training can handle | |
| # (filtered by the DataLoader collate, or produces minimal loss) | |
| import torch | |
| pad_id = getattr(self.tokenizer, 'pad_id', 0) | |
| return { | |
| "image": torch.zeros(3, self.img_size, self.img_size), | |
| "question_ids": torch.full((self.max_q,), pad_id, dtype=torch.long), | |
| "question_mask": torch.zeros(self.max_q, dtype=torch.long), | |
| "answer_ids": torch.full((self.max_a,), pad_id, dtype=torch.long), | |
| "answer_mask": torch.zeros(self.max_a, dtype=torch.long), | |
| } | |
| question = str(item.get("question", "What do you see?")) | |
| answer = str(item.get("answer", "unknown")) | |
| q_ids = self.tokenizer.encode(question) | |
| a_ids = self.tokenizer.encode(answer) | |
| # Pad/truncate | |
| pad_id = getattr(self.tokenizer, 'pad_id', 0) | |
| q_ids = (q_ids[:self.max_q] + [pad_id] * self.max_q)[:self.max_q] | |
| a_ids = (a_ids[:self.max_a] + [pad_id] * self.max_a)[:self.max_a] | |
| q_tensor = torch.tensor(q_ids, dtype=torch.long) | |
| a_tensor = torch.tensor(a_ids, dtype=torch.long) | |
| return { | |
| "image": image, | |
| "question_ids": q_tensor, | |
| "question_mask": (q_tensor != pad_id).long(), | |
| "answer_ids": a_tensor, | |
| "answer_mask": (a_tensor != pad_id).long(), | |
| } | |
| dataset = TokenizedWrapper(raw_dataset, tokenizer, img_size) | |
| print(f" [REAL DATA] multi_dataset: {len(dataset)} samples (with tokenization wrapper)") | |
| return dataset | |
| except (ImportError, Exception) as e: | |
| print(f" [WARN] multi_dataset loading failed: {e}") | |
| raise RuntimeError( | |
| "FATAL: No Stage 3 training data found.\n" | |
| "Download real data first: python3 scripts/download_all_data.py --stage 3\n" | |
| "Required: data/downloads/stage3/ with JSONL files" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # LR scheduler | |
| # --------------------------------------------------------------------------- | |
| class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): | |
| """Linear warmup then cosine decay.""" | |
| def __init__(self, optimizer, warmup_steps: int, total_steps: int, | |
| min_lr: float = 1e-7, last_epoch: int = -1): | |
| self.warmup_steps = warmup_steps | |
| self.total_steps = total_steps | |
| self.min_lr = min_lr | |
| super().__init__(optimizer, last_epoch) | |
| def get_lr(self): | |
| step = self.last_epoch | |
| if step < self.warmup_steps: | |
| scale = step / max(1, self.warmup_steps) | |
| return [base_lr * scale for base_lr in self.base_lrs] | |
| else: | |
| progress = (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps) | |
| cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) | |
| return [self.min_lr + (base_lr - self.min_lr) * cosine for base_lr in self.base_lrs] | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def setup_distributed(): | |
| dist.init_process_group(backend="nccl") | |
| local_rank = int(os.environ["LOCAL_RANK"]) | |
| torch.cuda.set_device(local_rank) | |
| return local_rank | |
| def cleanup(): | |
| if dist.is_initialized(): | |
| dist.destroy_process_group() | |
| def is_rank0(): | |
| return not dist.is_initialized() or dist.get_rank() == 0 | |
| def log(msg: str): | |
| if is_rank0(): | |
| print(msg, flush=True) | |
| def save_checkpoint(model, optimizer, scheduler, epoch, global_step, loss, path): | |
| if not is_rank0(): | |
| return | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| state_dict = model.module.state_dict() if hasattr(model, "module") else model.state_dict() | |
| torch.save({ | |
| "epoch": epoch, | |
| "global_step": global_step, | |
| "model_state_dict": state_dict, | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "scheduler_state_dict": scheduler.state_dict(), | |
| "loss": loss, | |
| }, path) | |
| log(f" Checkpoint saved: {path}") | |
| def push_checkpoints(): | |
| """Push checkpoints to GitHub LFS. Disabled during training to avoid git lock issues. | |
| Call scripts/push_checkpoints.py manually after training completes.""" | |
| pass # Disabled — run push_checkpoints.py separately after training | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Stage 3 DDP: Domain Fine-Tuning") | |
| parser.add_argument("--config", type=str, required=True, help="Path to YAML config") | |
| parser.add_argument("--stage2_ckpt", type=str, default="checkpoints/stage2_final.pt", | |
| help="Path to Stage 2 checkpoint") | |
| parser.add_argument("--resume", type=str, default=None, | |
| help="Path to Stage 3 checkpoint to resume from") | |
| args = parser.parse_args() | |
| # ---- Distributed setup ---- | |
| local_rank = setup_distributed() | |
| world_size = dist.get_world_size() | |
| global_rank = dist.get_rank() | |
| device = torch.device(f"cuda:{local_rank}") | |
| # ---- Config ---- | |
| with open(args.config) as f: | |
| config = yaml.safe_load(f) | |
| stage_cfg = config["train_stage3"] | |
| per_gpu_batch = stage_cfg["batch_size"] // world_size | |
| grad_accum = stage_cfg.get("gradient_accumulation", 4) | |
| max_epochs = stage_cfg["max_epochs"] | |
| lr = stage_cfg["learning_rate"] # Very low: 5e-6 | |
| warmup_steps = stage_cfg["warmup_steps"] | |
| grad_clip = stage_cfg["gradient_clip"] | |
| lb_weight = stage_cfg["load_balance_weight"] | |
| log("=" * 70) | |
| log("ArcisVLM — Stage 3: Domain Fine-Tuning (DDP)") | |
| log("=" * 70) | |
| log(f" World size: {world_size}") | |
| log(f" Global batch: {stage_cfg['batch_size']}") | |
| log(f" Per-GPU batch: {per_gpu_batch}") | |
| log(f" Gradient accumulation:{grad_accum}") | |
| log(f" Effective batch: {stage_cfg['batch_size'] * grad_accum}") | |
| log(f" Max epochs: {max_epochs}") | |
| log(f" Learning rate: {lr}") | |
| log(f" Warmup steps: {warmup_steps}") | |
| log(f" Load balance weight: {lb_weight}") | |
| log(f" Precision: {stage_cfg.get('precision', 'bf16')}") | |
| # ---- Tokenizer ---- | |
| tokenizer = BPETokenizer(vocab_size=config["decoder"]["vocab_size"]) | |
| for tok_path in ["checkpoints/tokenizer_32k.json", "checkpoints/tokenizer.json"]: | |
| if os.path.exists(tok_path): | |
| tokenizer.load(tok_path) | |
| log(f" Tokenizer: {len(tokenizer)} tokens (from {tok_path})") | |
| break | |
| else: | |
| log(" [WARN] No tokenizer found — using untrained tokenizer") | |
| # ---- Dataset ---- | |
| dataset = build_stage3_dataset(config, tokenizer) | |
| sampler = DistributedSampler(dataset, num_replicas=world_size, rank=global_rank, shuffle=True) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=per_gpu_batch, | |
| sampler=sampler, | |
| num_workers=4, | |
| pin_memory=True, | |
| drop_last=True, | |
| ) | |
| log(f" Dataset: {len(dataset)} samples, {len(loader)} batches/GPU") | |
| # ---- Model + Stage 2 checkpoint ---- | |
| model = VLJEPAModel(config).to(device) | |
| if os.path.exists(args.stage2_ckpt): | |
| ckpt = torch.load(args.stage2_ckpt, map_location=device, weights_only=False) | |
| model.load_state_dict(ckpt["model_state_dict"]) | |
| log(f" Loaded Stage 2 ckpt: {args.stage2_ckpt} (epoch {ckpt['epoch']}, loss {ckpt['loss']:.4f})") | |
| else: | |
| log(f" [WARN] Stage 2 checkpoint not found: {args.stage2_ckpt} — training from scratch") | |
| # Stage 3: all parameters trainable (unfrozen) but at very low LR | |
| model.unfreeze_all() | |
| if is_rank0(): | |
| params = model.count_parameters() | |
| for k, v in params.items(): | |
| log(f" {k}: {v:,}") | |
| # ---- Optimizer ---- | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) | |
| # ---- Scheduler ---- | |
| steps_per_epoch = math.ceil(len(loader) / grad_accum) | |
| total_steps = max_epochs * steps_per_epoch | |
| scheduler = CosineWarmupScheduler(optimizer, warmup_steps=warmup_steps, total_steps=total_steps) | |
| # ---- Mixed precision ---- | |
| use_bf16 = stage_cfg.get("precision", "bf16") == "bf16" | |
| scaler = torch.amp.GradScaler("cuda", enabled=(not use_bf16)) | |
| autocast_dtype = torch.bfloat16 if use_bf16 else torch.float16 | |
| # ---- Resume ---- | |
| start_epoch = 0 | |
| global_step = 0 | |
| if args.resume and os.path.exists(args.resume): | |
| ckpt = torch.load(args.resume, map_location=device, weights_only=False) | |
| model.load_state_dict(ckpt["model_state_dict"]) | |
| optimizer.load_state_dict(ckpt["optimizer_state_dict"]) | |
| if "scheduler_state_dict" in ckpt: | |
| scheduler.load_state_dict(ckpt["scheduler_state_dict"]) | |
| start_epoch = ckpt["epoch"] | |
| global_step = ckpt.get("global_step", start_epoch * steps_per_epoch) | |
| log(f" Resumed from {args.resume} (epoch {start_epoch}, loss {ckpt['loss']:.4f})") | |
| # ---- DDP wrap ---- | |
| model = DDP(model, device_ids=[local_rank], output_device=local_rank, | |
| find_unused_parameters=False) | |
| # ---- Training loop ---- | |
| model.train() | |
| os.makedirs("checkpoints", exist_ok=True) | |
| for epoch in range(start_epoch, max_epochs): | |
| sampler.set_epoch(epoch) | |
| epoch_loss = 0.0 | |
| epoch_decode_loss = 0.0 | |
| epoch_lb_loss = 0.0 | |
| epoch_steps = 0 | |
| epoch_start = time.time() | |
| optimizer.zero_grad(set_to_none=True) | |
| for batch_idx, batch in enumerate(loader): | |
| images = batch["image"].to(device, non_blocking=True) | |
| q_ids = batch["question_ids"].to(device, non_blocking=True) | |
| q_mask = batch["question_mask"].to(device, non_blocking=True) | |
| a_ids = batch["answer_ids"].to(device, non_blocking=True) | |
| with torch.amp.autocast("cuda", dtype=autocast_dtype): | |
| output = model.module.forward_stage2( | |
| images=images, | |
| query_ids=q_ids, | |
| query_padding_mask=q_mask, | |
| answer_ids=a_ids, | |
| load_balance_weight=lb_weight, | |
| ) | |
| loss = output["loss"] / grad_accum | |
| if use_bf16: | |
| loss.backward() | |
| else: | |
| scaler.scale(loss).backward() | |
| epoch_loss += output["loss"].item() | |
| epoch_decode_loss += output["decode_loss"].item() | |
| epoch_lb_loss += output["load_balance_loss"].item() | |
| epoch_steps += 1 | |
| # Optimizer step every grad_accum batches | |
| if (batch_idx + 1) % grad_accum == 0: | |
| if use_bf16: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| optimizer.step() | |
| else: | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| scheduler.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| global_step += 1 | |
| if global_step % 50 == 0: | |
| current_lr = scheduler.get_last_lr()[0] | |
| gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9 | |
| log(f" [Step {global_step}] loss={output['loss'].item():.4f} " | |
| f"decode={output['decode_loss'].item():.4f} " | |
| f"lb={output['load_balance_loss'].item():.4f} " | |
| f"lr={current_lr:.2e} GPU={gpu_mem:.1f}GB") | |
| # ---- Epoch summary ---- | |
| metrics = torch.tensor([epoch_loss, epoch_decode_loss, epoch_lb_loss, epoch_steps], | |
| device=device, dtype=torch.float64) | |
| dist.all_reduce(metrics, op=dist.ReduceOp.SUM) | |
| avg_loss = (metrics[0] / metrics[3]).item() | |
| avg_decode = (metrics[1] / metrics[3]).item() | |
| avg_lb = (metrics[2] / metrics[3]).item() | |
| epoch_time = time.time() - epoch_start | |
| current_lr = scheduler.get_last_lr()[0] | |
| gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9 | |
| log(f"\nEpoch {epoch + 1}/{max_epochs}: loss={avg_loss:.4f} decode={avg_decode:.4f} " | |
| f"lb={avg_lb:.4f} lr={current_lr:.2e} time={epoch_time:.0f}s GPU={gpu_mem:.1f}GB") | |
| # ---- Checkpoint every epoch ---- | |
| ckpt_path = f"checkpoints/stage3_epoch{epoch + 1}.pt" | |
| save_checkpoint(model, optimizer, scheduler, epoch + 1, global_step, avg_loss, ckpt_path) | |
| push_checkpoints() | |
| dist.barrier() | |
| # ---- Final checkpoint ---- | |
| save_checkpoint(model, optimizer, scheduler, max_epochs, global_step, avg_loss, | |
| "checkpoints/stage3_final.pt") | |
| push_checkpoints() | |
| log("\n" + "=" * 70) | |
| log(f"Stage 3 complete. Final loss: {avg_loss:.4f} decode: {avg_decode:.4f}") | |
| log("=" * 70) | |
| cleanup() | |
| if __name__ == "__main__": | |
| main() | |