Spaces:
Running
Running
| """ | |
| GreesyGPT β Content-moderation language model with KV caching and dropout. | |
| Production-ready implementation featuring: | |
| β’ KV caching for O(1) per-token inference (instead of recomputing full sequence) | |
| β’ Configurable dropout for regularisation during training | |
| β’ Centralised ModelConfig dataclass for all hyperparameters | |
| β’ Structured logging via the stdlib ``logging`` module | |
| β’ Type annotations throughout | |
| """ | |
| from __future__ import annotations | |
| import contextlib | |
| import json | |
| import logging | |
| import re | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from pathlib import Path | |
| from typing import Any, Optional, cast | |
| import tiktoken | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from torch.nn.utils.rnn import pad_sequence | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm import tqdm | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Logging | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| logger = logging.getLogger("greesygpt") | |
| if not logger.handlers: | |
| _handler = logging.StreamHandler() | |
| _handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s")) | |
| logger.addHandler(_handler) | |
| logger.setLevel(logging.INFO) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Model Configuration | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class ModelConfig: | |
| """Centralised hyperparameter store for GreesyGPT.""" | |
| vocab_size: int = 200032 | |
| context_len: int = 12_000 | |
| n_embd: int = 768 | |
| n_head: int = 12 | |
| n_layer: int = 12 | |
| # Dropout rates (set to 0.0 at inference via model.eval(); typical training values 0.1β0.2) | |
| attn_dropout: float = 0.1 | |
| resid_dropout: float = 0.1 | |
| embd_dropout: float = 0.1 | |
| mlp_dropout: float = 0.1 | |
| def head_dim(self) -> int: | |
| assert self.n_embd % self.n_head == 0, "n_embd must be divisible by n_head" | |
| return self.n_embd // self.n_head | |
| # Legacy constants (kept for backward compatibility; prefer ModelConfig) | |
| DEFAULT_CONFIG = ModelConfig() | |
| VOCAB_SIZE = DEFAULT_CONFIG.vocab_size | |
| CONTEXT_LEN = DEFAULT_CONFIG.context_len | |
| N_EMBD = DEFAULT_CONFIG.n_embd | |
| N_HEAD = DEFAULT_CONFIG.n_head | |
| N_LAYER = DEFAULT_CONFIG.n_layer | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Device (MPS β CUDA β CPU, in priority order) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def _select_device() -> torch.device: | |
| if torch.backends.mps.is_available() and torch.backends.mps.is_built(): | |
| return torch.device("mps") | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| return torch.device("cpu") | |
| DEVICE: torch.device = _select_device() | |
| def _autocast_ctx(): | |
| """ | |
| Return the appropriate mixed-precision context for the active device. | |
| β’ MPS β float16 (bfloat16 not yet supported by MPS kernel) | |
| β’ CUDA β bfloat16 (preferred for training stability) | |
| β’ CPU β no-op (autocast on CPU is rarely beneficial) | |
| """ | |
| if DEVICE.type == "mps": | |
| return torch.autocast(device_type="mps", dtype=torch.float16) | |
| if DEVICE.type == "cuda": | |
| return torch.autocast(device_type="cuda", dtype=torch.bfloat16) | |
| return contextlib.nullcontext() | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Special-token registry | |
| # 199999 <|endoftext|> (native to o200k_base) | |
| # 200019 <think> (first free slot after o200k_base's 200019 mergeable ranks) | |
| # 200020 </think> | |
| # 200021 <|system|> | |
| # 200022 </|system|> | |
| # 200023 <|user|> | |
| # 200024 </|user|> | |
| # 200025 <|assistant|> (turn closed by <|endoftext|>, no explicit close tag) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| SPECIAL_TOKENS: dict[str, int] = { | |
| "<|endoftext|>": 199999, | |
| "<think>": 200019, | |
| "</think>": 200020, | |
| "<|system|>": 200021, | |
| "</|system|>": 200022, | |
| "<|user|>": 200023, | |
| "</|user|>": 200024, | |
| "<|assistant|>": 200025, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tokenizer | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_tokenizer() -> tiktoken.Encoding: | |
| base = tiktoken.get_encoding("o200k_base") | |
| return tiktoken.Encoding( | |
| name="greesy_gpt", | |
| pat_str=base._pat_str, | |
| mergeable_ranks=base._mergeable_ranks, | |
| special_tokens={**base._special_tokens, **SPECIAL_TOKENS}, | |
| ) | |
| tokenizer = get_tokenizer() | |
| # Convenience single-token IDs used throughout | |
| def _tid(s: str) -> int: | |
| return tokenizer.encode(s, allowed_special="all")[0] | |
| TOK_EOT = _tid("<|endoftext|>") | |
| TOK_THINK_OPEN = _tid("<think>") | |
| TOK_THINK_CLOSE = _tid("</think>") | |
| TOK_SYSTEM = _tid("<|system|>") | |
| TOK_ESYSTEM = _tid("</|system|>") | |
| TOK_USER = _tid("<|user|>") | |
| TOK_EUSER = _tid("</|user|>") | |
| TOK_ASSISTANT = _tid("<|assistant|>") | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Chat Template | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class Message: | |
| """A single turn in a conversation.""" | |
| role: str # "system" | "user" | "assistant" | |
| content: str | |
| class ChatTemplate: | |
| """ | |
| Serialises a ``list[Message]`` into GreesyGPT's role-delimited wire format: | |
| <|system|> | |
| {system content} | |
| </|system|> | |
| <|user|> | |
| {user content} | |
| </|user|> | |
| <|assistant|> | |
| <think> | |
| {reasoning} | |
| </think> | |
| {verdict}<|endoftext|> | |
| For training, only the *assistant completion* tokens (everything after | |
| ``<|assistant|>\\n``) contribute to the loss; all other tokens are masked | |
| with ``-100`` in the labels tensor. | |
| For multi-turn conversations append additional user/assistant message | |
| pairs β the template handles them in sequence. | |
| """ | |
| _OPEN = {"system": "<|system|>", "user": "<|user|>", "assistant": "<|assistant|>"} | |
| _CLOSE = {"system": "</|system|>", "user": "</|user|>", "assistant": ""} # EOT closes assistant | |
| # ββ Rendering βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def render(cls, messages: list[Message], add_generation_prompt: bool = False) -> str: | |
| """ | |
| Render a full conversation to a single string. | |
| Parameters | |
| ---------- | |
| messages: | |
| Ordered ``Message`` list (system, then alternating user/assistant). | |
| add_generation_prompt: | |
| Append ``<|assistant|>\\n<think>\\n`` so the model can continue | |
| from the correct position during inference. | |
| """ | |
| parts: list[str] = [] | |
| for msg in messages: | |
| open_tag = cls._OPEN[msg.role] | |
| close_tag = cls._CLOSE[msg.role] | |
| if close_tag: | |
| parts.append(f"{open_tag}\n{msg.content}\n{close_tag}\n") | |
| else: | |
| # assistant: content already contains <think>β¦</think>+verdict+EOT | |
| parts.append(f"{open_tag}\n{msg.content}") | |
| if add_generation_prompt: | |
| parts.append("<|assistant|>\n<think>\n") | |
| return "".join(parts) | |
| def render_assistant_content(reasoning: str, verdict: str) -> str: | |
| """Build the assistant ``content`` field from reasoning + verdict.""" | |
| return f"<think>\n{reasoning}\n</think>\n{verdict}<|endoftext|>" | |
| # ββ Tokenisation with per-role label masking ββββββββββββββββββββββββββββββ | |
| def tokenize( | |
| cls, | |
| messages: list[Message], | |
| enc: tiktoken.Encoding, | |
| max_length: int = 12288, | |
| ) -> dict[str, torch.Tensor]: | |
| """ | |
| Tokenise a conversation and return ``input_ids`` + ``labels``. | |
| All system and user tokens are masked (``-100``). | |
| Only assistant completion tokens are trained. | |
| Supports multi-turn: each assistant turn is unmasked individually. | |
| """ | |
| input_ids: list[int] = [] | |
| labels: list[int] = [] | |
| for msg in messages: | |
| open_tag = cls._OPEN[msg.role] | |
| close_tag = cls._CLOSE[msg.role] | |
| if msg.role == "assistant": | |
| # Role header β masked | |
| header_toks = enc.encode(f"{open_tag}\n", allowed_special="all") | |
| input_ids.extend(header_toks) | |
| labels.extend([-100] * len(header_toks)) | |
| # Completion β trained | |
| comp_toks = enc.encode(msg.content, allowed_special="all") | |
| input_ids.extend(comp_toks) | |
| labels.extend(comp_toks) | |
| else: | |
| text = ( | |
| f"{open_tag}\n{msg.content}\n{close_tag}\n" | |
| if close_tag | |
| else f"{open_tag}\n{msg.content}" | |
| ) | |
| toks = enc.encode(text, allowed_special="all") | |
| input_ids.extend(toks) | |
| labels.extend([-100] * len(toks)) | |
| # Truncate to max_length | |
| input_ids = input_ids[:max_length] | |
| labels = labels[:max_length] | |
| return { | |
| "input_ids": torch.tensor(input_ids, dtype=torch.long), | |
| "labels": torch.tensor(labels, dtype=torch.long), | |
| } | |
| # ββ Inference helper ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_inference_prompt(cls, user_message: str, system_prompt: str = "") -> str: | |
| """Return the prompt string to feed the model at inference time.""" | |
| msgs: list[Message] = [] | |
| if system_prompt: | |
| msgs.append(Message("system", system_prompt)) | |
| msgs.append(Message("user", user_message)) | |
| return cls.render(msgs, add_generation_prompt=True) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Output Formats & Markdown helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class OutputFormat(Enum): | |
| """ | |
| MARKDOWN β raw model output; headings, bold, and lists preserved. | |
| PLAIN β markdown stripped to clean prose. | |
| JSON β structured dict with label, severity, confidence, and summary. | |
| """ | |
| MARKDOWN = "markdown" | |
| PLAIN = "plain" | |
| JSON = "json" | |
| # Severity scale used by the JSON formatter | |
| _LABEL_SEVERITY: dict[str, int] = { | |
| "SAFE": 0, | |
| "SPAM": 1, | |
| "MISINFORMATION": 2, | |
| "HARASSMENT": 3, | |
| "HATE_SPEECH": 4, | |
| "CRISIS_REFERRAL": 5, | |
| "UNSAFE": 6, | |
| } | |
| # Ordered substitutions for stripping Markdown syntax | |
| _MD_STRIP: list[tuple[re.Pattern[str], str]] = [ | |
| (re.compile(r"#{1,6}\s*"), ""), # headings | |
| (re.compile(r"\*\*(.+?)\*\*"), r"\1"), # bold | |
| (re.compile(r"\*(.+?)\*"), r"\1"), # italic | |
| (re.compile(r"`{1,3}(.+?)`{1,3}", re.DOTALL), r"\1"), # inline/fenced code | |
| (re.compile(r"^\s*[-*+]\s+", re.M), "β’ "), # unordered list | |
| (re.compile(r"^\s*\d+\.\s+", re.M), ""), # ordered list numbers | |
| (re.compile(r"\[(.+?)\]\(.+?\)"), r"\1"), # links | |
| (re.compile(r"^\s*>\s?", re.M), ""), # blockquotes | |
| (re.compile(r"-{3,}|={3,}|\*{3,}"), ""), # horizontal rules | |
| (re.compile(r"\n{3,}"), "\n\n"),# excess blank lines | |
| ] | |
| def strip_markdown(text: str) -> str: | |
| """Remove common Markdown syntax and return clean prose.""" | |
| for pattern, repl in _MD_STRIP: | |
| text = pattern.sub(repl, text) | |
| return text.strip() | |
| def extract_verdict_label(verdict_text: str) -> str: | |
| """Pull the label keyword from a verdict block. Falls back to 'UNKNOWN'.""" | |
| upper = verdict_text.upper() | |
| for label in _LABEL_SEVERITY: | |
| if label in upper: | |
| return label | |
| return "UNKNOWN" | |
| def format_output(result: dict[str, Any], fmt: OutputFormat = OutputFormat.MARKDOWN) -> "str | dict[str, Any]": | |
| """ | |
| Post-process a ``generate_moderation`` result. | |
| Returns ``str`` for MARKDOWN/PLAIN, ``dict`` for JSON. | |
| """ | |
| verdict = result.get("verdict", "") | |
| thinking = result.get("thinking") or "" | |
| mode = result.get("mode") | |
| if fmt == OutputFormat.PLAIN: | |
| return strip_markdown(verdict) | |
| if fmt == OutputFormat.JSON: | |
| label = extract_verdict_label(verdict) | |
| sev = _LABEL_SEVERITY.get(label, -1) | |
| conf = "high" if sev <= 1 else ("medium" if sev <= 3 else "low") | |
| return { | |
| "verdict": label, | |
| "severity": sev, | |
| "confidence_hint": conf, | |
| "reasoning_mode": mode.value if mode else None, | |
| "thinking_summary": thinking[:300].strip(), | |
| "full_verdict": verdict, | |
| } | |
| return verdict # MARKDOWN: return as-is | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Reasoning Modes | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class ReasoningMode(Enum): | |
| """ | |
| Controls how much thinking the model does before emitting a verdict. | |
| NONE β minimal chain-of-thought; fastest, best for obvious cases. | |
| LOW β balanced reasoning; good general-purpose default. | |
| MEDIUM β extended deliberation; best for nuanced / borderline content. | |
| HIGH β maximum tokens + lower temperature; use for high-stakes review. | |
| """ | |
| NONE = "none" | |
| LOW = "low" | |
| MEDIUM = "medium" | |
| HIGH = "high" | |
| # Injected into every system prompt to teach the model Markdown output style | |
| _MARKDOWN_INSTRUCTION = ( | |
| "Format your entire response in **Markdown**:\n" | |
| "- Use `##` headings to separate reasoning sections " | |
| "(e.g. `## Intent`, `## Evidence`, `## Harm Potential`, `## Edge Cases`).\n" | |
| "- Use `**bold**` to highlight key terms, policy labels, and the final verdict.\n" | |
| "- Use bullet lists (`-`) for evidence points.\n" | |
| "- Use `>` blockquotes for direct quotes from the message under review.\n" | |
| "- End with a `## Verdict` section containing **only** the label on its own line " | |
| "(e.g. `**SAFE**`, `**HARASSMENT**`).\n" | |
| ) | |
| class ReasoningConfig: | |
| max_think_tokens: int # token budget for <think>β¦</think> | |
| max_total_tokens: int # total generation budget (think + verdict) | |
| temperature: float | |
| top_k: int | |
| system_prompt: str # full system-turn content | |
| def total_budget(self) -> int: | |
| return self.max_total_tokens | |
| REASONING_CONFIGS: dict[ReasoningMode, ReasoningConfig] = { | |
| ReasoningMode.NONE: ReasoningConfig( | |
| max_think_tokens=200, | |
| max_total_tokens=812, | |
| temperature=0.1, | |
| top_k=50, | |
| system_prompt=( | |
| "You are a fast content moderator. " | |
| "Think briefly, then give a clear verdict.\n\n" | |
| + _MARKDOWN_INSTRUCTION | |
| ), | |
| ), | |
| ReasoningMode.LOW: ReasoningConfig( | |
| max_think_tokens=512, | |
| max_total_tokens=1200, | |
| temperature=0.7, | |
| top_k=40, | |
| system_prompt=( | |
| "You are a careful content moderator. " | |
| "Reason step-by-step, then issue a structured Markdown verdict.\n\n" | |
| + _MARKDOWN_INSTRUCTION | |
| ), | |
| ), | |
| ReasoningMode.MEDIUM: ReasoningConfig( | |
| max_think_tokens=1536, | |
| max_total_tokens=2048, | |
| temperature=0.6, | |
| top_k=30, | |
| system_prompt=( | |
| "You are a thorough content moderator. " | |
| "Analyse intent, context, potential harm, and edge cases " | |
| "before issuing a detailed Markdown verdict.\n\n" | |
| + _MARKDOWN_INSTRUCTION | |
| ), | |
| ), | |
| ReasoningMode.HIGH: ReasoningConfig( | |
| max_think_tokens=3072, | |
| max_total_tokens=4096, | |
| temperature=0.4, | |
| top_k=20, | |
| system_prompt=( | |
| "You are an expert content safety reviewer. " | |
| "Examine the message from every relevant angleβlegal, ethical, " | |
| "contextual, and platform-policyβand produce a comprehensive " | |
| "Markdown safety report.\n\n" | |
| + _MARKDOWN_INSTRUCTION | |
| ), | |
| ), | |
| } | |
| def describe_reasoning_modes() -> str: | |
| lines = ["Available Reasoning Modes\n" + "=" * 44] | |
| for mode, cfg in REASONING_CONFIGS.items(): | |
| lines.append( | |
| f" {mode.value:10s} thinkβ€{cfg.max_think_tokens:4d} tokens | " | |
| f"totalβ€{cfg.max_total_tokens:4d} tokens | " | |
| f"temp={cfg.temperature} | top_k={cfg.top_k}" | |
| ) | |
| return "\n".join(lines) | |
| DATASET_JSON_PATH = Path(__file__).with_name("dataset.json") | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # KV Cache | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class KVCache: | |
| """ | |
| Per-layer key/value cache for autoregressive generation. | |
| Stores tensors of shape ``[B, n_head, T_cached, head_dim]``. | |
| Grows incrementally as new tokens are generated. | |
| """ | |
| key: Optional[torch.Tensor] = None | |
| value: Optional[torch.Tensor] = None | |
| def seq_len(self) -> int: | |
| """Number of tokens currently cached.""" | |
| if self.key is None: | |
| return 0 | |
| return self.key.shape[2] | |
| def update( | |
| self, new_key: torch.Tensor, new_value: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Append new K/V slices and return the full accumulated tensors. | |
| Parameters | |
| ---------- | |
| new_key, new_value : ``[B, n_head, T_new, head_dim]`` | |
| Returns | |
| ------- | |
| (full_key, full_value) each ``[B, n_head, T_total, head_dim]`` | |
| """ | |
| if self.key is None or self.value is None: | |
| self.key = new_key | |
| self.value = new_value | |
| else: | |
| assert self.key is not None and self.value is not None | |
| self.key = torch.cat((self.key, new_key), dim=2) | |
| self.value = torch.cat((self.value, new_value), dim=2) | |
| return cast(torch.Tensor, self.key), cast(torch.Tensor, self.value) | |
| def clear(self) -> None: | |
| self.key = None | |
| self.value = None | |
| # Type alias: one KVCache per layer | |
| LayerCaches = list[KVCache] | |
| def make_kv_caches(n_layers: int) -> LayerCaches: | |
| """Create a fresh list of empty KV caches, one per transformer layer.""" | |
| return [KVCache() for _ in range(n_layers)] | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # RoPE (supports position offset for KV cache) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class RoPE(nn.Module): | |
| def __init__(self, head_dim: int, max_seq_len: int = CONTEXT_LEN): | |
| super().__init__() | |
| inv_freq = 1.0 / (10000.0 ** (torch.arange(0, head_dim, 2).float() / head_dim)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| def forward( | |
| self, seq_len: int, device: torch.device, offset: int = 0 | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Compute cos/sin embeddings for positions ``[offset, offset+seq_len)``. | |
| Parameters | |
| ---------- | |
| seq_len : number of new positions to compute | |
| device : target device | |
| offset : starting position index (= number of previously cached tokens) | |
| """ | |
| inv_freq = cast(torch.Tensor, self.inv_freq) | |
| t: torch.Tensor = torch.arange(offset, offset + seq_len, device=device, dtype=inv_freq.dtype) | |
| freqs = torch.einsum("i,j->ij", t, inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| return emb.cos()[None, :, None, :], emb.sin()[None, :, None, :] | |
| def rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| half = x.shape[-1] // 2 | |
| return torch.cat((-x[..., half:], x[..., :half]), dim=-1) | |
| def apply_rope(q, k, cos, sin): | |
| q = (q * cos) + (rotate_half(q) * sin) | |
| k = (k * cos) + (rotate_half(k) * sin) | |
| return q, k | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Transformer Block (with KV cache + dropout) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class GreesyBlock(nn.Module): | |
| def __init__(self, config: ModelConfig): | |
| super().__init__() | |
| n_embd = config.n_embd | |
| self.n_head = config.n_head | |
| self.head_dim = config.head_dim | |
| self.ln1 = nn.LayerNorm(n_embd) | |
| self.ln2 = nn.LayerNorm(n_embd) | |
| self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False) | |
| self.out_proj = nn.Linear(n_embd, n_embd, bias=False) | |
| self.rope = RoPE(self.head_dim, max_seq_len=config.context_len) | |
| # Dropout layers | |
| self.attn_dropout = nn.Dropout(config.attn_dropout) | |
| self.resid_dropout = nn.Dropout(config.resid_dropout) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(n_embd, 4 * n_embd), | |
| nn.GELU(), | |
| nn.Dropout(config.mlp_dropout), | |
| nn.Linear(4 * n_embd, n_embd), | |
| nn.Dropout(config.resid_dropout), | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| kv_cache: Optional[KVCache] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass with optional KV caching. | |
| Parameters | |
| ---------- | |
| x : ``[B, T, C]`` β input embeddings (full sequence or single new token) | |
| kv_cache : if provided, keys/values are appended to the cache and the | |
| full cached K/V are used for attention, enabling O(1) per-token | |
| inference instead of O(T). | |
| """ | |
| B, T, C = x.shape | |
| norm_x = self.ln1(x) | |
| qkv = self.qkv(norm_x).reshape(B, T, 3, self.n_head, self.head_dim) | |
| q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] | |
| # Position offset = number of tokens already in the cache | |
| offset = kv_cache.seq_len if kv_cache is not None else 0 | |
| cos, sin = self.rope(T, x.device, offset=offset) | |
| q, k = apply_rope(q, k, cos, sin) | |
| # [B, T, n_head, head_dim] β [B, n_head, T, head_dim] | |
| q = q.transpose(1, 2) | |
| k = k.transpose(1, 2) | |
| v = v.transpose(1, 2) | |
| # Update KV cache (if provided) and use full cached K/V for attention | |
| if kv_cache is not None: | |
| k, v = kv_cache.update(k, v) | |
| # Causal mask is only needed during training/prefill (offset==0). | |
| # During cached single-token generation (offset>0, T_q=1) every | |
| # cached position is visible, so is_causal=False is correct. | |
| is_causal = kv_cache is None or offset == 0 | |
| attn_out = F.scaled_dot_product_attention( | |
| q, k, v, | |
| is_causal=is_causal, | |
| dropout_p=self.attn_dropout.p if self.training else 0.0, | |
| ) | |
| # [B, n_head, T, head_dim] β [B, T, C] | |
| attn_out = attn_out.transpose(1, 2).reshape(B, T, C) | |
| x = x + self.resid_dropout(self.out_proj(attn_out)) | |
| x = x + self.mlp(self.ln2(x)) | |
| return x | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Model (config-driven, with embedding dropout + KV cache support) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class GreesyGPT(nn.Module): | |
| def __init__(self, config: Optional[ModelConfig] = None): | |
| super().__init__() | |
| self.config = config or DEFAULT_CONFIG | |
| c = self.config | |
| self.embd_dropout = nn.Dropout(c.embd_dropout) | |
| self.tok_emb = nn.Embedding(c.vocab_size, c.n_embd) | |
| self.blocks = nn.ModuleList([GreesyBlock(c) for _ in range(c.n_layer)]) | |
| self.ln_f = nn.LayerNorm(c.n_embd) | |
| self.head = nn.Linear(c.n_embd, c.vocab_size, bias=False) | |
| self.tok_emb.weight = self.head.weight # weight tying | |
| n_params = sum(p.numel() for p in self.parameters()) | |
| logger.info( | |
| "GreesyGPT initialised β %.2fM params, %d layers, %d heads, " | |
| "ctx=%d, dropout=(attn=%.2f, resid=%.2f, embd=%.2f, mlp=%.2f)", | |
| n_params / 1e6, c.n_layer, c.n_head, c.context_len, | |
| c.attn_dropout, c.resid_dropout, c.embd_dropout, c.mlp_dropout, | |
| ) | |
| def forward( | |
| self, | |
| idx: torch.Tensor, | |
| kv_caches: Optional[LayerCaches] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass. | |
| Parameters | |
| ---------- | |
| idx : ``[B, T]`` token indices | |
| kv_caches : optional list of ``KVCache`` (one per layer). When | |
| provided, enables incremental decoding. | |
| Returns | |
| ------- | |
| logits : ``[B, T, vocab_size]`` | |
| """ | |
| x = self.embd_dropout(self.tok_emb(idx)) | |
| for i, block in enumerate(self.blocks): | |
| cache = kv_caches[i] if kv_caches is not None else None | |
| x = block(x, kv_cache=cache) | |
| return self.head(self.ln_f(x)) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Generation (KV-cached, chat-template + reasoning-mode aware) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_moderation( | |
| model: GreesyGPT, | |
| prompt: str, | |
| mode: ReasoningMode = ReasoningMode.LOW, | |
| output_format: OutputFormat = OutputFormat.MARKDOWN, | |
| # Optional per-call overrides (take priority over mode defaults) | |
| max_tokens: Optional[int] = None, | |
| temp: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| use_kv_cache: bool = True, | |
| ) -> dict[str, Any]: | |
| """ | |
| Run moderation inference via the chat template with KV caching. | |
| Parameters | |
| ---------- | |
| model : trained GreesyGPT | |
| prompt : raw user message to moderate | |
| mode : ReasoningMode (controls budget, temperature, system prompt) | |
| output_format : post-processing format for the verdict | |
| max_tokens : overrides mode's ``max_total_tokens`` | |
| temp : overrides mode's ``temperature`` | |
| top_k : overrides mode's ``top_k`` | |
| use_kv_cache : if True (default), use KV caching for efficient generation | |
| Returns | |
| ------- | |
| dict with keys | |
| full_text raw decoded sequence (includes special tokens) | |
| thinking <think>β¦</think> block content (str | None) | |
| verdict raw Markdown verdict string | |
| verdict_fmt post-processed verdict (str or dict depending on output_format) | |
| mode ReasoningMode used | |
| output_format OutputFormat used | |
| """ | |
| cfg = REASONING_CONFIGS[mode] | |
| _max = max_tokens if max_tokens is not None else cfg.max_total_tokens | |
| _temp = temp if temp is not None else cfg.temperature | |
| _topk = top_k if top_k is not None else cfg.top_k | |
| model.eval() | |
| input_str = ChatTemplate.build_inference_prompt( | |
| user_message=prompt, | |
| system_prompt=cfg.system_prompt, | |
| ) | |
| tokens = tokenizer.encode(input_str, allowed_special="all") | |
| context_len = model.config.context_len | |
| think_tokens = 0 | |
| think_closed = False | |
| if use_kv_cache: | |
| # ββ KV-cached generation ββββββββββββββββββββββββββββββββββββββββββ | |
| kv_caches = make_kv_caches(model.config.n_layer) | |
| idx = torch.tensor([tokens], device=DEVICE) | |
| # Truncate prompt if it exceeds context length | |
| if idx.shape[1] > context_len: | |
| idx = idx[:, -context_len:] | |
| logger.warning("Prompt truncated to context_len=%d tokens", context_len) | |
| # Prefill: process the entire prompt in one forward pass | |
| logits = model(idx, kv_caches=kv_caches) | |
| generated_ids: list[int] = [] | |
| for _ in range(_max): | |
| scaled_logits = logits[:, -1, :] / _temp | |
| if not think_closed and think_tokens >= cfg.max_think_tokens: | |
| next_id = TOK_THINK_CLOSE | |
| next_tok = torch.tensor([[next_id]], device=DEVICE) | |
| else: | |
| v, _ = torch.topk(scaled_logits, _topk) | |
| scaled_logits[scaled_logits < v[:, [-1]]] = -float("Inf") | |
| probs = F.softmax(scaled_logits, dim=-1) | |
| next_tok = torch.multinomial(probs, num_samples=1) | |
| next_id = int(next_tok.item()) | |
| generated_ids.append(next_id) | |
| if not think_closed: | |
| if next_id == TOK_THINK_CLOSE: | |
| think_closed = True | |
| else: | |
| think_tokens += 1 | |
| if next_id == TOK_EOT: | |
| break | |
| # Check context length limit | |
| if kv_caches[0].seq_len >= context_len: | |
| logger.warning("Reached context_len=%d during generation, stopping.", context_len) | |
| break | |
| # Single-token forward pass using cached K/V | |
| logits = model(next_tok, kv_caches=kv_caches) | |
| all_ids = tokens + generated_ids | |
| else: | |
| # ββ Non-cached generation (legacy path) ββββββββββββββββββββββββββ | |
| idx = torch.tensor([tokens], device=DEVICE) | |
| for _ in range(_max): | |
| logits = model(idx[:, -context_len:]) | |
| logits = logits[:, -1, :] / _temp | |
| if not think_closed and think_tokens >= cfg.max_think_tokens: | |
| next_id = TOK_THINK_CLOSE | |
| next_tok = torch.tensor([[next_id]], device=idx.device) | |
| else: | |
| v, _ = torch.topk(logits, _topk) | |
| logits[logits < v[:, [-1]]] = -float("Inf") | |
| probs = F.softmax(logits, dim=-1) | |
| next_tok = torch.multinomial(probs, num_samples=1) | |
| next_id = int(next_tok.item()) | |
| idx = torch.cat((idx, next_tok), dim=1) | |
| if not think_closed: | |
| if next_id == TOK_THINK_CLOSE: | |
| think_closed = True | |
| else: | |
| think_tokens += 1 | |
| if next_id == TOK_EOT: | |
| break | |
| all_ids = idx[0].tolist() | |
| full_text = tokenizer.decode(all_ids, errors="replace") | |
| # ββ Parse <think> block and verdict ββββββββββββββββββββββββββββββββββββββ | |
| thinking = verdict = None | |
| if "<think>" in full_text and "</think>" in full_text: | |
| ts = full_text.index("<think>") + len("<think>") | |
| te = full_text.index("</think>") | |
| thinking = full_text[ts:te].strip() | |
| verdict = full_text[te + len("</think>"):].strip() | |
| else: | |
| verdict = re.sub(r"^.*?<\|assistant\|>\s*", "", full_text, flags=re.DOTALL).strip() | |
| verdict = re.sub(r"^<\|assistant\|>\s*", "", verdict).strip() | |
| result: dict[str, Any] = { | |
| "full_text": full_text, | |
| "thinking": thinking, | |
| "verdict": verdict, | |
| "mode": mode, | |
| "output_format": output_format, | |
| } | |
| result["verdict_fmt"] = format_output(result, output_format) | |
| return result | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Sample Dataset (Markdown-formatted reasoning) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| SAMPLE_MODERATION_DATA: list[dict[str, Any]] = [ | |
| # ββ SAFE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "instruction": "What's a good recipe for chocolate chip cookies?", | |
| "reasoning": ( | |
| "## Intent\n" | |
| "A standard culinary question with no harmful framing.\n\n" | |
| "## Harm Potential\n" | |
| "- No harmful content detected.\n" | |
| "- Appropriate for all audiences.\n\n" | |
| "## Edge Cases\n" | |
| "None applicable." | |
| ), | |
| "output": "## Verdict\n**SAFE**", | |
| }, | |
| { | |
| "instruction": "Can you explain how photosynthesis works?", | |
| "reasoning": ( | |
| "## Intent\n" | |
| "Science education question β no harmful elements.\n\n" | |
| "## Harm Potential\n" | |
| "- Zero risk.\n" | |
| "- Standard biology curriculum content.\n\n" | |
| "## Edge Cases\n" | |
| "None applicable." | |
| ), | |
| "output": "## Verdict\n**SAFE**", | |
| }, | |
| { | |
| "instruction": "What are some tips for improving my sleep hygiene?", | |
| "reasoning": ( | |
| "## Intent\n" | |
| "Wellness question about healthy sleeping habits.\n\n" | |
| "## Harm Potential\n" | |
| "- Content is broadly beneficial.\n" | |
| "- No self-harm framing present.\n\n" | |
| "## Edge Cases\n" | |
| "None applicable." | |
| ), | |
| "output": "## Verdict\n**SAFE**", | |
| }, | |
| { | |
| "instruction": "Write a short poem about autumn leaves.", | |
| "reasoning": ( | |
| "## Intent\n" | |
| "Creative writing request with pleasant seasonal imagery.\n\n" | |
| "## Harm Potential\n" | |
| "- No violent or offensive themes.\n\n" | |
| "## Edge Cases\n" | |
| "None applicable." | |
| ), | |
| "output": "## Verdict\n**SAFE**", | |
| }, | |
| { | |
| "instruction": "How do I fix a 'module not found' error in Python?", | |
| "reasoning": ( | |
| "## Intent\n" | |
| "Standard software debugging question.\n\n" | |
| "## Harm Potential\n" | |
| "- No malicious intent detectable; routine developer activity.\n\n" | |
| "## Edge Cases\n" | |
| "None applicable." | |
| ), | |
| "output": "## Verdict\n**SAFE**", | |
| }, | |
| # ββ SPAM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "instruction": ( | |
| "CONGRATULATIONS! You've been selected to receive a $1,000 gift card. " | |
| "Click here now to claim your prize before it expires!!!" | |
| ), | |
| "reasoning": ( | |
| "## Intent\n" | |
| "Unsolicited prize notification designed to deceive the recipient.\n\n" | |
| "## Evidence\n" | |
| "- **ALL-CAPS** opener β attention-hijacking tactic.\n" | |
| "- Implausibly large reward with zero context.\n" | |
| "- Artificial deadline pressure ('before it expires').\n" | |
| "- Suspicious link prompt (`Click here`).\n\n" | |
| "## Harm Potential\n" | |
| "High. Consistent with **phishing / advance-fee fraud**.\n\n" | |
| "## Edge Cases\n" | |
| "Legitimate prize notifications do not use this pattern; " | |
| "no mitigating context present." | |
| ), | |
| "output": "## Verdict\n**SPAM**", | |
| }, | |
| { | |
| "instruction": ( | |
| "Make money fast! Work from home and earn $5000 a week " | |
| "with no experience needed. DM me for details." | |
| ), | |
| "reasoning": ( | |
| "## Intent\n" | |
| "Recruitment for an unspecified work-from-home scheme.\n\n" | |
| "## Evidence\n" | |
| "- Unrealistic earning claim ($5,000/week, no experience).\n" | |
| "- Redirect to private messaging β opaque onboarding is a classic " | |
| "**MLM / scam** signal.\n" | |
| "- No product, employer, or verifiable detail provided.\n\n" | |
| "## Harm Potential\n" | |
| "High. Likely pyramid scheme or fraudulent job offer.\n\n" | |
| "## Edge Cases\n" | |
| "Could theoretically be legitimate affiliate marketing, but the total " | |
| "absence of verifiable detail removes the benefit of the doubt." | |
| ), | |
| "output": "## Verdict\n**SPAM**", | |
| }, | |
| # ββ HARASSMENT ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "instruction": "You're so stupid, nobody likes you and you should just disappear.", | |
| "reasoning": ( | |
| "## Intent\n" | |
| "Direct personal attack aimed at demeaning and distressing an individual.\n\n" | |
| "## Evidence\n" | |
| "> 'You're so stupid'\n\n" | |
| "Targeted intellectual insult.\n\n" | |
| "> 'nobody likes you'\n\n" | |
| "Social-isolation attack.\n\n" | |
| "> 'should just disappear'\n\n" | |
| "Veiled suggestion of self-removal β crisis-adjacent language.\n\n" | |
| "## Harm Potential\n" | |
| "High. Combination of insult and dismissal constitutes targeted " | |
| "**harassment**. The 'disappear' phrasing adds a secondary risk " | |
| "of normalising self-harm ideation.\n\n" | |
| "## Edge Cases\n" | |
| "No mitigating context (fiction, roleplay) present." | |
| ), | |
| "output": "## Verdict\n**HARASSMENT**", | |
| }, | |
| { | |
| "instruction": "I know where you live and I'll make sure you regret posting that.", | |
| "reasoning": ( | |
| "## Intent\n" | |
| "Intimidation message designed to silence a specific person.\n\n" | |
| "## Evidence\n" | |
| "> 'I know where you live'\n\n" | |
| "Implicit doxxing claim β suggests the sender has physical location data.\n\n" | |
| "> 'make sure you regret posting that'\n\n" | |
| "Vague but menacing promise of retaliation.\n\n" | |
| "## Harm Potential\n" | |
| "Very high. This constitutes an **implicit threat** with meaningful " | |
| "escalation risk to real-world harm.\n\n" | |
| "## Edge Cases\n" | |
| "None; the threatening framing is unambiguous." | |
| ), | |
| "output": "## Verdict\n**HARASSMENT**", | |
| }, | |
| # ββ HATE SPEECH βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "instruction": "All people from [country] are criminals and should be deported.", | |
| "reasoning": ( | |
| "## Intent\n" | |
| "Sweeping negative generalisation applied to an entire national group.\n\n" | |
| "## Evidence\n" | |
| "- Attributes **criminality** to every member of a nationality β " | |
| "textbook xenophobic stereotype.\n" | |
| "- Calls for **mass deportation** β dehumanising policy dimension.\n" | |
| "- No nuance, caveat, or individual context.\n\n" | |
| "## Harm Potential\n" | |
| "High. Promotes discrimination and could incite real-world hostility.\n\n" | |
| "## Policy\n" | |
| "Violates hate-speech policy under the **national origin** protected " | |
| "characteristic." | |
| ), | |
| "output": "## Verdict\n**HATE_SPEECH**", | |
| }, | |
| # ββ CRISIS REFERRAL βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "instruction": "I've been feeling really hopeless lately and I don't see a reason to keep going.", | |
| "reasoning": ( | |
| "## Intent\n" | |
| "User is expressing deep personal distress.\n\n" | |
| "## Evidence\n" | |
| "> 'feeling really hopeless'\n\n" | |
| "Recognised warning sign of depression / suicidal ideation.\n\n" | |
| "> 'don't see a reason to keep going'\n\n" | |
| "Implicit passive suicidal ideation β no explicit plan stated, " | |
| "but language warrants **immediate** attention.\n\n" | |
| "## Harm Potential\n" | |
| "High. This is not standard moderation territory β user safety is " | |
| "the primary concern.\n\n" | |
| "## Recommended Action\n" | |
| "Escalate to **crisis resources** immediately. Respond with empathetic " | |
| "referral to a mental-health helpline rather than standard policy " | |
| "enforcement." | |
| ), | |
| "output": "## Verdict\n**CRISIS_REFERRAL**", | |
| }, | |
| # ββ MISINFORMATION ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "instruction": ( | |
| "Vaccines cause autismβthis has been proven by multiple independent " | |
| "studies. Share this with everyone!" | |
| ), | |
| "reasoning": ( | |
| "## Intent\n" | |
| "Presenting a debunked medical claim as fact and urging viral spread.\n\n" | |
| "## Evidence\n" | |
| "- The vaccine-autism link originates from a **retracted 1998 Wakefield " | |
| "paper** found to be fraudulent.\n" | |
| "- Refuted by dozens of large-scale epidemiological studies " | |
| "(millions of children).\n" | |
| "- Claim of 'multiple independent studies' is **false**.\n" | |
| "- 'Share this with everyone!' actively promotes spread of the " | |
| "misinformation.\n\n" | |
| "## Harm Potential\n" | |
| "Very high. Vaccine hesitancy driven by this myth has caused measurable " | |
| "declines in immunisation coverage and preventable outbreaks.\n\n" | |
| "## Policy\n" | |
| "Violates **health misinformation** policy." | |
| ), | |
| "output": "## Verdict\n**MISINFORMATION**", | |
| }, | |
| { | |
| "instruction": ( | |
| "The moon landing in 1969 was faked by NASA on a Hollywood set. " | |
| "Here's the proof." | |
| ), | |
| "reasoning": ( | |
| "## Intent\n" | |
| "Promoting a well-known conspiracy theory as documented fact.\n\n" | |
| "## Evidence\n" | |
| "- Apollo 11 is independently verified by tracking stations in multiple " | |
| "countries including the **Soviet Union** (adversary with every incentive " | |
| "to expose a hoax).\n" | |
| "- Lunar samples, retroreflectors, and telemetry data all corroborate the " | |
| "mission.\n" | |
| "- 'Here's proof' promises evidence that conspiracy literature has " | |
| "repeatedly failed to produce.\n\n" | |
| "## Harm Potential\n" | |
| "Moderate. Erodes **public trust in science and institutions**; " | |
| "contributes to broader epistemic harm.\n\n" | |
| "## Policy\n" | |
| "Flagged as **misinformation** (thoroughly-refuted, well-documented claim)." | |
| ), | |
| "output": "## Verdict\n**MISINFORMATION**", | |
| }, | |
| # ββ AMBIGUOUS / CONTEXT-DEPENDENT βββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "instruction": "How do I pick a lock?", | |
| "reasoning": ( | |
| "## Intent\n" | |
| "Ambiguous. Lock-picking has multiple legitimate use-cases.\n\n" | |
| "## Legitimate Uses\n" | |
| "- **Locksport** β a recognised competitive hobby.\n" | |
| "- Professional **locksmiths** and security researchers.\n" | |
| "- Individuals locked out of their own property.\n\n" | |
| "## Harmful Uses\n" | |
| "- Facilitating burglary or unauthorised entry.\n\n" | |
| "## Harm Assessment\n" | |
| "Without additional context (e.g., mention of a specific target or prior " | |
| "threatening messages), this query alone does **not** meet the violation " | |
| "threshold. Information is freely available and has significant legitimate " | |
| "utility.\n\n" | |
| "## Decision\n" | |
| "Allow. Note: escalation to UNSAFE warranted if threatening context " | |
| "appears in the same conversation." | |
| ), | |
| "output": "## Verdict\n**SAFE** *(context-dependent)*", | |
| }, | |
| { | |
| "instruction": "What household chemicals should never be mixed together?", | |
| "reasoning": ( | |
| "## Intent\n" | |
| "The framing is explicitly **preventative**, not instructional.\n\n" | |
| "## Evidence\n" | |
| "- Phrased as 'should *never* be mixed' β seeks to *avoid* harm.\n" | |
| "- This information is routinely published by poison-control centres, " | |
| "schools, and safety agencies.\n\n" | |
| "## Harm Assessment\n" | |
| "Low. Serves harm-reduction purposes. Re-classification would require " | |
| "explicit harmful intent in additional context.\n\n" | |
| "## Decision\n" | |
| "Allow." | |
| ), | |
| "output": "## Verdict\n**SAFE**", | |
| }, | |
| ] | |
| def get_sample_dataset( | |
| tokenizer_obj: Optional[tiktoken.Encoding] = None, | |
| max_length: int = 12288, | |
| system_prompt: str = "", | |
| ) -> "ModerationReasoningDataset": | |
| """Return a ``ModerationReasoningDataset`` pre-loaded with SAMPLE_MODERATION_DATA.""" | |
| enc = tokenizer_obj or tokenizer | |
| return ModerationReasoningDataset( | |
| SAMPLE_MODERATION_DATA, enc, | |
| max_length=max_length, | |
| system_prompt=system_prompt, | |
| ) | |
| def load_dataset_json(file_path: Optional[str | Path] = None) -> list[dict[str, Any]]: | |
| """Load dataset from JSON file.""" | |
| path = Path(file_path) if file_path is not None else DATASET_JSON_PATH | |
| with path.open("r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| if not isinstance(data, list): | |
| raise ValueError("dataset.json must contain a list of samples") | |
| # Decrease dataset to 1k as requested | |
| data = data[:1000] | |
| normalized: list[dict[str, Any]] = [] | |
| for index, item in enumerate(data): | |
| if not isinstance(item, dict): | |
| raise ValueError(f"dataset.json item {index} must be an object") | |
| instruction = item.get("instruction") | |
| reasoning = item.get("reasoning") | |
| output = item.get("output") | |
| label = item.get("label") | |
| if not isinstance(instruction, str) or not instruction.strip(): | |
| raise ValueError(f"dataset.json item {index} is missing a valid instruction") | |
| if not isinstance(reasoning, str) or not reasoning.strip(): | |
| raise ValueError(f"dataset.json item {index} is missing valid reasoning") | |
| if not isinstance(output, str) or not output.strip(): | |
| if not isinstance(label, str) or not label.strip(): | |
| raise ValueError(f"dataset.json item {index} needs output or label") | |
| output = f"## Verdict\n**{label.strip().upper()}**" | |
| normalized.append( | |
| { | |
| "instruction": instruction.strip(), | |
| "reasoning": reasoning.strip(), | |
| "output": output.strip(), | |
| } | |
| ) | |
| return normalized | |
| def get_dataset( | |
| tokenizer_obj: Optional[tiktoken.Encoding] = None, | |
| max_length: int = 12288, | |
| system_prompt: str = "", | |
| file_path: Optional[str | Path] = None, | |
| ) -> "ModerationReasoningDataset": | |
| enc = tokenizer_obj or tokenizer | |
| samples = load_dataset_json(file_path) | |
| return ModerationReasoningDataset( | |
| samples, | |
| enc, | |
| max_length=max_length, | |
| system_prompt=system_prompt, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Dataset (chat-template aware) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class ModerationReasoningDataset(Dataset[dict[str, torch.Tensor]]): | |
| """ | |
| Formats each sample as a three-turn chat via ``ChatTemplate``: | |
| system β moderator persona + Markdown instruction | |
| user β message under review | |
| assistant β <think>{reasoning}</think>{verdict}<|endoftext|> | |
| Only assistant tokens contribute to the training loss. | |
| """ | |
| DEFAULT_SYSTEM: str = ( | |
| "You are a careful content moderator. " | |
| "Analyse the user message, reason step-by-step inside <think> tags, " | |
| "then issue a structured Markdown verdict.\n\n" | |
| + _MARKDOWN_INSTRUCTION | |
| ) | |
| def __init__( | |
| self, | |
| data_list: list[dict[str, Any]], | |
| enc: tiktoken.Encoding, | |
| max_length: int = 12288, | |
| system_prompt: str = "", | |
| ): | |
| self.enc = enc | |
| self.max_length = max_length | |
| self.samples = data_list | |
| self.system_prompt = system_prompt or self.DEFAULT_SYSTEM | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| item = self.samples[idx] | |
| assistant_content = ChatTemplate.render_assistant_content( | |
| reasoning=item["reasoning"], | |
| verdict=item["output"], | |
| ) | |
| messages = [ | |
| Message("system", self.system_prompt), | |
| Message("user", item["instruction"]), | |
| Message("assistant", assistant_content), | |
| ] | |
| return ChatTemplate.tokenize(messages, self.enc, self.max_length) | |
| def collate_fn(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: | |
| input_ids = pad_sequence( | |
| [b["input_ids"] for b in batch], batch_first=True, padding_value=TOK_EOT | |
| ) | |
| labels = pad_sequence( | |
| [b["labels"] for b in batch], batch_first=True, padding_value=-100 | |
| ) | |
| return {"input_ids": input_ids, "labels": labels} | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Trainer | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class GreesyTrainer: | |
| def __init__( | |
| self, | |
| model: GreesyGPT, | |
| train_dataset: Dataset[dict[str, torch.Tensor]], | |
| lr: float = 2e-5, | |
| batch_size: int = 2, | |
| grad_accum: int = 4, | |
| ): | |
| self.model = model.to(DEVICE) | |
| self.optimizer = torch.optim.AdamW( | |
| model.parameters(), lr=lr, weight_decay=0.1 | |
| ) | |
| self.dataloader = DataLoader( | |
| train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn | |
| ) | |
| self.grad_accum = grad_accum | |
| def train_epoch(self, epoch: int) -> float: | |
| self.model.train() | |
| total_loss = 0.0 | |
| self.optimizer.zero_grad() | |
| pbar = tqdm(self.dataloader, desc=f"Epoch {epoch}") | |
| for step, batch in enumerate(pbar): | |
| inputs = batch["input_ids"].to(DEVICE) | |
| targets = batch["labels"].to(DEVICE) | |
| with _autocast_ctx(): | |
| logits = self.model(inputs) | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = targets[..., 1:].contiguous() | |
| loss = F.cross_entropy( | |
| shift_logits.view(-1, shift_logits.size(-1)), | |
| shift_labels.view(-1), | |
| ignore_index=-100, | |
| ) / self.grad_accum | |
| loss.backward() | |
| if (step + 1) % self.grad_accum == 0: | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| total_loss += loss.item() * self.grad_accum | |
| pbar.set_postfix(loss=loss.item() * self.grad_accum) | |
| avg = total_loss / len(self.dataloader) | |
| print(f"Epoch {epoch} avg loss: {avg:.4f}") | |
| return avg | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Quick-start example | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| print(f"Using device: {DEVICE}") | |
| print(describe_reasoning_modes()) | |
| print() | |
| # ββ Train on sample data ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| model = GreesyGPT() | |
| dataset = get_dataset() if DATASET_JSON_PATH.exists() else get_sample_dataset() | |
| trainer = GreesyTrainer(model, dataset, batch_size=2, grad_accum=4) | |
| trainer.train_epoch(epoch=1) | |
| # ββ Save the model ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| save_path = Path(__file__).parent / "greesy_gpt.pt" | |
| torch.save(model.state_dict(), save_path) | |
| print(f"Model saved to {save_path}") | |
| # ββ Inference: compare modes Γ output formats βββββββββββββββββββββββββββββ | |
| test_prompt = "You are worthless and no one will ever love you." | |
| for mode in ReasoningMode: | |
| for fmt in OutputFormat: | |
| result = generate_moderation(model, test_prompt, mode=mode, output_format=fmt) | |
| print(f"\nββ {mode.value.upper()} / {fmt.value.upper()} ββ") | |
| if fmt == OutputFormat.JSON: | |
| print(json.dumps(result["verdict_fmt"], indent=2)) | |
| else: | |
| print(str(result["verdict_fmt"])[:300]) | |