#!/usr/bin/env python3 """ Custom streaming data loader for AGILLM training Pulls from stream_server on scraper box via HTTP Drop-in replacement for HuggingFace dataset streaming """ import requests import json from typing import Iterator, Dict, Any class ScraperStreamDataset: """ Streams training data from the scraper server. Compatible with AGILLM's _stream() interface. """ def __init__( self, server_url: str = "http://localhost:8888", # Will be SSH tunneled batch_size: int = 100, text_field: str = "text", shuffle: bool = True ): self.server_url = server_url self.batch_size = batch_size self.text_field = text_field self.shuffle = shuffle self._buffer = [] def __iter__(self) -> Iterator[Dict[str, Any]]: return self def __next__(self) -> Dict[str, Any]: if not self._buffer: self._fetch_batch() if not self._buffer: raise StopIteration return self._buffer.pop(0) def _fetch_batch(self): """Fetch a batch from stream server""" endpoint = "/stream" if self.shuffle else "/sequential" try: resp = requests.get( f"{self.server_url}{endpoint}", params={"batch": self.batch_size}, stream=True, timeout=30 ) for line in resp.iter_lines(): if line: try: obj = json.loads(line.decode('utf-8')) # Return in format trainer expects self._buffer.append({self.text_field: obj.get("text", "")}) except json.JSONDecodeError: continue except requests.RequestException as e: print(f"[StreamLoader] Fetch error: {e}") def get_status(self) -> dict: """Get server status""" try: resp = requests.get(f"{self.server_url}/status", timeout=10) return resp.json() except: return {"error": "unreachable"} def create_stream_iterator(server_url: str = "http://localhost:8888", seed: int = 42): """ Create iterator compatible with AGILLM's _stream() function. Returns infinite iterator of {"text": "..."} dicts. """ dataset = ScraperStreamDataset(server_url=server_url) while True: try: yield next(dataset) except StopIteration: # Refill and continue dataset._fetch_batch() if dataset._buffer: yield dataset._buffer.pop(0) # For testing if __name__ == "__main__": import sys url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8888" print(f"Testing stream from {url}") ds = ScraperStreamDataset(server_url=url, batch_size=5) print(f"Status: {ds.get_status()}") for i, item in enumerate(ds): text = item["text"] print(f"Sample {i}: {len(text)} chars - {text[:100]}...") if i >= 4: break