File size: 1,799 Bytes
b871d11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/env python3
"""Test generation with Loop Attention (use_cache=False)."""

import sys
import torch
sys.path.insert(0, '/content')
from modeling_qwen_loop import Qwen3LoopForCausalLM
from transformers import AutoTokenizer

MODEL_PATH = "/content/Qwen3-0.6B"
GATE_PATH = "/content/Qwen3-0.6B-looped/checkpoints/gate_projections_epoch_3.pt"

print("\n1. Loading model...")
model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


print("2. Loading trained gates...")
gate_state = torch.load(GATE_PATH, map_location=device)
for key, value in gate_state.items():
    parts = key.split('.')
    layer_idx = int(parts[1])
    param_name = parts[-1]
    if param_name == 'weight':
        model.model.layers[layer_idx].self_attn.gate.weight.data = value.to(device)
    elif param_name == 'bias':
        model.model.layers[layer_idx].self_attn.gate.bias.data = value.to(device)
print("   Gates loaded!")

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token = tokenizer.eos_token

model.eval()

prompts = [
    "The capital of France is",
    "def fibonacci(n):",
    "In the year 2050,",
    "The quick brown fox",
    "Explain quantum computing in simple terms:"
]


for prompt in prompts:
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
 
        out = model.generate(
            input_ids=inputs.input_ids,
            max_new_tokens=50,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            use_cache=False, 
            pad_token_id=tokenizer.eos_token_id
        )

    text = tokenizer.decode(out[0], skip_special_tokens=True)
    print(f"\nPrompt: {prompt}")
    print(f"Output: {text}")