coolpoodle commited on
Commit
b871d11
·
verified ·
1 Parent(s): 2d58dce

Uploaded Training / Testing File / Eval

Browse files

Make sure to update file paths if intend to use this.

<3

Files changed (3) hide show
  1. baseline_eval.py +73 -0
  2. test_loop_generation.py +63 -0
  3. train.py +179 -0
baseline_eval.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Baseline evaluation - compare Loop model to standard Qwen3."""
2
+
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from datasets import load_dataset
6
+ from torch.utils.data import DataLoader
7
+ from tqdm import tqdm
8
+
9
+ MODEL_PATH = "/content/Qwen3-0.6B"
10
+ BATCH_SIZE = 8
11
+ MAX_LENGTH = 256
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16).to(device)
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
17
+ tokenizer.pad_token = tokenizer.eos_token
18
+
19
+ print("\n2. Loading validation data...")
20
+ dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
21
+
22
+ def tokenize_fn(examples):
23
+ return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, padding="max_length")
24
+
25
+ tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
26
+ tokenized = tokenized.filter(lambda x: sum(1 for t in x["input_ids"] if t != tokenizer.pad_token_id) > 10)
27
+
28
+ val_data = tokenized["validation"]
29
+ print(f" Validation samples: {len(val_data)}")
30
+
31
+ def collate_fn(batch):
32
+ input_ids = torch.tensor([x["input_ids"] for x in batch])
33
+ attention_mask = torch.tensor([x["attention_mask"] for x in batch])
34
+ labels = input_ids.clone()
35
+ labels[attention_mask == 0] = -100
36
+ return {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device), "labels": labels.to(device)}
37
+
38
+ loader = DataLoader(val_data, batch_size=BATCH_SIZE, collate_fn=collate_fn)
39
+
40
+ print("\n3. Calculating Baseline Loss...")
41
+ model.eval()
42
+ total_loss = 0
43
+ steps = 0
44
+
45
+ with torch.no_grad():
46
+ for batch in tqdm(loader, desc="Evaluating"):
47
+ with torch.amp.autocast('cuda', dtype=torch.float16):
48
+ outputs = model(**batch)
49
+ total_loss += outputs.loss.item()
50
+ steps += 1
51
+
52
+ baseline_loss = total_loss / steps
53
+ baseline_ppl = torch.exp(torch.tensor(baseline_loss)).item()
54
+
55
+ print("\n" + "=" * 60)
56
+ print("RESULTS")
57
+ print("=" * 60)
58
+ print(f"Baseline Qwen3-0.6B Loss: {baseline_loss:.4f}")
59
+ print(f"Baseline Qwen3-0.6B PPL: {baseline_ppl:.2f}")
60
+ print(f"")
61
+ print(f"Loop Attention Loss: 3.5549 (Epoch 3)")
62
+ print(f"Loop Attention PPL: 35.01")
63
+ print(f"")
64
+
65
+ if baseline_loss > 3.5549:
66
+ delta = baseline_loss - 3.5549
67
+ print(f"✅ SUCCESS: Loop Attention beats baseline by {delta:.4f}!")
68
+ elif abs(baseline_loss - 3.5549) < 0.05:
69
+ print("📊 NEUTRAL: Loop Attention matches baseline (within noise).")
70
+ else:
71
+ delta = 3.5549 - baseline_loss
72
+ print(f"📉 Loop Attention is {delta:.4f} behind baseline.")
73
+ print("=" * 60)
test_loop_generation.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Test generation with Loop Attention (use_cache=False)."""
3
+
4
+ import sys
5
+ import torch
6
+ sys.path.insert(0, '/content')
7
+ from modeling_qwen_loop import Qwen3LoopForCausalLM
8
+ from transformers import AutoTokenizer
9
+
10
+ MODEL_PATH = "/content/Qwen3-0.6B"
11
+ GATE_PATH = "/content/Qwen3-0.6B-looped/checkpoints/gate_projections_epoch_3.pt"
12
+
13
+ print("\n1. Loading model...")
14
+ model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH)
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model = model.to(device)
17
+
18
+
19
+ print("2. Loading trained gates...")
20
+ gate_state = torch.load(GATE_PATH, map_location=device)
21
+ for key, value in gate_state.items():
22
+ parts = key.split('.')
23
+ layer_idx = int(parts[1])
24
+ param_name = parts[-1]
25
+ if param_name == 'weight':
26
+ model.model.layers[layer_idx].self_attn.gate.weight.data = value.to(device)
27
+ elif param_name == 'bias':
28
+ model.model.layers[layer_idx].self_attn.gate.bias.data = value.to(device)
29
+ print(" Gates loaded!")
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
32
+ tokenizer.pad_token = tokenizer.eos_token
33
+
34
+ model.eval()
35
+
36
+ prompts = [
37
+ "The capital of France is",
38
+ "def fibonacci(n):",
39
+ "In the year 2050,",
40
+ "The quick brown fox",
41
+ "Explain quantum computing in simple terms:"
42
+ ]
43
+
44
+
45
+ for prompt in prompts:
46
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
47
+
48
+ with torch.no_grad():
49
+
50
+ out = model.generate(
51
+ input_ids=inputs.input_ids,
52
+ max_new_tokens=50,
53
+ do_sample=True,
54
+ temperature=0.7,
55
+ top_p=0.9,
56
+ use_cache=False,
57
+ pad_token_id=tokenizer.eos_token_id
58
+ )
59
+
60
+ text = tokenizer.decode(out[0], skip_special_tokens=True)
61
+ print(f"\nPrompt: {prompt}")
62
+ print(f"Output: {text}")
63
+
train.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import json
5
+ import torch
6
+ import glob
7
+
8
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
9
+ torch.set_float32_matmul_precision('high')
10
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
11
+
12
+ from torch.utils.data import DataLoader
13
+ from transformers import AutoTokenizer
14
+ from datasets import load_dataset
15
+ from tqdm import tqdm
16
+
17
+ sys.path.insert(0, '/content/Qwen3-0.6B-looped')
18
+ from modeling_qwen_loop import Qwen3LoopForCausalLM
19
+
20
+ MODEL_PATH = "/content/Qwen3-0.6B"
21
+ OUTPUT_DIR = "/content/Qwen3-0.6B-looped/checkpoints"
22
+ BATCH_SIZE = 20
23
+ GRADIENT_ACCUMULATION_STEPS = 4
24
+ LEARNING_RATE = 1e-4
25
+ MAX_LENGTH = 1024
26
+ NUM_EPOCHS = 3
27
+ NUM_WORKERS = 8
28
+ PIN_MEMORY = True
29
+
30
+ print("=" * 60)
31
+ print("TRAINING v3: Optimized (Compile + Workers + Checkpointing)")
32
+ print("=" * 60)
33
+
34
+ print("\n1. Loading model...")
35
+ checkpoints = sorted(glob.glob(f"{OUTPUT_DIR}/epoch_*"))
36
+ start_epoch = 0
37
+
38
+ if checkpoints:
39
+ latest_checkpoint = checkpoints[-1]
40
+ print(f" Resuming from checkpoint: {latest_checkpoint}")
41
+ model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH)
42
+ state_path = os.path.join(latest_checkpoint, "pytorch_model.bin")
43
+ if os.path.exists(state_path):
44
+ model.load_state_dict(torch.load(state_path))
45
+ else:
46
+ print(" Warning: Checkpoint found but pytorch_model.bin missing. Starting fresh.")
47
+
48
+ try:
49
+ start_epoch = int(latest_checkpoint.split("_")[-1])
50
+ print(f" Resuming at Epoch {start_epoch + 1}")
51
+ except:
52
+ start_epoch = 0
53
+ else:
54
+ model = Qwen3LoopForCausalLM.from_pretrained(MODEL_PATH)
55
+
56
+ device = torch.device("cuda")
57
+ model = model.to(device)
58
+
59
+ print("\n2. Unfreezing gates + layer norms...")
60
+ model.enable_gate_and_layernorm_training()
61
+
62
+ print(" Compiling model with torch.compile()...")
63
+ try:
64
+ model = torch.compile(model)
65
+ except Exception as e:
66
+ print(f" Warning: torch.compile failed (ignoring): {e}")
67
+
68
+ print("\n3. Loading WikiText-2...")
69
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
70
+ tokenizer.pad_token = tokenizer.eos_token
71
+
72
+ dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
73
+
74
+ def tokenize_fn(examples):
75
+ return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, padding="max_length")
76
+
77
+ tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
78
+ tokenized = tokenized.filter(lambda x: sum(1 for t in x["input_ids"] if t != tokenizer.pad_token_id) > 10)
79
+
80
+ print(f" Train samples: {len(tokenized['train'])}")
81
+ print(f" Val samples: {len(tokenized['validation'])}")
82
+
83
+ def collate_fn(batch):
84
+ input_ids = torch.tensor([x["input_ids"] for x in batch])
85
+ attention_mask = torch.tensor([x["attention_mask"] for x in batch])
86
+ labels = input_ids.clone()
87
+ labels[attention_mask == 0] = -100
88
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
89
+
90
+ train_loader = DataLoader(
91
+ tokenized["train"],
92
+ batch_size=BATCH_SIZE,
93
+ shuffle=True,
94
+ collate_fn=collate_fn,
95
+ num_workers=NUM_WORKERS,
96
+ pin_memory=PIN_MEMORY
97
+ )
98
+ val_loader = DataLoader(
99
+ tokenized["validation"],
100
+ batch_size=BATCH_SIZE,
101
+ shuffle=False,
102
+ collate_fn=collate_fn,
103
+ num_workers=NUM_WORKERS,
104
+ pin_memory=PIN_MEMORY
105
+ )
106
+
107
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
108
+ total_steps = len(train_loader) * NUM_EPOCHS // GRADIENT_ACCUMULATION_STEPS
109
+ warmup_steps = total_steps // 10
110
+
111
+ def get_lr(step):
112
+ if step < warmup_steps:
113
+ return step / warmup_steps
114
+ return max(0.1, 1.0 - (step - warmup_steps) / (total_steps - warmup_steps))
115
+
116
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr)
117
+
118
+ print("\n4. Training Configuration:")
119
+ print(f" Context length: {MAX_LENGTH}")
120
+ print(f" Batch size: {BATCH_SIZE} (Effective: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS})")
121
+ print(f" Workers: {NUM_WORKERS}")
122
+ print(f" Total steps: {total_steps}")
123
+
124
+ print("\n" + "=" * 60)
125
+ print("Starting Training...")
126
+ print("=" * 60)
127
+
128
+ scaler = torch.amp.GradScaler('cuda')
129
+ model.train()
130
+ global_step = 0
131
+ start_time = time.time()
132
+
133
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
134
+
135
+ for epoch in range(start_epoch, NUM_EPOCHS):
136
+ epoch_loss = 0
137
+ epoch_steps = 0
138
+ progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
139
+
140
+ for step, batch in enumerate(progress):
141
+ batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
142
+
143
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
144
+ outputs = model(**batch, use_cache=False)
145
+ loss = outputs.loss / GRADIENT_ACCUMULATION_STEPS
146
+
147
+ scaler.scale(loss).backward()
148
+ epoch_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS
149
+ epoch_steps += 1
150
+
151
+ if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
152
+ scaler.unscale_(optimizer)
153
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
154
+ scaler.step(optimizer)
155
+ scaler.update()
156
+ scheduler.step()
157
+ optimizer.zero_grad()
158
+ global_step += 1
159
+
160
+ current_lr = scheduler.get_last_lr()[0]
161
+ mem_usage = torch.cuda.memory_allocated() / 1024**3
162
+ progress.set_postfix(loss=loss.item() * GRADIENT_ACCUMULATION_STEPS, lr=current_lr, mem=f"{mem_usage:.1f}GB")
163
+
164
+
165
+ print(f"Saving checkpoint for Epoch {epoch+1}...")
166
+
167
+ model_to_save = model._orig_mod if hasattr(model, '_orig_mod') else model
168
+ model_to_save.save_pretrained(f"{OUTPUT_DIR}/epoch_{epoch+1}")
169
+
170
+ gate_state_dict = {k: v for k, v in model_to_save.state_dict().items() if 'gate' in k}
171
+ torch.save(gate_state_dict, f"{OUTPUT_DIR}/gate_projections.pt")
172
+ torch.save(gate_state_dict, f"{OUTPUT_DIR}/gate_projections_epoch_{epoch+1}.pt")
173
+
174
+ print("Training complete.")
175
+
176
+ model_to_save = model._orig_mod if hasattr(model, '_orig_mod') else model
177
+ model_to_save.save_pretrained(f"{OUTPUT_DIR}/final")
178
+ gate_state_dict = {k: v for k, v in model_to_save.state_dict().items() if 'gate' in k}
179
+ torch.save(gate_state_dict, f"{OUTPUT_DIR}/gate_projections.pt")