edgemindroboticslabs commited on
Commit
9e87788
·
verified ·
1 Parent(s): 1059a9e

Upload generate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generate.py +76 -0
generate.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate text from a trained checkpoint."""
2
+ import argparse
3
+ import torch
4
+
5
+ from model import GPT, GPTConfig
6
+ from tokenizer import load_tokenizer
7
+
8
+
9
+ def get_device():
10
+ if torch.backends.mps.is_available():
11
+ return torch.device("mps")
12
+ if torch.cuda.is_available():
13
+ return torch.device("cuda")
14
+ return torch.device("cpu")
15
+
16
+
17
+ def load_model(checkpoint_path, device):
18
+ ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
19
+ config = GPTConfig(**ckpt["config"])
20
+ model = GPT(config).to(device)
21
+ model.load_state_dict(ckpt["model_state"])
22
+ model.eval()
23
+ return model
24
+
25
+
26
+ def alpaca_prompt(instruction, input_text=""):
27
+ """Format a prompt in Alpaca instruction style (for models trained on Alpaca)."""
28
+ if input_text.strip():
29
+ return f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
30
+ return f"### Instruction:\n{instruction}\n\n### Response:\n"
31
+
32
+
33
+ def generate_text(model, tokenizer, prompt, max_new_tokens=200, temperature=1.0, top_k=40, device="cpu"):
34
+ encoded = tokenizer.encode(prompt)
35
+ if not encoded:
36
+ encoded = [0]
37
+ idx = torch.tensor([encoded], dtype=torch.long, device=device)
38
+ with torch.no_grad():
39
+ out = model.generate(idx, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)
40
+ return tokenizer.decode(out[0].tolist())
41
+
42
+
43
+ if __name__ == "__main__":
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument("--checkpoint", default="checkpoints/best_model.pt")
46
+ parser.add_argument("--tokenizer", default="tokenizer.json")
47
+ parser.add_argument("--prompt", default="To be or not to be")
48
+ parser.add_argument("--instruction", default=None,
49
+ help="Use Alpaca-style prompt. Overrides --prompt.")
50
+ parser.add_argument("--input", default="", help="Optional input for Alpaca prompt")
51
+ parser.add_argument("--max_new_tokens", type=int, default=300)
52
+ parser.add_argument("--temperature", type=float, default=0.8)
53
+ parser.add_argument("--top_k", type=int, default=40)
54
+ args = parser.parse_args()
55
+
56
+ device = get_device()
57
+ print(f"Device: {device}")
58
+
59
+ tokenizer = load_tokenizer(args.tokenizer)
60
+ model = load_model(args.checkpoint, device)
61
+ print(f"Model loaded ({model.num_params():,} params)\n")
62
+
63
+ if args.instruction:
64
+ prompt = alpaca_prompt(args.instruction, args.input)
65
+ print(f"Prompt:\n{prompt}")
66
+ else:
67
+ prompt = args.prompt
68
+
69
+ result = generate_text(
70
+ model, tokenizer, prompt,
71
+ max_new_tokens=args.max_new_tokens,
72
+ temperature=args.temperature,
73
+ top_k=args.top_k,
74
+ device=device,
75
+ )
76
+ print(result)