File size: 7,544 Bytes
208eb59 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | """
neural_data.py — Training data manager for MLX LoRA fine-tuning.
Manages a rolling buffer of recent conversation turns and a persistent
replay buffer for anti-catastrophic-forgetting experience replay.
"""
import json
import random
import time
from collections import deque
from pathlib import Path
from typing import Optional
class TrainingExample:
"""A single training example (conversation turn)."""
__slots__ = ("messages", "timestamp", "token_count", "session_id")
def __init__(self, messages: list[dict], timestamp: float = 0,
token_count: int = 0, session_id: str = ""):
self.messages = messages
self.timestamp = timestamp or time.time()
self.token_count = token_count
self.session_id = session_id
def to_dict(self) -> dict:
return {
"messages": self.messages,
"timestamp": self.timestamp,
"token_count": self.token_count,
"session_id": self.session_id,
}
@classmethod
def from_dict(cls, d: dict) -> "TrainingExample":
return cls(
messages=d["messages"],
timestamp=d.get("timestamp", 0),
token_count=d.get("token_count", 0),
session_id=d.get("session_id", ""),
)
class TrainingDataManager:
"""Manages rolling buffer + persistent replay for LoRA training."""
def __init__(self, rolling_size: int = 100, replay_size: int = 500,
replay_path: str = "", min_response_tokens: int = 10):
self.rolling_size = rolling_size
self.replay_size = replay_size
self.min_response_tokens = min_response_tokens
self.replay_path = replay_path
self._rolling: deque[TrainingExample] = deque(maxlen=rolling_size)
self._replay: list[TrainingExample] = []
self._total_added = 0
if replay_path:
self._load_replay()
@property
def rolling_count(self) -> int:
return len(self._rolling)
@property
def replay_count(self) -> int:
return len(self._replay)
@property
def total_added(self) -> int:
return self._total_added
def add_turn(self, user_text: str, assistant_text: str,
system_prompt: str = "", session_id: str = "") -> bool:
"""Add a conversation turn to the training buffer.
Returns True if the example was accepted (not filtered).
"""
# Quality filter: skip short/empty responses
approx_tokens = len(assistant_text.split())
if approx_tokens < self.min_response_tokens:
return False
# Skip tool-only or empty content
if not assistant_text.strip():
return False
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_text})
messages.append({"role": "assistant", "content": assistant_text})
example = TrainingExample(
messages=messages,
token_count=approx_tokens,
session_id=session_id,
)
self._rolling.append(example)
self._total_added += 1
# Add to replay with reservoir sampling
if len(self._replay) < self.replay_size:
self._replay.append(example)
else:
idx = random.randint(0, self._total_added - 1)
if idx < self.replay_size:
self._replay[idx] = example
return True
def get_training_batch(self, batch_size: int = 1,
replay_ratio: float = 0.3) -> list[TrainingExample]:
"""Get a training batch mixing recent and replay examples.
Args:
batch_size: Total examples in batch. 0 = all available data.
replay_ratio: Fraction of batch from replay buffer (0.0-1.0)
Returns:
List of TrainingExample
"""
if not self._rolling:
return []
# batch_size <= 0 means "all available data"
if batch_size <= 0:
batch = list(self._rolling)
if self._replay:
# Add replay examples not already in rolling
rolling_set = {id(ex) for ex in self._rolling}
for ex in self._replay:
if id(ex) not in rolling_set:
batch.append(ex)
random.shuffle(batch)
return batch
n_replay = int(batch_size * replay_ratio)
n_recent = batch_size - n_replay
batch = []
# Recent examples (most recent first)
recent = list(self._rolling)
if n_recent > 0:
recent_sample = recent[-n_recent:] if len(recent) >= n_recent else recent
batch.extend(recent_sample)
# Replay examples (random sample)
if n_replay > 0 and self._replay:
replay_sample = random.sample(
self._replay,
min(n_replay, len(self._replay))
)
batch.extend(replay_sample)
random.shuffle(batch)
return batch
def get_recent(self, n: int = 5) -> list[TrainingExample]:
"""Get the N most recent training examples."""
return list(self._rolling)[-n:]
def save_rolling(self, path: str = ""):
"""Save rolling buffer to disk."""
path = path or str(Path(self.replay_path).parent / "buffer.jsonl")
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
for ex in self._rolling:
f.write(json.dumps(ex.to_dict()) + "\n")
def load_rolling(self, path: str = ""):
"""Load rolling buffer from disk."""
path = path or str(Path(self.replay_path).parent / "buffer.jsonl")
if not Path(path).exists():
return
self._rolling.clear()
with open(path) as f:
for line in f:
line = line.strip()
if line:
ex = TrainingExample.from_dict(json.loads(line))
self._rolling.append(ex)
def save_replay(self):
"""Persist replay buffer to disk."""
if not self.replay_path:
return
Path(self.replay_path).parent.mkdir(parents=True, exist_ok=True)
with open(self.replay_path, "w") as f:
for ex in self._replay:
f.write(json.dumps(ex.to_dict()) + "\n")
def _load_replay(self):
"""Load replay buffer from disk."""
if not self.replay_path or not Path(self.replay_path).exists():
return
self._replay.clear()
with open(self.replay_path) as f:
for line in f:
line = line.strip()
if line:
ex = TrainingExample.from_dict(json.loads(line))
self._replay.append(ex)
# Trim to max size
if len(self._replay) > self.replay_size:
self._replay = random.sample(self._replay, self.replay_size)
def clear(self):
"""Clear all buffers (for reset)."""
self._rolling.clear()
self._replay.clear()
self._total_added = 0
def stats(self) -> dict:
"""Return buffer statistics."""
return {
"rolling_count": self.rolling_count,
"rolling_capacity": self.rolling_size,
"replay_count": self.replay_count,
"replay_capacity": self.replay_size,
"total_added": self._total_added,
}
|