fuvty commited on
Commit
5ccf219
·
1 Parent(s): f07198b

[init] demo

Browse files
Files changed (50) hide show
  1. app.py +572 -0
  2. requirements.txt +20 -0
  3. rosetta/__pycache__/__init__.cpython-310.pyc +0 -0
  4. rosetta/baseline/__pycache__/multi_stage.cpython-310.pyc +0 -0
  5. rosetta/baseline/__pycache__/multi_stage.cpython-312.pyc +0 -0
  6. rosetta/baseline/__pycache__/two_stage_rosetta.cpython-310.pyc +0 -0
  7. rosetta/baseline/__pycache__/two_stage_rosetta.cpython-312.pyc +0 -0
  8. rosetta/baseline/multi_stage.py +992 -0
  9. rosetta/model/__pycache__/__init__.cpython-310.pyc +0 -0
  10. rosetta/model/__pycache__/aggregator.cpython-310.pyc +0 -0
  11. rosetta/model/__pycache__/aggregator.cpython-312.pyc +0 -0
  12. rosetta/model/__pycache__/aligner.cpython-310.pyc +0 -0
  13. rosetta/model/__pycache__/aligner.cpython-312.pyc +0 -0
  14. rosetta/model/__pycache__/all_in_one_projector.cpython-310.pyc +0 -0
  15. rosetta/model/__pycache__/all_in_one_projector.cpython-312.pyc +0 -0
  16. rosetta/model/__pycache__/cache.cpython-310.pyc +0 -0
  17. rosetta/model/__pycache__/oracle.cpython-310.pyc +0 -0
  18. rosetta/model/__pycache__/oracle.cpython-312.pyc +0 -0
  19. rosetta/model/__pycache__/projector.cpython-310.pyc +0 -0
  20. rosetta/model/__pycache__/projector.cpython-312.pyc +0 -0
  21. rosetta/model/__pycache__/sampling.cpython-310.pyc +0 -0
  22. rosetta/model/__pycache__/sampling.cpython-312.pyc +0 -0
  23. rosetta/model/__pycache__/wrapper.cpython-310.pyc +0 -0
  24. rosetta/model/__pycache__/wrapper.cpython-312.pyc +0 -0
  25. rosetta/model/ablation_projector.py +291 -0
  26. rosetta/model/aligner.py +639 -0
  27. rosetta/model/oracle.py +562 -0
  28. rosetta/model/projector.py +1194 -0
  29. rosetta/model/sampling.py +82 -0
  30. rosetta/model/wrapper.py +840 -0
  31. rosetta/train/__init__.py +20 -0
  32. rosetta/train/__pycache__/__init__.cpython-310.pyc +0 -0
  33. rosetta/train/__pycache__/__init__.cpython-312.pyc +0 -0
  34. rosetta/train/__pycache__/dataset_adapters.cpython-310.pyc +0 -0
  35. rosetta/train/__pycache__/dataset_adapters.cpython-312.pyc +0 -0
  36. rosetta/train/__pycache__/model_utils.cpython-310.pyc +0 -0
  37. rosetta/train/__pycache__/model_utils.cpython-312.pyc +0 -0
  38. rosetta/train/__pycache__/trainer.cpython-310.pyc +0 -0
  39. rosetta/train/__pycache__/trainer.cpython-312.pyc +0 -0
  40. rosetta/train/dataset_adapters.py +1648 -0
  41. rosetta/train/model_utils.py +142 -0
  42. rosetta/utils/__pycache__/evaluate.cpython-310.pyc +0 -0
  43. rosetta/utils/__pycache__/evaluate.cpython-312.pyc +0 -0
  44. rosetta/utils/__pycache__/multi_stage.cpython-310.pyc +0 -0
  45. rosetta/utils/__pycache__/multi_stage.cpython-312.pyc +0 -0
  46. rosetta/utils/__pycache__/registry.cpython-310.pyc +0 -0
  47. rosetta/utils/__pycache__/registry.cpython-312.pyc +0 -0
  48. rosetta/utils/evaluate.py +639 -0
  49. rosetta/utils/matheval.py +455 -0
  50. rosetta/utils/registry.py +279 -0
app.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio Side-by-Side Model Comparison Demo
3
+
4
+ This creates a web interface to compare three inference modes simultaneously:
5
+ 1. Single: Regular HuggingFace model
6
+ 2. T2T: Two-stage inference (shows context + answer)
7
+ 3. C2C: Rosetta model with projectors
8
+
9
+ All models are loaded at startup and respond to the same input in parallel.
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import torch
15
+ import argparse
16
+ import gradio as gr
17
+ from pathlib import Path
18
+ from typing import Optional, Generator
19
+ from queue import Queue
20
+ from threading import Thread
21
+
22
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
23
+ from rosetta.utils.evaluate import load_rosetta_model, load_hf_model, set_default_chat_template
24
+ from rosetta.model.wrapper import RosettaModel
25
+ from rosetta.baseline.multi_stage import TwoStageInference
26
+
27
+
28
+ class ModelManager:
29
+ """Manages loading and inference for all three model types."""
30
+
31
+ def __init__(
32
+ self,
33
+ single_model_name: str = "Qwen/Qwen2.5-0.5B-Instruct",
34
+ t2t_context_model: str = "Qwen/Qwen3-0.6B",
35
+ t2t_answer_model: str = "Qwen/Qwen2.5-0.5B-Instruct",
36
+ c2c_checkpoint_path: str = "local/checkpoints/qwen3_0.6b+qwen2.5_0.5b_Fuser",
37
+ device: str = "cuda"
38
+ ):
39
+ """
40
+ Initialize ModelManager with model configurations.
41
+
42
+ Args:
43
+ single_model_name: HuggingFace model name for single mode
44
+ t2t_context_model: Context model for T2T mode
45
+ t2t_answer_model: Answer model for T2T mode
46
+ c2c_checkpoint_path: Path to C2C checkpoint directory
47
+ device: Device to use (cuda, cpu, or auto)
48
+ """
49
+ if device == "auto":
50
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+ else:
52
+ self.device = torch.device(device)
53
+ print(f"Using device: {self.device}")
54
+
55
+ # Model configurations
56
+ self.single_model_name = single_model_name
57
+ self.t2t_context_model = t2t_context_model
58
+ self.t2t_answer_model = t2t_answer_model
59
+ self.c2c_checkpoint_path = c2c_checkpoint_path
60
+
61
+ # T2T prompt configurations
62
+ self.t2t_background_prompt = "Briefly describe the most useful background to answer the question:\n\n{question}"
63
+ self.t2t_answer_prompt = "Based on the background, answer the question:\n\n{question}" # Format for second round question
64
+ self.t2t_context_max_tokens = 512
65
+ self.t2t_answer_max_tokens = 512
66
+
67
+ # Generation configuration (shared across all models)
68
+ # To enable sampling: set use_sampling=True and adjust temperature/top_p/top_k
69
+ # Current mode: Greedy decoding (do_sample=False)
70
+ self.use_sampling = False # Set to True to enable sampling
71
+ self.temperature = 0.7 # Used when use_sampling=True
72
+ self.top_p = 0.8 # Used when use_sampling=True
73
+ self.top_k = 20 # Used when use_sampling=True
74
+
75
+ # Initialize models
76
+ self.single_model = None
77
+ self.single_tokenizer = None
78
+ self.t2t_model = None
79
+ self.c2c_model = None
80
+ self.c2c_tokenizer = None
81
+
82
+ # C2C model names (will be loaded from config)
83
+ self.c2c_base_model = None
84
+ self.c2c_teacher_model = None
85
+
86
+ print("=" * 60)
87
+ print("Initializing models... This may take a few minutes.")
88
+ print("=" * 60)
89
+
90
+ self._load_all_models()
91
+
92
+ def _load_single_model(self):
93
+ """Load single HuggingFace model."""
94
+ print(f"\n[Single] Loading {self.single_model_name}...")
95
+ self.single_model, self.single_tokenizer = load_hf_model(
96
+ self.single_model_name, self.device
97
+ )
98
+ set_default_chat_template(self.single_tokenizer, self.single_model_name)
99
+ print("[Single] ✓ Model loaded")
100
+
101
+ def _load_t2t_model(self):
102
+ """Load two-stage model."""
103
+ print(f"\n[T2T] Loading two-stage model...")
104
+ print(f" Context: {self.t2t_context_model}")
105
+ print(f" Answer: {self.t2t_answer_model}")
106
+ print(f" Background prompt: {self.t2t_background_prompt}")
107
+ print(f" Answer prompt: {self.t2t_answer_prompt}")
108
+
109
+ self.t2t_model = TwoStageInference(
110
+ context_model_path=self.t2t_context_model,
111
+ answer_model_path=self.t2t_answer_model,
112
+ device=str(self.device),
113
+ background_prompt=self.t2t_background_prompt
114
+ )
115
+ print("[T2T] ✓ Model loaded")
116
+
117
+ def _load_c2c_model(self):
118
+ """Load Rosetta (C2C) model."""
119
+ print(f"\n[C2C] Loading Rosetta model from {self.c2c_checkpoint_path}...")
120
+
121
+ # Check if checkpoint exists
122
+ if not Path(self.c2c_checkpoint_path).exists():
123
+ raise FileNotFoundError(
124
+ f"C2C checkpoint not found: {self.c2c_checkpoint_path}\n"
125
+ "You can download the checkpoints automatically with:\n"
126
+ " from huggingface_hub import snapshot_download\n"
127
+ " snapshot_download(\n"
128
+ " repo_id='nics-efc/C2C_Fuser',\n"
129
+ " allow_patterns=['qwen3_0.6b+qwen2.5_0.5b_Fuser/*'],\n"
130
+ " local_dir='local/checkpoints'\n"
131
+ " )\n"
132
+ "Or see the project README for more details."
133
+ )
134
+
135
+ # Load config
136
+ import yaml
137
+ config_path = Path(self.c2c_checkpoint_path) / "config.json"
138
+ if not config_path.exists():
139
+ raise FileNotFoundError(f"Config file not found: {config_path}")
140
+
141
+ with open(config_path, "r") as f:
142
+ config = yaml.safe_load(f)
143
+
144
+ # Store model names from config
145
+ self.c2c_base_model = config["model"]["base_model"]
146
+ self.c2c_teacher_model = config["model"]["teacher_model"]
147
+
148
+ # Load Rosetta model
149
+ subfolder_dir = Path(self.c2c_checkpoint_path) / "final"
150
+ if not subfolder_dir.exists():
151
+ raise FileNotFoundError(f"Final checkpoint directory not found: {subfolder_dir}")
152
+
153
+ model_config = {
154
+ "model_name": "Rosetta",
155
+ "rosetta_config": {
156
+ "checkpoints_dir": str(subfolder_dir),
157
+ "base_model": self.c2c_base_model,
158
+ "teacher_model": self.c2c_teacher_model,
159
+ "is_do_alignment": config["model"].get("is_do_alignment", False),
160
+ "alignment_strategy": config["model"].get("alignment_strategy", "first")
161
+ }
162
+ }
163
+
164
+ eval_config = {"checkpoints_dir": str(subfolder_dir)}
165
+
166
+ self.c2c_model, self.c2c_tokenizer = load_rosetta_model(
167
+ model_config, eval_config, self.device
168
+ )
169
+ print("[C2C] ✓ Model loaded")
170
+
171
+ def _load_all_models(self):
172
+ """Load all models sequentially."""
173
+ try:
174
+ self._load_single_model()
175
+ self._load_t2t_model()
176
+ self._load_c2c_model()
177
+ print("\n" + "=" * 60)
178
+ print("✓ All models loaded successfully!")
179
+ print("=" * 60 + "\n")
180
+ except Exception as e:
181
+ print(f"\n✗ Error loading models: {e}")
182
+ raise
183
+
184
+ def _get_generation_kwargs(self, max_new_tokens: int) -> dict:
185
+ """
186
+ Get generation kwargs with consistent settings across all models.
187
+
188
+ Args:
189
+ max_new_tokens: Maximum number of new tokens to generate
190
+
191
+ Returns:
192
+ Dictionary of generation parameters
193
+ """
194
+ kwargs = {
195
+ 'max_new_tokens': max_new_tokens,
196
+ 'do_sample': self.use_sampling
197
+ }
198
+
199
+ if self.use_sampling:
200
+ kwargs.update({
201
+ 'temperature': self.temperature,
202
+ 'top_p': self.top_p,
203
+ 'top_k': self.top_k
204
+ })
205
+
206
+ return kwargs
207
+
208
+ def generate_single(self, user_input: str) -> Generator[str, None, None]:
209
+ """Generate response from single model with streaming."""
210
+ messages = [{"role": "system", "content": ""}, {"role": "user", "content": user_input}]
211
+ text = self.single_tokenizer.apply_chat_template(
212
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
213
+ )
214
+ inputs = self.single_tokenizer(text, return_tensors="pt").to(self.device)
215
+
216
+ # Setup streamer
217
+ streamer = TextIteratorStreamer(
218
+ self.single_tokenizer,
219
+ skip_prompt=True,
220
+ skip_special_tokens=True
221
+ )
222
+
223
+ # Generation parameters
224
+ generation_kwargs = {
225
+ 'input_ids': inputs.input_ids,
226
+ 'attention_mask': inputs.attention_mask,
227
+ 'streamer': streamer,
228
+ **self._get_generation_kwargs(max_new_tokens=2048)
229
+ }
230
+
231
+ # Run generation in separate thread
232
+ thread = Thread(target=self.single_model.generate, kwargs=generation_kwargs)
233
+ thread.start()
234
+
235
+ # Stream tokens
236
+ generated_text = ""
237
+ for token in streamer:
238
+ generated_text += token
239
+ yield generated_text
240
+
241
+ def generate_t2t(self, user_input: str) -> Generator[tuple[str, str], None, None]:
242
+ """Generate response from T2T model with streaming (returns context, answer)."""
243
+ # Stage 1: Context generation
244
+ context_streamer = TextIteratorStreamer(
245
+ self.t2t_model.context_tokenizer,
246
+ skip_prompt=True,
247
+ skip_special_tokens=True
248
+ )
249
+
250
+ prompt = self.t2t_background_prompt.format(question=user_input)
251
+ inputs = self.t2t_model.context_tokenizer.apply_chat_template(
252
+ [{"role": "user", "content": prompt}],
253
+ tokenize=True,
254
+ add_generation_prompt=True,
255
+ return_tensors="pt",
256
+ enable_thinking=False
257
+ ).to(self.device)
258
+
259
+ generation_kwargs = {
260
+ 'input_ids': inputs,
261
+ 'streamer': context_streamer,
262
+ **self._get_generation_kwargs(max_new_tokens=self.t2t_context_max_tokens)
263
+ }
264
+
265
+ # Generate context in thread
266
+ thread = Thread(target=self.t2t_model.context_model.generate, kwargs=generation_kwargs)
267
+ thread.start()
268
+
269
+ # Stream context tokens
270
+ context_text = ""
271
+ for token in context_streamer:
272
+ context_text += token
273
+ yield context_text, ""
274
+
275
+ thread.join()
276
+
277
+ # Decode full context
278
+ with torch.inference_mode():
279
+ outputs = self.t2t_model.context_model.generate(
280
+ inputs, **self._get_generation_kwargs(max_new_tokens=self.t2t_context_max_tokens)
281
+ )
282
+ context = self.t2t_model.context_tokenizer.batch_decode(
283
+ outputs[:, inputs.shape[-1]:], skip_special_tokens=True
284
+ )[0]
285
+
286
+ # Stage 2: Answer generation
287
+ answer_streamer = TextIteratorStreamer(
288
+ self.t2t_model.answer_tokenizer,
289
+ skip_prompt=True,
290
+ skip_special_tokens=True
291
+ )
292
+
293
+ # Format the second round question
294
+ answer_question = self.t2t_answer_prompt.format(question=user_input)
295
+
296
+ inputs = self.t2t_model.answer_tokenizer.apply_chat_template(
297
+ [
298
+ {"role": "user", "content": prompt},
299
+ {"role": "assistant", "content": context},
300
+ {"role": "user", "content": answer_question}
301
+ ],
302
+ tokenize=True,
303
+ add_generation_prompt=True,
304
+ return_tensors="pt",
305
+ enable_thinking=False
306
+ ).to(self.device)
307
+
308
+ generation_kwargs = {
309
+ 'input_ids': inputs,
310
+ 'streamer': answer_streamer,
311
+ **self._get_generation_kwargs(max_new_tokens=self.t2t_answer_max_tokens)
312
+ }
313
+
314
+ # Generate answer in thread
315
+ thread = Thread(target=self.t2t_model.answer_model.generate, kwargs=generation_kwargs)
316
+ thread.start()
317
+
318
+ # Stream answer tokens
319
+ answer_text = ""
320
+ for token in answer_streamer:
321
+ answer_text += token
322
+ yield context_text, answer_text
323
+
324
+ def generate_c2c(self, user_input: str) -> Generator[str, None, None]:
325
+ """Generate response from C2C model with streaming."""
326
+ messages = [{"role": "system", "content": ""}, {"role": "user", "content": user_input}]
327
+ text = self.c2c_tokenizer.apply_chat_template(
328
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
329
+ )
330
+ inputs = self.c2c_tokenizer(text, return_tensors="pt").to(self.device)
331
+
332
+ # Setup streamer
333
+ streamer = TextIteratorStreamer(
334
+ self.c2c_tokenizer,
335
+ skip_prompt=True,
336
+ skip_special_tokens=True
337
+ )
338
+
339
+ # Prepare C2C-specific inputs
340
+ full_length = inputs.input_ids.shape[1]
341
+ instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(
342
+ full_length - 1, 1
343
+ ).unsqueeze(0).to(self.device)
344
+ label_index = torch.tensor([-1, 0], dtype=torch.long).repeat(
345
+ 1, 1
346
+ ).unsqueeze(0).to(self.device)
347
+ position_ids = inputs.attention_mask.long().cumsum(-1) - 1 if inputs.attention_mask is not None else \
348
+ torch.arange(full_length, dtype=torch.long).unsqueeze(0).to(self.device)
349
+
350
+ # Generation parameters
351
+ generation_kwargs = {
352
+ 'kv_cache_index': [instruction_index, label_index],
353
+ 'input_ids': inputs.input_ids,
354
+ 'attention_mask': inputs.attention_mask,
355
+ 'position_ids': position_ids,
356
+ 'streamer': streamer,
357
+ **self._get_generation_kwargs(max_new_tokens=2048)
358
+ }
359
+
360
+ # Run generation in separate thread
361
+ thread = Thread(target=self.c2c_model.generate, kwargs=generation_kwargs)
362
+ thread.start()
363
+
364
+ # Stream tokens
365
+ generated_text = ""
366
+ for token in streamer:
367
+ generated_text += token
368
+ yield generated_text
369
+
370
+
371
+ def create_demo(model_manager: ModelManager):
372
+ """Create Gradio interface."""
373
+
374
+ # Preset example questions
375
+ EXAMPLE_QUESTIONS = {
376
+ "example1": """Instead of asking why the act of destroying the environment might be immoral, Hill wants to ask ...
377
+
378
+ A. Why the act of destroying nature might be immoral.
379
+ B. Why people who destroy the environment might be bad people.
380
+ C. How the decision to preserve the environment benefits the environment.
381
+ D. Whether plants have interests.""",
382
+ "example2": "Which company launched the Gemini 1.5 Pro model in early 2024?"
383
+ }
384
+
385
+ def respond(user_input: str):
386
+ """Main response function that yields updates for all three models."""
387
+ if not user_input.strip():
388
+ yield "", "", "", ""
389
+
390
+ # Generators for each model
391
+ single_gen = model_manager.generate_single(user_input)
392
+ t2t_gen = model_manager.generate_t2t(user_input)
393
+ c2c_gen = model_manager.generate_c2c(user_input)
394
+
395
+ single_done = False
396
+ t2t_done = False
397
+ c2c_done = False
398
+
399
+ single_text = ""
400
+ t2t_context = ""
401
+ t2t_answer = ""
402
+ c2c_text = ""
403
+
404
+ # Stream from all three models
405
+ while not (single_done and t2t_done and c2c_done):
406
+ # Update single
407
+ if not single_done:
408
+ try:
409
+ single_text = next(single_gen)
410
+ except StopIteration:
411
+ single_done = True
412
+
413
+ # Update T2T
414
+ if not t2t_done:
415
+ try:
416
+ t2t_context, t2t_answer = next(t2t_gen)
417
+ except StopIteration:
418
+ t2t_done = True
419
+
420
+ # Update C2C
421
+ if not c2c_done:
422
+ try:
423
+ c2c_text = next(c2c_gen)
424
+ except StopIteration:
425
+ c2c_done = True
426
+
427
+ # Yield current state
428
+ yield single_text, t2t_context, t2t_answer, c2c_text
429
+
430
+ # Create Gradio interface
431
+ with gr.Blocks(title="C2C Demo", theme=gr.themes.Base()) as demo:
432
+ # Header with logo
433
+ with gr.Row():
434
+ with gr.Column(scale=1, min_width=100):
435
+ gr.Image("resource/logo.png", show_label=False, show_download_button=False, container=False, height=80)
436
+ with gr.Column(scale=5):
437
+ gr.Markdown("# Cache-to-Cache Communication Demo")
438
+ gr.Markdown("Compare three inference modes side-by-side: **Single** | **Text-to-Text Communication** | **Cache-to-Cache Communication**")
439
+
440
+ gr.Markdown("---")
441
+
442
+ # Input section
443
+ gr.Markdown("## Question")
444
+
445
+ # Preset question examples
446
+ gr.Markdown("Example Questions:")
447
+ with gr.Row():
448
+ example1_btn = gr.Button("📝 Example 1: Philosophy", size="sm")
449
+ example2_btn = gr.Button("📝 Example 2: Knowledge Cutoff", size="sm")
450
+
451
+
452
+ with gr.Row():
453
+ user_input = gr.Textbox(
454
+ label="",
455
+ placeholder="Type your question here...",
456
+ lines=2,
457
+ scale=4,
458
+ show_label=False
459
+ )
460
+
461
+ with gr.Row():
462
+ submit_btn = gr.Button("🚀 Submit", variant="primary", scale=1)
463
+ clear_btn = gr.Button("🗑️ Clear", scale=1)
464
+
465
+ gr.Markdown("---")
466
+
467
+ # Output section - three columns
468
+ gr.Markdown("## Responses")
469
+ with gr.Row():
470
+ # Single column
471
+ with gr.Column():
472
+ gr.Markdown("### Single Model")
473
+ gr.Markdown(f"*{model_manager.single_model_name}*")
474
+ single_output = gr.Textbox(
475
+ label="",
476
+ lines=18,
477
+ max_lines=30,
478
+ interactive=False,
479
+ show_label=False
480
+ )
481
+
482
+ # T2T column (with two sub-boxes)
483
+ with gr.Column():
484
+ gr.Markdown("### Text-to-Text Communication")
485
+ gr.Markdown(f"*{model_manager.t2t_context_model} → {model_manager.t2t_answer_model}*")
486
+ t2t_context_output = gr.Textbox(
487
+ label="📝 Context",
488
+ lines=6,
489
+ max_lines=12,
490
+ interactive=False
491
+ )
492
+ t2t_answer_output = gr.Textbox(
493
+ label="💬 Answer",
494
+ lines=7,
495
+ max_lines=14,
496
+ interactive=False
497
+ )
498
+
499
+ # C2C column
500
+ with gr.Column():
501
+ gr.Markdown("### Cache-to-Cache Communication")
502
+ gr.Markdown(f"*{model_manager.c2c_base_model} → {model_manager.c2c_teacher_model}*")
503
+ c2c_output = gr.Textbox(
504
+ label="",
505
+ lines=18,
506
+ max_lines=30,
507
+ interactive=False,
508
+ show_label=False
509
+ )
510
+
511
+ # Event handlers
512
+ submit_btn.click(
513
+ fn=respond,
514
+ inputs=[user_input],
515
+ outputs=[single_output, t2t_context_output, t2t_answer_output, c2c_output]
516
+ )
517
+
518
+ user_input.submit(
519
+ fn=respond,
520
+ inputs=[user_input],
521
+ outputs=[single_output, t2t_context_output, t2t_answer_output, c2c_output]
522
+ )
523
+
524
+ clear_btn.click(
525
+ fn=lambda: ("", "", "", "", ""),
526
+ inputs=None,
527
+ outputs=[user_input, single_output, t2t_context_output, t2t_answer_output, c2c_output]
528
+ )
529
+
530
+ # Example question handlers
531
+ example1_btn.click(
532
+ fn=lambda: EXAMPLE_QUESTIONS["example1"],
533
+ inputs=None,
534
+ outputs=[user_input]
535
+ )
536
+
537
+ example2_btn.click(
538
+ fn=lambda: EXAMPLE_QUESTIONS["example2"],
539
+ inputs=None,
540
+ outputs=[user_input]
541
+ )
542
+
543
+ return demo
544
+
545
+
546
+ def main():
547
+ """Main entry point."""
548
+ print("=" * 60)
549
+ print("Model Comparison Demo - Gradio Interface")
550
+ print("=" * 60)
551
+
552
+ # Initialize models
553
+ model_manager = ModelManager()
554
+
555
+ # Create and launch demo
556
+ demo = create_demo(model_manager)
557
+
558
+ print("\n" + "=" * 60)
559
+ print("🚀 Launching Gradio interface...")
560
+ print("=" * 60)
561
+
562
+ demo.launch(
563
+ server_name="0.0.0.0",
564
+ server_port=7860,
565
+ share=False,
566
+ show_error=True
567
+ )
568
+
569
+
570
+ if __name__ == "__main__":
571
+ main()
572
+
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies (from pyproject.toml)
2
+ torch==2.6.0
3
+ transformers==4.52.4
4
+
5
+ # Gradio web interface
6
+ gradio==5.9.1
7
+
8
+ # HuggingFace Hub for checkpoint downloads
9
+ huggingface-hub>=0.26.0
10
+
11
+ # Configuration file parsing
12
+ pyyaml>=6.0
13
+
14
+ # Additional dependencies for rosetta package
15
+ datasets>=2.0
16
+ accelerate>=0.20
17
+
18
+ # Utilities
19
+ einops>=0.8
20
+
rosetta/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (530 Bytes). View file
 
rosetta/baseline/__pycache__/multi_stage.cpython-310.pyc ADDED
Binary file (24.3 kB). View file
 
rosetta/baseline/__pycache__/multi_stage.cpython-312.pyc ADDED
Binary file (34.7 kB). View file
 
rosetta/baseline/__pycache__/two_stage_rosetta.cpython-310.pyc ADDED
Binary file (9.34 kB). View file
 
rosetta/baseline/__pycache__/two_stage_rosetta.cpython-312.pyc ADDED
Binary file (13.2 kB). View file
 
rosetta/baseline/multi_stage.py ADDED
@@ -0,0 +1,992 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-stage evaluation utilities for VLM+LLM and LLM+LLM pipelines.
3
+
4
+ This module provides utilities for multi-stage evaluation where:
5
+ 1. VLM describes/analyzes images + LLM performs reasoning
6
+ 2. LLM provides background context + LLM performs reasoning
7
+ """
8
+
9
+ from typing import Dict, Optional, Any
10
+ import torch
11
+ from transformers import (
12
+ # Qwen2_5_VLForConditionalGeneration,
13
+ AutoProcessor,
14
+ AutoTokenizer,
15
+ AutoModelForCausalLM,
16
+ )
17
+ from rosetta.utils.evaluate import set_default_chat_template, apply_generation_config
18
+
19
+ try:
20
+ from qwen_vl_utils import process_vision_info
21
+ from transformers import Qwen2_5_VLForConditionalGeneration
22
+ except ImportError:
23
+ print("Please install qwen-vl-utils to use VLM models")
24
+
25
+ class TwoStageInference:
26
+ """Two-stage LLM+LLM inference pipeline for question answering."""
27
+
28
+ def __init__(
29
+ self,
30
+ context_model_path: str,
31
+ answer_model_path: str,
32
+ device: str = "cuda",
33
+ max_new_tokens: int = 1024,
34
+ background_prompt: str = "Briefly describe the most useful background to solve the problem:\n\n{question}",
35
+ generation_config: Optional[Dict[str, Any]] = None
36
+ ):
37
+ """
38
+ Initialize two-stage LLM pipeline.
39
+
40
+ Args:
41
+ context_model_path: Path to context-providing LLM
42
+ answer_model_path: Path to answer-generating LLM
43
+ device: Device to use
44
+ max_new_tokens: Maximum number of new tokens to generate
45
+ background_prompt: Prompt template for background generation
46
+ generation_config: Optional generation configuration to apply to models
47
+ """
48
+ self.device = device
49
+ self.max_new_tokens = max_new_tokens
50
+ self.background_prompt = background_prompt
51
+ self.generation_config = generation_config or {}
52
+ self._load_models(context_model_path, answer_model_path)
53
+
54
+ def _load_models(self, context_path: str, answer_path: str):
55
+ """Load both LLM models."""
56
+ # Load context LLM
57
+ self.context_tokenizer = AutoTokenizer.from_pretrained(context_path)
58
+ # for gemma, set sliding_window=4096
59
+ if context_path == "google/gemma-3-1b-it":
60
+ torch._dynamo.config.cache_size_limit = 64
61
+ self.context_model = AutoModelForCausalLM.from_pretrained(
62
+ context_path, torch_dtype=torch.bfloat16, device_map={"": self.device}, sliding_window=4096
63
+ )
64
+ else:
65
+ self.context_model = AutoModelForCausalLM.from_pretrained(
66
+ context_path, torch_dtype=torch.bfloat16, device_map={"": self.device}
67
+ )
68
+ # Apply generation config to context model
69
+ apply_generation_config(self.context_model, self.generation_config)
70
+
71
+ # Load answer LLM
72
+ self.answer_tokenizer = AutoTokenizer.from_pretrained(answer_path)
73
+ self.answer_model = AutoModelForCausalLM.from_pretrained(
74
+ answer_path, torch_dtype=torch.bfloat16, device_map={"": self.device}
75
+ )
76
+ # Apply generation config to answer model
77
+ apply_generation_config(self.answer_model, self.generation_config)
78
+
79
+ def get_background_context(
80
+ self,
81
+ question: str,
82
+ max_new_tokens: Optional[int] = None
83
+ ) -> str:
84
+ """
85
+ Get background context from the first LLM.
86
+
87
+ Args:
88
+ question: Question text (without options)
89
+ max_new_tokens: Max tokens to generate (uses instance default if None)
90
+
91
+ Returns:
92
+ Background context
93
+ """
94
+ prompt = self.background_prompt.format(question=question)
95
+ messages = [{"role": "user", "content": prompt}]
96
+
97
+ template_kwargs = {'enable_thinking': False}
98
+
99
+ inputs = self.context_tokenizer.apply_chat_template(
100
+ messages,
101
+ tokenize=True,
102
+ add_generation_prompt=True,
103
+ return_tensors="pt",
104
+ **template_kwargs
105
+ )
106
+ inputs = inputs.to(self.device)
107
+
108
+ if max_new_tokens is None:
109
+ max_new_tokens = self.max_new_tokens
110
+
111
+ with torch.inference_mode():
112
+ outputs = self.context_model.generate(
113
+ inputs,
114
+ max_new_tokens=max_new_tokens,
115
+ do_sample=False
116
+ )
117
+
118
+ generated_ids = outputs[:, inputs.shape[-1]:]
119
+ context = self.context_tokenizer.batch_decode(
120
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
121
+ )[0]
122
+
123
+ return context
124
+
125
+ def answer_with_context(
126
+ self,
127
+ question: str,
128
+ context: str,
129
+ max_new_tokens: Optional[int] = None,
130
+ original_question: Optional[str] = None
131
+ ) -> str:
132
+ """
133
+ Answer question using the second LLM with context.
134
+
135
+ Args:
136
+ question: Full question with options and proper template
137
+ context: Background context from first LLM
138
+ max_new_tokens: Max tokens to generate (uses instance default if None)
139
+ original_question: Original question asked to first LLM (for conversation format)
140
+
141
+ Returns:
142
+ Generated answer
143
+ """
144
+ # Use conversation format: user asks for background, assistant provides it, user asks main question
145
+ if original_question:
146
+ messages = [
147
+ {"role": "user", "content": self.background_prompt.format(question=original_question)},
148
+ {"role": "assistant", "content": context},
149
+ {"role": "user", "content": question}
150
+ ]
151
+ else:
152
+ # Fallback to simple format
153
+ messages = [{"role": "user", "content": f"Background context: {context}\n\n{question}"}]
154
+
155
+ template_kwargs = {'enable_thinking': False}
156
+
157
+ inputs = self.answer_tokenizer.apply_chat_template(
158
+ messages,
159
+ tokenize=True,
160
+ add_generation_prompt=True,
161
+ return_tensors="pt",
162
+ **template_kwargs
163
+ )
164
+ inputs = inputs.to(self.device)
165
+
166
+ if max_new_tokens is None:
167
+ max_new_tokens = self.max_new_tokens
168
+
169
+ with torch.inference_mode():
170
+ outputs = self.answer_model.generate(
171
+ inputs,
172
+ max_new_tokens=max_new_tokens,
173
+ do_sample=False
174
+ )
175
+
176
+ generated_ids = outputs[:, inputs.shape[-1]:]
177
+ answer = self.answer_tokenizer.batch_decode(
178
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
179
+ )[0]
180
+
181
+ return answer
182
+
183
+ def forward_with_context(
184
+ self,
185
+ question: str,
186
+ context: str,
187
+ original_question: Optional[str] = None,
188
+ response_text: Optional[str] = None,
189
+ **forward_kwargs
190
+ ) -> Any:
191
+ """
192
+ Run a forward pass on the answer model using provided context (logits mode).
193
+
194
+ Args:
195
+ question: Full question with options and proper template
196
+ context: Background context from first LLM
197
+ original_question: Original question asked to first LLM (for conversation format)
198
+ response_text: Optional text to append after the chat template to steer next-token logits
199
+ **forward_kwargs: Extra kwargs forwarded to the model's forward
200
+
201
+ Returns:
202
+ Model outputs from the forward pass (e.g., logits)
203
+ """
204
+ # Use conversation format: user asks for background, assistant provides it, user asks main question
205
+ if original_question:
206
+ messages = [
207
+ {"role": "user", "content": self.background_prompt.format(question=original_question)},
208
+ {"role": "assistant", "content": context},
209
+ {"role": "user", "content": question}
210
+ ]
211
+ else:
212
+ # Fallback to simple format
213
+ messages = [{"role": "user", "content": f"Background context: {context}\n\n{question}"}]
214
+
215
+ template_kwargs = {'enable_thinking': False}
216
+
217
+ # Build model inputs; if response_text is provided, append it to steer next-token prediction
218
+ if response_text is not None:
219
+ # Build raw text then append response_text
220
+ text = self.answer_tokenizer.apply_chat_template(
221
+ messages,
222
+ tokenize=False,
223
+ add_generation_prompt=True,
224
+ **template_kwargs
225
+ )
226
+ text = text + response_text
227
+ tokenized = self.answer_tokenizer(text, return_tensors="pt")
228
+ else:
229
+ # Directly build tensors with generation prompt (predict next assistant token)
230
+ tokenized = self.answer_tokenizer.apply_chat_template(
231
+ messages,
232
+ tokenize=True,
233
+ add_generation_prompt=True,
234
+ return_tensors="pt",
235
+ **template_kwargs
236
+ )
237
+
238
+ inputs = {k: v.to(self.device) for k, v in tokenized.items()}
239
+
240
+ with torch.inference_mode():
241
+ outputs = self.answer_model(**inputs, **forward_kwargs)
242
+
243
+ return outputs
244
+
245
+ def forward(
246
+ self,
247
+ question_without_options: str,
248
+ question_with_options: str,
249
+ response_text: Optional[str] = None,
250
+ **forward_kwargs
251
+ ) -> Any:
252
+ """
253
+ Two-stage forward pass (logits mode):
254
+ 1) Generate background context with the context model
255
+ 2) Run a forward pass on the answer model conditioned on that context
256
+
257
+ Args:
258
+ question_without_options: Question text without multiple choice options
259
+ question_with_options: Full question with options and proper template
260
+ response_text: Optional text appended after the chat template to steer next-token logits
261
+ **forward_kwargs: Extra kwargs forwarded to the model's forward
262
+
263
+ Returns:
264
+ Model outputs from the forward pass (e.g., logits)
265
+ """
266
+ context = self.get_background_context(question_without_options)
267
+ return self.forward_with_context(
268
+ question=question_with_options,
269
+ context=context,
270
+ original_question=question_without_options,
271
+ response_text=response_text,
272
+ **forward_kwargs
273
+ )
274
+
275
+ def logits_with_context(
276
+ self,
277
+ question_without_options: str,
278
+ question_with_options: str,
279
+ response_text: Optional[str] = None,
280
+ **forward_kwargs
281
+ ) -> Any:
282
+ """
283
+ Two-stage logits helper that also returns the generated background context
284
+ for logging as CoT.
285
+
286
+ Returns:
287
+ (outputs, context)
288
+ """
289
+ context = self.get_background_context(question_without_options)
290
+ outputs = self.forward_with_context(
291
+ question=question_with_options,
292
+ context=context,
293
+ original_question=question_without_options,
294
+ response_text=response_text,
295
+ **forward_kwargs
296
+ )
297
+ return outputs, context
298
+
299
+ def generate(
300
+ self,
301
+ question_without_options: str,
302
+ question_with_options: str,
303
+ communication_max_new_tokens: Optional[int] = None,
304
+ response_max_new_tokens: Optional[int] = None,
305
+ **kwargs
306
+ ) -> str:
307
+ """
308
+ Generate answer using two-stage processing.
309
+
310
+ Args:
311
+ question_without_options: Question text without multiple choice options
312
+ question_with_options: Full question with options and proper template
313
+ communication_max_new_tokens: Maximum tokens to generate for the background context
314
+ response_max_new_tokens: Maximum tokens to generate for the answer
315
+ **kwargs: Additional generation parameters (ignored for compatibility)
316
+
317
+ Returns:
318
+ Generated answer string
319
+ """
320
+ # Stage 1: Get background context
321
+ context = self.get_background_context(question_without_options, communication_max_new_tokens)
322
+
323
+ # Stage 2: Answer question with context
324
+ answer = self.answer_with_context(
325
+ question=question_with_options,
326
+ context=context,
327
+ original_question=question_without_options,
328
+ max_new_tokens=response_max_new_tokens
329
+ )
330
+
331
+ return answer
332
+
333
+ def process(
334
+ self,
335
+ question_without_options: str,
336
+ question_with_options: str
337
+ ) -> Dict[str, str]:
338
+ """
339
+ Full two-stage processing (legacy method for backward compatibility).
340
+
341
+ Args:
342
+ question_without_options: Question text without multiple choice options
343
+ question_with_options: Full question with options and proper template
344
+
345
+ Returns:
346
+ Dictionary with context and answer
347
+ """
348
+ # Stage 1: Get background context
349
+ context = self.get_background_context(question_without_options)
350
+
351
+ # Stage 2: Answer question with context
352
+ answer = self.answer_with_context(
353
+ question=question_with_options,
354
+ context=context,
355
+ original_question=question_without_options
356
+ )
357
+
358
+ return {
359
+ "context": context,
360
+ "answer": answer
361
+ }
362
+
363
+
364
+ class TwoStageRosetta(TwoStageInference):
365
+ """Two-stage LLM+Rosetta inference pipeline for question answering."""
366
+
367
+ def __init__(
368
+ self,
369
+ context_model_path: str,
370
+ rosetta_checkpoint_dir: str,
371
+ rosetta_subfolder: str = "final",
372
+ device: str = "cuda",
373
+ max_new_tokens: int = 1024,
374
+ background_prompt: str = "Briefly describe the most useful background to solve the problem:\n\n{question}",
375
+ generation_config: Optional[Dict[str, Any]] = None
376
+ ):
377
+ """
378
+ Initialize two-stage pipeline with Rosetta as second model.
379
+
380
+ Args:
381
+ context_model_path: Path to context-providing LLM
382
+ rosetta_checkpoint_dir: Path to Rosetta checkpoint directory
383
+ rosetta_subfolder: Subfolder name in checkpoint directory (e.g., 'final', 'checkpoint-1000')
384
+ device: Device to use
385
+ max_new_tokens: Maximum number of new tokens to generate
386
+ background_prompt: Prompt template for background generation
387
+ generation_config: Optional generation configuration to apply to models
388
+ """
389
+ # Initialize parent class with dummy answer model path
390
+ # We'll override the answer model loading
391
+ super().__init__(
392
+ context_model_path=context_model_path,
393
+ answer_model_path=None, # Will be overridden
394
+ device=device,
395
+ max_new_tokens=max_new_tokens,
396
+ background_prompt=background_prompt,
397
+ generation_config=generation_config
398
+ )
399
+
400
+ self.rosetta_checkpoint_dir = rosetta_checkpoint_dir
401
+ self.rosetta_subfolder = rosetta_subfolder
402
+ self._load_rosetta_model()
403
+
404
+ def _load_models(self, context_path: str, answer_path: str):
405
+ """
406
+ Override parent class _load_models to prevent loading dummy answer model.
407
+ We only load the context model here, and the Rosetta model is loaded separately.
408
+ """
409
+ # Only load context LLM (answer model is replaced by Rosetta)
410
+ self.context_tokenizer = AutoTokenizer.from_pretrained(context_path)
411
+ self.context_model = AutoModelForCausalLM.from_pretrained(
412
+ context_path, torch_dtype=torch.bfloat16, device_map={"": self.device}
413
+ )
414
+ # Apply generation config to context model
415
+ apply_generation_config(self.context_model, self.generation_config)
416
+
417
+ # Skip loading answer model - we use Rosetta instead
418
+ print(f"Loaded context model from {context_path}")
419
+ print("Skipping answer model loading - using Rosetta model instead")
420
+
421
+ def _load_rosetta_model(self):
422
+ """Load Rosetta model and related components following load_model_from_checkpoint pattern."""
423
+ import json
424
+ from pathlib import Path
425
+ from rosetta.utils.evaluate import load_rosetta_model
426
+
427
+ checkpoint_path = Path(self.rosetta_checkpoint_dir)
428
+
429
+ # Load config
430
+ config_path = checkpoint_path / "config.json"
431
+ if not config_path.exists():
432
+ raise FileNotFoundError(f"Config file not found: {config_path}")
433
+
434
+ with open(config_path, 'r') as f:
435
+ config = json.load(f)
436
+
437
+ # Check if this is a Rosetta model (has projectors)
438
+ subfolder_dir = checkpoint_path / self.rosetta_subfolder
439
+ has_projectors = subfolder_dir.exists() and any(
440
+ f.name.startswith("projector_") and f.name.endswith(".pt")
441
+ for f in subfolder_dir.iterdir()
442
+ )
443
+
444
+ if not has_projectors:
445
+ raise ValueError(f"No projectors found in {subfolder_dir}. This doesn't appear to be a Rosetta checkpoint.")
446
+
447
+ # Load Rosetta model (following load_model_from_checkpoint pattern)
448
+ print(f"Loading Rosetta model from {self.rosetta_checkpoint_dir}")
449
+
450
+ # Create model config for Rosetta loading
451
+ model_config = {
452
+ "model_name": "Rosetta",
453
+ "rosetta_config": {
454
+ "checkpoints_dir": str(subfolder_dir),
455
+ "base_model": config["model"]["base_model"],
456
+ "teacher_model": config["model"]["teacher_model"],
457
+ "is_do_alignment": config["model"].get("is_do_alignment", False),
458
+ "alignment_strategy": config["model"].get("alignment_strategy", "first")
459
+ }
460
+ }
461
+
462
+ print(f"Model config: {model_config}")
463
+
464
+ eval_config = {
465
+ "checkpoints_dir": str(subfolder_dir)
466
+ }
467
+
468
+ # Load Rosetta model using the existing utility
469
+ self.rosetta_model, self.rosetta_tokenizer = load_rosetta_model(
470
+ model_config,
471
+ eval_config,
472
+ device=self.device
473
+ )
474
+
475
+ # Load LLM tokenizer for alignment if needed
476
+ is_do_alignment = config["model"].get("is_do_alignment", False)
477
+ llm_model_path = config["model"].get("teacher_model")
478
+ self.llm_tokenizer = None
479
+
480
+ if is_do_alignment and llm_model_path:
481
+ try:
482
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(str(llm_model_path))
483
+ if self.llm_tokenizer.pad_token is None:
484
+ self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
485
+ set_default_chat_template(self.llm_tokenizer, llm_model_path)
486
+ except Exception as e:
487
+ print(f"Failed to load LLM tokenizer '{llm_model_path}': {e}")
488
+ self.llm_tokenizer = None
489
+
490
+ print(f"Initialized TwoStageRosetta with Rosetta model on {self.device}")
491
+
492
+ def _prepare_rosetta_inputs(
493
+ self,
494
+ question: str,
495
+ context: str,
496
+ original_question: Optional[str] = None,
497
+ max_new_tokens: Optional[int] = None,
498
+ answer_method: str = "generate",
499
+ response_text: Optional[str] = None
500
+ ) -> Dict[str, Any]:
501
+ """
502
+ Prepare inputs for Rosetta model using the simpler approach from live_chat_example.py.
503
+
504
+ Args:
505
+ question: Question to answer
506
+ context: Background context from first LLM
507
+ original_question: Original question asked to first LLM (for conversation format)
508
+ max_new_tokens: Max tokens to generate (uses instance default if None)
509
+
510
+ Returns:
511
+ Dictionary with prepared inputs for Rosetta model
512
+ """
513
+ # Use conversation format: user asks for background, assistant provides it, user asks main question
514
+ if original_question:
515
+ messages = [
516
+ {"role": "user", "content": self.background_prompt.format(question=original_question)},
517
+ {"role": "assistant", "content": context},
518
+ {"role": "user", "content": question}
519
+ ]
520
+ else:
521
+ # Fallback to simple format
522
+ messages = [{"role": "user", "content": f"Background context: {context}\n\n{question}"}]
523
+
524
+ # Apply chat template (following live_chat_example.py pattern)
525
+ base_text = None
526
+ if hasattr(self.rosetta_tokenizer, 'apply_chat_template'):
527
+ base_text = self.rosetta_tokenizer.apply_chat_template(
528
+ messages,
529
+ tokenize=False,
530
+ add_generation_prompt=True,
531
+ enable_thinking=False
532
+ )
533
+ else:
534
+ base_text = f"### Human: {question}\n### Assistant:"
535
+
536
+ # Optionally append response_text for logits mode to steer next-token logits
537
+ if answer_method == 'logits' and response_text is not None:
538
+ text = base_text + response_text
539
+ else:
540
+ text = base_text
541
+
542
+ # Tokenize input
543
+ inputs = self.rosetta_tokenizer(text, return_tensors="pt").to(self.device)
544
+
545
+ # Create kv_cache_index for Rosetta model
546
+ full_length = inputs.input_ids.shape[1]
547
+ if answer_method == 'logits':
548
+ # Compute response length as the extra tokens appended by response_text
549
+ if response_text is not None:
550
+ base_tok = self.rosetta_tokenizer(base_text, return_tensors="pt")
551
+ response_length = int(inputs.input_ids.shape[1] - base_tok.input_ids.shape[1])
552
+ response_length = max(response_length, 0)
553
+ else:
554
+ response_length = 0
555
+ instr_len = max(full_length - response_length, 0)
556
+ instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(instr_len, 1).unsqueeze(0).to(self.device)
557
+ if response_length > 0:
558
+ response_index = torch.tensor([-1, 0], dtype=torch.long).repeat(response_length, 1).unsqueeze(0).to(self.device)
559
+ kv_cache_list = [instruction_index, response_index]
560
+ else:
561
+ kv_cache_list = [instruction_index]
562
+ else:
563
+ # Generate: treat the last position as response (length 1)
564
+ instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(full_length - 1, 1).unsqueeze(0).to(self.device)
565
+ label_index = torch.tensor([-1, 0], dtype=torch.long).repeat(1, 1).unsqueeze(0).to(self.device)
566
+ kv_cache_list = [instruction_index, label_index]
567
+
568
+ # Add position_ids if needed
569
+ if inputs.attention_mask is None:
570
+ position_ids = torch.arange(inputs.input_ids.shape[-1], dtype=torch.long).unsqueeze(0).to(self.device)
571
+ else:
572
+ position_ids = inputs.attention_mask.long().cumsum(-1) - 1
573
+
574
+ outputs = {
575
+ "inputs": {
576
+ "input_ids": inputs.input_ids,
577
+ "attention_mask": inputs.attention_mask,
578
+ "position_ids": position_ids,
579
+ "kv_cache_index": kv_cache_list
580
+ },
581
+ "printable_text": text
582
+ }
583
+
584
+ return outputs
585
+
586
+ def answer_with_context(
587
+ self,
588
+ question: str,
589
+ context: str,
590
+ max_new_tokens: Optional[int] = None,
591
+ original_question: Optional[str] = None
592
+ ) -> str:
593
+ """
594
+ Answer question using Rosetta model with context.
595
+ Overrides parent class method to use Rosetta instead of regular LLM.
596
+
597
+ Args:
598
+ question: Question to answer
599
+ context: Background context from first LLM
600
+ max_new_tokens: Max tokens to generate (uses instance default if None)
601
+ original_question: Original question asked to first LLM (for conversation format)
602
+
603
+ Returns:
604
+ Generated answer
605
+ """
606
+ # Prepare inputs for Rosetta model
607
+ prepared = self._prepare_rosetta_inputs(
608
+ question=question,
609
+ context=context,
610
+ original_question=original_question,
611
+ max_new_tokens=max_new_tokens
612
+ )
613
+
614
+ if max_new_tokens is None:
615
+ max_new_tokens = self.max_new_tokens
616
+
617
+ # Generation parameters (following live_chat_example.py pattern)
618
+ sampling_params = {
619
+ 'do_sample': False,
620
+ 'max_new_tokens': max_new_tokens
621
+ }
622
+
623
+ # Generate using Rosetta model (following live_chat_example.py pattern)
624
+ input_length = prepared['inputs']['input_ids'].shape[1]
625
+
626
+ with torch.inference_mode():
627
+ outputs = self.rosetta_model.generate(
628
+ kv_cache_index=prepared['inputs']['kv_cache_index'],
629
+ input_ids=prepared['inputs']['input_ids'],
630
+ attention_mask=prepared['inputs']['attention_mask'],
631
+ position_ids=prepared['inputs']['position_ids'],
632
+ **sampling_params
633
+ )
634
+ generated_ids = outputs[0]
635
+
636
+ # Decode response
637
+ answer = self.rosetta_tokenizer.decode(generated_ids[input_length:], skip_special_tokens=True).strip()
638
+
639
+ return answer
640
+
641
+ def forward_with_context(
642
+ self,
643
+ question: str,
644
+ context: str,
645
+ original_question: Optional[str] = None,
646
+ response_text: Optional[str] = None,
647
+ **forward_kwargs
648
+ ) -> Any:
649
+ """
650
+ Run a forward pass on the Rosetta model using provided context (logits mode).
651
+
652
+ Args:
653
+ question: Full question with options and proper template
654
+ context: Background context from first LLM
655
+ original_question: Original question asked to first LLM (for conversation format)
656
+ response_text: Optional text appended after the chat template to steer next-token logits
657
+ **forward_kwargs: Extra kwargs forwarded to the model's forward
658
+
659
+ Returns:
660
+ Model outputs from the forward pass (e.g., logits)
661
+ """
662
+ prepared = self._prepare_rosetta_inputs(
663
+ question=question,
664
+ context=context,
665
+ original_question=original_question,
666
+ answer_method='logits',
667
+ response_text=response_text
668
+ )
669
+
670
+ inputs = prepared['inputs']
671
+ with torch.inference_mode():
672
+ outputs = self.rosetta_model.forward(
673
+ kv_cache_index=inputs['kv_cache_index'],
674
+ input_ids=inputs['input_ids'],
675
+ attention_mask=inputs['attention_mask'],
676
+ position_ids=inputs['position_ids'],
677
+ **forward_kwargs
678
+ )
679
+ return outputs
680
+
681
+ def forward(
682
+ self,
683
+ question_without_options: str,
684
+ question_with_options: str,
685
+ response_text: Optional[str] = None,
686
+ **forward_kwargs
687
+ ) -> Any:
688
+ """
689
+ Two-stage forward pass (logits mode) for Rosetta:
690
+ 1) Generate background context with the context model
691
+ 2) Run a forward pass on the Rosetta model conditioned on that context
692
+
693
+ Args:
694
+ question_without_options: Question text without multiple choice options
695
+ question_with_options: Full question with options and proper template
696
+ response_text: Optional text appended after the chat template to steer next-token logits
697
+ **forward_kwargs: Extra kwargs forwarded to the model's forward
698
+
699
+ Returns:
700
+ Model outputs from the forward pass (e.g., logits)
701
+ """
702
+ # Work in progress
703
+ raise NotImplementedError
704
+ context = self.get_background_context(question_without_options)
705
+ return self.forward_with_context(
706
+ question=question_with_options,
707
+ context=context,
708
+ original_question=question_without_options,
709
+ response_text=response_text,
710
+ **forward_kwargs
711
+ )
712
+
713
+ def logits_with_context(
714
+ self,
715
+ question_without_options: str,
716
+ question_with_options: str,
717
+ response_text: Optional[str] = None,
718
+ **forward_kwargs
719
+ ) -> Any:
720
+ """
721
+ Two-stage logits helper that also returns the generated background context
722
+ for logging as CoT (Rosetta backend).
723
+
724
+ Returns:
725
+ (outputs, context)
726
+ """
727
+ context = self.get_background_context(question_without_options)
728
+ outputs = self.forward_with_context(
729
+ question=question_with_options,
730
+ context=context,
731
+ original_question=question_without_options,
732
+ response_text=response_text,
733
+ **forward_kwargs
734
+ )
735
+ return outputs, context
736
+
737
+ def generate(
738
+ self,
739
+ question_without_options: str,
740
+ question_with_options: str,
741
+ max_new_tokens: Optional[int] = None,
742
+ **kwargs
743
+ ) -> str:
744
+ """
745
+ Generate answer using two-stage processing with Rosetta.
746
+
747
+ Args:
748
+ question_without_options: Question text without multiple choice options
749
+ question_with_options: Full question with options and proper template
750
+ max_new_tokens: Maximum tokens to generate (passed to both stages)
751
+ **kwargs: Additional generation parameters (ignored for compatibility)
752
+
753
+ Returns:
754
+ Generated answer string
755
+ """
756
+ # Stage 1: Get background context (uses parent class method)
757
+ context = self.get_background_context(question_without_options, max_new_tokens)
758
+
759
+ # Stage 2: Answer question with context using Rosetta
760
+ answer = self.answer_with_context(
761
+ question=question_with_options,
762
+ context=context,
763
+ original_question=question_without_options,
764
+ max_new_tokens=max_new_tokens
765
+ )
766
+
767
+ return answer
768
+
769
+ def process(
770
+ self,
771
+ question_without_options: str,
772
+ question_with_options: str
773
+ ) -> Dict[str, str]:
774
+ """
775
+ Full two-stage processing with Rosetta (legacy method for backward compatibility).
776
+
777
+ Args:
778
+ question_without_options: Question text without multiple choice options
779
+ question_with_options: Full question with options and proper template
780
+
781
+ Returns:
782
+ Dictionary with context and answer
783
+ """
784
+ # Stage 1: Get background context (uses parent class method)
785
+ context = self.get_background_context(question_without_options)
786
+
787
+ # Stage 2: Answer question with context using Rosetta
788
+ answer = self.answer_with_context(
789
+ question=question_with_options,
790
+ context=context,
791
+ original_question=question_without_options
792
+ )
793
+
794
+ return {
795
+ "context": context,
796
+ "answer": answer
797
+ }
798
+
799
+
800
+
801
+ class MultiModalInference:
802
+ """Multi-modal VLM+LLM inference pipeline."""
803
+
804
+ def __init__(
805
+ self,
806
+ vlm_model_path: str,
807
+ llm_model_path: str,
808
+ device: str = "cuda",
809
+ max_new_tokens: int = 1024,
810
+ generation_config: Optional[Dict[str, Any]] = None
811
+ ):
812
+ """
813
+ Initialize multi-modal pipeline.
814
+
815
+ Args:
816
+ vlm_model_path: Path to VLM model
817
+ llm_model_path: Path to LLM model
818
+ device: Device to use
819
+ max_new_tokens: Maximum number of new tokens to generate
820
+ generation_config: Optional generation configuration to apply to models
821
+ """
822
+ self.device = device
823
+ self.max_new_tokens = max_new_tokens
824
+ self.generation_config = generation_config or {}
825
+ self._load_models(vlm_model_path, llm_model_path)
826
+
827
+ def _load_models(self, vlm_path: str, llm_path: str):
828
+ """Load VLM and LLM models."""
829
+ # Load VLM
830
+ self.vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
831
+ vlm_path,
832
+ torch_dtype=torch.bfloat16,
833
+ device_map={"": self.device},
834
+ )
835
+ # Apply generation config to VLM model
836
+ apply_generation_config(self.vlm_model, self.generation_config)
837
+ self.vlm_processor = AutoProcessor.from_pretrained(vlm_path)
838
+
839
+ # Load LLM
840
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(llm_path)
841
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
842
+ llm_path, torch_dtype=torch.bfloat16, device_map={"": self.device}
843
+ )
844
+ # Apply generation config to LLM model
845
+ apply_generation_config(self.llm_model, self.generation_config)
846
+
847
+ def get_image_description(
848
+ self,
849
+ image_path: str,
850
+ prompt: str = "Describe this image in detail.",
851
+ max_new_tokens: Optional[int] = None
852
+ ) -> str:
853
+ """
854
+ Get image description from VLM.
855
+
856
+ Args:
857
+ image_path: Path to image
858
+ prompt: Description prompt
859
+ max_new_tokens: Max tokens to generate (uses instance default if None)
860
+
861
+ Returns:
862
+ Image description
863
+ """
864
+ messages = [{
865
+ "role": "user",
866
+ "content": [
867
+ {"type": "image", "image": image_path},
868
+ {"type": "text", "text": prompt}
869
+ ]
870
+ }]
871
+
872
+ text = self.vlm_processor.apply_chat_template(
873
+ messages, tokenize=False, add_generation_prompt=True
874
+ )
875
+ image_inputs, video_inputs = process_vision_info(messages)
876
+ inputs = self.vlm_processor(
877
+ text=[text],
878
+ images=image_inputs,
879
+ videos=video_inputs,
880
+ padding=True,
881
+ return_tensors="pt",
882
+ )
883
+ inputs = inputs.to(self.device)
884
+
885
+ if max_new_tokens is None:
886
+ max_new_tokens = self.max_new_tokens
887
+
888
+ with torch.inference_mode():
889
+ outputs = self.vlm_model.generate(
890
+ **inputs,
891
+ max_new_tokens=max_new_tokens,
892
+ do_sample=False,
893
+ )
894
+
895
+ generated_ids = outputs[:, inputs["input_ids"].shape[-1]:]
896
+ description = self.vlm_processor.batch_decode(
897
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
898
+ )[0]
899
+
900
+ return description
901
+
902
+ def answer_with_context(
903
+ self,
904
+ question: str,
905
+ context: str,
906
+ max_new_tokens: Optional[int] = None,
907
+ original_question: Optional[str] = None
908
+ ) -> str:
909
+ """
910
+ Answer question using LLM with context.
911
+
912
+ Args:
913
+ question: Question to answer
914
+ context: Context (e.g., image description from VLM)
915
+ max_new_tokens: Max tokens to generate (uses instance default if None)
916
+ original_question: Original question asked to VLM (for conversation format)
917
+
918
+ Returns:
919
+ Generated answer
920
+ """
921
+ # Use conversation format: user asks about image, assistant describes, user asks follow-up
922
+ if original_question:
923
+ messages = [
924
+ {"role": "user", "content": original_question},
925
+ {"role": "assistant", "content": context},
926
+ {"role": "user", "content": question}
927
+ ]
928
+ else:
929
+ # Fallback to simple format
930
+ messages = [{"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}]
931
+
932
+ template_kwargs = {'enable_thinking': False}
933
+
934
+ # Some tokenizers may not support enable_thinking parameter
935
+ inputs = self.llm_tokenizer.apply_chat_template(
936
+ messages,
937
+ tokenize=True,
938
+ add_generation_prompt=True,
939
+ return_tensors="pt",
940
+ **template_kwargs
941
+ )
942
+
943
+ inputs = inputs.to(self.llm_model.device)
944
+
945
+ if max_new_tokens is None:
946
+ max_new_tokens = self.max_new_tokens
947
+
948
+ with torch.inference_mode():
949
+ outputs = self.llm_model.generate(
950
+ inputs,
951
+ max_new_tokens=max_new_tokens,
952
+ do_sample=False
953
+ )
954
+
955
+ generated_ids = outputs[:, inputs.shape[-1]:]
956
+ answer = self.llm_tokenizer.batch_decode(
957
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
958
+ )[0]
959
+
960
+ return answer
961
+
962
+ def process(
963
+ self,
964
+ image_path: str,
965
+ question: str,
966
+ description_prompt: str = "Briefly describe this image."
967
+ ) -> Dict[str, str]:
968
+ """
969
+ Full multi-stage processing.
970
+
971
+ Args:
972
+ image_path: Path to image
973
+ question: Question to answer
974
+ description_prompt: Prompt for image description
975
+
976
+ Returns:
977
+ Dictionary with description and answer
978
+ """
979
+ # Stage 1: Get image description
980
+ description = self.get_image_description(image_path, description_prompt)
981
+
982
+ # Stage 2: Answer question with context (pass original prompt for conversation format)
983
+ answer = self.answer_with_context(
984
+ question=question,
985
+ context=description,
986
+ original_question=description_prompt
987
+ )
988
+
989
+ return {
990
+ "description": description,
991
+ "answer": answer
992
+ }
rosetta/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (419 Bytes). View file
 
rosetta/model/__pycache__/aggregator.cpython-310.pyc ADDED
Binary file (5.83 kB). View file
 
rosetta/model/__pycache__/aggregator.cpython-312.pyc ADDED
Binary file (8.53 kB). View file
 
rosetta/model/__pycache__/aligner.cpython-310.pyc ADDED
Binary file (16.3 kB). View file
 
rosetta/model/__pycache__/aligner.cpython-312.pyc ADDED
Binary file (24.2 kB). View file
 
rosetta/model/__pycache__/all_in_one_projector.cpython-310.pyc ADDED
Binary file (27.8 kB). View file
 
rosetta/model/__pycache__/all_in_one_projector.cpython-312.pyc ADDED
Binary file (53.5 kB). View file
 
rosetta/model/__pycache__/cache.cpython-310.pyc ADDED
Binary file (980 Bytes). View file
 
rosetta/model/__pycache__/oracle.cpython-310.pyc ADDED
Binary file (13.8 kB). View file
 
rosetta/model/__pycache__/oracle.cpython-312.pyc ADDED
Binary file (23.5 kB). View file
 
rosetta/model/__pycache__/projector.cpython-310.pyc ADDED
Binary file (30.8 kB). View file
 
rosetta/model/__pycache__/projector.cpython-312.pyc ADDED
Binary file (56 kB). View file
 
rosetta/model/__pycache__/sampling.cpython-310.pyc ADDED
Binary file (2.08 kB). View file
 
rosetta/model/__pycache__/sampling.cpython-312.pyc ADDED
Binary file (3.49 kB). View file
 
rosetta/model/__pycache__/wrapper.cpython-310.pyc ADDED
Binary file (18.6 kB). View file
 
rosetta/model/__pycache__/wrapper.cpython-312.pyc ADDED
Binary file (33.8 kB). View file
 
rosetta/model/ablation_projector.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ablation Projector: A configurable projector for ablation studies based on C2CProjector.
3
+ Allows gradual removal of components to study their individual contributions.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch import Tensor
9
+ from typing import Optional, Tuple, Literal
10
+
11
+ from rosetta.utils.registry import register_model, capture_init_args
12
+ from rosetta.model.projector import Projector
13
+ from rosetta.model.projector import RegularMLP
14
+
15
+
16
+ @register_model
17
+ @capture_init_args
18
+ class AblationProjector(Projector):
19
+ """
20
+ Ablation study projector based on C2CProjector with configurable component removal.
21
+
22
+ Ablation levels:
23
+ 0. Full C2C (baseline)
24
+ 1. Remove scalar weights (set to 1.0)
25
+ 2. Remove gates (set to 1.0)
26
+ 3. Remove target contribution (only use source)
27
+ 4. Remove gates only (gates=1.0), keep scalars and target
28
+
29
+ Each level builds on the previous one, allowing gradual degradation study.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ source_dim: int,
35
+ target_dim: int,
36
+ source_num_heads: int = 1,
37
+ target_num_heads: int = 1,
38
+ intermediate_dim: int = 1024,
39
+ hidden_dim: int = 1024,
40
+ num_layers: int = 3,
41
+ dropout: float = 0.1,
42
+ initial_temperature: float = 1.0,
43
+ final_temperature: float = 0.001,
44
+ anneal_steps: int = 1929,
45
+ dtype: torch.dtype = torch.float32,
46
+
47
+ # Ablation configuration
48
+ ablation_level: int = 0, # 0=full, 1=no_scalar, 2=no_gate+no_scalar, 3=no_target, 4=no_gate_only
49
+ use_scalar_weights: bool = True, # Can be overridden by ablation_level
50
+ use_gates: bool = True, # Can be overridden by ablation_level
51
+ use_target: bool = True, # Can be overridden by ablation_level
52
+ ):
53
+ super().__init__()
54
+
55
+ assert 0 <= ablation_level <= 4, "ablation_level must be 0, 1, 2, 3, or 4"
56
+
57
+ # Dimensions
58
+ self.source_dim = source_dim
59
+ self.target_dim = target_dim
60
+ self.source_num_heads = source_num_heads
61
+ self.target_num_heads = target_num_heads
62
+ self.ablation_level = ablation_level
63
+
64
+ # Override component usage based on ablation level
65
+ if ablation_level == 4:
66
+ # Special case: disable gates only, keep scalars and target
67
+ use_scalar_weights = True
68
+ use_gates = False
69
+ use_target = True
70
+ else:
71
+ if ablation_level >= 1:
72
+ use_scalar_weights = False
73
+ if ablation_level >= 2:
74
+ use_gates = False
75
+ if ablation_level >= 3:
76
+ use_target = False
77
+
78
+ self.use_scalar_weights = use_scalar_weights
79
+ self.use_gates = use_gates
80
+ self.use_target = use_target
81
+
82
+ # Sizes
83
+ in_dim = source_dim * source_num_heads
84
+ out_dim = target_dim * target_num_heads
85
+
86
+ # 1) concat(source_X, target_X) then project to hidden_dim
87
+ # If not using target, only use source features
88
+ if self.use_target:
89
+ self.key_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype)
90
+ self.value_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype)
91
+ else:
92
+ # Only use source features
93
+ self.key_in = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype)
94
+ self.value_in = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype)
95
+
96
+ # 2) one-layer common embedding MLP to get intermediate representation (at hidden_dim)
97
+ self.key_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype)
98
+ self.value_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype)
99
+
100
+ # 3a) intermediate representation → (L-2)-layer MLP for weights → project to head dim
101
+ # Only build if using scalar weights
102
+ if self.use_scalar_weights:
103
+ self.key_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype)
104
+ self.value_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype)
105
+ self.key_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype)
106
+ self.value_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype)
107
+
108
+ # 3b) intermediate representation → (L-2)-layer MLP for projected_X → finally project hidden_dim → out_dim
109
+ self.key_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype)
110
+ self.value_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype)
111
+ self.key_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype)
112
+ self.value_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype)
113
+
114
+ # Scalar key/value gate parameters and temperature schedule
115
+ # Only build if using gates
116
+ if self.use_gates:
117
+ self.key_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
118
+ self.value_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
119
+ self.use_gumbel = True
120
+ self.register_buffer("gate_temperature", torch.tensor(initial_temperature, dtype=dtype))
121
+ self.initial_temperature = initial_temperature
122
+ self.final_temperature = final_temperature
123
+ self.anneal_steps = anneal_steps
124
+
125
+ # Temperature for weight normalization
126
+ self.scalar_temperature = 1.0
127
+
128
+ def update_temperature(self, step: int):
129
+ """Update temperature using exponential annealing schedule for gates."""
130
+ if self.use_gates:
131
+ ratio = min(step / self.anneal_steps, 1.0)
132
+ temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
133
+ self.gate_temperature.fill_(temp)
134
+
135
+ def forward(
136
+ self,
137
+ source_kv: Tuple[Tensor, Tensor],
138
+ target_kv: Tuple[Tensor, Tensor],
139
+ position_ids: Optional[Tensor] = None,
140
+ max_pos: Optional[Tensor] = None,
141
+ ) -> Tuple[Tensor, Tensor]:
142
+ source_key, source_value = source_kv
143
+ target_key, target_value = target_kv
144
+
145
+ B, Hs, N, Ds = source_key.shape
146
+ _, Ht, _, Dt = target_key.shape
147
+
148
+ # Flatten heads
149
+ source_key_flat = source_key.transpose(1, 2).contiguous().view(B, N, Hs * Ds)
150
+ source_value_flat = source_value.transpose(1, 2).contiguous().view(B, N, Hs * Ds)
151
+ target_key_flat = target_key.transpose(1, 2).contiguous().view(B, N, Ht * Dt)
152
+ target_value_flat = target_value.transpose(1, 2).contiguous().view(B, N, Ht * Dt)
153
+
154
+ # 1) Prepare input features based on ablation level
155
+ if self.use_target:
156
+ # Full C2C: concat source and target features
157
+ key_cat = torch.cat([source_key_flat, target_key_flat], dim=-1)
158
+ value_cat = torch.cat([source_value_flat, target_value_flat], dim=-1)
159
+ else:
160
+ # Ablation level 3: only use source features
161
+ key_cat = source_key_flat
162
+ value_cat = source_value_flat
163
+
164
+ # 2) project to hidden dim
165
+ key_hidden = self.key_in(key_cat)
166
+ value_hidden = self.value_in(value_cat)
167
+
168
+ # 3) one-layer common embedding MLP to get intermediate representation (at hidden_dim)
169
+ key_hidden = self.key_mlp1(key_hidden)
170
+ value_hidden = self.value_mlp1(value_hidden)
171
+
172
+ # 4b) intermediate representation -> projected feature path
173
+ key_proj_hidden = self.key_proj_out(self.key_proj_mlp2(key_hidden)) # (B, N, Ht * Dt)
174
+ value_proj_hidden = self.value_proj_out(self.value_proj_mlp2(value_hidden)) # (B, N, Ht * Dt)
175
+ projected_key = key_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt)
176
+ projected_value = value_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt)
177
+
178
+ # 4a) intermediate representation -> scalar path (if using scalar weights)
179
+ if self.use_scalar_weights:
180
+ key_scalar = self.key_scalar_head(self.key_scalar_mlp2(key_hidden)) # (B, N, Ht)
181
+ value_scalar = self.value_scalar_head(self.value_scalar_mlp2(value_hidden)) # (B, N, Ht)
182
+ key_scalar = key_scalar.permute(0, 2, 1).unsqueeze(-1) # (B, Ht, N, 1)
183
+ value_scalar = value_scalar.permute(0, 2, 1).unsqueeze(-1) # (B, Ht, N, 1)
184
+ # Normalize scalars
185
+ norm_key_scalar = torch.sigmoid(key_scalar)
186
+ norm_value_scalar = torch.sigmoid(value_scalar)
187
+ else:
188
+ # Ablation level 1+: set scalar weights to 1.0
189
+ norm_key_scalar = torch.ones(B, Ht, N, 1, device=projected_key.device, dtype=projected_key.dtype)
190
+ norm_value_scalar = torch.ones(B, Ht, N, 1, device=projected_value.device, dtype=projected_value.dtype)
191
+
192
+ # Key/value gates (if using gates)
193
+ if self.use_gates:
194
+ key_gate_logit = self.key_gate_logit.view(1, 1, 1, 1)
195
+ value_gate_logit = self.value_gate_logit.view(1, 1, 1, 1)
196
+ if self.training and self.use_gumbel:
197
+ u1 = torch.rand(B, Ht, N, 1, device=key_gate_logit.device, dtype=key_gate_logit.dtype)
198
+ u2 = torch.rand(B, Ht, N, 1, device=value_gate_logit.device, dtype=value_gate_logit.dtype)
199
+ g1 = -torch.log(-torch.log(u1 + 1e-20) + 1e-20)
200
+ g2 = -torch.log(-torch.log(u2 + 1e-20) + 1e-20)
201
+ key_gate = torch.sigmoid((key_gate_logit + g1) / self.gate_temperature)
202
+ value_gate = torch.sigmoid((value_gate_logit + g2) / self.gate_temperature)
203
+ else:
204
+ key_gate = (key_gate_logit > 0).float()
205
+ value_gate = (value_gate_logit > 0).float()
206
+ else:
207
+ # Gates disabled: set gates to 1.0 (always open)
208
+ key_gate = torch.ones(B, Ht, N, 1, device=projected_key.device, dtype=projected_key.dtype)
209
+ value_gate = torch.ones(B, Ht, N, 1, device=projected_value.device, dtype=projected_value.dtype)
210
+
211
+ # Compute projected contribution
212
+ projected_key_term = key_gate * norm_key_scalar * projected_key
213
+ projected_value_term = value_gate * norm_value_scalar * projected_value
214
+
215
+ # Compute target contribution (if using target)
216
+ if self.use_target:
217
+ # Full C2C: add target with projected
218
+ output_key = target_key + projected_key_term
219
+ output_value = target_value + projected_value_term
220
+ else:
221
+ # Ablation level 3: only use projected (no target)
222
+ output_key = projected_key_term
223
+ output_value = projected_value_term
224
+
225
+ return output_key, output_value
226
+
227
+ def get_ablation_info(self) -> dict:
228
+ """Return information about current ablation configuration."""
229
+ return {
230
+ 'ablation_level': self.ablation_level,
231
+ 'use_scalar_weights': self.use_scalar_weights,
232
+ 'use_gates': self.use_gates,
233
+ 'use_target': self.use_target,
234
+ 'description': self._get_ablation_description()
235
+ }
236
+
237
+ def _get_ablation_description(self) -> str:
238
+ """Get human-readable description of current ablation level."""
239
+ descriptions = {
240
+ 0: "Full C2C (baseline)",
241
+ 1: "No scalar weights (scalars=1.0)",
242
+ 2: "No gates (gates=1.0) + No scalar weights",
243
+ 3: "No target (source-only) + No gates + No scalar weights",
244
+ 4: "No gates (gates=1.0), keep scalars and target"
245
+ }
246
+ return descriptions.get(self.ablation_level, "Unknown ablation level")
247
+
248
+
249
+ # Convenience functions for creating specific ablation levels
250
+ def create_ablation_projector(
251
+ source_dim: int,
252
+ target_dim: int,
253
+ source_num_heads: int = 1,
254
+ target_num_heads: int = 1,
255
+ ablation_level: int = 0,
256
+ **kwargs
257
+ ) -> AblationProjector:
258
+ """Create an AblationProjector with specified ablation level."""
259
+ return AblationProjector(
260
+ source_dim=source_dim,
261
+ target_dim=target_dim,
262
+ source_num_heads=source_num_heads,
263
+ target_num_heads=target_num_heads,
264
+ ablation_level=ablation_level,
265
+ **kwargs
266
+ )
267
+
268
+
269
+ def create_full_c2c_projector(**kwargs) -> AblationProjector:
270
+ """Create full C2C projector (ablation level 0)."""
271
+ return create_ablation_projector(ablation_level=0, **kwargs)
272
+
273
+
274
+ def create_no_scalar_projector(**kwargs) -> AblationProjector:
275
+ """Create projector without scalar weights (ablation level 1)."""
276
+ return create_ablation_projector(ablation_level=1, **kwargs)
277
+
278
+
279
+ def create_no_gate_projector(**kwargs) -> AblationProjector:
280
+ """Create projector without gates (ablation level 2)."""
281
+ return create_ablation_projector(ablation_level=2, **kwargs)
282
+
283
+
284
+ def create_source_only_projector(**kwargs) -> AblationProjector:
285
+ """Create source-only projector (ablation level 3)."""
286
+ return create_ablation_projector(ablation_level=3, **kwargs)
287
+
288
+
289
+ def create_no_gate_only_projector(**kwargs) -> AblationProjector:
290
+ """Create projector without gates but with scalar weights and target (ablation level 4)."""
291
+ return create_ablation_projector(ablation_level=4, **kwargs)
rosetta/model/aligner.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Token Aligner for handling different tokenizers between SLM and LLM models.
3
+
4
+ This module provides functionality to align tokens between two different tokenizers,
5
+ handling cases where the same text is tokenized differently.
6
+ """
7
+
8
+ from typing import List, Tuple, Optional, Dict, Literal, Union
9
+ import torch
10
+ from transformers import PreTrainedTokenizerBase
11
+ from enum import Enum
12
+
13
+
14
+ class AlignmentStrategy(Enum):
15
+ """Strategies for handling 1-to-many token alignments"""
16
+ FIRST = "first" # Always take the first LLM token
17
+ LONGEST = "longest" # Take the LLM token with the longest string
18
+
19
+
20
+ class TokenAligner:
21
+ """
22
+ Aligns tokens between SLM (Small Language Model) and LLM (Large Language Model) tokenizers.
23
+
24
+ This class handles the case where the same text sequence is tokenized differently
25
+ by different tokenizers, using the SLM tokenization as the base and finding
26
+ corresponding LLM tokens for each SLM token.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ slm_tokenizer: PreTrainedTokenizerBase,
32
+ llm_tokenizer: PreTrainedTokenizerBase,
33
+ strategy: Union[AlignmentStrategy, str] = AlignmentStrategy.FIRST,
34
+ verbose: bool = False
35
+ ):
36
+ """
37
+ Initialize the TokenAligner.
38
+
39
+ Args:
40
+ slm_tokenizer: The tokenizer for the Small Language Model (base)
41
+ llm_tokenizer: The tokenizer for the Large Language Model
42
+ strategy: Strategy for handling 1-to-many token mappings
43
+ Either AlignmentStrategy enum or string ('first' or 'longest')
44
+ verbose: Whether to print debug information during alignment
45
+ """
46
+ self.slm_tokenizer = slm_tokenizer
47
+ self.llm_tokenizer = llm_tokenizer
48
+
49
+ if self.slm_tokenizer.pad_token is None:
50
+ self.slm_tokenizer.pad_token = self.slm_tokenizer.eos_token
51
+ self.slm_tokenizer.pad_token_id = self.slm_tokenizer.eos_token_id
52
+ if self.llm_tokenizer.pad_token is None:
53
+ self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
54
+ self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id
55
+
56
+ # Handle string strategy input
57
+ if isinstance(strategy, str):
58
+ strategy = AlignmentStrategy(strategy.lower())
59
+ self.strategy = strategy
60
+ self.verbose = verbose
61
+
62
+ # Cache for token mappings to improve performance
63
+ self._alignment_cache: Dict[Tuple[int, ...], List[int]] = {}
64
+
65
+ def align_tokens(
66
+ self,
67
+ slm_token_ids: Union[List[int], torch.Tensor],
68
+ return_mapping: bool = False
69
+ ) -> Union[List[int], Tuple[List[int], List[Tuple[int, List[int]]]]]:
70
+ """
71
+ Align SLM tokens to LLM tokens.
72
+
73
+ Args:
74
+ slm_token_ids: Token IDs from the SLM tokenizer
75
+ return_mapping: If True, also return the detailed mapping
76
+
77
+ Returns:
78
+ If return_mapping is False: List of aligned LLM token IDs
79
+ If return_mapping is True: Tuple of (aligned_llm_token_ids, mapping_details)
80
+ where mapping_details is a list of (slm_token_id, [candidate_llm_token_ids])
81
+ """
82
+ # Convert to list if tensor
83
+ if isinstance(slm_token_ids, torch.Tensor):
84
+ slm_token_ids = slm_token_ids.tolist()
85
+
86
+ # Check cache
87
+ cache_key = tuple(slm_token_ids)
88
+ if cache_key in self._alignment_cache and not return_mapping:
89
+ return self._alignment_cache[cache_key]
90
+
91
+ aligned_llm_tokens = []
92
+ mapping_details = []
93
+
94
+ for slm_token_id in slm_token_ids:
95
+ # Decode SLM token to string (without special token processing)
96
+ slm_token_str = self.slm_tokenizer.decode(
97
+ [slm_token_id],
98
+ skip_special_tokens=False,
99
+ clean_up_tokenization_spaces=False
100
+ )
101
+
102
+ # Handle special tokens
103
+ if slm_token_id in self.slm_tokenizer.all_special_ids:
104
+ # Try to find corresponding special token in LLM tokenizer
105
+ llm_token_id = self._map_special_token(slm_token_id, slm_token_str)
106
+ aligned_llm_tokens.append(llm_token_id)
107
+ mapping_details.append((slm_token_id, [llm_token_id]))
108
+ continue
109
+
110
+ # Tokenize the string with LLM tokenizer
111
+ llm_token_ids = self.llm_tokenizer.encode(
112
+ slm_token_str,
113
+ add_special_tokens=False,
114
+ return_tensors=None
115
+ )
116
+
117
+ if len(llm_token_ids) == 0:
118
+ # Handle empty tokenization (shouldn't normally happen)
119
+ if self.verbose:
120
+ print(f"Warning: SLM token {slm_token_id} ('{slm_token_str}') "
121
+ f"resulted in empty LLM tokenization")
122
+ # Use unknown token as fallback
123
+ llm_token_id = self.llm_tokenizer.unk_token_id or 0
124
+ aligned_llm_tokens.append(llm_token_id)
125
+ mapping_details.append((slm_token_id, [llm_token_id]))
126
+
127
+ elif len(llm_token_ids) == 1:
128
+ # Perfect 1-to-1 mapping
129
+ aligned_llm_tokens.append(llm_token_ids[0])
130
+ mapping_details.append((slm_token_id, llm_token_ids))
131
+
132
+ else:
133
+ # 1-to-many mapping, apply strategy
134
+ selected_token = self._apply_strategy(llm_token_ids, slm_token_str)
135
+ aligned_llm_tokens.append(selected_token)
136
+ mapping_details.append((slm_token_id, llm_token_ids))
137
+
138
+ if self.verbose:
139
+ selected_str = self.llm_tokenizer.decode(
140
+ [selected_token],
141
+ skip_special_tokens=False,
142
+ clean_up_tokenization_spaces=False
143
+ )
144
+ print(f"SLM token {slm_token_id} ('{slm_token_str}') -> "
145
+ f"LLM tokens {llm_token_ids}, selected {selected_token} ('{selected_str}')")
146
+
147
+ # Cache the result
148
+ self._alignment_cache[cache_key] = aligned_llm_tokens
149
+
150
+ if return_mapping:
151
+ return aligned_llm_tokens, mapping_details
152
+ return aligned_llm_tokens
153
+
154
+ def _map_special_token(self, slm_token_id: int, slm_token_str: str) -> int:
155
+ """
156
+ Map special tokens between tokenizers.
157
+
158
+ Args:
159
+ slm_token_id: The SLM special token ID
160
+ slm_token_str: The string representation of the special token
161
+
162
+ Returns:
163
+ The corresponding LLM token ID
164
+ """
165
+ # Common special token mappings
166
+ special_token_map = {
167
+ self.slm_tokenizer.pad_token_id: self.llm_tokenizer.pad_token_id,
168
+ self.slm_tokenizer.eos_token_id: self.llm_tokenizer.eos_token_id,
169
+ self.slm_tokenizer.bos_token_id: self.llm_tokenizer.bos_token_id,
170
+ self.slm_tokenizer.unk_token_id: self.llm_tokenizer.unk_token_id,
171
+ }
172
+
173
+ # Direct mapping if available
174
+ if slm_token_id in special_token_map and special_token_map[slm_token_id] is not None:
175
+ return special_token_map[slm_token_id]
176
+
177
+ # Try to find by string representation
178
+ try:
179
+ llm_token_id = self.llm_tokenizer.convert_tokens_to_ids(slm_token_str)
180
+ if llm_token_id != self.llm_tokenizer.unk_token_id:
181
+ return llm_token_id
182
+ except:
183
+ pass
184
+
185
+ # Fallback to unknown token
186
+ return self.llm_tokenizer.unk_token_id or 0
187
+
188
+ def _apply_strategy(self, llm_token_ids: List[int], original_str: str) -> int:
189
+ """
190
+ Apply the selected strategy to choose one LLM token from multiple candidates.
191
+
192
+ Args:
193
+ llm_token_ids: List of candidate LLM token IDs
194
+ original_str: The original string from SLM token
195
+
196
+ Returns:
197
+ The selected LLM token ID
198
+ """
199
+ if self.strategy == AlignmentStrategy.FIRST:
200
+ return llm_token_ids[0]
201
+
202
+ elif self.strategy == AlignmentStrategy.LONGEST:
203
+ # Find the token with the longest string representation
204
+ longest_token = llm_token_ids[0]
205
+ longest_length = 0
206
+
207
+ for token_id in llm_token_ids:
208
+ token_str = self.llm_tokenizer.decode(
209
+ [token_id],
210
+ skip_special_tokens=False,
211
+ clean_up_tokenization_spaces=False
212
+ )
213
+ if len(token_str) > longest_length:
214
+ longest_length = len(token_str)
215
+ longest_token = token_id
216
+
217
+ return longest_token
218
+
219
+ else:
220
+ # Default to first token if unknown strategy
221
+ return llm_token_ids[0]
222
+
223
+ def align_sequence(
224
+ self,
225
+ text: str,
226
+ return_details: bool = False
227
+ ) -> Union[Tuple[List[int], List[int]], Dict[str, any]]:
228
+ """
229
+ Tokenize text with both tokenizers and return aligned sequences.
230
+
231
+ Args:
232
+ text: The input text to tokenize and align
233
+ return_details: If True, return detailed alignment information
234
+
235
+ Returns:
236
+ If return_details is False: Tuple of (slm_token_ids, aligned_llm_token_ids)
237
+ If return_details is True: Dictionary with detailed alignment information
238
+ """
239
+ # Tokenize with SLM
240
+ slm_tokens = self.slm_tokenizer.encode(
241
+ text,
242
+ add_special_tokens=True,
243
+ return_tensors=None
244
+ )
245
+
246
+ # Get aligned LLM tokens
247
+ if return_details:
248
+ aligned_llm_tokens, mapping = self.align_tokens(slm_tokens, return_mapping=True)
249
+
250
+ # Decode tokens for inspection
251
+ slm_decoded = [
252
+ self.slm_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False)
253
+ for tid in slm_tokens
254
+ ]
255
+ llm_decoded = [
256
+ self.llm_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False)
257
+ for tid in aligned_llm_tokens
258
+ ]
259
+
260
+ # Original LLM tokenization for comparison
261
+ original_llm_tokens = self.llm_tokenizer.encode(
262
+ text,
263
+ add_special_tokens=True,
264
+ return_tensors=None
265
+ )
266
+
267
+ # One-to-one mapping statistics
268
+ num_tokens = len(slm_tokens)
269
+ one_to_one_count = sum(1 for _slm_id, candidates in mapping if len(candidates) == 1)
270
+ one_to_one_rate = (one_to_one_count / num_tokens) if num_tokens > 0 else 0.0
271
+
272
+ return {
273
+ 'text': text,
274
+ 'slm_token_ids': slm_tokens,
275
+ 'slm_decoded': slm_decoded,
276
+ 'aligned_llm_token_ids': aligned_llm_tokens,
277
+ 'aligned_llm_decoded': llm_decoded,
278
+ 'original_llm_token_ids': original_llm_tokens,
279
+ 'mapping': mapping,
280
+ 'strategy': self.strategy.value,
281
+ 'num_tokens': num_tokens,
282
+ 'one_to_one_count': one_to_one_count,
283
+ 'one_to_one_rate': one_to_one_rate
284
+ }
285
+ else:
286
+ aligned_llm_tokens = self.align_tokens(slm_tokens)
287
+ return slm_tokens, aligned_llm_tokens
288
+
289
+ def visualize_alignment(self, text: str):
290
+ """
291
+ Print a visual representation of the token alignment.
292
+
293
+ Args:
294
+ text: The text to analyze
295
+ """
296
+ details = self.align_sequence(text, return_details=True)
297
+
298
+ print("=" * 80)
299
+ print(f"Text: {text}")
300
+ print(f"Strategy: {details['strategy']}")
301
+ print("=" * 80)
302
+ print(f"SLM tokens ({len(details['slm_token_ids'])}): {details['slm_token_ids']}")
303
+ print(f"Aligned LLM tokens ({len(details['aligned_llm_token_ids'])}): {details['aligned_llm_token_ids']}")
304
+ print(f"Original LLM tokens ({len(details['original_llm_token_ids'])}): {details['original_llm_token_ids']}")
305
+ print("-" * 80)
306
+ print("Token-by-token alignment:")
307
+
308
+ for i, (slm_id, llm_id) in enumerate(zip(details['slm_token_ids'], details['aligned_llm_token_ids'])):
309
+ slm_str = details['slm_decoded'][i]
310
+ llm_str = details['aligned_llm_decoded'][i]
311
+ mapping_info = details['mapping'][i]
312
+
313
+ if len(mapping_info[1]) > 1:
314
+ candidates_str = ', '.join([
315
+ f"{tid}:'{self.llm_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False)}'"
316
+ for tid in mapping_info[1]
317
+ ])
318
+ print(f" [{i:3d}] SLM {slm_id:6d} ('{slm_str}') -> "
319
+ f"LLM {llm_id:6d} ('{llm_str}') "
320
+ f"[candidates: {candidates_str}]")
321
+ else:
322
+ print(f" [{i:3d}] SLM {slm_id:6d} ('{slm_str}') -> "
323
+ f"LLM {llm_id:6d} ('{llm_str}')")
324
+ print("=" * 80)
325
+
326
+ def clear_cache(self):
327
+ """Clear the alignment cache."""
328
+ self._alignment_cache.clear()
329
+
330
+ # ========================
331
+ # Chat messages alignment
332
+ # ========================
333
+ def _apply_chat_template_to_ids(
334
+ self,
335
+ tokenizer: PreTrainedTokenizerBase,
336
+ messages: List[Dict[str, str]],
337
+ add_generation_prompt: bool,
338
+ enable_thinking: bool,
339
+ remove_last_surfix: bool
340
+ ) -> Tuple[str, List[int], Optional[List[Tuple[int, int]]]]:
341
+ """
342
+ Apply chat template (no tokenization) then tokenize to ids with optional offsets.
343
+ If remove_last_surfix is True, remove the last suffix from the LLM text
344
+ Returns (templated_text, input_ids, offsets) where offsets may be None.
345
+ """
346
+ if remove_last_surfix:
347
+ assert messages[-1]["role"] == "assistant", "Last message must be an assistant message"
348
+ templated_text = tokenizer.apply_chat_template(
349
+ messages[:-1],
350
+ tokenize=False,
351
+ add_generation_prompt=True,
352
+ enable_thinking=enable_thinking
353
+ )
354
+ templated_text += messages[-1]["content"]
355
+ else:
356
+ templated_text = tokenizer.apply_chat_template(
357
+ messages,
358
+ tokenize=False,
359
+ add_generation_prompt=add_generation_prompt,
360
+ enable_thinking=enable_thinking
361
+ )
362
+ encoded = tokenizer(
363
+ templated_text,
364
+ add_special_tokens=False,
365
+ return_offsets_mapping=True
366
+ )
367
+ input_ids: List[int] = encoded["input_ids"]
368
+ offsets = encoded.get("offset_mapping")
369
+ return templated_text, input_ids, offsets
370
+
371
+ @staticmethod
372
+ def _first_non_empty_content(messages: List[Dict[str, str]]) -> Optional[str]:
373
+ for m in messages:
374
+ content = m.get("content")
375
+ if isinstance(content, str) and len(content.strip()) > 0:
376
+ return content
377
+ return None
378
+
379
+ def _find_boundary_token_index(
380
+ self,
381
+ tokenizer: PreTrainedTokenizerBase,
382
+ templated_text: str,
383
+ offsets: Optional[List[Tuple[int, int]]],
384
+ content_text: Optional[str]
385
+ ) -> int:
386
+ """
387
+ Find token index where the first non-empty message content starts.
388
+ Falls back to 0 if not found.
389
+ """
390
+ if not content_text:
391
+ return 0
392
+ char_idx = templated_text.find(content_text)
393
+ if char_idx < 0:
394
+ # Try a shorter probe to improve chances
395
+ probe = content_text[: min(32, len(content_text))]
396
+ if len(probe) > 0:
397
+ char_idx = templated_text.find(probe)
398
+ if char_idx < 0:
399
+ return 0
400
+
401
+ if offsets:
402
+ for idx, (start, _end) in enumerate(offsets):
403
+ if start >= char_idx:
404
+ return idx
405
+ return len(offsets)
406
+
407
+ # Fallback without offsets: tokenize prefix and count tokens
408
+ prefix = templated_text[:char_idx]
409
+ prefix_ids = tokenizer(prefix, add_special_tokens=False)["input_ids"]
410
+ return len(prefix_ids)
411
+
412
+ @staticmethod
413
+ def _compute_content_spans(templated_text: str, messages: List[Dict[str, str]]) -> List[Tuple[int, int]]:
414
+ """
415
+ Compute character spans in templated_text that correspond to message contents.
416
+ Searches sequentially to reduce ambiguity when contents repeat.
417
+ Enhanced matching: ensures the found content is followed by '<' (special token start)
418
+ to avoid matching content inside special tokens like <begin_of_text>.
419
+ """
420
+ spans: List[Tuple[int, int]] = []
421
+ search_from = 0
422
+ for m in messages:
423
+ content = m.get("content")
424
+ if not isinstance(content, str) or len(content) == 0:
425
+ continue
426
+
427
+ # Find all possible matches starting from search_from
428
+ idx = search_from
429
+ found_valid_match = False
430
+
431
+ while idx < len(templated_text):
432
+ idx = templated_text.find(content, idx)
433
+ if idx < 0:
434
+ break
435
+
436
+ # Check if this match is valid (followed by '<' indicating a special token)
437
+ end_pos = idx + len(content)
438
+ if end_pos < len(templated_text) and templated_text[end_pos] == '<':
439
+ # Valid match: content is followed by a special token
440
+ spans.append((idx, end_pos))
441
+ search_from = end_pos
442
+ found_valid_match = True
443
+ break
444
+ else:
445
+ # Check if this is the end of the text (also valid for last message)
446
+ if end_pos == len(templated_text):
447
+ spans.append((idx, end_pos))
448
+ search_from = end_pos
449
+ found_valid_match = True
450
+ break
451
+
452
+ # Invalid match, try next occurrence
453
+ idx += 1
454
+
455
+ # Fallback: if no valid match found with '<' requirement, use the old method
456
+ # but only as a last resort and with additional validation
457
+ if not found_valid_match:
458
+ idx = templated_text.find(content, search_from)
459
+ if idx < 0:
460
+ # Try searching from start as last resort
461
+ idx = templated_text.find(content)
462
+
463
+ if idx >= 0:
464
+ end_pos = idx + len(content)
465
+ # Additional check: avoid matching inside obvious special tokens
466
+ # Check if we're inside a special token (preceded by '<' and not followed by '>')
467
+ start_context = templated_text[max(0, idx-10):idx]
468
+ end_context = templated_text[end_pos:min(len(templated_text), end_pos+10)]
469
+
470
+ # Skip if we're clearly inside a special token
471
+ if ('<' in start_context and '>' not in start_context and
472
+ 'begin_of_text' in templated_text[max(0, idx-20):idx+20]):
473
+ # This looks like we're matching inside <begin_of_text> or similar
474
+ continue
475
+
476
+ spans.append((idx, end_pos))
477
+ search_from = end_pos
478
+
479
+ return spans
480
+
481
+ @staticmethod
482
+ def _build_token_mask_from_spans(
483
+ offsets: Optional[List[Tuple[int, int]]],
484
+ num_tokens: int,
485
+ spans: List[Tuple[int, int]]
486
+ ) -> List[bool]:
487
+ """
488
+ Build a boolean mask for tokens whose offset range overlaps any span.
489
+ If offsets are missing, default to all False.
490
+ """
491
+ if not offsets or len(offsets) != num_tokens:
492
+ return [False] * num_tokens
493
+ mask: List[bool] = []
494
+ for (start, end) in offsets:
495
+ if end <= start:
496
+ mask.append(False)
497
+ continue
498
+ is_msg = False
499
+ for s, e in spans:
500
+ # overlap check
501
+ if start < e and end > s:
502
+ is_msg = True
503
+ break
504
+ mask.append(is_msg)
505
+ return mask
506
+
507
+ @staticmethod
508
+ def _spans_to_token_ranges(
509
+ offsets: List[Tuple[int, int]],
510
+ spans: List[Tuple[int, int]]
511
+ ) -> List[Tuple[int, int]]:
512
+ """
513
+ Convert character spans to token index ranges using offsets.
514
+ start token = first token with end > span_start
515
+ end token = first token with start >= span_end
516
+ """
517
+ ranges: List[Tuple[int, int]] = []
518
+ n = len(offsets)
519
+ for s, e in spans:
520
+ # find start index
521
+ start_idx = 0
522
+ while start_idx < n and offsets[start_idx][1] <= s:
523
+ start_idx += 1
524
+ # find end index
525
+ end_idx = start_idx
526
+ while end_idx < n and offsets[end_idx][0] < e:
527
+ end_idx += 1
528
+ ranges.append((start_idx, end_idx))
529
+ return ranges
530
+
531
+ def align_chat_messages(
532
+ self,
533
+ messages: List[Dict[str, str]],
534
+ add_generation_prompt: bool = True,
535
+ enable_thinking: bool = False,
536
+ return_details: bool = False,
537
+ remove_last_surfix: bool = False
538
+ ) -> Dict[str, any]:
539
+ """
540
+ Align chat-templated sequences by sections (template/message/template...):
541
+ - Preserve all template tokens (pad the shorter template section)
542
+ - For each message section, map SLM tokens to LLM tokens 1:1 via strategy
543
+ - If remove_last_surfix is True, remove the last suffix from the LLM text
544
+ Returns essentials: slm_ids_padded, llm_ids_padded, message_mask (shared),
545
+ slm_padding_mask, llm_padding_mask (True where token is padding inserted).
546
+ When return_details=True, also returns 'sections' with aligned ranges.
547
+ """
548
+ assert not (add_generation_prompt and remove_last_surfix), "add_generation_prompt and remove_last_surfix cannot be True at the same time"
549
+
550
+ # Build templated sequences with offsets
551
+ slm_text, slm_ids, slm_offsets = self._apply_chat_template_to_ids(
552
+ self.slm_tokenizer, messages, add_generation_prompt, enable_thinking, remove_last_surfix
553
+ )
554
+ llm_text, llm_ids, llm_offsets = self._apply_chat_template_to_ids(
555
+ self.llm_tokenizer, messages, add_generation_prompt, enable_thinking, remove_last_surfix
556
+ )
557
+
558
+ # Required pad tokens
559
+ assert self.slm_tokenizer.pad_token_id is not None, "SLM pad_token_id required"
560
+ assert self.llm_tokenizer.pad_token_id is not None, "LLM pad_token_id required"
561
+ slm_pad_id = self.slm_tokenizer.pad_token_id
562
+ llm_pad_id = self.llm_tokenizer.pad_token_id
563
+
564
+ # Content spans (char) and token ranges
565
+ content_spans_slm = self._compute_content_spans(slm_text, messages)
566
+ content_spans_llm = self._compute_content_spans(llm_text, messages)
567
+ assert slm_offsets is not None and llm_offsets is not None, "offset_mapping required"
568
+ slm_msg_ranges = self._spans_to_token_ranges(slm_offsets, content_spans_slm)
569
+ llm_msg_ranges = self._spans_to_token_ranges(llm_offsets, content_spans_llm)
570
+ # Build section ranges (template/message alternating)
571
+ def build_sections(total_len: int, msg_ranges: List[Tuple[int,int]]):
572
+ sections: List[Tuple[str,int,int]] = []
573
+ prev = 0
574
+ for (s, e) in msg_ranges:
575
+ if prev < s:
576
+ sections.append(("template", prev, s))
577
+ sections.append(("message", s, e))
578
+ prev = e
579
+ if prev < total_len:
580
+ sections.append(("template", prev, total_len))
581
+ return sections
582
+ slm_sections = build_sections(len(slm_ids), slm_msg_ranges)
583
+ llm_sections = build_sections(len(llm_ids), llm_msg_ranges)
584
+ assert len(slm_sections) == len(llm_sections), "Section count mismatch"
585
+
586
+ slm_out: List[int] = []
587
+ llm_out: List[int] = []
588
+ mask_out: List[bool] = []
589
+ slm_pad_mask_out: List[bool] = []
590
+ llm_pad_mask_out: List[bool] = []
591
+ detailed_sections: List[Dict[str, Union[str, Tuple[int,int]]]] = []
592
+
593
+ for (stype_s, s_s, e_s), (stype_l, s_l, e_l) in zip(slm_sections, llm_sections):
594
+ assert stype_s == stype_l, "Section type mismatch"
595
+ slm_start_out = len(slm_out)
596
+ llm_start_out = len(llm_out)
597
+ if stype_s == "template":
598
+ slm_seg_len = e_s - s_s
599
+ llm_seg_len = e_l - s_l
600
+ target_len = slm_seg_len if slm_seg_len >= llm_seg_len else llm_seg_len
601
+ slm_pad_needed = target_len - slm_seg_len
602
+ llm_pad_needed = target_len - llm_seg_len
603
+ slm_seg = slm_ids[s_s:e_s] + [slm_pad_id] * slm_pad_needed
604
+ llm_seg = llm_ids[s_l:e_l] + [llm_pad_id] * llm_pad_needed
605
+ slm_out.extend(slm_seg)
606
+ llm_out.extend(llm_seg)
607
+ mask_out.extend([False] * target_len)
608
+ slm_pad_mask_out.extend([False] * slm_seg_len + [True] * slm_pad_needed)
609
+ llm_pad_mask_out.extend([False] * llm_seg_len + [True] * llm_pad_needed)
610
+ else: # message
611
+ slm_msg = slm_ids[s_s:e_s]
612
+ llm_msg = self.align_tokens(slm_msg)
613
+ assert len(llm_msg) == len(slm_msg)
614
+ slm_out.extend(slm_msg)
615
+ llm_out.extend(llm_msg)
616
+ mask_out.extend([True] * len(slm_msg))
617
+ # no padding in message sections
618
+ slm_pad_mask_out.extend([False] * len(slm_msg))
619
+ llm_pad_mask_out.extend([False] * len(slm_msg))
620
+ slm_end_out = len(slm_out)
621
+ llm_end_out = len(llm_out)
622
+ detailed_sections.append({
623
+ 'type': stype_s,
624
+ 'slm_range': (slm_start_out, slm_end_out),
625
+ 'llm_range': (llm_start_out, llm_end_out)
626
+ })
627
+
628
+ result_min = {
629
+ 'slm_ids_padded': slm_out,
630
+ 'llm_ids_padded': llm_out,
631
+ 'message_mask': mask_out,
632
+ 'slm_padding_mask': slm_pad_mask_out,
633
+ 'llm_padding_mask': llm_pad_mask_out
634
+ }
635
+ if return_details:
636
+ result_min['sections'] = detailed_sections
637
+ result_min['slm_text'] = slm_text
638
+ result_min['llm_text'] = llm_text
639
+ return result_min
rosetta/model/oracle.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+ import torch
3
+ from torch import nn
4
+ from transformers.cache_utils import Cache, DynamicCache
5
+ from transformers.modeling_utils import PreTrainedModel
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+ import json
8
+
9
+ from rosetta.model.projector import Projector
10
+ from rosetta.model.sampling import sample_token
11
+ from transformers.utils import ModelOutput
12
+ try:
13
+ from transformers.generation.utils import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
14
+ except Exception:
15
+ GreedySearchDecoderOnlyOutput = None
16
+ SampleDecoderOnlyOutput = None
17
+ from rosetta.model.wrapper import RosettaModel
18
+
19
+ class OracleRosettaModel(nn.Module):
20
+ """
21
+ Drop in replacement for the standard transformers LLM models, like Qwen3ForCausalLM
22
+ """
23
+ def __init__(self, model_list: List[PreTrainedModel], base_model_idx = 0, projector_list: List[Projector] = [], aggregator_list: List[nn.Module] = []):
24
+ super().__init__()
25
+ # model list: a list of model, model 0 by default is the base model
26
+ # projector list: a list of projector
27
+ # standard init with additional model list parameter
28
+ # kv-cache dict: key (source_model_idx, target_model_idx), value (Cache), assume only convert at prefill with one type of model
29
+ # projector dict: key (source_model_idx, target_model_idx) value dict(key (source_model_layer_idx, M_target value )
30
+
31
+ self.base_model_idx = base_model_idx
32
+ self.model_list = nn.ModuleList(model_list)
33
+
34
+ device = model_list[base_model_idx].device
35
+ dtype = model_list[base_model_idx].dtype
36
+ self.projector_list = nn.ModuleList(projector_list).to(device=device, dtype=dtype)
37
+ self.aggregator_list = nn.ModuleList(aggregator_list).to(device=device, dtype=dtype)
38
+
39
+ self.projector_dict = {}
40
+ self.aggregator_dict = {}
41
+ self.kv_cache_dict = {}
42
+ self._generation_hook_handlers = []
43
+
44
+ @property
45
+ def device(self):
46
+ return self.model_list[self.base_model_idx].device
47
+
48
+ def to(self, device):
49
+ """
50
+ Move the RosettaModel and all underlying models and projectors to the specified device.
51
+ """
52
+ super().to(device)
53
+ for model in self.model_list:
54
+ model.to(device)
55
+ for projector in self.projector_list:
56
+ projector.to(device)
57
+ for aggregator in self.aggregator_list:
58
+ aggregator.to(device)
59
+ return self
60
+
61
+ # set projector
62
+ def set_projector_config(self,
63
+ source_model_idx: int,
64
+ source_model_layer_idx: int,
65
+ target_model_idx: int,
66
+ target_model_layer_idx: int,
67
+ projector_idx: int):
68
+ """
69
+ Set the projector configuration
70
+ Args:
71
+ source_model_idx: int, the index of the source model
72
+ source_model_layer_idx: int, the index of the source model layer
73
+ target_model_idx: int, the index of the target model
74
+ target_model_layer_idx: int, the index of the target model layer
75
+ projector_idx: int, the index of the projector
76
+
77
+ The projector dict structure supports multiple projectors per target layer.
78
+ Structure:
79
+ {
80
+ target_model_idx: {
81
+ source_model_idx: {
82
+ target_model_layer_idx: [(source_model_layer_idx, projector_idx), ...]
83
+ }
84
+ }
85
+ }
86
+ Repeated calls for the same (target, source, target_layer) append additional pairs.
87
+ """
88
+
89
+ if target_model_idx not in self.projector_dict.keys():
90
+ self.projector_dict[target_model_idx] = {}
91
+ if source_model_idx not in self.projector_dict[target_model_idx].keys():
92
+ self.projector_dict[target_model_idx][source_model_idx] = {}
93
+ # Accumulate list of (source_layer, projector_idx) for this target layer
94
+ layer_entry = self.projector_dict[target_model_idx][source_model_idx].get(target_model_layer_idx)
95
+ if layer_entry is None:
96
+ self.projector_dict[target_model_idx][source_model_idx][target_model_layer_idx] = [(source_model_layer_idx, projector_idx)]
97
+ else:
98
+ layer_entry.append((source_model_layer_idx, projector_idx))
99
+
100
+
101
+ def load_projector(self, projector_list):
102
+ self.projector_list: List[Projector] = projector_list
103
+
104
+ def load_aggregator(self, aggregator_list):
105
+ self.aggregator_list: List[nn.Module] = aggregator_list
106
+
107
+
108
+ def get_projector(self,
109
+ source_model_idx,
110
+ source_model_layer_idx,
111
+ target_model_idx,
112
+ target_model_layer_idx):
113
+ pair_list = self.projector_dict[target_model_idx][source_model_idx][target_model_layer_idx]
114
+ if len(pair_list) == 0:
115
+ raise ValueError("No projector configured for the given target layer")
116
+ # Prefer exact source layer match
117
+ for src_layer, projector_id in pair_list:
118
+ if src_layer == source_model_layer_idx:
119
+ return self.projector_list[projector_id]
120
+ # Fallback: return the first projector
121
+ return self.projector_list[pair_list[0][1]]
122
+
123
+ def set_aggregator_idx(self,
124
+ source_model_idx: int,
125
+ target_model_idx: int,
126
+ target_model_layer_idx: int,
127
+ aggregator_idx: int):
128
+ if target_model_idx not in self.aggregator_dict:
129
+ self.aggregator_dict[target_model_idx] = {}
130
+ if source_model_idx not in self.aggregator_dict[target_model_idx]:
131
+ self.aggregator_dict[target_model_idx][source_model_idx] = {}
132
+ self.aggregator_dict[target_model_idx][source_model_idx][target_model_layer_idx] = aggregator_idx
133
+
134
+
135
+ @staticmethod
136
+ def load_json(file_name):
137
+ with open(file_name, "r") as f:
138
+ result = json.load(f)
139
+ return result
140
+
141
+ @staticmethod
142
+ def _convert_dict_keys_to_ints(obj):
143
+ """
144
+ Recursively convert dictionary keys that look like integers back to int.
145
+ This reverses json.dump's coercion of dict keys to strings.
146
+ """
147
+ if isinstance(obj, dict):
148
+ new_obj = {}
149
+ for key, value in obj.items():
150
+ if isinstance(key, str) and key.lstrip('-').isdigit():
151
+ new_key = int(key)
152
+ else:
153
+ new_key = key
154
+ new_obj[new_key] = RosettaModel._convert_dict_keys_to_ints(value)
155
+ return new_obj
156
+ if isinstance(obj, list):
157
+ return [RosettaModel._convert_dict_keys_to_ints(v) for v in obj]
158
+ return obj
159
+
160
+
161
+ def save_projector_config(self, file_name):
162
+ with open(file_name, "w") as f:
163
+ json.dump(self.projector_dict, f)
164
+
165
+
166
+ def load_projector_config(self, config_path):
167
+ if config_path.endswith(".json"):
168
+ loaded = RosettaModel.load_json(config_path)
169
+ self.projector_dict = RosettaModel._convert_dict_keys_to_ints(loaded)
170
+
171
+ def save_aggregator_config(self, file_name):
172
+ with open(file_name, "w") as f:
173
+ json.dump(self.aggregator_dict, f)
174
+
175
+ def load_aggregator_config(self, config_path):
176
+ if config_path.endswith(".json"):
177
+ loaded = RosettaModel.load_json(config_path)
178
+ self.aggregator_dict = RosettaModel._convert_dict_keys_to_ints(loaded)
179
+
180
+
181
+ def set_kv_cache_dict(self, source_model_idx, target_model_idx, cache):
182
+ if target_model_idx not in self.kv_cache_dict.keys():
183
+ self.kv_cache_dict[target_model_idx] = {}
184
+ if cache is None:
185
+ # Initialize with a DynamicCache instead of RosettaCache for now
186
+ self.kv_cache_dict[target_model_idx][source_model_idx] = DynamicCache() # noqa, maybe we should use RosettaCache here
187
+ else:
188
+ self.kv_cache_dict[target_model_idx][source_model_idx] = cache
189
+
190
+ def forward(
191
+ self,
192
+ kv_cache_index: Optional[List] = None,
193
+ input_ids: Optional[Union[torch.LongTensor, List[torch.LongTensor]]] = None,
194
+ attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
195
+ position_ids: Optional[torch.LongTensor] = None,
196
+ past_key_values: Optional[Cache] = None,
197
+ inputs_embeds: Optional[torch.FloatTensor] = None,
198
+ labels: Optional[torch.LongTensor] = None,
199
+ use_cache: Optional[bool] = None,
200
+ output_attentions: Optional[bool] = None,
201
+ output_hidden_states: Optional[bool] = None,
202
+ cache_position: Optional[torch.LongTensor] = None,
203
+ logits_to_keep: Union[int, torch.Tensor] = 0,
204
+ # **kwargs: Unpack[KwargsForCausalLM],
205
+ identifier = -1,
206
+ subject = None,
207
+ *args,
208
+ **kwargs,
209
+ ) -> CausalLMOutputWithPast:
210
+ """
211
+ Forward pass
212
+ KVCache index is a list of tensors with shape (B, sec_seq_len, 2), indicating the source and target kv cache model index
213
+
214
+ If input_ids is LongTensor, default to same input ids for different models
215
+ If input_ids is Tuple, default to different input ids for different models.
216
+
217
+ No Rosetta: (-1, 0)
218
+ """
219
+
220
+ # noqa
221
+ self.kv_cache_dict = dict()
222
+
223
+ # Handle different input formats: if input_ids is a list, use per-model inputs
224
+ if isinstance(input_ids, list):
225
+ # Use list format: different input_ids and attention_mask for each model
226
+ base_input_ids = input_ids[self.base_model_idx] if input_ids is not None else None
227
+ base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
228
+ _, seqlen = base_input_ids.size() if base_input_ids is not None else (0, 0)
229
+ else:
230
+ # Use tensor format: same input_ids and attention_mask for all models (backward compatibility)
231
+ base_input_ids = input_ids
232
+ base_attention_mask = attention_mask
233
+ _, seqlen = input_ids.size() if input_ids is not None else (0, 0)
234
+
235
+ num_sections = len(kv_cache_index) if kv_cache_index is not None else 1
236
+
237
+ section_lengths = [kv_cache_index[i].shape[1] for i in range(num_sections)] if kv_cache_index is not None else [seqlen]
238
+ section_starts = [0]
239
+ for l in section_lengths:
240
+ section_starts.append(section_starts[-1] + l)
241
+
242
+ curr_base_kv_cache = past_key_values
243
+
244
+ if seqlen > 1:
245
+ for i in range(num_sections):
246
+ start = section_starts[i]
247
+ end = section_starts[i + 1]
248
+ prefill_input_ids = base_input_ids[:, start:end] if base_input_ids is not None else None
249
+ prefill_attention_mask = base_attention_mask[:, :end] if base_attention_mask is not None else None
250
+ prefill_position_ids = position_ids[:, start:end] if position_ids is not None else None
251
+ prefill_labels = labels[:, start:end] if labels is not None else None
252
+
253
+ # calculate target model kvcache
254
+ output = self.model_list[self.base_model_idx].forward(
255
+ input_ids=prefill_input_ids,
256
+ attention_mask=prefill_attention_mask,
257
+ position_ids=prefill_position_ids,
258
+ past_key_values=curr_base_kv_cache,
259
+ labels=prefill_labels,
260
+ use_cache=use_cache,
261
+ output_attentions=output_attentions,
262
+ output_hidden_states=output_hidden_states,
263
+ *args,
264
+ **kwargs
265
+ )
266
+
267
+ if self.base_model_idx not in self.kv_cache_dict:
268
+ self.kv_cache_dict[self.base_model_idx] = {}
269
+ if self.base_model_idx not in self.kv_cache_dict[self.base_model_idx]:
270
+ self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = None
271
+ self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = output.past_key_values
272
+
273
+ curr_base_kv_cache: DynamicCache = output.past_key_values
274
+
275
+ # if i != num_sections - 1:
276
+ for source_model_idx in range(1, len(self.model_list)):
277
+ if self.base_model_idx not in self.kv_cache_dict:
278
+ self.kv_cache_dict[self.base_model_idx] = {}
279
+ if source_model_idx not in self.kv_cache_dict[self.base_model_idx]:
280
+ self.kv_cache_dict[self.base_model_idx][source_model_idx] = None
281
+
282
+ # Get model-specific input_ids and attention_mask
283
+ if isinstance(input_ids, list):
284
+ source_input_ids = input_ids[source_model_idx]
285
+ source_attention_mask = attention_mask[source_model_idx] if attention_mask is not None else None
286
+ source_prefill_input_ids = source_input_ids[:, start:end] if source_input_ids is not None else None
287
+ source_prefill_attention_mask = source_attention_mask[:, :end] if source_attention_mask is not None else None
288
+ else:
289
+ # Backward compatibility: use same input for all models
290
+ source_prefill_input_ids = prefill_input_ids
291
+ source_prefill_attention_mask = prefill_attention_mask
292
+
293
+ curr_source_kv_cache = self.model_list[source_model_idx].forward(
294
+ input_ids=source_prefill_input_ids,
295
+ attention_mask=source_prefill_attention_mask,
296
+ position_ids=prefill_position_ids,
297
+ past_key_values=self.kv_cache_dict[self.base_model_idx][source_model_idx],
298
+ use_cache=use_cache,
299
+ output_attentions=output_attentions,
300
+ output_hidden_states=output_hidden_states,
301
+ *args,
302
+ **kwargs
303
+ ).past_key_values
304
+ self.kv_cache_dict[self.base_model_idx][source_model_idx] = curr_source_kv_cache
305
+
306
+ # calculate source model kvcache and apply projections
307
+ if self.base_model_idx in self.projector_dict:
308
+ source_model_idx = kv_cache_index[i][0][0][0].item() # Get the source model index from the kv_cache_index
309
+ if source_model_idx != -1:
310
+ for target_layer_idx, entry in self.projector_dict[self.base_model_idx][source_model_idx].items():
311
+ base_key_cache, base_value_cache = curr_base_kv_cache[target_layer_idx]
312
+ new_base_key_cache = base_key_cache[:, :, start:end, :]
313
+ new_base_value_cache = base_value_cache[:, :, start:end, :]
314
+ new_base_kv_cache = (new_base_key_cache, new_base_value_cache)
315
+
316
+ pair_list = entry
317
+
318
+ projected_kv_list = []
319
+ source_kv_list = []
320
+ for source_model_layer_idx, projector_idx in pair_list:
321
+ source_key_cache, source_value_cache = self.kv_cache_dict[self.base_model_idx][source_model_idx][source_model_layer_idx]
322
+ new_source_key_cache = source_key_cache[:, :, start:end, :]
323
+ new_source_value_cache = source_value_cache[:, :, start:end, :]
324
+ new_source_kv_cache = (new_source_key_cache, new_source_value_cache)
325
+ projected_key, projected_value = self.projector_list[projector_idx].forward(
326
+ new_source_kv_cache, # tuple of (key, value), each of shape (B, N, H, D)
327
+ new_base_kv_cache
328
+ )
329
+ projected_kv_list.append((projected_key, projected_value))
330
+
331
+ # --------------
332
+ # save base and projected kv cache
333
+ torch.save((projected_key, projected_value), f"oracle/projected_kv/{subject}_{identifier}_{i}.pt")
334
+ torch.save(new_base_kv_cache, f"oracle/target_kv/{subject}_{identifier}_{i}.pt")
335
+ # --------------
336
+ source_kv_list.append(new_source_kv_cache)
337
+
338
+ # Aggregate (fallback to first projector if no aggregator is available)
339
+ use_aggregator = (
340
+ len(projected_kv_list) > 1 and
341
+ len(self.aggregator_list) > 0 and
342
+ self.base_model_idx in self.aggregator_dict and
343
+ source_model_idx in self.aggregator_dict[self.base_model_idx] and
344
+ target_layer_idx in self.aggregator_dict[self.base_model_idx][source_model_idx]
345
+ )
346
+
347
+ if use_aggregator:
348
+ aggregator_idx = self.aggregator_dict[self.base_model_idx][source_model_idx][target_layer_idx]
349
+ agg_key, agg_value = self.aggregator_list[aggregator_idx].forward(
350
+ source_kv_list,
351
+ new_base_kv_cache,
352
+ projected_kv_list
353
+ )
354
+ else:
355
+ # Fallback to first projector result when no aggregator is available
356
+ agg_key, agg_value = projected_kv_list[0]
357
+
358
+ # Update cache with aggregated result
359
+ curr_base_kv_cache.key_cache[target_layer_idx][:, :, start:end, :] = agg_key
360
+ curr_base_kv_cache.value_cache[target_layer_idx][:, :, start:end, :] = agg_value
361
+
362
+ output.past_key_values = curr_base_kv_cache
363
+
364
+ # use base model for decode phase
365
+ else:
366
+ # Handle list input format for decode phase as well
367
+ decode_input_ids = input_ids[self.base_model_idx] if isinstance(input_ids, list) else input_ids
368
+ decode_attention_mask = attention_mask[self.base_model_idx] if isinstance(attention_mask, list) else attention_mask
369
+
370
+ output = self.model_list[self.base_model_idx].forward(
371
+ input_ids=decode_input_ids,
372
+ attention_mask=decode_attention_mask,
373
+ position_ids=position_ids,
374
+ past_key_values=curr_base_kv_cache,
375
+ inputs_embeds=inputs_embeds,
376
+ labels=labels,
377
+ use_cache=use_cache,
378
+ output_attentions=output_attentions,
379
+ output_hidden_states=output_hidden_states,
380
+ cache_position=cache_position,
381
+ *args,
382
+ **kwargs
383
+ )
384
+
385
+ return output
386
+
387
+ @torch.no_grad()
388
+ def generate(
389
+ self,
390
+ kv_cache_index,
391
+ input_ids,
392
+ max_new_tokens: Optional[int] = None,
393
+ past_key_values: Optional[Cache] = None,
394
+ attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
395
+ position_ids: Optional[torch.LongTensor] = None,
396
+ eos_token_id: Optional[Union[int, List[int]]] = None,
397
+ pad_token_id: Optional[int] = None,
398
+ temperature: float = 1.0,
399
+ top_p: float = 1.0,
400
+ top_k: int = -1,
401
+ do_sample: Optional[bool] = None,
402
+ return_dict_in_generate: Optional[bool] = None,
403
+ output_scores: Optional[bool] = None,
404
+ max_length: Optional[int] = None,
405
+ use_cache: bool = True,
406
+ *args,
407
+ **kwargs,
408
+ ):
409
+ """
410
+ New generation loop without using the base model's generate.
411
+ - Uses this module's forward for prefill and per-token decode.
412
+ - Samples tokens via rosetta.model.sampling.sample_token.
413
+ Returns a tensor of shape [batch, prompt_len + generated_len] for the base model stream.
414
+ """
415
+ # Derive number of tokens to generate
416
+ # If max_new_tokens not provided, infer from max_length
417
+ if isinstance(input_ids, list):
418
+ base_input_ids_for_len = input_ids[self.base_model_idx]
419
+ else:
420
+ base_input_ids_for_len = input_ids
421
+ prompt_len = base_input_ids_for_len.size(1)
422
+
423
+ # Default eos/pad from base model tokenizer/config if not provided
424
+ base_model = self.model_list[self.base_model_idx]
425
+ gen_cfg = getattr(base_model, "generation_config", None)
426
+ cfg_obj = gen_cfg if gen_cfg is not None else getattr(base_model, "config", None)
427
+ if eos_token_id is None and cfg_obj is not None:
428
+ eos_token_id = getattr(cfg_obj, "eos_token_id", None)
429
+ if pad_token_id is None and cfg_obj is not None:
430
+ pad_token_id = getattr(cfg_obj, "pad_token_id", None)
431
+ if pad_token_id is None and eos_token_id is not None:
432
+ pad_token_id = eos_token_id if isinstance(eos_token_id, int) else eos_token_id[0]
433
+
434
+ if max_new_tokens is None:
435
+ if max_length is not None:
436
+ if max_length <= prompt_len:
437
+ max_new_tokens = 0
438
+ else:
439
+ max_new_tokens = max_length - prompt_len
440
+ else:
441
+ raise ValueError("Provide max_new_tokens or max_length")
442
+ if max_new_tokens < 0:
443
+ raise ValueError("max_new_tokens must be non-negative")
444
+
445
+ # Resolve base inputs
446
+ if isinstance(input_ids, list):
447
+ base_input_ids = input_ids[self.base_model_idx]
448
+ base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
449
+ else:
450
+ base_input_ids = input_ids
451
+ base_attention_mask = attention_mask
452
+
453
+ if base_attention_mask is None:
454
+ base_attention_mask = torch.ones_like(base_input_ids, dtype=torch.long, device=base_input_ids.device)
455
+
456
+ batch_size = base_input_ids.size(0)
457
+
458
+ # Prefill to build caches and obtain initial logits
459
+ prefill_output = self.forward(
460
+ kv_cache_index=kv_cache_index,
461
+ input_ids=input_ids,
462
+ attention_mask=attention_mask,
463
+ position_ids=position_ids,
464
+ past_key_values=past_key_values,
465
+ use_cache=use_cache,
466
+ *args,
467
+ **kwargs,
468
+ )
469
+
470
+ current_past = prefill_output.past_key_values
471
+ all_input_ids = base_input_ids
472
+ current_attention_mask = base_attention_mask
473
+
474
+ # EOS handling setup
475
+ eos_set = None
476
+ if eos_token_id is not None:
477
+ eos_set = set(eos_token_id if isinstance(eos_token_id, list) else [eos_token_id])
478
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=all_input_ids.device)
479
+
480
+ # Start from last prefill logits
481
+ last_logits = prefill_output.logits[:, -1, :]
482
+
483
+ # Determine sampling mode
484
+ if do_sample is None:
485
+ do_sample = False
486
+ effective_temperature = temperature if do_sample else 0.0
487
+
488
+ # Optional scores collection
489
+ collect_scores = bool(return_dict_in_generate) and bool(output_scores)
490
+ scores = []
491
+
492
+ for _ in range(max_new_tokens):
493
+ if collect_scores:
494
+ scores.append(last_logits)
495
+ # Sample next token
496
+ next_token = sample_token(last_logits, temperature=effective_temperature, top_p=top_p, top_k=top_k)
497
+ if not isinstance(next_token, torch.Tensor):
498
+ next_token = torch.tensor([next_token], device=all_input_ids.device, dtype=torch.long).repeat(batch_size)
499
+
500
+ # Apply EOS logic
501
+ if eos_set is not None:
502
+ just_finished = torch.zeros_like(finished)
503
+ for eid in eos_set:
504
+ just_finished |= (next_token == eid)
505
+ finished = finished | just_finished
506
+ if pad_token_id is not None:
507
+ next_token = torch.where(
508
+ finished,
509
+ torch.tensor(pad_token_id, device=next_token.device, dtype=next_token.dtype),
510
+ next_token,
511
+ )
512
+
513
+ # Append sampled token
514
+ next_token_unsqueezed = next_token.unsqueeze(1)
515
+ all_input_ids = torch.cat([all_input_ids, next_token_unsqueezed], dim=1)
516
+ current_attention_mask = torch.cat(
517
+ [
518
+ current_attention_mask,
519
+ torch.ones((batch_size, 1), device=current_attention_mask.device, dtype=current_attention_mask.dtype),
520
+ ],
521
+ dim=1,
522
+ )
523
+
524
+ # Early stop if all sequences finished
525
+ if eos_set is not None and torch.all(finished):
526
+ break
527
+
528
+ # Decode one step using cached states; pass base-stream tensors
529
+ kv_cache_index = [torch.tensor([-1, 0], dtype=torch.long).repeat(1, 1).unsqueeze(0).to(all_input_ids.device)]
530
+
531
+ decode_output = self.forward(
532
+ kv_cache_index=kv_cache_index,
533
+ input_ids=next_token_unsqueezed,
534
+ attention_mask=current_attention_mask,
535
+ position_ids=None,
536
+ past_key_values=current_past,
537
+ use_cache=True,
538
+ *args,
539
+ **kwargs,
540
+ )
541
+ current_past = decode_output.past_key_values
542
+ last_logits = decode_output.logits[:, -1, :]
543
+
544
+ # Return style compatible with HF generate
545
+ if return_dict_in_generate:
546
+ if GreedySearchDecoderOnlyOutput is not None and SampleDecoderOnlyOutput is not None:
547
+ if do_sample:
548
+ return SampleDecoderOnlyOutput(
549
+ sequences=all_input_ids,
550
+ scores=scores if collect_scores else None,
551
+ )
552
+ else:
553
+ return GreedySearchDecoderOnlyOutput(
554
+ sequences=all_input_ids,
555
+ scores=scores if collect_scores else None,
556
+ )
557
+ # Fallback to generic ModelOutput
558
+ result = {"sequences": all_input_ids}
559
+ if collect_scores:
560
+ result["scores"] = scores
561
+ return ModelOutput(**result)
562
+ return all_input_ids
rosetta/model/projector.py ADDED
@@ -0,0 +1,1194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Projector nn module for the unified memory
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import Tensor
8
+ from transformers import Cache, DynamicCache
9
+ from typing import Optional, Tuple, Literal, Union
10
+ import copy
11
+ import math
12
+
13
+ from rosetta.utils.registry import register_model, get_projector_class, PROJECTOR_REGISTRY, capture_init_args, save_object, load_object
14
+
15
+ class Projector(nn.Module):
16
+ """Base projector class for unified memory"""
17
+
18
+ def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
19
+ """
20
+ Project and combine the source key-value tensors to the target key-value tensors
21
+ Args:
22
+ source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
23
+ target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
24
+ Returns:
25
+ Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
26
+ """
27
+ raise NotImplementedError("Subclasses must implement forward method")
28
+
29
+ def cache_project(self, source_kv_cache: Cache, target_kv_cache: Cache) -> Cache:
30
+ """
31
+ Project the source kv cache to the target kv cache
32
+ """
33
+ if not isinstance(source_kv_cache, DynamicCache) or not isinstance(target_kv_cache, DynamicCache):
34
+ raise ValueError("Only DynamicCache is supported")
35
+
36
+ projected_cache = DynamicCache()
37
+
38
+ # Process each layer
39
+ for layer_idx in range(len(source_kv_cache.key_cache)):
40
+ source_key = source_kv_cache.key_cache[layer_idx] # (B, H, N, D_s)
41
+ source_value = source_kv_cache.value_cache[layer_idx] # (B, H, N, D_s)
42
+
43
+ # Get corresponding target tensors (for reference/combination)
44
+ if layer_idx < len(target_kv_cache.key_cache):
45
+ target_key = target_kv_cache.key_cache[layer_idx] # (B, H, N, D_t)
46
+ target_value = target_kv_cache.value_cache[layer_idx] # (B, H, N, D_t)
47
+ else:
48
+ # If target cache doesn't have this layer, create dummy tensors
49
+ B, H, N, D_s = source_key.shape
50
+ D_t = source_key.shape[-1] # Assume same dimension for simplicity
51
+ target_key = torch.zeros(B, H, N, D_t, device=source_key.device, dtype=source_key.dtype)
52
+ target_value = torch.zeros(B, H, N, D_t, device=source_value.device, dtype=source_value.dtype)
53
+
54
+ # Reshape for forward pass: DynamicCache format (B, H, N, D) -> projector format (B, N, H, D)
55
+ source_key_reshaped = source_key.transpose(1, 2)
56
+ source_value_reshaped = source_value.transpose(1, 2)
57
+ target_key_reshaped = target_key.transpose(1, 2)
58
+ target_value_reshaped = target_value.transpose(1, 2)
59
+
60
+ # Project using forward method with tuple input/output
61
+ source_kv = (source_key_reshaped, source_value_reshaped)
62
+ target_kv = (target_key_reshaped, target_value_reshaped)
63
+ projected_key, projected_value = self.forward(source_kv, target_kv)
64
+
65
+ # Reshape back: projector format (B, N, H, D) -> DynamicCache format (B, H, N, D)
66
+ projected_key = projected_key.transpose(1, 2)
67
+ projected_value = projected_value.transpose(1, 2)
68
+
69
+ # Update cache
70
+ projected_cache.update(projected_key, projected_value, layer_idx)
71
+
72
+ return projected_cache
73
+
74
+ @register_model
75
+ @capture_init_args
76
+ class TrivialProjector(Projector):
77
+ """
78
+ Trivial projector that directly outputs the target key-value pairs without any modification.
79
+ This is useful as a baseline or when you want to effectively disable projection.
80
+ """
81
+
82
+ def __init__(self, **kwargs):
83
+ """
84
+ Initialize the trivial projector.
85
+
86
+ Args:
87
+ source_dim: Source dimension (ignored, kept for compatibility)
88
+ target_dim: Target dimension (ignored, kept for compatibility)
89
+ **kwargs: Additional arguments (ignored, kept for compatibility)
90
+ """
91
+ super().__init__()
92
+
93
+ def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
94
+ """
95
+ Return the target key-value pairs unchanged, ignoring the source.
96
+
97
+ Args:
98
+ source_kv: Tuple of (key, value) tensors (ignored)
99
+ target_kv: Tuple of (key, value) tensors to return unchanged
100
+ Returns:
101
+ The target key-value pairs unchanged
102
+ """
103
+ return target_kv
104
+
105
+ @register_model
106
+ @capture_init_args
107
+ class ReplaceProjector(Projector):
108
+ """
109
+ Replacement projector that projects source key-value tensors to target dimension using MLP,
110
+ then replace target tensors using learnable weights.
111
+ """
112
+ def __init__(
113
+ self,
114
+ source_dim: int,
115
+ target_dim: int,
116
+ hidden_dim: int = 512,
117
+ num_layers: int = 2,
118
+ dropout: float = 0.1,
119
+ activation: str = "gelu",
120
+ use_layer_norm: bool = True,
121
+ init_weight: float = 0.1,
122
+ anneal_steps: int = 1360,
123
+ initial_temperature: float = 1.0,
124
+ final_temperature: float = 0.01,
125
+ scalar_temperature: float = 0.005,
126
+ # shared_key_projection: nn.Module = None,
127
+ # shared_value_projection: nn.Module = None,
128
+ dtype: torch.dtype = torch.float32
129
+ ):
130
+ super().__init__()
131
+
132
+ self.source_dim = source_dim
133
+ self.target_dim = target_dim
134
+ self.hidden_dim = hidden_dim
135
+ self.num_layers = num_layers
136
+
137
+ # Activation function
138
+ if activation.lower() == "gelu":
139
+ self.activation = nn.GELU()
140
+ elif activation.lower() == "relu":
141
+ self.activation = nn.ReLU()
142
+ elif activation.lower() == "silu":
143
+ self.activation = nn.SiLU()
144
+ else:
145
+ raise ValueError(f"Unsupported activation: {activation}")
146
+
147
+ # Build separate MLP layers for key and value projection
148
+ self.key_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
149
+ self.value_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
150
+
151
+
152
+ def _build_mlp(self, source_dim: int, hidden_dim: int, target_dim: int, num_layers: int,
153
+ use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
154
+ """Build a single MLP projection module"""
155
+ layers = []
156
+
157
+ # Input projection
158
+ layers.append(nn.Linear(source_dim, hidden_dim, dtype=dtype))
159
+ if use_layer_norm:
160
+ layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
161
+ layers.append(copy.deepcopy(self.activation))
162
+ layers.append(nn.Dropout(dropout))
163
+
164
+ # Hidden layers
165
+ for _ in range(num_layers - 2):
166
+ layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
167
+ if use_layer_norm:
168
+ layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
169
+ layers.append(copy.deepcopy(self.activation))
170
+ layers.append(nn.Dropout(dropout))
171
+
172
+ # Output projection
173
+ if num_layers > 1:
174
+ layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
175
+ else:
176
+ # Single layer case
177
+ layers = [nn.Linear(source_dim, target_dim, dtype=dtype)]
178
+
179
+ return nn.Sequential(*layers)
180
+
181
+
182
+ def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
183
+ """
184
+ Project source key-value tensors to target dimension and add to target tensors with learnable weights
185
+ Args:
186
+ source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
187
+ target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
188
+ Returns:
189
+ Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
190
+ """
191
+
192
+ source_key, source_value = source_kv
193
+ target_key, target_value = target_kv
194
+
195
+ # Get shapes - assuming format is (B, H, N, D) where H is num_heads, N is seq_len, D is head_dim
196
+ source_shape = source_key.shape # (B, H_s, N, D_s)
197
+ target_shape = target_key.shape # (B, H_t, N, D_t)
198
+
199
+ # Extract dimensions
200
+ batch_size, source_num_heads, seq_len, source_head_dim = source_shape
201
+ _, target_num_heads, _, target_head_dim = target_shape
202
+
203
+ # Reshape source: merge num_heads and head_dim for projection
204
+ # (B, H_s, N, D_s) -> (B, N, H_s * D_s)
205
+ source_key_reshaped = source_key.transpose(1, 2) # (B, N, H_s, D_s)
206
+ source_value_reshaped = source_value.transpose(1, 2) # (B, N, H_s, D_s)
207
+
208
+ source_key_flat = source_key_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
209
+ source_value_flat = source_value_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
210
+
211
+ # Project source tensors from (H_s * D_s) to (H_t * D_t)
212
+ projected_key_flat = self.key_projection(source_key_flat) # (B, N, H_t * D_t)
213
+ projected_value_flat = self.value_projection(source_value_flat) # (B, N, H_t * D_t)
214
+
215
+ # Reshape projected tensors back to target format
216
+ # (B, N, H_t * D_t) -> (B, N, H_t, D_t) -> (B, H_t, N, D_t)
217
+ projected_key_reshaped = projected_key_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
218
+ projected_value_reshaped = projected_value_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
219
+
220
+ projected_key = projected_key_reshaped.transpose(1, 2) # (B, H_t, N, D_t)
221
+ projected_value = projected_value_reshaped.transpose(1, 2) # (B, H_t, N, D_t)
222
+
223
+ return (projected_key, projected_value)
224
+
225
+ class ModernMLP(nn.Module):
226
+ """
227
+ Modern MLP with residual connections, layer normalization, and configurable architecture.
228
+ """
229
+
230
+ def __init__(
231
+ self,
232
+ input_dim: int,
233
+ output_dim: int,
234
+ hidden_dim: int = 512,
235
+ num_layers: int = 2,
236
+ activation: str = "gelu",
237
+ use_layer_norm: bool = True,
238
+ use_residual: bool = True,
239
+ dropout: float = 0.1,
240
+ use_swiglu: bool = False,
241
+ dtype: torch.dtype = torch.float32
242
+ ):
243
+ super().__init__()
244
+
245
+ self.input_dim = input_dim
246
+ self.output_dim = output_dim
247
+ self.hidden_dim = hidden_dim
248
+ self.num_layers = num_layers
249
+ self.use_residual = use_residual and (input_dim == output_dim)
250
+ self.use_swiglu = use_swiglu
251
+
252
+ # Activation function
253
+ if activation.lower() == "gelu":
254
+ self.activation = nn.GELU()
255
+ elif activation.lower() == "relu":
256
+ self.activation = nn.ReLU()
257
+ elif activation.lower() == "silu":
258
+ self.activation = nn.SiLU()
259
+ else:
260
+ raise ValueError(f"Unsupported activation: {activation}")
261
+
262
+ # Build layers
263
+ self.layers = nn.ModuleList()
264
+
265
+ for i in range(num_layers):
266
+ layer_input_dim = input_dim if i == 0 else hidden_dim
267
+ layer_output_dim = output_dim if i == num_layers - 1 else hidden_dim
268
+
269
+ if self.use_swiglu and i < num_layers - 1: # Don't use SwiGLU on output layer
270
+ layer = SwiGLUBlock(layer_input_dim, layer_output_dim, dtype=dtype)
271
+ else:
272
+ layer = nn.Linear(layer_input_dim, layer_output_dim, dtype=dtype)
273
+
274
+ self.layers.append(layer)
275
+
276
+ # Add layer norm after each layer except the last one
277
+ if use_layer_norm and i < num_layers - 1:
278
+ self.layers.append(nn.LayerNorm(layer_output_dim, dtype=dtype))
279
+
280
+ # Add activation after each layer except the last one
281
+ if i < num_layers - 1 and not self.use_swiglu:
282
+ self.layers.append(copy.deepcopy(self.activation))
283
+
284
+ # Add dropout after activation
285
+ if dropout > 0 and i < num_layers - 1:
286
+ self.layers.append(nn.Dropout(dropout))
287
+
288
+ # Residual projection if dimensions don't match
289
+ if self.use_residual and input_dim != output_dim:
290
+ self.residual_proj = nn.Linear(input_dim, output_dim, dtype=dtype)
291
+ else:
292
+ self.residual_proj = None
293
+
294
+ def forward(self, x: Tensor) -> Tensor:
295
+ """Forward pass with optional residual connection."""
296
+ residual = x
297
+
298
+ for layer in self.layers:
299
+ x = layer(x)
300
+
301
+ # Add residual connection
302
+ if self.use_residual:
303
+ if self.residual_proj is not None:
304
+ residual = self.residual_proj(residual)
305
+ x = x + residual
306
+
307
+ return x
308
+
309
+
310
+ class SwiGLUBlock(nn.Module):
311
+ """SwiGLU activation block for modern transformer architectures."""
312
+
313
+ def __init__(self, input_dim: int, output_dim: int, dtype: torch.dtype = torch.float32):
314
+ super().__init__()
315
+ self.gate_proj = nn.Linear(input_dim, output_dim, dtype=dtype)
316
+ self.up_proj = nn.Linear(input_dim, output_dim, dtype=dtype)
317
+ self.activation = nn.SiLU()
318
+
319
+ def forward(self, x: Tensor) -> Tensor:
320
+ gate = self.activation(self.gate_proj(x))
321
+ up = self.up_proj(x)
322
+ return gate * up
323
+
324
+
325
+ @register_model
326
+ @capture_init_args
327
+ class AllInOneProjector(Projector):
328
+ """
329
+ Unified projector that consolidates all projection functionalities with modern patterns.
330
+
331
+ Features:
332
+ 1. Gate logit granularity: scalar, token-wise, head-wise, or value-wise
333
+ 2. (DISABLED) Selector logit granularity: scalar, token-wise, head-wise, or value-wise
334
+ 3. Key/Value weight granularity: scalar, token-wise, head-wise, or value-wise
335
+ 4. Input-dependent gates and weights via MLP or parameters
336
+ 5. Optional concatenation with combiner networks
337
+ 6. Modern MLP architecture with residual connections and SwiGLU
338
+ 7. Configurable target preservation: choose between traditional blending or simplified projection
339
+ 8. Optional adding of target (self) signal to outputs via add_self
340
+
341
+ Target Preservation Modes:
342
+ - preserve_target_weight=True (default): output = (1-weight)*target + gate*selector*weight*projected
343
+ - preserve_target_weight=False: output = target + gate*selector*weight*projected (no weight coefficient on target)
344
+
345
+ Note: Selector functionality has been disabled/commented out.
346
+ """
347
+
348
+ def __init__(
349
+ self,
350
+ source_dim: int,
351
+ target_dim: int,
352
+ source_num_heads: int = 1,
353
+ target_num_heads: int = 1,
354
+ hidden_dim: int = 512,
355
+ num_layers: int = 2,
356
+ dropout: float = 0.1,
357
+ activation: str = "gelu",
358
+ use_layer_norm: bool = True,
359
+ use_residual: bool = True,
360
+ use_swiglu: bool = False,
361
+
362
+ # Gate configuration
363
+ gate_granularity: Literal["scalar", "token", "head", "head_merged", "value"] = "scalar",
364
+ gate_depends_on_input: bool = False,
365
+ gate_input_features: Optional[str] = "target_key", # "target_key", "target_value", "both", "target_projected_key", "target_projected_value", "target_projected_both"
366
+ gate_init_value: float = 0.0,
367
+
368
+ # Weight configuration
369
+ weight_granularity: Literal["scalar", "token", "head", "head_merged", "value"] = "scalar",
370
+ weight_depends_on_input: bool = False,
371
+ weight_input_features: Optional[str] = "target_key", # "target_key", "target_value", "both", "target_projected_key", "target_projected_value", "target_projected_both"
372
+ weight_init_value: float = 0.0,
373
+
374
+ # Target preservation configuration
375
+ preserve_target_weight: bool = True, # If False, target won't be multiplied by (1 - normalized_weight)
376
+ add_self: bool = True, # If False, target (self) won't be added to outputs
377
+
378
+ # Concat configuration
379
+ use_concat: bool = False,
380
+ # combiner_hidden_dim: int = 128,
381
+ weight_hidden_dim: int = 1024,
382
+
383
+ # Temperature and gumbel
384
+ use_gumbel: bool = True,
385
+ initial_temperature: float = 1.0,
386
+ final_temperature: float = 0.01,
387
+ anneal_steps: int = 1360,
388
+ scalar_temperature: float = 0.005,
389
+
390
+ # Sequence length configuration
391
+ max_sequence_length: int = 8192, # Maximum sequence length for token-level parameters
392
+
393
+ pos_emb: bool = False,
394
+
395
+ dtype: torch.dtype = torch.float32
396
+ ):
397
+ super().__init__()
398
+
399
+ self.source_dim = source_dim
400
+ self.target_dim = target_dim
401
+ self.source_num_heads = source_num_heads
402
+ self.target_num_heads = target_num_heads
403
+ self.hidden_dim = hidden_dim
404
+ self.weight_hidden_dim = weight_hidden_dim
405
+ self.max_sequence_length = max_sequence_length
406
+
407
+ # Configuration
408
+ self.gate_granularity = gate_granularity
409
+ self.gate_depends_on_input = gate_depends_on_input
410
+ self.gate_input_features = gate_input_features
411
+ self.weight_granularity = weight_granularity
412
+ self.weight_depends_on_input = weight_depends_on_input
413
+ self.weight_input_features = weight_input_features
414
+ self.preserve_target_weight = preserve_target_weight
415
+ self.add_self = add_self
416
+ self.use_concat = use_concat
417
+ self.use_gumbel = use_gumbel
418
+ self.scalar_temperature = scalar_temperature
419
+
420
+ # Temperature annealing for gate only (removed selector temperature)
421
+ self.register_buffer("gate_temperature", torch.tensor(initial_temperature, dtype=dtype))
422
+ self.initial_temperature = initial_temperature
423
+ self.final_temperature = final_temperature
424
+ self.anneal_steps = anneal_steps
425
+
426
+ # Build projection networks
427
+ self.key_projection = self._build_projection_mlp(
428
+ source_dim * source_num_heads,
429
+ target_dim * target_num_heads,
430
+ hidden_dim, num_layers, activation, use_layer_norm,
431
+ use_residual, dropout, use_swiglu, dtype
432
+ )
433
+ self.value_projection = self._build_projection_mlp(
434
+ source_dim * source_num_heads,
435
+ target_dim * target_num_heads,
436
+ hidden_dim, num_layers, activation, use_layer_norm,
437
+ use_residual, dropout, use_swiglu, dtype
438
+ )
439
+
440
+ # Build gate components
441
+ self._build_gate_components(dtype)
442
+
443
+ # Build weight components
444
+ self._build_weight_components(weight_init_value, dtype)
445
+
446
+ # Build concat components if needed
447
+ if self.use_concat:
448
+ in_dim = target_dim * target_num_heads * 2
449
+ out_dim = target_dim * target_num_heads
450
+ self.key_combiner = nn.Linear(in_dim, out_dim, dtype=dtype)
451
+ self.value_combiner = nn.Linear(in_dim, out_dim, dtype=dtype)
452
+
453
+ def _build_projection_mlp(
454
+ self, input_dim: int, output_dim: int, hidden_dim: int,
455
+ num_layers: int, activation: str, use_layer_norm: bool,
456
+ use_residual: bool, dropout: float, use_swiglu: bool, dtype: torch.dtype
457
+ ) -> ModernMLP:
458
+ """Build modern MLP for projection."""
459
+ return ModernMLP(
460
+ input_dim=input_dim,
461
+ output_dim=output_dim,
462
+ hidden_dim=hidden_dim,
463
+ num_layers=num_layers,
464
+ activation=activation,
465
+ use_layer_norm=use_layer_norm,
466
+ use_residual=use_residual,
467
+ dropout=dropout,
468
+ use_swiglu=use_swiglu,
469
+ dtype=dtype
470
+ )
471
+
472
+ def _build_gate_components(self, dtype: torch.dtype):
473
+ """Build gate logit components based on configuration."""
474
+ if not self.gate_depends_on_input:
475
+ # Parameter-based gate
476
+ gate_shape = self._get_parameter_shape(self.gate_granularity)
477
+ self.gate_logit = nn.Parameter(torch.zeros(gate_shape, dtype=dtype))
478
+ else:
479
+ # Input-dependent gate via MLP
480
+ input_dim = self._get_gate_input_dim()
481
+ output_dim = self._get_gate_output_dim()
482
+
483
+ self.gate_generator = ModernMLP(
484
+ input_dim=input_dim,
485
+ output_dim=output_dim,
486
+ hidden_dim=self.hidden_dim,
487
+ num_layers=2,
488
+ activation="gelu",
489
+ use_layer_norm=True,
490
+ use_residual=False,
491
+ dropout=0.1,
492
+ dtype=dtype
493
+ )
494
+
495
+ def _build_weight_components(self, weight_init_value: float, dtype: torch.dtype):
496
+ """Build weight components based on configuration."""
497
+ if not self.weight_depends_on_input:
498
+ # Parameter-based weights
499
+ weight_shape = self._get_parameter_shape(self.weight_granularity)
500
+ self.key_weight = nn.Parameter(torch.full(weight_shape, weight_init_value, dtype=dtype))
501
+ self.value_weight = nn.Parameter(torch.full(weight_shape, weight_init_value, dtype=dtype))
502
+ else:
503
+ # Input-dependent weights via MLP
504
+ input_dim = self._get_weight_input_dim()
505
+ output_dim = self._get_weight_output_dim()
506
+
507
+ # Shared hidden layer for efficiency
508
+ self.weight_hidden = ModernMLP(
509
+ input_dim=input_dim,
510
+ output_dim=self.weight_hidden_dim,
511
+ hidden_dim=self.weight_hidden_dim,
512
+ num_layers=2,
513
+ activation="gelu",
514
+ use_layer_norm=True,
515
+ use_residual=False,
516
+ dropout=0.1,
517
+ dtype=dtype
518
+ )
519
+
520
+ # Separate heads for key and value weights
521
+ self.key_weight_head = nn.Linear(self.weight_hidden_dim, output_dim, dtype=dtype)
522
+ self.value_weight_head = nn.Linear(self.weight_hidden_dim, output_dim, dtype=dtype)
523
+
524
+ def _get_parameter_shape(self, granularity: str) -> tuple:
525
+ """Get parameter shape based on granularity."""
526
+ if granularity == "scalar":
527
+ return () # Scalar
528
+ elif granularity == "token":
529
+ return (self.max_sequence_length,) # Token-level parameters with max sequence length
530
+ elif granularity == "head":
531
+ return (self.max_sequence_length, self.target_num_heads) # Token and head level parameters
532
+ elif granularity == "head_merged":
533
+ return (self.max_sequence_length, self.target_num_heads) # Token and head level parameters
534
+ elif granularity == "value":
535
+ return (self.max_sequence_length, self.target_num_heads, self.target_dim) # Token, head and value level parameters
536
+ else:
537
+ raise ValueError(f"Invalid granularity: {granularity}")
538
+
539
+ def _get_gate_input_dim(self) -> int:
540
+ """Get input dimension for gate generator."""
541
+ base_dim = 0
542
+ if self.gate_input_features == "target_key":
543
+ base_dim = self.target_dim
544
+ elif self.gate_input_features == "target_value":
545
+ base_dim = self.target_dim
546
+ elif self.gate_input_features == "both":
547
+ base_dim = self.target_dim * 2
548
+ elif self.gate_input_features == "target_projected_key":
549
+ base_dim = self.target_dim * 2 # target_key + projected_key
550
+ elif self.gate_input_features == "target_projected_value":
551
+ base_dim = self.target_dim * 2 # target_value + projected_value
552
+ elif self.gate_input_features == "target_projected_both":
553
+ base_dim = self.target_dim * 4 # target_key + target_value + projected_key + projected_value
554
+ else:
555
+ raise ValueError(f"Invalid gate input features: {self.gate_input_features}")
556
+
557
+ # Adjust for granularity processing strategy
558
+ if self.gate_granularity == "scalar":
559
+ # Scalar: process aggregated features across all heads
560
+ return base_dim # Use pooled features
561
+ elif self.gate_granularity == "token":
562
+ # Token: process merged head dimensions
563
+ return base_dim * self.target_num_heads # Flatten (H, D) to (H*D)
564
+ elif self.gate_granularity == "head_merged":
565
+ # Head-merged: similar to token granularity, merge H and D
566
+ return base_dim * self.target_num_heads # (B, N, H*D)
567
+ elif self.gate_granularity == "head":
568
+ # Head-local: per head processing, do not merge heads
569
+ return base_dim # (B, H, N, D)
570
+ else: # value
571
+ # Value: process per-head features
572
+ return base_dim # Keep per-head processing (B, H, N, D)
573
+
574
+ def _get_gate_output_dim(self) -> int:
575
+ """Get output dimension for gate generator."""
576
+ if self.gate_granularity == "scalar":
577
+ return 1
578
+ elif self.gate_granularity == "token":
579
+ return 1 # Per token
580
+ elif self.gate_granularity == "head_merged":
581
+ # Per token per head after merge: output one value per head
582
+ return self.target_num_heads
583
+ elif self.gate_granularity == "head":
584
+ # Per token per head: scalar per head
585
+ return 1
586
+ elif self.gate_granularity == "value":
587
+ return self.target_dim # Per token per head per value (but processed per-head, so output D per head)
588
+ else:
589
+ raise ValueError(f"Invalid gate granularity: {self.gate_granularity}")
590
+
591
+ def _get_weight_input_dim(self) -> int:
592
+ """Get input dimension for weight generator."""
593
+ base_dim = 0
594
+ if self.weight_input_features == "target_key":
595
+ base_dim = self.target_dim
596
+ elif self.weight_input_features == "target_value":
597
+ base_dim = self.target_dim
598
+ elif self.weight_input_features == "both":
599
+ base_dim = self.target_dim * 2
600
+ elif self.weight_input_features == "target_projected_key":
601
+ base_dim = self.target_dim * 2 # target_key + projected_key
602
+ elif self.weight_input_features == "target_projected_value":
603
+ base_dim = self.target_dim * 2 # target_value + projected_value
604
+ elif self.weight_input_features == "target_projected_both":
605
+ base_dim = self.target_dim * 4 # target_key + target_value + projected_key + projected_value
606
+ else:
607
+ raise ValueError(f"Invalid weight input features: {self.weight_input_features}")
608
+
609
+ # Adjust for granularity processing strategy
610
+ if self.weight_granularity == "scalar":
611
+ # Scalar: process aggregated features across all heads
612
+ return base_dim # Use pooled features
613
+ elif self.weight_granularity == "token":
614
+ # Token: process merged head dimensions
615
+ return base_dim * self.target_num_heads # Flatten (H, D) to (H*D)
616
+ elif self.weight_granularity == "head_merged":
617
+ # Head-merged: similar to token granularity, merge H and D
618
+ return base_dim * self.target_num_heads # (B, N, H*D)
619
+ elif self.weight_granularity == "head":
620
+ # Head-local: per head processing, do not merge heads
621
+ return base_dim # (B, H, N, D)
622
+ else: # value
623
+ # Value: process per-head features
624
+ return base_dim # Keep per-head processing (B, H, N, D)
625
+
626
+ def _get_weight_output_dim(self) -> int:
627
+ """Get output dimension for weight generator."""
628
+ if self.weight_granularity == "scalar":
629
+ return 1
630
+ elif self.weight_granularity == "token":
631
+ return 1 # Per token
632
+ elif self.weight_granularity == "head_merged":
633
+ # Per token per head after merge: output one value per head
634
+ return self.target_num_heads
635
+ elif self.weight_granularity == "head":
636
+ # Per token per head: scalar per head
637
+ return 1
638
+ elif self.weight_granularity == "value":
639
+ return self.target_dim # Per token per head per value (but processed per-head, so output D per head)
640
+ else:
641
+ raise ValueError(f"Invalid weight granularity: {self.weight_granularity}")
642
+
643
+ def _generate_gates(self, target_key: Tensor, target_value: Tensor, projected_key: Tensor = None, projected_value: Tensor = None) -> Tensor:
644
+ """Generate gate logits based on configuration."""
645
+ if not self.gate_depends_on_input:
646
+ # Use parameter-based gate
647
+ return self.gate_logit
648
+ else:
649
+ # Generate input-dependent gate
650
+ # First, prepare the base input features
651
+ if self.gate_input_features == "target_key":
652
+ base_input = target_key
653
+ elif self.gate_input_features == "target_value":
654
+ base_input = target_value
655
+ elif self.gate_input_features == "both":
656
+ base_input = torch.cat([target_key, target_value], dim=-1)
657
+ elif self.gate_input_features == "target_projected_key":
658
+ if projected_key is None:
659
+ raise ValueError("projected_key is required for target_projected_key input features")
660
+ base_input = torch.cat([target_key, projected_key], dim=-1)
661
+ elif self.gate_input_features == "target_projected_value":
662
+ if projected_value is None:
663
+ raise ValueError("projected_value is required for target_projected_value input features")
664
+ base_input = torch.cat([target_value, projected_value], dim=-1)
665
+ elif self.gate_input_features == "target_projected_both":
666
+ if projected_key is None or projected_value is None:
667
+ raise ValueError("Both projected_key and projected_value are required for target_projected_both input features")
668
+ base_input = torch.cat([target_key, target_value, projected_key, projected_value], dim=-1)
669
+
670
+ # Now process based on granularity
671
+ # base_input shape: (B, H, N, D_input)
672
+ B, H, N, D_input = base_input.shape
673
+
674
+ if self.gate_granularity == "scalar":
675
+ # For scalar granularity, aggregate all dimensions: (B, H, N, D_input) -> (B, D_input)
676
+ gate_input = base_input.mean(dim=(1, 2)) # Average over heads and tokens
677
+ elif self.gate_granularity == "token":
678
+ # For token granularity, merge H and D_input dimensions: (B, H, N, D_input) -> (B, N, H*D_input)
679
+ gate_input = base_input.transpose(1, 2).contiguous().view(B, N, H * D_input)
680
+ elif self.gate_granularity == "head_merged":
681
+ # For head granularity, merge H and D like token: (B, H, N, D_in) -> (B, N, H*D_in)
682
+ gate_input = base_input.transpose(1, 2).contiguous().view(B, N, H * D_input)
683
+ elif self.gate_granularity == "head":
684
+ # For head granularity, keep per-head processing: (B, H, N, D_input)
685
+ gate_input = base_input
686
+ elif self.gate_granularity == "value":
687
+ # For value granularity, keep per-head processing: (B, H, N, D_input)
688
+ gate_input = base_input
689
+
690
+ return self.gate_generator(gate_input)
691
+
692
+ def _generate_weights(self, target_key: Tensor, target_value: Tensor, projected_key: Tensor = None, projected_value: Tensor = None) -> Tuple[Tensor, Tensor]:
693
+ """Generate weights based on configuration."""
694
+ if not self.weight_depends_on_input:
695
+ # Use parameter-based weights
696
+ return self.key_weight, self.value_weight
697
+ else:
698
+ # Generate input-dependent weights
699
+ # First, prepare the base input features
700
+ if self.weight_input_features == "target_key":
701
+ base_input = target_key
702
+ elif self.weight_input_features == "target_value":
703
+ base_input = target_value
704
+ elif self.weight_input_features == "both":
705
+ base_input = torch.cat([target_key, target_value], dim=-1)
706
+ elif self.weight_input_features == "target_projected_key":
707
+ if projected_key is None:
708
+ raise ValueError("projected_key is required for target_projected_key input features")
709
+ base_input = torch.cat([target_key, projected_key], dim=-1)
710
+ elif self.weight_input_features == "target_projected_value":
711
+ if projected_value is None:
712
+ raise ValueError("projected_value is required for target_projected_value input features")
713
+ base_input = torch.cat([target_value, projected_value], dim=-1)
714
+ elif self.weight_input_features == "target_projected_both":
715
+ if projected_key is None or projected_value is None:
716
+ raise ValueError("Both projected_key and projected_value are required for target_projected_both input features")
717
+ base_input = torch.cat([target_key, target_value, projected_key, projected_value], dim=-1)
718
+
719
+ # Now process based on granularity
720
+ # base_input shape: (B, H, N, D_input)
721
+ B, H, N, D_input = base_input.shape
722
+
723
+ if self.weight_granularity == "scalar":
724
+ # For scalar granularity, aggregate all dimensions: (B, H, N, D_input) -> (B, D_input)
725
+ weight_input = base_input.mean(dim=(1, 2)) # Average over heads and tokens
726
+ elif self.weight_granularity == "token":
727
+ # For token granularity, merge H and D_input dimensions: (B, H, N, D_input) -> (B, N, H*D_input)
728
+ weight_input = base_input.transpose(1, 2).contiguous().view(B, N, H * D_input)
729
+ elif self.weight_granularity == "head_merged":
730
+ # For head granularity, merge H and D like token: (B, H, N, D_in) -> (B, N, H*D_in)
731
+ weight_input = base_input.transpose(1, 2).contiguous().view(B, N, H * D_input)
732
+ elif self.weight_granularity == "head":
733
+ # For head granularity, keep per-head processing: (B, H, N, D_input)
734
+ weight_input = base_input
735
+ elif self.weight_granularity == "value":
736
+ # For value granularity, keep per-head processing: (B, H, N, D_input)
737
+ weight_input = base_input
738
+
739
+ weight_hidden = self.weight_hidden(weight_input)
740
+ key_weight = self.key_weight_head(weight_hidden)
741
+ value_weight = self.value_weight_head(weight_hidden)
742
+
743
+ return key_weight, value_weight
744
+
745
+ def _apply_gumbel_sigmoid(self, gate_logit: Tensor) -> Tensor:
746
+ """Apply Gumbel sigmoid trick for training."""
747
+ if self.training and self.use_gumbel:
748
+ gumbel_noise = self._sample_gumbel(gate_logit.shape, gate_logit.device, gate_logit.dtype)
749
+ return torch.sigmoid((gate_logit + gumbel_noise) / self.gate_temperature)
750
+ else:
751
+ return (gate_logit > 0).float()
752
+
753
+ @staticmethod
754
+ def _sample_gumbel(shape: tuple, device: torch.device, dtype: torch.dtype, eps: float = 1e-20) -> Tensor:
755
+ """Sample from Gumbel distribution."""
756
+ u = torch.rand(shape, device=device, dtype=dtype)
757
+ return -torch.log(-torch.log(u + eps) + eps)
758
+
759
+ def _reshape_for_granularity(self, tensor: Tensor, granularity: str, target_shape: tuple) -> Tensor:
760
+ """Reshape tensor to match target shape based on granularity."""
761
+ B, H, N, D = target_shape
762
+
763
+ if granularity == "scalar":
764
+ # Scalar -> (B, H, N, D)
765
+ return tensor.view(1, 1, 1, 1).expand(B, H, N, D)
766
+ elif granularity == "token":
767
+ # (max_seq_len,) -> (B, H, N, D) - slice to actual sequence length
768
+ token_params = tensor[:N] # Take first N tokens
769
+ return token_params.view(1, 1, N, 1).expand(B, H, N, D)
770
+ elif granularity == "head":
771
+ # (max_seq_len, H) -> (B, H, N, D) - slice to actual sequence length, each token each head independent
772
+ head_params = tensor[:N, :] # Take first N tokens, all heads: (N, H)
773
+ return head_params.view(1, N, H, 1).transpose(1, 2).expand(B, H, N, D) # (1, N, H, 1) -> (1, H, N, 1) -> (B, H, N, D)
774
+ elif granularity == "head_merged":
775
+ raise NotImplementedError
776
+ elif granularity == "value":
777
+ # (max_seq_len, H, D) -> (B, H, N, D) - slice to actual sequence length, each token each head each value independent
778
+ value_params = tensor[:N, :, :] # Take first N tokens: (N, H, D)
779
+ return value_params.view(1, N, H, D).transpose(1, 2).expand(B, H, N, D) # (1, N, H, D) -> (1, H, N, D) -> (B, H, N, D)
780
+ else:
781
+ raise ValueError(f"Invalid granularity: {granularity}")
782
+
783
+ def update_temperature(self, step: int):
784
+ """Update temperature using exponential annealing schedule for gate only."""
785
+ # Update gate temperature
786
+ gate_ratio = min(step / self.anneal_steps, 1.0)
787
+ gate_temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** gate_ratio
788
+ self.gate_temperature.fill_(gate_temp)
789
+
790
+
791
+ def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor], position_ids: Optional[Tensor] = None, max_pos: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
792
+ """
793
+ Forward pass with unified projection logic.
794
+
795
+ Args:
796
+ source_kv: Tuple of (key, value) tensors, each (B, H_s, N, D_s)
797
+ target_kv: Tuple of (key, value) tensors, each (B, H_t, N, D_t)
798
+ position_ids: Position ids tensor (B, N), optional, required if pos_emb=True
799
+ Returns:
800
+ Tuple of (key, value) tensors, each (B, H_t, N, D_t)
801
+ """
802
+ source_key, source_value = source_kv
803
+ target_key, target_value = target_kv
804
+
805
+ # Get shapes
806
+ B, H_s, N, D_s = source_key.shape
807
+ _, H_t, _, D_t = target_key.shape
808
+
809
+ # Reshape for projection: (B, H, N, D) -> (B, N, H*D)
810
+ source_key_flat = source_key.transpose(1, 2).contiguous().view(B, N, H_s * D_s)
811
+ source_value_flat = source_value.transpose(1, 2).contiguous().view(B, N, H_s * D_s)
812
+
813
+ # Project source to target dimension
814
+ projected_key_flat = self.key_projection(source_key_flat) # (B, N, H_t * D_t)
815
+ projected_value_flat = self.value_projection(source_value_flat) # (B, N, H_t * D_t)
816
+
817
+ # Handle concatenation if enabled
818
+ if self.use_concat:
819
+ target_key_flat = target_key.transpose(1, 2).contiguous().view(B, N, H_t * D_t)
820
+ target_value_flat = target_value.transpose(1, 2).contiguous().view(B, N, H_t * D_t)
821
+
822
+ # Concatenate and combine
823
+ combined_key = torch.cat([projected_key_flat, target_key_flat], dim=-1)
824
+ combined_value = torch.cat([projected_value_flat, target_value_flat], dim=-1)
825
+
826
+ final_projected_key_flat = self.key_combiner(combined_key)
827
+ final_projected_value_flat = self.value_combiner(combined_value)
828
+ else:
829
+ final_projected_key_flat = projected_key_flat
830
+ final_projected_value_flat = projected_value_flat
831
+
832
+ # Reshape back: (B, N, H_t * D_t) -> (B, H_t, N, D_t)
833
+ projected_key = final_projected_key_flat.view(B, N, H_t, D_t).transpose(1, 2)
834
+ projected_value = final_projected_value_flat.view(B, N, H_t, D_t).transpose(1, 2)
835
+
836
+ # Generate gates, selectors and weights (may need projected tensors for input features)
837
+ needs_projected_for_gate = self.gate_depends_on_input and self.gate_input_features in [
838
+ "target_projected_key", "target_projected_value", "target_projected_both"
839
+ ]
840
+ needs_projected_for_weight = self.weight_depends_on_input and self.weight_input_features in [
841
+ "target_projected_key", "target_projected_value", "target_projected_both"
842
+ ]
843
+
844
+ if needs_projected_for_gate or needs_projected_for_weight:
845
+ gate_logit = self._generate_gates(target_key, target_value, projected_key, projected_value)
846
+ key_weight, value_weight = self._generate_weights(target_key, target_value, projected_key, projected_value)
847
+ else:
848
+ gate_logit = self._generate_gates(target_key, target_value)
849
+ key_weight, value_weight = self._generate_weights(target_key, target_value)
850
+
851
+ # Reshape gates and weights to match target shape
852
+ target_shape = (B, H_t, N, D_t)
853
+ if self.gate_depends_on_input:
854
+ # Reshape based on gate granularity - all preserve token dimension N
855
+ if self.gate_granularity == "scalar":
856
+ # For scalar, gate_logit is already (B, 1) from MLP, just expand
857
+ gate_logit = gate_logit.view(B, 1, 1, 1).expand(target_shape)
858
+ elif self.gate_granularity == "token":
859
+ gate_logit = gate_logit.unsqueeze(1).unsqueeze(-1).expand(target_shape) # (B, N, 1) -> (B, H, N, D)
860
+ elif self.gate_granularity == "head_merged":
861
+ # (B, N, H) -> (B, H, N, D) - per token per head, broadcast over D
862
+ gate_logit = gate_logit.permute(0, 2, 1).unsqueeze(-1).expand(B, H_t, N, D_t)
863
+ elif self.gate_granularity == "head":
864
+ # (B, H, N, 1) -> (B, H, N, D) - per token per head scalar, broadcast over D
865
+ gate_logit = gate_logit.expand(B, H_t, N, D_t)
866
+ elif self.gate_granularity == "value":
867
+ # (B, H, N, D) -> (B, H, N, D) - each token each head each value has one value
868
+ pass # Already in correct shape
869
+ else:
870
+ gate_logit = self._reshape_for_granularity(gate_logit, self.gate_granularity, target_shape)
871
+
872
+ if self.weight_depends_on_input:
873
+ # Reshape weights based on granularity - all preserve token dimension N
874
+ if self.weight_granularity == "scalar":
875
+ # For scalar, weights are already (B, 1) from MLP, just expand
876
+ key_weight = key_weight.view(B, 1, 1, 1).expand(target_shape)
877
+ value_weight = value_weight.view(B, 1, 1, 1).expand(target_shape)
878
+ elif self.weight_granularity == "token":
879
+ key_weight = key_weight.unsqueeze(1).expand(target_shape) # (B, N, 1) -> (B, H, N, D)
880
+ value_weight = value_weight.unsqueeze(1).expand(target_shape)
881
+ elif self.weight_granularity == "head_merged":
882
+ # (B, N, H) -> (B, H, N, D) - per token per head, broadcast over D
883
+ key_weight = key_weight.permute(0, 2, 1).unsqueeze(-1).expand(B, H_t, N, D_t)
884
+ value_weight = value_weight.permute(0, 2, 1).unsqueeze(-1).expand(B, H_t, N, D_t)
885
+ elif self.weight_granularity == "head":
886
+ # (B, H, N, 1) -> (B, H, N, D) - per token per head scalar, broadcast over D
887
+ key_weight = key_weight.expand(B, H_t, N, D_t)
888
+ value_weight = value_weight.expand(B, H_t, N, D_t)
889
+ elif self.weight_granularity == "value":
890
+ # (B, H, N, D) -> (B, H, N, D) - each token each head each value has one value
891
+ pass # Already in correct shape
892
+ else:
893
+ key_weight = self._reshape_for_granularity(key_weight, self.weight_granularity, target_shape)
894
+ value_weight = self._reshape_for_granularity(value_weight, self.weight_granularity, target_shape)
895
+
896
+ # Apply gating and selection
897
+ gate = self._apply_gumbel_sigmoid(gate_logit)
898
+
899
+ # Normalize weights using dynamic temperature
900
+ normalized_key_weight = torch.sigmoid(key_weight / self.scalar_temperature)
901
+ normalized_value_weight = torch.sigmoid(value_weight / self.scalar_temperature)
902
+
903
+ # Final combination
904
+ # Compute projected contribution (always present)
905
+ projected_key_term = gate * normalized_key_weight * projected_key
906
+ projected_value_term = gate * normalized_value_weight * projected_value
907
+
908
+ # Compute target (self) contribution depending on flags
909
+ if self.add_self:
910
+ if self.preserve_target_weight:
911
+ target_key_term = (1 - normalized_key_weight) * target_key
912
+ target_value_term = (1 - normalized_value_weight) * target_value
913
+ else:
914
+ target_key_term = target_key
915
+ target_value_term = target_value
916
+ else:
917
+ target_key_term = torch.zeros_like(target_key)
918
+ target_value_term = torch.zeros_like(target_value)
919
+
920
+ # Final outputs
921
+ output_key = target_key_term + projected_key_term
922
+ output_value = target_value_term + projected_value_term
923
+
924
+ return (output_key, output_value)
925
+
926
+ class QwenStyleLayer(nn.Module):
927
+ """
928
+ One Qwen3-style MLP sublayer:
929
+ y = x + Dropout( down( SiLU(gate(LN(x))) * up(LN(x)) ) )
930
+ - Pre-norm with RMSNorm
931
+ - Bias-free linears
932
+ """
933
+ def __init__(self, hidden_size: int, intermediate_size: int, dropout: float = 0.0, dtype: torch.dtype = torch.float32):
934
+ super().__init__()
935
+ self.norm = nn.RMSNorm(hidden_size, eps=1e-6, dtype=dtype)
936
+ self.gate = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype)
937
+ self.up = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype)
938
+ self.down = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype)
939
+ self.act = nn.SiLU()
940
+ self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
941
+
942
+ def forward(self, x: Tensor) -> Tensor:
943
+ h = self.norm(x)
944
+ h = self.act(self.gate(h)) * self.up(h) # SwiGLU
945
+ h = self.down(h)
946
+ h = self.drop(h)
947
+ return x + h
948
+
949
+ class StandardFFNLayer(nn.Module):
950
+ """
951
+ Pre-norm RMSNorm, classic MLP:
952
+ y = x + Dropout( W2( Act( W1( RMSNorm(x) ) ) ) )
953
+ - No SwiGLU: single hidden nonlinearity (GELU/ReLU/SiLU)
954
+ - Bias-free linears (common in modern LLM FFNs)
955
+ """
956
+ def __init__(
957
+ self,
958
+ hidden_size: int,
959
+ intermediate_size: int,
960
+ dropout: float = 0.0,
961
+ dtype: torch.dtype = torch.float32,
962
+ activation: str = "gelu",
963
+ ):
964
+ super().__init__()
965
+ self.norm = nn.RMSNorm(hidden_size, eps=1e-6, dtype=dtype)
966
+ self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype)
967
+ self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype)
968
+ self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
969
+
970
+ act = activation.lower()
971
+ if act == "gelu":
972
+ self.act = nn.GELU()
973
+ elif act == "relu":
974
+ self.act = nn.ReLU()
975
+ elif act == "silu":
976
+ self.act = nn.SiLU()
977
+ else:
978
+ raise ValueError(f"Unsupported activation: {activation}")
979
+
980
+ def forward(self, x: Tensor) -> Tensor:
981
+ h = self.norm(x)
982
+ h = self.act(self.w1(h))
983
+ h = self.w2(h)
984
+ h = self.drop(h)
985
+ return x + h
986
+
987
+ class RegularMLP(nn.Module):
988
+ """
989
+ Qwen3-style stacked MLP operating at a fixed hidden size.
990
+ - No input/output projections; caller is responsible for projections.
991
+ - num_layers repeats of Qwen-style FFN sublayer (pre-RMSNorm, SwiGLU, bias-free)
992
+ """
993
+ def __init__(
994
+ self,
995
+ hidden_dim: int = 1024,
996
+ intermediate_dim: int = 3072,
997
+ num_layers: int = 3,
998
+ dropout: float = 0.1,
999
+ dtype: torch.dtype = torch.float32,
1000
+ ):
1001
+ super().__init__()
1002
+ assert num_layers >= 1, "num_layers must be >= 1"
1003
+
1004
+ self.blocks = nn.ModuleList([
1005
+ StandardFFNLayer(hidden_size=hidden_dim, intermediate_size=intermediate_dim, dropout=dropout, dtype=dtype)
1006
+ for _ in range(num_layers)
1007
+ ])
1008
+
1009
+ def forward(self, x: Tensor) -> Tensor:
1010
+ for blk in self.blocks:
1011
+ x = blk(x)
1012
+ return x
1013
+
1014
+ @register_model
1015
+ @capture_init_args
1016
+ class C2CProjector(Projector):
1017
+ """
1018
+ Concise projector specialized to a fixed C2C configuration using StandardMLP.
1019
+ - Projections: StandardMLP (pre-RMSNorm, SwiGLU, residual per sublayer)
1020
+ - Concat: enabled, followed by linear combiner to target size
1021
+ - Gate: scalar parameter with Gumbel-sigmoid during training
1022
+ - Weights: input-dependent, head_merged granularity using target and projected key
1023
+ - Target preservation: add_self=True, preserve_target_weight=False
1024
+ - Temperatures: annealed gate temperature (1.0 -> 0.001 over 1929 steps), scalar_temperature=1.0
1025
+ """
1026
+
1027
+ def __init__(
1028
+ self,
1029
+ source_dim: int,
1030
+ target_dim: int,
1031
+ source_num_heads: int = 1,
1032
+ target_num_heads: int = 1,
1033
+ intermediate_dim: int = 1024,
1034
+ hidden_dim: int = 1024,
1035
+ num_layers: int = 3,
1036
+ dropout: float = 0.1,
1037
+ initial_temperature: float = 1.0,
1038
+ final_temperature: float = 0.001,
1039
+ anneal_steps: int = 1929,
1040
+ dtype: torch.dtype = torch.float32,
1041
+ ):
1042
+ super().__init__()
1043
+
1044
+ assert num_layers >= 3, "num_layers must be >= 3"
1045
+
1046
+ # Dimensions
1047
+ self.source_dim = source_dim
1048
+ self.target_dim = target_dim
1049
+ self.source_num_heads = source_num_heads
1050
+ self.target_num_heads = target_num_heads
1051
+
1052
+ # Sizes
1053
+ in_dim = source_dim * source_num_heads
1054
+ out_dim = target_dim * target_num_heads
1055
+
1056
+ # 1) concat(source_X, target_X) then project to hidden_dim
1057
+ self.key_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype)
1058
+ self.value_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype)
1059
+
1060
+ # 2) one-layer common embedding MLP to get intermediate representation (at hidden_dim)
1061
+ self.key_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype)
1062
+ self.value_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype)
1063
+
1064
+ # 3a) intermediate representation → (L-2)-layer MLP for weights → project to head dim
1065
+ self.key_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype)
1066
+ self.value_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype)
1067
+ self.key_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype)
1068
+ self.value_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype)
1069
+
1070
+ # 3b) intermediate representation → (L-2)-layer MLP for projected_X → finally project hidden_dim → out_dim
1071
+ self.key_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype)
1072
+ self.value_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype)
1073
+ self.key_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype)
1074
+ self.value_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype)
1075
+
1076
+ # Scalar key/value gate parameters and temperature schedule
1077
+ self.key_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
1078
+ self.value_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
1079
+ self.use_gumbel = True
1080
+ self.register_buffer("gate_temperature", torch.tensor(initial_temperature, dtype=dtype))
1081
+ self.initial_temperature = initial_temperature
1082
+ self.final_temperature = final_temperature
1083
+ self.anneal_steps = anneal_steps
1084
+
1085
+ # Temperature for weight normalization
1086
+ self.scalar_temperature = 1.0
1087
+
1088
+ def update_temperature(self, step: int):
1089
+ ratio = min(step / self.anneal_steps, 1.0)
1090
+ temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
1091
+ self.gate_temperature.fill_(temp)
1092
+
1093
+ def forward(
1094
+ self,
1095
+ source_kv: Tuple[Tensor, Tensor],
1096
+ target_kv: Tuple[Tensor, Tensor],
1097
+ position_ids: Optional[Tensor] = None,
1098
+ max_pos: Optional[Tensor] = None,
1099
+ ) -> Tuple[Tensor, Tensor]:
1100
+ source_key, source_value = source_kv
1101
+ target_key, target_value = target_kv
1102
+
1103
+ B, Hs, N, Ds = source_key.shape
1104
+ _, Ht, _, Dt = target_key.shape
1105
+
1106
+ # Flatten heads
1107
+ source_key_flat = source_key.transpose(1, 2).contiguous().view(B, N, Hs * Ds)
1108
+ source_value_flat = source_value.transpose(1, 2).contiguous().view(B, N, Hs * Ds)
1109
+ target_key_flat = target_key.transpose(1, 2).contiguous().view(B, N, Ht * Dt)
1110
+ target_value_flat = target_value.transpose(1, 2).contiguous().view(B, N, Ht * Dt)
1111
+
1112
+ # 1) concat source and target features along channel
1113
+ key_cat = torch.cat([source_key_flat, target_key_flat], dim=-1)
1114
+ value_cat = torch.cat([source_value_flat, target_value_flat], dim=-1)
1115
+
1116
+ # 2) project to hidden dim
1117
+ key_hidden = self.key_in(key_cat)
1118
+ value_hidden = self.value_in(value_cat)
1119
+
1120
+ # 3) one-layer common embedding MLP to get intermediate representation (at hidden_dim)
1121
+ key_hidden = self.key_mlp1(key_hidden)
1122
+ value_hidden = self.value_mlp1(value_hidden)
1123
+
1124
+ # 4b) intermediate representation -> projected feature path
1125
+ key_proj_hidden = self.key_proj_out(self.key_proj_mlp2(key_hidden)) # (B, N, Ht * Dt)
1126
+ value_proj_hidden = self.value_proj_out(self.value_proj_mlp2(value_hidden)) # (B, N, Ht * Dt)
1127
+ projected_key = key_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt)
1128
+ projected_value = value_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt)
1129
+
1130
+ # 4a) intermediate representation -> scalar path
1131
+ key_scalar = self.key_scalar_head(self.key_scalar_mlp2(key_hidden)) # (B, N, Ht)
1132
+ value_scalar = self.value_scalar_head(self.value_scalar_mlp2(value_hidden)) # (B, N, Ht)
1133
+ key_scalar = key_scalar.permute(0, 2, 1).unsqueeze(-1) # (B, Ht, N, 1)
1134
+ value_scalar = value_scalar.permute(0, 2, 1).unsqueeze(-1) # (B, Ht, N, 1)
1135
+
1136
+ # Key/value gates: element-wise Gumbel noise with scalar logits (broadcast over channels)
1137
+ key_gate_logit = self.key_gate_logit.view(1, 1, 1, 1)
1138
+ value_gate_logit = self.value_gate_logit.view(1, 1, 1, 1)
1139
+ if self.training and self.use_gumbel:
1140
+ u1 = torch.rand(B, Ht, N, 1, device=key_gate_logit.device, dtype=key_gate_logit.dtype)
1141
+ u2 = torch.rand(B, Ht, N, 1, device=value_gate_logit.device, dtype=value_gate_logit.dtype)
1142
+ g1 = -torch.log(-torch.log(u1 + 1e-20) + 1e-20)
1143
+ g2 = -torch.log(-torch.log(u2 + 1e-20) + 1e-20)
1144
+ key_gate = torch.sigmoid((key_gate_logit + g1) / self.gate_temperature)
1145
+ value_gate = torch.sigmoid((value_gate_logit + g2) / self.gate_temperature)
1146
+ else:
1147
+ key_gate = (key_gate_logit > 0).float()
1148
+ value_gate = (value_gate_logit > 0).float()
1149
+
1150
+ # Normalize scalars (scalar_temperature=1.0)
1151
+ norm_key_scalar = torch.sigmoid(key_scalar)
1152
+ norm_value_scalar = torch.sigmoid(value_scalar)
1153
+
1154
+ # Combine (preserve_target_weight=False, add_self=True)
1155
+ output_key = target_key + key_gate * norm_key_scalar * projected_key
1156
+ output_value = target_value + value_gate * norm_value_scalar * projected_value
1157
+
1158
+ # Expose capture attributes for downstream analysis scripts
1159
+ try:
1160
+ # Store normalized scalars (detach to avoid autograd, keep device-agnostic via CPU)
1161
+ self.last_norm_key_scalar = norm_key_scalar.detach().cpu()
1162
+ self.last_norm_value_scalar = norm_value_scalar.detach().cpu()
1163
+ # Store gate logits as python floats (parameters are scalar)
1164
+ self.last_key_gate_logit = float(self.key_gate_logit.detach().cpu().item())
1165
+ self.last_value_gate_logit = float(self.value_gate_logit.detach().cpu().item())
1166
+ except Exception:
1167
+ # Best-effort capture; never break forward path
1168
+ pass
1169
+
1170
+ return output_key, output_value
1171
+
1172
+ def save_projector(obj: Projector, file_path: str) -> None:
1173
+ save_object(obj, file_path)
1174
+
1175
+ def load_projector(file_path: str, override_args: Optional[dict] = None) -> Projector:
1176
+ return load_object(file_path, get_projector_class, override_args)
1177
+
1178
+ def create_projector(projector_type: str, **kwargs) -> Projector:
1179
+ """
1180
+ Factory function to create a projector based on type.
1181
+
1182
+ Args:
1183
+ projector_type: String indicating the type of projector
1184
+ **kwargs: Additional arguments to pass to the projector constructor
1185
+
1186
+ Returns:
1187
+ An instance of the appropriate projector
1188
+ """
1189
+ # Prefer using the unified registry getter (handles case-insensitive keys)
1190
+ try:
1191
+ cls = get_projector_class(projector_type)
1192
+ except ValueError as e:
1193
+ raise e
1194
+ return cls(**kwargs)
rosetta/model/sampling.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Union
4
+
5
+ def sample_token(logits: torch.Tensor, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1) -> Union[int, torch.Tensor]:
6
+ """Sample a token from logits using temperature, top-p, and top-k sampling.
7
+ Args:
8
+ logits: Token logits of shape [vocab_size] or [batch_size, vocab_size]
9
+ temperature: Temperature for sampling (>0). Higher values produce more random samples.
10
+ top_p: Top-p probability threshold for nucleus sampling (0 < top_p ≤ 1)
11
+ top_k: Top-k threshold for sampling (if -1, no top-k filtering is applied)
12
+ Returns:
13
+ Sampled token ID (int for single sample, tensor for batch)
14
+ """
15
+ if not isinstance(logits, torch.Tensor):
16
+ raise TypeError("logits must be a torch.Tensor")
17
+
18
+ if logits.dim() not in [1, 2]:
19
+ raise ValueError("logits must have shape [vocab_size] or [batch_size, vocab_size]")
20
+
21
+ # Handle single dimension input
22
+ is_single_input = logits.dim() == 1
23
+ if is_single_input:
24
+ logits = logits.unsqueeze(0)
25
+
26
+ batch_size = logits.shape[0]
27
+
28
+ # For greedy sampling (temperature=0), just return argmax
29
+ if temperature == 0 or temperature <= 1e-5:
30
+ tokens = torch.argmax(logits, dim=-1)
31
+ return tokens.item() if is_single_input else tokens
32
+
33
+ # Convert to probabilities
34
+ probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
35
+
36
+ # Apply top-k filtering first (if specified)
37
+ if top_k != -1:
38
+ # Get top-k values and indices
39
+ top_k_values, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]), dim=-1)
40
+
41
+ # Create a mask to zero out non-top-k probabilities
42
+ mask = torch.zeros_like(probs, dtype=torch.bool)
43
+ mask.scatter_(-1, top_k_indices, True)
44
+
45
+ # Zero out non-top-k probabilities
46
+ probs = probs * mask.float()
47
+
48
+ # Renormalize probabilities
49
+ probs = probs / probs.sum(dim=-1, keepdim=True)
50
+
51
+ # Apply top-p (nucleus) sampling
52
+ if top_p < 1.0:
53
+ # Sort probabilities in descending order
54
+ sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
55
+
56
+ # Calculate cumulative probabilities
57
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
58
+
59
+ # Create a mask for probabilities to keep
60
+ # Values above top_p threshold are masked out
61
+ mask = cumulative_probs <= top_p
62
+
63
+ # Always keep at least one token
64
+ mask[:, 0] = True
65
+
66
+ # Zero out masked positions to exclude them from sampling
67
+ sorted_probs = sorted_probs * mask.float()
68
+
69
+ # Renormalize probabilities
70
+ sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
71
+
72
+ # Sample from the filtered distribution
73
+ sampled_indices = torch.multinomial(sorted_probs, num_samples=1)
74
+
75
+ # Map back to original vocabulary indices
76
+ tokens = torch.gather(sorted_indices, dim=-1, index=sampled_indices)
77
+ tokens = tokens.squeeze(-1) # Remove sample dimension
78
+ else:
79
+ # Direct sampling if no top-p filtering
80
+ tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
81
+
82
+ return tokens.item() if is_single_input else tokens
rosetta/model/wrapper.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The ensemble of multiple standard transformers LLM models, with automatic kv-cache projection. It shares the same interface as the standard transformers LLM models.
3
+ """
4
+
5
+ from typing import List, Optional, Union
6
+ import torch
7
+ from torch import nn
8
+ from transformers.cache_utils import Cache, DynamicCache
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+ import json
12
+
13
+ from rosetta.model.projector import Projector
14
+ from rosetta.model.sampling import sample_token
15
+ from transformers.utils import ModelOutput
16
+ try:
17
+ from transformers.generation.utils import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
18
+ except Exception:
19
+ GreedySearchDecoderOnlyOutput = None
20
+ SampleDecoderOnlyOutput = None
21
+
22
+ def clone_kv_cache(kv_cache: DynamicCache) -> DynamicCache:
23
+ new_cache = DynamicCache()
24
+ for k, v in zip(kv_cache.key_cache, kv_cache.value_cache):
25
+ new_cache.key_cache.append(k.clone().detach())
26
+ new_cache.value_cache.append(v.clone().detach())
27
+ return new_cache
28
+
29
+ def hybrid_to_dynamic(hybrid_cache):
30
+ if hybrid_cache is None:
31
+ return None
32
+ if isinstance(hybrid_cache, DynamicCache):
33
+ return hybrid_cache
34
+
35
+ # 手动从 HybridCache 提取
36
+ if hasattr(hybrid_cache, "key_cache") and hasattr(hybrid_cache, "value_cache"):
37
+ keys = hybrid_cache.key_cache
38
+ values = hybrid_cache.value_cache
39
+ assert len(keys) == len(values), "key/value 层数不一致"
40
+
41
+ legacy_cache = [(k, v) for k, v in zip(keys, values)]
42
+ return DynamicCache.from_legacy_cache(legacy_cache)
43
+
44
+ raise TypeError(f"Unsupported cache type: {type(hybrid_cache)}")
45
+
46
+ class RosettaModel(nn.Module):
47
+ """
48
+ Drop in replacement for the standard transformers LLM models, like Qwen3ForCausalLM
49
+ """
50
+ def __init__(self, model_list: List[PreTrainedModel], base_model_idx = 0, projector_list: List[Projector] = [], aggregator_list: List[nn.Module] = []):
51
+ super().__init__()
52
+ # model list: a list of model, model 0 by default is the base model
53
+ # projector list: a list of projector
54
+ # standard init with additional model list parameter
55
+ # kv-cache dict: key (source_model_idx, target_model_idx), value (Cache), assume only convert at prefill with one type of model
56
+ # projector dict: key (source_model_idx, target_model_idx) value dict(key (source_model_layer_idx, M_target value )
57
+
58
+ self.base_model_idx = base_model_idx
59
+ self.model_list = nn.ModuleList(model_list)
60
+
61
+ device = model_list[base_model_idx].device
62
+ dtype = model_list[base_model_idx].dtype
63
+ self.projector_list = nn.ModuleList(projector_list).to(device=device, dtype=dtype)
64
+ self.aggregator_list = nn.ModuleList(aggregator_list).to(device=device, dtype=dtype)
65
+
66
+ self.projector_dict = {}
67
+ self.aggregator_dict = {}
68
+ self.kv_cache_dict = {}
69
+ self._generation_hook_handlers = []
70
+
71
+ @property
72
+ def device(self):
73
+ return self.model_list[self.base_model_idx].device
74
+
75
+ def to(self, device):
76
+ """
77
+ Move the RosettaModel and all underlying models and projectors to the specified device.
78
+ """
79
+ super().to(device)
80
+ for model in self.model_list:
81
+ model.to(device)
82
+ for projector in self.projector_list:
83
+ projector.to(device)
84
+ for aggregator in self.aggregator_list:
85
+ aggregator.to(device)
86
+ return self
87
+
88
+ # set projector
89
+ def set_projector_config(self,
90
+ source_model_idx: int,
91
+ source_model_layer_idx: int,
92
+ target_model_idx: int,
93
+ target_model_layer_idx: int,
94
+ projector_idx: int):
95
+ """
96
+ Set the projector configuration
97
+ Args:
98
+ source_model_idx: int, the index of the source model
99
+ source_model_layer_idx: int, the index of the source model layer
100
+ target_model_idx: int, the index of the target model
101
+ target_model_layer_idx: int, the index of the target model layer
102
+ projector_idx: int, the index of the projector
103
+
104
+ The projector dict structure supports multiple projectors per target layer.
105
+ Structure:
106
+ {
107
+ target_model_idx: {
108
+ source_model_idx: {
109
+ target_model_layer_idx: [(source_model_layer_idx, projector_idx), ...]
110
+ }
111
+ }
112
+ }
113
+ Repeated calls for the same (target, source, target_layer) append additional pairs.
114
+ """
115
+
116
+ if target_model_idx not in self.projector_dict.keys():
117
+ self.projector_dict[target_model_idx] = {}
118
+ if source_model_idx not in self.projector_dict[target_model_idx].keys():
119
+ self.projector_dict[target_model_idx][source_model_idx] = {}
120
+ # Accumulate list of (source_layer, projector_idx) for this target layer
121
+ layer_entry = self.projector_dict[target_model_idx][source_model_idx].get(target_model_layer_idx)
122
+ if layer_entry is None:
123
+ self.projector_dict[target_model_idx][source_model_idx][target_model_layer_idx] = [(source_model_layer_idx, projector_idx)]
124
+ else:
125
+ layer_entry.append((source_model_layer_idx, projector_idx))
126
+
127
+
128
+ def load_projector(self, projector_list):
129
+ self.projector_list: List[Projector] = projector_list
130
+
131
+ def load_aggregator(self, aggregator_list):
132
+ self.aggregator_list: List[nn.Module] = aggregator_list
133
+
134
+
135
+ def get_projector(self,
136
+ source_model_idx,
137
+ source_model_layer_idx,
138
+ target_model_idx,
139
+ target_model_layer_idx):
140
+ pair_list = self.projector_dict[target_model_idx][source_model_idx][target_model_layer_idx]
141
+ if len(pair_list) == 0:
142
+ raise ValueError("No projector configured for the given target layer")
143
+ # Prefer exact source layer match
144
+ for src_layer, projector_id in pair_list:
145
+ if src_layer == source_model_layer_idx:
146
+ return self.projector_list[projector_id]
147
+ # Fallback: return the first projector
148
+ return self.projector_list[pair_list[0][1]]
149
+
150
+ def set_aggregator_idx(self,
151
+ source_model_idx: int,
152
+ target_model_idx: int,
153
+ target_model_layer_idx: int,
154
+ aggregator_idx: int):
155
+ if target_model_idx not in self.aggregator_dict:
156
+ self.aggregator_dict[target_model_idx] = {}
157
+ if source_model_idx not in self.aggregator_dict[target_model_idx]:
158
+ self.aggregator_dict[target_model_idx][source_model_idx] = {}
159
+ self.aggregator_dict[target_model_idx][source_model_idx][target_model_layer_idx] = aggregator_idx
160
+
161
+
162
+ @staticmethod
163
+ def load_json(file_name):
164
+ with open(file_name, "r") as f:
165
+ result = json.load(f)
166
+ return result
167
+
168
+ @staticmethod
169
+ def _convert_dict_keys_to_ints(obj):
170
+ """
171
+ Recursively convert dictionary keys that look like integers back to int.
172
+ This reverses json.dump's coercion of dict keys to strings.
173
+ """
174
+ if isinstance(obj, dict):
175
+ new_obj = {}
176
+ for key, value in obj.items():
177
+ if isinstance(key, str) and key.lstrip('-').isdigit():
178
+ new_key = int(key)
179
+ else:
180
+ new_key = key
181
+ new_obj[new_key] = RosettaModel._convert_dict_keys_to_ints(value)
182
+ return new_obj
183
+ if isinstance(obj, list):
184
+ return [RosettaModel._convert_dict_keys_to_ints(v) for v in obj]
185
+ return obj
186
+
187
+
188
+ def save_projector_config(self, file_name):
189
+ with open(file_name, "w") as f:
190
+ json.dump(self.projector_dict, f)
191
+
192
+
193
+ def load_projector_config(self, config_path):
194
+ if config_path.endswith(".json"):
195
+ loaded = RosettaModel.load_json(config_path)
196
+ self.projector_dict = RosettaModel._convert_dict_keys_to_ints(loaded)
197
+
198
+ def save_aggregator_config(self, file_name):
199
+ with open(file_name, "w") as f:
200
+ json.dump(self.aggregator_dict, f)
201
+
202
+ def load_aggregator_config(self, config_path):
203
+ if config_path.endswith(".json"):
204
+ loaded = RosettaModel.load_json(config_path)
205
+ self.aggregator_dict = RosettaModel._convert_dict_keys_to_ints(loaded)
206
+
207
+
208
+ def set_kv_cache_dict(self, source_model_idx, target_model_idx, cache):
209
+ if target_model_idx not in self.kv_cache_dict.keys():
210
+ self.kv_cache_dict[target_model_idx] = {}
211
+ if cache is None:
212
+ # Initialize with a DynamicCache instead of RosettaCache for now
213
+ self.kv_cache_dict[target_model_idx][source_model_idx] = DynamicCache() # noqa, maybe we should use RosettaCache here
214
+ else:
215
+ self.kv_cache_dict[target_model_idx][source_model_idx] = cache
216
+
217
+ def forward(
218
+ self,
219
+ kv_cache_index: Optional[List] = None,
220
+ input_ids: Optional[Union[torch.LongTensor, List[torch.LongTensor]]] = None,
221
+ attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
222
+ position_ids: Optional[torch.LongTensor] = None,
223
+ past_key_values: Optional[Cache] = None,
224
+ inputs_embeds: Optional[torch.FloatTensor] = None,
225
+ labels: Optional[torch.LongTensor] = None,
226
+ use_cache: Optional[bool] = None,
227
+ output_attentions: Optional[bool] = None,
228
+ output_hidden_states: Optional[bool] = None,
229
+ cache_position: Optional[torch.LongTensor] = None,
230
+ logits_to_keep: Union[int, torch.Tensor] = 0,
231
+ # **kwargs: Unpack[KwargsForCausalLM],
232
+ *args,
233
+ **kwargs,
234
+ ) -> CausalLMOutputWithPast:
235
+ """
236
+ Forward pass
237
+ KVCache index is a list of tensors with shape (B, sec_seq_len, 2), indicating the source and target kv cache model index
238
+
239
+ If input_ids is LongTensor, default to same input ids for different models
240
+ If input_ids is Tuple, default to different input ids for different models.
241
+
242
+ No Rosetta: (-1, 0)
243
+ """
244
+
245
+ # noqa
246
+ self.kv_cache_dict = dict()
247
+
248
+ # Handle different input formats: if input_ids is a list, use per-model inputs
249
+ if isinstance(input_ids, list):
250
+ # Use list format: different input_ids and attention_mask for each model
251
+ base_input_ids = input_ids[self.base_model_idx] if input_ids is not None else None
252
+ base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
253
+ _, seqlen = base_input_ids.size() if base_input_ids is not None else (0, 0)
254
+ else:
255
+ # Use tensor format: same input_ids and attention_mask for all models (backward compatibility)
256
+ base_input_ids = input_ids
257
+ base_attention_mask = attention_mask
258
+ _, seqlen = input_ids.size() if input_ids is not None else (0, 0)
259
+
260
+ num_sections = len(kv_cache_index) if kv_cache_index is not None else 1
261
+
262
+ section_lengths = [kv_cache_index[i].shape[1] for i in range(num_sections)] if kv_cache_index is not None else [seqlen]
263
+ section_starts = [0]
264
+ for l in section_lengths:
265
+ section_starts.append(section_starts[-1] + l)
266
+
267
+ curr_base_kv_cache = past_key_values
268
+
269
+ if seqlen >= 1:
270
+ for i in range(num_sections):
271
+ start = section_starts[i]
272
+ end = section_starts[i + 1]
273
+ prefill_input_ids = base_input_ids[:, start:end] if base_input_ids is not None else None
274
+ prefill_attention_mask = base_attention_mask[:, :end] if base_attention_mask is not None else None
275
+ prefill_position_ids = position_ids[:, start:end] if position_ids is not None else None
276
+ prefill_labels = labels[:, start:end] if labels is not None else None
277
+
278
+ output = self.model_list[self.base_model_idx].forward(
279
+ input_ids=prefill_input_ids,
280
+ attention_mask=prefill_attention_mask,
281
+ position_ids=prefill_position_ids,
282
+ past_key_values=curr_base_kv_cache,
283
+ labels=prefill_labels,
284
+ use_cache=use_cache,
285
+ output_attentions=output_attentions,
286
+ output_hidden_states=output_hidden_states,
287
+ *args,
288
+ **kwargs
289
+ )
290
+
291
+ if self.base_model_idx not in self.kv_cache_dict:
292
+ self.kv_cache_dict[self.base_model_idx] = {}
293
+ if self.base_model_idx not in self.kv_cache_dict[self.base_model_idx]:
294
+ self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = None
295
+ self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = output.past_key_values
296
+
297
+ curr_base_kv_cache: DynamicCache = output.past_key_values
298
+
299
+ if i != num_sections - 1:
300
+ for source_model_idx in range(1, len(self.model_list)):
301
+ if self.base_model_idx not in self.kv_cache_dict:
302
+ self.kv_cache_dict[self.base_model_idx] = {}
303
+ if source_model_idx not in self.kv_cache_dict[self.base_model_idx]:
304
+ self.kv_cache_dict[self.base_model_idx][source_model_idx] = None
305
+
306
+ # Get model-specific input_ids and attention_mask
307
+ if isinstance(input_ids, list):
308
+ source_input_ids = input_ids[source_model_idx]
309
+ source_attention_mask = attention_mask[source_model_idx] if attention_mask is not None else None
310
+ source_prefill_input_ids = source_input_ids[:, start:end] if source_input_ids is not None else None
311
+ source_prefill_attention_mask = source_attention_mask[:, :end] if source_attention_mask is not None else None
312
+ else:
313
+ # Backward compatibility: use same input for all models
314
+ source_prefill_input_ids = prefill_input_ids
315
+ source_prefill_attention_mask = prefill_attention_mask
316
+
317
+ model = self.model_list[source_model_idx]
318
+ was_training = model.training
319
+ had_gc = getattr(model, "is_gradient_checkpointing", False)
320
+
321
+ try:
322
+ if was_training:
323
+ model.eval()
324
+ if had_gc:
325
+ model.gradient_checkpointing_disable()
326
+
327
+ with torch.no_grad():
328
+ out = model(
329
+ input_ids=source_prefill_input_ids,
330
+ attention_mask=source_prefill_attention_mask,
331
+ position_ids=prefill_position_ids,
332
+ past_key_values=self.kv_cache_dict[self.base_model_idx][source_model_idx],
333
+ use_cache=True,
334
+ return_dict=True,
335
+ )
336
+ curr_source_kv_cache = out.past_key_values
337
+ finally:
338
+ if had_gc:
339
+ model.gradient_checkpointing_enable()
340
+ if was_training:
341
+ model.train()
342
+
343
+ curr_source_kv_cache = hybrid_to_dynamic(curr_source_kv_cache)
344
+ self.kv_cache_dict[self.base_model_idx][source_model_idx] = curr_source_kv_cache
345
+
346
+ # calculate source model kvcache and apply projections
347
+ if self.base_model_idx in self.projector_dict:
348
+ source_model_idx = kv_cache_index[i][0][0][0].item() # Get the source model index from the kv_cache_index
349
+ if source_model_idx != -1:
350
+ for target_layer_idx, entry in self.projector_dict[self.base_model_idx][source_model_idx].items():
351
+ base_key_cache, base_value_cache = curr_base_kv_cache[target_layer_idx]
352
+ new_base_key_cache = base_key_cache[:, :, start:end, :]
353
+ new_base_value_cache = base_value_cache[:, :, start:end, :]
354
+ new_base_kv_cache = (new_base_key_cache, new_base_value_cache)
355
+
356
+ pair_list = entry
357
+
358
+ projected_kv_list = []
359
+ source_kv_list = []
360
+ for source_model_layer_idx, projector_idx in pair_list:
361
+ source_key_cache, source_value_cache = self.kv_cache_dict[self.base_model_idx][source_model_idx][source_model_layer_idx]
362
+ new_source_key_cache = source_key_cache[:, :, start:end, :]
363
+ new_source_value_cache = source_value_cache[:, :, start:end, :]
364
+ new_source_kv_cache = (new_source_key_cache, new_source_value_cache)
365
+ projected_key, projected_value = self.projector_list[projector_idx].forward(
366
+ new_source_kv_cache, # tuple of (key, value), each of shape (B, N, H, D)
367
+ new_base_kv_cache
368
+ )
369
+ projected_kv_list.append((projected_key, projected_value))
370
+ source_kv_list.append(new_source_kv_cache)
371
+
372
+ # Aggregate (fallback to first projector if no aggregator is available)
373
+ use_aggregator = (
374
+ len(projected_kv_list) > 1 and
375
+ len(self.aggregator_list) > 0 and
376
+ self.base_model_idx in self.aggregator_dict and
377
+ source_model_idx in self.aggregator_dict[self.base_model_idx] and
378
+ target_layer_idx in self.aggregator_dict[self.base_model_idx][source_model_idx]
379
+ )
380
+
381
+ if use_aggregator:
382
+ aggregator_idx = self.aggregator_dict[self.base_model_idx][source_model_idx][target_layer_idx]
383
+ agg_key, agg_value = self.aggregator_list[aggregator_idx].forward(
384
+ source_kv_list,
385
+ new_base_kv_cache,
386
+ projected_kv_list
387
+ )
388
+ else:
389
+ # Fallback to first projector result when no aggregator is available
390
+ agg_key, agg_value = projected_kv_list[0]
391
+
392
+ # Update cache with aggregated result
393
+ curr_base_kv_cache.key_cache[target_layer_idx][:, :, start:end, :] = agg_key
394
+ curr_base_kv_cache.value_cache[target_layer_idx][:, :, start:end, :] = agg_value
395
+
396
+ output.past_key_values = curr_base_kv_cache
397
+
398
+ # use base model for decode phase
399
+ else:
400
+ # Handle list input format for decode phase as well
401
+ decode_input_ids = input_ids[self.base_model_idx] if isinstance(input_ids, list) else input_ids
402
+ decode_attention_mask = attention_mask[self.base_model_idx] if isinstance(attention_mask, list) else attention_mask
403
+
404
+ output = self.model_list[self.base_model_idx].forward(
405
+ input_ids=decode_input_ids,
406
+ attention_mask=decode_attention_mask,
407
+ position_ids=position_ids,
408
+ past_key_values=curr_base_kv_cache,
409
+ inputs_embeds=inputs_embeds,
410
+ labels=labels,
411
+ use_cache=use_cache,
412
+ output_attentions=output_attentions,
413
+ output_hidden_states=output_hidden_states,
414
+ cache_position=cache_position,
415
+ *args,
416
+ **kwargs
417
+ )
418
+ return output
419
+
420
+ def oracle_forward(
421
+ self,
422
+ kv_cache_index: Optional[List] = None,
423
+ input_ids: Optional[Union[torch.LongTensor, List[torch.LongTensor]]] = None,
424
+ attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
425
+ position_ids: Optional[torch.LongTensor] = None,
426
+ past_key_values: Optional[Cache] = None,
427
+ inputs_embeds: Optional[torch.FloatTensor] = None,
428
+ labels: Optional[torch.LongTensor] = None,
429
+ use_cache: Optional[bool] = None,
430
+ output_attentions: Optional[bool] = None,
431
+ output_hidden_states: Optional[bool] = None,
432
+ cache_position: Optional[torch.LongTensor] = None,
433
+ logits_to_keep: Union[int, torch.Tensor] = 0,
434
+ # **kwargs: Unpack[KwargsForCausalLM],
435
+ *args,
436
+ **kwargs,
437
+ ) -> CausalLMOutputWithPast:
438
+ """
439
+ Forward pass
440
+ KVCache index is a list of tensors with shape (B, sec_seq_len, 2), indicating the source and target kv cache model index
441
+
442
+ If input_ids is LongTensor, default to same input ids for different models
443
+ If input_ids is Tuple, default to different input ids for different models.
444
+
445
+ No Rosetta: (-1, 0)
446
+ """
447
+
448
+ # noqa
449
+ self.kv_cache_dict = dict()
450
+
451
+ # Handle different input formats: if input_ids is a list, use per-model inputs
452
+ if isinstance(input_ids, list):
453
+ # Use list format: different input_ids and attention_mask for each model
454
+ base_input_ids = input_ids[self.base_model_idx] if input_ids is not None else None
455
+ base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
456
+ _, seqlen = base_input_ids.size() if base_input_ids is not None else (0, 0)
457
+ else:
458
+ # Use tensor format: same input_ids and attention_mask for all models (backward compatibility)
459
+ base_input_ids = input_ids
460
+ base_attention_mask = attention_mask
461
+ _, seqlen = input_ids.size() if input_ids is not None else (0, 0)
462
+
463
+ num_sections = len(kv_cache_index) if kv_cache_index is not None else 1
464
+
465
+ section_lengths = [kv_cache_index[i].shape[1] for i in range(num_sections)] if kv_cache_index is not None else [seqlen]
466
+ section_starts = [0]
467
+ for l in section_lengths:
468
+ section_starts.append(section_starts[-1] + l)
469
+
470
+ curr_base_kv_cache = past_key_values
471
+
472
+ loss = nn.MSELoss()
473
+ loss_output = 0
474
+ if seqlen > 1:
475
+ for i in range(num_sections):
476
+ start = section_starts[i]
477
+ end = section_starts[i + 1]
478
+ prefill_input_ids = base_input_ids[:, start:end] if base_input_ids is not None else None
479
+ prefill_attention_mask = base_attention_mask[:, :end] if base_attention_mask is not None else None
480
+ prefill_position_ids = position_ids[:, start:end] if position_ids is not None else None
481
+ prefill_labels = labels[:, start:end] if labels is not None else None
482
+
483
+ # calculate target model kvcache
484
+ output = self.model_list[self.base_model_idx].forward(
485
+ input_ids=prefill_input_ids,
486
+ attention_mask=prefill_attention_mask,
487
+ position_ids=prefill_position_ids,
488
+ past_key_values=curr_base_kv_cache,
489
+ labels=prefill_labels,
490
+ use_cache=use_cache,
491
+ output_attentions=output_attentions,
492
+ output_hidden_states=output_hidden_states,
493
+ *args,
494
+ **kwargs
495
+ )
496
+
497
+ if self.base_model_idx not in self.kv_cache_dict:
498
+ self.kv_cache_dict[self.base_model_idx] = {}
499
+ if self.base_model_idx not in self.kv_cache_dict[self.base_model_idx]:
500
+ self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = None
501
+ self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = output.past_key_values
502
+
503
+ curr_base_kv_cache: DynamicCache = output.past_key_values
504
+
505
+ if i != num_sections - 1:
506
+ for source_model_idx in range(1, len(self.model_list)):
507
+ if self.base_model_idx not in self.kv_cache_dict:
508
+ self.kv_cache_dict[self.base_model_idx] = {}
509
+ if source_model_idx not in self.kv_cache_dict[self.base_model_idx]:
510
+ self.kv_cache_dict[self.base_model_idx][source_model_idx] = None
511
+
512
+ # Get model-specific input_ids and attention_mask
513
+ if isinstance(input_ids, list):
514
+ source_input_ids = input_ids[source_model_idx]
515
+ source_attention_mask = attention_mask[source_model_idx] if attention_mask is not None else None
516
+ source_prefill_input_ids = source_input_ids[:, start:end] if source_input_ids is not None else None
517
+ source_prefill_attention_mask = source_attention_mask[:, :end] if source_attention_mask is not None else None
518
+ else:
519
+ # Backward compatibility: use same input for all models
520
+ source_prefill_input_ids = prefill_input_ids
521
+ source_prefill_attention_mask = prefill_attention_mask
522
+
523
+ curr_source_kv_cache = self.model_list[source_model_idx].forward(
524
+ input_ids=source_prefill_input_ids,
525
+ attention_mask=source_prefill_attention_mask,
526
+ position_ids=prefill_position_ids,
527
+ past_key_values=self.kv_cache_dict[self.base_model_idx][source_model_idx],
528
+ use_cache=use_cache,
529
+ output_attentions=output_attentions,
530
+ output_hidden_states=output_hidden_states,
531
+ *args,
532
+ **kwargs
533
+ ).past_key_values
534
+ self.kv_cache_dict[self.base_model_idx][source_model_idx] = curr_source_kv_cache
535
+
536
+ # calculate source model kvcache and apply projections
537
+ if self.base_model_idx in self.projector_dict:
538
+ source_model_idx = kv_cache_index[i][0][0][0].item() # Get the source model index from the kv_cache_index
539
+ if source_model_idx != -1:
540
+ for target_layer_idx, entry in self.projector_dict[self.base_model_idx][source_model_idx].items():
541
+ base_key_cache, base_value_cache = curr_base_kv_cache[target_layer_idx]
542
+ new_base_key_cache = base_key_cache[:, :, start:end, :]
543
+ new_base_value_cache = base_value_cache[:, :, start:end, :]
544
+ new_base_kv_cache = (new_base_key_cache, new_base_value_cache)
545
+
546
+ pair_list = entry
547
+
548
+ projected_kv_list = []
549
+ source_kv_list = []
550
+ for source_model_layer_idx, projector_idx in pair_list:
551
+ source_key_cache, source_value_cache = self.kv_cache_dict[self.base_model_idx][source_model_idx][source_model_layer_idx]
552
+ new_source_key_cache = source_key_cache[:, :, start:end, :]
553
+ new_source_value_cache = source_value_cache[:, :, start:end, :]
554
+ new_source_kv_cache = (new_source_key_cache, new_source_value_cache)
555
+ projected_key, projected_value = self.projector_list[projector_idx].forward(
556
+ new_source_kv_cache, # tuple of (key, value), each of shape (B, N, H, D)
557
+ new_base_kv_cache
558
+ )
559
+ loss_output = loss_output + loss(torch.dstack([projected_key, projected_value]),
560
+ torch.dstack([new_source_key_cache, new_source_value_cache]))
561
+ projected_kv_list.append((projected_key, projected_value))
562
+ source_kv_list.append(new_source_kv_cache)
563
+
564
+ # Aggregate (fallback to first projector if no aggregator is available)
565
+ use_aggregator = (
566
+ len(projected_kv_list) > 1 and
567
+ len(self.aggregator_list) > 0 and
568
+ self.base_model_idx in self.aggregator_dict and
569
+ source_model_idx in self.aggregator_dict[self.base_model_idx] and
570
+ target_layer_idx in self.aggregator_dict[self.base_model_idx][source_model_idx]
571
+ )
572
+
573
+ if use_aggregator:
574
+ aggregator_idx = self.aggregator_dict[self.base_model_idx][source_model_idx][target_layer_idx]
575
+ agg_key, agg_value = self.aggregator_list[aggregator_idx].forward(
576
+ source_kv_list,
577
+ new_base_kv_cache,
578
+ projected_kv_list
579
+ )
580
+ else:
581
+ # Fallback to first projector result when no aggregator is available
582
+ agg_key, agg_value = projected_kv_list[0]
583
+
584
+ # Update cache with aggregated result
585
+ curr_base_kv_cache.key_cache[target_layer_idx][:, :, start:end, :] = agg_key
586
+ curr_base_kv_cache.value_cache[target_layer_idx][:, :, start:end, :] = agg_value
587
+
588
+ output.past_key_values = curr_base_kv_cache
589
+
590
+ # use base model for decode phase
591
+ else:
592
+ # Handle list input format for decode phase as well
593
+ decode_input_ids = input_ids[self.base_model_idx] if isinstance(input_ids, list) else input_ids
594
+ decode_attention_mask = attention_mask[self.base_model_idx] if isinstance(attention_mask, list) else attention_mask
595
+
596
+ output = self.model_list[self.base_model_idx].forward(
597
+ input_ids=decode_input_ids,
598
+ attention_mask=decode_attention_mask,
599
+ position_ids=position_ids,
600
+ past_key_values=curr_base_kv_cache,
601
+ inputs_embeds=inputs_embeds,
602
+ labels=labels,
603
+ use_cache=use_cache,
604
+ output_attentions=output_attentions,
605
+ output_hidden_states=output_hidden_states,
606
+ cache_position=cache_position,
607
+ *args,
608
+ **kwargs
609
+ )
610
+
611
+ return output, loss_output
612
+
613
+ @torch.no_grad()
614
+ def generate(
615
+ self,
616
+ kv_cache_index,
617
+ input_ids,
618
+ max_new_tokens: Optional[int] = None,
619
+ past_key_values: Optional[Cache] = None,
620
+ attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
621
+ position_ids: Optional[torch.LongTensor] = None,
622
+ eos_token_id: Optional[Union[int, List[int]]] = None,
623
+ pad_token_id: Optional[int] = None,
624
+ temperature: float = 1.0,
625
+ top_p: float = 1.0,
626
+ top_k: int = -1,
627
+ repetition_penalty: float = 1.0,
628
+ presence_penalty: float = 0.0,
629
+ frequency_penalty: float = 0.0,
630
+ do_sample: Optional[bool] = None,
631
+ return_dict_in_generate: Optional[bool] = None,
632
+ output_scores: Optional[bool] = None,
633
+ max_length: Optional[int] = None,
634
+ use_cache: bool = True,
635
+ streamer = None,
636
+ *args,
637
+ **kwargs,
638
+ ):
639
+ """
640
+ New generation loop without using the base model's generate.
641
+ - Uses this module's forward for prefill and per-token decode.
642
+ - Samples tokens via rosetta.model.sampling.sample_token.
643
+ Returns a tensor of shape [batch, prompt_len + generated_len] for the base model stream.
644
+ """
645
+ # Derive number of tokens to generate
646
+ # If max_new_tokens not provided, infer from max_length
647
+ if isinstance(input_ids, list):
648
+ base_input_ids_for_len = input_ids[self.base_model_idx]
649
+ else:
650
+ base_input_ids_for_len = input_ids
651
+ prompt_len = base_input_ids_for_len.size(1)
652
+
653
+ # Default eos/pad from base model tokenizer/config if not provided
654
+ base_model = self.model_list[self.base_model_idx]
655
+ gen_cfg = getattr(base_model, "generation_config", None)
656
+ cfg_obj = gen_cfg if gen_cfg is not None else getattr(base_model, "config", None)
657
+ if eos_token_id is None and cfg_obj is not None:
658
+ eos_token_id = getattr(cfg_obj, "eos_token_id", None)
659
+ if pad_token_id is None and cfg_obj is not None:
660
+ pad_token_id = getattr(cfg_obj, "pad_token_id", None)
661
+ if pad_token_id is None and eos_token_id is not None:
662
+ pad_token_id = eos_token_id if isinstance(eos_token_id, int) else eos_token_id[0]
663
+
664
+ if max_new_tokens is None:
665
+ if max_length is not None:
666
+ if max_length <= prompt_len:
667
+ max_new_tokens = 0
668
+ else:
669
+ max_new_tokens = max_length - prompt_len
670
+ else:
671
+ raise ValueError("Provide max_new_tokens or max_length")
672
+ if max_new_tokens < 0:
673
+ raise ValueError("max_new_tokens must be non-negative")
674
+
675
+ # Resolve base inputs
676
+ if isinstance(input_ids, list):
677
+ base_input_ids = input_ids[self.base_model_idx]
678
+ base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
679
+ else:
680
+ base_input_ids = input_ids
681
+ base_attention_mask = attention_mask
682
+
683
+ if base_attention_mask is None:
684
+ base_attention_mask = torch.ones_like(base_input_ids, dtype=torch.long, device=base_input_ids.device)
685
+
686
+ batch_size = base_input_ids.size(0)
687
+
688
+ # Prefill to build caches and obtain initial logits
689
+ prefill_output = self.forward(
690
+ kv_cache_index=kv_cache_index,
691
+ input_ids=input_ids,
692
+ attention_mask=attention_mask,
693
+ position_ids=position_ids,
694
+ past_key_values=past_key_values,
695
+ use_cache=use_cache,
696
+ *args,
697
+ **kwargs,
698
+ )
699
+
700
+ current_past = prefill_output.past_key_values
701
+ all_input_ids = base_input_ids
702
+ current_attention_mask = base_attention_mask
703
+
704
+ # Initialize streamer with prompt if provided
705
+ if streamer is not None:
706
+ streamer.put(base_input_ids)
707
+
708
+ # EOS handling setup
709
+ eos_set = None
710
+ if eos_token_id is not None:
711
+ eos_set = set(eos_token_id if isinstance(eos_token_id, list) else [eos_token_id])
712
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=all_input_ids.device)
713
+
714
+ # Start from last prefill logits
715
+ last_logits = prefill_output.logits[:, -1, :]
716
+
717
+ # Determine sampling mode
718
+ if do_sample is None:
719
+ do_sample = False
720
+ effective_temperature = temperature if do_sample else 0.0
721
+
722
+ # Optional scores collection
723
+ collect_scores = bool(return_dict_in_generate) and bool(output_scores)
724
+ scores = []
725
+
726
+ for _ in range(max_new_tokens):
727
+ if collect_scores:
728
+ scores.append(last_logits)
729
+ # Apply repetition/presence/frequency penalties to logits before sampling
730
+ adjusted_logits = last_logits
731
+ if (
732
+ (repetition_penalty is not None and repetition_penalty != 1.0) or
733
+ (presence_penalty is not None and presence_penalty != 0.0) or
734
+ (frequency_penalty is not None and frequency_penalty != 0.0)
735
+ ):
736
+ adjusted_logits = last_logits.clone()
737
+ vocab_size = adjusted_logits.size(-1)
738
+ # Per-batch penalty application for clarity and correctness
739
+ for b in range(batch_size):
740
+ seq_tokens = all_input_ids[b]
741
+ if seq_tokens.numel() == 0:
742
+ continue
743
+ counts = torch.bincount(seq_tokens, minlength=vocab_size)
744
+ if counts.dtype != torch.float32 and counts.dtype != torch.float64:
745
+ counts = counts.to(adjusted_logits.dtype)
746
+ # Presence penalty: penalize any token that has appeared
747
+ if presence_penalty and presence_penalty != 0.0:
748
+ presence_mask = counts > 0
749
+ if presence_mask.any():
750
+ adjusted_logits[b, presence_mask] = adjusted_logits[b, presence_mask] - presence_penalty
751
+ # Frequency penalty: penalize proportionally to frequency
752
+ if frequency_penalty and frequency_penalty != 0.0:
753
+ adjusted_logits[b] = adjusted_logits[b] - frequency_penalty * counts
754
+ # Repetition penalty (HF-style): divide positive logits, multiply negative logits
755
+ if repetition_penalty and repetition_penalty != 1.0:
756
+ rep_mask = counts > 0
757
+ if rep_mask.any():
758
+ pos_mask = rep_mask & (adjusted_logits[b] > 0)
759
+ neg_mask = rep_mask & ~pos_mask
760
+ if pos_mask.any():
761
+ adjusted_logits[b, pos_mask] = adjusted_logits[b, pos_mask] / repetition_penalty
762
+ if neg_mask.any():
763
+ adjusted_logits[b, neg_mask] = adjusted_logits[b, neg_mask] * repetition_penalty
764
+
765
+ # Sample next token
766
+ next_token = sample_token(adjusted_logits, temperature=effective_temperature, top_p=top_p, top_k=top_k)
767
+ if not isinstance(next_token, torch.Tensor):
768
+ next_token = torch.tensor([next_token], device=all_input_ids.device, dtype=torch.long).repeat(batch_size)
769
+
770
+ # Apply EOS logic
771
+ if eos_set is not None:
772
+ just_finished = torch.zeros_like(finished)
773
+ for eid in eos_set:
774
+ just_finished |= (next_token == eid)
775
+ finished = finished | just_finished
776
+ if pad_token_id is not None:
777
+ next_token = torch.where(
778
+ finished,
779
+ torch.tensor(pad_token_id, device=next_token.device, dtype=next_token.dtype),
780
+ next_token,
781
+ )
782
+
783
+ # Append sampled token
784
+ next_token_unsqueezed = next_token.unsqueeze(1)
785
+ all_input_ids = torch.cat([all_input_ids, next_token_unsqueezed], dim=1)
786
+ current_attention_mask = torch.cat(
787
+ [
788
+ current_attention_mask,
789
+ torch.ones((batch_size, 1), device=current_attention_mask.device, dtype=current_attention_mask.dtype),
790
+ ],
791
+ dim=1,
792
+ )
793
+
794
+ # Stream the new token if streamer provided
795
+ if streamer is not None:
796
+ streamer.put(next_token_unsqueezed)
797
+
798
+ # Early stop if all sequences finished
799
+ if eos_set is not None and torch.all(finished):
800
+ break
801
+
802
+ # Decode one step using cached states; pass base-stream tensors
803
+ kv_cache_index = [torch.tensor([-1, 0], dtype=torch.long).repeat(1, 1).unsqueeze(0).to(all_input_ids.device)]
804
+
805
+ decode_output = self.forward(
806
+ kv_cache_index=kv_cache_index,
807
+ input_ids=next_token_unsqueezed,
808
+ attention_mask=current_attention_mask,
809
+ position_ids=None,
810
+ past_key_values=current_past,
811
+ use_cache=True,
812
+ *args,
813
+ **kwargs,
814
+ )
815
+ current_past = decode_output.past_key_values
816
+ last_logits = decode_output.logits[:, -1, :]
817
+
818
+ # End streaming if streamer provided
819
+ if streamer is not None:
820
+ streamer.end()
821
+
822
+ # Return style compatible with HF generate
823
+ if return_dict_in_generate:
824
+ if GreedySearchDecoderOnlyOutput is not None and SampleDecoderOnlyOutput is not None:
825
+ if do_sample:
826
+ return SampleDecoderOnlyOutput(
827
+ sequences=all_input_ids,
828
+ scores=scores if collect_scores else None,
829
+ )
830
+ else:
831
+ return GreedySearchDecoderOnlyOutput(
832
+ sequences=all_input_ids,
833
+ scores=scores if collect_scores else None,
834
+ )
835
+ # Fallback to generic ModelOutput
836
+ result = {"sequences": all_input_ids}
837
+ if collect_scores:
838
+ result["scores"] = scores
839
+ return ModelOutput(**result)
840
+ return all_input_ids
rosetta/train/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training utilities for RosettaModel
3
+ """
4
+
5
+ from .dataset_adapters import (
6
+ ChatDataset,
7
+ RosettaDataCollator,
8
+ )
9
+ from .model_utils import setup_models
10
+
11
+ __all__ = [
12
+ "RosettaTrainer",
13
+ "ProjectorSaveCallback",
14
+ "freeze_model_components",
15
+ "InstructCoderChatDataset",
16
+ "ChatDataset",
17
+ "RosettaDataCollator",
18
+ "create_instructcoder_dataset",
19
+ "setup_models"
20
+ ]
rosetta/train/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (505 Bytes). View file
 
rosetta/train/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (616 Bytes). View file
 
rosetta/train/__pycache__/dataset_adapters.cpython-310.pyc ADDED
Binary file (50 kB). View file
 
rosetta/train/__pycache__/dataset_adapters.cpython-312.pyc ADDED
Binary file (73.3 kB). View file
 
rosetta/train/__pycache__/model_utils.cpython-310.pyc ADDED
Binary file (3.99 kB). View file
 
rosetta/train/__pycache__/model_utils.cpython-312.pyc ADDED
Binary file (5.58 kB). View file
 
rosetta/train/__pycache__/trainer.cpython-310.pyc ADDED
Binary file (2.56 kB). View file
 
rosetta/train/__pycache__/trainer.cpython-312.pyc ADDED
Binary file (3.74 kB). View file
 
rosetta/train/dataset_adapters.py ADDED
@@ -0,0 +1,1648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple dataset adapter for converting InstructCoder to chat format
3
+ """
4
+
5
+ from typing import List, Dict, Any, Optional, Union, Callable
6
+ from datasets import load_dataset, load_from_disk
7
+ from torch.utils.data import Dataset
8
+ import torch
9
+ from transformers import AutoTokenizer
10
+ import inspect
11
+ import os
12
+ import hashlib
13
+ # Dataset Registry System
14
+ DATASET_REGISTRY = {}
15
+
16
+ def register_dataset(cls=None, name=None):
17
+ """
18
+ Register a dataset class in the global registry.
19
+ Can be used as a decorator with or without arguments.
20
+
21
+ Args:
22
+ cls: The class to register
23
+ name: Optional name to register the class under. If None, uses the class name.
24
+
25
+ Returns:
26
+ The registered class
27
+ """
28
+ def _register(cls):
29
+ dataset_name = name if name is not None else cls.__name__
30
+ DATASET_REGISTRY[dataset_name] = cls
31
+ # Also register with lowercase name for case-insensitive lookup
32
+ DATASET_REGISTRY[dataset_name.lower()] = cls
33
+ return cls
34
+
35
+ # Called as @register_dataset
36
+ if cls is not None:
37
+ return _register(cls)
38
+
39
+ # Called as @register_dataset() or @register_dataset(name="DatasetName")
40
+ return _register
41
+
42
+
43
+ def capture_init_args(cls):
44
+ """
45
+ Decorator to capture initialization arguments of a dataset class.
46
+
47
+ Args:
48
+ cls: The class to decorate
49
+
50
+ Returns:
51
+ The decorated class with automatic init args capture
52
+ """
53
+ original_init = cls.__init__
54
+
55
+ def new_init(self, *args, **kwargs):
56
+ # Store all initialization arguments
57
+ self._init_args = {}
58
+
59
+ # Get parameter names from the original __init__ method
60
+ sig = inspect.signature(original_init)
61
+ param_names = list(sig.parameters.keys())[1:] # Skip 'self'
62
+
63
+ # Map positional args to parameter names
64
+ for i, arg in enumerate(args):
65
+ if i < len(param_names):
66
+ self._init_args[param_names[i]] = arg
67
+
68
+ # Add keyword args
69
+ self._init_args.update(kwargs)
70
+
71
+ # Call the original __init__
72
+ original_init(self, *args, **kwargs)
73
+
74
+ cls.__init__ = new_init
75
+ return cls
76
+
77
+
78
+ # Unified batch filtering functions
79
+
80
+
81
+ def create_text_length_filter(
82
+ max_length: int,
83
+ text_extractor: Callable[[Dict[str, Any]], str],
84
+ tokenizer: Optional[Any] = None,
85
+ use_tokens: bool = False
86
+ ):
87
+ """
88
+ Unified text length filter that can handle both word count and token count filtering.
89
+
90
+ Args:
91
+ max_length: Maximum allowed length (words or tokens)
92
+ text_extractor: Function that extracts text from a single sample
93
+ tokenizer: Tokenizer for token counting (required if use_tokens=True)
94
+ use_tokens: If True, count tokens; if False, count words
95
+
96
+ Returns:
97
+ Filter function that can be used with dataset.filter(batched=True)
98
+ """
99
+ if use_tokens and tokenizer is None:
100
+ raise ValueError("Tokenizer must be provided when use_tokens=True")
101
+
102
+ def _text_length_filter_batch(batch):
103
+ batch_size = len(next(iter(batch.values())))
104
+ samples = [{key: values[i] for key, values in batch.items()} for i in range(batch_size)]
105
+ try:
106
+ texts = [text_extractor(sample) for sample in samples]
107
+ if use_tokens:
108
+ if hasattr(tokenizer, 'apply_chat_template') and any(isinstance(t, list) for t in texts):
109
+ rendered = []
110
+ for t in texts:
111
+ if isinstance(t, list):
112
+ rendered.append(tokenizer.apply_chat_template(t, tokenize=False, add_generation_prompt=False))
113
+ else:
114
+ rendered.append(str(t))
115
+ tokenized = tokenizer(rendered, add_special_tokens=False)
116
+ else:
117
+ tokenized = tokenizer([str(t) for t in texts], add_special_tokens=False)
118
+ lengths = [len(ids) for ids in tokenized["input_ids"]]
119
+ else:
120
+ lengths = [len(str(t).split()) for t in texts]
121
+ return [length <= max_length for length in lengths]
122
+ except Exception as e:
123
+ print(f"Error in text length filter: {e}")
124
+ return [False] * batch_size
125
+
126
+ return _text_length_filter_batch
127
+
128
+
129
+ def create_field_value_filter(target_value: Any, field_name: str, comparison: str = 'equal'):
130
+ """
131
+ Unified field value filter for exact matching, language filtering, etc.
132
+
133
+ Args:
134
+ target_value: Value to compare against
135
+ field_name: Field name to check
136
+ comparison: Type of comparison ('equal', 'not_equal', 'in', 'not_in')
137
+
138
+ Returns:
139
+ Filter function that can be used with dataset.filter(batched=True)
140
+ """
141
+ def _field_value_filter_batch(batch):
142
+ field_values = batch.get(field_name, [])
143
+
144
+ if comparison == 'equal':
145
+ return [value == target_value for value in field_values]
146
+ elif comparison == 'not_equal':
147
+ return [value != target_value for value in field_values]
148
+ elif comparison == 'in':
149
+ return [value in target_value for value in field_values]
150
+ elif comparison == 'not_in':
151
+ return [value not in target_value for value in field_values]
152
+ else:
153
+ raise ValueError(f"Unsupported comparison: {comparison}")
154
+
155
+ return _field_value_filter_batch
156
+
157
+
158
+ def create_modulo_filter(mod_base: int, exclude_values: Union[int, List[int]], field_name: str = '_id'):
159
+ """
160
+ Unified modulo filter for ID-based filtering.
161
+
162
+ Args:
163
+ mod_base: Modulo base
164
+ exclude_values: Value(s) to exclude (can be single int or list)
165
+ field_name: Field name containing the ID
166
+
167
+ Returns:
168
+ Filter function that can be used with dataset.filter(batched=True)
169
+ """
170
+ if isinstance(exclude_values, int):
171
+ exclude_values = [exclude_values]
172
+
173
+ def _modulo_filter_batch(batch):
174
+ ids = batch.get(field_name, [])
175
+ results = []
176
+
177
+ for _id in ids:
178
+ try:
179
+ # Try numeric conversion first
180
+ id_num = int(_id)
181
+ mod_result = id_num % mod_base
182
+ except (ValueError, TypeError):
183
+ # Use hash for non-numeric IDs
184
+ id_hash = hash(str(_id))
185
+ mod_result = id_hash % mod_base
186
+
187
+ results.append(mod_result not in exclude_values)
188
+
189
+ return results
190
+
191
+ return _modulo_filter_batch
192
+
193
+
194
+ def create_conversation_length_filter(min_messages: int, text_field: str = 'conversations'):
195
+ """
196
+ Unified conversation length filter for OpenHermes-style datasets.
197
+
198
+ Args:
199
+ min_messages: Minimum number of messages required (excluding system messages)
200
+ text_field: Field name containing the conversation
201
+
202
+ Returns:
203
+ Filter function that can be used with dataset.filter(batched=True)
204
+ """
205
+ def _conversation_length_filter_batch(batch):
206
+ conversations_list = batch.get(text_field, [])
207
+ results = []
208
+
209
+ for conversations in conversations_list:
210
+ try:
211
+ # Extract messages (excluding system)
212
+ message_count = 0
213
+ for msg in conversations:
214
+ role = msg.get('from') or msg.get('role')
215
+ if role in ('human', 'user', 'gpt', 'assistant'):
216
+ message_count += 1
217
+
218
+ results.append(message_count > min_messages)
219
+ except Exception:
220
+ results.append(False)
221
+
222
+ return results
223
+
224
+ return _conversation_length_filter_batch
225
+
226
+
227
+ # Text extraction functions for common dataset patterns
228
+ def extract_mmlu_text(sample: Dict[str, Any], question_field: str = 'question', choices_field: str = 'choices') -> str:
229
+ """Extract text from MMLU-style samples"""
230
+ question = sample.get(question_field, '')
231
+ choices = sample.get(choices_field, [])
232
+
233
+ # Handle both list and dict formats for choices
234
+ if isinstance(choices, dict):
235
+ choices_text = choices.get('text', [])
236
+ else:
237
+ choices_text = choices
238
+
239
+ return (str(question) + " " + " ".join(map(str, choices_text))).strip()
240
+
241
+
242
+ def extract_chat_text(sample: Dict[str, Any], input_field: str = 'input',
243
+ context_field: str = 'context', answers_field: str = 'answers') -> List[Dict[str, str]]:
244
+ """Extract chat messages from LongBench-style samples"""
245
+ input_text = str(sample.get(input_field, ''))
246
+ context = str(sample.get(context_field, ''))
247
+ answers = sample.get(answers_field, [])
248
+
249
+ assistant_message = answers[0] if answers and len(answers) > 0 else "No answer provided"
250
+
251
+ # Build complete chat format
252
+ if context:
253
+ human_message = f"Context: {context}\n\nInstruction: {input_text}"
254
+ else:
255
+ human_message = f"Instruction: {input_text}"
256
+
257
+ return [
258
+ {"role": "user", "content": human_message.strip()},
259
+ {"role": "assistant", "content": assistant_message.strip()}
260
+ ]
261
+
262
+
263
+ def extract_conversation_text(sample: Dict[str, Any], text_field: str = 'conversations') -> str:
264
+ """Extract text from OpenHermes-style conversation samples"""
265
+ conversations = sample.get(text_field, [])
266
+
267
+ if conversations and len(conversations) > 0:
268
+ return conversations[0].get('value', '')
269
+ return ''
270
+
271
+
272
+ def extract_first_user_message(sample: Dict[str, Any], text_field: str = 'conversations') -> str:
273
+ """Extract the first human/user message from conversation-style samples."""
274
+ conversations = sample.get(text_field, [])
275
+ for msg in conversations:
276
+ role = msg.get('from') or msg.get('role')
277
+ if role in ('human', 'user'):
278
+ return str(msg.get('value', ''))
279
+ # Fallback to first message if role tags are missing
280
+ if conversations:
281
+ return str(conversations[0].get('value', ''))
282
+ return ''
283
+
284
+
285
+ def extract_first_assistant_message(sample: Dict[str, Any], text_field: str = 'conversations') -> str:
286
+ """Extract the first gpt/assistant message from conversation-style samples."""
287
+ conversations = sample.get(text_field, [])
288
+ for msg in conversations:
289
+ role = msg.get('from') or msg.get('role')
290
+ if role in ('gpt', 'assistant'):
291
+ return str(msg.get('value', ''))
292
+ # Fallback to second message if present
293
+ if len(conversations) > 1:
294
+ return str(conversations[1].get('value', ''))
295
+ return ''
296
+
297
+
298
+ def extract_openhermes_messages(sample: Dict[str, Any], text_field: str = 'conversations') -> List[Dict[str, str]]:
299
+ """Build chat messages excluding system; include all human/user and gpt/assistant in order."""
300
+ conversation = sample.get(text_field, [])
301
+ messages: List[Dict[str, str]] = []
302
+ for msg in conversation:
303
+ role = msg.get('from') or msg.get('role')
304
+ if role == 'system':
305
+ continue
306
+ if role in ('human', 'user'):
307
+ messages.append({"role": "user", "content": str(msg.get('value', '')).strip()})
308
+ elif role in ('gpt', 'assistant'):
309
+ messages.append({"role": "assistant", "content": str(msg.get('value', ''))})
310
+ return messages
311
+
312
+
313
+ def extract_instruction_text(sample: Dict[str, Any], instruction_field: str = 'instruction',
314
+ inputs_field: str = 'inputs') -> str:
315
+ """Extract text from Inkuba-style instruction samples"""
316
+ instruction = sample.get(instruction_field)
317
+ inputs = sample.get(inputs_field, '')
318
+
319
+ if instruction is not None:
320
+ return str(instruction) + "\n\n" + str(inputs)
321
+ else:
322
+ return str(inputs)
323
+
324
+
325
+ def extract_chat_pair_text(sample: Dict[str, Any], user_field: str = 'inputs',
326
+ assistant_field: str = 'targets') -> List[Dict[str, str]]:
327
+ """Extract chat messages from Aya-style samples"""
328
+ user_text = str(sample.get(user_field, ''))
329
+ assistant_text = str(sample.get(assistant_field, ''))
330
+
331
+ return [
332
+ {"role": "user", "content": user_text.strip()},
333
+ {"role": "assistant", "content": assistant_text.strip()}
334
+ ]
335
+
336
+
337
+
338
+ def extract_dolly_chat_messages(sample: Dict[str, Any]) -> List[Dict[str, str]]:
339
+ """Extract chat messages from Dolly-style samples.
340
+
341
+ Fields:
342
+ - instruction: str
343
+ - context: str (may be empty)
344
+ - response: str
345
+ - category: optional, may be empty/missing
346
+ """
347
+ instruction = str(sample.get('instruction', '')).strip()
348
+ context = str(sample.get('context', '') or '').strip()
349
+ response = str(sample.get('response', '')).strip()
350
+
351
+ if context:
352
+ user_message = f"{context}\n\n{instruction}"
353
+ else:
354
+ user_message = f"{instruction}"
355
+
356
+ return [
357
+ {"role": "user", "content": user_message.strip()},
358
+ {"role": "assistant", "content": response}
359
+ ]
360
+
361
+
362
+ def extract_mmmlu_chat_messages(sample: Dict[str, Any]) -> List[Dict[str, str]]:
363
+ """Extract chat messages from MMMLU-style samples (OpenAI/MMMLU)."""
364
+ choice_labels = ['A', 'B', 'C', 'D']
365
+
366
+ template = (
367
+ "Jibu kwa usahihi swali lifuatalo:\n\n"
368
+ "{{question}}\n\n"
369
+ "Chaguo:\n"
370
+ "{{choices}}\n\n"
371
+ "Maelekezo:\n"
372
+ "- Soma swali na chaguo zote kwa makini.\n"
373
+ "- Chagua jibu sahihi zaidi kati ya yaliyotolewa.\n"
374
+ "- Jibu TU kwa herufi (A, B, C, D) inayolingana na jibu sahihi.\n"
375
+ "- Usijumuishe maelezo, maandishi ya ziada, au alama yoyote ya uakifishaji.\n\n"
376
+ "Jibu lako:"
377
+ )
378
+
379
+ choices_text = ""
380
+ for label in choice_labels:
381
+ content = sample.get(label, '')
382
+ choices_text += f"{label}. {content}\n"
383
+
384
+ user_prompt = template.replace("{{choices}}", choices_text).replace("{{question}}", str(sample.get('Question', '')))
385
+
386
+ correct_label = sample.get('Answer', '')
387
+ correct_content = sample.get(correct_label, '')
388
+ assistant_response = f"**Jibu lako: {correct_label}. {correct_content}.**"
389
+
390
+ return [
391
+ {"role": "user", "content": user_prompt.strip()},
392
+ {"role": "assistant", "content": assistant_response}
393
+ ]
394
+
395
+
396
+
397
+
398
+ def apply_batch_filters(dataset, filters: list, filter_descriptions: list = None,
399
+ batch_size: int = 4096, combine_filters: bool = True,
400
+ num_proc: Optional[int] = None):
401
+ """
402
+ Apply multiple filters using native batched filtering for maximum performance.
403
+
404
+ Args:
405
+ dataset: Dataset to filter
406
+ filters: List of batched filter functions
407
+ filter_descriptions: Optional list of descriptions for logging
408
+ batch_size: Batch size for filtering operations
409
+ combine_filters: If True, combine all filters into a single batched operation
410
+
411
+ Returns:
412
+ Filtered dataset and original length
413
+ """
414
+ if not filters:
415
+ return dataset, len(dataset)
416
+
417
+ original_len = len(dataset)
418
+
419
+ if combine_filters and len(filters) > 1:
420
+ # Combine all filters into a single batched operation for maximum efficiency
421
+ def _combined_batch_filter(batch):
422
+ # Get results from all filters
423
+ filter_results = []
424
+ for filter_func in filters:
425
+ filter_results.append(filter_func(batch))
426
+
427
+ # Combine results with AND logic
428
+ combined_results = []
429
+ batch_size = len(filter_results[0]) if filter_results else 0
430
+
431
+ for i in range(batch_size):
432
+ combined_results.append(all(result[i] for result in filter_results))
433
+
434
+ return combined_results
435
+
436
+ # Apply combined filter in a single pass
437
+ filtered_dataset = dataset.filter(
438
+ _combined_batch_filter,
439
+ batched=True,
440
+ batch_size=batch_size,
441
+ num_proc=num_proc if num_proc and (num_proc or 0) > 1 else None,
442
+ desc="Combined batch filtering"
443
+ )
444
+
445
+ # Print filtering results
446
+ final_len = len(filtered_dataset)
447
+ if original_len != final_len:
448
+ print(f"Applied combined batch filtering: {original_len} -> {final_len} samples")
449
+ if filter_descriptions:
450
+ for desc in filter_descriptions:
451
+ print(f" - {desc}")
452
+
453
+ else:
454
+ # Apply each filter sequentially with batched processing
455
+ current_dataset = dataset
456
+
457
+ for i, (filter_func, desc) in enumerate(zip(filters, filter_descriptions or [''] * len(filters))):
458
+ pre_filter_len = len(current_dataset)
459
+
460
+ current_dataset = current_dataset.filter(
461
+ filter_func,
462
+ batched=True,
463
+ batch_size=batch_size,
464
+ num_proc=num_proc if num_proc and (num_proc or 0) > 1 else None,
465
+ desc=f"Filtering: {desc}" if desc else f"Filter {i+1}"
466
+ )
467
+
468
+ post_filter_len = len(current_dataset)
469
+ if desc and pre_filter_len != post_filter_len:
470
+ print(f" - {desc}: {pre_filter_len} -> {post_filter_len} samples")
471
+
472
+ filtered_dataset = current_dataset
473
+ final_len = len(filtered_dataset)
474
+ if original_len != final_len:
475
+ print(f"Applied sequential batch filtering: {original_len} -> {final_len} samples")
476
+
477
+ return filtered_dataset, original_len
478
+
479
+
480
+ def generate_kv_cache_index(instruction_length: int, full_length: int) -> torch.tensor:
481
+ """
482
+ Generate KV cache index for the input sequence.
483
+
484
+ Args:
485
+ instruction_length: Length of the instruction tokens
486
+ full_length: Total length of the full conversation tokens
487
+
488
+ Returns:
489
+ Tensor with KV cache index
490
+ """
491
+ assert instruction_length <= full_length
492
+
493
+ instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(instruction_length - 1, 1)
494
+ label_index = torch.tensor([-1, 0], dtype=torch.long).repeat(full_length - instruction_length + 1, 1)
495
+
496
+ kv_cache_index = torch.cat([instruction_index, label_index], dim=0) # shape: (seq_len, 2)
497
+
498
+ return kv_cache_index
499
+
500
+
501
+ """
502
+ Instruction dataset
503
+
504
+ Convert any form of inputs to standard message format
505
+ """
506
+
507
+ @register_dataset
508
+ @capture_init_args
509
+ class LongBenchChatDataset(Dataset):
510
+ """LongBench数据集转换为LongBench原始格式"""
511
+
512
+ def __init__(self, split: str = "test", num_samples: Optional[int] = None,
513
+ dataset_name: Optional[str] = None, language: Optional[str] = None,
514
+ max_word_count: Optional[int] = None, max_length: Optional[int] = 14000,
515
+ use_longbench_e: bool = True, filter_mod4: bool = True):
516
+ """
517
+ 初始化LongBench数据集
518
+
519
+ Args:
520
+ split: 数据集分割 ("test" - LongBench主要使用test分割)
521
+ num_samples: 使用的样本数量 (None表示全部)
522
+ dataset_name: 特定数据集名称 (None表示所有数据集)
523
+ language: 语言过滤 ("en" 或 "zh")
524
+ max_word_count: 最大词数限制(用于英文文本)
525
+ max_length: 最大字符长度限制
526
+ use_longbench_e: 是否使用LongBench-E版本
527
+ filter_mod4: 是否过滤_id mod4余1的样本
528
+ """
529
+ print(f"Loading LongBench{' -E' if use_longbench_e else ''} dataset (split: {split}, dataset: {dataset_name})...")
530
+
531
+ # LongBench包含的数据集列表
532
+ longbench_datasets = [
533
+ "narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa",
534
+ "2wikimqa", "musique", "dureader", "gov_report", "qmsum", "multi_news",
535
+ "vcsum", "trec", "triviaqa", "samsum", "lsht", "passage_count",
536
+ "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"
537
+ ]
538
+
539
+ longbench_e_datasets = [
540
+ "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report",
541
+ "multi_news", "trec", "triviaqa", "samsum", "passage_count",
542
+ "passage_retrieval_en", "lcc", "repobench-p"
543
+ ]
544
+
545
+ target_datasets = longbench_e_datasets if use_longbench_e else longbench_datasets
546
+
547
+ # 定义LongBench提示模板
548
+ self.dataset_prompt_formats = {
549
+ "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
550
+ "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
551
+ "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
552
+ "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
553
+ "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
554
+ "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
555
+ "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
556
+ "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
557
+ "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:",
558
+ "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
559
+ "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:",
560
+ "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:",
561
+ "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}",
562
+ "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}",
563
+ "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}",
564
+ "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}",
565
+ "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
566
+ "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ",
567
+ "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:",
568
+ "lcc": "Please complete the code given below. \n{context}Next line of code:\n",
569
+ "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n"
570
+ }
571
+
572
+ # 定义不使用聊天模板的任务
573
+ #self.no_chat_template_tasks = ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]
574
+ self.no_chat_template_tasks=['']
575
+ self.use_longbench_e = use_longbench_e
576
+ self.max_length = max_length
577
+
578
+ if dataset_name:
579
+ if dataset_name not in target_datasets:
580
+ raise ValueError(f"Dataset {dataset_name} not found in LongBench{' -E' if use_longbench_e else ''}")
581
+ target_datasets = [dataset_name]
582
+ self.current_evaluating_subject = dataset_name
583
+ else:
584
+ self.current_evaluating_subject = None
585
+
586
+ # 加载所有选定的数据集
587
+ all_data = []
588
+ for dataset in target_datasets:
589
+ try:
590
+ dataset_suffix = f"{dataset}_e" if use_longbench_e else dataset
591
+ data = load_dataset('THUDM/LongBench', dataset_suffix, split=split)
592
+ print(f" Loaded {len(data)} samples from {dataset}")
593
+
594
+ # 添加数据集名称标识
595
+ data = data.map(lambda x: {"dataset_source": dataset})
596
+ all_data.append(data)
597
+ except Exception as e:
598
+ print(f"Warning: Failed to load {dataset}: {e}")
599
+ continue
600
+
601
+ if not all_data:
602
+ raise ValueError("No datasets were successfully loaded")
603
+
604
+
605
+ from datasets import concatenate_datasets
606
+ self.dataset = concatenate_datasets(all_data)
607
+
608
+
609
+
610
+
611
+
612
+ # mod4!=1
613
+ if filter_mod4:
614
+ original_len = len(self.dataset)
615
+
616
+ def _mod4_not_1(example):
617
+ _id = example.get('_id', '')
618
+ id_hash = int(hashlib.sha256(str(_id).encode('utf-8')).hexdigest(), 16)
619
+
620
+ return id_hash % 4 != 1
621
+
622
+ self.dataset = self.dataset.filter(_mod4_not_1)
623
+ print(f"Filtered by _id mod4 != 1: {original_len} -> {len(self.dataset)} samples")
624
+
625
+ # 限制样本数量
626
+ if num_samples and num_samples < len(self.dataset):
627
+ self.dataset = self.dataset.select(range(num_samples))
628
+
629
+ print(f"Loaded total {len(self.dataset)} samples from LongBench{' -E' if use_longbench_e else ''}")
630
+ def __len__(self):
631
+ return len(self.dataset)
632
+
633
+ def _format_longbench_example(self, example: Dict[str, Any], tokenizer: AutoTokenizer) -> str:
634
+
635
+ # 1. 确定任务类型
636
+ dataset_source = example.get('dataset_source', '')
637
+ if self.current_evaluating_subject:
638
+ current_subject = self.current_evaluating_subject
639
+ else:
640
+ current_subject = dataset_source
641
+
642
+ # 仅当字符串以"_e"结尾时才替换
643
+ import re
644
+ subject = re.sub(r"_e$", "", current_subject) if self.use_longbench_e else current_subject
645
+
646
+ # 2. 获取提示模板
647
+ if subject not in self.dataset_prompt_formats:
648
+ subject = "narrativeqa" # 默认模板
649
+ prompt_format = self.dataset_prompt_formats[subject]
650
+
651
+ # 3. 直接使用**example展开所有字段
652
+ raw_prompt = prompt_format.format(**example)
653
+
654
+ # 4. 超长截断逻辑
655
+ tokenized_raw = tokenizer(raw_prompt, truncation=False, return_tensors="pt").input_ids[0]
656
+ if len(tokenized_raw) > self.max_length:
657
+ half_len = int(self.max_length / 2)
658
+ raw_prompt = tokenizer.decode(tokenized_raw[:half_len], skip_special_tokens=True) + \
659
+ tokenizer.decode(tokenized_raw[-half_len:], skip_special_tokens=True)
660
+
661
+ # 5. 应用Chat Template
662
+
663
+ final_prompt = raw_prompt
664
+ print(len(tokenized_raw))
665
+ return final_prompt
666
+
667
+ def __getitem__(self, idx):
668
+
669
+ sample = self.dataset[idx]
670
+
671
+ # 格式化样本
672
+ formatted_prompt = self._format_longbench_example(sample, self.tokenizer)
673
+
674
+ # 提取答案
675
+ answers = sample.get('answers', [])
676
+ assistant_message = answers[0] if answers and len(answers) > 0 else "No answer provided"
677
+
678
+ return [
679
+ {
680
+ "role": "user",
681
+ "content": formatted_prompt.strip()
682
+ },
683
+ {
684
+ "role": "assistant",
685
+ "content": assistant_message.strip()
686
+ }
687
+ ]
688
+
689
+ @register_dataset
690
+ @capture_init_args
691
+ class MMLUChatDataset(Dataset):
692
+ """Simple MMLU dataset converted to chat format"""
693
+
694
+ def __init__(self, split: str = "train", num_samples: Optional[int] = None, max_word_count: Optional[int] = None):
695
+ """
696
+ Initialize the dataset
697
+
698
+ Args:
699
+ split: Dataset split
700
+ num_samples: Number of samples to use (None for all)
701
+ max_word_count: If set, drop samples whose question + all choices exceed this word count
702
+ """
703
+ print(f"Loading MMLU dataset (split: {split})...")
704
+ # Load dataset
705
+ dataset = load_dataset("cais/mmlu", "all")
706
+ dataset = dataset[split]
707
+
708
+ # Ensure we have a proper Dataset object
709
+ if hasattr(dataset, 'select'):
710
+ self.dataset = dataset
711
+ else:
712
+ raise ValueError(f"Unexpected dataset type: {type(dataset)}")
713
+
714
+ # Limit samples if specified
715
+ if num_samples and num_samples < len(self.dataset):
716
+ self.dataset = self.dataset.select(range(num_samples))
717
+
718
+ # Apply total token length filtering on full chat (user + assistant)
719
+ if max_word_count is not None:
720
+ # Use a small tokenizer for speed; total token length = chat(user+assistant)
721
+ self._mmlu_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
722
+ extractor = lambda sample: self._build_chat_messages(sample)
723
+ filters = [create_text_length_filter(max_word_count, extractor, self._mmlu_tokenizer, use_tokens=True)]
724
+ filter_descriptions = [f"Token count filter (full chat): max {max_word_count}"]
725
+ self.dataset, _ = apply_batch_filters(self.dataset, filters, filter_descriptions)
726
+
727
+ print(f"Loaded {len(self.dataset)} samples")
728
+
729
+ def __len__(self):
730
+ return len(self.dataset)
731
+
732
+ def __getitem__(self, idx):
733
+ sample = self.dataset[idx]
734
+ return self._build_chat_messages(sample)
735
+
736
+ def _build_chat_messages(self, sample: Dict[str, Any]) -> List[Dict[str, str]]:
737
+ choice_labels = ['A', 'B', 'C', 'D']
738
+ question = sample.get('question', '')
739
+ choices_list = sample.get('choices', [])
740
+ user_prompt = f"Question: {question}\n\nChoices:\n"
741
+ for i, choice in enumerate(choices_list):
742
+ label = choice_labels[i] if i < len(choice_labels) else chr(65 + i)
743
+ user_prompt += f"{label}. {choice}\n"
744
+ ans_idx = sample.get('answer', 0)
745
+ if isinstance(ans_idx, str) and ans_idx.isdigit():
746
+ ans_idx = int(ans_idx)
747
+ ans_label = choice_labels[ans_idx] if 0 <= int(ans_idx) < len(choice_labels) else chr(65 + int(ans_idx))
748
+ assistant_text = f"The correct answer is {ans_label}."
749
+ return [
750
+ {"role": "user", "content": user_prompt.strip()},
751
+ {"role": "assistant", "content": assistant_text.strip()},
752
+ ]
753
+
754
+ @register_dataset
755
+ @capture_init_args
756
+ class MMLUCotChatDataset(Dataset):
757
+ """Simple MMLUCot dataset converted to chat format"""
758
+
759
+ def __init__(self, split: str = "train", num_samples: Optional[int] = None):
760
+ """
761
+ Initialize the dataset
762
+
763
+ Args:
764
+ split: Dataset split
765
+ num_samples: Number of samples to use (None for all)
766
+ """
767
+ print(f"Loading MMLUCot dataset (split: {split})...")
768
+ # Load dataset
769
+ dataset = load_dataset("Brench/MMLU-Pro-CoT-Train-43K")
770
+ dataset = dataset[split]
771
+
772
+ # Ensure we have a proper Dataset object
773
+ if hasattr(dataset, 'select'):
774
+ self.dataset = dataset
775
+ else:
776
+ raise ValueError(f"Unexpected dataset type: {type(dataset)}")
777
+
778
+ # Limit samples if specified
779
+ if num_samples and num_samples < len(self.dataset):
780
+ self.dataset = self.dataset.select(range(num_samples))
781
+
782
+ print(f"Loaded {len(self.dataset)} samples")
783
+
784
+ def __len__(self):
785
+ return len(self.dataset)
786
+
787
+ def __getitem__(self, idx):
788
+ sample = self.dataset[idx]
789
+
790
+ user_prompt = sample['question'] + "\n"
791
+
792
+ assistant_response = sample['chain_of_thoughts']
793
+
794
+ return [
795
+ {
796
+ "role": "user",
797
+ "content": user_prompt.strip()
798
+ },
799
+ {
800
+ "role": "assistant",
801
+ "content": assistant_response
802
+ }
803
+ ]
804
+
805
+ @register_dataset
806
+ @capture_init_args
807
+ class LLMGeneratedChatDataset(Dataset):
808
+ """Simple LLM Generated dataset converted to chat format"""
809
+
810
+ def __init__(self, split: str = "train", num_samples: Optional[int] = None, data_path: str = "./teacher_datasets/output/dataset_finished", max_word_count: Optional[int] = None):
811
+ """
812
+ Initialize the dataset
813
+
814
+ Args:
815
+ split: Dataset split
816
+ num_samples: Number of samples to use (None for all)
817
+ """
818
+ print(f"Loading LLMGeneratedCot dataset (split: {split})...")
819
+ # Load dataset
820
+ dataset = load_from_disk(data_path)
821
+
822
+ # Ensure we have a proper Dataset object
823
+ if hasattr(dataset, 'select'):
824
+ self.dataset = dataset
825
+ else:
826
+ raise ValueError(f"Unexpected dataset type: {type(dataset)}")
827
+
828
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
829
+
830
+ if max_word_count is not None:
831
+ original_len = len(self.dataset)
832
+ half = max_word_count // 2
833
+ def _under_token_limit(batch):
834
+ q = tokenizer(batch["input_text"], add_special_tokens=False, padding=False, truncation=False)
835
+ a = tokenizer(batch["model_response"], add_special_tokens=False, padding=False, truncation=False)
836
+ return [
837
+ (len(q_ids) <= half) and (len(q_ids) + len(a_ids) <= max_word_count)
838
+ for q_ids, a_ids in zip(q["input_ids"], a["input_ids"])
839
+ ]
840
+
841
+ self.dataset = self.dataset.filter(
842
+ _under_token_limit,
843
+ batched=True,
844
+ batch_size=2048, # 视显存/内存调大
845
+ num_proc=min(8, os.cpu_count() or 1),
846
+ load_from_cache_file=True,
847
+ desc=f"Filter max_word_count={max_word_count}",
848
+ )
849
+ print(f"Filtered by max_word_count={max_word_count}: {original_len} -> {len(self.dataset)} samples")
850
+
851
+ # Limit samples if specified
852
+ if num_samples and num_samples < len(self.dataset):
853
+ self.dataset = self.dataset.select(range(num_samples))
854
+
855
+ print(f"Loaded {len(self.dataset)} samples")
856
+
857
+ def __len__(self):
858
+ return len(self.dataset)
859
+
860
+ def __getitem__(self, idx):
861
+ sample = self.dataset[idx]
862
+
863
+ input_text = sample.get('input_text', '') or ''
864
+
865
+ # Parse question and choices from input_text, which is expected to contain a
866
+ # "Choices:" section followed by labeled options like "A. ..."
867
+ def _parse_question_and_choices(text: str):
868
+ lines = (text or '').splitlines()
869
+ # Find the line index for "Choices:" (case-insensitive, ignoring spaces)
870
+ choices_idx = -1
871
+ for i, line in enumerate(lines):
872
+ if line.strip().lower().startswith('choices'):
873
+ choices_idx = i
874
+ break
875
+
876
+ if choices_idx == -1:
877
+ # Fallback: no explicit Choices header found
878
+ question_part = text.strip()
879
+ return question_part, ''
880
+
881
+ question_part = '\n'.join(lines[:choices_idx]).strip()
882
+
883
+ # Collect labeled choices until blank line or instruction-like line
884
+ collected = []
885
+ for raw in lines[choices_idx + 1:]:
886
+ s = raw.strip()
887
+ if not s:
888
+ # Stop on first blank after having collected at least one choice
889
+ if collected:
890
+ break
891
+ else:
892
+ continue
893
+ lower = s.lower()
894
+ # Stop when hitting instruction section common in prompts
895
+ if lower.startswith('instructions:') or lower.startswith("let's ") or lower.startswith('you must'):
896
+ break
897
+ # Accept formats like "A. ..." or "A) ..."
898
+ if len(s) >= 3 and s[0] in 'ABCDEFGHIJ' and s[1] in ').' and s[2] == ' ':
899
+ collected.append(s)
900
+ else:
901
+ # If we've started collecting and this line doesn't look like a choice, stop
902
+ if collected:
903
+ break
904
+ # Otherwise ignore preamble noise
905
+ continue
906
+
907
+ choices_block = '\n'.join(collected).strip()
908
+ return question_part, choices_block
909
+
910
+ question, choices_block = _parse_question_and_choices(input_text)
911
+
912
+ # Rebuild user prompt using the evaluation CoT template
913
+ template = """Accurately answer the following question:
914
+
915
+ {{question}}
916
+
917
+ Choices:
918
+ {{choices}}
919
+
920
+ Instructions:
921
+ - Carefully read the question and all options.
922
+ - Let's think step by step and you must explain your reasoning briefly.
923
+ - Then give the final answer.
924
+ - Keep your response within 150 words."""
925
+
926
+ filled_prompt = (
927
+ template
928
+ .replace("{{question}}", question or '')
929
+ .replace("{{choices}}", choices_block or '')
930
+ )
931
+
932
+ user_prompt = filled_prompt.strip() + "\n"
933
+
934
+ assistant_response = sample['model_response']
935
+
936
+ return [
937
+ {
938
+ "role": "user",
939
+ "content": user_prompt.strip()
940
+ },
941
+ {
942
+ "role": "assistant",
943
+ "content": assistant_response
944
+ }
945
+ ]
946
+
947
+ @register_dataset
948
+ @capture_init_args
949
+ class OpenBookChatDataset(Dataset):
950
+ """Simple OpenBook dataset converted to chat format"""
951
+
952
+ def __init__(self, split: str = "train", num_samples: Optional[int] = None):
953
+ """
954
+ Initialize the dataset
955
+
956
+ Args:
957
+ split: Dataset split
958
+ num_samples: Number of samples to use (None for all)
959
+ """
960
+ print(f"Loading OpenBook dataset (split: {split})...")
961
+ # Load dataset
962
+ dataset = load_dataset("allenai/openbookqa", "main")
963
+ dataset = dataset[split]
964
+
965
+ # Ensure we have a proper Dataset object
966
+ if hasattr(dataset, 'select'):
967
+ self.dataset = dataset
968
+ else:
969
+ raise ValueError(f"Unexpected dataset type: {type(dataset)}")
970
+
971
+ # Limit samples if specified
972
+ if num_samples and num_samples < len(self.dataset):
973
+ self.dataset = self.dataset.select(range(num_samples))
974
+
975
+ print(f"Loaded {len(self.dataset)} samples")
976
+
977
+ def __len__(self):
978
+ return len(self.dataset)
979
+
980
+ def __getitem__(self, idx):
981
+ sample = self.dataset[idx]
982
+ choice_labels = ['A', 'B', 'C', 'D']
983
+
984
+ user_prompt = (
985
+ f"Question: {sample['question_stem']}\n\n"
986
+ f"Choices:\n"
987
+ )
988
+ for idx, choice in enumerate(sample['choices']['text']):
989
+ label = choice_labels[idx]
990
+ user_prompt += f"{label}. {choice}\n"
991
+
992
+ correct_label = sample["answerKey"]
993
+ assistant_response = f"The correct answer is {correct_label}."
994
+
995
+ return [
996
+ {
997
+ "role": "user",
998
+ "content": user_prompt.strip()
999
+ },
1000
+ {
1001
+ "role": "assistant",
1002
+ "content": assistant_response
1003
+ }
1004
+ ]
1005
+
1006
+ @register_dataset
1007
+ @capture_init_args
1008
+ class OpenHermesChatDataset(Dataset):
1009
+ """Simple general dataset converted to chat format"""
1010
+
1011
+ def __init__(self, split: str = "train", num_samples: Optional[int] = None, max_word_count: Optional[int] = None, min_conversation_turns: int = 0):
1012
+ """
1013
+ Initialize the dataset
1014
+
1015
+ Args:
1016
+ split: Dataset split
1017
+ num_samples: Number of samples to use (None for all)
1018
+ max_word_count: Maximum token count for filtering
1019
+ min_conversation_turns: Minimum number of conversation turns (default 3 for multi-turn conversations)
1020
+ """
1021
+ print(f"Loading OpenHermes dataset (split: {split})...")
1022
+ # Load dataset
1023
+ dataset = load_dataset("teknium/OpenHermes-2.5")
1024
+ dataset = dataset[split]
1025
+
1026
+ # Ensure we have a proper Dataset object
1027
+ if hasattr(dataset, 'select'):
1028
+ self.dataset = dataset
1029
+ else:
1030
+ raise ValueError(f"Unexpected dataset type: {type(dataset)}")
1031
+
1032
+ # Limit samples if specified
1033
+ if num_samples and num_samples < len(self.dataset):
1034
+ self.dataset = self.dataset.select(range(num_samples))
1035
+
1036
+ # Apply filters
1037
+ filters = []
1038
+ filter_descriptions = []
1039
+
1040
+ # Filter by minimum conversation length (exclude conversations with <= 2 messages)
1041
+ if min_conversation_turns > 0:
1042
+ filters.append(create_conversation_length_filter(min_conversation_turns - 1, 'conversations'))
1043
+ filter_descriptions.append(f"Conversation length filter: min {min_conversation_turns} messages (multi-turn only)")
1044
+
1045
+ # Apply conversation-level token count filtering (all messages combined <= max_word_count)
1046
+ if max_word_count is not None:
1047
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
1048
+ extractor = lambda sample: extract_openhermes_messages(sample, 'conversations')
1049
+ filters.append(create_text_length_filter(max_word_count, extractor, tokenizer, use_tokens=True))
1050
+ filter_descriptions.append(f"Token count filter: max {max_word_count}")
1051
+
1052
+ # Apply all filters
1053
+ if filters:
1054
+ self.dataset, _ = apply_batch_filters(self.dataset, filters, filter_descriptions, num_proc=8)
1055
+
1056
+ print(f"Loaded {len(self.dataset)} samples")
1057
+
1058
+ def __len__(self):
1059
+ return len(self.dataset)
1060
+
1061
+ def __getitem__(self, idx):
1062
+ sample = self.dataset[idx]
1063
+ return extract_openhermes_messages(sample, 'conversations')
1064
+
1065
+ """
1066
+ Chat dataset
1067
+
1068
+ Convert standard message format to input_ids and labels
1069
+ """
1070
+ class ChatDataset(Dataset):
1071
+ """Dataset for chat format training with HuggingFace Trainer compatibility"""
1072
+
1073
+ def __init__(self, chat_dataset, tokenizer: AutoTokenizer, max_length: int = 32768):
1074
+ self.chat_dataset = chat_dataset
1075
+ self.tokenizer = tokenizer
1076
+ self.max_length = max_length
1077
+
1078
+ def __len__(self):
1079
+ return len(self.chat_dataset)
1080
+
1081
+ def __getitem__(self, idx) -> Dict[str, Any]:
1082
+ messages = self.chat_dataset[idx]
1083
+
1084
+ # Get instruction (first message)
1085
+ instruction = self.tokenizer.apply_chat_template(
1086
+ messages[:-1],
1087
+ tokenize=False,
1088
+ add_generation_prompt=True,
1089
+ enable_thinking=False,
1090
+ )
1091
+
1092
+ # Get full conversation
1093
+ full_text = self.tokenizer.apply_chat_template(
1094
+ messages,
1095
+ tokenize=False,
1096
+ add_generation_prompt=False,
1097
+ enable_thinking=False,
1098
+ )
1099
+
1100
+ # Tokenize instruction and full text
1101
+ instruction_tokens = self.tokenizer(instruction, add_special_tokens=False)["input_ids"]
1102
+ full_tokens = self.tokenizer(full_text, add_special_tokens=False)["input_ids"]
1103
+
1104
+ # Truncate if necessary
1105
+ if len(full_tokens) > self.max_length:
1106
+ full_tokens = full_tokens[:self.max_length]
1107
+
1108
+ # Create labels (-100 for instruction tokens, actual tokens for response)
1109
+ labels = [-100] * len(instruction_tokens) + full_tokens[len(instruction_tokens):]
1110
+ # labels = [-100] * (len(full_tokens) - 4) + full_tokens[-4:]
1111
+ if len(labels) > self.max_length:
1112
+ labels = labels[:self.max_length]
1113
+
1114
+ kv_cache_index = generate_kv_cache_index(len(instruction_tokens), len(full_tokens))
1115
+ # kv_cache_index = generate_kv_cache_index(len(full_tokens)-4, len(full_tokens))
1116
+ # kv_cache_index = generate_kv_cache_index(len(full_tokens) + 1, len(full_tokens))
1117
+
1118
+ return {
1119
+ "input_ids": full_tokens,
1120
+ "labels": labels,
1121
+ "kv_cache_index": kv_cache_index
1122
+ }
1123
+
1124
+
1125
+ class AlignedChatDataset(Dataset):
1126
+ """Dataset that precomputes aligned inputs for SLM/LLM using a TokenAligner"""
1127
+
1128
+ def __init__(self, instruct_dataset: Dataset, aligner: Any, max_length: int = 32768):
1129
+ self.dataset = instruct_dataset
1130
+ self.aligner = aligner
1131
+ self.max_length = max_length
1132
+
1133
+ def __len__(self):
1134
+ return len(self.dataset)
1135
+
1136
+ def __getitem__(self, idx):
1137
+ messages = self.dataset[idx]
1138
+
1139
+ # Build aligned sequences and section map
1140
+ details = self.aligner.align_chat_messages(messages, add_generation_prompt=False, return_details=True)
1141
+ slm_ids: List[int] = details['slm_ids_padded']
1142
+ llm_ids: List[int] = details['llm_ids_padded']
1143
+ sections = details['sections']
1144
+
1145
+ slm_pad_mask = torch.tensor(details['slm_padding_mask'])
1146
+ llm_pad_mask = torch.tensor(details['llm_padding_mask'])
1147
+ message_mask = torch.tensor(details['message_mask'])
1148
+
1149
+ # Determine instruction boundary as start of the last message section
1150
+ instr_end = 0
1151
+ for sec_idx in range(len(sections) - 1, -1, -1):
1152
+ sec = sections[sec_idx]
1153
+ if sec['type'] == 'message':
1154
+ instr_end = sec['slm_range'][0]
1155
+ break
1156
+
1157
+ # Labels: follow ChatDataset policy (-100 for instruction-only, supervise the rest)
1158
+ labels = [-100] * instr_end + slm_ids[instr_end:]
1159
+ if len(labels) > self.max_length:
1160
+ labels = labels[:self.max_length]
1161
+
1162
+ # Truncate inputs if needed
1163
+ if len(slm_ids) > self.max_length:
1164
+ slm_ids = slm_ids[:self.max_length]
1165
+ # Truncate padding mask accordingly
1166
+ slm_pad_mask = slm_pad_mask[:self.max_length]
1167
+ if len(llm_ids) > self.max_length:
1168
+ llm_ids = llm_ids[:self.max_length]
1169
+ llm_pad_mask = llm_pad_mask[:self.max_length]
1170
+
1171
+ # KV cache index based on instruction length
1172
+ kv_cache_index = generate_kv_cache_index(instr_end, len(slm_ids))
1173
+ # Addtionally mask non-message parts
1174
+ kv_cache_index[~message_mask] = torch.tensor([[-1,0]])
1175
+
1176
+ return {
1177
+ "input_ids": [slm_ids, llm_ids],
1178
+ "labels": labels,
1179
+ "kv_cache_index": kv_cache_index,
1180
+ "messages": messages,
1181
+ # Per-model aligned inputs (per-sample, pre-batch)
1182
+ "model_padding_mask": [slm_pad_mask, llm_pad_mask],
1183
+ }
1184
+
1185
+
1186
+ class BaselineChatDataset(Dataset):
1187
+ """Simple dataset for baseline model training without Rosetta-specific features"""
1188
+
1189
+ def __init__(self, chat_dataset, tokenizer: AutoTokenizer, max_length: int = 2048):
1190
+ self.chat_dataset = chat_dataset
1191
+ self.tokenizer = tokenizer
1192
+ self.max_length = max_length
1193
+
1194
+ def __len__(self):
1195
+ return len(self.chat_dataset)
1196
+
1197
+ def __getitem__(self, idx):
1198
+ messages = self.chat_dataset[idx]
1199
+
1200
+ # Get instruction (first message)
1201
+ instruction = self.tokenizer.apply_chat_template(
1202
+ messages[:1],
1203
+ tokenize=False,
1204
+ add_generation_prompt=True,
1205
+ enable_thinking=False,
1206
+ )
1207
+
1208
+ # Get full conversation
1209
+ full_text = self.tokenizer.apply_chat_template(
1210
+ messages,
1211
+ tokenize=False,
1212
+ add_generation_prompt=False,
1213
+ enable_thinking=False,
1214
+ )
1215
+
1216
+ # Tokenize instruction and full text
1217
+ instruction_tokens = self.tokenizer(instruction, add_special_tokens=False)["input_ids"]
1218
+ full_tokens = self.tokenizer(full_text, add_special_tokens=False)["input_ids"]
1219
+
1220
+ # Truncate if necessary
1221
+ if len(full_tokens) > self.max_length:
1222
+ full_tokens = full_tokens[:self.max_length]
1223
+
1224
+ # Create labels (-100 for instruction tokens, actual tokens for response)
1225
+ labels = [-100] * len(instruction_tokens) + full_tokens[len(instruction_tokens):]
1226
+ if len(labels) > self.max_length:
1227
+ labels = labels[:self.max_length]
1228
+
1229
+ return {
1230
+ "input_ids": full_tokens,
1231
+ "labels": labels,
1232
+ }
1233
+
1234
+ """
1235
+ Data collator
1236
+
1237
+ Batch chat data to model input
1238
+ """
1239
+
1240
+ class RosettaDataCollator:
1241
+ """Improved data collator for RosettaModel training with cleaner logic"""
1242
+
1243
+ def __init__(self, slm_tokenizer: AutoTokenizer, llm_tokenizer: AutoTokenizer = None,
1244
+ pad_to_multiple_of: Optional[int] = None, max_length: Optional[int] = None,
1245
+ aligner: Optional[Any] = None, do_alignment: bool = False):
1246
+ """
1247
+ Initialize the collator.
1248
+
1249
+ Args:
1250
+ slm_tokenizer: Small language model tokenizer
1251
+ llm_tokenizer: Large language model tokenizer (optional)
1252
+ pad_to_multiple_of: Pad sequence length to multiple of this value
1253
+ max_length: Maximum sequence length
1254
+ aligner: Alignment module (if needed)
1255
+ do_alignment: Whether to perform alignment
1256
+ """
1257
+ self.slm_tokenizer = slm_tokenizer
1258
+ self.llm_tokenizer = llm_tokenizer
1259
+ self.pad_to_multiple_of = pad_to_multiple_of
1260
+ self.max_length = max_length
1261
+ self.aligner = aligner
1262
+ self.do_alignment = do_alignment
1263
+
1264
+ if self.do_alignment:
1265
+ assert self.aligner is not None, "Aligner must be provided if do_alignment is True"
1266
+
1267
+ # Store padding token IDs for different models
1268
+ self.slm_pad_token_id = self.slm_tokenizer.pad_token_id
1269
+ self.llm_pad_token_id = self.llm_tokenizer.pad_token_id if self.llm_tokenizer else self.slm_pad_token_id
1270
+
1271
+ def _normalize_input_format(self, feature: Dict[str, Any]) -> Dict[str, Any]:
1272
+ """
1273
+ Normalize input format to handle both single and dual model inputs.
1274
+
1275
+ Args:
1276
+ feature: Input feature dictionary
1277
+
1278
+ Returns:
1279
+ Normalized feature with consistent format
1280
+ """
1281
+ # Normalize input_ids: ensure it's always a list of tensors
1282
+ input_ids = feature['input_ids']
1283
+ if isinstance(input_ids, list) and len(input_ids) > 0:
1284
+ if isinstance(input_ids[0], list):
1285
+ # Case: [[ids1], [ids2]] -> convert to list of tensors
1286
+ input_ids_tensors = [torch.tensor(ids, dtype=torch.long) for ids in input_ids]
1287
+ else:
1288
+ # Case: [id1, id2, ...] -> single model case
1289
+ input_ids_tensors = [torch.tensor(input_ids, dtype=torch.long)]
1290
+ else:
1291
+ # Fallback: assume single model
1292
+ input_ids_tensors = [torch.tensor(input_ids, dtype=torch.long)]
1293
+
1294
+ # Normalize attention_mask
1295
+ attention_masks = []
1296
+ if "model_padding_mask" in feature:
1297
+ # Use model-specific padding masks
1298
+ for model_padding_mask in feature["model_padding_mask"]:
1299
+ attention_masks.append((~model_padding_mask).float())
1300
+ else:
1301
+ # Generate default attention masks
1302
+ for input_tensor in input_ids_tensors:
1303
+ attention_masks.append(torch.ones(len(input_tensor), dtype=torch.float))
1304
+
1305
+ return {
1306
+ 'input_ids': input_ids_tensors,
1307
+ 'attention_mask': attention_masks,
1308
+ 'labels': torch.tensor(feature['labels'], dtype=torch.long),
1309
+ 'kv_cache_index': feature['kv_cache_index'],
1310
+ 'position_ids': torch.arange(len(feature['labels']), dtype=torch.long)
1311
+ }
1312
+
1313
+ def _split_into_sections(self, normalized_feature: Dict[str, Any]) -> List[Dict[str, Any]]:
1314
+ """
1315
+ Split sequence into sections based on kv_cache_index changes.
1316
+
1317
+ Args:
1318
+ normalized_feature: Normalized feature dictionary
1319
+
1320
+ Returns:
1321
+ List of sections
1322
+ """
1323
+ kv_idx = normalized_feature['kv_cache_index']
1324
+
1325
+ # Find change points in kv_cache_index
1326
+ change_points = [0]
1327
+ for i in range(1, kv_idx.size(0)):
1328
+ if not torch.equal(kv_idx[i], kv_idx[i - 1]):
1329
+ change_points.append(i)
1330
+ change_points.append(kv_idx.size(0))
1331
+
1332
+ # Create sections
1333
+ sections = []
1334
+ for i in range(len(change_points) - 1):
1335
+ start, end = change_points[i], change_points[i + 1]
1336
+ section = {
1337
+ 'input_ids': [ids[start:end] for ids in normalized_feature['input_ids']],
1338
+ 'attention_mask': [mask[start:end] for mask in normalized_feature['attention_mask']],
1339
+ 'labels': normalized_feature['labels'][start:end],
1340
+ 'kv_cache_index': normalized_feature['kv_cache_index'][start:end],
1341
+ 'position_ids': normalized_feature['position_ids'][start:end]
1342
+ }
1343
+ sections.append(section)
1344
+
1345
+ return sections
1346
+
1347
+ def _pad_sections(self, all_sections: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
1348
+ """
1349
+ Pad sections to ensure uniform structure across batch.
1350
+
1351
+ Args:
1352
+ all_sections: List of section lists for each sample
1353
+
1354
+ Returns:
1355
+ Padded batch dictionary
1356
+ """
1357
+ max_sections = max(len(sections) for sections in all_sections)
1358
+ num_models = len(all_sections[0][0]['input_ids']) if all_sections else 1
1359
+
1360
+ # Initialize output structure - keep models separate throughout
1361
+ padded_output = {
1362
+ 'input_ids_per_model': [[] for _ in range(num_models)], # One list per model
1363
+ 'attention_mask_per_model': [[] for _ in range(num_models)], # One list per model
1364
+ 'labels': [],
1365
+ 'kv_cache_index': [],
1366
+ 'position_ids': []
1367
+ }
1368
+
1369
+ # Process each section index
1370
+ for sec_idx in range(max_sections):
1371
+ section_data = self._collect_section_data(all_sections, sec_idx, num_models)
1372
+ padded_section = self._pad_single_section(section_data, num_models)
1373
+
1374
+ # Add to output - keep models separate
1375
+ for model_idx in range(num_models):
1376
+ padded_output['input_ids_per_model'][model_idx].append(
1377
+ padded_section['input_ids_per_model'][model_idx])
1378
+ padded_output['attention_mask_per_model'][model_idx].append(
1379
+ padded_section['attention_mask_per_model'][model_idx])
1380
+
1381
+ padded_output['labels'].append(padded_section['labels'])
1382
+ padded_output['kv_cache_index'].append(padded_section['kv_cache_index'])
1383
+ padded_output['position_ids'].append(padded_section['position_ids'])
1384
+
1385
+ # Concatenate sections and finalize
1386
+ return self._finalize_output(padded_output, num_models, len(all_sections))
1387
+
1388
+ def _collect_section_data(self, all_sections: List[List[Dict[str, Any]]],
1389
+ sec_idx: int, num_models: int) -> Dict[str, List]:
1390
+ """Collect data for a specific section across all samples."""
1391
+ # Separate collections for each model to avoid confusion
1392
+ section_data = {
1393
+ 'input_ids_per_model': [[] for _ in range(num_models)], # [[slm_seqs], [llm_seqs]]
1394
+ 'attention_mask_per_model': [[] for _ in range(num_models)],
1395
+ 'labels': [],
1396
+ 'kv_cache_index': [],
1397
+ 'position_ids': []
1398
+ }
1399
+
1400
+ for sample_sections in all_sections:
1401
+ # Some samples may have fewer sections; create default empty tensors when missing
1402
+ if sec_idx < len(sample_sections):
1403
+ sec = sample_sections[sec_idx]
1404
+ for model_idx in range(num_models):
1405
+ section_data['input_ids_per_model'][model_idx].append(sec['input_ids'][model_idx])
1406
+ section_data['attention_mask_per_model'][model_idx].append(sec['attention_mask'][model_idx])
1407
+ section_data['labels'].append(sec['labels'])
1408
+ section_data['kv_cache_index'].append(sec['kv_cache_index'])
1409
+ section_data['position_ids'].append(sec['position_ids'])
1410
+ else:
1411
+ # Default empty tensors; downstream pad_sequence will pad appropriately
1412
+ for model_idx in range(num_models):
1413
+ section_data['input_ids_per_model'][model_idx].append(torch.tensor([], dtype=torch.long))
1414
+ section_data['attention_mask_per_model'][model_idx].append(torch.tensor([], dtype=torch.float))
1415
+ section_data['labels'].append(torch.tensor([], dtype=torch.long))
1416
+ section_data['kv_cache_index'].append(torch.empty((0, 2), dtype=torch.long))
1417
+ section_data['position_ids'].append(torch.tensor([], dtype=torch.long))
1418
+
1419
+ return section_data
1420
+
1421
+ def _pad_single_section(self, section_data: Dict[str, List], num_models: int) -> Dict[str, Any]:
1422
+ """Pad tensors within a single section."""
1423
+ # Pad input_ids separately for each model with their respective pad tokens
1424
+ padded_input_ids_per_model = []
1425
+ padded_attention_mask_per_model = []
1426
+
1427
+ for model_idx in range(num_models):
1428
+ pad_token_id = self.slm_pad_token_id if model_idx == 0 else self.llm_pad_token_id
1429
+
1430
+ # Pad input_ids for this model
1431
+ padded_input_ids = torch.nn.utils.rnn.pad_sequence(
1432
+ section_data['input_ids_per_model'][model_idx],
1433
+ batch_first=True,
1434
+ padding_value=pad_token_id
1435
+ )
1436
+ padded_input_ids_per_model.append(padded_input_ids)
1437
+
1438
+ # Pad attention_mask for this model
1439
+ padded_attention_mask = torch.nn.utils.rnn.pad_sequence(
1440
+ section_data['attention_mask_per_model'][model_idx],
1441
+ batch_first=True,
1442
+ padding_value=0
1443
+ )
1444
+ padded_attention_mask_per_model.append(padded_attention_mask)
1445
+
1446
+ # Standard padding for other tensors
1447
+ padded_labels = torch.nn.utils.rnn.pad_sequence(
1448
+ section_data['labels'], batch_first=True, padding_value=-100)
1449
+ padded_kv_cache = torch.nn.utils.rnn.pad_sequence(
1450
+ section_data['kv_cache_index'], batch_first=True, padding_value=-1)
1451
+ padded_position_ids = torch.nn.utils.rnn.pad_sequence(
1452
+ section_data['position_ids'], batch_first=True, padding_value=0)
1453
+
1454
+ return {
1455
+ 'input_ids_per_model': padded_input_ids_per_model, # Keep separate per model
1456
+ 'attention_mask_per_model': padded_attention_mask_per_model, # Keep separate per model
1457
+ 'labels': padded_labels,
1458
+ 'kv_cache_index': padded_kv_cache,
1459
+ 'position_ids': padded_position_ids,
1460
+ 'num_models': num_models
1461
+ }
1462
+
1463
+ def _finalize_output(self, padded_output: Dict[str, List],
1464
+ num_models: int, batch_size: int) -> Dict[str, Any]:
1465
+ """Finalize the output by concatenating sections - keep models separate throughout."""
1466
+ final_output = {}
1467
+
1468
+ # Handle input_ids and attention_mask - keep separate per model
1469
+ if num_models == 1:
1470
+ # Single model case: concatenate sections for the single model
1471
+ final_output['input_ids'] = torch.cat(padded_output['input_ids_per_model'][0], dim=1)
1472
+ final_output['attention_mask'] = torch.cat(padded_output['attention_mask_per_model'][0], dim=1)
1473
+ else:
1474
+ # Multi-model case: keep as list of tensors, one per model
1475
+ final_output['input_ids'] = [
1476
+ torch.cat(padded_output['input_ids_per_model'][model_idx], dim=1)
1477
+ for model_idx in range(num_models)
1478
+ ]
1479
+ final_output['attention_mask'] = [
1480
+ torch.cat(padded_output['attention_mask_per_model'][model_idx], dim=1)
1481
+ for model_idx in range(num_models)
1482
+ ]
1483
+
1484
+ # Concatenate other tensors normally
1485
+ final_output['labels'] = torch.cat(padded_output['labels'], dim=1)
1486
+ final_output['position_ids'] = torch.cat(padded_output['position_ids'], dim=1)
1487
+ final_output['kv_cache_index'] = padded_output['kv_cache_index'] # Keep as list of sections
1488
+
1489
+ return final_output
1490
+
1491
+ def _apply_length_constraints(self, output: Dict[str, Any]) -> Dict[str, Any]:
1492
+ """Apply max_length truncation if specified."""
1493
+ if self.max_length is None:
1494
+ return output
1495
+
1496
+ # Determine current sequence length
1497
+ if isinstance(output['input_ids'], list):
1498
+ seq_length = output['input_ids'][0].size(1)
1499
+ else:
1500
+ seq_length = output['input_ids'].size(1)
1501
+
1502
+ if seq_length <= self.max_length:
1503
+ return output
1504
+
1505
+ # Truncate sequences
1506
+ if isinstance(output['input_ids'], list):
1507
+ output['input_ids'] = [ids[:, :self.max_length] for ids in output['input_ids']]
1508
+ output['attention_mask'] = [mask[:, :self.max_length] for mask in output['attention_mask']]
1509
+ else:
1510
+ output['input_ids'] = output['input_ids'][:, :self.max_length]
1511
+ output['attention_mask'] = output['attention_mask'][:, :self.max_length]
1512
+
1513
+ output['labels'] = output['labels'][:, :self.max_length]
1514
+ output['position_ids'] = output['position_ids'][:, :self.max_length]
1515
+
1516
+ # Truncate kv_cache_index sections appropriately
1517
+ output['kv_cache_index'] = self._truncate_kv_cache_sections(
1518
+ output['kv_cache_index'], self.max_length)
1519
+
1520
+ return output
1521
+
1522
+ def _truncate_kv_cache_sections(self, kv_cache_sections: List[torch.Tensor],
1523
+ max_length: int) -> List[torch.Tensor]:
1524
+ """Truncate kv_cache sections to fit within max_length."""
1525
+ truncated_sections = []
1526
+ current_pos = 0
1527
+
1528
+ for section in kv_cache_sections:
1529
+ section_length = section.size(1)
1530
+ remaining_length = max_length - current_pos
1531
+
1532
+ if remaining_length <= 0:
1533
+ break
1534
+ elif remaining_length >= section_length:
1535
+ truncated_sections.append(section)
1536
+ current_pos += section_length
1537
+ else:
1538
+ truncated_section = section[:, :remaining_length]
1539
+ truncated_sections.append(truncated_section)
1540
+ break
1541
+
1542
+ return truncated_sections
1543
+
1544
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
1545
+ """
1546
+ Main collation function with improved logic.
1547
+
1548
+ Args:
1549
+ features: List of feature dictionaries from dataset
1550
+
1551
+ Returns:
1552
+ Batched and padded output dictionary
1553
+ """
1554
+ if not features:
1555
+ return {}
1556
+
1557
+ # Step 1: Normalize input format for all features
1558
+ normalized_features = [self._normalize_input_format(feat) for feat in features]
1559
+
1560
+ # Step 2: Split each feature into sections
1561
+ all_sections = [self._split_into_sections(feat) for feat in normalized_features]
1562
+
1563
+ # Step 3: Pad sections to create uniform batch structure
1564
+ output = self._pad_sections(all_sections)
1565
+
1566
+ # Step 4: Apply length constraints if needed
1567
+ output = self._apply_length_constraints(output)
1568
+
1569
+ return output
1570
+
1571
+
1572
+ class BaselineDataCollator:
1573
+ """Custom data collator for baseline model training"""
1574
+
1575
+ def __init__(self, tokenizer: AutoTokenizer, pad_to_multiple_of: Optional[int] = None):
1576
+ self.tokenizer = tokenizer
1577
+ self.pad_to_multiple_of = pad_to_multiple_of
1578
+
1579
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
1580
+ # Extract input_ids and labels
1581
+ input_ids = [f["input_ids"] for f in features]
1582
+ labels = [f["labels"] for f in features]
1583
+
1584
+ # Find max length in batch
1585
+ max_length = max(len(ids) for ids in input_ids)
1586
+
1587
+ # Apply pad_to_multiple_of if specified
1588
+ if self.pad_to_multiple_of is not None:
1589
+ max_length = ((max_length + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of) * self.pad_to_multiple_of
1590
+
1591
+ # Pad sequences
1592
+ batch_input_ids = []
1593
+ batch_labels = []
1594
+ batch_attention_mask = []
1595
+
1596
+ for ids, lbls in zip(input_ids, labels):
1597
+ # Pad input_ids
1598
+ padded_ids = ids + [self.tokenizer.pad_token_id] * (max_length - len(ids))
1599
+ batch_input_ids.append(padded_ids)
1600
+
1601
+ # Pad labels (use -100 for padding)
1602
+ padded_labels = lbls + [-100] * (max_length - len(lbls))
1603
+ batch_labels.append(padded_labels)
1604
+
1605
+ # Create attention mask
1606
+ attention_mask = [1] * len(ids) + [0] * (max_length - len(ids))
1607
+ batch_attention_mask.append(attention_mask)
1608
+
1609
+ return {
1610
+ "input_ids": torch.tensor(batch_input_ids, dtype=torch.long),
1611
+ "labels": torch.tensor(batch_labels, dtype=torch.long),
1612
+ "attention_mask": torch.tensor(batch_attention_mask, dtype=torch.long),
1613
+ }
1614
+
1615
+
1616
+
1617
+ """
1618
+ Helper functions
1619
+ """
1620
+
1621
+
1622
+ def create_dataset(dataset_type: str, **kwargs) -> Dataset:
1623
+ """
1624
+ Factory function to create a dataset based on type.
1625
+
1626
+ Args:
1627
+ dataset_type: String indicating the type of dataset
1628
+ **kwargs: Additional arguments to pass to the dataset constructor
1629
+
1630
+ Returns:
1631
+ An instance of the appropriate dataset
1632
+ """
1633
+ # First, check if dataset_type is directly in the registry (exact match)
1634
+ if dataset_type in DATASET_REGISTRY:
1635
+ return DATASET_REGISTRY[dataset_type](**kwargs)
1636
+
1637
+ # Then check for case-insensitive match
1638
+ dataset_type_lower = dataset_type.lower()
1639
+ if dataset_type_lower in DATASET_REGISTRY:
1640
+ return DATASET_REGISTRY[dataset_type_lower](**kwargs)
1641
+
1642
+ # If not found in registry, raise an error with valid options
1643
+ valid_options = list(
1644
+ set([name for name, cls in DATASET_REGISTRY.items() if name == cls.__name__])
1645
+ ) # Only include actual class names
1646
+ raise ValueError(
1647
+ f"Unknown dataset type: {dataset_type}. Valid options are: {valid_options}"
1648
+ )
rosetta/train/model_utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model setup utilities for RosettaModel training/evaluation
3
+ """
4
+
5
+ import torch
6
+ from typing import Dict, Any, List
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ from rosetta.model.wrapper import RosettaModel
10
+ from rosetta.model.projector import create_projector
11
+
12
+ """
13
+ Mapping strategies
14
+ """
15
+ def k_nearest_sources(num_target_layers: int, num_source_layers: int, k: int) -> Dict[int, List[int]]:
16
+ """
17
+ Compute a per-target mapping to K nearest source layers.
18
+
19
+ Returns: Dict[target_idx, List[source_idx]] only for targets we map.
20
+ Distances are computed by placing target and source layers uniformly in [0, 1]
21
+ and sorting by absolute distance.
22
+ """
23
+ if num_target_layers <= 1:
24
+ target_positions = [0.0]
25
+ else:
26
+ target_positions = [i / (num_target_layers - 1) for i in range(num_target_layers)]
27
+ if num_source_layers <= 1:
28
+ source_positions = [0.0]
29
+ else:
30
+ source_positions = [j / (num_source_layers - 1) for j in range(num_source_layers)]
31
+
32
+ mapping: Dict[int, List[int]] = {}
33
+ for t_idx, t_pos in enumerate(target_positions):
34
+ sorted_src = sorted(range(num_source_layers), key=lambda j: abs(source_positions[j] - t_pos))
35
+ chosen = sorted_src[:max(0, k)]
36
+ if len(chosen) > 0:
37
+ mapping[t_idx] = chosen
38
+ return mapping
39
+
40
+
41
+ def last_aligned_sources(num_target_layers: int, num_source_layers: int, k: int = 1) -> Dict[int, List[int]]:
42
+ """
43
+ Return a per-target mapping that aligns the last target layer to the last
44
+ source layer and walks toward the front.
45
+
46
+ Returns: Dict[target_idx, List[source_idx]] only for targets we map. For each
47
+ target t, we choose up to K sources anchored at the aligned index, preferring
48
+ backward indices first then forward to satisfy K.
49
+
50
+ Example (T=11, S=33): target 10 -> [32, 31, ...], target 9 -> [31, 30, ...]
51
+ """
52
+ mapping: Dict[int, List[int]] = {}
53
+ if num_target_layers <= 0 or num_source_layers <= 0:
54
+ return mapping
55
+
56
+ # Align ends; offset >= 0 means extra source layers at the front
57
+ offset = num_source_layers - num_target_layers
58
+
59
+ def take_k_from(s0: int) -> List[int]:
60
+ result: List[int] = []
61
+ # Prefer moving backward from the anchor (last-to-front)
62
+ for back in range(k):
63
+ idx = s0 - back
64
+ if 0 <= idx < num_source_layers:
65
+ result.append(idx)
66
+ # If not enough due to boundary, extend forward
67
+ next_idx = s0 + 1
68
+ while len(result) < k and next_idx < num_source_layers:
69
+ result.append(next_idx)
70
+ next_idx += 1
71
+ return result
72
+
73
+ for t in range(num_target_layers):
74
+ s0 = offset + t
75
+ # Clamp to valid range for edge cases (e.g., fewer source layers)
76
+ if s0 < 0:
77
+ s0 = 0
78
+ elif s0 > num_source_layers - 1:
79
+ s0 = num_source_layers - 1
80
+ chosen = take_k_from(s0)
81
+ if len(chosen) > 0:
82
+ mapping[t] = chosen
83
+
84
+ return mapping
85
+
86
+
87
+ def setup_models(model_config: Dict[str, Any], device: str = "cuda", dtype: torch.dtype = torch.bfloat16):
88
+ """Setup RosettaModel with base model, teacher model, and projectors"""
89
+
90
+ # Load tokenizer
91
+ tokenizer = AutoTokenizer.from_pretrained(model_config["base_model"])
92
+ if tokenizer.pad_token is None:
93
+ tokenizer.pad_token = tokenizer.eos_token
94
+
95
+ # Load models
96
+ base_model = AutoModelForCausalLM.from_pretrained(
97
+ model_config["base_model"],
98
+ torch_dtype=dtype,
99
+ device_map=device
100
+ )
101
+
102
+ teacher_model = AutoModelForCausalLM.from_pretrained(
103
+ model_config["teacher_model"],
104
+ torch_dtype=dtype,
105
+ device_map=device
106
+ )
107
+
108
+ # Create projector
109
+ projector_config = model_config["projector"]
110
+ projector_params = projector_config["params"].copy()
111
+ projector_params["dtype"] = dtype
112
+
113
+ projector = create_projector(
114
+ projector_config["type"],
115
+ source_dim=teacher_model.config.head_dim,
116
+ target_dim=base_model.config.head_dim,
117
+ **projector_params
118
+ )
119
+
120
+ # Setup RosettaModel
121
+ rosetta_model = RosettaModel(
122
+ model_list=[base_model, teacher_model],
123
+ base_model_idx=0,
124
+ projector_list=[projector]
125
+ ).to(device)
126
+
127
+ # Configure projector mappings
128
+ num_layers_to_map = min(
129
+ base_model.config.num_hidden_layers,
130
+ teacher_model.config.num_hidden_layers
131
+ )
132
+
133
+ for layer_idx in range(num_layers_to_map):
134
+ rosetta_model.set_projector_config(
135
+ source_model_idx=1, # Teacher
136
+ source_model_layer_idx=layer_idx,
137
+ target_model_idx=0, # Base
138
+ target_model_layer_idx=layer_idx,
139
+ projector_idx=0
140
+ )
141
+
142
+ return rosetta_model, tokenizer
rosetta/utils/__pycache__/evaluate.cpython-310.pyc ADDED
Binary file (15.8 kB). View file
 
rosetta/utils/__pycache__/evaluate.cpython-312.pyc ADDED
Binary file (24 kB). View file
 
rosetta/utils/__pycache__/multi_stage.cpython-310.pyc ADDED
Binary file (5.22 kB). View file
 
rosetta/utils/__pycache__/multi_stage.cpython-312.pyc ADDED
Binary file (7.49 kB). View file
 
rosetta/utils/__pycache__/registry.cpython-310.pyc ADDED
Binary file (7.43 kB). View file
 
rosetta/utils/__pycache__/registry.cpython-312.pyc ADDED
Binary file (9.98 kB). View file
 
rosetta/utils/evaluate.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common evaluation utilities for benchmark tasks.
3
+
4
+ This module provides shared functions for model evaluation across different benchmarks
5
+ like MMLU-Redux and MMMLU.
6
+ """
7
+
8
+ import re
9
+ import os
10
+ import torch
11
+ import torch.nn as nn
12
+ import numpy as np
13
+ from typing import Dict, Any, List, Tuple, Optional
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
+
16
+ from rosetta.model.projector import load_projector
17
+ from rosetta.model.wrapper import RosettaModel
18
+ from rosetta.model.oracle import OracleRosettaModel
19
+
20
+ def build_prompt(dataset: str, locale: str, question: str, choices: str, use_cot: bool, use_template: bool = True) -> str:
21
+ """
22
+ Build a localized prompt for a given dataset and locale.
23
+
24
+ Currently supports:
25
+ - dataset: "mmmlu"
26
+ - locale: "SW_KE" (Swahili). Other locales fall back to English.
27
+
28
+ Args:
29
+ dataset: Dataset identifier (e.g., "mmmlu")
30
+ locale: Locale/subject code (e.g., "SW_KE")
31
+ question: Question text
32
+ choices: Formatted choices string
33
+ use_cot: Whether to include CoT instruction
34
+
35
+ Returns:
36
+ Localized prompt string
37
+ """
38
+
39
+ # Unified default English templates (shared by MMLU and MMMLU)
40
+ if not use_cot:
41
+ template = """Accurately answer the following question:
42
+
43
+ {{question}}
44
+
45
+ Choices:
46
+ {{choices}}
47
+
48
+ Instructions:
49
+ - Carefully read the question and all options.
50
+ - Select the single most correct answer.
51
+ - Respond ONLY in the following format: "The correct answer is A/B/C/D".
52
+ - Do not include any explanations, additional text, or punctuation besides the answer.
53
+
54
+ The correct answer is"""
55
+
56
+ else:
57
+ template = """Accurately answer the following question:
58
+
59
+ {{question}}
60
+
61
+ Choices:
62
+ {{choices}}
63
+
64
+ Instructions:
65
+ - Carefully read the question and all options.
66
+ - Let's think step by step and explain your reasoning briefly.
67
+ - Then give the final answer starting with The correct answer is"""
68
+
69
+ prompt = template.replace("{{question}}", question)
70
+ prompt = prompt.replace("{{choices}}", choices)
71
+
72
+ if not use_template:
73
+ prompt = question + "\n\nChoices:\n" + choices
74
+
75
+ return prompt
76
+
77
+
78
+ def parse_answer(answer_str: str) -> List[str]:
79
+ """
80
+ Parse answer string to extract valid answer options.
81
+ Converts digits 0/1/2/3 to letters A/B/C/D.
82
+
83
+ Args:
84
+ answer_str: String containing answer digits
85
+
86
+ Returns:
87
+ List of parsed answer letters
88
+ """
89
+ if not isinstance(answer_str, str):
90
+ return []
91
+ valid_digits = [c for c in answer_str if c in {'0','1','2','3'}]
92
+ return sorted(list({
93
+ chr(65 + int(d)) # 0->A, 1->B, 2->C, 3->D
94
+ for d in valid_digits
95
+ }))
96
+
97
+
98
+ def extract_answer_from_content(text: str) -> Optional[str]:
99
+ """
100
+ Extract answer from model output with robust multi-pattern matching.
101
+ Supports multiple languages and response formats.
102
+
103
+ Args:
104
+ text: Model output text
105
+
106
+ Returns:
107
+ Extracted answer letter or None
108
+ """
109
+ text = text.strip()
110
+ if not text:
111
+ return None
112
+
113
+ # Define multiple answer patterns for different languages and formats
114
+ answer_patterns = [
115
+ # English patterns
116
+ r'Answer:\s*(.*)',
117
+ r'answer:\s*(.*)',
118
+ r'ANSWER:\s*(.*)',
119
+ r'Your answer:\s*(.*)',
120
+ r'your answer:\s*(.*)',
121
+ r'YOUR ANSWER:\s*(.*)',
122
+ r'The answer is\s*(.*)',
123
+ r'the answer is\s*(.*)',
124
+ r'THE ANSWER IS\s*(.*)',
125
+ r'Correct answer is\s*(.*)',
126
+ r'correct answer is\s*(.*)',
127
+ r'Correct answer is:\s*(.*)',
128
+ r'correct answer is:\s*(.*)',
129
+ r'Correct answer:\s*(.*)',
130
+ r'correct answer:\s*(.*)',
131
+ r'CORRECT ANSWER:\s*(.*)',
132
+
133
+ # Swahili patterns
134
+ r'Jibu lako:\s*(.*)',
135
+ r'jibu lako:\s*(.*)',
136
+ r'JIBU LAKO:\s*(.*)',
137
+ r'Jibu:\s*(.*)',
138
+ r'jibu:\s*(.*)',
139
+ r'JIBU:\s*(.*)',
140
+ r'Jibu sahihi:\s*(.*)',
141
+ r'jibu sahihi:\s*(.*)',
142
+ r'JIBU SAHIHI:\s*(.*)',
143
+
144
+ # Other common patterns
145
+ r'Response:\s*(.*)',
146
+ r'response:\s*(.*)',
147
+ r'RESPONSE:\s*(.*)',
148
+ r'Choice:\s*(.*)',
149
+ r'choice:\s*(.*)',
150
+ r'CHOICE:\s*(.*)',
151
+ r'Option:\s*(.*)',
152
+ r'option:\s*(.*)',
153
+ r'OPTION:\s*(.*)',
154
+ ]
155
+
156
+ # 1. Try to match any of the answer patterns
157
+ for pattern in answer_patterns:
158
+ match = re.search(pattern, text, re.IGNORECASE)
159
+ if match:
160
+ answer_part = match.group(1).strip()
161
+ # Search for first A-D letter in the matched part
162
+ for char in answer_part:
163
+ if char in {'A', 'B', 'C', 'D'}:
164
+ return char
165
+
166
+ # 2. Look for standalone A-D letters that are likely answers
167
+ # Prioritize letters at the end of text or with clear answer-like context
168
+ standalone_patterns = [
169
+ r'\b([A-D])(?:\s*[.,!?:)]?\s*$)', # A-D at end of text with optional punctuation
170
+ r'\b([A-D])(?:\s*[.,!?:)]\s)', # A-D followed by punctuation and space
171
+ r'(?:^|\s)([A-D])(?:\s*$)', # A-D at start or with word boundary at end
172
+ ]
173
+
174
+ for pattern in standalone_patterns:
175
+ matches = re.findall(pattern, text, re.IGNORECASE)
176
+ if matches:
177
+ # Check if this looks like mathematical expressions rather than answers
178
+ math_indicators = ['+', '-', '*', '/', '=', '^', 'x^', 'y^', 'z^', 'mod', 'sqrt', 'sin', 'cos', 'tan']
179
+ has_math = any(indicator in text for indicator in math_indicators)
180
+ has_answer_indicators = any(phrase in text.lower() for phrase in ['jibu', 'answer', 'choice', 'option', 'response', 'correct', 'sahihi'])
181
+
182
+ # If it has math indicators but no answer indicators, it's likely mathematical notation
183
+ if has_math and not has_answer_indicators:
184
+ continue # Skip this match, try next pattern
185
+
186
+ return matches[-1].upper()
187
+
188
+ # 3. Fallback: find all A-D letters but be more selective
189
+ all_letters = re.findall(r'\b([A-D])\b', text, re.IGNORECASE)
190
+ if all_letters:
191
+ # Check if this looks like mathematical expressions rather than answers
192
+ math_indicators = ['+', '-', '*', '/', '=', '^', 'x^', 'y^', 'z^', 'mod', 'sqrt', 'sin', 'cos', 'tan']
193
+ has_math = any(indicator in text for indicator in math_indicators)
194
+ has_answer_indicators = any(phrase in text.lower() for phrase in ['jibu', 'answer', 'choice', 'option', 'response', 'correct', 'sahihi'])
195
+
196
+ # If it has math indicators but no answer indicators, it's likely mathematical notation
197
+ if has_math and not has_answer_indicators:
198
+ return None
199
+
200
+ # Otherwise, return the last letter found
201
+ return all_letters[-1].upper()
202
+
203
+ # 3. Search backwards for any A-D letter as fallback
204
+ for char in reversed(text):
205
+ if char in {'A', 'B', 'C', 'D'}:
206
+ return char
207
+
208
+ return None
209
+
210
+
211
+ def apply_generation_config(model: Any, generation_config: Optional[Dict[str, Any]] = None) -> None:
212
+ """
213
+ Apply generation configuration to a model and handle sampling parameters.
214
+
215
+ This function applies the provided generation config to the model and removes
216
+ sampling parameters (temperature, top_p, top_k, min_p) when do_sample=False
217
+ to avoid warnings from the transformers library. If no config is provided,
218
+ it defaults to greedy decoding with cleaned sampling parameters.
219
+
220
+ Args:
221
+ model: Model object with generation_config attribute
222
+ generation_config: Optional generation configuration dictionary.
223
+ If None, defaults to greedy decoding (do_sample=False).
224
+ """
225
+ if not hasattr(model, 'generation_config'):
226
+ return
227
+
228
+ # If no config provided, default to greedy decoding
229
+ if not generation_config:
230
+ generation_config = {'do_sample': False}
231
+
232
+ # Apply all configuration parameters
233
+ for key, value in generation_config.items():
234
+ setattr(model.generation_config, key, value)
235
+
236
+ # Disable sampling parameters if do_sample=False to avoid warnings
237
+ # We set them to None instead of deleting, since some model code may
238
+ # access these attributes unconditionally.
239
+ if not generation_config.get('do_sample', True):
240
+ sampling_params = ['temperature', 'top_p', 'top_k', 'min_p', 'repetition_penalty']
241
+ for param in sampling_params:
242
+ try:
243
+ setattr(model.generation_config, param, None)
244
+ except Exception:
245
+ # If the backend does not allow setting, ignore silently
246
+ pass
247
+
248
+
249
+ def set_default_chat_template(tokenizer, model_name: str):
250
+ """
251
+ Set default chat template for models without one.
252
+
253
+ Args:
254
+ tokenizer: Tokenizer object
255
+ model_name: Name of the model
256
+ """
257
+ if tokenizer.chat_template is None:
258
+ if "UlizaLlama3".lower() in model_name.lower():
259
+ tokenizer.chat_template = (
260
+ "{%- for message in messages %}"
261
+ "{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}"
262
+ "{{- message['content'] }}"
263
+ "{{- '<|eot_id|>' }}"
264
+ "{%- endfor %}"
265
+ "{%- if add_generation_prompt %}"
266
+ "{{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}"
267
+ "{%- endif %}"
268
+ )
269
+ else:
270
+ print(f"Model {model_name} has no chat template, setting default template...")
271
+ default_template = """{% for message in messages %}{% if message['role'] == 'user' %}### Human: {{ message['content'] }}{% elif message['role'] == 'assistant' %}### Assistant: {{ message['content'] }}{% endif %}{% if not loop.last %}
272
+ {% endif %}{% endfor %}{% if add_generation_prompt %}
273
+ ### Assistant:{% endif %}"""
274
+ tokenizer.chat_template = default_template
275
+ print("Default chat template has been set.")
276
+ else:
277
+ print(f"Model {model_name} already has a chat template.")
278
+
279
+
280
+ def load_hf_model(model_name: str, device: torch.device, generation_config: Optional[Dict[str, Any]] = None) -> Tuple[Any, Any]:
281
+ """
282
+ Load Hugging Face model and tokenizer.
283
+
284
+ Args:
285
+ model_name: Model name or path
286
+ device: Device to load model on
287
+ generation_config: Optional generation configuration to apply
288
+
289
+ Returns:
290
+ Tuple of (model, tokenizer)
291
+ """
292
+ tokenizer = AutoTokenizer.from_pretrained(
293
+ str(model_name),
294
+ trust_remote_code=True,
295
+ padding_side='left'
296
+ )
297
+
298
+ if tokenizer.pad_token is None:
299
+ tokenizer.pad_token = tokenizer.eos_token
300
+
301
+ # Check and set chat template
302
+ set_default_chat_template(tokenizer, model_name)
303
+
304
+ if model_name == "google/gemma-3-1b-it":
305
+ torch._dynamo.config.cache_size_limit = 64
306
+ model = AutoModelForCausalLM.from_pretrained(
307
+ str(model_name),
308
+ torch_dtype=torch.bfloat16,
309
+ device_map={"": device},
310
+ sliding_window=4096
311
+ ).eval()
312
+ else:
313
+ model = AutoModelForCausalLM.from_pretrained(
314
+ str(model_name),
315
+ torch_dtype=torch.bfloat16,
316
+ device_map={"": device}
317
+ ).eval()
318
+
319
+ # Apply generation config
320
+ apply_generation_config(model, generation_config)
321
+
322
+ return model, tokenizer
323
+
324
+
325
+ def load_rosetta_model(model_config: Dict[str, Any], eval_config: Dict[str, Any],
326
+ device: torch.device, generation_config: Optional[Dict[str, Any]] = None) -> Tuple[Any, Any]:
327
+ """
328
+ Load Rosetta model with projectors and aggregators.
329
+
330
+ Args:
331
+ model_config: Model configuration dict
332
+ eval_config: Evaluation configuration dict
333
+ device: Device to load model on
334
+ generation_config: Optional generation configuration to apply
335
+
336
+ Returns:
337
+ Tuple of (rosetta_model, tokenizer)
338
+ """
339
+ # Prefer checkpoints_dir under model.rosetta_config; fall back to eval config for backward compatibility
340
+ rosetta_config = model_config["rosetta_config"]
341
+ checkpoint_dir = rosetta_config.get("checkpoints_dir", eval_config.get("checkpoints_dir"))
342
+ if checkpoint_dir is None:
343
+ raise KeyError("checkpoints_dir must be provided under model.rosetta_config (preferred) or eval config (legacy)")
344
+ slm_model_path = rosetta_config["base_model"]
345
+ llm_model_path = rosetta_config["teacher_model"]
346
+
347
+ # Load tokenizer
348
+ slm_tokenizer = AutoTokenizer.from_pretrained(str(slm_model_path))
349
+ set_default_chat_template(slm_tokenizer, slm_model_path)
350
+
351
+ # Load models
352
+ slm_model = AutoModelForCausalLM.from_pretrained(
353
+ str(slm_model_path),
354
+ torch_dtype=torch.bfloat16,
355
+ device_map={"": device}
356
+ ).eval()
357
+
358
+ # Apply generation config to SLM
359
+ apply_generation_config(slm_model, generation_config)
360
+
361
+ if llm_model_path == "google/gemma-3-1b-it":
362
+ llm_model = AutoModelForCausalLM.from_pretrained(
363
+ str(llm_model_path),
364
+ torch_dtype=torch.bfloat16,
365
+ device_map={"": device},
366
+ sliding_window=4096
367
+ ).eval()
368
+ else:
369
+ llm_model = AutoModelForCausalLM.from_pretrained(
370
+ str(llm_model_path),
371
+ torch_dtype=torch.bfloat16,
372
+ device_map={"": device}
373
+ ).eval()
374
+
375
+ # Apply generation config to LLM
376
+ apply_generation_config(llm_model, generation_config)
377
+
378
+ # Load projectors
379
+ num_projectors = len([f for f in os.listdir(checkpoint_dir) if re.match(r"projector_\d+\.pt", f)])
380
+ projector_list = []
381
+ for t in range(num_projectors):
382
+ json_cfg = os.path.join(checkpoint_dir, f"projector_{t}.json")
383
+ proj = load_projector(json_cfg)
384
+ proj = proj.to(device)
385
+ pt_path = os.path.join(checkpoint_dir, f"projector_{t}.pt")
386
+ if os.path.exists(pt_path):
387
+ state_dict = torch.load(pt_path, map_location=device)
388
+ proj.load_state_dict(state_dict, strict=False)
389
+ projector_list.append(proj)
390
+
391
+ aggregator_list = []
392
+
393
+ # Initialize Rosetta model
394
+ rosetta_model = RosettaModel(
395
+ model_list=[slm_model, llm_model],
396
+ base_model_idx=0,
397
+ projector_list=projector_list,
398
+ aggregator_list=aggregator_list,
399
+ ).to(device).eval()
400
+
401
+ # Load projector/aggregator mapping configs
402
+ proj_cfg_path = os.path.join(checkpoint_dir, "projector_config.json")
403
+ agg_cfg_path = os.path.join(checkpoint_dir, "aggregator_config.json")
404
+ rosetta_model.load_projector_config(proj_cfg_path)
405
+ rosetta_model.load_aggregator_config(agg_cfg_path)
406
+
407
+ return rosetta_model, slm_tokenizer
408
+
409
+
410
+ def load_oracle_rosetta_model(model_config: Dict[str, Any], eval_config: Dict[str, Any],
411
+ device: torch.device) -> Tuple[Any, Any]:
412
+ """
413
+ Load Rosetta model with projectors and aggregators.
414
+
415
+ Args:
416
+ model_config: Model configuration dict
417
+ eval_config: Evaluation configuration dict
418
+ device: Device to load model on
419
+
420
+ Returns:
421
+ Tuple of (rosetta_model, tokenizer)
422
+ """
423
+ # Prefer checkpoints_dir under model.rosetta_config; fall back to eval config for backward compatibility
424
+ rosetta_config = model_config["rosetta_config"]
425
+ checkpoint_dir = rosetta_config.get("checkpoints_dir", eval_config.get("checkpoints_dir"))
426
+ if checkpoint_dir is None:
427
+ raise KeyError("checkpoints_dir must be provided under model.rosetta_config (preferred) or eval config (legacy)")
428
+ slm_model_path = rosetta_config["base_model"]
429
+ llm_model_path = rosetta_config["teacher_model"]
430
+
431
+ # Load tokenizer
432
+ slm_tokenizer = AutoTokenizer.from_pretrained(str(slm_model_path))
433
+ set_default_chat_template(slm_tokenizer, slm_model_path)
434
+
435
+ # Load models
436
+ slm_model = AutoModelForCausalLM.from_pretrained(
437
+ str(slm_model_path),
438
+ torch_dtype=torch.bfloat16,
439
+ device_map={"": device}
440
+ ).eval()
441
+
442
+ llm_model = AutoModelForCausalLM.from_pretrained(
443
+ str(llm_model_path),
444
+ torch_dtype=torch.bfloat16,
445
+ device_map={"": device}
446
+ ).eval()
447
+
448
+ # Load projectors
449
+ num_projectors = len([f for f in os.listdir(checkpoint_dir) if re.match(r"projector_\d+\.pt", f)])
450
+ projector_list = []
451
+ for t in range(num_projectors):
452
+ json_cfg = os.path.join(checkpoint_dir, f"projector_{t}.json")
453
+ proj = load_projector(json_cfg)
454
+ proj = proj.to(device)
455
+ pt_path = os.path.join(checkpoint_dir, f"projector_{t}.pt")
456
+ if os.path.exists(pt_path):
457
+ state_dict = torch.load(pt_path, map_location=device)
458
+ proj.load_state_dict(state_dict, strict=False)
459
+ projector_list.append(proj)
460
+
461
+ aggregator_list = []
462
+
463
+ # Initialize Rosetta model
464
+ rosetta_model = OracleRosettaModel(
465
+ model_list=[slm_model, llm_model],
466
+ base_model_idx=0,
467
+ projector_list=projector_list,
468
+ aggregator_list=aggregator_list,
469
+ ).to(device).eval()
470
+
471
+ # Load projector/aggregator mapping configs
472
+ proj_cfg_path = os.path.join(checkpoint_dir, "projector_config.json")
473
+ agg_cfg_path = os.path.join(checkpoint_dir, "aggregator_config.json")
474
+ rosetta_model.load_projector_config(proj_cfg_path)
475
+ rosetta_model.load_aggregator_config(agg_cfg_path)
476
+
477
+ return rosetta_model, slm_tokenizer
478
+
479
+
480
+ def get_option_token_ids(tokenizer, num_options: int = 4) -> List[int]:
481
+ """
482
+ Get token IDs for options A, B, C, D (or more up to J).
483
+
484
+ Args:
485
+ tokenizer: Tokenizer object
486
+ num_options: Number of options to get (default 4 for A-D, max 10 for A-J)
487
+
488
+ Returns:
489
+ List of token IDs for options
490
+ """
491
+ # Limit to maximum of 10 options (A-J)
492
+ num_options = min(num_options, 10)
493
+ option_ids = []
494
+ for i in range(num_options):
495
+ letter = chr(65 + i) # A=65, B=66, etc.
496
+ ids = tokenizer.encode(" " + letter, add_special_tokens=False)
497
+ option_ids.append(ids[0] if ids else tokenizer.eos_token_id)
498
+ return option_ids
499
+
500
+ """
501
+ Deprecated
502
+ """
503
+
504
+ @torch.no_grad()
505
+ def generate_answer_with_logits(model, tokenizer, prompt: str, option_ids: List[int],
506
+ device: torch.device, model_type: str = "hf") -> Tuple[str, np.ndarray]:
507
+ """
508
+ Generate answer using logits method.
509
+
510
+ Args:
511
+ model: Model object
512
+ tokenizer: Tokenizer object
513
+ prompt: Input prompt
514
+ option_ids: Token IDs for options A, B, C, D
515
+ device: Device to run on
516
+ model_type: Type of model ("rosetta", "qwen", or "hf")
517
+
518
+ Returns:
519
+ Tuple of (predicted_answer, probabilities)
520
+ """
521
+ messages = [{
522
+ "role": "user",
523
+ "content": prompt
524
+ }]
525
+
526
+ # Try to apply chat template
527
+ try:
528
+ text = tokenizer.apply_chat_template(
529
+ messages,
530
+ tokenize=False,
531
+ add_generation_prompt=True,
532
+ enable_thinking=False if model_type == "qwen" else None
533
+ )
534
+ except Exception as e:
535
+ print(f"Failed to apply chat template for {model_type} model: {e}")
536
+ text = f"### Human: {prompt}\n### Assistant:"
537
+
538
+ text += "The correct answer is"
539
+ input_ids = tokenizer(text, return_tensors="pt").to(device)['input_ids']
540
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(device)
541
+ position_ids = attention_mask.long().cumsum(-1) - 1
542
+
543
+ if model_type == "rosetta":
544
+ instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(input_ids.shape[1]-1, 1).unsqueeze(0).to(device)
545
+ response_index = torch.tensor([[-1, 0]], dtype=torch.long).unsqueeze(0)
546
+ outputs = model.forward(
547
+ input_ids=input_ids,
548
+ attention_mask=attention_mask,
549
+ position_ids=position_ids,
550
+ kv_cache_index=[instruction_index, response_index]
551
+ )
552
+ else:
553
+ outputs = model(input_ids)
554
+
555
+ logits = outputs.logits[0, -1]
556
+ option_logits = torch.tensor([
557
+ logits[option_ids[0]].item(),
558
+ logits[option_ids[1]].item(),
559
+ logits[option_ids[2]].item(),
560
+ logits[option_ids[3]].item()
561
+ ])
562
+
563
+ probs = torch.nn.functional.softmax(option_logits, dim=0).numpy()
564
+ pred = chr(65 + np.argmax(probs))
565
+ return pred, probs
566
+
567
+
568
+ @torch.no_grad()
569
+ def generate_answer_with_generate(model, tokenizer, prompt: str, device: torch.device,
570
+ model_type: str = "hf") -> Tuple[str, np.ndarray, int, int, str]:
571
+ """
572
+ Generate answer using text generation method.
573
+
574
+ Args:
575
+ model: Model object
576
+ tokenizer: Tokenizer object
577
+ prompt: Input prompt
578
+ device: Device to run on
579
+ model_type: Type of model ("rosetta" or "hf")
580
+
581
+ Returns:
582
+ Tuple of (predicted_answer, probabilities, input_length, generation_length, generated_text)
583
+ """
584
+ messages = [{
585
+ "role": "user",
586
+ "content": prompt
587
+ }]
588
+
589
+ # Apply chat template
590
+ try:
591
+ text = tokenizer.apply_chat_template(
592
+ messages,
593
+ tokenize=False,
594
+ add_generation_prompt=True,
595
+ enable_thinking=False
596
+ )
597
+ except Exception as e:
598
+ print(f"Failed to apply chat template: {e}")
599
+ text = f"### Human: {prompt}\n### Assistant:"
600
+
601
+ # Prepare model input
602
+ inputs = tokenizer(text, return_tensors="pt").to(device)
603
+
604
+ # Generation parameters
605
+ sampling_params = {
606
+ 'do_sample': True,
607
+ 'temperature': 0.7,
608
+ 'top_p': 0.8,
609
+ 'top_k': 20,
610
+ 'min_p': 0.0,
611
+ 'repetition_penalty': 1.2,
612
+ 'max_new_tokens': 1024
613
+ }
614
+
615
+ # Generate text
616
+ outputs = model.generate(
617
+ input_ids=inputs.input_ids,
618
+ attention_mask=inputs.attention_mask,
619
+ **sampling_params
620
+ )
621
+
622
+ # Parse output
623
+ if isinstance(model, RosettaModel):
624
+ generated_ids = outputs[0]
625
+ else:
626
+ generated_ids = outputs[0][inputs.input_ids.shape[1]:]
627
+ content = tokenizer.decode(generated_ids, skip_special_tokens=True).strip("\n")
628
+
629
+ # Extract answer
630
+ pred = extract_answer_from_content(content)
631
+
632
+ # Return uniform distribution for generate method
633
+ probs = np.array([0.25, 0.25, 0.25, 0.25])
634
+
635
+ input_length = inputs.input_ids.shape[1]
636
+ gen_length = generated_ids.shape[0]
637
+
638
+ return pred, probs, input_length, gen_length, content
639
+
rosetta/utils/matheval.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import argparse
4
+ import jsonlines
5
+ from collections import defaultdict
6
+ from openai import OpenAI
7
+ from typing import Dict, Any
8
+ from math_verify import parse, verify, LatexExtractionConfig, ExprExtractionConfig, StringExtractionConfig
9
+ from latex2sympy2_extended import NormalizationConfig
10
+ from transformers import AutoTokenizer
11
+ import requests
12
+ import os
13
+
14
+ class MathEvaluator:
15
+
16
+ def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
17
+ raise NotImplementedError
18
+
19
+ def extract_after_think(self, text: str, truncate_length: int = 1000, finish_generation: bool = True) -> str:
20
+ pattern = r"</think>(.*)"
21
+ match = re.search(pattern, text, re.DOTALL)
22
+ return match.group(1).strip() if (match and finish_generation) else text[-truncate_length:]
23
+
24
+ def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extracted_answer: str = "", finish_generation: bool = True) -> str:
25
+ raise NotImplementedError
26
+
27
+ def get_llm_judge_prompt_not_finished(self, solution_str: str, ground_truth: str, extracted_answer: str = "", finish_generation: bool = True) -> str:
28
+ return f"""Please determine whether the final answer in the model-generated response was already correctly derived early in the reasoning process, and that the subsequent content consists mainly of unnecessary verification, overthinking, or repetitive reasoning. If correct is derived early, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
29
+ Reference answer: {ground_truth}
30
+ Model-generated response: {solution_str}
31
+ """.strip()
32
+
33
+ def llm_judge(self, solution_str: str, ground_truth: str, extracted_answer: str = "", finish_generation: bool = True) -> bool:
34
+ global OPENAI_CLIENT, MODEL_NAME
35
+ def get_inputs(scene_description):
36
+ body = [
37
+ {"role": "user", "content": scene_description},
38
+ ]
39
+ return body
40
+
41
+ def run_api(inputs):
42
+ completion = OPENAI_CLIENT.chat.completions.create(
43
+ model=MODEL_NAME,
44
+ messages=inputs
45
+ )
46
+ return completion.choices[0].message.content.strip()
47
+ if finish_generation:
48
+ scene_description = self.get_llm_judge_prompt(solution_str, ground_truth, extracted_answer, finish_generation)
49
+ else:
50
+ scene_description = self.get_llm_judge_prompt_not_finished(solution_str, ground_truth, extracted_answer, finish_generation)
51
+ inputs = get_inputs(scene_description)
52
+ response = run_api(inputs)
53
+
54
+ return "YES" in response
55
+
56
+
57
+ class AIMEEvaluator(MathEvaluator):
58
+ def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
59
+ # if not ground_truth.startswith("$"):
60
+ # ground_truth = f"${ground_truth}$"
61
+ gold = parse(
62
+ ground_truth,
63
+ extraction_config=[ExprExtractionConfig()],
64
+ )
65
+ answer = parse(
66
+ solution_str,
67
+ extraction_config=[
68
+ LatexExtractionConfig(
69
+ normalization_config=NormalizationConfig(
70
+ nits=False,
71
+ malformed_operators=False,
72
+ basic_latex=True,
73
+ boxed="all",
74
+ units=True,
75
+ ),
76
+ boxed_match_priority=0,
77
+ try_extract_without_anchor=False,
78
+ ),
79
+ ExprExtractionConfig(),
80
+ ],
81
+ extraction_mode="first_match",
82
+ )
83
+ if len(answer) == 0:
84
+ return False, "No extracted answer"
85
+ else:
86
+ return verify(gold, answer), str(answer)
87
+
88
+ def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str:
89
+ solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation)
90
+ return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a math question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
91
+ Model-generated answer: {solution_str}
92
+ Reference answer: {ground_truth}""".strip()
93
+
94
+
95
+ class GSM8KEvaluator(MathEvaluator):
96
+ def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
97
+ # if not ground_truth.startswith("$"):
98
+ # ground_truth = f"${ground_truth}$"
99
+ gold = parse(
100
+ ground_truth,
101
+ extraction_config=[ExprExtractionConfig()],
102
+ )
103
+ answer = parse(
104
+ solution_str,
105
+ extraction_config=[
106
+ LatexExtractionConfig(
107
+ normalization_config=NormalizationConfig(
108
+ nits=False,
109
+ malformed_operators=False,
110
+ basic_latex=True,
111
+ boxed="all",
112
+ units=True,
113
+ ),
114
+ boxed_match_priority=0,
115
+ try_extract_without_anchor=False,
116
+ ),
117
+ ExprExtractionConfig(),
118
+ ],
119
+ extraction_mode="first_match",
120
+ )
121
+ if len(answer) == 0:
122
+ return False, "No extracted answer"
123
+ else:
124
+ return verify(gold, answer), str(answer)
125
+
126
+ def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str:
127
+ solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation)
128
+ return f"""Please determine whether the final answer provided in the model-generated response with rule-based extracted answer is equivalent to the reference answer from a math question. The final answer may either be enclosed in the \\boxed{{}} or appear after the "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
129
+
130
+ 1. The reference answer does not include percentage signs, units or time formats (e.g., am, pm), but the Model-generated answer may include them.
131
+ For example, 1 is equivalent to 1 %, 1 kg, 1 am, 1 pm, 1:00 am, 1:00 pm, etc.
132
+ Model-generated answer: 1%
133
+ Reference answer: 1
134
+ Your output: YES
135
+
136
+ Model-generated answer: 1 kg
137
+ Reference answer: 1
138
+ Your output: YES
139
+
140
+ Model-generated answer: 1:00 pm
141
+ Reference answer: 1
142
+ Your output: YES
143
+
144
+ 2. The reference answer only includes one single number, but the Model-generated answer may include multiple numbers.
145
+ For example, 10 is equivalent to \\boxed{{(4, 6)}}, etc.
146
+ Model-generated answer: 5, 5
147
+ Reference answer: 10
148
+ Your output: YES
149
+
150
+ Model-generated answer: 4, 6
151
+ Reference answer: 10
152
+ Your output: YES
153
+
154
+ Model-generated answer: 86, 42
155
+ Reference answer: 128
156
+ Your output: YES
157
+
158
+ Now let's try a real example.
159
+ Model-generated answer: {solution_str}
160
+ Reference answer: {ground_truth}
161
+ """.strip()
162
+
163
+
164
+ class MATH500Evaluator(MathEvaluator):
165
+ def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
166
+ if not ground_truth.startswith("$"):
167
+ ground_truth = f"${ground_truth}$"
168
+ gold = parse(
169
+ ground_truth,
170
+ extraction_config=[LatexExtractionConfig()],
171
+ )
172
+ answer = parse(
173
+ solution_str,
174
+ extraction_config=[
175
+ LatexExtractionConfig(
176
+ normalization_config=NormalizationConfig(
177
+ nits=False,
178
+ malformed_operators=False,
179
+ basic_latex=True,
180
+ boxed="all",
181
+ units=True,
182
+ ),
183
+ boxed_match_priority=0,
184
+ try_extract_without_anchor=False,
185
+ ),
186
+ ExprExtractionConfig(),
187
+ ],
188
+ extraction_mode="first_match",
189
+ )
190
+ if len(answer) == 0:
191
+ return False, "No extracted answer"
192
+ else:
193
+ return verify(gold, answer), str(answer)
194
+ def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str:
195
+ solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation)
196
+ return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a math question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
197
+ Model-generated answer: {solution_str}
198
+ Reference answer: {ground_truth}""".strip()
199
+
200
+ class AMCEvaluator(MathEvaluator):
201
+ def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
202
+ if not ground_truth.startswith("$"):
203
+ ground_truth = f"${ground_truth}$"
204
+ gold = parse(
205
+ ground_truth,
206
+ extraction_config=[LatexExtractionConfig()],
207
+ )
208
+ answer = parse(
209
+ solution_str,
210
+ extraction_config=[
211
+ LatexExtractionConfig(
212
+ normalization_config=NormalizationConfig(
213
+ nits=False,
214
+ malformed_operators=False,
215
+ basic_latex=True,
216
+ boxed="all",
217
+ units=True,
218
+ ),
219
+ boxed_match_priority=0,
220
+ try_extract_without_anchor=False,
221
+ ),
222
+ ExprExtractionConfig(),
223
+ ],
224
+ extraction_mode="first_match",
225
+ )
226
+ if len(answer) == 0:
227
+ return False, "No extracted answer"
228
+ else:
229
+ return verify(gold, answer), str(answer)
230
+ def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str:
231
+ solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation)
232
+ return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a math question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
233
+ Model-generated answer: {solution_str}
234
+ Reference answer: {ground_truth}""".strip()
235
+
236
+
237
+ class GPQAEvaluator(MathEvaluator):
238
+ def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
239
+ # if not ground_truth.startswith("$"):
240
+ # ground_truth = f"${ground_truth}$"
241
+ gold = parse(
242
+ ground_truth,
243
+ extraction_config=[StringExtractionConfig()],
244
+ )
245
+ answer = parse(
246
+ solution_str,
247
+ extraction_config=[
248
+ StringExtractionConfig(),
249
+ ]
250
+ )
251
+ if len(answer) == 0:
252
+ return False, "No extracted answer"
253
+ else:
254
+ return verify(gold, answer), str(answer)
255
+
256
+ def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str:
257
+ solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation)
258
+ return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a multiple choice question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
259
+ Model-generated answer: {solution_str}
260
+ Reference answer: {ground_truth}""".strip()
261
+
262
+
263
+ # class MBPPEvaluator(Evaluator):
264
+ # def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
265
+ # return True, "No extracted answer"
266
+
267
+ # def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str:
268
+ # solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation)
269
+ # return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a multiple choice question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
270
+ # Model-generated answer: {solution_str}
271
+ # Reference answer: {ground_truth}""".strip()
272
+
273
+
274
+ # class HUMANEVALEvaluator(Evaluator):
275
+ # def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
276
+ # return True, "No extracted answer"
277
+
278
+ # def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str:
279
+ # solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation)
280
+ # return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a multiple choice question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
281
+ # Model-generated answer: {solution_str}
282
+ # Reference answer: {ground_truth}""".strip()
283
+
284
+
285
+ evaluator_map = {
286
+ "aime2024": AIMEEvaluator(),
287
+ "aime2025": AIMEEvaluator(),
288
+ "gsm8k": GSM8KEvaluator(),
289
+ "math500": MATH500Evaluator(),
290
+ "gpqa_diamond": GPQAEvaluator(),
291
+ "amc23": AMCEvaluator(),
292
+ }
293
+
294
+ API_BASE = None
295
+ DEPLOYMENT_NAME = None
296
+ API_VERSION = None
297
+ CONSTRUCTED_URL = None
298
+ API_KEY = None
299
+ HEADERS = None
300
+ OPENAI_CLIENT = None
301
+ MODEL_NAME = None
302
+
303
+ def set_client(api_base=None, deployment_name=None, api_version=None, api_key=None, model_name="gpt-4.1-2025-04-14"):
304
+ global API_BASE, DEPLOYMENT_NAME, API_VERSION, CONSTRUCTED_URL, API_KEY, HEADERS, MODEL_NAME, OPENAI_CLIENT
305
+
306
+ API_BASE = api_base
307
+ DEPLOYMENT_NAME = deployment_name
308
+ API_VERSION = api_version
309
+ CONSTRUCTED_URL = f"{api_base}/openai/deployments/{deployment_name}/chat/completions?api-version={api_version}"
310
+ API_KEY = api_key or os.getenv("OPENAI_API_KEY", "")
311
+ MODEL_NAME = model_name
312
+ HEADERS = {
313
+ "Content-Type": "application/json",
314
+ "api-key": api_key,
315
+ }
316
+ if API_KEY:
317
+ print(f"Using API key: {API_KEY}")
318
+ OPENAI_CLIENT = OpenAI(api_key=API_KEY)
319
+ else:
320
+ OPENAI_CLIENT = None
321
+
322
+
323
+
324
+
325
+ # def call_llm_judge(message: list, args: argparse.Namespace) -> str:
326
+ # """
327
+ # Call the Qwen API with the given message.
328
+
329
+ # Args:
330
+ # message (list): Message list for the API.
331
+ # args (argparse.Namespace): Parsed arguments.
332
+
333
+ # Returns:
334
+ # str: The content of the completion response.
335
+
336
+ # Raises:
337
+ # Exception: When the API call fails.
338
+ # """
339
+
340
+ # try:
341
+ # completion = client.chat.completions.create(
342
+ # model=args.model_name,
343
+ # messages=message,
344
+ # temperature=args.temperature,
345
+ # top_p=args.top_p,
346
+ # )
347
+ # return completion.choices[0].message.content.strip()
348
+ # except Exception as e:
349
+ # print(f"API call failed: {str(e)}")
350
+ # raise
351
+
352
+
353
+ # def rule_judge(completin, ground_truth, dataset_name: str) -> Dict[str, Any]:
354
+ # rule_judge_result = None
355
+
356
+ # rule_judge_result = evaluator_map[dataset_name].rule_judge(completin, ground_truth)
357
+
358
+ # if not rule_judge_result:
359
+ # print(f"No valid answer detected | LLM judge")
360
+ # call_llm_judge
361
+
362
+
363
+ # return example
364
+
365
+
366
+
367
+ # def process_example(example: Dict[str, Any], args: argparse.Namespace) -> Dict[str, Any]:
368
+ # """
369
+ # Process an individual example with token counting.
370
+ # """
371
+ # try:
372
+
373
+ # completion_text = example.get('completion', '')
374
+ # before_truncation_text = completion_text.split("</think>")[0]
375
+ # before_truncation_token_count = count_tokens(before_truncation_text, args)
376
+ # after_truncation_token_count = example.get('generated_tokens', 0) - before_truncation_token_count
377
+
378
+ # example['before_truncation_token_count'] = before_truncation_token_count
379
+ # example['after_truncation_token_count'] = after_truncation_token_count
380
+
381
+ # example = post_process(example)
382
+ # if not example.get('rule_judge_result', False):
383
+ # think_truncation = extract_after_think(completion_text) or ''
384
+
385
+ # prompt = generate_prompt(example, think_truncation)
386
+ # messages = format_prompt(prompt)
387
+ # llm_judge_response = call_qwen(messages, args)
388
+ # print(llm_judge_response)
389
+ # example['llm_judge_response'] = llm_judge_response
390
+ # if llm_judge_response == "YES":
391
+ # example['llm_judge_result'] = True
392
+ # example['final_judge_result'] = True
393
+ # elif llm_judge_response == "NO":
394
+ # example['llm_judge_result'] = False
395
+ # example['final_judge_result'] = False
396
+ # else:
397
+ # example['llm_judge_result'] = None
398
+ # example['final_judge_result'] = None
399
+
400
+ # return example
401
+ # except Exception as e:
402
+ # print(f"Failed to process example {example.get('idx', 'unknown')}: {str(e)}")
403
+ # example['final_judge_result'] = None
404
+ # return example
405
+
406
+
407
+
408
+ # def main():
409
+ # args = parse_arguments()
410
+ # try:
411
+ # final_results = []
412
+ # with jsonlines.open(args.result_save_name, mode='w') as writer:
413
+ # with open(args.data_load_name, 'r', encoding='utf-8') as f:
414
+ # data = json.load(f)
415
+
416
+ # for idx, example in enumerate(data, start=1):
417
+ # print(f"Processed example ID: {idx}")
418
+ # new_example = process_example(example, args)
419
+ # final_results.append(new_example)
420
+
421
+ # writer.write_all(final_results)
422
+ # print(f"Processing complete! Valid results saved to: {args.result_save_name}")
423
+
424
+ # grouped_data = defaultdict(list)
425
+ # for entry in final_results:
426
+ # grouped_data[entry['idx']].append(entry)
427
+
428
+ # pass_at_1_per_idx = {}
429
+ # for idx, entries in grouped_data.items():
430
+ # correct_count = sum(entry['final_judge_result'] for entry in entries if entry['final_judge_result'] is True)
431
+ # total_count = len(entries)
432
+ # pass_at_1_per_idx[idx] = correct_count / total_count if total_count else 0
433
+
434
+ # overall_pass_at_1 = sum(pass_at_1_per_idx.values()) / len(pass_at_1_per_idx)
435
+ # average_generated_tokens = sum(entry['generated_tokens'] for entry in final_results) / len(final_results)
436
+ # average_reasoning_tokens = sum(entry['before_truncation_token_count'] for entry in final_results) / len(final_results)
437
+ # average_answer_tokens = sum(entry['after_truncation_token_count'] for entry in final_results) / len(final_results)
438
+
439
+ # print(f'\nPass@1 per idx:\n{pass_at_1_per_idx}\n')
440
+ # print(f'Overall pass@1: {overall_pass_at_1:.4f}')
441
+ # print(f'Average generated_tokens: {average_generated_tokens:.2f}')
442
+ # print(f'Average reasoning_tokens: {average_reasoning_tokens:.2f}')
443
+ # print(f'Average answer_tokens: {average_answer_tokens:.2f}')
444
+
445
+ # except json.JSONDecodeError as e:
446
+ # print(f"Data loading failed: Invalid JSON format - {str(e)}")
447
+ # raise
448
+
449
+
450
+ if __name__ == '__main__':
451
+ api_key = os.getenv("OPENAI_API_KEY", "")
452
+
453
+ set_client(api_base="", deployment_name="", api_version="", api_key=api_key)
454
+ response = "Okay, so I need to find the number of triples (a, b, c) of nonnegative integers such that a + b + c = 300 and the expression a\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b equals 6,000,000. Hmm, that seems a bit complicated, but maybe I can break it down step by step.\n\nFirst, let me note down the given equations:\n\n1. a + b + c = 300\n2. a\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b = 6,000,000\n\nI need to find all nonnegative integer triples (a, b, c) that satisfy both equations. Since the problem involves symmetric expressions in a, b, c, maybe there's a way to exploit symmetry here.\n\nLet me first if I can rewrite the second equation in a more manageable form. The expression a\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b looks symmetric. Maybe I can factor it somehow?\n\nLet me see: Each term is of the form (variable squared times another variable). So, for example, a\u00b2b + a\u00b2c = a\u00b2(b + c), similarly for the other terms. Let me try that:\n\na\u00b2(b + c) + b\u00b2(a + c) + c\u00b2(a + b)\n\nYes, that's the same as the original expression. Now, since a + b + c = 300, maybe I can substitute (b + c) with (300 - a), and similarly for the others. Let's try that substitution.\n\nSo, substituting:\n\na\u00b2(300 - a) + b\u00b2(300 - b) + c\u00b2(300 - c)\n\nWait, hold on. Let me check that again. The original expression is't exactly that. Wait, the's see:\n\nWait, the original expression is a\u00b2(b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b. So I group terms as a\u00b2(b + c) + b\u00b2(a + c) + c\u00b2(a + b), then yes, each term is variable squared times the sum of the other two variables. Since the sum of the other two variables is 300 minus the variable. So, for example, (b + c) = 300 - a, so substituting that in, we get:\n\na\u00b2(300 - a) + b\u00b2(300 - b) + c\u00b2(300 - c)\n\nTherefore, the second equation can be rewritten as:\n\na\u00b2(300 - a) + b\u00b2(300 - b) + c\u00b2(300 - c) = 6,000,000\n\nHmm, that might be a useful form. Let me write that down:\n\nEquation 2: a\u00b2(300 - a) + b\u00b2(300 - b) + c\u00b2(300 - c) = 6,000,000\n\nNow, perhaps I can consider expanding expression in terms of the variables. Let me think if there's a way to relate this to the first equation. Since a + b + c = 300, maybe we can express some relationship between these variables here Let't immediately yet.\n\nAlternatively, maybe I can consider expanding the expression or finding another way to express it. Let me see ifAlternatively, perhaps I can consider the expression a\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b. Let me see if this can be written in terms of (a + b + c)^3 or something similar.\n\nWait, let me recall that (a + b + c)^3 = a^3 + b^3 + c^3 + 3(a + b)(a + c)(b + c). Hmm, not sure if that helps. Alternatively, maybe expanding (a + b + c)(ab + bc + ca). Let me see:\n\n(a + b + c)(ab + bc + ca) = a^2b + a^2c + ab^2 + b^2c + a c^2 + b c^2 + 3abc\n\nWait, that's exactly our expression in the second equation plus 3abc. So:\n\n(a + b + c)(ab + bc + ca) = (a\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b) + 3abc\n\nTherefore, the second equation can be rewritten as:\n\n(a + b + c)(ab + bc + ca) - 3abc = 6,000,000\n\nSince a know that a + b + c = 300, substitute that in:\n\n300(abab + bc + ca) - 3abc = 6,000,000\n\nHmm, that might be a better way to handle it equation. Let me write that down:\n\nEqu00(ab + bc + ca) - 3abc = 6,000,000\n\nDiv can divide both sides by 3 to simplify:\n\n100(ab + bc + ca) - abc = 2,000,000\n\nSo now, the equation becomes:\n\n100(ab + bc + ca) - abc = 2,000,000\n\nHmm, that seems a bit more manageable. Let me write that as:\n\nabc - 100(ab + bc + ca) = -2,000,000\n\nAlternatively, maybe rearranged as:\n\nabc = 100(ab + bc + ca) - 2,000,000\n\nHmm, not sure yet. Let me see if I can find another relation here.\n\n Since we have a + b + c = 300, perhaps we can denote variables in terms of two variables. say, let me set variables as a, b, and then c = 300 - a - b. Then substitute expressions in terms of a and b. But that might get complicated, but perhaps manageable.\n\nAlternatively, maybe there's a substitution or symmetry here. Let me think.\n\n the problem is symmetric in a, b, c, so perhaps the solutions are symmetric in some way. Maybe two variables are equal, or all three are equal?\n\nLet me first check if all three variables are equal. If a = b = c, then since a + b + c = 300, each would be 100. Let me check if that satisfies the second equation.\n\nIf a = b = c = 100, then compute the second expression:\n\na\u00b2b + a\u00b2c + ...\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\n = 6*(100)^2*100 = 6*10003 = 6,000,000. Wait, that's exactly the required given! So (100, 100, 100) is a solution.\n\n So that's one solution. But the problem asks for all triples, so maybe might be more?\n\nHmm, so maybe there are other triples where variables are not all equal but still satisfy the equations. Let me see.\n\nAlternatively, maybe two variables are equal and the third is different. Let me suppose that two variables are equal, say a = b, and c is different. Let,2a + c = 300, so c = 300 - 2a.\n\nThen, let's plug into the second equation.\n\nFirst, let me compute ab + bc + ca. when a = b:\n\nab + bc + ca = a\u00b2 + a c + a c = a\u00b2 + 2a c\n\nSimilarly, abc = a\u00b2 c\n\nSo substitugging into the equation 100(ab + bc + ca) - abc = 100(a\u00b2 + 2a c) - a\u00b2 c = 100a\u00b2 + 200a c - a\u00b2 c\n\nBut since c = 300 - 2a, substitute that in:\n\n100a\u00b2 + 200a(300 - 2a) - a\u00b2(300 - 2a)\n\nLet me compute each term:\n\nFirst term: 100a\u00b2\n\nSecond term: 200a*(300 - 2a) = 200a*300 - 200a*2a = 60,000a - 400a\u00b2\n\nThird term: -a\u00b2*(300 - 2a) = -300a\u00b2 + 2a\u00b3\n\nSo adding all three terms together:\n\n100a\u00b2 + (60,000a - 400a\u00b2) + (-300a\u00b2 + 2a\u00b3)\n\nCombine like terms:\n\nFor00a\u00b2 - 400a\u00b2 - 300a\u00b2 + 2a\u00b3 + 60,000a\n\nThat's:\n\n(100 - 400 - 300)a\u00b2 + 2a\u00b3 + 60,000a\n\n= (-600a\u00b2) + 2a\u00b3 + 60,000a\n\nSo the entire expression is 2a\u00b3 - 600a\u00b2 + 60,000a\n\nSet this equal to 2,000,000 ( (since the equation was 100(ab + bc + ca) - abc = 2,000,000)\n\nWait:\n\n2a\u00b3 - 600a\u00b2 + 60,000a = 2,000,000\n\nDivide both sides by 2:\n\na\u00b3 - 300a\u00b2 + 30,000a = 1,000,000\n\nBring all terms to left-hand:\n\na\u00b3 - 300a\u00b2 + 30,000a - 1,000,000 = 0\n\nHmm, need to solve this cubic equation for integer a. Let me see if I can factor this or find integer roots.\n\nPossible me denote f(a) = a\u00b3 - 300a\u00b2 + 30,000a - 1,000,000\n\nWe can try possible integer roots. Since a must be a nonnegative integer such that 2a \u2264 300 (since c = 300 - 2a must be nonnegative). so a \u2264 150.\n\nPossible rational roots are factors of 1,000,000 divided by factors of 1, so possible integer roots are divisors of 1,000,000. But that's a lot. Maybe try some values.\n\nFirst, let me try a = 100. Since when all variables are 100, that's a solution. so if a = 100, then c = 300 - 200 = 100, so that's the case where all three are equal, so that should be a root.\n\nLet me check f(100):\n\n100\u00b3 - 300*(100)^2 + 30,000*100 - 1,000,000\n\n= 1,000,000 - 300*10,000 + 3,000,000 - 1,000,000\n\n= 1,000,000 - 3,000,000 + 3,000,000 - 1,000,000 = 0. Yes, so a = 100 is a root.\n\nTherefore, (a - 100) is a factor. Let's perform polynomial division or factor it out.\n\nDivide f(a) by (a - 100):\n\nUsing synthetic division:\n\nCoefficients: 1 | -300 | 30,000 | -1,000,000\n\nDivide by (a - 100), so root at 100.\n\nBring down the 1.\n\nMultiply 1 by 100: 100. Add to next coefficient: -300 + 100 = -200\n\nMultiply -200 by 100 = -20,000. Add to next coefficient: 30,000 + (-20,000) = 10,000\n\nMultiply 10,000 by 100 = 1,000,000. Add to last term: -1,000,000 + 1,000,000 = 0. Perfect.\n\nSo the polynomial factors as (a - 100)(a\u00b2 - 200a + 10,000)\n\nNow set the quadratic factor to zero:\n\na\u00b2 - 200a + 10,000 = 0\n\nLet me compute discriminant D = (200)^2 - 4*1*10,000 = 40,000 - 40,000 = 0\n\nSo discriminant is zero, so a = [200 \u00b1 0]/2 = 100. So the quadratic has a double root at a = 100.\n\nTherefore, the cubic factors as (a - 100)^3 = 0. Wait, but that would mean that the original cubic is (a - 100)^3, but let me check:\n\nWait, (a - 100)^3 = a\u00b3 - 300a\u00b2 + 30,000a - 1,000,000, which matches exactly the cubic. So indeed, the equation is (a - 100)^3 = 0, so the only real root is a = 100. Therefore, in the case where a = b, the only solution is a = b = 100, which gives c = 100. so that's the triple-equal case.\n\nHmm, so if we assume two variables are equal, the only solution is when all three are equal. So maybe the only symmetric with two variables equal is the all-equal case. So perhaps other are solutions where all variables are different?\n\nAlternatively, maybe there are solutions where one variable is 0. Let me check that possibility.\n\nSuppose c = 0. Then the equations become a + b = 300 and and the second equation becomes a\u00b2b + a\u00b2*0 + b\u00b2a + b\u00b2*0 + 0 + 0 = a\u00b2b + ab\u00b2 = ab(a + b) = ab*300.\n\nSo the second equation becomes 300ab = 6,000,000, so ab = 20,000.\n\nSo if have a + b = 300 and ab = 20,000. Let me see if there are integer solutions here here.\n\nWe can set up the quadratic equation: Let a and b be roots of x\u00b2 - 300x + 20,000 = 0.\n\nCompute discriminant D = 300\u00b2 - 4*1*20,000 = 90,000 - 80,000 = 10,000\n\nsqrt(D) = 100\n\n so roots are [300 \u00b1 100]/2 = (400)/2 = 200 or (200)/2 = 100. So the roots are 200 and 100. Therefore, the solutions are (a, b) = (200, 100) or (100, 200). Since c = 0, the triples would be (200, 100, 0), (100, 200, 0), and and also permutations where c is in different positions. Wait, since c was set to 0, but the variables are ordered, so actually, if we set c = 0, then the triples are (a, b, 0) where a + b = 300 and ab = 20,000. So the solutions are (200, 100, 0), (100, 200, 0). But since the variables are ordered, but the problem counts triples (a, b, c) as ordered? Wait, actually, the problem says \"triples of nonnegative integers (a, b, c)\", so I think order matters here. Wait, but actually, in combinatorics, unless specified otherwise, sometimes triples are considered ordered. But let me check the problem statement again.\n\nWait, the problem says \"the number of triples of nonnegative integers (a, b, c)\", so I think they are ordered triples. so different permutationsings count as different triples. However, in the case where c is 0, but we could also have a = 0 or b = 0. So actually's see.\n\nWait, so if c = 0, then we have two solutions (200,100,0 and 100,200,0. Similarly, if a = 0, then we have (0, 200, 100) and (0, 100, 200). Similarly, if b = 0, then (200, 0, 100) and (100, 0, 200). So total of 6 triples where one variable is 0 and the other two are 200 and 100. Additionally, the all-100 case. So that's 7 so far. But maybe there are more solutions where none of the variables are zero?\n\nAlternatively, maybe there are solutions where two variables are zero? Let me check that.\n\n quickly. Suppose two variables are zero, say b = c = 0, then a = 300. Then the second equation becomes a\u00b2b + ... all terms would be zero, so 0 \u2260 6,000,000. So that's invalid. So no solutions with two variables zero.\n\nSo the only solutions with a zero are the ones above with one zero and the other two being 200 and 100. So that gives 6 triples. Plus the all-100 case. So total 7? But maybe there are other triples where none of the variables are zero?\n\nWait, let me see.\n\n Let me suppose of another approach. Let me suppose variables a, b, c all positive integers (since if one is zero, we already considered that case). Let me see if there are solutions where all variables are positive and and different from 100, 100, 100 and 200,100,0 etc.\n\nAlternatively, maybe there can consider the equation 100(ab + bc + ca) - abc = 2,000,000. Let me denote this as:\n\nabc - 100(ab + bc + ca) = -2,000,000\n\nHmm, perhaps we can factor this equation somehow. Let me see:\n\nLet me add 1,000,000 to both sides:\n\nabc - 100(ab + bc + ca) + 1,000,000 = -1,000,000\n\nHmm, not sure. Alternatively, perhaps we't think of variables x = a, y = b, z = c, and see if there's a substitution.\n\nAlternatively, let set variables = 300 - a, y = 300 - b, z = 300 - c, but not sure.\n\nAlternatively, perhaps consider the equation:\n\n100(ab + bc + ca) - abc = 2,000,000. Let me rearr this as:\n\nabc - 100(ab - 100bc - 100ca + 2,000,000 = 0\n\nHmm, perhaps factor terms:\n\nLet me see, perhaps factor terms with ab, bc, ca:\n\nab(c - 100) + bc(a - 100) + ca(b - 100) = -2,000,000\n\nHmm, not sure. Alternatively, perhaps rearr the equation as:\n\nabc = 100(ab + bc + ca) - 2,000,000\n\nLet me think if I can write this as:\n\nabc = 100(ab + bc + ca) - 2,000,000\n\nHmm, perhaps divide both sides by abc (assuming a,b,c \u2260 zero):\n\n1 = 100(1/c + 1/a + 1/b) - 2,000,000/(abc)\n\nHmm, not sure if helpful helps.\n\nAlternatively, maybe think of variables in terms of ratios. Let me suppose that variables are in some proportion. Let me see.\n\nAlternatively, let me consider that in the case where one variable is 0, we have solutions. The case where all variables equal 100 is a solution. Maybe there are other solutions where variables are different but not involving zeros?\n\nAlternatively me suppose that variables a, b, c are all different from 100 and none are zero. Let me see if such solutions exist.\n\nAlternatively, let try to set variables as 150, 100, 50. Let me test that.\n\nWait, let me pick some numbers. Let me see, for't suppose a = 150, b = 100, c = 50. Then a + b + c = 300. Let me compute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\nCompute each term:\n\na\u00b2b = 150\u00b2 * 100 = 22500 * 100 = 2,250,000\n\na\u00b2c = 150\u00b2 *50 = 22500 *50 = 1,125,000\n\nb\u00b2a = 100\u00b2 *150 = 10,000 *150 = 1,500,000\n\nb\u00b2c = 100\u00b2 *50 = 10,000 *50 = 500,000\n\nc\u00b2a = 50\u00b2 *150 = 2500 *150 = 375,000\n\nc\u00b2b = 50\u00b2 *100 = 2500 *100 = 250,000\n\nAdding all these up:\n\n2,250,000 + 1,125,000 = 3,375,000\n\n+1,500,000 = 4,875,000\n\n+500,000 = 5,375,000\n\n+375,000 = 5,750,000\n\n+250,000 = 6,000,000\n\nWait a second! That works! So (150, 100, 50) is another solution. Hmm that's a solution where all variables are different and none are zero. So that's another solution. Hmm my previous thought that only the all-100 and the ones with a zero were solutions was incorrect. So there are more solutions.\n\nHmm, so that complicates things. Let me see why that works. Let me check the calculation again.\n\nWait, so (150, 100, 50) gives the required sum and the required product expression6,000,000. So that's a valid solution. So there are more solutions. So I need to find all such triples.\n\nHmm, so perhaps there are multiple families variables are in arithmetic progression? Let me see. 50, 100, 150 is an arithmetic progression with difference 50. Let me see if that's a pattern.\n\nAlternatively, maybe variables a can think of variables as multiples of 50? Let me see.\n\nWait, let me see if there are other triples. that. Let me try (200, 100, 0) which we already know works, but (150, 100, 50) works, and (100, 100, 100) works. Let me see if there are others.\n\nLet me try (200, 50, 50). Let's check:\n\na + b + c = 200 +50 +50 = 300. Good.\n\nCompute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\n= 200\u00b2*50 + 200\u00b2*50 + 50\u00b2*200 + 50\u00b2*50 + 50\u00b2*200 + 50\u00b2*50\n\nWait, let me compute each term:\n\na\u00b2b = (200)^2 *50 = 40,000 *50 = 2,000,000\n\na\u00b2c = same as a\u00b2b since b and c are both 50, so another 2,000,000\n\nb\u00b2a = (50)^2 *200 = 2,500 *200 = 500,000\n\nb\u00b2c = (50)^2 *50 = 2,500 *50 = 125,000\n\nc\u00b2a = same as b\u00b2a, so another 500,000\n\nc\u00b2b = same as b\u00b2c, so another 125,000\n\nAdding all together:\n\n2,000,000 + 2,000,000 = 4,000,000\n\n+500,000 + 00,000 = 4,625,000\n\nWait, no, let me do step by step:\n\nFirst two terms: 2,000,000 + 2,000,000 = 4,000,000\n\nThen b\u00b2a + b\u00b2c = 500,000 + 125,000 = 625,000\n\nThen c\u00b2a + c\u00b2b = same as b\u00b2a + b\u00b2c, so another 625,000\n\nTotal: 4,000,000 + 625,000 + 625,000 = 5,250,000. which is less than 6,000,000. So that doesn't work.\n\nHmm, so that's not a solution. So maybe arithmetic progression is a better idea.\n\nWait, let me try another triple. Let me see, maybe't try (120, 100, 80). Let me check sum: 120 + 100 +80 = 300. Good.\n\nCompute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\nCompute me compute each term:\n\na\u00b2b = 120\u00b2 *100 = 14,400 *100 = 1,440,000\n\na\u00b2c = 120\u00b2 *80 = 14,400 *80 = 1,152,000\n\nb\u00b2a = 100\u00b2 *120 = 10,000 *120 = 1,200,000\n\nb\u00b2c = 100\u00b2 *80 = 10,000 *80 = 800,000\n\nc\u00b2a = 80\u00b2 *120 = 6,400 *120 = 768,000\n\nc\u00b2b = 80\u00b2 *100 = 6,400 *100 = 640,000\n\nAdding all terms:\n\n1,440,000 + 1,152,000 = 2,592,000\n\n+1,200,000 = 3,792,000\n\n+800,000 = 4,592,000\n\n+768,000 = 5,360,000\n\n+640,000 = 6,000,000\n\nOh! That works too! So (120, 100, 80) is another solution. Interesting. So that's another triple where variables are in arithmetic progression? 80, 100, 120 with difference 20. Wait, so that's also an arithmetic progression. Hmm, so maybe arithmetic progressions are the key here?\n\nWait, let me see another arithmetic progression. Let me try (180, 100, 20). Let me check sumSum is 180 + 100 +20 = 300. Let me compute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\na\u00b2b = 180\u00b2 *100 = 32,400 *100 = 3,240,000\n\na\u00b2c = 180\u00b2 *20 = 32,400 *20 = 648,000\n\nb\u00b2a = 100\u00b2 *180 = 10,000 *180 = 1,800,000\n\nb\u00b2c = 100\u00b2 *20 = 10,000 *20 = 200,000\n\nc\u00b2a = 20\u00b2 *180 = 400 *180 = 72,000\n\nc\u00b2b = 20\u00b2 *100 = 400 *100 = 40,000\n\nAdding all terms:\n\n3,240,000 + 648,000 = 3,888,000\n\n+1,800,000 = 5,688,000\n\n+200,000 = 5,888,000\n\n+72,000 = 5,960,000\n\n+40,000 = 6,000,000\n\nWow, that works works! So (180, 100, 20) is another solution. So arithmetic progression with difference 80? Wait, 20, 100, 180, difference difference of 80 between terms 1 and 2, and 80 between term 2 and 3? Wait,100 -20 =80, 180 -100=80. Yes, so that's an arithmetic progression with common difference 80. So that works too as well.\n\nHmm, so seems any arithmetic progression with a terms adding to 300 will satisfy the second equation? Let me check another one.\n\n Let me try (160, 100, 40). Sum is 160 + 00 +40 = 300.\n\nCompute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\na\u00b2b = 160\u00b2 *100 = 25,600 *100 = 2,560,000\n\na\u00b2c = 160\u00b2 *40 = 25,600 *40 = 1,024,000\n\nb\u00b2a = 100\u00b2 *160 = 10,000 *160 = 1,600,000\n\nb\u00b2c = 100\u00b2 *40 = 10,000 *40 = 400,000\n\nc\u00b2a = 40\u00b2 *160 = 1,600 *160 = 256,000\n\nc\u00b2b = 40\u00b2 *100 = 1,600 *100 = 160,000\n\nAdding all terms:\n\n2,560,000 + ,024,000 = 3,584,000\n\n+1,600,000 = 5,184,000\n\n+400,000 = 5,584,000\n\n+256,000 = 5,840,000\n\n+160,000 = 6,000,000\n\nYes! That works too. So (160, 100, 40) is another solution. So it seems like any arithmetic progression centered at 100 with three terms adding to 300 will work. Let me see.\n\nWait, let me think of the arithmetic progression. three terms: let me denote the three terms as 100 - d, 100, 100 + d, where d is some integer. Then their sum is (100 - d) + 100 + (100 + d) = 300, which works. So that's an arithmetic progression with common difference d. Let me check if this satisfies the second equation.\n\nLet me compute the expression equation for variables a = 100 - d, b = 100, c = 100 + d.\n\nCompute a\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b.\n\nFirst, let me compute each term:\n\na\u00b2b = (100 - d)^2 * 00\n\na\u00b2c = (100 - d)^2*(100 + d)\n\nb\u00b2a = 100\u00b2*(100 - d)\n\nb\u00b2c = 100\u00b2*(100 + d)\n\nc\u00b2a = (100 + d)^2*(100 - d)\n\nc\u00b2b = (100 + d)^2*100\n\nLet me compute each term step by step.\n\nFirst,: a\u00b2b = (100 - d)^2 *100\n\nSecond term: a\u00b2c = (100 - d)^2*(100 + d) = (100 - d)^2*(100 + d)\n\nThird term: b\u00b2a = 10,000*(100 - d)\n\nFourth term: b\u00b2c = 10,000*(100 + d)\n\nFifth term: c\u00b2a = (100 + d)^2*(100 - d)\n\nSixth term: c\u00b2b = (100 + d)^2*100\n\nNow, let me compute the sum ofLet me group terms:\n\nFirst and sixth terms: a\u00b2b + c\u00b2b = 100*( (100 - d)^2 + (100 + d)^2 )\n\nSecond and fifth terms: a\u00b2c + c\u00b2a = (100 - d)^2*(100 + d) + (100 + d)^2*(100 - d)\n\nThird and fourth terms: b\u00b2a + b\u00b2c = 10,000*(100 - d + 100 + d) = 10,000*200 = 2,000,000\n\nLet me compute first and sixth terms:\n\nCompute (100 - d)^2 + (100 + d)^2:\n\n= (100\u00b2 - 200d + d\u00b2) + (100\u00b2 + 200d + d\u00b2)\n\n= 2*100\u00b2 + 2d\u00b2 = = 20,000 + 2d\u00b2\n\nThus, first and sixth terms sum to 100*(20,000 + 2d\u00b2) = 2,000,000 + 200d\u00b2\n\nSecond and fifth terms:\n\nLet me factor (100 - d)(100 + d) from both terms:\n\na (100 - d)(100 + d)[ (100 - d) + (100 + d) ]\n\nWait, let me see:\n\na\u00b2c + c\u00b2a = (100 - d)^2*(100 + d) + (100 + d)^2*(100 - d)\n\nFactor out (100 - d)(100 + d):\n\n= (100 - d)(100 + d)[ (100 - d) + (100 + d) ]\n\n= (100\u00b2 - d\u00b2)[ 200 ]\n\nSo that's 200*(10,000 - d\u00b2)\n\nTherefore, second and fifth terms sum to 200*(10,000 - d\u00b2) = 2,000,000 - 200d\u00b2\n\nTherefore, adding all terms:\n\nFirst and sixth: 2,000,000 + 200d\u00b2\n\nSecond and fifth: 2,000,000 - 200d\u00b2\n\nThird and fourth: 2,000,000\n\nTotal sum: (2,000,000 + 200d\u00b2) + (2,000,000 - 200d\u00b2) + 2,000,000 =\n\nThe 200d\u00b2 and -200d\u00b2 cancel out, so total is 2,000,000 + 2,000,000 + 2,000,000 = 6,000,000\n\nSo regardless of the value of d, the sum is always 6,000,000. That's fascinating! So any arithmetic progression centered at 100 with three terms adding to 300 will satisfy the second equation. Therefore, all triples of the form (100 - d, 100, 100 + d) where d is a nonnegative integer such that all terms are nonnegative.\n\nWait, but d can be positive or negative? Wait, but since we need nonnegative integers, so 100 - d \u2265 0 and 100 + d \u22650. Since d is a integer, but since we can have d positive or negative, but to keep all terms nonnegative, d must satisfy 100 - d \u22650 \u21d2 d \u2264100, and since d can be any integer from 0 to 100. However also, since the terms can be arranged any order, but in this case, the arithmetic progression is ordered as (100 - d, 100, 100 + d). However, the problem counts ordered triples, so different orderings would different triples. Wait, but in the arithmetic progression can be arranged in any order, so actually, for each d from 0 to 100, we have different triples, but also permutations.\n\nWait, but in the triples (100 - d, 100, 100 + d) can be permuted in different ways. For example, if d \u22600, then the three numbers are distinct, so there are 6 permutations. However, when d =0, all three are 100, so only one permutation.\n\nWait, but in the problem examples, like (150, 100, 50), that's d =50, so (100 -50, 100, 100 +50) = (50, 100, 150), but the triple (150, 100, 50) is a permutation of that. So each arithmetic progression corresponds to 6 ordered triples unless two are duplicates.\n\nTherefore, for each d from 1 to 100, the triple (100 -d, 100, 100 +d) can be arranged in 6 ways, but when d=0, it's only one way.\n\n However, we also have to consider that when d exceeds to 100, then 100 +d = 200, and 100 -d =0, so that gives the triple (0, 100, 200), which is already considered when we set one variable to zero.\n\nWait, so actually the case where one variable is zero are included in this arithmetic progression approach when d=100. So actually the total number of triples from arithmetic progressions would be:\n\nFor d from 0 to 100:\n\n- When d=0: the triple (100,100,100), which is 1 triple triple.\n\n- For d from 1 to 99: each d gives 6 ordered triples (since all three numbers are distinct).\n\n- When d =100: the triple (0, 100, 200), which can permutations in 3! /1! =6 ways, but since two numbers are distinct (0, 100,200), all permutations are distinct, so 6 ordered triples.\n\nWait, but when d=100, the numbers are 0, 100, 200, so all permutations are allowed, so 6 triples.\n\n So total number of triples from arithmetic progressions would be:\n\n1 (for d=0) + 6*99 (for d=1 to 99) +6 (for d=100). = 1 + 594 +6 = 601? Wait, but that can't be right because when d=100, it's included in the d=1 to 99? Wait no, d=100 is separate.\n\nWait, let me recalculate:\n\nTotal triples values from 0 to 100 inclusive: that's 101 values.\n\nFor each d from 1 to 99: 99 values, each gives 6 triples.\n\nFor d=0: 1 triple.\n\nFor d=100: 1 value, giving 6 triples.\n\nTotal triples: 1 + 99*6 +6 = 1 + 594 +6 = 601.\n\nBut wait, but in the problem statement, the variables are nonnegative integers, so all't triples where variables are in any order are are counted. separately separate. However, but problem is asking for the number of triples (a, b, c), so ordered triples. So this approach would give 601 triples. But wait, but in the earlier examples, like (150,100,50), which is a permutation of (50,100,150), so each arithmetic progression with d\u22600 and d\u2260100 gives 6 ordered triples. However, d=100, the triple is (0,100,200), which also has 6 permutations. And d=0 gives only one triple.\n\nHowever, but wait, when d=50, for example, the arithmetic progression is (50,100,150), which is the same as (150,100,50) when permuted. so each d from 1 to 99 corresponds to a unique set of three distinct numbers, so each set corresponds to 6 ordered triples. So the total number would be 1 + 6*100 = 601? Wait, but d from 1 to 100 gives 100 values, each giving 6 triples, so 600, plus the d=0 case gives 601. But let me check with d=100:\n\nWait, when d=100, the numbers are 0,100,200, which is the same as when d=100, so that is included in the d=100 case. So yes, that's correct.\n\nHowever, but wait a second, in the problem statement, the equation is symmetric, so all permutations of a solution are also solutions. Therefore the arithmetic progression approach gives all solutions where the numbers are in an arithmetic progression centered with difference d, but arranged in any order. So the total number of ordered triples would indeed 601. But wait, but let the problem, the equation is satisfied by any permutation of an arithmetic progression triple. so all permutations permutations are valid solutions.\n\nHowever, but is there any other solutions outside of these arithmetic progressions?\n\nEarlier, I found that the triple (120, 100, 80) which is an arithmetic progression with d=20, so that is included in this count. Similarly, (180,100,20) is d=80, so also included. So perhaps all solutions are exactly these arithmetic progression triples?\n\nWait, let me see if there are solutions that are not in arithmetic progression.\n\nSuppose I take a triple like (140, 100, 60). Let me check if that works.\n\nSum is 140 +100 +60 = 300.\n\nCompute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\na\u00b2140, b=100, c=60.\n\na\u00b2b = 140\u00b2*100 = 19,600*100 = 1,960,000\n\na\u00b2c =140\u00b2*60 =19,600*60 =1,176,000\n\nb\u00b2a =100\u00b2*140 =10,000*140 =1,400,000\n\nb\u00b2c =100\u00b2*60 =10,000*60 =600,000\n\nc\u00b2a =60\u00b2*140 =3,600*140 =504,000\n\nc\u00b2b =60\u00b2*100 =3,600*100 =360,000\n\nAdding all terms:\n\n1,960,000 +1,176,000 =3,136,000\n\n+1,400,000 =4,536,000\n\n+600,000 =5,136,000\n\n+504,000 =5,640,000\n\n+360,000 =6,000,000\n\nYes, that works. too. And ( this is an arithmetic progression with d=40: 60,100,140. So yes is included in the arithmetic progression case.\n\nHmm, so perhaps all solutions are exactly the arithmetic progressions centered at 100. Let me see if there is a solution that is not an arithmetic progression.\n\nLet me try a triple like (130, 110, 60). Sum is 130+110+60=300.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\na=130, b=110, c=60.\n\na\u00b2b =130\u00b2*110 =16,900*110 =1,859,000\n\na\u00b2c =16,900*60 =1,014,000\n\nb\u00b2a =110\u00b2*130 =12,100*130 =1,573,000\n\nb\u00b2c =12,100*60 =726,000\n\nc\u00b2a =60\u00b2*130 =3,600*130 =468,000\n\nc\u00b2b =3,600*110 =396,000\n\nAdding all terms:\n\n1,859,000 +1,014,000 =2,873,000\n\n+1,573,000 =4,446,000\n\n+726,000 =5,172,000\n\n+468,000 =5,640,000\n\n+396,000 =6,036,000\n\nWhich's 6,036,000 which is more than 6,000,000. So that doesn't work. So that triple is not a solution.\n\nHmm test: Let me try (120, 90, 150). Wait, sum is 160+90+150=400, which is over. Not good.\n\nLet me try (120, 90, 90). Sum is 120+90+90=300.\n\nCompute the second equation:\n\na\u00b22b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\na=120, b=90, c=90.\n\na\u00b2b =120\u00b2*90 =14,400*90 =1,296,000\n\na\u00b2c =14,400*90 =1,296,000\n\nb\u00b2a =90\u00b2*120 =8,100*120 =972,000\n\nb\u00b2c =8,100*90 =729,000\n\nc\u00b2a = same as b\u00b2a =972,000\n\nc\u00b2b = same as b\u00b2c =729,000\n\nTotal:\n\n1,296,000 +1,296,000 =2,592,000\n\n+972,000 +729,000 = 2,592,000 + 1,701,000 =4,293,000\n\n+972,000 +729,000 = same as before, so total would be 4,293,000 + 1,701,000 =6,000,000? Wait, let me compute step by step:\n\nWait, let me add all terms:\n\nFirst.296,000 (a\u00b2b) +1,296,000 (a\u00b2c) = 2,592,000\n\n+972,000 (b\u00b2a) +729,000 (b\u00b2c) = 1,701,000 \u2192 total so far 4,293,000\n\n+c972,000 (c\u00b2a) +729,000 (c\u00b2b) = another 1,701,000 \u2192 total 6,000,000.\n\nWait, so that works! So (120, 90, 90) is a solution. But this is not an arithmetic progression. Let me check:\n\nThe numbers are 120, 90, 90. So two of them are equal. So this is a case where two variables are equal, but not the middle term. Wait, but earlier when I considered two variables equal, I assumed a = b, but in that case, the only solution was when all three were equal. But here, with't two variables equal (b and c), but a different.\n\n from them. So this is a case where two variables are equal but not the first two.\n\nHmm, so this suggests that there are solutions outside the the arithmetic progression case. So my previous conclusion was wrong.\n\nWait, let let me analyze this case. Let me suppose that two variables are equal, say b = c, then a + 2b = 300, so a = 300 - 2b.\n\nThen, let's compute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\nSince b = c, this becomes:\n\na\u00b2b + a\u00b2b + b\u00b2a + b\u00b2b + b\u00b2a + b\u00b2b\n\n= 2a\u00b2b + 2b\u00b2a + 2b\u00b3\n\n= 2a\u00b2b + 2ab\u00b2 + 2b\u00b3\n\nFactor out 2b:\n\n2b(a\u00b2 + ab + b\u00b2)\n\nSet this equal to 6,000,000:\n\n2b(a\u00b2 + ab + b\u00b2) = 6,000,000\n\nDivide both sides by 2:\n\nb(a\u00b2 + ab + b\u00b2) = 3,000,000\n\nBut since a = 300 - 2b, substitute that in:\n\nb[(300 - 2b)^2 + (300 - 2b)b + b\u00b2] = 3,000,000\n\nLet me compute the expression inside the brackets:\n\nFirst term: (300 - 2b)^2 = 90,000 - 1,200b +4b\u00b2\n\nSecond term: (300 - 2b)bb = 300b - 2b\u00b2\n\nThird term: b\u00b2\n\nAdding them three terms:\n\n90,000 - ,200b +4b\u00b2 +300b -2b\u00b2 +b\u00b2\n\nCombine like terms:\n\n90,000 + (-1,200b + 00b) + (4b\u00b2 -2b\u00b2 +b\u00b2)\n\n= 90,000 - 900b + b\u00b2\n\nTherefore, the equation becomes:\n\nb*(90,000 - 900b +3b\u00b2) = 3,000,000\n\nLet me write this as:\n\n3b\u00b3 - 900b\u00b2 +90,000b -3,000,000 =0\n\nDivide both sides by 3:\n\nb\u00b3 - 300b\u00b2 +30,000b -1,000,000 =0\n\nWait, this is the same cubic equation as before when I set a = b!\n\nIndeed, the cubic is the same as when we set a = b. earlier. which had a root at b =100, and fact factored into (b -100)^3 =0. Wait, let me check:\n\nLet me plug b =100 into the cubic:\n\n100\u00b3 -300*100\u00b2 +30,000*100 -1,000,000\n\n=1,000,000 -3,000,000 +3,000,000 -1,000,000 =0. So yes, b=100 is a root.\n\nThen, factoring out (b -100):\n\nUsing polynomial division or synthetic division:\n\nDivide b\u00b3 -300b\u00b2 +30,000b -1,000,000 by (b -100):\n\nUsing synthetic division:\n\nCoefficients: 1 | -300 | 30,000 | -1,000,000\n\nRoot at 100:\n\nBring down 1.\n\nMultiply by 100: 100. Add to next300: -200\n\nMultiply -200 by 100: -20,000. Add to 30,000: 10,000\n\nMultiply 10,000 by 100: 1,000,000. Add to -1,000,000: 0.\n\nThus the cubic factors as (b -100)(b\u00b2 -200b +10,000). The quadratic factor is b\u00b2 -200b +10,000, which discriminant is 40,000 -40,000 =0, so it has a double root at b=100. Thus the equation is (b -100)^3 =0, so the only solution is b=100.\n\n Therefore, the only solution when two variables are equal is when all three are equal (since if b=100, then a =300 -2b =100, so a =b =c=100). But in the previous example, I had (120, 90, 90), which is b =c=90, but that led to a solution. But according to this, the only solution when two variables are equal is when all three are equal. Contradiction?\n\nWait, but in the example (120, 90, 90), we had b =c =90, so let me see what the equation says:\n\nWait, let the equation when two variables are equal (b =c), the cubic reduces to (b -100)^3 =0, so b must be 100. But in the example, b =90, which is't satisfy that. So there must be an error in my calculations.\n\nWait, let me reccheck the example (120, 90, 90):\n\nCompute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\nWith a=120, b=90, c=90:\n\na\u00b2b =120\u00b2*90 =14,400*90 =1,296,000\n\na\u00b2c = same as a\u00b2b =1,296,000\n\nb\u00b2a =90\u00b2*120 =8,100*120 =972,000\n\nb\u00b2c =90\u00b2*90 =8,100*90 =729,000\n\nc\u00b2a = same as b\u00b2a =972,000\n\nc\u00b2b = same as b\u00b2c =729,000\n\nTotal sum:\n\n1,296,000 +1,296,000 =2,592,000\n\n+972,000 +729,000 = 1,701,000 \u2192 total 4,293,000\n\n+972,000 +729,000 = another 1,701,000 \u2192 total 6,000,000. So it works.\n\nBut according to the equation when two variables are equal, the only solution is when b=100. So why does this work?\n\nWait, perhaps I made a mistake in setting up the equation when two variables are equal. Let me re-derive it.\n\nLet me suppose that two variables are equal, say b =c. Then a =300 - 2b.\n\nThen second equation is:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\nSince b =c, this becomes:\n\na\u00b2b + a\u00b2b + b\u00b2a + b\u00b2b + b\u00b2a + b\u00b2b\n\n= 2a\u00b2b + 2ab\u00b2 + 2b\u00b3\n\n= 2b(a\u00b2 + ab + b\u00b2)\n\nSet equal to 6,000,000:\n\n2b(a\u00b2 +ab + b\u00b2) =6,000,000 \u2192 b(a\u00b2 +ab + b\u00b2)=3,000,000\n\nBut a =300 -2b, so substitute:\n\nb[( (300 -2b)^2 + (300 -2b)b + b\u00b2 ] =3,000,000\n\nLet me compute the expression inside the brackets again:\n\nFirst term: (300 -2b)^2 = 0,000 -1,200b +4b\u00b2\n\nSecond term: (300 -2b)*b =300b -2b\u00b2\n\nThird term: b\u00b2\n\nAdding them together90,000 -1,200b +4b\u00b2 +300b -2b\u00b2 +b\u00b2\n\n=90,000 -900b +3b\u00b2\n\nThus the equation is:\n\nb*(90,000 -900b +3b\u00b2)= =3,000,000\n\nWhich simplifies to:\n\n3b\u00b3 -900b\u00b2 +90,000b -3,000,000 =0 \u2192 divide by 3:\n\nb\u00b3 -300b\u00b2 +30,000b -1,000,000 =0\n\nWhich factors as (b -100)^3 =0, so only solution is b=100. But in the example, b=90,. a solution. Contr contradiction must be an error here.\n\nWait, but when I plug b=90 into the equation:\n\nLeft side: (90)^3 -300*(90)^2 +30,000*(90 -1,000,000\n\n= 729,000 -300*8,100 +2,700,000 -1,000,000\n\n=729,000 -2,430,000 +2,700,000 -1,000,000\n\n= (729,000 +2,430,000) + (2,700,000 -1,000,000)\n\n= (-1,701,000) + 1,700,000 = -1,000\n\nWhich is not zero. So that suggests that my example (120,90,90) does not satisfy the equation? But when I computed manually, it did. Wait must be a miscalculation.\n\nWait, let me recompute the example:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n= (120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\nWait, let me compute each term step-by step:\n\nFirst term: a\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\nSecond term: a\u00b2c = same as first term:1,296,000\n\nThird term: b\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nFourth term: b\u00b2c =90\u00b2 *90 =8,100 *90 =729,000\n\nFifth term: c\u00b2a = same as third term 972,000\n\nSixth term: c\u00b2b = same as fourth term 729,000\n\nNow adding all terms:\n\nFirst two terms: 1,296,000 +1,296,000 =2,592,000\n\nThird and fourth: 72,000 +729,000 =1,701,000 \u2192 total so far 4,293,000\n\nFifth and sixth:972,000 +729,000 =1,701,000 \u2192 total 6,000,000. So it does add up.\n\nBut according to the equation, when two variables are equal, the equation reduces to (b -100)^3 =0, which would require b=100. So why is there a discrepancy?\n\nAh! Wait a second, perhaps think I made a mistake in the substitution. Let me recheck the substitution step.\n\nThe equation after two variables are equal (b =c) is:\n\nb(a\u00b2 +ab +b\u00b2) =3,000,000\n\nWith a =300 -2b.\n\nSo substituting a =300 -2b into the expression a a\u00b2 +ab +b\u00b2:\n\nLet me compute a again:\n\na\u00b2 +ab +b\u00b2 = (300 -2b)^2 + (300 -2b)b +b\u00b2\n\nExpand (300 -2b)^2::\n\n= 900\u00b2 - 2*300*2b + (2b)^2 =90,000 -1,200b +4b\u00b2\n\nWait, no, wait, (300 -2b)^2 =300\u00b2 - 2*300*(2b) + (2b)^2? Wait, wait's incorrect.\n\nWait, (a - b)^2 =a\u00b2 -2ab +b\u00b2. So (300 -2b)^2 =300\u00b2 -2*300*(2b) + (2b)^2 =90,000 -1,200b +4b\u00b2. That part is correct.\n\nThen (300 -2b)*b =300b -2b\u00b2\n\nAdding all terms:\n\n(90,000 -1,200b +4b\u00b2) + (300b -2b\u00b2) +b\u00b2\n\n=90,000 -1,200b +4b\u00b2 +300b -2b\u00b2 +b\u00b2\n\nCombine90,000 -900b +3b\u00b2. That is correct.\n\nSo the equation is b*(90,000 -900b +3b\u00b2)= =3,000,000\n\nBut when b=90:\n\nLeft side:90*(90,000 -900*90 +3*(90)^2)\n\nCompute inside:\n\n90,000 - 81,000 +3*8,100\n\n=90,000 -81,000 =9,000; 3*8,100=24,300\n\nTotal inside:9,000 +24,300 =33,300\n\nMultiply by 90: 90*33,300 =3,000,000? Let me compute:\n\n33,300 *90 = (33,000 *90) + (300 *90) =2,970,000 +27,000 =2,997,000. Which is not 3,000,000. So discrepancy is a discrepancy here here.\n\nWait, but when I computed manually, the total was 6,000,000. So there must be an error in my manual calculation.\n\nWait, let me recompute the example:\n\na=120, b=90, c=90.\n\nCompute each term:\n\na\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000\n\nc\u00b2a = same as b\u00b2a =972,000\n\nc\u00b2b = same as b\u00b2c =729,000\n\nAdding all terms:\n\n1,296,000 +1,296,000 =2,592,000\n\n+972,000 +729,000 =1,701,000 \u2192 total 4,293,000\n\n+972,000 +729,000 =1,701,000 \u2192 total 6,000,000.\n\nWait, but according to the equation, it should be 2,997,000. So there is a contradiction. Therefore I must have made a mistake in the equation setup.\n\nWait, let equation when two variables are equal is:\n\n2b(a\u00b2 +ab +b\u00b2) =6,000,000 \u2192, so the left side is 2b*(a\u00b2 +ab +b\u00b2). In the example, a=120, b=90:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300 =5,994,000, which is not 6,000,000. Wait when I computed manually, I got 6,000,000. So there must be an error in manual manual calculation.\n\nWait, let me recompute the terms:\n\na\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total for a\u00b2 terms: 2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000 \u2192 total for b\u00b2 terms: 1,701,000\n\nc\u00b2a =90\u00b2 *120 = same as b\u00b2a =972,000\n\nc\u00b2b =90\u00b2 *90 = same as b\u00b2c =729,000 \u2192 total for c\u00b2 terms: 1,701,000\n\nTotal sum:2,592,000 +1,701,000 +1,701,000 =6,000,000? Wait, 2,592,000 +1,701,000 is=4,293,000 +1,701,000 =6,000,000. But according to the equation, it should be 5,994,000. There is a discrepancy of 6,000,000 -5,994,000 =6,000.\n\nWait, so where is the mistake?\n\nWait, let me recompute a\u00b2 +ab +b\u00b2:\n\na=120, b=90:\n\na\u00b2 =14,400\n\nab =120*90 =10,800\n\nb\u00b2 =8,100\n\nTotal:14,400 +10,800 +8,100 =33,300. Correct that's correct.\n\nThen 2b*(a\u00b2 +ab +b\u00b2) =2*90*33,300 =180*33,300 =5,994,000. But manual calculation gives 6,000,000. So there is a mistake in the manual calculation.\n\nWait, let me recompute the terms again carefullya\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total a\u00b2 terms: 2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000 \u2192 total b\u00b2 terms: 1,701,000\n\nc\u00b2a =90\u00b2 *120 = same as b\u00b2a =972,000\n\nc\u00b2b =90\u00b2 *90 = same as b\u00b2c =729,000 \u2192 total c\u00b2 terms: 1,701,000\n\nTotal sum:2,592,000 +1,701,000 +1,701,000 =6,000,000.Wait according to the equation, it should be 5,994,000. There is a discrepancy of 6,000. So where must be an error in the equation setup.\n\nWait, perhaps me check the equation again.The second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b =6,000,000.\n\nWhen two variables are equal ( say b =c, then:\n\na\u00b2b +a\u00b2c =a\u00b2b2b\n\nb\u00b2a +b\u00b2c =b\u00b2a +b\u00b3b =b\u00b2(a +b\u00b3\n\nc\u00b2a +c\u00b2b =c\u00b2a +c\u00b2b = same as b\u00b2a +b\u00b2b =b\u00b2a +b\u00b3\n\nSo total equation2a\u00b2b + 2b\u00b2a + 2b\u00b3 =2a\u00b2b +2ab\u00b2 +2b\u00b3 =2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, in the example, a I computed manually the terms, I got 6,000,000, but according to the equation, it should be 5,994,000. So there must be an arithmetic error in the manual calculation.\n\nWait, let me recompute the terms:\n\nFirst=120, b=90, c=90.\n\na\u00b2b =120\u00b22 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total a\u00b2 terms: 2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =9,100 *90 =729,000 \u2192 total b\u00b2 terms: 1,701,000\n\nc\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nc\u00b2b =90\u00b2 *90 =8,100 *90 =729,000 \u2192 total c\u00b2 terms: 1,701,000\n\nTotal sum:2,592,000 +1,701,000 +1,701,000 =6,000,000\n\nWait, but according to the equation, it should be 2b*(a\u00b2 +ab +b\u00b2)= =2*90*(14,400 +10,800 +8,100)= 180*(33,300)=5,994,000. So there is a discrepancy of 6,000,000 -5,994,000 =6,000.\n\nWait, so where must be an mistakeiscalculation in the equation setup.\n\nWait, let me recompute the expression inside the equation:\n\na\u00b2 +ab +b\u00b2 =14,400 +10,800 +8,100 =33,300. Correct's correct.\n\nThen 2b*(a\u00b2 +ab +b\u00b2)= =2*90*33,300 =180*33,300.\n\nLet me compute 33,300 *180:\n\n33,300 *100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000.\n\nBut manual calculation gives 6,000,000. So there must be an error in the manual calculation.\n\nWait, let me recompute the terms again:\n\na\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total a\u00b2 terms: 2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000 \u2192 total b\u00b2 terms: 1,701,000\n\nc\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nc\u00b2b =90\u00b2 *90 =8,100 *90 =729,000 \u2192 total c\u00b2 terms: 1,701,000\n\nAdding all terms:\n\n2,592,000 +1,701,000 =4,293,000\n\n+1,701,000 =6,000,000\n\nWait, but according to the equation, it should be 5,994,000. So there is a discrepancy of 6,000. Where suggests't make sense. Unless I made a mistake in the equation setup?\n\nWait, let me check the original equation again:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b =6,000,000.\n\nWhen two variables are equal, say b =c, then:\n\na\u00b2b +a\u00b2c =a\u00b2(b +c) =a\u00b2*(2b)\n\nb\u00b2a +b\u00b2c =b\u00b2a +b\u00b2b =b\u00b2(a +b)\n\nc\u00b2a +c\u00b2b =c\u00b2a +c\u00b2b =c\u00b2(a +b) =b\u00b2(a +b)\n\nSo total:\n\n2a\u00b2b + \u00b2(a +b) +b\u00b2(a +b) =2a\u00b2b +2b\u00b2(a +b)\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3 =2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90. Sum is 120+90+90=300. Correct's correct.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\nWait, but 120 squared is 14,400, yes thatFirst two terms:14,400*90*2 =28,800*90 =2,592,000\n\nNext two terms:90\u00b2*120 +90\u00b2*90 =8,100*120 +8,100*90 =972,000 +729,000 =1,701,000\n\nLast two terms: same as the next two terms: another72,000 +729,000 =1,701,000\n\nTotal:2,592,000 +1,701,000 +1,701,000 =6,000,000\n\nBut according to the equation, it should be 2,994,000. So there must be an mistake in the equation setup.\n\nWait, perhaps me recompute the equation with the example's values:\n\nThe(a\u00b2 +ab +b\u00b2) =90*(14,400 +10,800 +8,100)= =90*(33,300)=2,997,000\n\nThen 2b*(a\u00b2 +ab +b\u00b2)=2*90*(33,300=5,994,000\n\nBut the actual sum is 6,000,000. So there is a difference of 6,000. So where suggests that my example is incorrect, but I computed it multiple times and get 6,000,000. So perhaps must be an error in the equation setup.\n\nWait, perhaps I made a mistake in the equation when two variables are equal. Let me re-derive it.\n\nLet me suppose b =c, so variables are (a, b, b.\n\n with a +2b =300.\n\nThe second equation is:\n\na\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct.\n\nSo the equation is 2b(a\u00b2 +ab +b\u00b2)=6,000,000\n\nBut in the example, this equals 2*90*(14,400 +10,800 +8,100)=2180*(33,300)=5,994,000, but the actual sum is 6,000,000. So there must be an mistake in the example.\n\nWait, perhaps I made a mistake in the example's variables. Let me check:\n\nWait, if a=120, b=90, c=90, then a120+90+90=300. That's correct.\n\nWait, let me computecompute the terms:\n\na\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total a\u00b2 terms:2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000 \u2192 total b\u00b2 terms:1,701,000\n\nc\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nc\u00b2b =9,100 *90 =729,000 \u2192 total c\u00b2 terms:1,701,000\n\nTotal sum:2,592,000 +1,701,000 +1,701,000 =6,000,000\n\nBut according to the equation, it should be 5,994,000. So there is a discrepancy of 6,000. This is very confusing.\n\nWait, perhaps I made a mistake in the equation setup. Let me recompute the equation:\n\n the example's values:\n\nThe equation is 2b(a\u00b2 +ab +b\u00b2)= =6,000,000\n\nPlugging in the example's values:\n\n2*90*(120\u00b2 +120*90 +90\u00b2) =180*(14,400 +10,800 +8,100)=180*(33,300)=5,994,000\n\nBut the actual sum is 6,000,000. So there is a mistake somewhere.\n\nWait, perhaps I made a mistake in the arithmetic:\n\nWait, 33,300 *180:\n\n33,300 * 00 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nYes.\n\nBut the manual calculation gives 6,000,000. So there must be an error in the manual calculation.\n\nWait, let me computecompute the terms:\n\n one more time:\n\na\u00b22b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total a\u00b2 terms:2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000 \u2192 total b\u00b2 terms:1,701,000\n\nc\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nc\u00b2b =8,100 *90 =729,000 \u2192 total c\u00b2 terms:1,701,000\n\nAdding all terms:\n\n2,592,000 +1,701,000 =4,293,000\n\n+1,701,000 =6,000,000\n\nWait, but according to the equation, it should be 5,994,000. So there is a discrepancy of 6,000. This is perplexing.\n\nWait, perhaps I made a mistake in the equation setup. Let me recheck the equation:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b =When b =c, this becomes:\n\na\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\nYes2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct.\n\nSo the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the example's variables. Let me check:\n\n the values again:\n\na=120, b=90, c=90.\n\nCompute a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b:\n\na\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total a\u00b2 terms:2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000 \u2192 total b\u00b2 terms:1,701,000\n\nc\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nc\u00b2b =8,100 *90 =729,000 \u2192 total c\u00b2 terms:1,701,000\n\nTotal sum:2,592,000 +1,701,000 +1,701,000 =6,000,000\n\nBut according to the equation, it should be 5,994,000. So there must be an error in the equation setup.\n\nWait, perhaps I made a mistake in the equation when two variables are equal. Let me re-check:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b.\n\nIf b =c, then:\n\na\u00b2b +a\u00b2c =a\u00b2(b +c) =a\u00b2*(2b)\n\nb\u00b2a +b\u00b2c =b\u00b2a +b\u00b2b =b\u00b2(a +b)\n\nc\u00b2a +c\u00b2b =c\u00b2a +c\u00b2b =c\u00b2(a +b) =b\u00b2(a +b)\n\nSo total:\n\n2a\u00b2b +b\u00b2(a +b) +b\u00b2(a +b) =2a\u00b2b +2b\u00b2(a +b)\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3 =2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic of calculating the example's terms. Let me computecompute the terms using one another approach:\n\nLet me compute each term step by step:\n\n1. a\u00b2b = (20\u00b2 *90 =14,400 *90 =1,296,000\n\n2. a\u00b2c =14,400 *90 =1,296,000\n\n3. b\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\n4. b\u00b2c =8,100 *90 =729,000\n\n5. c\u00b2a =8,100 *120 =972,000\n\n6. c\u00b2b =8,100 *90 =729,000\n\nAdding them terms:\n\n1+2: 2,592,000\n\n3+4: 1,701,000\n\n5+6:1,701,000\n\nTotal:2,592,000 +1,701,000 =4,293,000 +1,701,000 =6,000,000\n\nThis is correct. So the equation must be wrong.Wait, but according to the equation, it should be 5,994,000. So there must be a mistake in the equation setup.\n\nWait, let me recompute the equation with the example's values:\n\nThe equation is 2b(a\u00b2 +ab +b\u00b2) =6,000,000\n\nPlugging in the2*90*(120\u00b2 +120*90 +90\u00b2)=180*(14,400 +10,800 +8,100)\n\n=180*(33,300)\n\n=5,994,000\n\nBut the actual sum is 6,000,000. So there is a discrepancy of 6,000. This suggests that there is an error in the equation setup.\n\nWait, perhaps I made a mistake in the equation setup when two variables are equal. Let me re-derive it againLet me recompute the expression:\n\nWhen two =c,The second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +222b\u00b2a + 2b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(109,800) +2*(729,000)\n\nWait, wait, think I made a mistake here:\n\nWait, 90\u00b2*120 is 8,100*120 =972,000, not 109,800. So:\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 = =2,592,000 +1,944,000 =4,536,000 +1,458,000 =6,000,000\n\nWait according to the equation:\n\n2b(a\u00b2 +ab +b\u00b2)= =2*90*(14,400 +10,800 +8,100)=180*(33,300)=5,994,000\n\nWait there is a discrepancy of 6,000,000 -5,994,000 =6,000.\n\nWait, but where terms are:\n\n2*(a\u00b2b) +2*(b\u00b2a) +2*(b\u00b3)\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3 = =2b(a\u00b2 +ab +b\u00b2)\n\nYes in the example:\n\n2a\u00b2b =2*14,400*90 =2*1,296,000 =2,592,000\n\n2ab\u00b2 =2*120*8,100 =2*972,000 =1,944,000\n\n2b\u00b3 =2*(729,000 =1,458,000\n\nTotal:2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nBut according to the equation,2b(a\u00b2 +ab +b\u00b2)=2b90*(14,400 +10,800 +8,100)=2*90*(33,300)=5,994,000\n\nWait, but 2b*(\u00b2 +ab +b\u00b2)= is 2*90*(14,400 +10,800 +8,100)= 180*(33,300)=5,994,000\n\nBut the actual terms are 6,000,000. So there must be an mistake in the equation setup.\n\nWait, perhaps's possible that I made a mistake in the equation setup. Let me recompute the expression inside the equation:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nYes, that's correct.\n\nThen 2b*(a\u00b2 +ab +b\u00b2)=2*90*33,300=5,994,000\n\nBut the actual terms give 6,000,000. So there is a mistake somewhere the equation setup.\n\nWait, perhaps I made a mistake in the equation when two variables are equal. Let me re-derive it againThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c,= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b ++a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nBut according to the equation, it should be 5,994,000. So there is a discrepancy of 6,000. This is't possible. unless there is a calculation error.\n\nWait, let me compute 20\u00b2*90:\n\n120 squared is 14,400. 14,400 *90 =1,296,000. That's correct.\n\n90 squared is 8,100. 8,100 *120 =972,000. Correct.\n\n8,100 *90 =729,000. Correct.\n\nSo the terms are correct. So the equation must be wrong.\n\nWait, perhaps I made a mistake in the equation setup. Let me recheck:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c, this is:\n\na\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me computecompute 2b(a\u00b2 +ab +b\u00b2):\n\nWith b=90, a=120:\n\na\u00b2 +ab +b\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply\n\nOkay, I think I need to take a step back. here. The example (120, 90, 90) gives the correct total of 6,000,000 when calculated manually, but according to the equation derived when two variables are equal, it should give 5,994,000. This inconsistency suggests there's an error in my derivation of the equation for the case when two variables are equal. However, after multiple checks, the derivation seems correct. This contradiction't make sense. Perhaps I made a mistake in the example's variables?\n\nWait, let me try another example where two variables are equal. Let's take b =c =100, then a =300 -200 =100. So the triple is (100,100,100). which should work. Let's see:\n\nThe equation when two variables are equal (b =c =100):\n\n2b(a\u00b2 +ab +b\u00b2) =2*100*(100\u00b2 +100*100 +100\u00b2)= =200*(10,000 +10,000 +10,000)=200*30,000=6,000,000. Which matches the required value. So that works.\n\nAnother example: let me take b =c =50, so a =300 -100 =200. So the triple is (200,50,50). Let's compute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=200\u00b2*50 +200\u00b2*50 +50\u00b2*200 +50\u00b2*50 +50\u00b2*200 +50\u00b2*50\n\n=2*(40,000*50) +2*(2,500*200) +2*(2,500*50)\n\n=2*(2,000,000 +2*500,000 +2*125,000\n\n=4,000,000 +1,000,000 +250,000 =5,250,000\n\nAccording to the equation:\n\n2b(a\u00b2 +ab +b\u00b2) =2*50*(200\u00b2 +200*50 +50\u00b2)\n\n=100*(40,000 +10,000 +2,500)=100*(52,500)=5,250,000. Which matches.\n\nSo the equation works for this case. But in the case of b=90, the gives a discrepancy. So why?\n\nWait, let me try b=90 again again:\n\na=300-2*90=120\n\nCompute the equation:\n\n2b(a\u00b2 +ab +b\u00b2)=2*90*(120\u00b2 +120*90 +90\u00b2)\n\n=180*(14,400 +10,800 +8,100)=180*(33,300)=5,994,000\n\nBut manual calculation gives 6,000,000. So there must be an error in the manual calculation.\n\nWait, let me recompute the terms for (120,90,90):\n\na\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total a\u00b2 terms:2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000 \u2192 total b\u00b2 terms:1,701,000\n\nc\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nc\u00b2b =8,100 *90 =729,000 \u2192 total c\u00b2 terms:1,701,000\n\nTotal sum:2,592,000 +1,701,000 +1,701,000 =6,000,000\n\nBut according to the equation, it should be 5,994,000. So is a contradiction. I must have made a mistake in the equation setup.\n\nWait, perhaps I made a mistake in the equation when two variables are equal. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\nWait, wait, 90\u00b2*120 is 8,100*120 =972,000, not 972,000. So:\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nBut according to the equation:\n\n2b(a\u00b2 +ab +b\u00b2)=2*90*(14,400 +10,800 +8,100)=180*(33,300)=5,994,000\n\nThere is a discrepancy of 6,000. This suggests't make sense. I must be missing something.\n\nWait, perhaps I made a mistake in the equation setup. Let me recompute the expression inside when two variables are equal:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\nLet3,300 *180:\n\n33,300 *100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is be an error in the manual calculation.\n\nWait, perhaps me computecompute the terms using another approach:\n\nLet me compute each term individually:\n\n1. a\u00b2b =120\u00b22 *90 =14,400 *90 =1,296,000\n\n2. a\u00b2c =14,400 *90 =1,296,000\n\n3. b\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\n4. b\u00b2c =8,100 *90 =729,000\n\n5. c\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\n6. c\u00b2b =8,100 *90 =729,000\n\nAdding all terms:\n\n1+2 =2,592,000\n\n3+4 =1,701,000\n\n5+6 =1,701,000\n\nTotal:2,592,000 +1,701,000 =4,293,000 +1,701,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me try to compute the equation with the example's values using using the original equation:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b =6,000,000\n\nWith a=120, b=90, c=90:\n\n120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nYes according to the equation when two variables are equal:\n\n2b(a\u00b2 +ab +b\u00b2)=2*90*(14,400 +10,800 +8,100)=5,994,000\n\nThis suggests't add up. There must be a mistake in the equation setup.\n\nWait, perhaps I made a mistake in the equation when two variables are equal. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is't possible. unless there is an calculation error.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nYes is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me try to compute the equation with the example's values using the original equation:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b =6,000,000\n\nWith a=120, b=90, c=90:\n\n120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nYes is correct. So the equation must two variables are equal must be wrong. But I can't see why.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is perplexing.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me try to compute the equation with the example's values using the original equation:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b =6,000,000\n\nWith a=120, b=90, c=90:\n\n120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation when two variables are equal must be wrong. But I can't see why.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is perplex.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me try to compute the equation with the example's values using the original equation:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b =6,000,000\n\nWith a=120, b=90, c=90:\n\n120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation when two variables are equal must be wrong. But I can't see why.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it"
455
+ print(evaluator_map["aime2024"].llm_judge(response, "601", "", False))
rosetta/utils/registry.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified registry utilities and simple JSON-based save/load helpers.
3
+
4
+ This module provides:
5
+ - create_registry: factory to create (registry dict, register decorator, get_class)
6
+ - capture_init_args: decorator to record __init__ kwargs on instances as _init_args
7
+ - save_object / load_object: serialize/deserialize object configs via registry
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import inspect
13
+ import json
14
+ from typing import Dict, Type, Callable, Optional, Tuple, TypeVar, Any
15
+ import torch
16
+
17
+ T = TypeVar("T")
18
+
19
+
20
+ def create_registry(
21
+ registry_name: str,
22
+ case_insensitive: bool = False,
23
+ ) -> Tuple[Dict[str, Type[T]], Callable[..., Type[T]], Callable[[str], Type[T]]]:
24
+ """
25
+ Create a registry system with register and get functions.
26
+
27
+ Args:
28
+ registry_name: Name used in error messages (e.g., "projector")
29
+ case_insensitive: Whether to store lowercase versions of names
30
+
31
+ Returns:
32
+ (registry_dict, register_function, get_function)
33
+ """
34
+
35
+ registry: Dict[str, Type[T]] = {}
36
+
37
+ def register(cls_or_name=None, name: Optional[str] = None):
38
+ """Register a class in the registry. Supports multiple usage patterns.
39
+
40
+ Usage:
41
+ @register
42
+ class Foo: ...
43
+
44
+ @register("foo")
45
+ class Foo: ...
46
+
47
+ @register(name="foo")
48
+ class Foo: ...
49
+ """
50
+
51
+ def _register(c: Type[T]) -> Type[T]:
52
+ # Determine the name to use
53
+ if isinstance(cls_or_name, str):
54
+ class_name = cls_or_name
55
+ elif name is not None:
56
+ class_name = name
57
+ else:
58
+ class_name = c.__name__
59
+
60
+ registry[class_name] = c
61
+ if case_insensitive:
62
+ registry[class_name.lower()] = c
63
+ return c
64
+
65
+ if cls_or_name is not None and not isinstance(cls_or_name, str):
66
+ # Called as @register or register(cls)
67
+ return _register(cls_or_name)
68
+ else:
69
+ # Called as @register("name") or @register(name="name")
70
+ return _register
71
+
72
+ def get_class(name: str) -> Type[T]:
73
+ """Get class by name from registry."""
74
+ if name not in registry:
75
+ # Build readable available list without duplicates when case_insensitive
76
+ seen = set()
77
+ available = []
78
+ for k in registry.keys():
79
+ if k.lower() in seen:
80
+ continue
81
+ seen.add(k.lower())
82
+ available.append(k)
83
+ raise ValueError(
84
+ f"Unknown {registry_name} class: {name}. Available: {available}"
85
+ )
86
+ return registry[name]
87
+
88
+ return registry, register, get_class
89
+
90
+
91
+ def capture_init_args(cls):
92
+ """
93
+ Decorator to capture initialization arguments of a class.
94
+
95
+ Stores the mapping of the constructor's parameters to the values supplied
96
+ at instantiation time into `self._init_args` for later serialization.
97
+ """
98
+ original_init = cls.__init__
99
+
100
+ def new_init(self, *args, **kwargs):
101
+ # Store all initialization arguments
102
+ init_args: Dict[str, Any] = {}
103
+
104
+ # Get parameter names from the original __init__ method
105
+ sig = inspect.signature(original_init)
106
+ param_names = list(sig.parameters.keys())[1:] # Skip 'self'
107
+
108
+ # Map positional args to parameter names
109
+ for i, arg in enumerate(args):
110
+ if i < len(param_names):
111
+ init_args[param_names[i]] = arg
112
+
113
+ # Add keyword args
114
+ init_args.update(kwargs)
115
+
116
+ self._init_args = init_args
117
+
118
+ # Call the original __init__
119
+ original_init(self, *args, **kwargs)
120
+
121
+ cls.__init__ = new_init
122
+ return cls
123
+
124
+
125
+ # -------------------------
126
+ # Serialization utilities
127
+ # -------------------------
128
+
129
+ def _encode_value(value: Any) -> Any:
130
+ """Best-effort JSON encoding for common ML types."""
131
+ # Primitives and None
132
+ if value is None or isinstance(value, (bool, int, float, str)):
133
+ return value
134
+
135
+ # Tuples -> lists
136
+ if isinstance(value, tuple):
137
+ return [
138
+ _encode_value(v) for v in value
139
+ ]
140
+
141
+ # Lists
142
+ if isinstance(value, list):
143
+ return [
144
+ _encode_value(v) for v in value
145
+ ]
146
+
147
+ # Dicts
148
+ if isinstance(value, dict):
149
+ return {k: _encode_value(v) for k, v in value.items()}
150
+
151
+ # torch-specific types
152
+ if torch is not None:
153
+ # torch.dtype
154
+ if isinstance(value, type(getattr(torch, "float32", object))):
155
+ # Guard: torch.dtype is not a class; rely on str(value) format
156
+ s = str(value)
157
+ if s.startswith("torch."):
158
+ return {"__type__": "torch.dtype", "value": s.split(".")[-1]}
159
+
160
+ # torch.device
161
+ if isinstance(value, getattr(torch, "device", ())):
162
+ return {"__type__": "torch.device", "value": str(value)}
163
+
164
+ # Fallback to string representation
165
+ return {"__type__": "str", "value": str(value)}
166
+
167
+
168
+ def _decode_value(value: Any) -> Any:
169
+ """Decode values produced by _encode_value, recursively for containers."""
170
+ # Lists: decode each element
171
+ if isinstance(value, list):
172
+ return [_decode_value(v) for v in value]
173
+
174
+ # Dicts: either a typed-marker dict or a regular mapping that needs recursive decoding
175
+ if isinstance(value, dict):
176
+ if "__type__" in value:
177
+ t = value.get("__type__")
178
+ v = value.get("value")
179
+
180
+ if t == "torch.dtype" and torch is not None:
181
+ dtype = getattr(torch, str(v), None)
182
+ if dtype is None:
183
+ raise ValueError(f"Unknown torch.dtype: {v}")
184
+ return dtype
185
+
186
+ if t == "torch.device" and torch is not None:
187
+ return torch.device(v)
188
+
189
+ if t == "str":
190
+ return str(v)
191
+
192
+ # Unknown type marker; return raw as-is
193
+ return value
194
+
195
+ # Regular dict: decode values recursively
196
+ return {k: _decode_value(v) for k, v in value.items()}
197
+
198
+ # Primitives and anything else: return as-is
199
+ return value
200
+
201
+
202
+ def save_object(obj: Any, file_path: str) -> None:
203
+ """
204
+ Save an object's construction config to a JSON file.
205
+
206
+ The object is expected to have been decorated with capture_init_args,
207
+ so that `obj._init_args` exists.
208
+ """
209
+ class_name = obj.__class__.__name__
210
+ init_args = getattr(obj, "_init_args", {})
211
+
212
+ serializable_args = _encode_value(init_args)
213
+ payload = {
214
+ "class": class_name,
215
+ "init_args": serializable_args,
216
+ }
217
+
218
+ with open(file_path, "w", encoding="utf-8") as f:
219
+ json.dump(payload, f, indent=2)
220
+
221
+
222
+ def load_object(
223
+ file_path: str,
224
+ get_class_fn: Callable[[str], Type[T]],
225
+ override_args: Optional[Dict[str, Any]] = None,
226
+ ) -> T:
227
+ """
228
+ Load an object from a JSON config file previously saved by save_object.
229
+
230
+ Args:
231
+ file_path: Path to JSON file
232
+ get_class_fn: Function to resolve class names from registry
233
+ override_args: Optional dict to override stored init args
234
+
235
+ Returns:
236
+ Instantiated object of type T
237
+ """
238
+ with open(file_path, "r", encoding="utf-8") as f:
239
+ payload = json.load(f)
240
+
241
+ class_name = payload["class"]
242
+ encoded_args = payload.get("init_args", {})
243
+ init_args = _decode_value(encoded_args)
244
+
245
+ if override_args:
246
+ init_args.update(override_args)
247
+
248
+ cls = get_class_fn(class_name)
249
+ return cls(**init_args)
250
+
251
+
252
+ def dumps_object_config(obj: Any) -> str:
253
+ """Return a JSON string with the object's class and init args."""
254
+ class_name = obj.__class__.__name__
255
+ init_args = getattr(obj, "_init_args", {})
256
+ serializable_args = _encode_value(init_args)
257
+ return json.dumps({"class": class_name, "init_args": serializable_args}, indent=2)
258
+
259
+
260
+ def loads_object_config(
261
+ s: str,
262
+ get_class_fn: Callable[[str], Type[T]],
263
+ override_args: Optional[Dict[str, Any]] = None,
264
+ ) -> T:
265
+ """Instantiate an object from a JSON string produced by dumps_object_config."""
266
+ payload = json.loads(s)
267
+ class_name = payload["class"]
268
+ encoded_args = payload.get("init_args", {})
269
+ init_args = _decode_value(encoded_args)
270
+ if override_args:
271
+ init_args.update(override_args)
272
+ cls = get_class_fn(class_name)
273
+ return cls(**init_args)
274
+
275
+
276
+ # Model Registry System (case-insensitive for backward compatibility)
277
+ PROJECTOR_REGISTRY, register_model, get_projector_class = create_registry(
278
+ "projector", case_insensitive=True
279
+ )