arman-bd commited on
Commit
dd35755
·
verified ·
1 Parent(s): cb15325

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.py +36 -0
  2. inference.py +124 -0
  3. model.py +129 -0
config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GuppyLM configuration."""
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class GuppyConfig:
8
+ vocab_size: int = 4096
9
+ max_seq_len: int = 128
10
+ d_model: int = 384
11
+ n_layers: int = 6
12
+ n_heads: int = 6
13
+ ffn_hidden: int = 768
14
+ dropout: float = 0.1
15
+
16
+ # Special tokens
17
+ pad_id: int = 0
18
+ bos_id: int = 1 # <|im_start|>
19
+ eos_id: int = 2 # <|im_end|>
20
+
21
+
22
+ @dataclass
23
+ class TrainConfig:
24
+ batch_size: int = 32
25
+ learning_rate: float = 3e-4
26
+ min_lr: float = 3e-5
27
+ weight_decay: float = 0.1
28
+ warmup_steps: int = 200
29
+ max_steps: int = 10000
30
+ eval_interval: int = 200
31
+ save_interval: int = 500
32
+ grad_clip: float = 1.0
33
+ device: str = "auto"
34
+ seed: int = 42
35
+ data_dir: str = "data"
36
+ output_dir: str = "checkpoints"
inference.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GuppyLM inference — simple chat."""
2
+
3
+ import json
4
+ import time
5
+ import uuid
6
+
7
+ import torch
8
+ from tokenizers import Tokenizer
9
+
10
+ from config import GuppyConfig
11
+ from model import GuppyLM
12
+
13
+
14
+ class GuppyInference:
15
+ def __init__(self, checkpoint_path, tokenizer_path, device="cpu"):
16
+ self.device = torch.device(device)
17
+ self.tokenizer = Tokenizer.from_file(tokenizer_path)
18
+
19
+ import os
20
+ ckpt = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
21
+
22
+ # Load config.json from same directory as the model file
23
+ config_dir = os.path.dirname(os.path.abspath(checkpoint_path))
24
+ config_path = os.path.join(config_dir, "config.json")
25
+
26
+ # Extract state_dict — handle both legacy and standard formats
27
+ if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
28
+ state_dict = ckpt["model_state_dict"]
29
+ else:
30
+ state_dict = ckpt
31
+
32
+ # Load config — try config.json first, fall back to embedded config
33
+ if os.path.exists(config_path):
34
+ with open(config_path) as f:
35
+ cfg = json.load(f)
36
+ # Support both HF standard keys and our own keys
37
+ self.config = GuppyConfig(
38
+ vocab_size=cfg.get("vocab_size", 4096),
39
+ max_seq_len=cfg.get("max_position_embeddings", cfg.get("max_seq_len", 128)),
40
+ d_model=cfg.get("hidden_size", cfg.get("d_model", 384)),
41
+ n_layers=cfg.get("num_hidden_layers", cfg.get("n_layers", 6)),
42
+ n_heads=cfg.get("num_attention_heads", cfg.get("n_heads", 6)),
43
+ ffn_hidden=cfg.get("intermediate_size", cfg.get("ffn_hidden", 768)),
44
+ dropout=cfg.get("hidden_dropout_prob", cfg.get("dropout", 0.1)),
45
+ pad_id=cfg.get("pad_token_id", cfg.get("pad_id", 0)),
46
+ bos_id=cfg.get("bos_token_id", cfg.get("bos_id", 1)),
47
+ eos_id=cfg.get("eos_token_id", cfg.get("eos_id", 2)),
48
+ )
49
+ elif isinstance(ckpt, dict) and "config" in ckpt:
50
+ valid_fields = {f.name for f in GuppyConfig.__dataclass_fields__.values()}
51
+ self.config = GuppyConfig(**{k: v for k, v in ckpt["config"].items() if k in valid_fields})
52
+ else:
53
+ print("Warning: No config found, using defaults")
54
+ self.config = GuppyConfig()
55
+
56
+ self.model = GuppyLM(self.config).to(self.device)
57
+ filtered = {k: v for k, v in state_dict.items() if k in self.model.state_dict()}
58
+ self.model.load_state_dict(filtered)
59
+ self.model.eval()
60
+
61
+ total, _ = self.model.param_count()
62
+ print(f"GuppyLM loaded: {total/1e6:.1f}M params")
63
+
64
+ def chat_completion(self, messages, temperature=0.7, max_tokens=64,
65
+ top_k=50, **kwargs):
66
+ """Chat completion — takes messages, returns response."""
67
+ prompt = self._format_prompt(messages)
68
+ input_ids = self.tokenizer.encode(prompt).ids
69
+ prompt_tokens = len(input_ids)
70
+ input_t = torch.tensor([input_ids], dtype=torch.long, device=self.device)
71
+
72
+ output_t, _ = self.model.generate(input_t, max_tokens, temperature, top_k)
73
+ output_text = self.tokenizer.decode(output_t[0].tolist()[prompt_tokens:])
74
+ # Truncate at first <|im_end|> — don't let the model leak into the next turn
75
+ if "<|im_end|>" in output_text:
76
+ output_text = output_text.split("<|im_end|>")[0]
77
+ # Also strip any <|im_start|> fragments
78
+ if "<|im_start|>" in output_text:
79
+ output_text = output_text.split("<|im_start|>")[0]
80
+ resp_text = output_text.strip()
81
+
82
+ return {
83
+ "choices": [{
84
+ "message": {"role": "assistant", "content": resp_text},
85
+ }],
86
+ }
87
+
88
+ def _format_prompt(self, messages):
89
+ parts = []
90
+ for msg in messages:
91
+ role = msg.get("role", "user")
92
+ content = msg.get("content") or ""
93
+ if role == "system":
94
+ continue
95
+ parts.append(f"<|im_start|>{role}\n{content}<|im_end|>")
96
+ parts.append("<|im_start|>assistant\n")
97
+ return "\n".join(parts)
98
+
99
+
100
+ def main():
101
+ import argparse
102
+ p = argparse.ArgumentParser(description="Chat with Guppy")
103
+ p.add_argument("--checkpoint", default="checkpoints/best_model.pt")
104
+ p.add_argument("--tokenizer", default="data/tokenizer.json")
105
+ p.add_argument("--device", default="cpu")
106
+ args = p.parse_args()
107
+
108
+ engine = GuppyInference(args.checkpoint, args.tokenizer, args.device)
109
+ print("\nGuppy Chat (type 'quit' to exit)")
110
+ msgs = []
111
+ while True:
112
+ inp = input("\nYou> ").strip()
113
+ if inp.lower() in ("quit", "exit", "q"):
114
+ break
115
+ msgs.append({"role": "user", "content": inp})
116
+ result = engine.chat_completion(msgs)
117
+ msg = result["choices"][0]["message"]
118
+ if msg.get("content"):
119
+ print(f"Guppy> {msg['content']}")
120
+ msgs.append(msg)
121
+
122
+
123
+ if __name__ == "__main__":
124
+ main()
model.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GuppyLM — a tiny fish brain.
3
+
4
+ Vanilla transformer: multi-head attention, ReLU FFN, LayerNorm, learned positional embeddings.
5
+ No GQA, no SwiGLU, no parallel residual, no RoPE. As simple as it gets.
6
+ """
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from config import GuppyConfig
13
+
14
+
15
+ class Attention(nn.Module):
16
+ def __init__(self, config):
17
+ super().__init__()
18
+ self.n_heads = config.n_heads
19
+ self.head_dim = config.d_model // config.n_heads
20
+
21
+ self.qkv = nn.Linear(config.d_model, 3 * config.d_model)
22
+ self.out = nn.Linear(config.d_model, config.d_model)
23
+ self.dropout = nn.Dropout(config.dropout)
24
+
25
+ def forward(self, x, mask=None):
26
+ B, T, C = x.shape
27
+ qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
28
+ q, k, v = qkv[0], qkv[1], qkv[2]
29
+
30
+ attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
31
+ if mask is not None:
32
+ attn = attn.masked_fill(mask == 0, float("-inf"))
33
+ attn = self.dropout(F.softmax(attn, dim=-1))
34
+ return self.out((attn @ v).transpose(1, 2).contiguous().view(B, T, C))
35
+
36
+
37
+ class FFN(nn.Module):
38
+ def __init__(self, config):
39
+ super().__init__()
40
+ self.up = nn.Linear(config.d_model, config.ffn_hidden)
41
+ self.down = nn.Linear(config.ffn_hidden, config.d_model)
42
+ self.dropout = nn.Dropout(config.dropout)
43
+
44
+ def forward(self, x):
45
+ return self.dropout(self.down(F.relu(self.up(x))))
46
+
47
+
48
+ class Block(nn.Module):
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ self.norm1 = nn.LayerNorm(config.d_model)
52
+ self.attn = Attention(config)
53
+ self.norm2 = nn.LayerNorm(config.d_model)
54
+ self.ffn = FFN(config)
55
+
56
+ def forward(self, x, mask=None):
57
+ x = x + self.attn(self.norm1(x), mask)
58
+ x = x + self.ffn(self.norm2(x))
59
+ return x
60
+
61
+
62
+ class GuppyLM(nn.Module):
63
+ def __init__(self, config: GuppyConfig):
64
+ super().__init__()
65
+ self.config = config
66
+
67
+ self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
68
+ self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
69
+ self.drop = nn.Dropout(config.dropout)
70
+ self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)])
71
+ self.norm = nn.LayerNorm(config.d_model)
72
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
73
+ self.lm_head.weight = self.tok_emb.weight # tie weights
74
+
75
+ self.apply(self._init_weights)
76
+
77
+ def _init_weights(self, m):
78
+ if isinstance(m, nn.Linear):
79
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
80
+ if m.bias is not None:
81
+ nn.init.zeros_(m.bias)
82
+ elif isinstance(m, nn.Embedding):
83
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
84
+
85
+ def forward(self, idx, targets=None):
86
+ B, T = idx.shape
87
+ pos = torch.arange(T, device=idx.device)
88
+ x = self.drop(self.tok_emb(idx) + self.pos_emb(pos))
89
+ mask = torch.tril(torch.ones(T, T, device=idx.device)).unsqueeze(0).unsqueeze(0)
90
+
91
+ for block in self.blocks:
92
+ x = block(x, mask)
93
+
94
+ logits = self.lm_head(self.norm(x))
95
+
96
+ loss = None
97
+ if targets is not None:
98
+ loss = F.cross_entropy(
99
+ logits.view(-1, self.config.vocab_size),
100
+ targets.view(-1),
101
+ ignore_index=0,
102
+ )
103
+
104
+ return logits, loss
105
+
106
+ @torch.no_grad()
107
+ def generate(self, idx, max_new_tokens=64, temperature=0.7, top_k=50, **kwargs):
108
+ self.eval()
109
+ for _ in range(max_new_tokens):
110
+ idx_cond = idx[:, -self.config.max_seq_len:]
111
+ logits, _ = self(idx_cond)
112
+ logits = logits[:, -1, :] / temperature
113
+ if top_k > 0:
114
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
115
+ logits[logits < v[:, [-1]]] = float("-inf")
116
+ probs = F.softmax(logits, dim=-1)
117
+ next_id = torch.multinomial(probs, num_samples=1)
118
+ idx = torch.cat([idx, next_id], dim=1)
119
+ if next_id.item() == self.config.eos_id:
120
+ break
121
+ return idx, []
122
+
123
+ def param_count(self):
124
+ total = sum(p.numel() for p in self.parameters())
125
+ return total, 0
126
+
127
+ def param_summary(self):
128
+ total, _ = self.param_count()
129
+ return f"GuppyLM: {total:,} params ({total/1e6:.1f}M)"