sathishphdai commited on
Commit
abedca1
·
verified ·
1 Parent(s): 190836e

Upload folder using huggingface_hub

Browse files
Files changed (10) hide show
  1. README.md +50 -0
  2. chat.py +90 -0
  3. config.json +18 -0
  4. config.py +102 -0
  5. model.py +223 -0
  6. model.safetensors +3 -0
  7. pytorch_model.bin +3 -0
  8. system_admin_tokenizer.json +0 -0
  9. tokenizer.json +0 -0
  10. tokenizer_config.json +8 -0
README.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: [en]
3
+ license: mit
4
+ tags:
5
+ - sysadmin
6
+ - linux
7
+ - windows-server
8
+ - networking
9
+ - security
10
+ - slm
11
+ - llama-style
12
+ - rope
13
+ - 5m-context
14
+ - from-scratch
15
+ - 1b-params
16
+ pipeline_tag: text-generation
17
+ ---
18
+
19
+ # System Admin-SLM: Role-Based Small Language Model
20
+
21
+ A **LLaMA-style transformer** (~1016.6M params, ~1.02B) trained from scratch for the **System Admin** role.
22
+ Supports up to **5M token context** via RoPE with gradient checkpointing.
23
+
24
+ ## Architecture
25
+ | Component | Value |
26
+ |-----------|-------|
27
+ | Architecture | LLaMA-style (RoPE + RMSNorm + SwiGLU) |
28
+ | Parameters | ~1016.6M (~1.02B) |
29
+ | Layers | 32 |
30
+ | Heads | 20 |
31
+ | Embedding | 1600 |
32
+ | Max Context | 5,000,000 tokens |
33
+ | Max Output | 5,000,000 tokens |
34
+ | Vocab | 18,841 BPE |
35
+ | Model Size | ~4 GB (fp32) |
36
+
37
+ ## Training
38
+ - Best eval loss: 5.795391702651978
39
+ - Trained with gradient checkpointing on Apple M4 (MPS)
40
+ - 3 epochs, batch_size=1, grad_accum=16
41
+
42
+ ## Usage
43
+ ```python
44
+ from huggingface_hub import hf_hub_download
45
+ from tokenizers import Tokenizer
46
+
47
+ model_path = hf_hub_download("sathishphdai/system-admin-slm-5m", "model.safetensors")
48
+ tokenizer_path = hf_hub_download("sathishphdai/system-admin-slm-5m", "system_admin_tokenizer.json")
49
+ tokenizer = Tokenizer.from_file(tokenizer_path)
50
+ ```
chat.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Interactive chat and demo inference for Role SLM."""
3
+
4
+ import torch
5
+ from tokenizers import Tokenizer
6
+ from config import cfg
7
+ from model import RoleSLM
8
+
9
+
10
+ def load_model(checkpoint_name="best_model.pt"):
11
+ device = torch.device(cfg.device)
12
+ ckpt_path = cfg.checkpoint_dir / checkpoint_name
13
+ if not ckpt_path.exists():
14
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
15
+
16
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
17
+ for key, val in ckpt.get("config", {}).items():
18
+ if hasattr(cfg, key):
19
+ setattr(cfg, key, val)
20
+
21
+ model = RoleSLM()
22
+ model.load_state_dict(ckpt["model_state_dict"], strict=False)
23
+ model = model.to(device)
24
+ model.eval()
25
+
26
+ tok_path = cfg.tokenizer_dir / cfg.tokenizer_filename
27
+ tokenizer = Tokenizer.from_file(str(tok_path))
28
+ print(f"Model loaded: {model.count_parameters()/1e6:.2f}M params")
29
+ return model, tokenizer, device
30
+
31
+
32
+ def generate_response(model, tokenizer, device, prompt, max_tokens=None,
33
+ temperature=0.8, top_k=50, top_p=0.9):
34
+ max_tokens = max_tokens or min(cfg.max_new_tokens, 512)
35
+ encoded = tokenizer.encode(prompt)
36
+ ids = encoded.ids
37
+ if ids and ids[-1] == 3:
38
+ ids = ids[:-1]
39
+ input_ids = torch.tensor([ids], dtype=torch.long, device=device)
40
+ input_len = input_ids.shape[1]
41
+
42
+ with torch.no_grad():
43
+ output_ids = model.generate(input_ids, max_new_tokens=max_tokens,
44
+ temperature=temperature, top_k=top_k, top_p=top_p)
45
+
46
+ new_tokens = output_ids[0][input_len:].tolist()
47
+ response = tokenizer.decode(new_tokens)
48
+ response = response.replace("<eos>", "").replace("<bos>", "").replace("<pad>", "").strip()
49
+ return response
50
+
51
+
52
+ DEMO_PROMPTS = ['Linux system administration involves', 'Server hardening best practices include', 'Automated configuration management using', 'Network troubleshooting steps include', 'System monitoring tools help administrators by']
53
+
54
+
55
+ def demo_generation(model, tokenizer, device):
56
+ print(f"\n{'='*60}")
57
+ print(f"Demo: {cfg.domain_name}-SLM Inference")
58
+ print(f"{'='*60}\n")
59
+ for i, prompt in enumerate(DEMO_PROMPTS, 1):
60
+ print(f"[{i}] Prompt: {prompt}")
61
+ response = generate_response(model, tokenizer, device, prompt, max_tokens=256)
62
+ print(f" Response: {response}\n")
63
+
64
+
65
+ def interactive_chat():
66
+ print("Loading model...")
67
+ model, tokenizer, device = load_model()
68
+ print(f"\n{'='*60}")
69
+ print(f"{cfg.domain_name}-SLM Interactive Chat (type 'quit' to exit, 'demo' for demos)")
70
+ print(f"{'='*60}\n")
71
+ while True:
72
+ try:
73
+ user_input = input("You: ").strip()
74
+ if not user_input:
75
+ continue
76
+ if user_input.lower() == "quit":
77
+ print("Goodbye!")
78
+ break
79
+ if user_input.lower() == "demo":
80
+ demo_generation(model, tokenizer, device)
81
+ continue
82
+ response = generate_response(model, tokenizer, device, user_input, max_tokens=512)
83
+ print(f"SLM: {response}\n")
84
+ except KeyboardInterrupt:
85
+ print("\nGoodbye!")
86
+ break
87
+
88
+
89
+ if __name__ == "__main__":
90
+ interactive_chat()
config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RoleSLM"
4
+ ],
5
+ "model_type": "system_admin-slm",
6
+ "domain": "System Admin",
7
+ "vocab_size": 18841,
8
+ "n_layer": 32,
9
+ "n_head": 20,
10
+ "n_embd": 1600,
11
+ "block_size": 512,
12
+ "dropout": 0.05,
13
+ "bias": false,
14
+ "ffn_multiplier": 2.667,
15
+ "max_position_embeddings": 5000000,
16
+ "rope_theta": 5000000.0,
17
+ "n_parameters": 1016566400
18
+ }
config.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Configuration for System-Admin-SLM: A Role-Based SLM for System Admin.
4
+ ~1B params, LLaMA-style architecture with RoPE — supports up to 5M token context.
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+
12
+ @dataclass
13
+ class SLMConfig:
14
+ """All hyperparameters and paths in one place."""
15
+
16
+ # ── Project paths ──────────────────────────────────────────────
17
+ project_dir: Path = Path(__file__).resolve().parent
18
+ data_dir: Path = field(default=None)
19
+ tokenizer_dir: Path = field(default=None)
20
+ checkpoint_dir: Path = field(default=None)
21
+
22
+ # ── Domain ─────────────────────────────────────────────────────
23
+ domain_name: str = "System Admin"
24
+ domain_slug: str = "system_admin"
25
+ tokenizer_filename: str = "system_admin_tokenizer.json"
26
+
27
+ # ── Tokenizer ──────────────────────────────────────────────────
28
+ vocab_size: int = 32_768
29
+ min_frequency: int = 2
30
+ special_tokens: list = field(
31
+ default_factory=lambda: [
32
+ "<pad>", "<unk>", "<bos>", "<eos>",
33
+ "<|system|>", "<|user|>", "<|assistant|>",
34
+ ]
35
+ )
36
+
37
+ # ── Model (~1B params, LLaMA-style with RoPE) ─────────────────
38
+ n_layer: int = 32
39
+ n_head: int = 20
40
+ n_embd: int = 1600
41
+ block_size: int = 1_000_000 # 1M input token context window
42
+ dropout: float = 0.05
43
+ bias: bool = False
44
+ ffn_multiplier: float = 2.667
45
+
46
+ # ── RoPE ───────────────────────────────────────────────────────
47
+ max_position_embeddings: int = 5_000_000 # 5M context window via RoPE
48
+ rope_theta: float = 5_000_000.0 # Scaled for 5M context window
49
+
50
+ # ── Sliding Window ─────────────────────────────────────────────
51
+ sliding_window: Optional[int] = None
52
+
53
+ # ── Gradient Checkpointing (essential for 1B on 24GB) ──────────
54
+ gradient_checkpointing: bool = True
55
+
56
+ # ── Training ───────────────────────────────────────────────────
57
+ batch_size: int = 1
58
+ gradient_accumulation_steps: int = 16
59
+ learning_rate: float = 2e-4
60
+ weight_decay: float = 0.1
61
+ max_epochs: int = 3
62
+ dataset_stride: int = 512 # Training stride
63
+ warmup_steps: int = 100
64
+ grad_clip: float = 1.0
65
+ eval_interval: int = 50
66
+ eval_samples: int = 10
67
+ log_interval: int = 10
68
+ device: str = "auto"
69
+
70
+ # ── Generation ─────────────────────────────────────────────────
71
+ max_new_tokens: int = 5_000_000 # 5M max output tokens
72
+ temperature: float = 0.8
73
+ top_k: int = 50
74
+ top_p: float = 0.9
75
+
76
+ # ── HuggingFace ────────────────────────────────────────────────
77
+ hf_repo_name: str = "system-admin-slm-5m"
78
+ hf_model_card_tags: list = field(default_factory=lambda: ['sysadmin', 'linux', 'windows-server', 'networking', 'security', 'slm', 'llama-style', 'rope', '5m-context', 'from-scratch', '1b-params'])
79
+
80
+ def __post_init__(self):
81
+ if self.data_dir is None:
82
+ self.data_dir = self.project_dir / "data"
83
+ if self.tokenizer_dir is None:
84
+ self.tokenizer_dir = self.project_dir / "tokenizer"
85
+ if self.checkpoint_dir is None:
86
+ self.checkpoint_dir = self.project_dir / "checkpoints"
87
+
88
+ self.data_dir.mkdir(parents=True, exist_ok=True)
89
+ self.tokenizer_dir.mkdir(parents=True, exist_ok=True)
90
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
91
+
92
+ if self.device == "auto":
93
+ import torch
94
+ if torch.cuda.is_available():
95
+ self.device = "cuda"
96
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
97
+ self.device = "mps"
98
+ else:
99
+ self.device = "cpu"
100
+
101
+
102
+ cfg = SLMConfig()
model.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ model.py — Role SLM Transformer (~1B params) with RoPE + Gradient Checkpointing
4
+ ================================================================================
5
+ Supports context lengths up to 5M tokens via:
6
+ * RoPE (no fixed position embedding table)
7
+ * RMSNorm (more efficient than LayerNorm)
8
+ * SwiGLU activation (better training dynamics)
9
+ * Flash Attention via PyTorch scaled_dot_product_attention
10
+ * Gradient checkpointing for memory-efficient training on 24GB
11
+ """
12
+
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.checkpoint import checkpoint as grad_checkpoint
18
+ from typing import Optional, Tuple
19
+ from config import cfg
20
+
21
+
22
+ class RMSNorm(nn.Module):
23
+ def __init__(self, dim: int, eps: float = 1e-6):
24
+ super().__init__()
25
+ self.eps = eps
26
+ self.weight = nn.Parameter(torch.ones(dim))
27
+
28
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
29
+ norm = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
30
+ return (x.float() * norm).type_as(x) * self.weight
31
+
32
+
33
+ def precompute_rope_freqs(dim, max_seq_len, theta=10000.0, device=None):
34
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
35
+ t = torch.arange(max_seq_len, device=device).float()
36
+ freqs = torch.outer(t, freqs)
37
+ return freqs.cos(), freqs.sin()
38
+
39
+
40
+ def apply_rope(x, cos, sin):
41
+ seq_len = x.shape[2]
42
+ head_dim = x.shape[3]
43
+ cos = cos[:seq_len].unsqueeze(0).unsqueeze(0)
44
+ sin = sin[:seq_len].unsqueeze(0).unsqueeze(0)
45
+ x1 = x[..., :head_dim // 2]
46
+ x2 = x[..., head_dim // 2:]
47
+ return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
48
+
49
+
50
+ class CausalSelfAttention(nn.Module):
51
+ def __init__(self):
52
+ super().__init__()
53
+ assert cfg.n_embd % cfg.n_head == 0
54
+ self.n_head = cfg.n_head
55
+ self.head_dim = cfg.n_embd // cfg.n_head
56
+ self.q_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False)
57
+ self.k_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False)
58
+ self.v_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False)
59
+ self.out_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False)
60
+ self.resid_drop = nn.Dropout(cfg.dropout)
61
+
62
+ def forward(self, x, rope_cos, rope_sin):
63
+ B, T, C = x.shape
64
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
65
+ k = self.k_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
66
+ v = self.v_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
67
+ q = apply_rope(q, rope_cos, rope_sin)
68
+ k = apply_rope(k, rope_cos, rope_sin)
69
+ if hasattr(F, 'scaled_dot_product_attention'):
70
+ y = F.scaled_dot_product_attention(q, k, v,
71
+ dropout_p=cfg.dropout if self.training else 0.0, is_causal=True)
72
+ else:
73
+ scale = 1.0 / math.sqrt(self.head_dim)
74
+ att = (q @ k.transpose(-2, -1)) * scale
75
+ mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
76
+ att = att.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
77
+ att = F.softmax(att, dim=-1)
78
+ y = att @ v
79
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
80
+ return self.resid_drop(self.out_proj(y))
81
+
82
+
83
+ class SwiGLUFFN(nn.Module):
84
+ def __init__(self):
85
+ super().__init__()
86
+ hidden_dim = int(cfg.n_embd * getattr(cfg, 'ffn_multiplier', 2.667))
87
+ hidden_dim = ((hidden_dim + 63) // 64) * 64
88
+ self.gate_proj = nn.Linear(cfg.n_embd, hidden_dim, bias=False)
89
+ self.up_proj = nn.Linear(cfg.n_embd, hidden_dim, bias=False)
90
+ self.down_proj = nn.Linear(hidden_dim, cfg.n_embd, bias=False)
91
+ self.dropout = nn.Dropout(cfg.dropout)
92
+
93
+ def forward(self, x):
94
+ return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
95
+
96
+
97
+ class TransformerBlock(nn.Module):
98
+ def __init__(self):
99
+ super().__init__()
100
+ self.attn_norm = RMSNorm(cfg.n_embd)
101
+ self.attn = CausalSelfAttention()
102
+ self.ffn_norm = RMSNorm(cfg.n_embd)
103
+ self.ffn = SwiGLUFFN()
104
+
105
+ def forward(self, x, rope_cos, rope_sin):
106
+ x = x + self.attn(self.attn_norm(x), rope_cos, rope_sin)
107
+ x = x + self.ffn(self.ffn_norm(x))
108
+ return x
109
+
110
+
111
+ class RoleSLM(nn.Module):
112
+ """Role-Based Small Language Model — ~1B params, LLaMA-style with gradient checkpointing."""
113
+
114
+ def __init__(self):
115
+ super().__init__()
116
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd)
117
+ self.drop = nn.Dropout(cfg.dropout)
118
+ self.blocks = nn.ModuleList([TransformerBlock() for _ in range(cfg.n_layer)])
119
+ self.norm = RMSNorm(cfg.n_embd)
120
+ self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
121
+ self.tok_emb.weight = self.lm_head.weight # Weight tying
122
+
123
+ self.use_checkpointing = getattr(cfg, 'gradient_checkpointing', True)
124
+
125
+ head_dim = cfg.n_embd // cfg.n_head
126
+ max_pos = getattr(cfg, 'max_position_embeddings', 1_000_000)
127
+ rope_theta = getattr(cfg, 'rope_theta', 10000.0)
128
+ precompute_len = min(max_pos, cfg.block_size * 2)
129
+ cos, sin = precompute_rope_freqs(head_dim, precompute_len, theta=rope_theta)
130
+ self.register_buffer("rope_cos", cos, persistent=False)
131
+ self.register_buffer("rope_sin", sin, persistent=False)
132
+ self._rope_max_len = precompute_len
133
+ self._rope_theta = rope_theta
134
+ self._head_dim = head_dim
135
+ self.apply(self._init_weights)
136
+
137
+ n_params = sum(p.numel() for p in self.parameters())
138
+ print(f"{cfg.domain_name}-SLM initialized: {n_params/1e6:.2f}M parameters ({n_params/1e9:.3f}B)")
139
+ print(f" Architecture: {cfg.n_layer}L / {cfg.n_head}H / {cfg.n_embd}D")
140
+ print(f" Gradient checkpointing: {self.use_checkpointing}")
141
+ print(f" Max context: {max_pos:,} tokens (via RoPE)")
142
+ print(f" Estimated model size: {n_params * 4 / 1e9:.2f} GB (fp32)")
143
+
144
+ def _init_weights(self, module):
145
+ if isinstance(module, nn.Linear):
146
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
147
+ if module.bias is not None:
148
+ torch.nn.init.zeros_(module.bias)
149
+ elif isinstance(module, nn.Embedding):
150
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
151
+
152
+ def _extend_rope(self, seq_len, device):
153
+ if seq_len > self._rope_max_len:
154
+ new_len = max(seq_len, self._rope_max_len * 2)
155
+ cos, sin = precompute_rope_freqs(self._head_dim, new_len,
156
+ theta=self._rope_theta, device=device)
157
+ self.rope_cos = cos
158
+ self.rope_sin = sin
159
+ self._rope_max_len = new_len
160
+
161
+ def _block_forward(self, block, x, rope_cos, rope_sin):
162
+ """Wrapper for gradient checkpointing."""
163
+ return block(x, rope_cos, rope_sin)
164
+
165
+ def forward(self, idx, targets=None):
166
+ B, T = idx.shape
167
+ device = idx.device
168
+ self._extend_rope(T, device)
169
+ x = self.drop(self.tok_emb(idx))
170
+ rope_cos = self.rope_cos[:T].to(device)
171
+ rope_sin = self.rope_sin[:T].to(device)
172
+ for block in self.blocks:
173
+ if self.use_checkpointing and self.training:
174
+ x = grad_checkpoint(self._block_forward, block, x, rope_cos, rope_sin,
175
+ use_reentrant=False)
176
+ else:
177
+ x = block(x, rope_cos, rope_sin)
178
+ x = self.norm(x)
179
+ logits = self.lm_head(x)
180
+ loss = None
181
+ if targets is not None:
182
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
183
+ return logits, loss
184
+
185
+ @torch.no_grad()
186
+ def generate(self, idx, max_new_tokens, temperature=0.8, top_k=50, top_p=0.9):
187
+ self.use_checkpointing = False # No checkpointing during generation
188
+ for _ in range(max_new_tokens):
189
+ idx_cond = idx if idx.size(1) <= cfg.block_size else idx[:, -cfg.block_size:]
190
+ logits, _ = self(idx_cond)
191
+ logits = logits[:, -1, :]
192
+ if temperature == 0:
193
+ idx_next = logits.argmax(dim=-1, keepdim=True)
194
+ else:
195
+ logits = logits / temperature
196
+ if top_k > 0:
197
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
198
+ logits[logits < v[:, [-1]]] = float('-inf')
199
+ if top_p < 1.0:
200
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
201
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
202
+ sorted_indices_to_remove = cumulative_probs > top_p
203
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
204
+ sorted_indices_to_remove[..., 0] = 0
205
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
206
+ logits[indices_to_remove] = float('-inf')
207
+ probs = F.softmax(logits, dim=-1)
208
+ idx_next = torch.multinomial(probs, num_samples=1)
209
+ idx = torch.cat([idx, idx_next], dim=1)
210
+ if idx_next.item() == 3: # <eos>
211
+ break
212
+ self.use_checkpointing = getattr(cfg, 'gradient_checkpointing', True)
213
+ return idx
214
+
215
+ def count_parameters(self):
216
+ return sum(p.numel() for p in self.parameters())
217
+
218
+
219
+ if __name__ == "__main__":
220
+ model = RoleSLM()
221
+ x = torch.randint(0, cfg.vocab_size, (1, 32))
222
+ logits, loss = model(x, x)
223
+ print(f"Test forward: logits={logits.shape}, loss={loss.item():.4f}")
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c37efec5f714363958d855c92cb131c00aa8dc31ea09c83f013ffd641a297990
3
+ size 4186878264
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:217f370c5c5db93d56c09bd8d623cc8af8cc0355d87da08662e14cd154582d5c
3
+ size 4066376683
system_admin_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "PreTrainedTokenizerFast",
3
+ "bos_token": "<bos>",
4
+ "eos_token": "<eos>",
5
+ "unk_token": "<unk>",
6
+ "pad_token": "<pad>",
7
+ "model_max_length": 5000000
8
+ }