Eve-2-MoE-272M / generate.py
anthonym21's picture
Upload folder using huggingface_hub
dd9d5c4 verified
"""
Eve-2-MoE Inference
===================
Quick generation script. Works with local weights or HuggingFace download.
Usage:
python generate.py --prompt "The future of AI is"
python generate.py --prompt "The future of AI is" --model_path ./model_final/pytorch_model.bin
python generate.py --prompt "The future of AI is" --hf_repo anthonym21/Eve-2-MoE-250M
"""
import argparse
import torch
import tiktoken
from modeling_eve import ModelConfig, DeepSeekMoE
def load_model(model_path: str = None, hf_repo: str = None, device: str = "cuda"):
config = ModelConfig()
model = DeepSeekMoE(config)
if hf_repo:
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id=hf_repo, filename="pytorch_model.bin")
if model_path:
state_dict = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(state_dict)
return model.to(device).eval()
def generate_streaming(model, prompt: str, max_tokens: int = 200,
temperature: float = 0.8, top_k: int = 50, device: str = "cuda"):
enc = tiktoken.get_encoding("gpt2")
tokens = torch.tensor(enc.encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
print(prompt, end="", flush=True)
with torch.no_grad():
for _ in range(max_tokens):
idx_cond = tokens[:, -model.config.block_size:]
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=(device == "cuda")):
logits, _ = model(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("Inf")
probs = torch.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
tokens = torch.cat((tokens, idx_next), dim=1)
print(enc.decode([idx_next.item()]), end="", flush=True)
print("\n")
def main():
p = argparse.ArgumentParser()
p.add_argument("--prompt", type=str, default="The future of artificial intelligence is")
p.add_argument("--model_path", type=str, default=None)
p.add_argument("--hf_repo", type=str, default=None)
p.add_argument("--max_tokens", type=int, default=200)
p.add_argument("--temperature", type=float, default=0.8)
p.add_argument("--top_k", type=int, default=50)
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
args = p.parse_args()
if not args.model_path and not args.hf_repo:
args.hf_repo = "anthonym21/Eve-2-MoE-250M"
print(f"Loading model on {args.device}...")
model = load_model(args.model_path, args.hf_repo, args.device)
param_count = sum(p.numel() for p in model.parameters())
print(f"Parameters: {param_count / 1e6:.2f}M\n")
generate_streaming(model, args.prompt, args.max_tokens, args.temperature, args.top_k, args.device)
if __name__ == "__main__":
main()