""" LUNA SFT — Interactive Chat Loads the SFT fine-tuned model once, then lets you chat continuously. Usage: python chat.py python chat.py --ckpt "D:\\ASTERIZER 2026\\LUNA\\Base\\out\\sft\\model.pth" python chat.py --max_new 300 --temp 0.7 """ import sys, argparse, torch import torch.nn as nn import torch.nn.functional as F from pathlib import Path # ─── Model (must match train.py) ───────────────────────────────────────────── class RotaryEmbedding(nn.Module): def __init__(self, dim, max_seq_len=1024): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) t = torch.arange(max_seq_len).float() freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer("cos_cached", emb.cos()) self.register_buffer("sin_cached", emb.sin()) def forward(self, seq_len): return self.cos_cached[:seq_len], self.sin_cached[:seq_len] def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def apply_rotary(x, cos, sin): c = cos.unsqueeze(0).unsqueeze(0) s = sin.unsqueeze(0).unsqueeze(0) return x * c + rotate_half(x) * s class CausalSelfAttention(nn.Module): def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25): super().__init__() self.n_head = n_head self.head_dim = n_embd // n_head self.rot_dim = int(self.head_dim * rotary_pct) self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True) self.c_proj = nn.Linear(n_embd, n_embd, bias=True) self.rotary = RotaryEmbedding(self.rot_dim, block_size) def forward(self, x): B, T, C = x.size() qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) cos, sin = self.rotary(T) q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1) k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C)) class MLP(nn.Module): def __init__(self, n_embd): super().__init__() self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True) self.gelu = nn.GELU() self.proj = nn.Linear(4 * n_embd, n_embd, bias=True) def forward(self, x): return self.proj(self.gelu(self.fc(x))) class Block(nn.Module): def __init__(self, n_embd, n_head, block_size): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_embd, n_head, block_size) self.ln2 = nn.LayerNorm(n_embd) self.mlp = MLP(n_embd) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class LUNAModel(nn.Module): def __init__(self, vocab_size=50304, block_size=1024, n_layer=10, n_embd=768, n_head=12): super().__init__() self.block_size = block_size self.wte = nn.Embedding(vocab_size, n_embd) self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) self.lm_head.weight = self.wte.weight def forward(self, idx): x = self.wte(idx) for block in self.blocks: x = block(x) return self.lm_head(self.ln_f(x)) # ─── Generation ─────────────────────────────────────────────────────────────── @torch.no_grad() def generate(model, input_ids, max_new=200, temperature=0.7, top_p=0.9, top_k=50, repetition_penalty=1.1, device="cpu"): ids = input_ids.to(device) generated = [] for _ in range(max_new): logits = model(ids[:, -model.block_size:])[:, -1, :] # Repetition penalty if repetition_penalty != 1.0: for tok_id in set(ids[0].tolist()): if logits[0, tok_id] > 0: logits[0, tok_id] /= repetition_penalty else: logits[0, tok_id] *= repetition_penalty if temperature < 1e-6: next_token = logits.argmax(dim=-1, keepdim=True) else: logits = logits / temperature probs = F.softmax(logits, dim=-1) # Top-k if top_k > 0: kval = min(top_k, probs.size(-1)) topk_vals, _ = torch.topk(probs, kval) probs[probs < topk_vals[:, [-1]]] = 0.0 probs /= probs.sum() # Top-p if top_p < 1.0: sorted_probs, sorted_idx = torch.sort(probs, descending=True) cumsum = torch.cumsum(sorted_probs, dim=-1) mask = cumsum - sorted_probs > top_p sorted_probs[mask] = 0.0 sorted_probs /= sorted_probs.sum() next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)] else: next_token = torch.multinomial(probs[0], 1) ids = torch.cat([ids, next_token.view(1, 1)], dim=1) generated.append(next_token.item()) if next_token.item() == 0: # EOS (pythia tokenizer) break return generated # ─── Alpaca prompt template ─────────────────────────────────────────────────── # Prompt format matching sft_train.py exactly (no preamble) def format_prompt(instruction, context=""): inst = instruction.strip() ctx = context.strip() if inst and ctx: return f"### Instruction:\n{inst}\n\n### Input:\n{ctx}\n\n### Response:\n" elif inst: return f"### Instruction:\n{inst}\n\n### Response:\n" else: return f"### Input:\n{ctx}\n\n### Response:\n" # ─── Main ───────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="LUNA SFT — Interactive Chat") parser.add_argument("--ckpt", default=r"D:\ASTERIZER 2026\LUNA\Base\out\sft\model.pth") parser.add_argument("--tok_dir", default="Base/checkpoints/EleutherAI/pythia-160m") parser.add_argument("--max_new", type=int, default=150) parser.add_argument("--temp", type=float, default=0.7) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--top_k", type=int, default=40) parser.add_argument("--rep_pen", type=float, default=1.0) parser.add_argument("--device", default="auto") args = parser.parse_args() device = "cuda" if args.device == "auto" and torch.cuda.is_available() else args.device if device == "auto": device = "cpu" print(f"\nDevice: {device}") # Load model print(f"Loading: {args.ckpt}") ckpt = torch.load(args.ckpt, map_location="cpu", weights_only=True) state = ckpt["model"] if "model" in ckpt else ckpt model = LUNAModel() model.load_state_dict(state, strict=True) model = model.to(device).eval() params = sum(p.numel() for p in model.parameters()) print(f" Model loaded: {params:,} parameters") # Load tokenizer from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(args.tok_dir) print(f" Tokenizer: {args.tok_dir} (vocab {tokenizer.vocab_size})") # Chat loop print(f"\n{'='*60}") print(" LUNA — Interactive Chat") print(f" max_new={args.max_new} temp={args.temp} top_p={args.top_p} top_k={args.top_k}") print(f" Type your message and press Enter. Type 'quit' to exit.") print(f"{'='*60}\n") while True: try: user_input = input("You: ").strip() except (EOFError, KeyboardInterrupt): print("\nBye!") break if not user_input: continue if user_input.lower() in ("quit", "exit", "q"): print("Bye!") break prompt = format_prompt(user_input) ids = tokenizer.encode(prompt, return_tensors="pt") tokens = generate( model, ids, max_new=args.max_new, temperature=args.temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.rep_pen, device=device, ) response = tokenizer.decode(tokens, skip_special_tokens=True).strip() # Cut at any trailing ### if model generates next template if "### " in response: response = response.split("### ")[0].strip() print(f"\nLUNA: {response}\n") if __name__ == "__main__": main()