|
|
|
|
|
""" |
|
|
Custom streaming data loader for AGILLM training |
|
|
Pulls from stream_server on scraper box via HTTP |
|
|
Drop-in replacement for HuggingFace dataset streaming |
|
|
""" |
|
|
import requests |
|
|
import json |
|
|
from typing import Iterator, Dict, Any |
|
|
|
|
|
class ScraperStreamDataset: |
|
|
""" |
|
|
Streams training data from the scraper server. |
|
|
Compatible with AGILLM's _stream() interface. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
server_url: str = "http://localhost:8888", |
|
|
batch_size: int = 100, |
|
|
text_field: str = "text", |
|
|
shuffle: bool = True |
|
|
): |
|
|
self.server_url = server_url |
|
|
self.batch_size = batch_size |
|
|
self.text_field = text_field |
|
|
self.shuffle = shuffle |
|
|
self._buffer = [] |
|
|
|
|
|
def __iter__(self) -> Iterator[Dict[str, Any]]: |
|
|
return self |
|
|
|
|
|
def __next__(self) -> Dict[str, Any]: |
|
|
if not self._buffer: |
|
|
self._fetch_batch() |
|
|
if not self._buffer: |
|
|
raise StopIteration |
|
|
return self._buffer.pop(0) |
|
|
|
|
|
def _fetch_batch(self): |
|
|
"""Fetch a batch from stream server""" |
|
|
endpoint = "/stream" if self.shuffle else "/sequential" |
|
|
try: |
|
|
resp = requests.get( |
|
|
f"{self.server_url}{endpoint}", |
|
|
params={"batch": self.batch_size}, |
|
|
stream=True, |
|
|
timeout=30 |
|
|
) |
|
|
for line in resp.iter_lines(): |
|
|
if line: |
|
|
try: |
|
|
obj = json.loads(line.decode('utf-8')) |
|
|
|
|
|
self._buffer.append({self.text_field: obj.get("text", "")}) |
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
except requests.RequestException as e: |
|
|
print(f"[StreamLoader] Fetch error: {e}") |
|
|
|
|
|
def get_status(self) -> dict: |
|
|
"""Get server status""" |
|
|
try: |
|
|
resp = requests.get(f"{self.server_url}/status", timeout=10) |
|
|
return resp.json() |
|
|
except: |
|
|
return {"error": "unreachable"} |
|
|
|
|
|
|
|
|
def create_stream_iterator(server_url: str = "http://localhost:8888", seed: int = 42): |
|
|
""" |
|
|
Create iterator compatible with AGILLM's _stream() function. |
|
|
Returns infinite iterator of {"text": "..."} dicts. |
|
|
""" |
|
|
dataset = ScraperStreamDataset(server_url=server_url) |
|
|
while True: |
|
|
try: |
|
|
yield next(dataset) |
|
|
except StopIteration: |
|
|
|
|
|
dataset._fetch_batch() |
|
|
if dataset._buffer: |
|
|
yield dataset._buffer.pop(0) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8888" |
|
|
print(f"Testing stream from {url}") |
|
|
|
|
|
ds = ScraperStreamDataset(server_url=url, batch_size=5) |
|
|
print(f"Status: {ds.get_status()}") |
|
|
|
|
|
for i, item in enumerate(ds): |
|
|
text = item["text"] |
|
|
print(f"Sample {i}: {len(text)} chars - {text[:100]}...") |
|
|
if i >= 4: |
|
|
break |
|
|
|