File size: 3,108 Bytes
59bcafe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python3
"""
Custom streaming data loader for AGILLM training
Pulls from stream_server on scraper box via HTTP
Drop-in replacement for HuggingFace dataset streaming
"""
import requests
import json
from typing import Iterator, Dict, Any

class ScraperStreamDataset:
    """
    Streams training data from the scraper server.
    Compatible with AGILLM's _stream() interface.
    """
    def __init__(
        self,
        server_url: str = "http://localhost:8888",  # Will be SSH tunneled
        batch_size: int = 100,
        text_field: str = "text",
        shuffle: bool = True
    ):
        self.server_url = server_url
        self.batch_size = batch_size
        self.text_field = text_field
        self.shuffle = shuffle
        self._buffer = []
    
    def __iter__(self) -> Iterator[Dict[str, Any]]:
        return self
    
    def __next__(self) -> Dict[str, Any]:
        if not self._buffer:
            self._fetch_batch()
        if not self._buffer:
            raise StopIteration
        return self._buffer.pop(0)
    
    def _fetch_batch(self):
        """Fetch a batch from stream server"""
        endpoint = "/stream" if self.shuffle else "/sequential"
        try:
            resp = requests.get(
                f"{self.server_url}{endpoint}",
                params={"batch": self.batch_size},
                stream=True,
                timeout=30
            )
            for line in resp.iter_lines():
                if line:
                    try:
                        obj = json.loads(line.decode('utf-8'))
                        # Return in format trainer expects
                        self._buffer.append({self.text_field: obj.get("text", "")})
                    except json.JSONDecodeError:
                        continue
        except requests.RequestException as e:
            print(f"[StreamLoader] Fetch error: {e}")
    
    def get_status(self) -> dict:
        """Get server status"""
        try:
            resp = requests.get(f"{self.server_url}/status", timeout=10)
            return resp.json()
        except:
            return {"error": "unreachable"}


def create_stream_iterator(server_url: str = "http://localhost:8888", seed: int = 42):
    """
    Create iterator compatible with AGILLM's _stream() function.
    Returns infinite iterator of {"text": "..."} dicts.
    """
    dataset = ScraperStreamDataset(server_url=server_url)
    while True:
        try:
            yield next(dataset)
        except StopIteration:
            # Refill and continue
            dataset._fetch_batch()
            if dataset._buffer:
                yield dataset._buffer.pop(0)


# For testing
if __name__ == "__main__":
    import sys
    url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8888"
    print(f"Testing stream from {url}")
    
    ds = ScraperStreamDataset(server_url=url, batch_size=5)
    print(f"Status: {ds.get_status()}")
    
    for i, item in enumerate(ds):
        text = item["text"]
        print(f"Sample {i}: {len(text)} chars - {text[:100]}...")
        if i >= 4:
            break