| |
| """ |
| Streaming Teacher Cache with Async Background Fetching. |
| |
| This module implements a producer-consumer pattern for knowledge distillation: |
| 1. Producer: Background async task that continuously fetches teacher logits |
| 2. Consumer: Training dataloader that reads from the cache |
| |
| The cache grows during training, allowing training to start immediately |
| with a small initial cache while more samples are fetched in background. |
| |
| Usage: |
| cache = StreamingTeacherCache(tokenizer, initial_samples=10000) |
| await cache.warmup() # Pre-cache initial samples |
| cache.start_background_fetching() # Start producer |
| |
| # Training loop uses cache.get_batch() |
| for batch in cache: |
| train_step(batch) |
| """ |
|
|
| import asyncio |
| import threading |
| import torch |
| import queue |
| import time |
| from dataclasses import dataclass |
| from typing import Optional, Any, Iterator |
| from collections import deque |
| from concurrent.futures import ThreadPoolExecutor |
|
|
| from datasets import load_dataset |
| from transformers import AutoTokenizer |
|
|
| from sem_v6.training.master_teacher import create_master_teacher, MasterTeacher |
|
|
|
|
| @dataclass |
| class CachedSample: |
| """A single cached sample with teacher logits.""" |
|
|
| input_ids: torch.Tensor |
| labels: torch.Tensor |
| teacher_logits: torch.Tensor |
|
|
|
|
| class StreamingTeacherCache: |
| """ |
| Streaming cache with async background fetching. |
| |
| The cache uses a ring buffer that grows up to max_samples. |
| A background producer continuously fetches new samples. |
| The training dataloader consumes from the cache. |
| |
| Args: |
| tokenizer: HuggingFace tokenizer |
| max_samples: Maximum cache size (default: 50000) |
| batch_size: Batch size for API calls (default: 4) |
| max_length: Maximum sequence length (default: 64) |
| sample_positions: Which positions to sample for teacher (default: last 8) |
| temperature: Distillation temperature (default: 2.0) |
| providers: Which providers to use (default: all except openrouter) |
| """ |
|
|
| def __init__( |
| self, |
| tokenizer, |
| max_samples: int = 50000, |
| batch_size: int = 4, |
| max_length: int = 64, |
| sample_positions: Optional[list[int]] = None, |
| temperature: float = 2.0, |
| providers: Optional[list[str]] = None, |
| concurrent_batches: int = 8, |
| max_consecutive_errors: int = 5, |
| ): |
| self.tokenizer = tokenizer |
| self.vocab_size = len(tokenizer) |
| self.max_samples = max_samples |
| self.batch_size = batch_size |
| self.max_length = max_length |
| self.sample_positions = sample_positions or list( |
| range(max_length - 8, max_length) |
| ) |
| self.temperature = temperature |
| self.concurrent_batches = concurrent_batches |
| self.max_consecutive_errors = max_consecutive_errors |
|
|
| |
| self.providers = providers or ["gemini", "cloudflare", "opencode"] |
|
|
| |
| self.cache: deque[CachedSample] = deque(maxlen=max_samples) |
| self.cache_lock = threading.Lock() |
|
|
| |
| self._stop_event = threading.Event() |
| self._producer_thread: Optional[threading.Thread] = None |
| self._samples_produced = 0 |
| self._samples_consumed = 0 |
|
|
| |
| self._data_iter: Optional[Iterator] = None |
|
|
| |
| self.stats = { |
| "produced": 0, |
| "consumed": 0, |
| "cache_size": 0, |
| "producer_errors": 0, |
| } |
| self._consecutive_errors = 0 |
|
|
| def _create_data_iterator(self) -> Iterator: |
| """Create streaming iterator from OpenWebText.""" |
| dataset = load_dataset("openwebtext", split="train", streaming=True) |
|
|
| for sample in dataset: |
| text = sample["text"] |
| if len(text) < 100: |
| continue |
|
|
| tokens = self.tokenizer.encode( |
| text, max_length=self.max_length, truncation=True |
| ) |
| if len(tokens) < 50: |
| continue |
|
|
| |
| if len(tokens) < self.max_length: |
| tokens = tokens + [self.tokenizer.eos_token_id or 0] * ( |
| self.max_length - len(tokens) |
| ) |
| else: |
| tokens = tokens[: self.max_length] |
|
|
| input_ids = torch.tensor(tokens) |
| labels = torch.tensor(tokens[1:] + [self.tokenizer.eos_token_id or 0]) |
|
|
| yield input_ids, labels |
|
|
| async def _fetch_batch_async( |
| self, |
| teacher: MasterTeacher, |
| batch_inputs: list[torch.Tensor], |
| batch_labels: list[torch.Tensor], |
| ) -> list[CachedSample]: |
| """Fetch teacher logits for a batch asynchronously.""" |
| x = torch.stack(batch_inputs) |
|
|
| try: |
| teacher_logits = await teacher.get_teacher_logits_for_batch_async( |
| x, |
| sample_positions=self.sample_positions, |
| temperature=self.temperature, |
| ) |
|
|
| samples = [] |
| for i in range(len(batch_inputs)): |
| samples.append( |
| CachedSample( |
| input_ids=batch_inputs[i], |
| labels=batch_labels[i], |
| teacher_logits=teacher_logits[i], |
| ) |
| ) |
| self._consecutive_errors = 0 |
| return samples |
|
|
| except Exception as e: |
| self.stats["producer_errors"] += 1 |
| self._consecutive_errors += 1 |
| if self._consecutive_errors >= self.max_consecutive_errors: |
| raise RuntimeError( |
| f"StreamingTeacherCache failed {self._consecutive_errors} consecutive fetches" |
| ) from e |
| return [] |
|
|
| async def _producer_loop_async(self, teacher: MasterTeacher): |
| """Async producer loop that fetches samples.""" |
| if self._data_iter is None: |
| self._data_iter = self._create_data_iterator() |
|
|
| semaphore = asyncio.Semaphore(self.concurrent_batches) |
|
|
| async def process_batch(batch_inputs, batch_labels): |
| async with semaphore: |
| return await self._fetch_batch_async( |
| teacher, batch_inputs, batch_labels |
| ) |
|
|
| while not self._stop_event.is_set(): |
| |
| batch_inputs = [] |
| batch_labels = [] |
|
|
| try: |
| for _ in range(self.batch_size): |
| input_ids, labels = next(self._data_iter) |
| batch_inputs.append(input_ids) |
| batch_labels.append(labels) |
| except StopIteration: |
| |
| self._data_iter = self._create_data_iterator() |
| continue |
|
|
| |
| samples = await process_batch(batch_inputs, batch_labels) |
|
|
| |
| if samples: |
| with self.cache_lock: |
| for sample in samples: |
| self.cache.append(sample) |
| self.stats["produced"] += len(samples) |
| self.stats["cache_size"] = len(self.cache) |
|
|
| |
| await asyncio.sleep(0.001) |
|
|
| def _producer_thread_main(self): |
| """Thread entry point for producer.""" |
| teacher = create_master_teacher( |
| self.tokenizer, |
| providers=self.providers, |
| ) |
|
|
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
|
|
| try: |
| loop.run_until_complete(self._producer_loop_async(teacher)) |
| finally: |
| loop.close() |
|
|
| async def warmup(self, num_samples: int = 10000) -> None: |
| """ |
| Pre-cache initial samples before training starts. |
| |
| Args: |
| num_samples: Number of samples to pre-cache (default: 10000) |
| """ |
| print(f"[StreamingCache] Warming up with {num_samples} samples...") |
| start = time.time() |
|
|
| teacher = create_master_teacher( |
| self.tokenizer, |
| providers=self.providers, |
| ) |
|
|
| self._data_iter = self._create_data_iterator() |
|
|
| samples_fetched = 0 |
| batch_count = 0 |
|
|
| while samples_fetched < num_samples: |
| |
| batch_inputs = [] |
| batch_labels = [] |
|
|
| for _ in range(self.batch_size): |
| try: |
| input_ids, labels = next(self._data_iter) |
| batch_inputs.append(input_ids) |
| batch_labels.append(labels) |
| except StopIteration: |
| break |
|
|
| if not batch_inputs: |
| break |
|
|
| |
| samples = await self._fetch_batch_async(teacher, batch_inputs, batch_labels) |
|
|
| |
| for sample in samples: |
| self.cache.append(sample) |
| samples_fetched += 1 |
|
|
| batch_count += 1 |
| if batch_count % 50 == 0: |
| elapsed = time.time() - start |
| rate = samples_fetched / elapsed * 60 |
| print(f" [{samples_fetched}/{num_samples}] {rate:.1f} samples/min") |
|
|
| elapsed = time.time() - start |
| self.stats["produced"] = len(self.cache) |
| self.stats["cache_size"] = len(self.cache) |
|
|
| print( |
| f"[StreamingCache] Warmup complete: {len(self.cache)} samples in {elapsed:.1f}s" |
| ) |
| print(f" Provider stats: {teacher.stats()}") |
|
|
| def start_background_fetching(self) -> None: |
| """Start the background producer thread.""" |
| if self._producer_thread is not None: |
| return |
|
|
| self._stop_event.clear() |
| self._producer_thread = threading.Thread( |
| target=self._producer_thread_main, |
| daemon=True, |
| name="TeacherCacheProducer", |
| ) |
| self._producer_thread.start() |
| print("[StreamingCache] Background producer started") |
|
|
| def stop_background_fetching(self) -> None: |
| """Stop the background producer thread.""" |
| if self._producer_thread is None: |
| return |
|
|
| self._stop_event.set() |
| self._producer_thread.join(timeout=5.0) |
| self._producer_thread = None |
| print("[StreamingCache] Background producer stopped") |
|
|
| def get_batch(self, batch_size: int) -> Optional[dict[str, torch.Tensor]]: |
| """ |
| Get a batch of samples from the cache. |
| |
| Args: |
| batch_size: Number of samples to get |
| |
| Returns: |
| Dict with input_ids, labels, teacher_logits or None if cache empty |
| """ |
| with self.cache_lock: |
| if len(self.cache) < batch_size: |
| return None |
|
|
| |
| indices = torch.randperm(len(self.cache))[:batch_size] |
| samples = [self.cache[i] for i in indices] |
|
|
| self.stats["consumed"] += batch_size |
|
|
| return { |
| "input_ids": torch.stack([s.input_ids for s in samples]), |
| "labels": torch.stack([s.labels for s in samples]), |
| "teacher_logits": torch.stack([s.teacher_logits for s in samples]), |
| } |
|
|
| def __len__(self) -> int: |
| """Current cache size.""" |
| return len(self.cache) |
|
|
| def __iter__(self): |
| """Iterate over cache for DataLoader compatibility.""" |
| with self.cache_lock: |
| for sample in self.cache: |
| yield { |
| "input_ids": sample.input_ids, |
| "labels": sample.labels, |
| "teacher_logits": sample.teacher_logits, |
| } |
|
|
| def get_stats(self) -> dict[str, Any]: |
| """Get cache statistics.""" |
| return { |
| **self.stats, |
| "cache_size": len(self.cache), |
| "producer_running": self._producer_thread is not None |
| and self._producer_thread.is_alive(), |
| } |
|
|
|
|
| class StreamingCacheDataset(torch.utils.data.IterableDataset): |
| """ |
| PyTorch IterableDataset wrapper for StreamingTeacherCache. |
| |
| This allows the cache to be used with PyTorch DataLoader. |
| |
| Args: |
| cache: StreamingTeacherCache instance |
| batch_size: Batch size for training |
| min_cache_size: Minimum samples in cache before yielding (default: 1000) |
| """ |
|
|
| def __init__( |
| self, |
| cache: StreamingTeacherCache, |
| batch_size: int = 16, |
| min_cache_size: int = 1000, |
| ): |
| self.cache = cache |
| self.batch_size = batch_size |
| self.min_cache_size = min_cache_size |
|
|
| def __iter__(self): |
| """Yield batches from cache, waiting if needed.""" |
| while True: |
| |
| while len(self.cache) < self.min_cache_size: |
| time.sleep(0.1) |
|
|
| |
| batch = self.cache.get_batch(self.batch_size) |
| if batch is not None: |
| yield batch |
|
|
|
|
| if __name__ == "__main__": |
| |
| import asyncio |
|
|
| print("Testing StreamingTeacherCache...") |
|
|
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B") |
|
|
| cache = StreamingTeacherCache( |
| tokenizer=tokenizer, |
| max_samples=1000, |
| batch_size=4, |
| max_length=64, |
| providers=["gemini", "cloudflare", "opencode"], |
| ) |
|
|
| |
| asyncio.run(cache.warmup(num_samples=100)) |
|
|
| print(f"\nCache stats: {cache.get_stats()}") |
|
|
| |
| cache.start_background_fetching() |
|
|
| |
| print("\nSimulating training for 10 seconds...") |
| start = time.time() |
| step = 0 |
|
|
| while time.time() - start < 10: |
| batch = cache.get_batch(batch_size=8) |
| if batch is not None: |
| step += 1 |
| if step % 10 == 0: |
| stats = cache.get_stats() |
| print( |
| f"Step {step}: cache_size={stats['cache_size']}, produced={stats['produced']}" |
| ) |
| time.sleep(0.1) |
|
|
| cache.stop_background_fetching() |
|
|
| print(f"\nFinal stats: {cache.get_stats()}") |
| print("✅ StreamingTeacherCache test complete!") |
|
|