anthonym21 commited on
Commit
4a33e99
·
1 Parent(s): 6e816f4

Upload generate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generate.py +85 -0
generate.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Eve-2-MoE Inference
3
+ ===================
4
+ Quick generation script. Works with local weights or HuggingFace download.
5
+
6
+ Usage:
7
+ python generate.py --prompt "The future of AI is"
8
+ python generate.py --prompt "The future of AI is" --model_path ./model_final/pytorch_model.bin
9
+ python generate.py --prompt "The future of AI is" --hf_repo anthonym21/Eve-2-MoE-250M
10
+ """
11
+
12
+ import argparse
13
+ import torch
14
+ import tiktoken
15
+ from modeling_eve import ModelConfig, DeepSeekMoE
16
+
17
+
18
+ def load_model(model_path: str = None, hf_repo: str = None, device: str = "cuda"):
19
+ config = ModelConfig()
20
+ model = DeepSeekMoE(config)
21
+
22
+ if hf_repo:
23
+ from huggingface_hub import hf_hub_download
24
+ model_path = hf_hub_download(repo_id=hf_repo, filename="pytorch_model.bin")
25
+
26
+ if model_path:
27
+ state_dict = torch.load(model_path, map_location=device, weights_only=True)
28
+ model.load_state_dict(state_dict)
29
+
30
+ return model.to(device).eval()
31
+
32
+
33
+ def generate_streaming(model, prompt: str, max_tokens: int = 200,
34
+ temperature: float = 0.8, top_k: int = 50, device: str = "cuda"):
35
+ enc = tiktoken.get_encoding("gpt2")
36
+ tokens = torch.tensor(enc.encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
37
+
38
+ print(prompt, end="", flush=True)
39
+
40
+ with torch.no_grad():
41
+ for _ in range(max_tokens):
42
+ idx_cond = tokens[:, -model.config.block_size:]
43
+
44
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=(device == "cuda")):
45
+ logits, _ = model(idx_cond)
46
+
47
+ logits = logits[:, -1, :] / temperature
48
+
49
+ if top_k is not None:
50
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
51
+ logits[logits < v[:, [-1]]] = -float("Inf")
52
+
53
+ probs = torch.softmax(logits, dim=-1)
54
+ idx_next = torch.multinomial(probs, num_samples=1)
55
+ tokens = torch.cat((tokens, idx_next), dim=1)
56
+
57
+ print(enc.decode([idx_next.item()]), end="", flush=True)
58
+
59
+ print("\n")
60
+
61
+
62
+ def main():
63
+ p = argparse.ArgumentParser()
64
+ p.add_argument("--prompt", type=str, default="The future of artificial intelligence is")
65
+ p.add_argument("--model_path", type=str, default=None)
66
+ p.add_argument("--hf_repo", type=str, default=None)
67
+ p.add_argument("--max_tokens", type=int, default=200)
68
+ p.add_argument("--temperature", type=float, default=0.8)
69
+ p.add_argument("--top_k", type=int, default=50)
70
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
71
+ args = p.parse_args()
72
+
73
+ if not args.model_path and not args.hf_repo:
74
+ args.hf_repo = "anthonym21/Eve-2-MoE-250M"
75
+
76
+ print(f"Loading model on {args.device}...")
77
+ model = load_model(args.model_path, args.hf_repo, args.device)
78
+ param_count = sum(p.numel() for p in model.parameters())
79
+ print(f"Parameters: {param_count / 1e6:.2f}M\n")
80
+
81
+ generate_streaming(model, args.prompt, args.max_tokens, args.temperature, args.top_k, args.device)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()