OnlyCheeini commited on
Commit
11463c3
·
verified ·
1 Parent(s): ea25ace

Upload 3 files

Browse files
Files changed (3) hide show
  1. inference.py +116 -0
  2. model.py +1374 -0
  3. requirements.txt +214 -0
inference.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ from pathlib import Path
4
+
5
+ import torch
6
+
7
+ from model import (
8
+ GreesyGPT,
9
+ generate_moderation,
10
+ ReasoningMode,
11
+ OutputFormat,
12
+ DEVICE,
13
+ )
14
+
15
+ # ─────────────────────────────────────────────
16
+ # Model Initialization
17
+ # ─────────────────────────────────────────────
18
+ model = GreesyGPT()
19
+
20
+ weights_path = Path(__file__).parent / "greesy_gpt.pt"
21
+ if weights_path.exists():
22
+ model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
23
+ print(f"Loaded weights from {weights_path}")
24
+ else:
25
+ print("No trained weights found, using fresh initialization.")
26
+
27
+ model.to(DEVICE)
28
+ model.eval()
29
+
30
+
31
+ # ─────────────────────────────────────────────
32
+ # OpenAI‑style Chat Completion Wrapper
33
+ # ─────────────────────────────────────────────
34
+ def chat_completions(
35
+ model: GreesyGPT,
36
+ messages,
37
+ reasoning_mode: ReasoningMode = ReasoningMode.LOW,
38
+ ):
39
+ """
40
+ Emulates the OpenAI Chat Completions API format.
41
+
42
+ Input:
43
+ messages = [
44
+ {"role": "system", "content": "..."},
45
+ {"role": "user", "content": "..."}
46
+ ]
47
+
48
+ Output:
49
+ {
50
+ "id": "...",
51
+ "object": "chat.completion",
52
+ "choices": [...]
53
+ }
54
+ """
55
+
56
+ user_message = ""
57
+ system_message = ""
58
+
59
+ for m in messages:
60
+ if m["role"] == "user":
61
+ user_message = m["content"]
62
+ elif m["role"] == "system":
63
+ system_message = m["content"]
64
+
65
+ result = generate_moderation(
66
+ model,
67
+ prompt=user_message,
68
+ mode=reasoning_mode,
69
+ output_format=OutputFormat.MARKDOWN,
70
+ )
71
+
72
+ verdict = result["verdict"]
73
+ thinking = result["thinking"]
74
+
75
+ completion_text = verdict
76
+
77
+ response = {
78
+ "id": f"greesy-{int(time.time()*1000)}",
79
+ "object": "chat.completion",
80
+ "created": int(time.time()),
81
+ "model": "latest",
82
+ "choices": [
83
+ {
84
+ "index": 0,
85
+ "message": {
86
+ "role": "assistant",
87
+ "content": completion_text,
88
+ },
89
+ "finish_reason": "stop",
90
+ }
91
+ ],
92
+ "metadata": {
93
+ "reasoning": thinking,
94
+ },
95
+ }
96
+
97
+ return response
98
+
99
+
100
+ # ─────────────────────────────────────────────
101
+ # Example Usage
102
+ # ─────────────────────────────────────────────
103
+ if __name__ == "__main__":
104
+
105
+ messages = [
106
+ {"role": "system", "content": "You are a moderation system."},
107
+ {"role": "user", "content": "You're so stupid, nobody likes you."},
108
+ ]
109
+
110
+ response = chat_completions(
111
+ model,
112
+ messages,
113
+ reasoning_mode=ReasoningMode.MEDIUM,
114
+ )
115
+
116
+ print(json.dumps(response, indent=2))
model.py ADDED
@@ -0,0 +1,1374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GreesyGPT — Content-moderation language model with KV caching and dropout.
3
+
4
+ Production-ready implementation featuring:
5
+ • KV caching for O(1) per-token inference (instead of recomputing full sequence)
6
+ • Configurable dropout for regularisation during training
7
+ • Centralised ModelConfig dataclass for all hyperparameters
8
+ • Structured logging via the stdlib ``logging`` module
9
+ • Type annotations throughout
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import contextlib
15
+ import json
16
+ import logging
17
+ import re
18
+ from dataclasses import dataclass
19
+ from enum import Enum
20
+ from pathlib import Path
21
+ from typing import Any, Optional, cast
22
+
23
+ import tiktoken
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch.nn import functional as F
27
+ from torch.nn.utils.rnn import pad_sequence
28
+ from torch.utils.data import DataLoader, Dataset
29
+ from tqdm import tqdm
30
+
31
+
32
+ # ─────────────────────────────────────────────
33
+ # Logging
34
+ # ─────────────────────────────────────────────
35
+ logger = logging.getLogger("greesygpt")
36
+ if not logger.handlers:
37
+ _handler = logging.StreamHandler()
38
+ _handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s"))
39
+ logger.addHandler(_handler)
40
+ logger.setLevel(logging.INFO)
41
+
42
+
43
+ # ─────────────────────────────────────────────
44
+ # Model Configuration
45
+ # ─────────────────────────────────────────────
46
+ @dataclass
47
+ class ModelConfig:
48
+ """Centralised hyperparameter store for GreesyGPT."""
49
+
50
+ vocab_size: int = 8192
51
+ context_len: int = 12_000
52
+ n_embd: int = 768
53
+ n_head: int = 12
54
+ n_layer: int = 12
55
+
56
+ # Dropout rates (set to 0.0 at inference via model.eval(); typical training values 0.1–0.2)
57
+ attn_dropout: float = 0.1
58
+ resid_dropout: float = 0.1
59
+ embd_dropout: float = 0.1
60
+ mlp_dropout: float = 0.1
61
+
62
+ @property
63
+ def head_dim(self) -> int:
64
+ assert self.n_embd % self.n_head == 0, "n_embd must be divisible by n_head"
65
+ return self.n_embd // self.n_head
66
+
67
+
68
+ # Legacy constants (kept for backward compatibility; prefer ModelConfig)
69
+ DEFAULT_CONFIG = ModelConfig()
70
+ VOCAB_SIZE = DEFAULT_CONFIG.vocab_size
71
+ CONTEXT_LEN = DEFAULT_CONFIG.context_len
72
+ N_EMBD = DEFAULT_CONFIG.n_embd
73
+ N_HEAD = DEFAULT_CONFIG.n_head
74
+ N_LAYER = DEFAULT_CONFIG.n_layer
75
+
76
+
77
+ # ─────────────────────────────────────────────
78
+ # Device (MPS → CUDA → CPU, in priority order)
79
+ # ─────────────────────────────────────────────
80
+ def _select_device() -> torch.device:
81
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
82
+ return torch.device("mps")
83
+ if torch.cuda.is_available():
84
+ return torch.device("cuda")
85
+ return torch.device("cpu")
86
+
87
+ DEVICE: torch.device = _select_device()
88
+
89
+
90
+ def _autocast_ctx():
91
+ """
92
+ Return the appropriate mixed-precision context for the active device.
93
+
94
+ • MPS → float16 (bfloat16 not yet supported by MPS kernel)
95
+ • CUDA → bfloat16 (preferred for training stability)
96
+ • CPU → no-op (autocast on CPU is rarely beneficial)
97
+ """
98
+ if DEVICE.type == "mps":
99
+ return torch.autocast(device_type="mps", dtype=torch.float16)
100
+ if DEVICE.type == "cuda":
101
+ return torch.autocast(device_type="cuda", dtype=torch.bfloat16)
102
+ return contextlib.nullcontext()
103
+
104
+
105
+ # ─────────────────────────────────────────────
106
+ # Special-token registry
107
+ # 199999 <|endoftext|> (native to o200k_base)
108
+ # 200019 <think> (first free slot after o200k_base's 200019 mergeable ranks)
109
+ # 200020 </think>
110
+ # 200021 <|system|>
111
+ # 200022 </|system|>
112
+ # 200023 <|user|>
113
+ # 200024 </|user|>
114
+ # 200025 <|assistant|> (turn closed by <|endoftext|>, no explicit close tag)
115
+ # ─────────────────────────────────────────────
116
+ SPECIAL_TOKENS: dict[str, int] = {
117
+ "<|endoftext|>": 199999,
118
+ "<think>": 200019,
119
+ "</think>": 200020,
120
+ "<|system|>": 200021,
121
+ "</|system|>": 200022,
122
+ "<|user|>": 200023,
123
+ "</|user|>": 200024,
124
+ "<|assistant|>": 200025,
125
+ }
126
+
127
+
128
+ # ─────────────────────────────────────────────
129
+ # Tokenizer
130
+ # ─────────────────────────────────────────────
131
+ def get_tokenizer() -> tiktoken.Encoding:
132
+ base = tiktoken.get_encoding("o200k_base")
133
+ return tiktoken.Encoding(
134
+ name="greesy_gpt",
135
+ pat_str=base._pat_str,
136
+ mergeable_ranks=base._mergeable_ranks,
137
+ special_tokens={**base._special_tokens, **SPECIAL_TOKENS},
138
+ )
139
+
140
+
141
+ tokenizer = get_tokenizer()
142
+
143
+ # Convenience single-token IDs used throughout
144
+ def _tid(s: str) -> int:
145
+ return tokenizer.encode(s, allowed_special="all")[0]
146
+
147
+ TOK_EOT = _tid("<|endoftext|>")
148
+ TOK_THINK_OPEN = _tid("<think>")
149
+ TOK_THINK_CLOSE = _tid("</think>")
150
+ TOK_SYSTEM = _tid("<|system|>")
151
+ TOK_ESYSTEM = _tid("</|system|>")
152
+ TOK_USER = _tid("<|user|>")
153
+ TOK_EUSER = _tid("</|user|>")
154
+ TOK_ASSISTANT = _tid("<|assistant|>")
155
+
156
+
157
+ # ─────────────────────────────────────────────
158
+ # Chat Template
159
+ # ─────────────────────────────────────────────
160
+ @dataclass
161
+ class Message:
162
+ """A single turn in a conversation."""
163
+ role: str # "system" | "user" | "assistant"
164
+ content: str
165
+
166
+
167
+ class ChatTemplate:
168
+ """
169
+ Serialises a ``list[Message]`` into GreesyGPT's role-delimited wire format:
170
+
171
+ <|system|>
172
+ {system content}
173
+ </|system|>
174
+ <|user|>
175
+ {user content}
176
+ </|user|>
177
+ <|assistant|>
178
+ <think>
179
+ {reasoning}
180
+ </think>
181
+ {verdict}<|endoftext|>
182
+
183
+ For training, only the *assistant completion* tokens (everything after
184
+ ``<|assistant|>\\n``) contribute to the loss; all other tokens are masked
185
+ with ``-100`` in the labels tensor.
186
+
187
+ For multi-turn conversations append additional user/assistant message
188
+ pairs — the template handles them in sequence.
189
+ """
190
+
191
+ _OPEN = {"system": "<|system|>", "user": "<|user|>", "assistant": "<|assistant|>"}
192
+ _CLOSE = {"system": "</|system|>", "user": "</|user|>", "assistant": ""} # EOT closes assistant
193
+
194
+ # ── Rendering ─────────────────────────────────────────────────────────────
195
+
196
+ @classmethod
197
+ def render(cls, messages: list[Message], add_generation_prompt: bool = False) -> str:
198
+ """
199
+ Render a full conversation to a single string.
200
+
201
+ Parameters
202
+ ----------
203
+ messages:
204
+ Ordered ``Message`` list (system, then alternating user/assistant).
205
+ add_generation_prompt:
206
+ Append ``<|assistant|>\\n<think>\\n`` so the model can continue
207
+ from the correct position during inference.
208
+ """
209
+ parts: list[str] = []
210
+ for msg in messages:
211
+ open_tag = cls._OPEN[msg.role]
212
+ close_tag = cls._CLOSE[msg.role]
213
+ if close_tag:
214
+ parts.append(f"{open_tag}\n{msg.content}\n{close_tag}\n")
215
+ else:
216
+ # assistant: content already contains <think>…</think>+verdict+EOT
217
+ parts.append(f"{open_tag}\n{msg.content}")
218
+ if add_generation_prompt:
219
+ parts.append("<|assistant|>\n<think>\n")
220
+ return "".join(parts)
221
+
222
+ @staticmethod
223
+ def render_assistant_content(reasoning: str, verdict: str) -> str:
224
+ """Build the assistant ``content`` field from reasoning + verdict."""
225
+ return f"<think>\n{reasoning}\n</think>\n{verdict}<|endoftext|>"
226
+
227
+ # ── Tokenisation with per-role label masking ──────────────────────────────
228
+
229
+ @classmethod
230
+ def tokenize(
231
+ cls,
232
+ messages: list[Message],
233
+ enc: tiktoken.Encoding,
234
+ max_length: int = 12288,
235
+ ) -> dict[str, torch.Tensor]:
236
+ """
237
+ Tokenise a conversation and return ``input_ids`` + ``labels``.
238
+
239
+ All system and user tokens are masked (``-100``).
240
+ Only assistant completion tokens are trained.
241
+ Supports multi-turn: each assistant turn is unmasked individually.
242
+ """
243
+ input_ids: list[int] = []
244
+ labels: list[int] = []
245
+
246
+ for msg in messages:
247
+ open_tag = cls._OPEN[msg.role]
248
+ close_tag = cls._CLOSE[msg.role]
249
+
250
+ if msg.role == "assistant":
251
+ # Role header — masked
252
+ header_toks = enc.encode(f"{open_tag}\n", allowed_special="all")
253
+ input_ids.extend(header_toks)
254
+ labels.extend([-100] * len(header_toks))
255
+ # Completion — trained
256
+ comp_toks = enc.encode(msg.content, allowed_special="all")
257
+ input_ids.extend(comp_toks)
258
+ labels.extend(comp_toks)
259
+ else:
260
+ text = (
261
+ f"{open_tag}\n{msg.content}\n{close_tag}\n"
262
+ if close_tag
263
+ else f"{open_tag}\n{msg.content}"
264
+ )
265
+ toks = enc.encode(text, allowed_special="all")
266
+ input_ids.extend(toks)
267
+ labels.extend([-100] * len(toks))
268
+
269
+ # Truncate to max_length
270
+ input_ids = input_ids[:max_length]
271
+ labels = labels[:max_length]
272
+
273
+ return {
274
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
275
+ "labels": torch.tensor(labels, dtype=torch.long),
276
+ }
277
+
278
+ # ── Inference helper ──────────────────────────────────────────────────────
279
+
280
+ @classmethod
281
+ def build_inference_prompt(cls, user_message: str, system_prompt: str = "") -> str:
282
+ """Return the prompt string to feed the model at inference time."""
283
+ msgs: list[Message] = []
284
+ if system_prompt:
285
+ msgs.append(Message("system", system_prompt))
286
+ msgs.append(Message("user", user_message))
287
+ return cls.render(msgs, add_generation_prompt=True)
288
+
289
+
290
+ # ─────────────────────────────────────────────
291
+ # Output Formats & Markdown helpers
292
+ # ─────────────────────────────────────────────
293
+ class OutputFormat(Enum):
294
+ """
295
+ MARKDOWN – raw model output; headings, bold, and lists preserved.
296
+ PLAIN – markdown stripped to clean prose.
297
+ JSON – structured dict with label, severity, confidence, and summary.
298
+ """
299
+ MARKDOWN = "markdown"
300
+ PLAIN = "plain"
301
+ JSON = "json"
302
+
303
+
304
+ # Severity scale used by the JSON formatter
305
+ _LABEL_SEVERITY: dict[str, int] = {
306
+ "SAFE": 0,
307
+ "SPAM": 1,
308
+ "MISINFORMATION": 2,
309
+ "HARASSMENT": 3,
310
+ "HATE_SPEECH": 4,
311
+ "CRISIS_REFERRAL": 5,
312
+ "UNSAFE": 6,
313
+ }
314
+
315
+ # Ordered substitutions for stripping Markdown syntax
316
+ _MD_STRIP: list[tuple[re.Pattern[str], str]] = [
317
+ (re.compile(r"#{1,6}\s*"), ""), # headings
318
+ (re.compile(r"\*\*(.+?)\*\*"), r"\1"), # bold
319
+ (re.compile(r"\*(.+?)\*"), r"\1"), # italic
320
+ (re.compile(r"`{1,3}(.+?)`{1,3}", re.DOTALL), r"\1"), # inline/fenced code
321
+ (re.compile(r"^\s*[-*+]\s+", re.M), "• "), # unordered list
322
+ (re.compile(r"^\s*\d+\.\s+", re.M), ""), # ordered list numbers
323
+ (re.compile(r"\[(.+?)\]\(.+?\)"), r"\1"), # links
324
+ (re.compile(r"^\s*>\s?", re.M), ""), # blockquotes
325
+ (re.compile(r"-{3,}|={3,}|\*{3,}"), ""), # horizontal rules
326
+ (re.compile(r"\n{3,}"), "\n\n"),# excess blank lines
327
+ ]
328
+
329
+
330
+ def strip_markdown(text: str) -> str:
331
+ """Remove common Markdown syntax and return clean prose."""
332
+ for pattern, repl in _MD_STRIP:
333
+ text = pattern.sub(repl, text)
334
+ return text.strip()
335
+
336
+
337
+ def extract_verdict_label(verdict_text: str) -> str:
338
+ """Pull the label keyword from a verdict block. Falls back to 'UNKNOWN'."""
339
+ upper = verdict_text.upper()
340
+ for label in _LABEL_SEVERITY:
341
+ if label in upper:
342
+ return label
343
+ return "UNKNOWN"
344
+
345
+
346
+ def format_output(result: dict[str, Any], fmt: OutputFormat = OutputFormat.MARKDOWN) -> "str | dict[str, Any]":
347
+ """
348
+ Post-process a ``generate_moderation`` result.
349
+
350
+ Returns ``str`` for MARKDOWN/PLAIN, ``dict`` for JSON.
351
+ """
352
+ verdict = result.get("verdict", "")
353
+ thinking = result.get("thinking") or ""
354
+ mode = result.get("mode")
355
+
356
+ if fmt == OutputFormat.PLAIN:
357
+ return strip_markdown(verdict)
358
+
359
+ if fmt == OutputFormat.JSON:
360
+ label = extract_verdict_label(verdict)
361
+ sev = _LABEL_SEVERITY.get(label, -1)
362
+ conf = "high" if sev <= 1 else ("medium" if sev <= 3 else "low")
363
+ return {
364
+ "verdict": label,
365
+ "severity": sev,
366
+ "confidence_hint": conf,
367
+ "reasoning_mode": mode.value if mode else None,
368
+ "thinking_summary": thinking[:300].strip(),
369
+ "full_verdict": verdict,
370
+ }
371
+
372
+ return verdict # MARKDOWN: return as-is
373
+
374
+
375
+ # ─────────────────────────────────────────────
376
+ # Reasoning Modes
377
+ # ─────────────────────────────────────────────
378
+ class ReasoningMode(Enum):
379
+ """
380
+ Controls how much thinking the model does before emitting a verdict.
381
+
382
+ NONE – minimal chain-of-thought; fastest, best for obvious cases.
383
+ LOW – balanced reasoning; good general-purpose default.
384
+ MEDIUM – extended deliberation; best for nuanced / borderline content.
385
+ HIGH – maximum tokens + lower temperature; use for high-stakes review.
386
+ """
387
+ NONE = "none"
388
+ LOW = "low"
389
+ MEDIUM = "medium"
390
+ HIGH = "high"
391
+
392
+
393
+ # Injected into every system prompt to teach the model Markdown output style
394
+ _MARKDOWN_INSTRUCTION = (
395
+ "Format your entire response in **Markdown**:\n"
396
+ "- Use `##` headings to separate reasoning sections "
397
+ "(e.g. `## Intent`, `## Evidence`, `## Harm Potential`, `## Edge Cases`).\n"
398
+ "- Use `**bold**` to highlight key terms, policy labels, and the final verdict.\n"
399
+ "- Use bullet lists (`-`) for evidence points.\n"
400
+ "- Use `>` blockquotes for direct quotes from the message under review.\n"
401
+ "- End with a `## Verdict` section containing **only** the label on its own line "
402
+ "(e.g. `**SAFE**`, `**HARASSMENT**`).\n"
403
+ )
404
+
405
+
406
+ @dataclass
407
+ class ReasoningConfig:
408
+ max_think_tokens: int # token budget for <think>…</think>
409
+ max_total_tokens: int # total generation budget (think + verdict)
410
+ temperature: float
411
+ top_k: int
412
+ system_prompt: str # full system-turn content
413
+
414
+ def total_budget(self) -> int:
415
+ return self.max_total_tokens
416
+
417
+
418
+ REASONING_CONFIGS: dict[ReasoningMode, ReasoningConfig] = {
419
+ ReasoningMode.NONE: ReasoningConfig(
420
+ max_think_tokens=200,
421
+ max_total_tokens=812,
422
+ temperature=0.1,
423
+ top_k=50,
424
+ system_prompt=(
425
+ "You are a fast content moderator. "
426
+ "Think briefly, then give a clear verdict.\n\n"
427
+ + _MARKDOWN_INSTRUCTION
428
+ ),
429
+ ),
430
+ ReasoningMode.LOW: ReasoningConfig(
431
+ max_think_tokens=512,
432
+ max_total_tokens=1200,
433
+ temperature=0.7,
434
+ top_k=40,
435
+ system_prompt=(
436
+ "You are a careful content moderator. "
437
+ "Reason step-by-step, then issue a structured Markdown verdict.\n\n"
438
+ + _MARKDOWN_INSTRUCTION
439
+ ),
440
+ ),
441
+ ReasoningMode.MEDIUM: ReasoningConfig(
442
+ max_think_tokens=1536,
443
+ max_total_tokens=2048,
444
+ temperature=0.6,
445
+ top_k=30,
446
+ system_prompt=(
447
+ "You are a thorough content moderator. "
448
+ "Analyse intent, context, potential harm, and edge cases "
449
+ "before issuing a detailed Markdown verdict.\n\n"
450
+ + _MARKDOWN_INSTRUCTION
451
+ ),
452
+ ),
453
+ ReasoningMode.HIGH: ReasoningConfig(
454
+ max_think_tokens=3072,
455
+ max_total_tokens=4096,
456
+ temperature=0.4,
457
+ top_k=20,
458
+ system_prompt=(
459
+ "You are an expert content safety reviewer. "
460
+ "Examine the message from every relevant angle—legal, ethical, "
461
+ "contextual, and platform-policy—and produce a comprehensive "
462
+ "Markdown safety report.\n\n"
463
+ + _MARKDOWN_INSTRUCTION
464
+ ),
465
+ ),
466
+ }
467
+
468
+
469
+ def describe_reasoning_modes() -> str:
470
+ lines = ["Available Reasoning Modes\n" + "=" * 44]
471
+ for mode, cfg in REASONING_CONFIGS.items():
472
+ lines.append(
473
+ f" {mode.value:10s} think≤{cfg.max_think_tokens:4d} tokens | "
474
+ f"total≤{cfg.max_total_tokens:4d} tokens | "
475
+ f"temp={cfg.temperature} | top_k={cfg.top_k}"
476
+ )
477
+ return "\n".join(lines)
478
+
479
+
480
+ DATASET_JSON_PATH = Path(__file__).with_name("dataset.json")
481
+
482
+
483
+ # ─────────────────────────────────────────────
484
+ # KV Cache
485
+ # ─────────────────────────────────────────────
486
+ @dataclass
487
+ class KVCache:
488
+ """
489
+ Per-layer key/value cache for autoregressive generation.
490
+
491
+ Stores tensors of shape ``[B, n_head, T_cached, head_dim]``.
492
+ Grows incrementally as new tokens are generated.
493
+ """
494
+
495
+ key: Optional[torch.Tensor] = None
496
+ value: Optional[torch.Tensor] = None
497
+
498
+ @property
499
+ def seq_len(self) -> int:
500
+ """Number of tokens currently cached."""
501
+ if self.key is None:
502
+ return 0
503
+ return self.key.shape[2]
504
+
505
+ def update(
506
+ self, new_key: torch.Tensor, new_value: torch.Tensor
507
+ ) -> tuple[torch.Tensor, torch.Tensor]:
508
+ """
509
+ Append new K/V slices and return the full accumulated tensors.
510
+
511
+ Parameters
512
+ ----------
513
+ new_key, new_value : ``[B, n_head, T_new, head_dim]``
514
+
515
+ Returns
516
+ -------
517
+ (full_key, full_value) each ``[B, n_head, T_total, head_dim]``
518
+ """
519
+ if self.key is None or self.value is None:
520
+ self.key = new_key
521
+ self.value = new_value
522
+ else:
523
+ assert self.key is not None and self.value is not None
524
+ self.key = torch.cat((self.key, new_key), dim=2)
525
+ self.value = torch.cat((self.value, new_value), dim=2)
526
+
527
+ return cast(torch.Tensor, self.key), cast(torch.Tensor, self.value)
528
+
529
+ def clear(self) -> None:
530
+ self.key = None
531
+ self.value = None
532
+
533
+
534
+ # Type alias: one KVCache per layer
535
+ LayerCaches = list[KVCache]
536
+
537
+
538
+ def make_kv_caches(n_layers: int) -> LayerCaches:
539
+ """Create a fresh list of empty KV caches, one per transformer layer."""
540
+ return [KVCache() for _ in range(n_layers)]
541
+
542
+
543
+ # ─────────────────────────────────────────────
544
+ # RoPE (supports position offset for KV cache)
545
+ # ─────────────────────────────────────────────
546
+ class RoPE(nn.Module):
547
+ def __init__(self, head_dim: int, max_seq_len: int = CONTEXT_LEN):
548
+ super().__init__()
549
+ inv_freq = 1.0 / (10000.0 ** (torch.arange(0, head_dim, 2).float() / head_dim))
550
+ self.register_buffer("inv_freq", inv_freq)
551
+
552
+ def forward(
553
+ self, seq_len: int, device: torch.device, offset: int = 0
554
+ ) -> tuple[torch.Tensor, torch.Tensor]:
555
+ """
556
+ Compute cos/sin embeddings for positions ``[offset, offset+seq_len)``.
557
+
558
+ Parameters
559
+ ----------
560
+ seq_len : number of new positions to compute
561
+ device : target device
562
+ offset : starting position index (= number of previously cached tokens)
563
+ """
564
+ inv_freq = cast(torch.Tensor, self.inv_freq)
565
+ t: torch.Tensor = torch.arange(offset, offset + seq_len, device=device, dtype=inv_freq.dtype)
566
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
567
+ emb = torch.cat((freqs, freqs), dim=-1)
568
+ return emb.cos()[None, :, None, :], emb.sin()[None, :, None, :]
569
+
570
+
571
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
572
+ half = x.shape[-1] // 2
573
+ return torch.cat((-x[..., half:], x[..., :half]), dim=-1)
574
+
575
+
576
+ def apply_rope(q, k, cos, sin):
577
+ q = (q * cos) + (rotate_half(q) * sin)
578
+ k = (k * cos) + (rotate_half(k) * sin)
579
+ return q, k
580
+
581
+
582
+ # ─────────────────────────────────────────────
583
+ # Transformer Block (with KV cache + dropout)
584
+ # ─────────────────────────────────────────────
585
+ class GreesyBlock(nn.Module):
586
+ def __init__(self, config: ModelConfig):
587
+ super().__init__()
588
+ n_embd = config.n_embd
589
+ self.n_head = config.n_head
590
+ self.head_dim = config.head_dim
591
+
592
+ self.ln1 = nn.LayerNorm(n_embd)
593
+ self.ln2 = nn.LayerNorm(n_embd)
594
+ self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
595
+ self.out_proj = nn.Linear(n_embd, n_embd, bias=False)
596
+ self.rope = RoPE(self.head_dim, max_seq_len=config.context_len)
597
+
598
+ # Dropout layers
599
+ self.attn_dropout = nn.Dropout(config.attn_dropout)
600
+ self.resid_dropout = nn.Dropout(config.resid_dropout)
601
+
602
+ self.mlp = nn.Sequential(
603
+ nn.Linear(n_embd, 4 * n_embd),
604
+ nn.GELU(),
605
+ nn.Dropout(config.mlp_dropout),
606
+ nn.Linear(4 * n_embd, n_embd),
607
+ nn.Dropout(config.resid_dropout),
608
+ )
609
+
610
+ def forward(
611
+ self,
612
+ x: torch.Tensor,
613
+ kv_cache: Optional[KVCache] = None,
614
+ ) -> torch.Tensor:
615
+ """
616
+ Forward pass with optional KV caching.
617
+
618
+ Parameters
619
+ ----------
620
+ x : ``[B, T, C]`` — input embeddings (full sequence or single new token)
621
+ kv_cache : if provided, keys/values are appended to the cache and the
622
+ full cached K/V are used for attention, enabling O(1) per-token
623
+ inference instead of O(T).
624
+ """
625
+ B, T, C = x.shape
626
+
627
+ norm_x = self.ln1(x)
628
+ qkv = self.qkv(norm_x).reshape(B, T, 3, self.n_head, self.head_dim)
629
+ q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
630
+
631
+ # Position offset = number of tokens already in the cache
632
+ offset = kv_cache.seq_len if kv_cache is not None else 0
633
+ cos, sin = self.rope(T, x.device, offset=offset)
634
+ q, k = apply_rope(q, k, cos, sin)
635
+
636
+ # [B, T, n_head, head_dim] → [B, n_head, T, head_dim]
637
+ q = q.transpose(1, 2)
638
+ k = k.transpose(1, 2)
639
+ v = v.transpose(1, 2)
640
+
641
+ # Update KV cache (if provided) and use full cached K/V for attention
642
+ if kv_cache is not None:
643
+ k, v = kv_cache.update(k, v)
644
+
645
+ # Causal mask is only needed during training/prefill (offset==0).
646
+ # During cached single-token generation (offset>0, T_q=1) every
647
+ # cached position is visible, so is_causal=False is correct.
648
+ is_causal = kv_cache is None or offset == 0
649
+ attn_out = F.scaled_dot_product_attention(
650
+ q, k, v,
651
+ is_causal=is_causal,
652
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
653
+ )
654
+
655
+ # [B, n_head, T, head_dim] → [B, T, C]
656
+ attn_out = attn_out.transpose(1, 2).reshape(B, T, C)
657
+ x = x + self.resid_dropout(self.out_proj(attn_out))
658
+ x = x + self.mlp(self.ln2(x))
659
+ return x
660
+
661
+
662
+ # ─────────────────────────────────────────────
663
+ # Model (config-driven, with embedding dropout + KV cache support)
664
+ # ─────────────────────────────────────────────
665
+ class GreesyGPT(nn.Module):
666
+ def __init__(self, config: Optional[ModelConfig] = None):
667
+ super().__init__()
668
+ self.config = config or DEFAULT_CONFIG
669
+ c = self.config
670
+
671
+ self.embd_dropout = nn.Dropout(c.embd_dropout)
672
+ self.tok_emb = nn.Embedding(c.vocab_size, c.n_embd)
673
+ self.blocks = nn.ModuleList([GreesyBlock(c) for _ in range(c.n_layer)])
674
+ self.ln_f = nn.LayerNorm(c.n_embd)
675
+ self.head = nn.Linear(c.n_embd, c.vocab_size, bias=False)
676
+ self.tok_emb.weight = self.head.weight # weight tying
677
+
678
+ n_params = sum(p.numel() for p in self.parameters())
679
+ logger.info(
680
+ "GreesyGPT initialised — %.2fM params, %d layers, %d heads, "
681
+ "ctx=%d, dropout=(attn=%.2f, resid=%.2f, embd=%.2f, mlp=%.2f)",
682
+ n_params / 1e6, c.n_layer, c.n_head, c.context_len,
683
+ c.attn_dropout, c.resid_dropout, c.embd_dropout, c.mlp_dropout,
684
+ )
685
+
686
+ def forward(
687
+ self,
688
+ idx: torch.Tensor,
689
+ kv_caches: Optional[LayerCaches] = None,
690
+ ) -> torch.Tensor:
691
+ """
692
+ Forward pass.
693
+
694
+ Parameters
695
+ ----------
696
+ idx : ``[B, T]`` token indices
697
+ kv_caches : optional list of ``KVCache`` (one per layer). When
698
+ provided, enables incremental decoding.
699
+
700
+ Returns
701
+ -------
702
+ logits : ``[B, T, vocab_size]``
703
+ """
704
+ x = self.embd_dropout(self.tok_emb(idx))
705
+ for i, block in enumerate(self.blocks):
706
+ cache = kv_caches[i] if kv_caches is not None else None
707
+ x = block(x, kv_cache=cache)
708
+ return self.head(self.ln_f(x))
709
+
710
+
711
+ # ─────────────────────────────────────────────
712
+ # Generation (KV-cached, chat-template + reasoning-mode aware)
713
+ # ─────────────────────────────────────────────
714
+ @torch.inference_mode()
715
+ def generate_moderation(
716
+ model: GreesyGPT,
717
+ prompt: str,
718
+ mode: ReasoningMode = ReasoningMode.LOW,
719
+ output_format: OutputFormat = OutputFormat.MARKDOWN,
720
+ # Optional per-call overrides (take priority over mode defaults)
721
+ max_tokens: Optional[int] = None,
722
+ temp: Optional[float] = None,
723
+ top_k: Optional[int] = None,
724
+ use_kv_cache: bool = True,
725
+ ) -> dict[str, Any]:
726
+ """
727
+ Run moderation inference via the chat template with KV caching.
728
+
729
+ Parameters
730
+ ----------
731
+ model : trained GreesyGPT
732
+ prompt : raw user message to moderate
733
+ mode : ReasoningMode (controls budget, temperature, system prompt)
734
+ output_format : post-processing format for the verdict
735
+ max_tokens : overrides mode's ``max_total_tokens``
736
+ temp : overrides mode's ``temperature``
737
+ top_k : overrides mode's ``top_k``
738
+ use_kv_cache : if True (default), use KV caching for efficient generation
739
+
740
+ Returns
741
+ -------
742
+ dict with keys
743
+ full_text raw decoded sequence (includes special tokens)
744
+ thinking <think>…</think> block content (str | None)
745
+ verdict raw Markdown verdict string
746
+ verdict_fmt post-processed verdict (str or dict depending on output_format)
747
+ mode ReasoningMode used
748
+ output_format OutputFormat used
749
+ """
750
+ cfg = REASONING_CONFIGS[mode]
751
+ _max = max_tokens if max_tokens is not None else cfg.max_total_tokens
752
+ _temp = temp if temp is not None else cfg.temperature
753
+ _topk = top_k if top_k is not None else cfg.top_k
754
+
755
+ model.eval()
756
+
757
+ input_str = ChatTemplate.build_inference_prompt(
758
+ user_message=prompt,
759
+ system_prompt=cfg.system_prompt,
760
+ )
761
+ tokens = tokenizer.encode(input_str, allowed_special="all")
762
+ context_len = model.config.context_len
763
+
764
+ think_tokens = 0
765
+ think_closed = False
766
+
767
+ if use_kv_cache:
768
+ # ── KV-cached generation ──────────────────────────────────────────
769
+ kv_caches = make_kv_caches(model.config.n_layer)
770
+ idx = torch.tensor([tokens], device=DEVICE)
771
+
772
+ # Truncate prompt if it exceeds context length
773
+ if idx.shape[1] > context_len:
774
+ idx = idx[:, -context_len:]
775
+ logger.warning("Prompt truncated to context_len=%d tokens", context_len)
776
+
777
+ # Prefill: process the entire prompt in one forward pass
778
+ logits = model(idx, kv_caches=kv_caches)
779
+
780
+ generated_ids: list[int] = []
781
+
782
+ for _ in range(_max):
783
+ scaled_logits = logits[:, -1, :] / _temp
784
+
785
+ if not think_closed and think_tokens >= cfg.max_think_tokens:
786
+ next_id = TOK_THINK_CLOSE
787
+ next_tok = torch.tensor([[next_id]], device=DEVICE)
788
+ else:
789
+ v, _ = torch.topk(scaled_logits, _topk)
790
+ scaled_logits[scaled_logits < v[:, [-1]]] = -float("Inf")
791
+ probs = F.softmax(scaled_logits, dim=-1)
792
+ next_tok = torch.multinomial(probs, num_samples=1)
793
+ next_id = int(next_tok.item())
794
+
795
+ generated_ids.append(next_id)
796
+
797
+ if not think_closed:
798
+ if next_id == TOK_THINK_CLOSE:
799
+ think_closed = True
800
+ else:
801
+ think_tokens += 1
802
+
803
+ if next_id == TOK_EOT:
804
+ break
805
+
806
+ # Check context length limit
807
+ if kv_caches[0].seq_len >= context_len:
808
+ logger.warning("Reached context_len=%d during generation, stopping.", context_len)
809
+ break
810
+
811
+ # Single-token forward pass using cached K/V
812
+ logits = model(next_tok, kv_caches=kv_caches)
813
+
814
+ all_ids = tokens + generated_ids
815
+
816
+ else:
817
+ # ── Non-cached generation (legacy path) ──────────────────────────
818
+ idx = torch.tensor([tokens], device=DEVICE)
819
+
820
+ for _ in range(_max):
821
+ logits = model(idx[:, -context_len:])
822
+ logits = logits[:, -1, :] / _temp
823
+
824
+ if not think_closed and think_tokens >= cfg.max_think_tokens:
825
+ next_id = TOK_THINK_CLOSE
826
+ next_tok = torch.tensor([[next_id]], device=idx.device)
827
+ else:
828
+ v, _ = torch.topk(logits, _topk)
829
+ logits[logits < v[:, [-1]]] = -float("Inf")
830
+ probs = F.softmax(logits, dim=-1)
831
+ next_tok = torch.multinomial(probs, num_samples=1)
832
+ next_id = int(next_tok.item())
833
+
834
+ idx = torch.cat((idx, next_tok), dim=1)
835
+
836
+ if not think_closed:
837
+ if next_id == TOK_THINK_CLOSE:
838
+ think_closed = True
839
+ else:
840
+ think_tokens += 1
841
+
842
+ if next_id == TOK_EOT:
843
+ break
844
+
845
+ all_ids = idx[0].tolist()
846
+
847
+ full_text = tokenizer.decode(all_ids, errors="replace")
848
+
849
+ # ── Parse <think> block and verdict ──────────────────────────────────────
850
+ thinking = verdict = None
851
+ if "<think>" in full_text and "</think>" in full_text:
852
+ ts = full_text.index("<think>") + len("<think>")
853
+ te = full_text.index("</think>")
854
+ thinking = full_text[ts:te].strip()
855
+ verdict = full_text[te + len("</think>"):].strip()
856
+ else:
857
+ verdict = re.sub(r"^.*?<\|assistant\|>\s*", "", full_text, flags=re.DOTALL).strip()
858
+
859
+ verdict = re.sub(r"^<\|assistant\|>\s*", "", verdict).strip()
860
+
861
+ result: dict[str, Any] = {
862
+ "full_text": full_text,
863
+ "thinking": thinking,
864
+ "verdict": verdict,
865
+ "mode": mode,
866
+ "output_format": output_format,
867
+ }
868
+ result["verdict_fmt"] = format_output(result, output_format)
869
+ return result
870
+
871
+
872
+ # ─────────────────────────────────────────────
873
+ # Sample Dataset (Markdown-formatted reasoning)
874
+ # ─────────────────────────────────────────────
875
+ SAMPLE_MODERATION_DATA: list[dict[str, Any]] = [
876
+ # ── SAFE ──────────────────────────────────────────────────────────────────
877
+ {
878
+ "instruction": "What's a good recipe for chocolate chip cookies?",
879
+ "reasoning": (
880
+ "## Intent\n"
881
+ "A standard culinary question with no harmful framing.\n\n"
882
+ "## Harm Potential\n"
883
+ "- No harmful content detected.\n"
884
+ "- Appropriate for all audiences.\n\n"
885
+ "## Edge Cases\n"
886
+ "None applicable."
887
+ ),
888
+ "output": "## Verdict\n**SAFE**",
889
+ },
890
+ {
891
+ "instruction": "Can you explain how photosynthesis works?",
892
+ "reasoning": (
893
+ "## Intent\n"
894
+ "Science education question — no harmful elements.\n\n"
895
+ "## Harm Potential\n"
896
+ "- Zero risk.\n"
897
+ "- Standard biology curriculum content.\n\n"
898
+ "## Edge Cases\n"
899
+ "None applicable."
900
+ ),
901
+ "output": "## Verdict\n**SAFE**",
902
+ },
903
+ {
904
+ "instruction": "What are some tips for improving my sleep hygiene?",
905
+ "reasoning": (
906
+ "## Intent\n"
907
+ "Wellness question about healthy sleeping habits.\n\n"
908
+ "## Harm Potential\n"
909
+ "- Content is broadly beneficial.\n"
910
+ "- No self-harm framing present.\n\n"
911
+ "## Edge Cases\n"
912
+ "None applicable."
913
+ ),
914
+ "output": "## Verdict\n**SAFE**",
915
+ },
916
+ {
917
+ "instruction": "Write a short poem about autumn leaves.",
918
+ "reasoning": (
919
+ "## Intent\n"
920
+ "Creative writing request with pleasant seasonal imagery.\n\n"
921
+ "## Harm Potential\n"
922
+ "- No violent or offensive themes.\n\n"
923
+ "## Edge Cases\n"
924
+ "None applicable."
925
+ ),
926
+ "output": "## Verdict\n**SAFE**",
927
+ },
928
+ {
929
+ "instruction": "How do I fix a 'module not found' error in Python?",
930
+ "reasoning": (
931
+ "## Intent\n"
932
+ "Standard software debugging question.\n\n"
933
+ "## Harm Potential\n"
934
+ "- No malicious intent detectable; routine developer activity.\n\n"
935
+ "## Edge Cases\n"
936
+ "None applicable."
937
+ ),
938
+ "output": "## Verdict\n**SAFE**",
939
+ },
940
+ # ── SPAM ──────────────────────────────────────────────────────────────────
941
+ {
942
+ "instruction": (
943
+ "CONGRATULATIONS! You've been selected to receive a $1,000 gift card. "
944
+ "Click here now to claim your prize before it expires!!!"
945
+ ),
946
+ "reasoning": (
947
+ "## Intent\n"
948
+ "Unsolicited prize notification designed to deceive the recipient.\n\n"
949
+ "## Evidence\n"
950
+ "- **ALL-CAPS** opener — attention-hijacking tactic.\n"
951
+ "- Implausibly large reward with zero context.\n"
952
+ "- Artificial deadline pressure ('before it expires').\n"
953
+ "- Suspicious link prompt (`Click here`).\n\n"
954
+ "## Harm Potential\n"
955
+ "High. Consistent with **phishing / advance-fee fraud**.\n\n"
956
+ "## Edge Cases\n"
957
+ "Legitimate prize notifications do not use this pattern; "
958
+ "no mitigating context present."
959
+ ),
960
+ "output": "## Verdict\n**SPAM**",
961
+ },
962
+ {
963
+ "instruction": (
964
+ "Make money fast! Work from home and earn $5000 a week "
965
+ "with no experience needed. DM me for details."
966
+ ),
967
+ "reasoning": (
968
+ "## Intent\n"
969
+ "Recruitment for an unspecified work-from-home scheme.\n\n"
970
+ "## Evidence\n"
971
+ "- Unrealistic earning claim ($5,000/week, no experience).\n"
972
+ "- Redirect to private messaging — opaque onboarding is a classic "
973
+ "**MLM / scam** signal.\n"
974
+ "- No product, employer, or verifiable detail provided.\n\n"
975
+ "## Harm Potential\n"
976
+ "High. Likely pyramid scheme or fraudulent job offer.\n\n"
977
+ "## Edge Cases\n"
978
+ "Could theoretically be legitimate affiliate marketing, but the total "
979
+ "absence of verifiable detail removes the benefit of the doubt."
980
+ ),
981
+ "output": "## Verdict\n**SPAM**",
982
+ },
983
+ # ── HARASSMENT ────────────────────────────────────────────────────────────
984
+ {
985
+ "instruction": "You're so stupid, nobody likes you and you should just disappear.",
986
+ "reasoning": (
987
+ "## Intent\n"
988
+ "Direct personal attack aimed at demeaning and distressing an individual.\n\n"
989
+ "## Evidence\n"
990
+ "> 'You're so stupid'\n\n"
991
+ "Targeted intellectual insult.\n\n"
992
+ "> 'nobody likes you'\n\n"
993
+ "Social-isolation attack.\n\n"
994
+ "> 'should just disappear'\n\n"
995
+ "Veiled suggestion of self-removal — crisis-adjacent language.\n\n"
996
+ "## Harm Potential\n"
997
+ "High. Combination of insult and dismissal constitutes targeted "
998
+ "**harassment**. The 'disappear' phrasing adds a secondary risk "
999
+ "of normalising self-harm ideation.\n\n"
1000
+ "## Edge Cases\n"
1001
+ "No mitigating context (fiction, roleplay) present."
1002
+ ),
1003
+ "output": "## Verdict\n**HARASSMENT**",
1004
+ },
1005
+ {
1006
+ "instruction": "I know where you live and I'll make sure you regret posting that.",
1007
+ "reasoning": (
1008
+ "## Intent\n"
1009
+ "Intimidation message designed to silence a specific person.\n\n"
1010
+ "## Evidence\n"
1011
+ "> 'I know where you live'\n\n"
1012
+ "Implicit doxxing claim — suggests the sender has physical location data.\n\n"
1013
+ "> 'make sure you regret posting that'\n\n"
1014
+ "Vague but menacing promise of retaliation.\n\n"
1015
+ "## Harm Potential\n"
1016
+ "Very high. This constitutes an **implicit threat** with meaningful "
1017
+ "escalation risk to real-world harm.\n\n"
1018
+ "## Edge Cases\n"
1019
+ "None; the threatening framing is unambiguous."
1020
+ ),
1021
+ "output": "## Verdict\n**HARASSMENT**",
1022
+ },
1023
+ # ── HATE SPEECH ───────────────────────────────────────────────────────────
1024
+ {
1025
+ "instruction": "All people from [country] are criminals and should be deported.",
1026
+ "reasoning": (
1027
+ "## Intent\n"
1028
+ "Sweeping negative generalisation applied to an entire national group.\n\n"
1029
+ "## Evidence\n"
1030
+ "- Attributes **criminality** to every member of a nationality — "
1031
+ "textbook xenophobic stereotype.\n"
1032
+ "- Calls for **mass deportation** — dehumanising policy dimension.\n"
1033
+ "- No nuance, caveat, or individual context.\n\n"
1034
+ "## Harm Potential\n"
1035
+ "High. Promotes discrimination and could incite real-world hostility.\n\n"
1036
+ "## Policy\n"
1037
+ "Violates hate-speech policy under the **national origin** protected "
1038
+ "characteristic."
1039
+ ),
1040
+ "output": "## Verdict\n**HATE_SPEECH**",
1041
+ },
1042
+ # ── CRISIS REFERRAL ───────────────────────────────────────────────────────
1043
+ {
1044
+ "instruction": "I've been feeling really hopeless lately and I don't see a reason to keep going.",
1045
+ "reasoning": (
1046
+ "## Intent\n"
1047
+ "User is expressing deep personal distress.\n\n"
1048
+ "## Evidence\n"
1049
+ "> 'feeling really hopeless'\n\n"
1050
+ "Recognised warning sign of depression / suicidal ideation.\n\n"
1051
+ "> 'don't see a reason to keep going'\n\n"
1052
+ "Implicit passive suicidal ideation — no explicit plan stated, "
1053
+ "but language warrants **immediate** attention.\n\n"
1054
+ "## Harm Potential\n"
1055
+ "High. This is not standard moderation territory — user safety is "
1056
+ "the primary concern.\n\n"
1057
+ "## Recommended Action\n"
1058
+ "Escalate to **crisis resources** immediately. Respond with empathetic "
1059
+ "referral to a mental-health helpline rather than standard policy "
1060
+ "enforcement."
1061
+ ),
1062
+ "output": "## Verdict\n**CRISIS_REFERRAL**",
1063
+ },
1064
+ # ── MISINFORMATION ────────────────────────────────────────────────────────
1065
+ {
1066
+ "instruction": (
1067
+ "Vaccines cause autism—this has been proven by multiple independent "
1068
+ "studies. Share this with everyone!"
1069
+ ),
1070
+ "reasoning": (
1071
+ "## Intent\n"
1072
+ "Presenting a debunked medical claim as fact and urging viral spread.\n\n"
1073
+ "## Evidence\n"
1074
+ "- The vaccine-autism link originates from a **retracted 1998 Wakefield "
1075
+ "paper** found to be fraudulent.\n"
1076
+ "- Refuted by dozens of large-scale epidemiological studies "
1077
+ "(millions of children).\n"
1078
+ "- Claim of 'multiple independent studies' is **false**.\n"
1079
+ "- 'Share this with everyone!' actively promotes spread of the "
1080
+ "misinformation.\n\n"
1081
+ "## Harm Potential\n"
1082
+ "Very high. Vaccine hesitancy driven by this myth has caused measurable "
1083
+ "declines in immunisation coverage and preventable outbreaks.\n\n"
1084
+ "## Policy\n"
1085
+ "Violates **health misinformation** policy."
1086
+ ),
1087
+ "output": "## Verdict\n**MISINFORMATION**",
1088
+ },
1089
+ {
1090
+ "instruction": (
1091
+ "The moon landing in 1969 was faked by NASA on a Hollywood set. "
1092
+ "Here's the proof."
1093
+ ),
1094
+ "reasoning": (
1095
+ "## Intent\n"
1096
+ "Promoting a well-known conspiracy theory as documented fact.\n\n"
1097
+ "## Evidence\n"
1098
+ "- Apollo 11 is independently verified by tracking stations in multiple "
1099
+ "countries including the **Soviet Union** (adversary with every incentive "
1100
+ "to expose a hoax).\n"
1101
+ "- Lunar samples, retroreflectors, and telemetry data all corroborate the "
1102
+ "mission.\n"
1103
+ "- 'Here's proof' promises evidence that conspiracy literature has "
1104
+ "repeatedly failed to produce.\n\n"
1105
+ "## Harm Potential\n"
1106
+ "Moderate. Erodes **public trust in science and institutions**; "
1107
+ "contributes to broader epistemic harm.\n\n"
1108
+ "## Policy\n"
1109
+ "Flagged as **misinformation** (thoroughly-refuted, well-documented claim)."
1110
+ ),
1111
+ "output": "## Verdict\n**MISINFORMATION**",
1112
+ },
1113
+ # ── AMBIGUOUS / CONTEXT-DEPENDENT ─────────────────────────────────────────
1114
+ {
1115
+ "instruction": "How do I pick a lock?",
1116
+ "reasoning": (
1117
+ "## Intent\n"
1118
+ "Ambiguous. Lock-picking has multiple legitimate use-cases.\n\n"
1119
+ "## Legitimate Uses\n"
1120
+ "- **Locksport** — a recognised competitive hobby.\n"
1121
+ "- Professional **locksmiths** and security researchers.\n"
1122
+ "- Individuals locked out of their own property.\n\n"
1123
+ "## Harmful Uses\n"
1124
+ "- Facilitating burglary or unauthorised entry.\n\n"
1125
+ "## Harm Assessment\n"
1126
+ "Without additional context (e.g., mention of a specific target or prior "
1127
+ "threatening messages), this query alone does **not** meet the violation "
1128
+ "threshold. Information is freely available and has significant legitimate "
1129
+ "utility.\n\n"
1130
+ "## Decision\n"
1131
+ "Allow. Note: escalation to UNSAFE warranted if threatening context "
1132
+ "appears in the same conversation."
1133
+ ),
1134
+ "output": "## Verdict\n**SAFE** *(context-dependent)*",
1135
+ },
1136
+ {
1137
+ "instruction": "What household chemicals should never be mixed together?",
1138
+ "reasoning": (
1139
+ "## Intent\n"
1140
+ "The framing is explicitly **preventative**, not instructional.\n\n"
1141
+ "## Evidence\n"
1142
+ "- Phrased as 'should *never* be mixed' — seeks to *avoid* harm.\n"
1143
+ "- This information is routinely published by poison-control centres, "
1144
+ "schools, and safety agencies.\n\n"
1145
+ "## Harm Assessment\n"
1146
+ "Low. Serves harm-reduction purposes. Re-classification would require "
1147
+ "explicit harmful intent in additional context.\n\n"
1148
+ "## Decision\n"
1149
+ "Allow."
1150
+ ),
1151
+ "output": "## Verdict\n**SAFE**",
1152
+ },
1153
+ ]
1154
+
1155
+
1156
+ def get_sample_dataset(
1157
+ tokenizer_obj: Optional[tiktoken.Encoding] = None,
1158
+ max_length: int = 12288,
1159
+ system_prompt: str = "",
1160
+ ) -> "ModerationReasoningDataset":
1161
+ """Return a ``ModerationReasoningDataset`` pre-loaded with SAMPLE_MODERATION_DATA."""
1162
+ enc = tokenizer_obj or tokenizer
1163
+ return ModerationReasoningDataset(
1164
+ SAMPLE_MODERATION_DATA, enc,
1165
+ max_length=max_length,
1166
+ system_prompt=system_prompt,
1167
+ )
1168
+
1169
+
1170
+ def load_dataset_json(file_path: Optional[str | Path] = None) -> list[dict[str, Any]]:
1171
+ """Load dataset from JSON file."""
1172
+ path = Path(file_path) if file_path is not None else DATASET_JSON_PATH
1173
+ with path.open("r", encoding="utf-8") as f:
1174
+ data = json.load(f)
1175
+
1176
+ if not isinstance(data, list):
1177
+ raise ValueError("dataset.json must contain a list of samples")
1178
+
1179
+ # Decrease dataset to 1k as requested
1180
+ data = data[:1000]
1181
+
1182
+ normalized: list[dict[str, Any]] = []
1183
+ for index, item in enumerate(data):
1184
+ if not isinstance(item, dict):
1185
+ raise ValueError(f"dataset.json item {index} must be an object")
1186
+
1187
+ instruction = item.get("instruction")
1188
+ reasoning = item.get("reasoning")
1189
+ output = item.get("output")
1190
+ label = item.get("label")
1191
+
1192
+ if not isinstance(instruction, str) or not instruction.strip():
1193
+ raise ValueError(f"dataset.json item {index} is missing a valid instruction")
1194
+ if not isinstance(reasoning, str) or not reasoning.strip():
1195
+ raise ValueError(f"dataset.json item {index} is missing valid reasoning")
1196
+
1197
+ if not isinstance(output, str) or not output.strip():
1198
+ if not isinstance(label, str) or not label.strip():
1199
+ raise ValueError(f"dataset.json item {index} needs output or label")
1200
+ output = f"## Verdict\n**{label.strip().upper()}**"
1201
+
1202
+ normalized.append(
1203
+ {
1204
+ "instruction": instruction.strip(),
1205
+ "reasoning": reasoning.strip(),
1206
+ "output": output.strip(),
1207
+ }
1208
+ )
1209
+
1210
+ return normalized
1211
+
1212
+
1213
+ def get_dataset(
1214
+ tokenizer_obj: Optional[tiktoken.Encoding] = None,
1215
+ max_length: int = 12288,
1216
+ system_prompt: str = "",
1217
+ file_path: Optional[str | Path] = None,
1218
+ ) -> "ModerationReasoningDataset":
1219
+ enc = tokenizer_obj or tokenizer
1220
+ samples = load_dataset_json(file_path)
1221
+ return ModerationReasoningDataset(
1222
+ samples,
1223
+ enc,
1224
+ max_length=max_length,
1225
+ system_prompt=system_prompt,
1226
+ )
1227
+
1228
+
1229
+ # ─────────────────────────────────────────────
1230
+ # Dataset (chat-template aware)
1231
+ # ─────────────────────────────────────────────
1232
+ class ModerationReasoningDataset(Dataset[dict[str, torch.Tensor]]):
1233
+ """
1234
+ Formats each sample as a three-turn chat via ``ChatTemplate``:
1235
+
1236
+ system → moderator persona + Markdown instruction
1237
+ user → message under review
1238
+ assistant → <think>{reasoning}</think>{verdict}<|endoftext|>
1239
+
1240
+ Only assistant tokens contribute to the training loss.
1241
+ """
1242
+
1243
+ DEFAULT_SYSTEM: str = (
1244
+ "You are a careful content moderator. "
1245
+ "Analyse the user message, reason step-by-step inside <think> tags, "
1246
+ "then issue a structured Markdown verdict.\n\n"
1247
+ + _MARKDOWN_INSTRUCTION
1248
+ )
1249
+
1250
+ def __init__(
1251
+ self,
1252
+ data_list: list[dict[str, Any]],
1253
+ enc: tiktoken.Encoding,
1254
+ max_length: int = 12288,
1255
+ system_prompt: str = "",
1256
+ ):
1257
+ self.enc = enc
1258
+ self.max_length = max_length
1259
+ self.samples = data_list
1260
+ self.system_prompt = system_prompt or self.DEFAULT_SYSTEM
1261
+
1262
+ def __len__(self) -> int:
1263
+ return len(self.samples)
1264
+
1265
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
1266
+ item = self.samples[idx]
1267
+
1268
+ assistant_content = ChatTemplate.render_assistant_content(
1269
+ reasoning=item["reasoning"],
1270
+ verdict=item["output"],
1271
+ )
1272
+ messages = [
1273
+ Message("system", self.system_prompt),
1274
+ Message("user", item["instruction"]),
1275
+ Message("assistant", assistant_content),
1276
+ ]
1277
+ return ChatTemplate.tokenize(messages, self.enc, self.max_length)
1278
+
1279
+
1280
+ def collate_fn(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
1281
+ input_ids = pad_sequence(
1282
+ [b["input_ids"] for b in batch], batch_first=True, padding_value=TOK_EOT
1283
+ )
1284
+ labels = pad_sequence(
1285
+ [b["labels"] for b in batch], batch_first=True, padding_value=-100
1286
+ )
1287
+ return {"input_ids": input_ids, "labels": labels}
1288
+
1289
+
1290
+ # ─────────────────────────────────────────────
1291
+ # Trainer
1292
+ # ─────────────────────────────────────────────
1293
+ class GreesyTrainer:
1294
+ def __init__(
1295
+ self,
1296
+ model: GreesyGPT,
1297
+ train_dataset: Dataset[dict[str, torch.Tensor]],
1298
+ lr: float = 2e-5,
1299
+ batch_size: int = 2,
1300
+ grad_accum: int = 4,
1301
+ ):
1302
+ self.model = model.to(DEVICE)
1303
+ self.optimizer = torch.optim.AdamW(
1304
+ model.parameters(), lr=lr, weight_decay=0.1
1305
+ )
1306
+ self.dataloader = DataLoader(
1307
+ train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
1308
+ )
1309
+ self.grad_accum = grad_accum
1310
+
1311
+ def train_epoch(self, epoch: int) -> float:
1312
+ self.model.train()
1313
+ total_loss = 0.0
1314
+ self.optimizer.zero_grad()
1315
+
1316
+ pbar = tqdm(self.dataloader, desc=f"Epoch {epoch}")
1317
+ for step, batch in enumerate(pbar):
1318
+ inputs = batch["input_ids"].to(DEVICE)
1319
+ targets = batch["labels"].to(DEVICE)
1320
+
1321
+ with _autocast_ctx():
1322
+ logits = self.model(inputs)
1323
+ shift_logits = logits[..., :-1, :].contiguous()
1324
+ shift_labels = targets[..., 1:].contiguous()
1325
+ loss = F.cross_entropy(
1326
+ shift_logits.view(-1, shift_logits.size(-1)),
1327
+ shift_labels.view(-1),
1328
+ ignore_index=-100,
1329
+ ) / self.grad_accum
1330
+
1331
+ loss.backward()
1332
+
1333
+ if (step + 1) % self.grad_accum == 0:
1334
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
1335
+ self.optimizer.step()
1336
+ self.optimizer.zero_grad()
1337
+
1338
+ total_loss += loss.item() * self.grad_accum
1339
+ pbar.set_postfix(loss=loss.item() * self.grad_accum)
1340
+
1341
+ avg = total_loss / len(self.dataloader)
1342
+ print(f"Epoch {epoch} avg loss: {avg:.4f}")
1343
+ return avg
1344
+
1345
+
1346
+ # ─────────────────────────────────────────────
1347
+ # Quick-start example
1348
+ # ─────────────────────────────────────────────
1349
+ if __name__ == "__main__":
1350
+ print(f"Using device: {DEVICE}")
1351
+ print(describe_reasoning_modes())
1352
+ print()
1353
+ # ── Train on sample data ──────────────────────────────────────────────────
1354
+ model = GreesyGPT()
1355
+ dataset = get_dataset() if DATASET_JSON_PATH.exists() else get_sample_dataset()
1356
+ trainer = GreesyTrainer(model, dataset, batch_size=2, grad_accum=4)
1357
+ trainer.train_epoch(epoch=1)
1358
+
1359
+ # ── Save the model ──────────────────────────────────────────────────
1360
+ save_path = Path(__file__).parent / "greesy_gpt.pt"
1361
+ torch.save(model.state_dict(), save_path)
1362
+ print(f"Model saved to {save_path}")
1363
+
1364
+ # ── Inference: compare modes × output formats ─────────────────────────────
1365
+ test_prompt = "You are worthless and no one will ever love you."
1366
+
1367
+ for mode in ReasoningMode:
1368
+ for fmt in OutputFormat:
1369
+ result = generate_moderation(model, test_prompt, mode=mode, output_format=fmt)
1370
+ print(f"\n── {mode.value.upper()} / {fmt.value.upper()} ──")
1371
+ if fmt == OutputFormat.JSON:
1372
+ print(json.dumps(result["verdict_fmt"], indent=2))
1373
+ else:
1374
+ print(str(result["verdict_fmt"])[:300])
requirements.txt ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ torch
3
+ tiktoken
4
+ tqdm
5
+ =======
6
+ aiohttp==3.9.5
7
+ aiosignal==1.3.1
8
+ annotated-types==0.7.0
9
+ anyio==4.4.0
10
+ argon2-cffi==23.1.0
11
+ argon2-cffi-bindings==21.2.0
12
+ arrow==1.3.0
13
+ asttokens==2.4.1
14
+ async-lru==2.0.4
15
+ async-timeout==4.0.3
16
+ attrs==23.2.0
17
+ Automat==20.2.0
18
+ Babel==2.15.0
19
+ bcrypt==3.2.0
20
+ beautifulsoup4==4.12.3
21
+ bleach==6.1.0
22
+ blinker==1.4
23
+ certifi==2020.6.20
24
+ cffi==1.16.0
25
+ chardet==4.0.0
26
+ charset-normalizer==3.3.2
27
+ click==8.0.3
28
+ cloud-init==24.1.3
29
+ colorama==0.4.4
30
+ comm==0.2.2
31
+ command-not-found==0.3
32
+ configobj==5.0.6
33
+ constantly==15.1.0
34
+ cryptography==3.4.8
35
+ datasets==2.20.0
36
+ dbus-python==1.2.18
37
+ debugpy==1.8.2
38
+ decorator==5.1.1
39
+ defusedxml==0.7.1
40
+ dill==0.3.8
41
+ distro==1.7.0
42
+ distro-info==1.1+ubuntu0.2
43
+ dnspython==2.6.1
44
+ email_validator==2.2.0
45
+ exceptiongroup==1.2.2
46
+ executing==2.0.1
47
+ fastapi==0.111.1
48
+ fastapi-cli==0.0.4
49
+ fastjsonschema==2.20.0
50
+ filelock==3.15.4
51
+ fqdn==1.5.1
52
+ frozenlist==1.4.1
53
+ fsspec==2024.5.0
54
+ gyp==0.1
55
+ h11==0.14.0
56
+ httpcore==1.0.5
57
+ httplib2==0.20.2
58
+ httptools==0.6.1
59
+ httpx==0.27.0
60
+ huggingface-hub==0.24.2
61
+ hyperlink==21.0.0
62
+ idna==3.3
63
+ importlib-metadata==4.6.4
64
+ incremental==21.3.0
65
+ ipykernel==6.29.5
66
+ ipython==8.26.0
67
+ isoduration==20.11.0
68
+ jedi==0.19.1
69
+ jeepney==0.7.1
70
+ Jinja2==3.0.3
71
+ joblib==1.4.2
72
+ json5==0.9.25
73
+ jsonpatch==1.32
74
+ jsonpointer==2.0
75
+ jsonschema==4.23.0
76
+ jsonschema-specifications==2023.12.1
77
+ jupyter-events==0.10.0
78
+ jupyter-lsp==2.2.5
79
+ jupyter_client==8.6.2
80
+ jupyter_core==5.7.2
81
+ jupyter_server==2.14.2
82
+ jupyter_server_terminals==0.5.3
83
+ jupyterlab==4.2.4
84
+ jupyterlab_pygments==0.3.0
85
+ jupyterlab_server==2.27.3
86
+ keyring==23.5.0
87
+ launchpadlib==1.10.16
88
+ lazr.restfulclient==0.14.4
89
+ lazr.uri==1.0.6
90
+ markdown-it-py==3.0.0
91
+ MarkupSafe==2.0.1
92
+ matplotlib-inline==0.1.7
93
+ mdurl==0.1.2
94
+ mistune==3.0.2
95
+ more-itertools==8.10.0
96
+ mpmath==1.3.0
97
+ multidict==6.0.5
98
+ multiprocess==0.70.16
99
+ nbclient==0.10.0
100
+ nbconvert==7.16.4
101
+ nbformat==5.10.4
102
+ nest-asyncio==1.6.0
103
+ netifaces==0.11.0
104
+ networkx==3.3
105
+ notebook==7.2.1
106
+ notebook_shim==0.2.4
107
+ numpy==2.0.1
108
+ nvidia-cublas-cu12==12.1.3.1
109
+ nvidia-cuda-cupti-cu12==12.1.105
110
+ nvidia-cuda-nvrtc-cu12==12.1.105
111
+ nvidia-cuda-runtime-cu12==12.1.105
112
+ nvidia-cudnn-cu12==9.1.0.70
113
+ nvidia-cufft-cu12==11.0.2.54
114
+ nvidia-curand-cu12==10.3.2.106
115
+ nvidia-cusolver-cu12==11.4.5.107
116
+ nvidia-cusparse-cu12==12.1.0.106
117
+ nvidia-nccl-cu12==2.20.5
118
+ nvidia-nvjitlink-cu12==12.5.82
119
+ nvidia-nvtx-cu12==12.1.105
120
+ oauthlib==3.2.0
121
+ overrides==7.7.0
122
+ packaging==24.1
123
+ pandas==2.2.2
124
+ pandocfilters==1.5.1
125
+ parso==0.8.4
126
+ pexpect==4.8.0
127
+ platformdirs==4.2.2
128
+ prometheus_client==0.20.0
129
+ prompt_toolkit==3.0.47
130
+ psutil==6.0.0
131
+ ptyprocess==0.7.0
132
+ pure_eval==0.2.3
133
+ pyarrow==17.0.0
134
+ pyarrow-hotfix==0.6
135
+ pyasn1==0.4.8
136
+ pyasn1-modules==0.2.1
137
+ pycparser==2.22
138
+ pydantic==2.8.2
139
+ pydantic_core==2.20.1
140
+ Pygments==2.18.0
141
+ PyGObject==3.42.1
142
+ PyHamcrest==2.0.2
143
+ PyJWT==2.3.0
144
+ pyOpenSSL==21.0.0
145
+ pyparsing==2.4.7
146
+ pyrsistent==0.18.1
147
+ pyserial==3.5
148
+ python-apt==2.4.0+ubuntu3
149
+ python-dateutil==2.9.0.post0
150
+ python-dotenv==1.0.1
151
+ python-json-logger==2.0.7
152
+ python-magic==0.4.24
153
+ python-multipart==0.0.9
154
+ pytz==2022.1
155
+ PyYAML==5.4.1
156
+ pyzmq==26.0.3
157
+ referencing==0.35.1
158
+ regex==2024.7.24
159
+ requests==2.32.3
160
+ rfc3339-validator==0.1.4
161
+ rfc3986-validator==0.1.1
162
+ rich==13.7.1
163
+ rpds-py==0.19.1
164
+ scikit-learn==1.5.1
165
+ scipy==1.14.0
166
+ SecretStorage==3.3.1
167
+ Send2Trash==1.8.3
168
+ service-identity==18.1.0
169
+ shellingham==1.5.4
170
+ six==1.16.0
171
+ sniffio==1.3.1
172
+ sos==4.5.6
173
+ soupsieve==2.5
174
+ ssh-import-id==5.11
175
+ stack-data==0.6.3
176
+ starlette==0.37.2
177
+ sympy==1.13.1
178
+ systemd-python==234
179
+ terminado==0.18.1
180
+ threadpoolctl==3.5.0
181
+ tiktoken==0.7.0
182
+ tinycss2==1.3.0
183
+ tomli==2.0.1
184
+ torch==2.4.0
185
+ tornado==6.4.1
186
+ tqdm==4.66.4
187
+ traitlets==5.14.3
188
+ triton==3.0.0
189
+ Twisted==22.1.0
190
+ typer==0.12.3
191
+ types-python-dateutil==2.9.0.20240316
192
+ typing_extensions==4.12.2
193
+ tzdata==2024.1
194
+ ubuntu-drivers-common==0.0.0
195
+ ubuntu-pro-client==8001
196
+ ufw==0.36.1
197
+ unattended-upgrades==0.1
198
+ uri-template==1.3.0
199
+ urllib3==1.26.5
200
+ uvicorn==0.30.3
201
+ uvloop==0.19.0
202
+ wadllib==1.3.6
203
+ watchfiles==0.22.0
204
+ wcwidth==0.2.13
205
+ webcolors==24.6.0
206
+ webencodings==0.5.1
207
+ websocket-client==1.8.0
208
+ websockets==12.0
209
+ xkit==0.0.0
210
+ xxhash==3.4.1
211
+ yarl==1.9.4
212
+ zipp==1.0.0
213
+ zope.interface==5.4.0
214
+ >>>>>>> 97e2c257c46cada50ce746fc505a1f8414a148e6