"""Quick test of model quality with diverse prompts.""" import os, sys, time, torch sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from model.config import ModelConfig from model.transformer import Transformer from model.data import get_tokenizer DPO_CKPT = "/jfs/deepak-kumar/checkpoints_dpo/dpo_final.pt" SFT_CKPT = "/jfs/deepak-kumar/checkpoints_sft/sft_final.pt" CHECKPOINT = DPO_CKPT if os.path.exists(DPO_CKPT) else SFT_CKPT DEVICE = "cuda:0" USER_START = "<|user|>\n" ASST_START = "<|assistant|>\n" TURN_END = "\n<|end|>\n" TEST_PROMPTS = [ "Hi! How are you?", "What is photosynthesis?", "Explain gravity to a 5-year-old.", "Write a short poem about the ocean.", "What are the three states of matter?", "How does a computer work?", "What is the capital of France and why is it famous?", "Give me 3 tips for learning a new language.", "What is machine learning in simple terms?", ] @torch.no_grad() def generate(model, tokenizer, prompt, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.15): input_ids = tokenizer.encode(prompt, add_special_tokens=False) input_ids = torch.tensor([input_ids], dtype=torch.long, device=DEVICE) generated = [] eos_id = tokenizer.eos_token_id end_token_ids = tokenizer.encode("<|end|>", add_special_tokens=False) end_id = end_token_ids[0] if end_token_ids else None user_token_ids = tokenizer.encode("<|user|>", add_special_tokens=False) user_id = user_token_ids[0] if user_token_ids else None stop_ids = set() if eos_id is not None: stop_ids.add(eos_id) if end_id is not None: stop_ids.add(end_id) if user_id is not None: stop_ids.add(user_id) for _ in range(max_new_tokens): with torch.autocast(device_type="cuda", dtype=torch.bfloat16): logits, _ = model(input_ids) logits = logits[:, -1, :].float() if repetition_penalty != 1.0 and generated: for tid in set(generated): if logits[0, tid] > 0: logits[0, tid] /= repetition_penalty else: logits[0, tid] *= repetition_penalty logits = logits / max(temperature, 1e-5) if top_k > 0: topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < topk_vals[:, -1:]] = float('-inf') if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) cumulative = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) remove = cumulative - torch.softmax(sorted_logits, dim=-1) > top_p sorted_logits[remove] = float('-inf') logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, 1) token_id = next_token.item() if token_id in stop_ids: break generated.append(token_id) input_ids = torch.cat([input_ids, next_token], dim=1) if input_ids.size(1) > 2048: break return tokenizer.decode(generated, skip_special_tokens=True) def main(): ckpt_name = "DPO" if "dpo" in CHECKPOINT else "SFT" print("=" * 70) print(" " + ckpt_name + " MODEL TEST") print("=" * 70) tokenizer = get_tokenizer() special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"] vocab = tokenizer.get_vocab() new_tokens = [t for t in special_tokens if t not in vocab] if new_tokens: tokenizer.add_tokens(new_tokens, special_tokens=True) config = ModelConfig() config.vocab_size = len(tokenizer) model = Transformer(config) print("") print("Loading checkpoint: " + CHECKPOINT) ckpt = torch.load(CHECKPOINT, map_location="cpu", weights_only=False) model.load_state_dict(ckpt["model"]) step = ckpt.get("step", "?") del ckpt model = model.to(DEVICE).bfloat16().eval() print("Model loaded (" + ckpt_name + " step " + str(step) + ", vocab " + str(config.vocab_size) + ")") mem = torch.cuda.max_memory_allocated(DEVICE) / 1e9 print("GPU memory: " + str(round(mem, 1)) + " GB") print("-" * 70) for i, question in enumerate(TEST_PROMPTS, 1): prompt = USER_START + question + TURN_END + ASST_START print("") print("[Test " + str(i) + "/" + str(len(TEST_PROMPTS)) + "]") print(" Q: " + question) t0 = time.time() response = generate(model, tokenizer, prompt) dt = time.time() - t0 tokens = len(tokenizer.encode(response, add_special_tokens=False)) response = response.split("<|end|>")[0].split("<|user|>")[0].strip() print(" A: " + response) tps = int(tokens / max(dt, 0.01)) print(" [" + str(tokens) + " tokens, " + str(round(dt, 1)) + "s, " + str(tps) + " tok/s]") print("-" * 70) print("") print("Done!") if __name__ == "__main__": main()