File size: 1,166 Bytes
d0f40b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""
Inference script for bitskip-v2-earlyexit
"""

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def main():
    # Load from HuggingFace Hub or local path
    model_path = "."  # Current directory or specify repo_id
    
    print("Loading model...")
    model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    model.eval()
    print("Model loaded!")
    
    # Example generation
    prompt = "Once upon a time"
    inputs = tokenizer(prompt, return_tensors="pt")
    
    print(f"\nPrompt: {prompt}\n")
    
    # Full model
    print("Generating with all layers...")
    outputs = model.generate(**inputs, max_length=100, pad_token_id=tokenizer.eos_token_id)
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))
    
    # Early exit at layer 12
    print("\nGenerating with early exit at layer 12...")
    model.set_exit_layer(12)
    outputs = model.generate(**inputs, max_length=100, pad_token_id=tokenizer.eos_token_id)
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))

if __name__ == "__main__":
    main()