Improve readability of README.md by spinning off the training and inference examples as files in the repo

#1
Files changed (1) hide show
  1. README.md +9 -863
README.md CHANGED
@@ -31,870 +31,16 @@ ldm_weights_path = hf_hub_download(repo_id="dataopsnick/adapt-diff-qwen-0.8b", f
31
  ldm_heads.load_state_dict(torch.load(ldm_weights_path))
32
  ```
33
 
34
- Example:
35
 
36
- ```python
37
- """
38
- ADAPT-DIFF Inference & Benchmark Script
39
- Downloads 'dataopsnick/adapt-diff-qwen-0.8b' and compares it with 'Qwen/Qwen3.5-0.8B'.
40
- """
41
-
42
- import os
43
- import gc
44
- import time
45
- import re
46
- from collections import defaultdict
47
- import torch
48
- import torch.nn as nn
49
- import torch.nn.functional as F
50
-
51
- # 1. Install/Update Dependencies
52
- print("Ensuring dependencies are installed...")
53
- os.system("pip install -q transformers>=4.40.0 datasets>=2.18.0 accelerate>=0.29.0 huggingface_hub")
54
-
55
- import transformers
56
- from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM
57
- from transformers.cache_utils import DynamicCache
58
- from transformers.modeling_outputs import BaseModelOutputWithPast
59
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
60
- from datasets import load_dataset
61
- from huggingface_hub import hf_hub_download
62
-
63
- # Clean up GPU cache before running
64
- gc.collect()
65
- torch.cuda.empty_cache()
66
-
67
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
68
- BASE_MODEL_ID = "Qwen/Qwen3.5-0.8B"
69
- ADAPT_DIFF_ID = "dataopsnick/adapt-diff-qwen-0.8b"
70
-
71
- print(f"Loading {BASE_MODEL_ID} metadata to dynamically resolve architecture classes...")
72
- src_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
73
- if src_tokenizer.pad_token is None:
74
- src_tokenizer.pad_token = src_tokenizer.eos_token
75
-
76
- # Load temporary instance to resolve base classes exactly as in your environment
77
- temp_model = AutoModelForCausalLM.from_pretrained(
78
- BASE_MODEL_ID,
79
- torch_dtype=torch.bfloat16,
80
- device_map="cpu"
81
- )
82
- src_config = temp_model.config
83
-
84
- BaseConfig = src_config.__class__
85
- BaseModel = temp_model.model.__class__
86
- BaseCausalLM = temp_model.__class__
87
-
88
- BasePreTrainedModel = next(
89
- (cls for cls in BaseCausalLM.__mro__ if cls.__name__.endswith("PreTrainedModel")),
90
- None
91
- )
92
- if BasePreTrainedModel is None:
93
- BasePreTrainedModel = BaseCausalLM.__bases__[0]
94
-
95
- # Free temporary model memory
96
- del temp_model
97
- gc.collect()
98
-
99
-
100
- # ==============================================================================
101
- # Custom ADAPT-DIFF Architecture Classes
102
- # ==============================================================================
103
- class A2DQwenConfig(BaseConfig):
104
- model_type = "a2d-qwen"
105
-
106
- class A2DQwenModel(BaseModel):
107
- def forward(
108
- self,
109
- input_ids = None,
110
- attention_mask = None,
111
- position_ids = None,
112
- past_key_values = None,
113
- inputs_embeds = None,
114
- use_cache = None,
115
- cache_position = None,
116
- **kwargs,
117
- ):
118
- if (input_ids is None) ^ (inputs_embeds is not None):
119
- raise ValueError("Specify exactly one of input_ids or inputs_embeds")
120
-
121
- if inputs_embeds is None:
122
- inputs_embeds = self.embed_tokens(input_ids)
123
-
124
- if use_cache and past_key_values is None:
125
- past_key_values = DynamicCache(config=self.config)
126
-
127
- if cache_position is None:
128
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
129
- cache_position = torch.arange(
130
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
131
- )
132
-
133
- if position_ids is None:
134
- position_ids = cache_position.unsqueeze(0)
135
-
136
- # Core ADAPT-DIFF modification: replace causal mask with bidirectional/padding-only mask
137
- if not isinstance(causal_mask_mapping := attention_mask, dict):
138
- if attention_mask is None:
139
- attention_mask = torch.ones(
140
- inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long
141
- )
142
- if not (isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 4):
143
- attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
144
- causal_mask_mapping = defaultdict(lambda: attention_mask)
145
-
146
- hidden_states = inputs_embeds
147
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
148
-
149
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
150
- attn_type = getattr(decoder_layer, "attention_type", "self_attn")
151
- hidden_states = decoder_layer(
152
- hidden_states,
153
- attention_mask=causal_mask_mapping[attn_type],
154
- position_ids=position_ids,
155
- past_key_values=past_key_values,
156
- use_cache=use_cache,
157
- cache_position=cache_position,
158
- position_embeddings=position_embeddings,
159
- **kwargs,
160
- )
161
-
162
- hidden_states = self.norm(hidden_states)
163
- return BaseModelOutputWithPast(
164
- last_hidden_state=hidden_states,
165
- past_key_values=past_key_values if use_cache else None,
166
- )
167
-
168
- class A2DQwenLMHeadModel(BaseCausalLM):
169
- config_class = A2DQwenConfig
170
- def __init__(self, config):
171
- BasePreTrainedModel.__init__(self, config)
172
- self.model = A2DQwenModel(config)
173
- self.vocab_size = config.vocab_size
174
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
175
- self.post_init()
176
-
177
-
178
- # Register custom classes with Hugging Face AutoClasses
179
- transformers.AutoConfig.register("a2d-qwen", A2DQwenConfig)
180
- transformers.AutoModel.register(A2DQwenConfig, A2DQwenLMHeadModel)
181
- transformers.AutoModelForCausalLM.register(A2DQwenConfig, A2DQwenLMHeadModel)
182
-
183
-
184
- # ==============================================================================
185
- # Custom Projection and Search Pipeline Components
186
- # ==============================================================================
187
- class StackedLDMHeads(nn.Module):
188
- def __init__(self, hidden_size, vocab_size, block_size=12):
189
- super().__init__()
190
- self.block_size = block_size
191
- self.proj = nn.Linear(hidden_size, block_size * hidden_size, dtype=torch.bfloat16)
192
- self.head = nn.Linear(hidden_size, vocab_size, dtype=torch.bfloat16)
193
-
194
- def forward(self, hidden_states):
195
- batch_size, seq_len, hidden_size = hidden_states.shape
196
- forecast = self.proj(hidden_states)
197
- forecast = forecast.view(batch_size, seq_len, self.block_size, hidden_size)
198
- logits = self.head(forecast)
199
- return logits
200
-
201
- class LogitUncertaintyFilter(nn.Module):
202
- def compute_entropy(self, logits: torch.Tensor) -> torch.Tensor:
203
- probs = F.softmax(logits.float(), dim=-1)
204
- entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1)
205
- return entropy
206
-
207
- def forward(self, logits: torch.Tensor, threshold: float):
208
- entropy = self.compute_entropy(logits)
209
- mask = entropy >= threshold
210
- return mask, entropy
211
-
212
- class ActorCriticPruner:
213
- def __init__(self, lm_head, lambda_reg=0.1):
214
- self.lm_head = lm_head
215
- self.lambda_reg = lambda_reg
216
-
217
- def evaluate_sequence_value(self, candidate_tokens, logits):
218
- log_probs = F.log_softmax(logits.float(), dim=-1)
219
- gathered = torch.gather(log_probs, -1, candidate_tokens.unsqueeze(-1)).squeeze(-1)
220
- return gathered.mean().item()
221
-
222
- def recursive_refine(self, sequence, logits, mask, entropy, depth, alpha, beta):
223
- refined_sequence = sequence.clone()
224
- if depth == 0 or mask.sum() == 0:
225
- return refined_sequence, self.evaluate_sequence_value(sequence, logits)
226
-
227
- high_unc_positions = torch.where(mask)[0]
228
- if len(high_unc_positions) == 0:
229
- return refined_sequence, self.evaluate_sequence_value(sequence, logits)
230
-
231
- target_pos = high_unc_positions[0].item()
232
- top_logits, top_tokens = torch.topk(logits[target_pos], k=3)
233
-
234
- best_val = float('-inf')
235
- for token_opt in top_tokens:
236
- candidate = sequence.clone()
237
- candidate[target_pos] = token_opt
238
-
239
- approx_val = self.evaluate_sequence_value(candidate, logits) - (self.lambda_reg * entropy[target_pos].item())
240
- if approx_val < alpha:
241
- continue
242
-
243
- new_mask = mask.clone()
244
- new_mask[target_pos] = False
245
-
246
- _, path_val = self.recursive_refine(candidate, logits, new_mask, entropy, depth - 1, alpha, beta)
247
- if path_val > alpha:
248
- alpha = path_val
249
- best_val = path_val
250
- refined_sequence = candidate
251
-
252
- if alpha >= beta:
253
- break
254
-
255
- return refined_sequence, best_val
256
-
257
-
258
- class ADAPTDIFFPipeline(nn.Module):
259
- def __init__(self, base_lm_model, block_size=12, entropy_threshold=1.5):
260
- super().__init__()
261
- self.base_model = base_lm_model.model
262
- self.lm_head = base_lm_model.lm_head
263
- self.block_size = block_size
264
- self.entropy_threshold = entropy_threshold
265
-
266
- self.ldm_heads = StackedLDMHeads(
267
- hidden_size=base_lm_model.config.hidden_size,
268
- vocab_size=base_lm_model.config.vocab_size,
269
- block_size=block_size
270
- ).to(DEVICE)
271
-
272
- self.router = LogitUncertaintyFilter()
273
- self.pruner = ActorCriticPruner(self.lm_head)
274
-
275
- def generate_adapt_diff(self, input_ids, max_new_tokens=128):
276
- current_seq = input_ids.clone()
277
- generated_count = 0
278
- total_full_transformer_evals = 0
279
-
280
- while generated_count < max_new_tokens:
281
- outputs = self.base_model(input_ids=current_seq)
282
- total_full_transformer_evals += 1
283
- last_hidden = outputs.last_hidden_state[:, -1:, :]
284
-
285
- block_logits = self.ldm_heads(last_hidden).squeeze(0).squeeze(0)
286
- draft_tokens = torch.argmax(block_logits, dim=-1)
287
-
288
- mask, entropy = self.router(block_logits, self.entropy_threshold)
289
-
290
- if not mask.any():
291
- final_block = draft_tokens
292
- else:
293
- total_full_transformer_evals += 1
294
- final_block, _ = self.pruner.recursive_refine(
295
- sequence=draft_tokens,
296
- logits=block_logits,
297
- mask=mask,
298
- entropy=entropy,
299
- depth=2,
300
- alpha=float('-inf'),
301
- beta=float('inf')
302
- )
303
-
304
- current_seq = torch.cat([current_seq, final_block.unsqueeze(0)], dim=-1)
305
- generated_count += self.block_size
306
-
307
- return current_seq[0, input_ids.shape[1]:], total_full_transformer_evals
308
-
309
-
310
- # ==============================================================================
311
- # Model Loading & LDM Weights Initialization
312
- # ==============================================================================
313
- print(f"Downloading custom bidirectional model {ADAPT_DIFF_ID} from Hugging Face...")
314
- a2d_model = AutoModelForCausalLM.from_pretrained(
315
- ADAPT_DIFF_ID,
316
- torch_dtype=torch.bfloat16,
317
- device_map=DEVICE
318
- )
319
-
320
- print(f"Downloading baseline model {BASE_MODEL_ID} for comparative evaluation...")
321
- baseline_model = AutoModelForCausalLM.from_pretrained(
322
- BASE_MODEL_ID,
323
- torch_dtype=torch.bfloat16,
324
- device_map=DEVICE
325
- )
326
-
327
- # Initialize generation pipeline and load pre-trained custom LDM weights
328
- pipeline = ADAPTDIFFPipeline(a2d_model, block_size=12, entropy_threshold=1.5)
329
- print("Downloading LDM head projection weights...")
330
- ldm_weights_path = hf_hub_download(repo_id=ADAPT_DIFF_ID, filename="ldm_heads.pt")
331
- pipeline.ldm_heads.load_state_dict(torch.load(ldm_weights_path, map_location=DEVICE))
332
- pipeline.eval()
333
-
334
-
335
- # ==============================================================================
336
- # Sub-Sampled Benchmark Initialization
337
- # ==============================================================================
338
- print("\nLoading GSM8K and MBPP evaluation datasets...")
339
- gsm8k_ds = load_dataset("openai/gsm8k", "main", split="test")
340
- mbpp_ds = load_dataset("google-research-datasets/mbpp", split="test")
341
-
342
- val_math = []
343
- for item in gsm8k_ds:
344
- val_math.append((f"Problem: {item['question']}\nSolution:", item['answer']))
345
- if len(val_math) >= 10: # Fast benchmark slice
346
- break
347
-
348
- val_code = []
349
- for item in mbpp_ds:
350
- val_code.append((f"Write a Python function to solve this task:\n{item['text']}\nSolution:\n", item['code'], item['test_list']))
351
- if len(val_code) >= 10:
352
- break
353
-
354
-
355
- # ==============================================================================
356
- # Validation Helpers
357
- # ==============================================================================
358
- def extract_answer(text):
359
- if "####" in text:
360
- text = text.split("####")[-1]
361
- matches = re.findall(r'-?[\d,]*\.?\d+', text)
362
- return matches[-1].replace(',', '') if matches else None
363
-
364
- def verify_math(generated_text, ref_ans):
365
- pred_val = extract_answer(generated_text)
366
- ref_val = extract_answer(ref_ans)
367
- if pred_val is None or ref_val is None:
368
- return 0.0
369
- try:
370
- return 1.0 if float(pred_val) == float(ref_val) else 0.0
371
- except ValueError:
372
- return 1.0 if str(pred_val).strip() == str(ref_val).strip() else 0.0
373
-
374
- def verify_code(generated_text, test_list):
375
- code_block = generated_text
376
- if "```python" in generated_text:
377
- code_block = generated_text.split("```python")[-1].split("```")[0]
378
- elif "```" in generated_text:
379
- code_block = generated_text.split("```")[-1].split("```")[0]
380
-
381
- local_scope = {}
382
- try:
383
- compiled_code = compile(code_block, "<string>", "exec")
384
- exec(compiled_code, local_scope, local_scope)
385
- for test in test_list:
386
- exec(test, local_scope, local_scope)
387
- return 1.0
388
- except Exception:
389
- return 0.0
390
-
391
-
392
- # ==============================================================================
393
- # Evaluation Loop
394
- # ==============================================================================
395
- def run_benchmark(pipeline, base_model, dataset, is_code=False):
396
- ar_correct = 0
397
- ad_correct = 0
398
- total = len(dataset)
399
-
400
- ar_total_tokens = 0
401
- ad_total_tokens = 0
402
- ar_total_time = 0.0
403
- ad_total_time = 0.0
404
- ad_total_evals = 0
405
-
406
- for idx, item in enumerate(dataset):
407
- prompt = item[0]
408
- inputs = src_tokenizer(prompt, return_tensors="pt").to(DEVICE)
409
- max_new_tokens = 48
410
-
411
- # Autoregressive generation
412
- t_start = time.time()
413
- with torch.no_grad():
414
- ar_outputs = base_model.generate(
415
- **inputs,
416
- max_new_tokens=max_new_tokens,
417
- pad_token_id=src_tokenizer.pad_token_id,
418
- eos_token_id=src_tokenizer.eos_token_id,
419
- do_sample=False
420
- )
421
- ar_total_time += (time.time() - t_start)
422
- ar_gen_tokens = ar_outputs[0][inputs.input_ids.shape[1]:]
423
- ar_total_tokens += len(ar_gen_tokens)
424
- ar_text = src_tokenizer.decode(ar_gen_tokens, skip_special_tokens=True)
425
-
426
- # ADAPT-DIFF speculative generation
427
- t_start = time.time()
428
- with torch.no_grad():
429
- ad_gen_tokens, step_evals = pipeline.generate_adapt_diff(
430
- input_ids=inputs.input_ids,
431
- max_new_tokens=max_new_tokens
432
- )
433
- ad_total_time += (time.time() - t_start)
434
- ad_total_tokens += len(ad_gen_tokens)
435
- ad_total_evals += step_evals
436
- ad_text = src_tokenizer.decode(ad_gen_tokens, skip_special_tokens=True)
437
-
438
- if is_code:
439
- ar_correct += verify_code(ar_text, item[2])
440
- ad_correct += verify_code(ad_text, item[2])
441
- else:
442
- ar_correct += verify_math(ar_text, item[1])
443
- ad_correct += verify_math(ad_text, item[1])
444
-
445
- ar_throughput = ar_total_tokens / (ar_total_time + 1e-9)
446
- ad_throughput = ad_total_tokens / (ad_total_time + 1e-9)
447
- ad_flops_per_token = ad_total_evals / (ad_total_tokens + 1e-9)
448
-
449
- return {
450
- "ar_acc": ar_correct / total,
451
- "ad_acc": ad_correct / total,
452
- "ar_speed": ar_throughput,
453
- "ad_speed": ad_throughput,
454
- "ar_flops": 1.0,
455
- "ad_flops": ad_flops_per_token
456
- }
457
-
458
- print("\nStarting evaluation run...")
459
- math_results = run_benchmark(pipeline, baseline_model, val_math, is_code=False)
460
- code_results = run_benchmark(pipeline, baseline_model, val_code, is_code=True)
461
-
462
- # Print comparative results
463
- print("\n" + "="*95)
464
- print(" ADAPT-DIFF INFERENCE BENCHMARK RESULTS (Block Size L = 12)")
465
- print("="*95)
466
- print(f"{'Task / Strategy':<30} | {'Throughput (tok/s)':<20} | {'Task Acc':<15} | {'Relative FLOPs/Tok':<20}")
467
- print("-"*95)
468
- print(f"{'GSM8K (Autoregressive Baseline)':<30} | {math_results['ar_speed']:<20.2f} | {math_results['ar_acc']:<15.2%} | {math_results['ar_flops']:<20.4f}")
469
- print(f"{'GSM8K (ADAPT-DIFF Speculative)':<30} | {math_results['ad_speed']:<20.2f} | {math_results['ad_acc']:<15.2%} | {math_results['ad_flops']:<20.4f}")
470
- print("-"*95)
471
- print(f"{'MBPP (Autoregressive Baseline)':<30} | {code_results['ar_speed']:<20.2f} | {code_results['ar_acc']:<15.2%} | {code_results['ar_flops']:<20.4f}")
472
- print(f"{'MBPP (ADAPT-DIFF Speculative)':<30} | {code_results['ad_speed']:<20.2f} | {code_results['ad_acc']:<15.2%} | {code_results['ad_flops']:<20.4f}")
473
- print("="*95)
474
- ```
475
-
476
- Here is an example script for full SFT training on the GSM8K and MBPP benchmarks:
477
- ```python
478
- """
479
- ADAPT-DIFF Calibration & Training Script
480
- Finetunes the Custom Stacked LDM Heads using target sequences from GSM8K & MBPP.
481
- """
482
-
483
- import os
484
- import gc
485
- import copy
486
- import random
487
- import time
488
- import re
489
- from collections import defaultdict
490
- import torch
491
- import torch.nn as nn
492
- import torch.nn.functional as F
493
-
494
- print("Ensuring dependencies are installed...")
495
- os.system("pip install -q transformers>=4.40.0 datasets>=2.18.0 accelerate>=0.29.0 huggingface_hub")
496
-
497
- import transformers
498
- from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM
499
- from transformers.cache_utils import DynamicCache
500
- from transformers.modeling_outputs import BaseModelOutputWithPast
501
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
502
- from datasets import load_dataset
503
- from huggingface_hub import hf_hub_download
504
-
505
- # Clean up GPU cache before running
506
- gc.collect()
507
- torch.cuda.empty_cache()
508
-
509
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
510
- BASE_MODEL_ID = "Qwen/Qwen3.5-0.8B"
511
- ADAPT_DIFF_ID = "dataopsnick/adapt-diff-qwen-0.8b"
512
-
513
- print(f"Loading {BASE_MODEL_ID} tokenizer and model structure metadata...")
514
- src_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
515
- if src_tokenizer.pad_token is None:
516
- src_tokenizer.pad_token = src_tokenizer.eos_token
517
-
518
- # Load temporary instance to resolve base classes dynamically
519
- temp_model = AutoModelForCausalLM.from_pretrained(
520
- BASE_MODEL_ID,
521
- torch_dtype=torch.bfloat16,
522
- device_map="cpu"
523
- )
524
- src_config = temp_model.config
525
-
526
- BaseConfig = src_config.__class__
527
- BaseModel = temp_model.model.__class__
528
- BaseCausalLM = temp_model.__class__
529
-
530
- BasePreTrainedModel = next(
531
- (cls for cls in BaseCausalLM.__mro__ if cls.__name__.endswith("PreTrainedModel")),
532
- None
533
- )
534
- if BasePreTrainedModel is None:
535
- BasePreTrainedModel = BaseCausalLM.__bases__[0]
536
-
537
- del temp_model
538
- gc.collect()
539
-
540
-
541
- # ==============================================================================
542
- # Model & Pipeline Definitions
543
- # ==============================================================================
544
- class A2DQwenConfig(BaseConfig):
545
- model_type = "a2d-qwen"
546
-
547
- class A2DQwenModel(BaseModel):
548
- def forward(
549
- self,
550
- input_ids = None,
551
- attention_mask = None,
552
- position_ids = None,
553
- past_key_values = None,
554
- inputs_embeds = None,
555
- use_cache = None,
556
- cache_position = None,
557
- **kwargs,
558
- ):
559
- if (input_ids is None) ^ (inputs_embeds is not None):
560
- raise ValueError("Specify exactly one of input_ids or inputs_embeds")
561
-
562
- if inputs_embeds is None:
563
- inputs_embeds = self.embed_tokens(input_ids)
564
-
565
- if use_cache and past_key_values is None:
566
- past_key_values = DynamicCache(config=self.config)
567
-
568
- if cache_position is None:
569
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
570
- cache_position = torch.arange(
571
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
572
- )
573
-
574
- if position_ids is None:
575
- position_ids = cache_position.unsqueeze(0)
576
-
577
- if not isinstance(causal_mask_mapping := attention_mask, dict):
578
- if attention_mask is None:
579
- attention_mask = torch.ones(
580
- inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long
581
- )
582
- if not (isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 4):
583
- attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
584
- causal_mask_mapping = defaultdict(lambda: attention_mask)
585
-
586
- hidden_states = inputs_embeds
587
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
588
-
589
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
590
- attn_type = getattr(decoder_layer, "attention_type", "self_attn")
591
- hidden_states = decoder_layer(
592
- hidden_states,
593
- attention_mask=causal_mask_mapping[attn_type],
594
- position_ids=position_ids,
595
- past_key_values=past_key_values,
596
- use_cache=use_cache,
597
- cache_position=cache_position,
598
- position_embeddings=position_embeddings,
599
- **kwargs,
600
- )
601
-
602
- hidden_states = self.norm(hidden_states)
603
- return BaseModelOutputWithPast(
604
- last_hidden_state=hidden_states,
605
- past_key_values=past_key_values if use_cache else None,
606
- )
607
-
608
- class A2DQwenLMHeadModel(BaseCausalLM):
609
- config_class = A2DQwenConfig
610
- def __init__(self, config):
611
- BasePreTrainedModel.__init__(self, config)
612
- self.model = A2DQwenModel(config)
613
- self.vocab_size = config.vocab_size
614
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
615
- self.post_init()
616
-
617
-
618
- # Register custom classes
619
- transformers.AutoConfig.register("a2d-qwen", A2DQwenConfig)
620
- transformers.AutoModel.register(A2DQwenConfig, A2DQwenLMHeadModel)
621
- transformers.AutoModelForCausalLM.register(A2DQwenConfig, A2DQwenLMHeadModel)
622
-
623
-
624
- class StackedLDMHeads(nn.Module):
625
- def __init__(self, hidden_size, vocab_size, block_size=12):
626
- super().__init__()
627
- self.block_size = block_size
628
- self.proj = nn.Linear(hidden_size, block_size * hidden_size, dtype=torch.bfloat16)
629
- self.head = nn.Linear(hidden_size, vocab_size, dtype=torch.bfloat16)
630
-
631
- def forward(self, hidden_states):
632
- batch_size, seq_len, hidden_size = hidden_states.shape
633
- forecast = self.proj(hidden_states)
634
- forecast = forecast.view(batch_size, seq_len, self.block_size, hidden_size)
635
- logits = self.head(forecast)
636
- return logits
637
-
638
- class LogitUncertaintyFilter(nn.Module):
639
- def compute_entropy(self, logits: torch.Tensor) -> torch.Tensor:
640
- probs = F.softmax(logits.float(), dim=-1)
641
- entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1)
642
- return entropy
643
-
644
- def forward(self, logits: torch.Tensor, threshold: float):
645
- entropy = self.compute_entropy(logits)
646
- mask = entropy >= threshold
647
- return mask, entropy
648
-
649
- class ActorCriticPruner:
650
- def __init__(self, lm_head, lambda_reg=0.1):
651
- self.lm_head = lm_head
652
- self.lambda_reg = lambda_reg
653
-
654
- def evaluate_sequence_value(self, candidate_tokens, logits):
655
- log_probs = F.log_softmax(logits.float(), dim=-1)
656
- gathered = torch.gather(log_probs, -1, candidate_tokens.unsqueeze(-1)).squeeze(-1)
657
- return gathered.mean().item()
658
-
659
- def recursive_refine(self, sequence, logits, mask, entropy, depth, alpha, beta):
660
- refined_sequence = sequence.clone()
661
- if depth == 0 or mask.sum() == 0:
662
- return refined_sequence, self.evaluate_sequence_value(sequence, logits)
663
-
664
- high_unc_positions = torch.where(mask)[0]
665
- if len(high_unc_positions) == 0:
666
- return refined_sequence, self.evaluate_sequence_value(sequence, logits)
667
-
668
- target_pos = high_unc_positions[0].item()
669
- top_logits, top_tokens = torch.topk(logits[target_pos], k=3)
670
-
671
- best_val = float('-inf')
672
- for token_opt in top_tokens:
673
- candidate = sequence.clone()
674
- candidate[target_pos] = token_opt
675
-
676
- approx_val = self.evaluate_sequence_value(candidate, logits) - (self.lambda_reg * entropy[target_pos].item())
677
- if approx_val < alpha:
678
- continue
679
-
680
- new_mask = mask.clone()
681
- new_mask[target_pos] = False
682
-
683
- _, path_val = self.recursive_refine(candidate, logits, new_mask, entropy, depth - 1, alpha, beta)
684
- if path_val > alpha:
685
- alpha = path_val
686
- best_val = path_val
687
- refined_sequence = candidate
688
-
689
- if alpha >= beta:
690
- break
691
-
692
- return refined_sequence, best_val
693
-
694
-
695
- class ADAPTDIFFPipeline(nn.Module):
696
- def __init__(self, base_lm_model, block_size=12, entropy_threshold=1.5):
697
- super().__init__()
698
- self.base_model = base_lm_model.model
699
- self.lm_head = base_lm_model.lm_head
700
- self.block_size = block_size
701
- self.entropy_threshold = entropy_threshold
702
-
703
- self.ldm_heads = StackedLDMHeads(
704
- hidden_size=base_lm_model.config.hidden_size,
705
- vocab_size=base_lm_model.config.vocab_size,
706
- block_size=block_size
707
- ).to(DEVICE)
708
-
709
- self.router = LogitUncertaintyFilter()
710
- self.pruner = ActorCriticPruner(self.lm_head)
711
-
712
- def generate_adapt_diff(self, input_ids, max_new_tokens=128):
713
- current_seq = input_ids.clone()
714
- generated_count = 0
715
- total_full_transformer_evals = 0
716
-
717
- while generated_count < max_new_tokens:
718
- outputs = self.base_model(input_ids=current_seq)
719
- total_full_transformer_evals += 1
720
- last_hidden = outputs.last_hidden_state[:, -1:, :]
721
-
722
- block_logits = self.ldm_heads(last_hidden).squeeze(0).squeeze(0)
723
- draft_tokens = torch.argmax(block_logits, dim=-1)
724
-
725
- mask, entropy = self.router(block_logits, self.entropy_threshold)
726
-
727
- if not mask.any():
728
- final_block = draft_tokens
729
- else:
730
- total_full_transformer_evals += 1
731
- final_block, _ = self.pruner.recursive_refine(
732
- sequence=draft_tokens,
733
- logits=block_logits,
734
- mask=mask,
735
- entropy=entropy,
736
- depth=2,
737
- alpha=float('-inf'),
738
- beta=float('inf')
739
- )
740
-
741
- current_seq = torch.cat([current_seq, final_block.unsqueeze(0)], dim=-1)
742
- generated_count += self.block_size
743
-
744
- return current_seq[0, input_ids.shape[1]:], total_full_transformer_evals
745
-
746
-
747
- # ==============================================================================
748
- # Model Loading
749
- # ==============================================================================
750
- print(f"Loading ADAPT-DIFF base model {ADAPT_DIFF_ID}...")
751
- a2d_model = AutoModelForCausalLM.from_pretrained(
752
- ADAPT_DIFF_ID,
753
- torch_dtype=torch.bfloat16,
754
- device_map=DEVICE
755
- )
756
-
757
- pipeline = ADAPTDIFFPipeline(a2d_model, block_size=12, entropy_threshold=1.5)
758
- print("Downloading LDM head projection weights for calibration baseline...")
759
- ldm_weights_path = hf_hub_download(repo_id=ADAPT_DIFF_ID, filename="ldm_heads.pt")
760
- pipeline.ldm_heads.load_state_dict(torch.load(ldm_weights_path, map_location=DEVICE))
761
-
762
-
763
- # ==============================================================================
764
- # SFT Training Dataset Setup
765
- # ==============================================================================
766
- print("\nDownloading datasets (GSM8K & MBPP) for calibration phase...")
767
- gsm8k_ds = load_dataset("openai/gsm8k", "main")
768
- mbpp_ds = load_dataset("google-research-datasets/mbpp")
769
-
770
- candidate_train = []
771
-
772
- if "train" in gsm8k_ds:
773
- for item in gsm8k_ds["train"]:
774
- prompt = f"Problem: {item['question']}\nSolution:"
775
- completion = f" {item['answer']}"
776
- candidate_train.append((prompt, completion))
777
- if len(candidate_train) >= 40:
778
- break
779
-
780
- mbpp_train_raw = mbpp_ds["train"] if "train" in mbpp_ds else list(mbpp_ds.values())[0]
781
- code_count = 0
782
- for item in mbpp_train_raw:
783
- if 'text' in item and 'code' in item:
784
- prompt = f"Write a Python function to solve this task:\n{item['text']}\nSolution:\n"
785
- completion = f"{item['code']}"
786
- candidate_train.append((prompt, completion))
787
- code_count += 1
788
- if code_count >= 40:
789
- break
790
-
791
- print(f"Assembled training set with {len(candidate_train)} sequences.")
792
-
793
- train_tensors = []
794
- for prompt, completion in candidate_train:
795
- full_text = prompt + completion
796
- encoded = src_tokenizer(full_text, return_tensors="pt").to(DEVICE)
797
- if encoded.input_ids.shape[1] > (pipeline.block_size + 2):
798
- train_tensors.append(encoded.input_ids)
799
-
800
-
801
- # ==============================================================================
802
- # Calibration Loop
803
- # ==============================================================================
804
- pipeline.train()
805
- optimizer = torch.optim.AdamW(pipeline.parameters(), lr=2e-4, weight_decay=0.01)
806
-
807
- def compute_ldm_forecast_loss(pipeline, input_ids):
808
- outputs = pipeline.base_model(input_ids=input_ids)
809
- hidden_states = outputs.last_hidden_state
810
-
811
- block_logits = pipeline.ldm_heads(hidden_states)
812
- B, S, L, V = block_logits.shape
813
- max_idx = S - 1 - L
814
-
815
- if max_idx <= 0:
816
- return torch.tensor(0.0, device=input_ids.device, requires_grad=True)
817
-
818
- pred_logits = block_logits[:, :max_idx, :, :]
819
- targets = torch.stack([
820
- input_ids[:, i + 1 : i + 1 + L] for i in range(max_idx)
821
- ], dim=1)
822
-
823
- loss_fct = nn.CrossEntropyLoss()
824
- return loss_fct(pred_logits.reshape(-1, V), targets.reshape(-1))
825
-
826
- epochs = 20
827
- step = 0
828
- best_loss = float('inf')
829
- best_state_dict = None
830
-
831
- print(f"\nCalibrating Stacked LDM heads across {epochs} epochs...")
832
-
833
- for epoch in range(epochs):
834
- random.shuffle(train_tensors)
835
- epoch_loss = 0.0
836
-
837
- for input_ids in train_tensors:
838
- pipeline.train()
839
- optimizer.zero_grad(set_to_none=True)
840
-
841
- loss = compute_ldm_forecast_loss(pipeline, input_ids)
842
- if loss.item() == 0.0:
843
- continue
844
-
845
- loss.backward()
846
- torch.nn.utils.clip_grad_norm_(pipeline.parameters(), max_norm=1.0)
847
- optimizer.step()
848
-
849
- current_loss = loss.item()
850
- epoch_loss += current_loss
851
- step += 1
852
-
853
- if current_loss < best_loss:
854
- best_loss = current_loss
855
- best_state_dict = copy.deepcopy(pipeline.state_dict())
856
-
857
- if step % 20 == 0:
858
- print(f"Step {step:3d} | Epoch {epoch+1} | Loss: {current_loss:.4f} (Best: {best_loss:.4f})")
859
-
860
- print("\nSFT alignment completed.")
861
- if best_state_dict is not None:
862
- pipeline.load_state_dict(best_state_dict)
863
- print(f"Successfully loaded best state checkpoint with loss: {best_loss:.4f}")
864
-
865
-
866
- # ==============================================================================
867
- # Model Post-Training Evaluation
868
- # ==============================================================================
869
- pipeline.eval()
870
- print("\nVerifying model calibration progress on training sequence forecasts...")
871
 
