File size: 7,164 Bytes
ae68385 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | """
collect_training_data.py — Poll the running Soci simulation and save raw
conversation + event data for later training.
Polls every POLL_INTERVAL seconds, deduplicates by conversation ID,
and writes JSONL to data/training/raw/.
Usage:
# Poll Render deployment (default):
"C:/Users/xabon/.conda/envs/ml-env/python.exe" scripts/collect_training_data.py
# Poll a different base URL:
"C:/Users/xabon/.conda/envs/ml-env/python.exe" scripts/collect_training_data.py --url http://localhost:8000
# Run once (no loop):
"C:/Users/xabon/.conda/envs/ml-env/python.exe" scripts/collect_training_data.py --once
"""
from __future__ import annotations
import argparse
import json
import os
import time
from datetime import datetime
from pathlib import Path
import urllib.request
import urllib.error
BASE_URL = "https://soci-tl3c.onrender.com"
POLL_INTERVAL = 30 # seconds
RAW_DIR = Path("data/training/raw")
RAW_DIR.mkdir(parents=True, exist_ok=True)
# Files to accumulate into
today = datetime.now().strftime("%Y%m%d")
CONV_FILE = RAW_DIR / f"conversations_{today}.jsonl"
EVENT_FILE = RAW_DIR / f"events_{today}.jsonl"
AGENT_CACHE_FILE = RAW_DIR / "agents_cache.json"
# In-memory dedup sets (also rehydrated from disk on startup)
_seen_conv_ids: set[str] = set()
_seen_event_ticks_msgs: set[str] = set()
def fetch_json(url: str, timeout: int = 15) -> dict | None:
try:
with urllib.request.urlopen(url, timeout=timeout) as resp:
return json.loads(resp.read().decode())
except urllib.error.URLError as e:
print(f" [WARN] fetch failed: {url} — {e}")
return None
except Exception as e:
print(f" [ERR] {url}: {e}")
return None
def load_seen_ids() -> None:
"""Rehydrate dedup sets from existing JSONL files."""
if CONV_FILE.exists():
with open(CONV_FILE, encoding="utf-8") as f:
for line in f:
try:
d = json.loads(line)
cid = d.get("id", "")
if cid:
_seen_conv_ids.add(cid)
except json.JSONDecodeError:
pass
if EVENT_FILE.exists():
with open(EVENT_FILE, encoding="utf-8") as f:
for line in f:
try:
d = json.loads(line)
key = f"{d.get('tick','')}|{d.get('message','')}"
_seen_event_ticks_msgs.add(key)
except json.JSONDecodeError:
pass
print(f" Loaded dedup: {len(_seen_conv_ids)} convs, {len(_seen_event_ticks_msgs)} events")
def poll_conversations(base_url: str) -> int:
"""Fetch conversation history and save new ones. Returns count of new convs."""
data = fetch_json(f"{base_url}/api/conversations?limit=200&include_history=true")
if data is None:
return 0
new_count = 0
with open(CONV_FILE, "a", encoding="utf-8") as f:
for conv in data.get("active", []) + data.get("recent", []):
cid = conv.get("id", "")
if not cid or cid in _seen_conv_ids:
continue
if len(conv.get("turns", [])) < 2:
# Skip single-turn (incomplete) conversations
continue
conv["_collected_at"] = datetime.now().isoformat()
conv["_source"] = "api"
f.write(json.dumps(conv, ensure_ascii=False) + "\n")
_seen_conv_ids.add(cid)
new_count += 1
return new_count
def poll_events(base_url: str) -> int:
"""Fetch recent events and save new ones. Returns count of new events."""
data = fetch_json(f"{base_url}/api/events?limit=500")
if data is None:
return 0
new_count = 0
with open(EVENT_FILE, "a", encoding="utf-8") as f:
for event in data.get("events", []):
key = f"{event.get('tick','')}|{event.get('message','')}"
if key in _seen_event_ticks_msgs:
continue
event["_collected_at"] = datetime.now().isoformat()
f.write(json.dumps(event, ensure_ascii=False) + "\n")
_seen_event_ticks_msgs.add(key)
new_count += 1
return new_count
def refresh_agent_cache(base_url: str) -> None:
"""Refresh the local agent persona cache (done once per session)."""
agents_data = fetch_json(f"{base_url}/api/agents")
if not agents_data:
return
# Fetch full detail for named (non-generated) agents
full_agents = {}
for aid in agents_data:
detail = fetch_json(f"{base_url}/api/agents/{aid}")
if detail:
full_agents[aid] = detail
time.sleep(0.2) # Be gentle to the API
AGENT_CACHE_FILE.write_text(
json.dumps(full_agents, indent=2, ensure_ascii=False), encoding="utf-8"
)
print(f" Agent cache refreshed: {len(full_agents)} agents -> {AGENT_CACHE_FILE}")
def print_stats() -> None:
conv_count = 0
if CONV_FILE.exists():
with open(CONV_FILE, encoding="utf-8") as f:
conv_count = sum(1 for line in f if line.strip())
ev_count = 0
if EVENT_FILE.exists():
with open(EVENT_FILE, encoding="utf-8") as f:
ev_count = sum(1 for line in f if line.strip())
print(f" Stats: {conv_count} convs, {ev_count} events saved")
def run(base_url: str, once: bool = False, skip_agent_cache: bool = False) -> None:
print(f"Soci Training Data Collector")
print(f" Target: {base_url}")
print(f" Output: {RAW_DIR.resolve()}")
print(f" Poll interval: {POLL_INTERVAL}s")
load_seen_ids()
if not skip_agent_cache:
print(" Refreshing agent cache...")
refresh_agent_cache(base_url)
iteration = 0
try:
while True:
iteration += 1
ts = datetime.now().strftime("%H:%M:%S")
new_convs = poll_conversations(base_url)
new_events = poll_events(base_url)
print(f"[{ts}] iter={iteration} +{new_convs} convs +{new_events} events "
f"(total: {len(_seen_conv_ids)} convs, {len(_seen_event_ticks_msgs)} events)")
if once:
break
time.sleep(POLL_INTERVAL)
except KeyboardInterrupt:
print("\nStopped by user.")
print_stats()
print(f"\nConversations: {CONV_FILE}")
print(f"Events: {EVENT_FILE}")
print(f"Agent cache: {AGENT_CACHE_FILE}")
print(f"\nNext step: run python scripts/convert_to_training_jsonl.py")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Soci training data collector")
parser.add_argument("--url", default=BASE_URL, help="Base URL of the Soci API")
parser.add_argument("--once", action="store_true", help="Run a single poll and exit")
parser.add_argument("--no-agent-cache", action="store_true", help="Skip agent cache refresh")
parser.add_argument("--interval", type=int, default=POLL_INTERVAL,
help="Poll interval in seconds (default 30)")
args = parser.parse_args()
POLL_INTERVAL = args.interval
run(args.url, once=args.once, skip_agent_cache=args.no_agent_cache)
|