Spaces:
Running on Zero
Running on Zero
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| MODEL_ID = "dphn/dolphin-2.9.4-llama3.1-8b" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| ) | |
| messages = [ | |
| {"role": "user", "content": "KV cache 是什么?"}, | |
| ] | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs, use_cache=True) | |
| logits = outputs.logits[:, -1, :] | |
| next_token_id = logits.argmax(dim=-1) | |
| next_token = tokenizer.decode(next_token_id) | |
| print("input token count:", inputs["input_ids"].shape[-1]) | |
| print("vocab size:", logits.shape[-1]) | |
| print("greedy next token id:", next_token_id.item()) | |
| print("greedy next token:", repr(next_token)) | |
| cache = outputs.past_key_values | |
| print("KV cache type:", type(cache).__name__) | |
| try: | |
| print("KV cache layers:", len(cache)) | |
| except TypeError: | |
| print("KV cache layers: unknown") | |
| if hasattr(cache, "get_seq_length"): | |
| print("KV cache sequence length:", cache.get_seq_length()) | |