|
|
|
|
|
"""Test generation with Loop Attention (use_cache=False).""" |
|
|
|
|
|
import sys |
|
|
import torch |
|
|
sys.path.insert(0, '/content') |
|
|
from modeling_qwen_loop import Qwen3LoopForCausalLM |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
MODEL_PATH = "/content/Qwen3-0.6B" |
|
|
GATE_PATH = "/content/Qwen3-0.6B-looped/checkpoints/gate_projections_epoch_3.pt" |
|
|
|
|
|
print("\n1. Loading model...") |
|
|
model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
print("2. Loading trained gates...") |
|
|
gate_state = torch.load(GATE_PATH, map_location=device) |
|
|
for key, value in gate_state.items(): |
|
|
parts = key.split('.') |
|
|
layer_idx = int(parts[1]) |
|
|
param_name = parts[-1] |
|
|
if param_name == 'weight': |
|
|
model.model.layers[layer_idx].self_attn.gate.weight.data = value.to(device) |
|
|
elif param_name == 'bias': |
|
|
model.model.layers[layer_idx].self_attn.gate.bias.data = value.to(device) |
|
|
print(" Gates loaded!") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model.eval() |
|
|
|
|
|
prompts = [ |
|
|
"The capital of France is", |
|
|
"def fibonacci(n):", |
|
|
"In the year 2050,", |
|
|
"The quick brown fox", |
|
|
"Explain quantum computing in simple terms:" |
|
|
] |
|
|
|
|
|
|
|
|
for prompt in prompts: |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
out = model.generate( |
|
|
input_ids=inputs.input_ids, |
|
|
max_new_tokens=50, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
use_cache=False, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
text = tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
print(f"\nPrompt: {prompt}") |
|
|
print(f"Output: {text}") |
|
|
|
|
|
|