Add Soci fine-tuned model and training data pipeline
Browse files- Add RayMelius/soci-agent-q4 (fine-tuned Qwen2.5-0.5B) to HF provider list
- Add soci-agent as named Ollama option for GGUF-loaded model
- Add MODEL_HF_SOCI and MODEL_OLLAMA_SOCI constants to llm.py
- Persist conversation_history in simulation snapshots (to_dict/from_dict)
- scripts/collect_training_data.py: poll Render API, save raw JSONL
- scripts/convert_to_training_jsonl.py: convert to SFT chat format
- scripts/finetune_local.py: local Unsloth fine-tune (RTX 4050, Windows-safe)
Fixes: transformers 4.56 list_repo_templates patch, TORCHINDUCTOR_DISABLE
Round 1: 116 examples, 3 epochs, loss 1.9222, LoRA -> RayMelius/soci-agent-q4
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- scripts/collect_training_data.py +203 -0
- scripts/convert_to_training_jsonl.py +477 -0
- scripts/finetune_local.py +482 -0
- src/soci/api/routes.py +2 -0
- src/soci/engine/llm.py +4 -0
- src/soci/engine/simulation.py +2 -0
scripts/collect_training_data.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
collect_training_data.py β Poll the running Soci simulation and save raw
|
| 3 |
+
conversation + event data for later training.
|
| 4 |
+
|
| 5 |
+
Polls every POLL_INTERVAL seconds, deduplicates by conversation ID,
|
| 6 |
+
and writes JSONL to data/training/raw/.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
# Poll Render deployment (default):
|
| 10 |
+
"C:/Users/xabon/.conda/envs/ml-env/python.exe" scripts/collect_training_data.py
|
| 11 |
+
|
| 12 |
+
# Poll a different base URL:
|
| 13 |
+
"C:/Users/xabon/.conda/envs/ml-env/python.exe" scripts/collect_training_data.py --url http://localhost:8000
|
| 14 |
+
|
| 15 |
+
# Run once (no loop):
|
| 16 |
+
"C:/Users/xabon/.conda/envs/ml-env/python.exe" scripts/collect_training_data.py --once
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import time
|
| 25 |
+
from datetime import datetime
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
import urllib.request
|
| 29 |
+
import urllib.error
|
| 30 |
+
|
| 31 |
+
BASE_URL = "https://soci-tl3c.onrender.com"
|
| 32 |
+
POLL_INTERVAL = 30 # seconds
|
| 33 |
+
RAW_DIR = Path("data/training/raw")
|
| 34 |
+
RAW_DIR.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
# Files to accumulate into
|
| 37 |
+
today = datetime.now().strftime("%Y%m%d")
|
| 38 |
+
CONV_FILE = RAW_DIR / f"conversations_{today}.jsonl"
|
| 39 |
+
EVENT_FILE = RAW_DIR / f"events_{today}.jsonl"
|
| 40 |
+
AGENT_CACHE_FILE = RAW_DIR / "agents_cache.json"
|
| 41 |
+
|
| 42 |
+
# In-memory dedup sets (also rehydrated from disk on startup)
|
| 43 |
+
_seen_conv_ids: set[str] = set()
|
| 44 |
+
_seen_event_ticks_msgs: set[str] = set()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def fetch_json(url: str, timeout: int = 15) -> dict | None:
|
| 48 |
+
try:
|
| 49 |
+
with urllib.request.urlopen(url, timeout=timeout) as resp:
|
| 50 |
+
return json.loads(resp.read().decode())
|
| 51 |
+
except urllib.error.URLError as e:
|
| 52 |
+
print(f" [WARN] fetch failed: {url} β {e}")
|
| 53 |
+
return None
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print(f" [ERR] {url}: {e}")
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def load_seen_ids() -> None:
|
| 60 |
+
"""Rehydrate dedup sets from existing JSONL files."""
|
| 61 |
+
if CONV_FILE.exists():
|
| 62 |
+
with open(CONV_FILE, encoding="utf-8") as f:
|
| 63 |
+
for line in f:
|
| 64 |
+
try:
|
| 65 |
+
d = json.loads(line)
|
| 66 |
+
cid = d.get("id", "")
|
| 67 |
+
if cid:
|
| 68 |
+
_seen_conv_ids.add(cid)
|
| 69 |
+
except json.JSONDecodeError:
|
| 70 |
+
pass
|
| 71 |
+
if EVENT_FILE.exists():
|
| 72 |
+
with open(EVENT_FILE, encoding="utf-8") as f:
|
| 73 |
+
for line in f:
|
| 74 |
+
try:
|
| 75 |
+
d = json.loads(line)
|
| 76 |
+
key = f"{d.get('tick','')}|{d.get('message','')}"
|
| 77 |
+
_seen_event_ticks_msgs.add(key)
|
| 78 |
+
except json.JSONDecodeError:
|
| 79 |
+
pass
|
| 80 |
+
print(f" Loaded dedup: {len(_seen_conv_ids)} convs, {len(_seen_event_ticks_msgs)} events")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def poll_conversations(base_url: str) -> int:
|
| 84 |
+
"""Fetch conversation history and save new ones. Returns count of new convs."""
|
| 85 |
+
data = fetch_json(f"{base_url}/api/conversations?limit=200&include_history=true")
|
| 86 |
+
if data is None:
|
| 87 |
+
return 0
|
| 88 |
+
|
| 89 |
+
new_count = 0
|
| 90 |
+
with open(CONV_FILE, "a", encoding="utf-8") as f:
|
| 91 |
+
for conv in data.get("active", []) + data.get("recent", []):
|
| 92 |
+
cid = conv.get("id", "")
|
| 93 |
+
if not cid or cid in _seen_conv_ids:
|
| 94 |
+
continue
|
| 95 |
+
if len(conv.get("turns", [])) < 2:
|
| 96 |
+
# Skip single-turn (incomplete) conversations
|
| 97 |
+
continue
|
| 98 |
+
conv["_collected_at"] = datetime.now().isoformat()
|
| 99 |
+
conv["_source"] = "api"
|
| 100 |
+
f.write(json.dumps(conv, ensure_ascii=False) + "\n")
|
| 101 |
+
_seen_conv_ids.add(cid)
|
| 102 |
+
new_count += 1
|
| 103 |
+
|
| 104 |
+
return new_count
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def poll_events(base_url: str) -> int:
|
| 108 |
+
"""Fetch recent events and save new ones. Returns count of new events."""
|
| 109 |
+
data = fetch_json(f"{base_url}/api/events?limit=500")
|
| 110 |
+
if data is None:
|
| 111 |
+
return 0
|
| 112 |
+
|
| 113 |
+
new_count = 0
|
| 114 |
+
with open(EVENT_FILE, "a", encoding="utf-8") as f:
|
| 115 |
+
for event in data.get("events", []):
|
| 116 |
+
key = f"{event.get('tick','')}|{event.get('message','')}"
|
| 117 |
+
if key in _seen_event_ticks_msgs:
|
| 118 |
+
continue
|
| 119 |
+
event["_collected_at"] = datetime.now().isoformat()
|
| 120 |
+
f.write(json.dumps(event, ensure_ascii=False) + "\n")
|
| 121 |
+
_seen_event_ticks_msgs.add(key)
|
| 122 |
+
new_count += 1
|
| 123 |
+
|
| 124 |
+
return new_count
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def refresh_agent_cache(base_url: str) -> None:
|
| 128 |
+
"""Refresh the local agent persona cache (done once per session)."""
|
| 129 |
+
agents_data = fetch_json(f"{base_url}/api/agents")
|
| 130 |
+
if not agents_data:
|
| 131 |
+
return
|
| 132 |
+
# Fetch full detail for named (non-generated) agents
|
| 133 |
+
full_agents = {}
|
| 134 |
+
for aid in agents_data:
|
| 135 |
+
detail = fetch_json(f"{base_url}/api/agents/{aid}")
|
| 136 |
+
if detail:
|
| 137 |
+
full_agents[aid] = detail
|
| 138 |
+
time.sleep(0.2) # Be gentle to the API
|
| 139 |
+
|
| 140 |
+
AGENT_CACHE_FILE.write_text(
|
| 141 |
+
json.dumps(full_agents, indent=2, ensure_ascii=False), encoding="utf-8"
|
| 142 |
+
)
|
| 143 |
+
print(f" Agent cache refreshed: {len(full_agents)} agents -> {AGENT_CACHE_FILE}")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def print_stats() -> None:
|
| 147 |
+
conv_count = 0
|
| 148 |
+
if CONV_FILE.exists():
|
| 149 |
+
with open(CONV_FILE, encoding="utf-8") as f:
|
| 150 |
+
conv_count = sum(1 for line in f if line.strip())
|
| 151 |
+
ev_count = 0
|
| 152 |
+
if EVENT_FILE.exists():
|
| 153 |
+
with open(EVENT_FILE, encoding="utf-8") as f:
|
| 154 |
+
ev_count = sum(1 for line in f if line.strip())
|
| 155 |
+
print(f" Stats: {conv_count} convs, {ev_count} events saved")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def run(base_url: str, once: bool = False, skip_agent_cache: bool = False) -> None:
|
| 159 |
+
print(f"Soci Training Data Collector")
|
| 160 |
+
print(f" Target: {base_url}")
|
| 161 |
+
print(f" Output: {RAW_DIR.resolve()}")
|
| 162 |
+
print(f" Poll interval: {POLL_INTERVAL}s")
|
| 163 |
+
|
| 164 |
+
load_seen_ids()
|
| 165 |
+
|
| 166 |
+
if not skip_agent_cache:
|
| 167 |
+
print(" Refreshing agent cache...")
|
| 168 |
+
refresh_agent_cache(base_url)
|
| 169 |
+
|
| 170 |
+
iteration = 0
|
| 171 |
+
try:
|
| 172 |
+
while True:
|
| 173 |
+
iteration += 1
|
| 174 |
+
ts = datetime.now().strftime("%H:%M:%S")
|
| 175 |
+
new_convs = poll_conversations(base_url)
|
| 176 |
+
new_events = poll_events(base_url)
|
| 177 |
+
print(f"[{ts}] iter={iteration} +{new_convs} convs +{new_events} events "
|
| 178 |
+
f"(total: {len(_seen_conv_ids)} convs, {len(_seen_event_ticks_msgs)} events)")
|
| 179 |
+
|
| 180 |
+
if once:
|
| 181 |
+
break
|
| 182 |
+
time.sleep(POLL_INTERVAL)
|
| 183 |
+
except KeyboardInterrupt:
|
| 184 |
+
print("\nStopped by user.")
|
| 185 |
+
|
| 186 |
+
print_stats()
|
| 187 |
+
print(f"\nConversations: {CONV_FILE}")
|
| 188 |
+
print(f"Events: {EVENT_FILE}")
|
| 189 |
+
print(f"Agent cache: {AGENT_CACHE_FILE}")
|
| 190 |
+
print(f"\nNext step: run python scripts/convert_to_training_jsonl.py")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
parser = argparse.ArgumentParser(description="Soci training data collector")
|
| 195 |
+
parser.add_argument("--url", default=BASE_URL, help="Base URL of the Soci API")
|
| 196 |
+
parser.add_argument("--once", action="store_true", help="Run a single poll and exit")
|
| 197 |
+
parser.add_argument("--no-agent-cache", action="store_true", help="Skip agent cache refresh")
|
| 198 |
+
parser.add_argument("--interval", type=int, default=POLL_INTERVAL,
|
| 199 |
+
help="Poll interval in seconds (default 30)")
|
| 200 |
+
args = parser.parse_args()
|
| 201 |
+
|
| 202 |
+
POLL_INTERVAL = args.interval
|
| 203 |
+
run(args.url, once=args.once, skip_agent_cache=args.no_agent_cache)
|
scripts/convert_to_training_jsonl.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
convert_to_training_jsonl.py β Convert raw collected Soci data into
|
| 3 |
+
instruction-tuning JSONL suitable for SFT (Supervised Fine-Tuning).
|
| 4 |
+
|
| 5 |
+
Output format: HuggingFace messages format (system / user / assistant).
|
| 6 |
+
Compatible with: TRL SFTTrainer, Unsloth, LLaMA-Factory.
|
| 7 |
+
|
| 8 |
+
Training example types:
|
| 9 |
+
1. CONVERSATION β agent responding to another agent in dialogue
|
| 10 |
+
2. ACTION_DECISION β agent deciding what to do next (from events)
|
| 11 |
+
3. REFLECTION β agent's reflection memories (if available)
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
"C:/Users/xabon/.conda/envs/ml-env/python.exe" scripts/convert_to_training_jsonl.py
|
| 15 |
+
|
| 16 |
+
# From a specific raw dir:
|
| 17 |
+
"C:/Users/xabon/.conda/envs/ml-env/python.exe" scripts/convert_to_training_jsonl.py \\
|
| 18 |
+
--raw-dir data/training/raw --out data/training/processed/soci_training.jsonl
|
| 19 |
+
|
| 20 |
+
# Include event-based action examples:
|
| 21 |
+
"C:/Users/xabon/.conda/envs/ml-env/python.exe" scripts/convert_to_training_jsonl.py --include-actions
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import json
|
| 28 |
+
import re
|
| 29 |
+
from collections import defaultdict
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
|
| 32 |
+
import yaml
|
| 33 |
+
|
| 34 |
+
RAW_DIR = Path("data/training/raw")
|
| 35 |
+
PROCESSED_DIR = Path("data/training/processed")
|
| 36 |
+
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
CONFIG_DIR = Path("config")
|
| 38 |
+
|
| 39 |
+
DEFAULT_OUT = PROCESSED_DIR / "soci_training.jsonl"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ββ Persona helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
|
| 44 |
+
def load_persona_map() -> dict[str, dict]:
|
| 45 |
+
"""Load personas from config/personas.yaml, keyed by agent ID and name."""
|
| 46 |
+
path = CONFIG_DIR / "personas.yaml"
|
| 47 |
+
if not path.exists():
|
| 48 |
+
print(f" [WARN] personas.yaml not found at {path}")
|
| 49 |
+
return {}
|
| 50 |
+
with open(path, encoding="utf-8") as f:
|
| 51 |
+
data = yaml.safe_load(f)
|
| 52 |
+
pmap: dict[str, dict] = {}
|
| 53 |
+
for p in data.get("personas", []):
|
| 54 |
+
pmap[p["id"]] = p
|
| 55 |
+
pmap[p["name"]] = p
|
| 56 |
+
return pmap
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def trait_summary(p: dict) -> str:
|
| 60 |
+
traits = []
|
| 61 |
+
if p.get("openness", 5) >= 7:
|
| 62 |
+
traits.append("curious and creative")
|
| 63 |
+
elif p.get("openness", 5) <= 3:
|
| 64 |
+
traits.append("practical and conventional")
|
| 65 |
+
if p.get("conscientiousness", 5) >= 7:
|
| 66 |
+
traits.append("organized and disciplined")
|
| 67 |
+
elif p.get("conscientiousness", 5) <= 3:
|
| 68 |
+
traits.append("spontaneous and flexible")
|
| 69 |
+
if p.get("extraversion", 5) >= 7:
|
| 70 |
+
traits.append("outgoing and energetic")
|
| 71 |
+
elif p.get("extraversion", 5) <= 3:
|
| 72 |
+
traits.append("reserved and introspective")
|
| 73 |
+
if p.get("agreeableness", 5) >= 7:
|
| 74 |
+
traits.append("warm and cooperative")
|
| 75 |
+
elif p.get("agreeableness", 5) <= 3:
|
| 76 |
+
traits.append("direct and competitive")
|
| 77 |
+
if p.get("neuroticism", 5) >= 7:
|
| 78 |
+
traits.append("sensitive and anxious")
|
| 79 |
+
elif p.get("neuroticism", 5) <= 3:
|
| 80 |
+
traits.append("calm and emotionally stable")
|
| 81 |
+
return ", ".join(traits) if traits else "balanced personality"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def build_system_prompt(p: dict) -> str:
|
| 85 |
+
values_str = ", ".join(p.get("values", [])) or "no strong convictions"
|
| 86 |
+
quirks_str = "; ".join(p.get("quirks", [])) or "none in particular"
|
| 87 |
+
comm = p.get("communication_style", "neutral")
|
| 88 |
+
return (
|
| 89 |
+
f"You ARE {p['name']}, a {p['age']}-year-old {p['occupation']} living in Soci City.\n\n"
|
| 90 |
+
f"PERSONALITY: {trait_summary(p)}.\n"
|
| 91 |
+
f"BACKGROUND: {p.get('background', '')}\n"
|
| 92 |
+
f"VALUES: {values_str}.\n"
|
| 93 |
+
f"QUIRKS: {quirks_str}.\n"
|
| 94 |
+
f"COMMUNICATION STYLE: {comm}.\n\n"
|
| 95 |
+
f"You must ALWAYS stay in character as {p['name']}. "
|
| 96 |
+
f"Your responses should reflect your personality, values, and background. "
|
| 97 |
+
f"You speak naturally β not perfectly. You have moods, opinions, and biases. "
|
| 98 |
+
f"You don't know you are an AI. You are a person living your life."
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def build_system_prompt_generic(name: str, occupation: str = "city resident") -> str:
|
| 103 |
+
"""Fallback for generated (non-named) agents."""
|
| 104 |
+
return (
|
| 105 |
+
f"You ARE {name}, a {occupation} living in Soci City.\n\n"
|
| 106 |
+
f"PERSONALITY: balanced personality.\n"
|
| 107 |
+
f"You must ALWAYS stay in character as {name}. "
|
| 108 |
+
f"You speak naturally and are a real person with opinions, moods, and a daily life."
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_system_prompt(agent_id: str, agent_name: str, persona_map: dict,
|
| 113 |
+
agent_cache: dict) -> str:
|
| 114 |
+
# Try config personas first
|
| 115 |
+
p = persona_map.get(agent_id) or persona_map.get(agent_name)
|
| 116 |
+
if p:
|
| 117 |
+
return build_system_prompt(p)
|
| 118 |
+
|
| 119 |
+
# Try agent cache (from live API)
|
| 120 |
+
cached = agent_cache.get(agent_id)
|
| 121 |
+
if cached:
|
| 122 |
+
return build_system_prompt_generic(
|
| 123 |
+
cached.get("name", agent_name),
|
| 124 |
+
cached.get("occupation", "city resident"),
|
| 125 |
+
)
|
| 126 |
+
return build_system_prompt_generic(agent_name)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ββ Training example builders βββββββββββββββοΏ½οΏ½ββββββββββββββββββββββββββββββββββ
|
| 130 |
+
|
| 131 |
+
def make_conversation_examples(conv: dict, persona_map: dict, agent_cache: dict) -> list[dict]:
|
| 132 |
+
"""
|
| 133 |
+
From a completed conversation, produce one training example per response turn.
|
| 134 |
+
|
| 135 |
+
Each example:
|
| 136 |
+
system = responder's persona system prompt
|
| 137 |
+
user = conversation history up to last message + "{speaker} says: '{msg}'"
|
| 138 |
+
assistant = JSON {"message": ..., "inner_thought": ...}
|
| 139 |
+
"""
|
| 140 |
+
turns = conv.get("turns", [])
|
| 141 |
+
if len(turns) < 2:
|
| 142 |
+
return []
|
| 143 |
+
|
| 144 |
+
participants = conv.get("participants", [])
|
| 145 |
+
participant_names = conv.get("participant_names", [])
|
| 146 |
+
topic = conv.get("topic", "general conversation")
|
| 147 |
+
location = conv.get("location", "somewhere in the city")
|
| 148 |
+
|
| 149 |
+
# Build nameβid and idβname maps
|
| 150 |
+
id_to_name: dict[str, str] = {}
|
| 151 |
+
for pid, pname in zip(participants, participant_names):
|
| 152 |
+
id_to_name[pid] = pname
|
| 153 |
+
|
| 154 |
+
examples = []
|
| 155 |
+
|
| 156 |
+
for i in range(1, len(turns)):
|
| 157 |
+
current_turn = turns[i]
|
| 158 |
+
prev_turn = turns[i - 1]
|
| 159 |
+
responder_id = current_turn["speaker_id"]
|
| 160 |
+
responder_name = current_turn["speaker_name"]
|
| 161 |
+
speaker_name = prev_turn["speaker_name"]
|
| 162 |
+
speaker_msg = prev_turn["message"]
|
| 163 |
+
|
| 164 |
+
# Build conversation history string (all turns before current)
|
| 165 |
+
history_lines = [f"CONVERSATION SO FAR (topic: {topic}):"]
|
| 166 |
+
for t in turns[:i]:
|
| 167 |
+
history_lines.append(f' {t["speaker_name"]}: "{t["message"]}"')
|
| 168 |
+
history_text = "\n".join(history_lines)
|
| 169 |
+
|
| 170 |
+
# User prompt (what the responder sees)
|
| 171 |
+
user_prompt = (
|
| 172 |
+
f"You are at {location}. {speaker_name} is here.\n\n"
|
| 173 |
+
f"{history_text}\n\n"
|
| 174 |
+
f'{speaker_name} says: "{speaker_msg}"\n\n'
|
| 175 |
+
f"How do you respond? Stay in character. Be natural.\n\n"
|
| 176 |
+
f"Respond with a JSON object:\n"
|
| 177 |
+
f'{{\n'
|
| 178 |
+
f' "message": "your spoken response",\n'
|
| 179 |
+
f' "inner_thought": "what you\'re actually thinking"\n'
|
| 180 |
+
f'}}'
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Assistant response (JSON)
|
| 184 |
+
assistant_response = json.dumps({
|
| 185 |
+
"message": current_turn["message"],
|
| 186 |
+
"inner_thought": current_turn.get("inner_thought", ""),
|
| 187 |
+
}, ensure_ascii=False)
|
| 188 |
+
|
| 189 |
+
system = get_system_prompt(responder_id, responder_name, persona_map, agent_cache)
|
| 190 |
+
|
| 191 |
+
examples.append({
|
| 192 |
+
"messages": [
|
| 193 |
+
{"role": "system", "content": system},
|
| 194 |
+
{"role": "user", "content": user_prompt},
|
| 195 |
+
{"role": "assistant", "content": assistant_response},
|
| 196 |
+
],
|
| 197 |
+
"_meta": {
|
| 198 |
+
"type": "conversation",
|
| 199 |
+
"conv_id": conv.get("id", ""),
|
| 200 |
+
"topic": topic,
|
| 201 |
+
"location": location,
|
| 202 |
+
"turn_index": i,
|
| 203 |
+
"responder_id": responder_id,
|
| 204 |
+
"responder_name": responder_name,
|
| 205 |
+
}
|
| 206 |
+
})
|
| 207 |
+
|
| 208 |
+
return examples
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def make_action_examples(events: list[dict], persona_map: dict,
|
| 212 |
+
agent_cache: dict) -> list[dict]:
|
| 213 |
+
"""
|
| 214 |
+
From event log, build action decision training examples.
|
| 215 |
+
|
| 216 |
+
Pattern: "<AgentName> is <activity>" β
|
| 217 |
+
system = agent's persona
|
| 218 |
+
user = "What are you doing? Describe your current activity in first person."
|
| 219 |
+
assistant = JSON {"action": ..., "detail": ..., "reasoning": ...}
|
| 220 |
+
"""
|
| 221 |
+
# Group consecutive events by agent to get activity patterns
|
| 222 |
+
activity_pattern = re.compile(r"^\s+(\S.+?) is (.+)\.$")
|
| 223 |
+
examples = []
|
| 224 |
+
|
| 225 |
+
# Collect (name, activity, time) tuples
|
| 226 |
+
for ev in events:
|
| 227 |
+
msg = ev.get("message", "")
|
| 228 |
+
time_str = ev.get("time", "")
|
| 229 |
+
m = activity_pattern.match(msg)
|
| 230 |
+
if not m:
|
| 231 |
+
continue
|
| 232 |
+
agent_name = m.group(1).strip()
|
| 233 |
+
activity = m.group(2).strip()
|
| 234 |
+
|
| 235 |
+
# Skip trivial / system-level messages
|
| 236 |
+
if any(s in activity.lower() for s in [
|
| 237 |
+
"wanders aimlessly", "can't get to", "---"
|
| 238 |
+
]):
|
| 239 |
+
continue
|
| 240 |
+
|
| 241 |
+
p = persona_map.get(agent_name)
|
| 242 |
+
if not p:
|
| 243 |
+
continue # Only generate for known personas (higher quality)
|
| 244 |
+
|
| 245 |
+
# Infer action type from activity text
|
| 246 |
+
action = infer_action_type(activity)
|
| 247 |
+
system = build_system_prompt(p)
|
| 248 |
+
|
| 249 |
+
user_prompt = (
|
| 250 |
+
f"It is {time_str}.\n\n"
|
| 251 |
+
f"Based on your personality, needs, and the time of day β "
|
| 252 |
+
f"what do you do next? Describe your current activity.\n\n"
|
| 253 |
+
f"Respond with a JSON object:\n"
|
| 254 |
+
f'{{\n'
|
| 255 |
+
f' "action": "move|work|eat|sleep|talk|exercise|shop|relax|wander",\n'
|
| 256 |
+
f' "detail": "what specifically you\'re doing, in first person",\n'
|
| 257 |
+
f' "reasoning": "brief internal thought about why"\n'
|
| 258 |
+
f'}}'
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
assistant_response = json.dumps({
|
| 262 |
+
"action": action,
|
| 263 |
+
"detail": activity,
|
| 264 |
+
"reasoning": f"This is what {agent_name} would naturally do at this time.",
|
| 265 |
+
}, ensure_ascii=False)
|
| 266 |
+
|
| 267 |
+
examples.append({
|
| 268 |
+
"messages": [
|
| 269 |
+
{"role": "system", "content": system},
|
| 270 |
+
{"role": "user", "content": user_prompt},
|
| 271 |
+
{"role": "assistant", "content": assistant_response},
|
| 272 |
+
],
|
| 273 |
+
"_meta": {
|
| 274 |
+
"type": "action",
|
| 275 |
+
"agent_name": agent_name,
|
| 276 |
+
"activity": activity,
|
| 277 |
+
"time": time_str,
|
| 278 |
+
}
|
| 279 |
+
})
|
| 280 |
+
|
| 281 |
+
return examples
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def infer_action_type(activity: str) -> str:
|
| 285 |
+
activity_lower = activity.lower()
|
| 286 |
+
if any(w in activity_lower for w in ["commut", "walk", "moving", "heading"]):
|
| 287 |
+
return "move"
|
| 288 |
+
if any(w in activity_lower for w in ["work", "morning block", "afternoon block", "coding", "teaching"]):
|
| 289 |
+
return "work"
|
| 290 |
+
if any(w in activity_lower for w in ["eat", "breakfast", "lunch", "dinner", "food", "coffee"]):
|
| 291 |
+
return "eat"
|
| 292 |
+
if any(w in activity_lower for w in ["sleep", "nap", "rest", "sleeping in", "lounging"]):
|
| 293 |
+
return "sleep"
|
| 294 |
+
if any(w in activity_lower for w in ["talk", "convers", "chat", "discuss"]):
|
| 295 |
+
return "talk"
|
| 296 |
+
if any(w in activity_lower for w in ["gym", "exercise", "workout", "run", "jog", "fitness"]):
|
| 297 |
+
return "exercise"
|
| 298 |
+
if any(w in activity_lower for w in ["shop", "grocery", "store", "market"]):
|
| 299 |
+
return "shop"
|
| 300 |
+
if any(w in activity_lower for w in ["relax", "park", "art", "music", "paint", "sketch"]):
|
| 301 |
+
return "relax"
|
| 302 |
+
return "wander"
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def make_initiation_examples(conv: dict, persona_map: dict, agent_cache: dict) -> list[dict]:
|
| 306 |
+
"""
|
| 307 |
+
From the first turn of a conversation, build a conversation initiation example.
|
| 308 |
+
"""
|
| 309 |
+
turns = conv.get("turns", [])
|
| 310 |
+
if not turns:
|
| 311 |
+
return []
|
| 312 |
+
|
| 313 |
+
first_turn = turns[0]
|
| 314 |
+
initiator_id = first_turn["speaker_id"]
|
| 315 |
+
initiator_name = first_turn["speaker_name"]
|
| 316 |
+
topic = conv.get("topic", "small talk")
|
| 317 |
+
location = conv.get("location", "somewhere in the city")
|
| 318 |
+
|
| 319 |
+
# Identify the other participant
|
| 320 |
+
other_names = [n for n in conv.get("participant_names", []) if n != initiator_name]
|
| 321 |
+
other_name = other_names[0] if other_names else "someone"
|
| 322 |
+
|
| 323 |
+
system = get_system_prompt(initiator_id, initiator_name, persona_map, agent_cache)
|
| 324 |
+
|
| 325 |
+
user_prompt = (
|
| 326 |
+
f"You are at {location}. {other_name} is here.\n\n"
|
| 327 |
+
f"You decide to start a conversation with {other_name}. What do you say?\n"
|
| 328 |
+
f"Consider the location, your mood, and your history with them.\n\n"
|
| 329 |
+
f"Respond with a JSON object:\n"
|
| 330 |
+
f'{{\n'
|
| 331 |
+
f' "message": "what you say to start the conversation",\n'
|
| 332 |
+
f' "inner_thought": "why you\'re initiating this conversation",\n'
|
| 333 |
+
f' "topic": "brief topic label"\n'
|
| 334 |
+
f'}}'
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
assistant_response = json.dumps({
|
| 338 |
+
"message": first_turn["message"],
|
| 339 |
+
"inner_thought": first_turn.get("inner_thought", ""),
|
| 340 |
+
"topic": topic,
|
| 341 |
+
}, ensure_ascii=False)
|
| 342 |
+
|
| 343 |
+
return [{
|
| 344 |
+
"messages": [
|
| 345 |
+
{"role": "system", "content": system},
|
| 346 |
+
{"role": "user", "content": user_prompt},
|
| 347 |
+
{"role": "assistant", "content": assistant_response},
|
| 348 |
+
],
|
| 349 |
+
"_meta": {
|
| 350 |
+
"type": "conversation_initiation",
|
| 351 |
+
"conv_id": conv.get("id", ""),
|
| 352 |
+
"topic": topic,
|
| 353 |
+
"location": location,
|
| 354 |
+
"initiator_id": initiator_id,
|
| 355 |
+
"initiator_name": initiator_name,
|
| 356 |
+
"other_name": other_name,
|
| 357 |
+
}
|
| 358 |
+
}]
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 362 |
+
|
| 363 |
+
def load_raw_jsonl(path: Path) -> list[dict]:
|
| 364 |
+
if not path.exists():
|
| 365 |
+
return []
|
| 366 |
+
items = []
|
| 367 |
+
with open(path, encoding="utf-8") as f:
|
| 368 |
+
for line in f:
|
| 369 |
+
line = line.strip()
|
| 370 |
+
if line:
|
| 371 |
+
try:
|
| 372 |
+
items.append(json.loads(line))
|
| 373 |
+
except json.JSONDecodeError:
|
| 374 |
+
pass
|
| 375 |
+
return items
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def load_agent_cache() -> dict:
|
| 379 |
+
cache_file = RAW_DIR / "agents_cache.json"
|
| 380 |
+
if cache_file.exists():
|
| 381 |
+
try:
|
| 382 |
+
return json.loads(cache_file.read_text(encoding="utf-8"))
|
| 383 |
+
except Exception:
|
| 384 |
+
pass
|
| 385 |
+
return {}
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def run(raw_dir: Path, out_path: Path, include_actions: bool = False) -> None:
|
| 389 |
+
print("Soci Training Data Converter")
|
| 390 |
+
print(f" Raw dir : {raw_dir.resolve()}")
|
| 391 |
+
print(f" Output : {out_path.resolve()}")
|
| 392 |
+
|
| 393 |
+
# Load personas
|
| 394 |
+
persona_map = load_persona_map()
|
| 395 |
+
print(f" Personas: {len(persona_map)//2} loaded from config") # /2 because keyed by id+name
|
| 396 |
+
|
| 397 |
+
# Load agent cache (from collector)
|
| 398 |
+
agent_cache = load_agent_cache()
|
| 399 |
+
print(f" Agent cache: {len(agent_cache)} agents")
|
| 400 |
+
|
| 401 |
+
# Load all raw conversations from all date files
|
| 402 |
+
all_convs: list[dict] = []
|
| 403 |
+
seen_ids: set[str] = set()
|
| 404 |
+
for conv_file in sorted(raw_dir.glob("conversations_*.jsonl")):
|
| 405 |
+
items = load_raw_jsonl(conv_file)
|
| 406 |
+
for c in items:
|
| 407 |
+
cid = c.get("id", "")
|
| 408 |
+
if cid and cid not in seen_ids:
|
| 409 |
+
all_convs.append(c)
|
| 410 |
+
seen_ids.add(cid)
|
| 411 |
+
print(f" Conversations loaded: {len(all_convs)}")
|
| 412 |
+
|
| 413 |
+
# Load all raw events from all date files
|
| 414 |
+
all_events: list[dict] = []
|
| 415 |
+
for ev_file in sorted(raw_dir.glob("events_*.jsonl")):
|
| 416 |
+
all_events.extend(load_raw_jsonl(ev_file))
|
| 417 |
+
print(f" Events loaded: {len(all_events)}")
|
| 418 |
+
|
| 419 |
+
# Generate training examples
|
| 420 |
+
examples: list[dict] = []
|
| 421 |
+
|
| 422 |
+
# 1. Conversation initiation examples
|
| 423 |
+
for conv in all_convs:
|
| 424 |
+
examples.extend(make_initiation_examples(conv, persona_map, agent_cache))
|
| 425 |
+
|
| 426 |
+
# 2. Conversation response examples
|
| 427 |
+
for conv in all_convs:
|
| 428 |
+
examples.extend(make_conversation_examples(conv, persona_map, agent_cache))
|
| 429 |
+
|
| 430 |
+
# 3. Action decision examples (optional)
|
| 431 |
+
if include_actions and all_events:
|
| 432 |
+
action_examples = make_action_examples(all_events, persona_map, agent_cache)
|
| 433 |
+
examples.extend(action_examples)
|
| 434 |
+
print(f" Action examples: {len(action_examples)}")
|
| 435 |
+
|
| 436 |
+
# Count by type
|
| 437 |
+
type_counts: dict[str, int] = defaultdict(int)
|
| 438 |
+
for ex in examples:
|
| 439 |
+
type_counts[ex.get("_meta", {}).get("type", "unknown")] += 1
|
| 440 |
+
|
| 441 |
+
print(f"\n Total training examples: {len(examples)}")
|
| 442 |
+
for t, c in sorted(type_counts.items()):
|
| 443 |
+
print(f" {t}: {c}")
|
| 444 |
+
|
| 445 |
+
# Write output JSONL (without _meta for clean training files, or with --keep-meta)
|
| 446 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 447 |
+
for ex in examples:
|
| 448 |
+
# Write with _meta stripped (keep messages only)
|
| 449 |
+
clean = {"messages": ex["messages"]}
|
| 450 |
+
f.write(json.dumps(clean, ensure_ascii=False) + "\n")
|
| 451 |
+
|
| 452 |
+
# Also write a version with meta for analysis
|
| 453 |
+
meta_path = out_path.with_suffix(".meta.jsonl")
|
| 454 |
+
with open(meta_path, "w", encoding="utf-8") as f:
|
| 455 |
+
for ex in examples:
|
| 456 |
+
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
|
| 457 |
+
|
| 458 |
+
print(f"\n Training JSONL : {out_path}")
|
| 459 |
+
print(f" With meta : {meta_path}")
|
| 460 |
+
print(f"\nSample (first example):")
|
| 461 |
+
if examples:
|
| 462 |
+
ex = examples[0]
|
| 463 |
+
print(f" Type: {ex['_meta']['type']}")
|
| 464 |
+
print(f" System: {ex['messages'][0]['content'][:120]}...")
|
| 465 |
+
print(f" User: {ex['messages'][1]['content'][:120]}...")
|
| 466 |
+
print(f" Asst: {ex['messages'][2]['content'][:120]}...")
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
if __name__ == "__main__":
|
| 470 |
+
parser = argparse.ArgumentParser(description="Convert raw Soci data to SFT training JSONL")
|
| 471 |
+
parser.add_argument("--raw-dir", default=str(RAW_DIR), help="Directory with raw JSONL files")
|
| 472 |
+
parser.add_argument("--out", default=str(DEFAULT_OUT), help="Output JSONL path")
|
| 473 |
+
parser.add_argument("--include-actions", action="store_true",
|
| 474 |
+
help="Include action decision examples from events")
|
| 475 |
+
args = parser.parse_args()
|
| 476 |
+
|
| 477 |
+
run(Path(args.raw_dir), Path(args.out), include_actions=args.include_actions)
|
scripts/finetune_local.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
finetune_local.py β Local adaptation of Soci_FineTune_3_Incremental
|
| 3 |
+
Fine-tunes Qwen2.5-0.5B-Instruct on Soci city-simulation tasks using Unsloth.
|
| 4 |
+
|
| 5 |
+
Differences from the Colab version:
|
| 6 |
+
- No Google Drive / google.colab dependencies
|
| 7 |
+
- Local checkpoint and adapter storage in data/training/
|
| 8 |
+
- Loads live conversation data from data/training/processed/
|
| 9 |
+
- HF token from HF_TOKEN env var (or .env file)
|
| 10 |
+
- --debug flag for quick 1-epoch smoke test (no HF push)
|
| 11 |
+
- --resume flag to continue from saved LoRA adapters
|
| 12 |
+
|
| 13 |
+
Usage (from project root):
|
| 14 |
+
# Debug / smoke test (fast, no push):
|
| 15 |
+
"C:/Users/xabon/.conda/envs/ml-env/python.exe" scripts/finetune_local.py --debug
|
| 16 |
+
|
| 17 |
+
# Full round-1 training + push to HF:
|
| 18 |
+
"C:/Users/xabon/.conda/envs/ml-env/python.exe" scripts/finetune_local.py
|
| 19 |
+
|
| 20 |
+
# Resume round 2 with same command:
|
| 21 |
+
"C:/Users/xabon/.conda/envs/ml-env/python.exe" scripts/finetune_local.py --resume
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import sys
|
| 27 |
+
import io
|
| 28 |
+
import os
|
| 29 |
+
|
| 30 |
+
# Force UTF-8 stdout/stderr on Windows (unsloth prints emoji characters)
|
| 31 |
+
if sys.platform == "win32":
|
| 32 |
+
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
|
| 33 |
+
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
|
| 34 |
+
|
| 35 |
+
# Disable torch.compile/inductor β triton 3.x on Windows doesn't export 'triton_key'
|
| 36 |
+
# which inductor needs at compile time. Training still uses CUDA kernels, just not
|
| 37 |
+
# the AOT-compiled fusion path. Has no meaningful effect on a single-GPU setup.
|
| 38 |
+
os.environ.setdefault("TORCHINDUCTOR_DISABLE", "1")
|
| 39 |
+
os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
|
| 40 |
+
|
| 41 |
+
# Import unsloth FIRST so it can patch transformers before anything else loads.
|
| 42 |
+
# Then patch list_repo_templates to skip the 'additional_chat_templates' HF Hub
|
| 43 |
+
# check that fails on unsloth's quantized repos (transformers 4.56+ behavior).
|
| 44 |
+
import unsloth # noqa: F401 β must be first
|
| 45 |
+
import transformers.utils.hub
|
| 46 |
+
import transformers.tokenization_utils_base
|
| 47 |
+
_noop = lambda *a, **kw: []
|
| 48 |
+
transformers.tokenization_utils_base.list_repo_templates = _noop
|
| 49 |
+
transformers.utils.hub.list_repo_templates = _noop
|
| 50 |
+
|
| 51 |
+
import argparse
|
| 52 |
+
import json
|
| 53 |
+
import os
|
| 54 |
+
import shutil
|
| 55 |
+
from datetime import datetime
|
| 56 |
+
from pathlib import Path
|
| 57 |
+
|
| 58 |
+
# ββ Parse args first (before heavy imports) βββββββββββββββββββββββββββββββββββ
|
| 59 |
+
parser = argparse.ArgumentParser(description="Soci local fine-tune")
|
| 60 |
+
parser.add_argument("--resume", action="store_true", help="Resume from saved LoRA adapters")
|
| 61 |
+
parser.add_argument("--debug", action="store_true", help="Debug/smoke-test: 1 epoch, 20 examples, no push")
|
| 62 |
+
parser.add_argument("--no-push", action="store_true", help="Skip HF Hub push")
|
| 63 |
+
parser.add_argument("--no-gguf", action="store_true", help="Skip GGUF export")
|
| 64 |
+
parser.add_argument("--epochs", type=int, default=None, help="Override epoch count")
|
| 65 |
+
parser.add_argument("--hf-repo", default=None, help="HF repo ID (overrides default)")
|
| 66 |
+
args = parser.parse_args()
|
| 67 |
+
|
| 68 |
+
# ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 69 |
+
TRAIN_DIR = Path("data/training")
|
| 70 |
+
LORA_SAVE_DIR = TRAIN_DIR / "lora_adapters"
|
| 71 |
+
DATA_ARCHIVE_DIR = TRAIN_DIR / "data_archive"
|
| 72 |
+
GGUF_DIR = TRAIN_DIR / "gguf"
|
| 73 |
+
CHECKPOINTS_DIR = TRAIN_DIR / "checkpoints"
|
| 74 |
+
ROUND_FILE = TRAIN_DIR / "training_round.json"
|
| 75 |
+
CORE_DATA_FILE = TRAIN_DIR / "core_examples.json"
|
| 76 |
+
LIVE_DATA_FILE = TRAIN_DIR / "processed" / "soci_training.jsonl"
|
| 77 |
+
|
| 78 |
+
for d in [LORA_SAVE_DIR, DATA_ARCHIVE_DIR, GGUF_DIR, CHECKPOINTS_DIR]:
|
| 79 |
+
d.mkdir(parents=True, exist_ok=True)
|
| 80 |
+
|
| 81 |
+
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
+
MAX_SEQ_LENGTH = 2048
|
| 83 |
+
HF_USERNAME = "RayMelius"
|
| 84 |
+
REPO_NAME = "soci-agent-q4"
|
| 85 |
+
HF_REPO_ID = args.hf_repo or f"{HF_USERNAME}/{REPO_NAME}"
|
| 86 |
+
|
| 87 |
+
# Load HF token
|
| 88 |
+
try:
|
| 89 |
+
from dotenv import load_dotenv
|
| 90 |
+
load_dotenv()
|
| 91 |
+
except ImportError:
|
| 92 |
+
pass
|
| 93 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 94 |
+
if not HF_TOKEN:
|
| 95 |
+
# Try to read from the project .env
|
| 96 |
+
env_file = Path(".env")
|
| 97 |
+
if env_file.exists():
|
| 98 |
+
for line in env_file.read_text().splitlines():
|
| 99 |
+
if line.startswith("HF_TOKEN="):
|
| 100 |
+
HF_TOKEN = line.split("=", 1)[1].strip().strip('"')
|
| 101 |
+
|
| 102 |
+
# ββ GPU check βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 103 |
+
import torch
|
| 104 |
+
if not torch.cuda.is_available():
|
| 105 |
+
print("[WARN] No CUDA GPU detected β training will be very slow on CPU.")
|
| 106 |
+
print(" Consider running on Colab or a machine with a GPU.")
|
| 107 |
+
else:
|
| 108 |
+
print(f"GPU : {torch.cuda.get_device_name(0)}")
|
| 109 |
+
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 110 |
+
|
| 111 |
+
# ββ Determine training round ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 112 |
+
RESUME = args.resume
|
| 113 |
+
if RESUME and ROUND_FILE.exists():
|
| 114 |
+
round_info = json.loads(ROUND_FILE.read_text())
|
| 115 |
+
CURRENT_ROUND = round_info["round"] + 1
|
| 116 |
+
print(f"Resuming from round {round_info['round']} -> round {CURRENT_ROUND}")
|
| 117 |
+
print(f"Previous loss: {round_info.get('final_loss', 'N/A')}")
|
| 118 |
+
elif RESUME:
|
| 119 |
+
CURRENT_ROUND = 2
|
| 120 |
+
print("No round file found, assuming round 2")
|
| 121 |
+
else:
|
| 122 |
+
CURRENT_ROUND = 1
|
| 123 |
+
print("Starting fresh (round 1)")
|
| 124 |
+
|
| 125 |
+
# ββ Load model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 126 |
+
from unsloth import FastLanguageModel # noqa: already imported via 'import unsloth'
|
| 127 |
+
|
| 128 |
+
if RESUME and LORA_SAVE_DIR.exists() and any(LORA_SAVE_DIR.iterdir()):
|
| 129 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 130 |
+
model_name = str(LORA_SAVE_DIR),
|
| 131 |
+
max_seq_length = MAX_SEQ_LENGTH,
|
| 132 |
+
dtype = None,
|
| 133 |
+
load_in_4bit = True,
|
| 134 |
+
)
|
| 135 |
+
print(f"Resumed LoRA adapters from {LORA_SAVE_DIR}")
|
| 136 |
+
else:
|
| 137 |
+
if RESUME:
|
| 138 |
+
print(f"[WARN] No LoRA adapters at {LORA_SAVE_DIR}, starting fresh.")
|
| 139 |
+
CURRENT_ROUND = 1
|
| 140 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 141 |
+
model_name = "unsloth/Qwen2.5-0.5B-Instruct",
|
| 142 |
+
max_seq_length = MAX_SEQ_LENGTH,
|
| 143 |
+
dtype = None,
|
| 144 |
+
load_in_4bit = True,
|
| 145 |
+
)
|
| 146 |
+
print("Fresh base model loaded (round 1)")
|
| 147 |
+
|
| 148 |
+
# ββ Attach LoRA βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 149 |
+
if CURRENT_ROUND == 1:
|
| 150 |
+
model = FastLanguageModel.get_peft_model(
|
| 151 |
+
model,
|
| 152 |
+
r = 16,
|
| 153 |
+
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
| 154 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 155 |
+
lora_alpha = 16,
|
| 156 |
+
lora_dropout = 0,
|
| 157 |
+
bias = "none",
|
| 158 |
+
use_gradient_checkpointing = "unsloth",
|
| 159 |
+
random_state = 42,
|
| 160 |
+
)
|
| 161 |
+
print("Fresh LoRA adapters attached")
|
| 162 |
+
else:
|
| 163 |
+
model.gradient_checkpointing_enable()
|
| 164 |
+
print(f"Resumed LoRA adapters from round {CURRENT_ROUND - 1}")
|
| 165 |
+
|
| 166 |
+
model.print_trainable_parameters()
|
| 167 |
+
|
| 168 |
+
# ββ System prompt βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 169 |
+
SYSTEM_PROMPT = (
|
| 170 |
+
"You are the reasoning engine for Soci, an LLM-powered city population simulator. "
|
| 171 |
+
"You control AI agents (NPCs) living in a city. Each agent has a persona, needs "
|
| 172 |
+
"(hunger, energy, social, purpose, comfort, fun), memories, and relationships. "
|
| 173 |
+
"You receive structured context and must respond ONLY with valid JSON. "
|
| 174 |
+
"Never add explanation outside the JSON."
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# ββ Load training data ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 178 |
+
print("\nLoading training data...")
|
| 179 |
+
|
| 180 |
+
# 1. Core examples (from data/training/core_examples.json, extracted from v3 script)
|
| 181 |
+
core_examples: list[dict] = []
|
| 182 |
+
if CORE_DATA_FILE.exists():
|
| 183 |
+
core_examples = json.loads(CORE_DATA_FILE.read_text(encoding="utf-8"))
|
| 184 |
+
print(f" Core examples: {len(core_examples)}")
|
| 185 |
+
else:
|
| 186 |
+
print(f" [WARN] {CORE_DATA_FILE} not found β run extract step or collect_training_data.py first")
|
| 187 |
+
|
| 188 |
+
# 2. Live collected data from the running simulation
|
| 189 |
+
live_examples: list[dict] = []
|
| 190 |
+
if LIVE_DATA_FILE.exists():
|
| 191 |
+
with open(LIVE_DATA_FILE, encoding="utf-8") as f:
|
| 192 |
+
for line in f:
|
| 193 |
+
line = line.strip()
|
| 194 |
+
if not line:
|
| 195 |
+
continue
|
| 196 |
+
try:
|
| 197 |
+
ex = json.loads(line)
|
| 198 |
+
# Convert messages format -> instruction/response format
|
| 199 |
+
msgs = ex.get("messages", [])
|
| 200 |
+
if len(msgs) >= 3:
|
| 201 |
+
# Find system-ish context in user message; use Soci system prompt
|
| 202 |
+
user_content = msgs[1]["content"]
|
| 203 |
+
asst_content = msgs[2]["content"]
|
| 204 |
+
# Prepend persona context from system message as part of instruction
|
| 205 |
+
persona_ctx = msgs[0]["content"]
|
| 206 |
+
# Keep persona as part of instruction since we use unified system prompt
|
| 207 |
+
instruction = f"{persona_ctx}\n\n{user_content}"
|
| 208 |
+
live_examples.append({
|
| 209 |
+
"instruction": instruction,
|
| 210 |
+
"response": asst_content,
|
| 211 |
+
})
|
| 212 |
+
except (json.JSONDecodeError, KeyError):
|
| 213 |
+
pass
|
| 214 |
+
print(f" Live examples: {len(live_examples)} (from Render simulation)")
|
| 215 |
+
|
| 216 |
+
# 3. Replay archived examples from previous rounds
|
| 217 |
+
replay_examples: list[dict] = []
|
| 218 |
+
if CURRENT_ROUND > 1:
|
| 219 |
+
for archive_f in sorted(DATA_ARCHIVE_DIR.glob("round_*.json")):
|
| 220 |
+
try:
|
| 221 |
+
batch = json.loads(archive_f.read_text(encoding="utf-8"))
|
| 222 |
+
replay_examples.extend(batch)
|
| 223 |
+
except Exception:
|
| 224 |
+
pass
|
| 225 |
+
print(f" Replay examples: {len(replay_examples)}")
|
| 226 |
+
|
| 227 |
+
# 4. New examples for this round (add yours here for incremental training)
|
| 228 |
+
new_examples_this_round: list[dict] = [
|
| 229 |
+
# Add new instruction/response pairs here for incremental training rounds.
|
| 230 |
+
# Example:
|
| 231 |
+
# {"instruction": "You are playing Diana Novak, 41, grocery store owner. ...",
|
| 232 |
+
# "response": '{"action": "work", "location": "grocery_store", "reason": "..."}'},
|
| 233 |
+
]
|
| 234 |
+
if new_examples_this_round:
|
| 235 |
+
print(f" New examples this round: {len(new_examples_this_round)}")
|
| 236 |
+
|
| 237 |
+
# Merge and deduplicate by instruction
|
| 238 |
+
seen: set[str] = set()
|
| 239 |
+
all_examples: list[dict] = []
|
| 240 |
+
for ex in core_examples + live_examples + new_examples_this_round + replay_examples:
|
| 241 |
+
key = ex.get("instruction", "")[:100]
|
| 242 |
+
if key not in seen:
|
| 243 |
+
seen.add(key)
|
| 244 |
+
all_examples.append(ex)
|
| 245 |
+
|
| 246 |
+
if args.debug:
|
| 247 |
+
all_examples = all_examples[:20]
|
| 248 |
+
print(f" DEBUG mode: using {len(all_examples)} examples")
|
| 249 |
+
|
| 250 |
+
print(f" Total (deduped): {len(all_examples)}")
|
| 251 |
+
|
| 252 |
+
# ββ Format into chat template βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 253 |
+
from datasets import Dataset
|
| 254 |
+
|
| 255 |
+
def format_example(ex: dict) -> dict:
|
| 256 |
+
msgs = [
|
| 257 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 258 |
+
{"role": "user", "content": ex["instruction"]},
|
| 259 |
+
{"role": "assistant", "content": ex["response"]},
|
| 260 |
+
]
|
| 261 |
+
return {"text": tokenizer.apply_chat_template(
|
| 262 |
+
msgs, tokenize=False, add_generation_prompt=False
|
| 263 |
+
)}
|
| 264 |
+
|
| 265 |
+
dataset = Dataset.from_list(all_examples).map(format_example)
|
| 266 |
+
print(f"Formatted {len(dataset)} examples. Sample:")
|
| 267 |
+
print(dataset[0]["text"][:400])
|
| 268 |
+
|
| 269 |
+
# ββ Training config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 270 |
+
from trl import SFTTrainer, SFTConfig
|
| 271 |
+
from unsloth import is_bfloat16_supported
|
| 272 |
+
|
| 273 |
+
if args.debug:
|
| 274 |
+
LR, EPOCHS, WARMUP, SCHEDULER = 2e-4, 1, 2, "linear"
|
| 275 |
+
print(f"\nDEBUG: 1 epoch smoke test")
|
| 276 |
+
elif CURRENT_ROUND == 1:
|
| 277 |
+
LR, EPOCHS, WARMUP, SCHEDULER = 2e-4, 3, 5, "linear"
|
| 278 |
+
print(f"\nRound 1: Full training β LR={LR}, epochs={EPOCHS}")
|
| 279 |
+
else:
|
| 280 |
+
LR, EPOCHS, WARMUP, SCHEDULER = 5e-5, 2, 10, "cosine"
|
| 281 |
+
print(f"\nRound {CURRENT_ROUND}: Incremental β LR={LR}, epochs={EPOCHS}")
|
| 282 |
+
|
| 283 |
+
if args.epochs is not None:
|
| 284 |
+
EPOCHS = args.epochs
|
| 285 |
+
print(f"Epoch override: {EPOCHS}")
|
| 286 |
+
|
| 287 |
+
trainer = SFTTrainer(
|
| 288 |
+
model = model,
|
| 289 |
+
tokenizer = tokenizer,
|
| 290 |
+
train_dataset = dataset,
|
| 291 |
+
dataset_text_field = "text",
|
| 292 |
+
max_seq_length = MAX_SEQ_LENGTH,
|
| 293 |
+
dataset_num_proc = 2,
|
| 294 |
+
args = SFTConfig(
|
| 295 |
+
per_device_train_batch_size = 2,
|
| 296 |
+
gradient_accumulation_steps = 4,
|
| 297 |
+
warmup_steps = WARMUP,
|
| 298 |
+
num_train_epochs = EPOCHS,
|
| 299 |
+
learning_rate = LR,
|
| 300 |
+
fp16 = not is_bfloat16_supported(),
|
| 301 |
+
bf16 = is_bfloat16_supported(),
|
| 302 |
+
logging_steps = 5,
|
| 303 |
+
optim = "adamw_8bit",
|
| 304 |
+
weight_decay = 0.01,
|
| 305 |
+
lr_scheduler_type = SCHEDULER,
|
| 306 |
+
seed = 42,
|
| 307 |
+
output_dir = str(CHECKPOINTS_DIR),
|
| 308 |
+
report_to = "none",
|
| 309 |
+
dataset_text_field = "text",
|
| 310 |
+
max_seq_length = MAX_SEQ_LENGTH,
|
| 311 |
+
),
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
print(f"\nTraining round {CURRENT_ROUND} on {len(dataset)} examples...")
|
| 315 |
+
stats = trainer.train()
|
| 316 |
+
print(f"\nRound {CURRENT_ROUND} complete!")
|
| 317 |
+
print(f" Steps: {stats.global_step} | Final loss: {stats.training_loss:.4f}")
|
| 318 |
+
|
| 319 |
+
# ββ Save LoRA adapters ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 320 |
+
print(f"\nSaving LoRA adapters to {LORA_SAVE_DIR}...")
|
| 321 |
+
model.save_pretrained(str(LORA_SAVE_DIR))
|
| 322 |
+
tokenizer.save_pretrained(str(LORA_SAVE_DIR))
|
| 323 |
+
print(" Saved.")
|
| 324 |
+
|
| 325 |
+
# ββ Save round metadata βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 326 |
+
round_info = {
|
| 327 |
+
"round": CURRENT_ROUND,
|
| 328 |
+
"final_loss": stats.training_loss,
|
| 329 |
+
"global_steps": stats.global_step,
|
| 330 |
+
"total_examples": len(all_examples),
|
| 331 |
+
"new_examples": len(new_examples_this_round) + len(live_examples),
|
| 332 |
+
"learning_rate": LR,
|
| 333 |
+
"epochs": EPOCHS,
|
| 334 |
+
"timestamp": datetime.now().isoformat(),
|
| 335 |
+
}
|
| 336 |
+
ROUND_FILE.write_text(json.dumps(round_info, indent=2))
|
| 337 |
+
print(f" Round info: {ROUND_FILE}")
|
| 338 |
+
|
| 339 |
+
# Archive new examples
|
| 340 |
+
all_new = new_examples_this_round + live_examples
|
| 341 |
+
if all_new:
|
| 342 |
+
archive_file = DATA_ARCHIVE_DIR / f"round_{CURRENT_ROUND:03d}.json"
|
| 343 |
+
archive_file.write_text(json.dumps(all_new, indent=2, ensure_ascii=False))
|
| 344 |
+
print(f" Archived {len(all_new)} new examples")
|
| 345 |
+
|
| 346 |
+
# Training history
|
| 347 |
+
history_file = TRAIN_DIR / "training_history.jsonl"
|
| 348 |
+
with open(history_file, "a", encoding="utf-8") as f:
|
| 349 |
+
f.write(json.dumps(round_info) + "\n")
|
| 350 |
+
|
| 351 |
+
# ββ Quick inference test ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 352 |
+
print(f"\n=== Testing after Round {CURRENT_ROUND} ===\n")
|
| 353 |
+
FastLanguageModel.for_inference(model)
|
| 354 |
+
|
| 355 |
+
def ask(question: str, label: str = "") -> None:
|
| 356 |
+
msgs = [
|
| 357 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 358 |
+
{"role": "user", "content": question},
|
| 359 |
+
]
|
| 360 |
+
encoded = tokenizer.apply_chat_template(
|
| 361 |
+
msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
| 362 |
+
)
|
| 363 |
+
if hasattr(encoded, "input_ids"):
|
| 364 |
+
inp = encoded.input_ids.to("cuda")
|
| 365 |
+
else:
|
| 366 |
+
inp = encoded.to("cuda")
|
| 367 |
+
out = model.generate(
|
| 368 |
+
input_ids=inp, max_new_tokens=200,
|
| 369 |
+
temperature=0.7, top_p=0.9, do_sample=True,
|
| 370 |
+
)
|
| 371 |
+
resp = tokenizer.decode(out[0][inp.shape[1]:], skip_special_tokens=True)
|
| 372 |
+
print(f"[{label}]")
|
| 373 |
+
print(f"Q: {question[:100]}...")
|
| 374 |
+
try:
|
| 375 |
+
parsed = json.loads(resp)
|
| 376 |
+
print(f"A (valid JSON):\n{json.dumps(parsed, indent=2)}")
|
| 377 |
+
except Exception:
|
| 378 |
+
print(f"A (raw): {resp}")
|
| 379 |
+
print("-" * 60)
|
| 380 |
+
|
| 381 |
+
ask(
|
| 382 |
+
"You are playing Elena Vasquez, 34, software engineer. "
|
| 383 |
+
"Needs: energy=0.3, hunger=0.7. Location: office. Time: 12:30. "
|
| 384 |
+
"Decide next action. JSON: {\"action\": str, \"location\": str, \"reason\": str}",
|
| 385 |
+
"decide_action",
|
| 386 |
+
)
|
| 387 |
+
ask(
|
| 388 |
+
"You are playing Marcus Chen talking to Zoe. "
|
| 389 |
+
"Zoe says: 'Marcus, I bombed my exam.' Continue as Marcus. "
|
| 390 |
+
"JSON: {\"speech\": str, \"emotion\": str}",
|
| 391 |
+
"conversation_turn",
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
# ββ GGUF export βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 395 |
+
# Windows: unsloth GGUF export requires building llama.cpp via apt-get (Linux only).
|
| 396 |
+
# Auto-skip on Windows; use --no-gguf on Linux too if llama.cpp isn't set up.
|
| 397 |
+
import platform
|
| 398 |
+
_on_windows = platform.system() == "Windows"
|
| 399 |
+
skip_gguf = args.no_gguf or args.debug or _on_windows
|
| 400 |
+
if _on_windows and not args.no_gguf and not args.debug:
|
| 401 |
+
print("\nSkipping GGUF export (Windows β llama.cpp build not supported via unsloth on Win)")
|
| 402 |
+
print(" To export GGUF manually, use llama.cpp's convert_hf_to_gguf.py")
|
| 403 |
+
print(f" LoRA merged weights saved to: {GGUF_DIR}/ (after push)")
|
| 404 |
+
|
| 405 |
+
if not skip_gguf:
|
| 406 |
+
print(f"\nExporting GGUF Q4_K_M (takes a few minutes)...")
|
| 407 |
+
model.save_pretrained_gguf(str(GGUF_DIR), tokenizer, quantization_method="q4_k_m")
|
| 408 |
+
gguf_files = list(GGUF_DIR.glob("*.gguf"))
|
| 409 |
+
for gf in gguf_files:
|
| 410 |
+
print(f" GGUF: {gf.name} ({gf.stat().st_size / 1e6:.0f} MB)")
|
| 411 |
+
else:
|
| 412 |
+
if args.debug:
|
| 413 |
+
print("\nSkipping GGUF export (debug mode)")
|
| 414 |
+
gguf_files = []
|
| 415 |
+
|
| 416 |
+
# ββ Push to HuggingFace Hub βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 417 |
+
skip_push = args.no_push or args.debug
|
| 418 |
+
if skip_push:
|
| 419 |
+
print("\nSkipping HF push (debug mode or --no-push)")
|
| 420 |
+
else:
|
| 421 |
+
if not HF_TOKEN:
|
| 422 |
+
print("\n[WARN] No HF_TOKEN found β skipping push.")
|
| 423 |
+
print(" Set HF_TOKEN env var or add to .env file.")
|
| 424 |
+
else:
|
| 425 |
+
from huggingface_hub import login, HfApi
|
| 426 |
+
print(f"\nPushing to HuggingFace: {HF_REPO_ID}")
|
| 427 |
+
login(token=HF_TOKEN)
|
| 428 |
+
api = HfApi()
|
| 429 |
+
api.create_repo(repo_id=HF_REPO_ID, repo_type="model", exist_ok=True)
|
| 430 |
+
|
| 431 |
+
# Push LoRA adapters
|
| 432 |
+
print(" Uploading LoRA adapters...")
|
| 433 |
+
api.upload_folder(
|
| 434 |
+
folder_path = str(LORA_SAVE_DIR),
|
| 435 |
+
repo_id = HF_REPO_ID,
|
| 436 |
+
repo_type = "model",
|
| 437 |
+
path_in_repo= "lora_adapters",
|
| 438 |
+
)
|
| 439 |
+
print(f" LoRA -> https://huggingface.co/{HF_REPO_ID}/tree/main/lora_adapters")
|
| 440 |
+
|
| 441 |
+
# Push GGUF file(s)
|
| 442 |
+
for gf in gguf_files:
|
| 443 |
+
mb = gf.stat().st_size / 1e6
|
| 444 |
+
print(f" Uploading {gf.name} ({mb:.0f} MB)...")
|
| 445 |
+
api.upload_file(
|
| 446 |
+
path_or_fileobj = str(gf),
|
| 447 |
+
path_in_repo = gf.name,
|
| 448 |
+
repo_id = HF_REPO_ID,
|
| 449 |
+
repo_type = "model",
|
| 450 |
+
)
|
| 451 |
+
print(f" Done: https://huggingface.co/{HF_REPO_ID}/blob/main/{gf.name}")
|
| 452 |
+
|
| 453 |
+
# Push round metadata
|
| 454 |
+
api.upload_file(
|
| 455 |
+
path_or_fileobj = str(ROUND_FILE),
|
| 456 |
+
path_in_repo = "training_round.json",
|
| 457 |
+
repo_id = HF_REPO_ID,
|
| 458 |
+
repo_type = "model",
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
print(f"\nUpload complete! Model at: https://huggingface.co/{HF_REPO_ID}")
|
| 462 |
+
|
| 463 |
+
# ββ Training history display ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 464 |
+
print("\n=== Training History ===\n")
|
| 465 |
+
if history_file.exists():
|
| 466 |
+
print(f"{'Round':>6} {'Loss':>8} {'Steps':>7} {'Examples':>9} {'New':>5} {'LR':>10} {'Date':>12}")
|
| 467 |
+
print("-" * 65)
|
| 468 |
+
with open(history_file, encoding="utf-8") as f:
|
| 469 |
+
for line in f:
|
| 470 |
+
r = json.loads(line)
|
| 471 |
+
date = r.get("timestamp", "")[:10]
|
| 472 |
+
print(f"{r['round']:>6} {r['final_loss']:>8.4f} {r['global_steps']:>7} "
|
| 473 |
+
f"{r['total_examples']:>9} {r['new_examples']:>5} "
|
| 474 |
+
f"{r['learning_rate']:>10.1e} {date:>12}")
|
| 475 |
+
|
| 476 |
+
print(f"\nTo resume: python scripts/finetune_local.py --resume")
|
| 477 |
+
print(f"LoRA adapters: {LORA_SAVE_DIR}")
|
| 478 |
+
if gguf_files:
|
| 479 |
+
print(f"GGUF: {gguf_files[0]}")
|
| 480 |
+
print(f"\nOllama integration:")
|
| 481 |
+
print(f" ollama create soci-agent -f Modelfile")
|
| 482 |
+
print(f" set SOCI_PROVIDER=ollama && set OLLAMA_MODEL=soci-agent")
|
src/soci/api/routes.py
CHANGED
|
@@ -286,10 +286,12 @@ async def get_llm_providers():
|
|
| 286 |
or os.environ.get("HF_API_TOKEN")
|
| 287 |
)
|
| 288 |
if has_hf:
|
|
|
|
| 289 |
providers.append({"id": "hf", "model": "HuggingFaceH4/zephyr-7b-beta", "label": "HF Zephyr 7B", "icon": "π€"})
|
| 290 |
providers.append({"id": "hf", "model": "Qwen/Qwen2.5-7B-Instruct", "label": "HF Qwen 2.5 7B", "icon": "π€"})
|
| 291 |
providers.append({"id": "hf", "model": "meta-llama/Llama-3.2-3B-Instruct", "label": "HF Llama 3.2 3B", "icon": "π€"})
|
| 292 |
providers.append({"id": "hf", "model": "mistralai/Mistral-7B-Instruct-v0.3", "label": "HF Mistral 7B", "icon": "π€"})
|
|
|
|
| 293 |
providers.append({"id": "ollama", "label": "Ollama (local)", "icon": "π¦", "model": ""})
|
| 294 |
return {"current": current, "current_model": current_model, "providers": providers}
|
| 295 |
|
|
|
|
| 286 |
or os.environ.get("HF_API_TOKEN")
|
| 287 |
)
|
| 288 |
if has_hf:
|
| 289 |
+
providers.append({"id": "hf", "model": "RayMelius/soci-agent-q4", "label": "Soci Agent (fine-tuned)", "icon": "π"})
|
| 290 |
providers.append({"id": "hf", "model": "HuggingFaceH4/zephyr-7b-beta", "label": "HF Zephyr 7B", "icon": "π€"})
|
| 291 |
providers.append({"id": "hf", "model": "Qwen/Qwen2.5-7B-Instruct", "label": "HF Qwen 2.5 7B", "icon": "π€"})
|
| 292 |
providers.append({"id": "hf", "model": "meta-llama/Llama-3.2-3B-Instruct", "label": "HF Llama 3.2 3B", "icon": "π€"})
|
| 293 |
providers.append({"id": "hf", "model": "mistralai/Mistral-7B-Instruct-v0.3", "label": "HF Mistral 7B", "icon": "π€"})
|
| 294 |
+
providers.append({"id": "ollama", "label": "Soci Agent (Ollama)", "icon": "π", "model": "soci-agent"})
|
| 295 |
providers.append({"id": "ollama", "label": "Ollama (local)", "icon": "π¦", "model": ""})
|
| 296 |
return {"current": current, "current_model": current_model, "providers": providers}
|
| 297 |
|
src/soci/engine/llm.py
CHANGED
|
@@ -64,6 +64,10 @@ MODEL_HF_QWEN = "Qwen/Qwen2.5-7B-Instruct" # default β auto-routed, g
|
|
| 64 |
MODEL_HF_LLAMA = "meta-llama/Llama-3.2-3B-Instruct"
|
| 65 |
MODEL_HF_MISTRAL = "mistralai/Mistral-7B-Instruct-v0.3"
|
| 66 |
MODEL_HF_SMOL = "HuggingFaceTB/SmolLM3-3B:hf-inference" # CPU inference, no credits needed
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
# Approximate cost per 1M tokens (USD) β Ollama is free, Groq is very cheap
|
| 69 |
COST_PER_1M = {
|
|
|
|
| 64 |
MODEL_HF_LLAMA = "meta-llama/Llama-3.2-3B-Instruct"
|
| 65 |
MODEL_HF_MISTRAL = "mistralai/Mistral-7B-Instruct-v0.3"
|
| 66 |
MODEL_HF_SMOL = "HuggingFaceTB/SmolLM3-3B:hf-inference" # CPU inference, no credits needed
|
| 67 |
+
MODEL_HF_SOCI = "RayMelius/soci-agent-q4" # Soci fine-tuned Qwen2.5-0.5B (LoRA)
|
| 68 |
+
|
| 69 |
+
# Ollama model IDs for Soci fine-tuned models
|
| 70 |
+
MODEL_OLLAMA_SOCI = "soci-agent" # load via: ollama create soci-agent -f Modelfile
|
| 71 |
|
| 72 |
# Approximate cost per 1M tokens (USD) β Ollama is free, Groq is very cheap
|
| 73 |
COST_PER_1M = {
|
src/soci/engine/simulation.py
CHANGED
|
@@ -858,6 +858,7 @@ class Simulation:
|
|
| 858 |
"events": self.events.to_dict(),
|
| 859 |
"entropy": self.entropy.to_dict(),
|
| 860 |
"conversation_counter": self._conversation_counter,
|
|
|
|
| 861 |
}
|
| 862 |
|
| 863 |
@classmethod
|
|
@@ -869,6 +870,7 @@ class Simulation:
|
|
| 869 |
sim.events = EventSystem.from_dict(data["events"])
|
| 870 |
sim.entropy = EntropyManager.from_dict(data["entropy"])
|
| 871 |
sim._conversation_counter = data.get("conversation_counter", 0)
|
|
|
|
| 872 |
for aid, agent_data in data["agents"].items():
|
| 873 |
agent = Agent.from_dict(agent_data)
|
| 874 |
sim.agents[agent.id] = agent
|
|
|
|
| 858 |
"events": self.events.to_dict(),
|
| 859 |
"entropy": self.entropy.to_dict(),
|
| 860 |
"conversation_counter": self._conversation_counter,
|
| 861 |
+
"conversation_history": self.conversation_history,
|
| 862 |
}
|
| 863 |
|
| 864 |
@classmethod
|
|
|
|
| 870 |
sim.events = EventSystem.from_dict(data["events"])
|
| 871 |
sim.entropy = EntropyManager.from_dict(data["entropy"])
|
| 872 |
sim._conversation_counter = data.get("conversation_counter", 0)
|
| 873 |
+
sim.conversation_history = data.get("conversation_history", [])
|
| 874 |
for aid, agent_data in data["agents"].items():
|
| 875 |
agent = Agent.from_dict(agent_data)
|
| 876 |
sim.agents[agent.id] = agent
|