872
- with torch.no_grad():
873
- for idx, input_ids in enumerate(train_tensors[:2]):
874
- seq_len = input_ids.shape[1]
875
- L = pipeline.block_size
876
- if seq_len <= L + 1:
877
- continue
878
-
879
- prefix_len = seq_len - L
880
- prefix_ids = input_ids[:, :prefix_len]
881
- target_ids = input_ids[0, prefix_len : prefix_len + L]
882
-
883
- outputs = pipeline.base_model(input_ids=prefix_ids)
884
- hidden_states = outputs.last_hidden_state
885
- block_logits = pipeline.ldm_heads(hidden_states)
886
-
887
- forecast_logits = block_logits[0, -1, :, :]
888
- pred_ids = torch.argmax(forecast_logits, dim=-1)
889
-
890
- prompt_text = src_tokenizer.decode(prefix_ids[0], skip_special_tokens=True)
891
- expected_text = src_tokenizer.decode(target_ids, skip_special_tokens=True)
892
- predicted_text = src_tokenizer.decode(pred_ids, skip_special_tokens=True)
893
-
894
- truncated_prompt = prompt_text[-200:] if len(prompt_text) > 200 else prompt_text
895
- print(f"\n--- Sequence Output Check {idx + 1} ---")
896
- print(f"[Context Prompt Segment]: ... {truncated_prompt}")
897
- print(f"[Expected Block Output]: {expected_text}")
898
- print(f"[Predicted Block Output]: {predicted_text}")
899
  ```
900
 
 
 
 
 
 
31
  ldm_heads.load_state_dict(torch.load(ldm_weights_path))
32
  ```
33
 
34
+ ### Full Inference Benchmarks & SFT Calibration
35
 
36
+ To run the complete benchmark comparison against the autoregressive baseline or to perform Supervised Fine-Tuning (SFT) calibration on your own system, clone this repository and execute the dedicated scripts included in the repository:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ #### 1. Run Comparative Benchmarking (GSM8K & MBPP)
39
+ ```bash
40
+ python infer.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  ```
42
 
43
+ #### 2. Run Head Alignment & SFT Training
44
+ ```bash
45
+ python train.py
46
+ ```