#!/usr/bin/env python3 """ Streaming Teacher Cache with Async Background Fetching. This module implements a producer-consumer pattern for knowledge distillation: 1. Producer: Background async task that continuously fetches teacher logits 2. Consumer: Training dataloader that reads from the cache The cache grows during training, allowing training to start immediately with a small initial cache while more samples are fetched in background. Usage: cache = StreamingTeacherCache(tokenizer, initial_samples=10000) await cache.warmup() # Pre-cache initial samples cache.start_background_fetching() # Start producer # Training loop uses cache.get_batch() for batch in cache: train_step(batch) """ import asyncio import threading import torch import queue import time from dataclasses import dataclass from typing import Optional, Any, Iterator from collections import deque from concurrent.futures import ThreadPoolExecutor from datasets import load_dataset from transformers import AutoTokenizer from sem_v6.training.master_teacher import create_master_teacher, MasterTeacher @dataclass class CachedSample: """A single cached sample with teacher logits.""" input_ids: torch.Tensor # (seq_len,) labels: torch.Tensor # (seq_len,) teacher_logits: torch.Tensor # (seq_len, vocab_size) class StreamingTeacherCache: """ Streaming cache with async background fetching. The cache uses a ring buffer that grows up to max_samples. A background producer continuously fetches new samples. The training dataloader consumes from the cache. Args: tokenizer: HuggingFace tokenizer max_samples: Maximum cache size (default: 50000) batch_size: Batch size for API calls (default: 4) max_length: Maximum sequence length (default: 64) sample_positions: Which positions to sample for teacher (default: last 8) temperature: Distillation temperature (default: 2.0) providers: Which providers to use (default: all except openrouter) """ def __init__( self, tokenizer, max_samples: int = 50000, batch_size: int = 4, max_length: int = 64, sample_positions: Optional[list[int]] = None, temperature: float = 2.0, providers: Optional[list[str]] = None, concurrent_batches: int = 8, max_consecutive_errors: int = 5, ): self.tokenizer = tokenizer self.vocab_size = len(tokenizer) self.max_samples = max_samples self.batch_size = batch_size self.max_length = max_length self.sample_positions = sample_positions or list( range(max_length - 8, max_length) ) self.temperature = temperature self.concurrent_batches = concurrent_batches self.max_consecutive_errors = max_consecutive_errors # Skip OpenRouter (rate limited on free tier) self.providers = providers or ["gemini", "cloudflare", "opencode"] # Cache storage (thread-safe deque) self.cache: deque[CachedSample] = deque(maxlen=max_samples) self.cache_lock = threading.Lock() # Producer control self._stop_event = threading.Event() self._producer_thread: Optional[threading.Thread] = None self._samples_produced = 0 self._samples_consumed = 0 # Data source self._data_iter: Optional[Iterator] = None # Stats self.stats = { "produced": 0, "consumed": 0, "cache_size": 0, "producer_errors": 0, } self._consecutive_errors = 0 def _create_data_iterator(self) -> Iterator: """Create streaming iterator from OpenWebText.""" dataset = load_dataset("openwebtext", split="train", streaming=True) for sample in dataset: text = sample["text"] if len(text) < 100: continue tokens = self.tokenizer.encode( text, max_length=self.max_length, truncation=True ) if len(tokens) < 50: continue # Pad or truncate if len(tokens) < self.max_length: tokens = tokens + [self.tokenizer.eos_token_id or 0] * ( self.max_length - len(tokens) ) else: tokens = tokens[: self.max_length] input_ids = torch.tensor(tokens) labels = torch.tensor(tokens[1:] + [self.tokenizer.eos_token_id or 0]) yield input_ids, labels async def _fetch_batch_async( self, teacher: MasterTeacher, batch_inputs: list[torch.Tensor], batch_labels: list[torch.Tensor], ) -> list[CachedSample]: """Fetch teacher logits for a batch asynchronously.""" x = torch.stack(batch_inputs) try: teacher_logits = await teacher.get_teacher_logits_for_batch_async( x, sample_positions=self.sample_positions, temperature=self.temperature, ) samples = [] for i in range(len(batch_inputs)): samples.append( CachedSample( input_ids=batch_inputs[i], labels=batch_labels[i], teacher_logits=teacher_logits[i], ) ) self._consecutive_errors = 0 return samples except Exception as e: self.stats["producer_errors"] += 1 self._consecutive_errors += 1 if self._consecutive_errors >= self.max_consecutive_errors: raise RuntimeError( f"StreamingTeacherCache failed {self._consecutive_errors} consecutive fetches" ) from e return [] async def _producer_loop_async(self, teacher: MasterTeacher): """Async producer loop that fetches samples.""" if self._data_iter is None: self._data_iter = self._create_data_iterator() semaphore = asyncio.Semaphore(self.concurrent_batches) async def process_batch(batch_inputs, batch_labels): async with semaphore: return await self._fetch_batch_async( teacher, batch_inputs, batch_labels ) while not self._stop_event.is_set(): # Collect batch batch_inputs = [] batch_labels = [] try: for _ in range(self.batch_size): input_ids, labels = next(self._data_iter) batch_inputs.append(input_ids) batch_labels.append(labels) except StopIteration: # Restart iterator self._data_iter = self._create_data_iterator() continue # Fetch asynchronously samples = await process_batch(batch_inputs, batch_labels) # Add to cache if samples: with self.cache_lock: for sample in samples: self.cache.append(sample) self.stats["produced"] += len(samples) self.stats["cache_size"] = len(self.cache) # Brief yield to allow other tasks await asyncio.sleep(0.001) def _producer_thread_main(self): """Thread entry point for producer.""" teacher = create_master_teacher( self.tokenizer, providers=self.providers, ) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: loop.run_until_complete(self._producer_loop_async(teacher)) finally: loop.close() async def warmup(self, num_samples: int = 10000) -> None: """ Pre-cache initial samples before training starts. Args: num_samples: Number of samples to pre-cache (default: 10000) """ print(f"[StreamingCache] Warming up with {num_samples} samples...") start = time.time() teacher = create_master_teacher( self.tokenizer, providers=self.providers, ) self._data_iter = self._create_data_iterator() samples_fetched = 0 batch_count = 0 while samples_fetched < num_samples: # Collect batch batch_inputs = [] batch_labels = [] for _ in range(self.batch_size): try: input_ids, labels = next(self._data_iter) batch_inputs.append(input_ids) batch_labels.append(labels) except StopIteration: break if not batch_inputs: break # Fetch teacher logits samples = await self._fetch_batch_async(teacher, batch_inputs, batch_labels) # Add to cache for sample in samples: self.cache.append(sample) samples_fetched += 1 batch_count += 1 if batch_count % 50 == 0: elapsed = time.time() - start rate = samples_fetched / elapsed * 60 print(f" [{samples_fetched}/{num_samples}] {rate:.1f} samples/min") elapsed = time.time() - start self.stats["produced"] = len(self.cache) self.stats["cache_size"] = len(self.cache) print( f"[StreamingCache] Warmup complete: {len(self.cache)} samples in {elapsed:.1f}s" ) print(f" Provider stats: {teacher.stats()}") def start_background_fetching(self) -> None: """Start the background producer thread.""" if self._producer_thread is not None: return self._stop_event.clear() self._producer_thread = threading.Thread( target=self._producer_thread_main, daemon=True, name="TeacherCacheProducer", ) self._producer_thread.start() print("[StreamingCache] Background producer started") def stop_background_fetching(self) -> None: """Stop the background producer thread.""" if self._producer_thread is None: return self._stop_event.set() self._producer_thread.join(timeout=5.0) self._producer_thread = None print("[StreamingCache] Background producer stopped") def get_batch(self, batch_size: int) -> Optional[dict[str, torch.Tensor]]: """ Get a batch of samples from the cache. Args: batch_size: Number of samples to get Returns: Dict with input_ids, labels, teacher_logits or None if cache empty """ with self.cache_lock: if len(self.cache) < batch_size: return None # Random sample from cache (not pop, so samples can be reused) indices = torch.randperm(len(self.cache))[:batch_size] samples = [self.cache[i] for i in indices] self.stats["consumed"] += batch_size return { "input_ids": torch.stack([s.input_ids for s in samples]), "labels": torch.stack([s.labels for s in samples]), "teacher_logits": torch.stack([s.teacher_logits for s in samples]), } def __len__(self) -> int: """Current cache size.""" return len(self.cache) def __iter__(self): """Iterate over cache for DataLoader compatibility.""" with self.cache_lock: for sample in self.cache: yield { "input_ids": sample.input_ids, "labels": sample.labels, "teacher_logits": sample.teacher_logits, } def get_stats(self) -> dict[str, Any]: """Get cache statistics.""" return { **self.stats, "cache_size": len(self.cache), "producer_running": self._producer_thread is not None and self._producer_thread.is_alive(), } class StreamingCacheDataset(torch.utils.data.IterableDataset): """ PyTorch IterableDataset wrapper for StreamingTeacherCache. This allows the cache to be used with PyTorch DataLoader. Args: cache: StreamingTeacherCache instance batch_size: Batch size for training min_cache_size: Minimum samples in cache before yielding (default: 1000) """ def __init__( self, cache: StreamingTeacherCache, batch_size: int = 16, min_cache_size: int = 1000, ): self.cache = cache self.batch_size = batch_size self.min_cache_size = min_cache_size def __iter__(self): """Yield batches from cache, waiting if needed.""" while True: # Wait for minimum cache size while len(self.cache) < self.min_cache_size: time.sleep(0.1) # Get random samples from cache batch = self.cache.get_batch(self.batch_size) if batch is not None: yield batch if __name__ == "__main__": # Test the streaming cache import asyncio print("Testing StreamingTeacherCache...") tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B") cache = StreamingTeacherCache( tokenizer=tokenizer, max_samples=1000, batch_size=4, max_length=64, providers=["gemini", "cloudflare", "opencode"], # Skip rate-limited openrouter ) # Warmup with small amount asyncio.run(cache.warmup(num_samples=100)) print(f"\nCache stats: {cache.get_stats()}") # Start background producer cache.start_background_fetching() # Simulate training loop print("\nSimulating training for 10 seconds...") start = time.time() step = 0 while time.time() - start < 10: batch = cache.get_batch(batch_size=8) if batch is not None: step += 1 if step % 10 == 0: stats = cache.get_stats() print( f"Step {step}: cache_size={stats['cache_size']}, produced={stats['produced']}" ) time.sleep(0.1) cache.stop_background_fetching() print(f"\nFinal stats: {cache.get_stats()}") print("✅ StreamingTeacherCache test complete!")