OpenTransformer commited on
Commit
93b6ddd
·
verified ·
1 Parent(s): 9429964

Backup script stream_loader.py

Browse files
Files changed (1) hide show
  1. scripts/stream_loader.py +98 -0
scripts/stream_loader.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Custom streaming data loader for AGILLM training
4
+ Pulls from stream_server on scraper box via HTTP
5
+ Drop-in replacement for HuggingFace dataset streaming
6
+ """
7
+ import requests
8
+ import json
9
+ from typing import Iterator, Dict, Any
10
+
11
+ class ScraperStreamDataset:
12
+ """
13
+ Streams training data from the scraper server.
14
+ Compatible with AGILLM's _stream() interface.
15
+ """
16
+ def __init__(
17
+ self,
18
+ server_url: str = "http://localhost:8888", # Will be SSH tunneled
19
+ batch_size: int = 100,
20
+ text_field: str = "text",
21
+ shuffle: bool = True
22
+ ):
23
+ self.server_url = server_url
24
+ self.batch_size = batch_size
25
+ self.text_field = text_field
26
+ self.shuffle = shuffle
27
+ self._buffer = []
28
+
29
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
30
+ return self
31
+
32
+ def __next__(self) -> Dict[str, Any]:
33
+ if not self._buffer:
34
+ self._fetch_batch()
35
+ if not self._buffer:
36
+ raise StopIteration
37
+ return self._buffer.pop(0)
38
+
39
+ def _fetch_batch(self):
40
+ """Fetch a batch from stream server"""
41
+ endpoint = "/stream" if self.shuffle else "/sequential"
42
+ try:
43
+ resp = requests.get(
44
+ f"{self.server_url}{endpoint}",
45
+ params={"batch": self.batch_size},
46
+ stream=True,
47
+ timeout=30
48
+ )
49
+ for line in resp.iter_lines():
50
+ if line:
51
+ try:
52
+ obj = json.loads(line.decode('utf-8'))
53
+ # Return in format trainer expects
54
+ self._buffer.append({self.text_field: obj.get("text", "")})
55
+ except json.JSONDecodeError:
56
+ continue
57
+ except requests.RequestException as e:
58
+ print(f"[StreamLoader] Fetch error: {e}")
59
+
60
+ def get_status(self) -> dict:
61
+ """Get server status"""
62
+ try:
63
+ resp = requests.get(f"{self.server_url}/status", timeout=10)
64
+ return resp.json()
65
+ except:
66
+ return {"error": "unreachable"}
67
+
68
+
69
+ def create_stream_iterator(server_url: str = "http://localhost:8888", seed: int = 42):
70
+ """
71
+ Create iterator compatible with AGILLM's _stream() function.
72
+ Returns infinite iterator of {"text": "..."} dicts.
73
+ """
74
+ dataset = ScraperStreamDataset(server_url=server_url)
75
+ while True:
76
+ try:
77
+ yield next(dataset)
78
+ except StopIteration:
79
+ # Refill and continue
80
+ dataset._fetch_batch()
81
+ if dataset._buffer:
82
+ yield dataset._buffer.pop(0)
83
+
84
+
85
+ # For testing
86
+ if __name__ == "__main__":
87
+ import sys
88
+ url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8888"
89
+ print(f"Testing stream from {url}")
90
+
91
+ ds = ScraperStreamDataset(server_url=url, batch_size=5)
92
+ print(f"Status: {ds.get_status()}")
93
+
94
+ for i, item in enumerate(ds):
95
+ text = item["text"]
96
+ print(f"Sample {i}: {len(text)} chars - {text[:100]}...")
97
+ if i >= 4:
98
+ break