AGILLM-3-backup / stream_loader.py
OpenTransformer's picture
Upload stream_loader.py with huggingface_hub
59bcafe verified
#!/usr/bin/env python3
"""
Custom streaming data loader for AGILLM training
Pulls from stream_server on scraper box via HTTP
Drop-in replacement for HuggingFace dataset streaming
"""
import requests
import json
from typing import Iterator, Dict, Any
class ScraperStreamDataset:
"""
Streams training data from the scraper server.
Compatible with AGILLM's _stream() interface.
"""
def __init__(
self,
server_url: str = "http://localhost:8888", # Will be SSH tunneled
batch_size: int = 100,
text_field: str = "text",
shuffle: bool = True
):
self.server_url = server_url
self.batch_size = batch_size
self.text_field = text_field
self.shuffle = shuffle
self._buffer = []
def __iter__(self) -> Iterator[Dict[str, Any]]:
return self
def __next__(self) -> Dict[str, Any]:
if not self._buffer:
self._fetch_batch()
if not self._buffer:
raise StopIteration
return self._buffer.pop(0)
def _fetch_batch(self):
"""Fetch a batch from stream server"""
endpoint = "/stream" if self.shuffle else "/sequential"
try:
resp = requests.get(
f"{self.server_url}{endpoint}",
params={"batch": self.batch_size},
stream=True,
timeout=30
)
for line in resp.iter_lines():
if line:
try:
obj = json.loads(line.decode('utf-8'))
# Return in format trainer expects
self._buffer.append({self.text_field: obj.get("text", "")})
except json.JSONDecodeError:
continue
except requests.RequestException as e:
print(f"[StreamLoader] Fetch error: {e}")
def get_status(self) -> dict:
"""Get server status"""
try:
resp = requests.get(f"{self.server_url}/status", timeout=10)
return resp.json()
except:
return {"error": "unreachable"}
def create_stream_iterator(server_url: str = "http://localhost:8888", seed: int = 42):
"""
Create iterator compatible with AGILLM's _stream() function.
Returns infinite iterator of {"text": "..."} dicts.
"""
dataset = ScraperStreamDataset(server_url=server_url)
while True:
try:
yield next(dataset)
except StopIteration:
# Refill and continue
dataset._fetch_batch()
if dataset._buffer:
yield dataset._buffer.pop(0)
# For testing
if __name__ == "__main__":
import sys
url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8888"
print(f"Testing stream from {url}")
ds = ScraperStreamDataset(server_url=url, batch_size=5)
print(f"Status: {ds.get_status()}")
for i, item in enumerate(ds):
text = item["text"]
print(f"Sample {i}: {len(text)} chars - {text[:100]}...")
if i >= 4:
break