Spaces:
Runtime error
Runtime error
[init] demo
Browse files- app.py +572 -0
- requirements.txt +20 -0
- rosetta/__pycache__/__init__.cpython-310.pyc +0 -0
- rosetta/baseline/__pycache__/multi_stage.cpython-310.pyc +0 -0
- rosetta/baseline/__pycache__/multi_stage.cpython-312.pyc +0 -0
- rosetta/baseline/__pycache__/two_stage_rosetta.cpython-310.pyc +0 -0
- rosetta/baseline/__pycache__/two_stage_rosetta.cpython-312.pyc +0 -0
- rosetta/baseline/multi_stage.py +992 -0
- rosetta/model/__pycache__/__init__.cpython-310.pyc +0 -0
- rosetta/model/__pycache__/aggregator.cpython-310.pyc +0 -0
- rosetta/model/__pycache__/aggregator.cpython-312.pyc +0 -0
- rosetta/model/__pycache__/aligner.cpython-310.pyc +0 -0
- rosetta/model/__pycache__/aligner.cpython-312.pyc +0 -0
- rosetta/model/__pycache__/all_in_one_projector.cpython-310.pyc +0 -0
- rosetta/model/__pycache__/all_in_one_projector.cpython-312.pyc +0 -0
- rosetta/model/__pycache__/cache.cpython-310.pyc +0 -0
- rosetta/model/__pycache__/oracle.cpython-310.pyc +0 -0
- rosetta/model/__pycache__/oracle.cpython-312.pyc +0 -0
- rosetta/model/__pycache__/projector.cpython-310.pyc +0 -0
- rosetta/model/__pycache__/projector.cpython-312.pyc +0 -0
- rosetta/model/__pycache__/sampling.cpython-310.pyc +0 -0
- rosetta/model/__pycache__/sampling.cpython-312.pyc +0 -0
- rosetta/model/__pycache__/wrapper.cpython-310.pyc +0 -0
- rosetta/model/__pycache__/wrapper.cpython-312.pyc +0 -0
- rosetta/model/ablation_projector.py +291 -0
- rosetta/model/aligner.py +639 -0
- rosetta/model/oracle.py +562 -0
- rosetta/model/projector.py +1194 -0
- rosetta/model/sampling.py +82 -0
- rosetta/model/wrapper.py +840 -0
- rosetta/train/__init__.py +20 -0
- rosetta/train/__pycache__/__init__.cpython-310.pyc +0 -0
- rosetta/train/__pycache__/__init__.cpython-312.pyc +0 -0
- rosetta/train/__pycache__/dataset_adapters.cpython-310.pyc +0 -0
- rosetta/train/__pycache__/dataset_adapters.cpython-312.pyc +0 -0
- rosetta/train/__pycache__/model_utils.cpython-310.pyc +0 -0
- rosetta/train/__pycache__/model_utils.cpython-312.pyc +0 -0
- rosetta/train/__pycache__/trainer.cpython-310.pyc +0 -0
- rosetta/train/__pycache__/trainer.cpython-312.pyc +0 -0
- rosetta/train/dataset_adapters.py +1648 -0
- rosetta/train/model_utils.py +142 -0
- rosetta/utils/__pycache__/evaluate.cpython-310.pyc +0 -0
- rosetta/utils/__pycache__/evaluate.cpython-312.pyc +0 -0
- rosetta/utils/__pycache__/multi_stage.cpython-310.pyc +0 -0
- rosetta/utils/__pycache__/multi_stage.cpython-312.pyc +0 -0
- rosetta/utils/__pycache__/registry.cpython-310.pyc +0 -0
- rosetta/utils/__pycache__/registry.cpython-312.pyc +0 -0
- rosetta/utils/evaluate.py +639 -0
- rosetta/utils/matheval.py +455 -0
- rosetta/utils/registry.py +279 -0
app.py
ADDED
|
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio Side-by-Side Model Comparison Demo
|
| 3 |
+
|
| 4 |
+
This creates a web interface to compare three inference modes simultaneously:
|
| 5 |
+
1. Single: Regular HuggingFace model
|
| 6 |
+
2. T2T: Two-stage inference (shows context + answer)
|
| 7 |
+
3. C2C: Rosetta model with projectors
|
| 8 |
+
|
| 9 |
+
All models are loaded at startup and respond to the same input in parallel.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import torch
|
| 15 |
+
import argparse
|
| 16 |
+
import gradio as gr
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Optional, Generator
|
| 19 |
+
from queue import Queue
|
| 20 |
+
from threading import Thread
|
| 21 |
+
|
| 22 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
| 23 |
+
from rosetta.utils.evaluate import load_rosetta_model, load_hf_model, set_default_chat_template
|
| 24 |
+
from rosetta.model.wrapper import RosettaModel
|
| 25 |
+
from rosetta.baseline.multi_stage import TwoStageInference
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ModelManager:
|
| 29 |
+
"""Manages loading and inference for all three model types."""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
single_model_name: str = "Qwen/Qwen2.5-0.5B-Instruct",
|
| 34 |
+
t2t_context_model: str = "Qwen/Qwen3-0.6B",
|
| 35 |
+
t2t_answer_model: str = "Qwen/Qwen2.5-0.5B-Instruct",
|
| 36 |
+
c2c_checkpoint_path: str = "local/checkpoints/qwen3_0.6b+qwen2.5_0.5b_Fuser",
|
| 37 |
+
device: str = "cuda"
|
| 38 |
+
):
|
| 39 |
+
"""
|
| 40 |
+
Initialize ModelManager with model configurations.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
single_model_name: HuggingFace model name for single mode
|
| 44 |
+
t2t_context_model: Context model for T2T mode
|
| 45 |
+
t2t_answer_model: Answer model for T2T mode
|
| 46 |
+
c2c_checkpoint_path: Path to C2C checkpoint directory
|
| 47 |
+
device: Device to use (cuda, cpu, or auto)
|
| 48 |
+
"""
|
| 49 |
+
if device == "auto":
|
| 50 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 51 |
+
else:
|
| 52 |
+
self.device = torch.device(device)
|
| 53 |
+
print(f"Using device: {self.device}")
|
| 54 |
+
|
| 55 |
+
# Model configurations
|
| 56 |
+
self.single_model_name = single_model_name
|
| 57 |
+
self.t2t_context_model = t2t_context_model
|
| 58 |
+
self.t2t_answer_model = t2t_answer_model
|
| 59 |
+
self.c2c_checkpoint_path = c2c_checkpoint_path
|
| 60 |
+
|
| 61 |
+
# T2T prompt configurations
|
| 62 |
+
self.t2t_background_prompt = "Briefly describe the most useful background to answer the question:\n\n{question}"
|
| 63 |
+
self.t2t_answer_prompt = "Based on the background, answer the question:\n\n{question}" # Format for second round question
|
| 64 |
+
self.t2t_context_max_tokens = 512
|
| 65 |
+
self.t2t_answer_max_tokens = 512
|
| 66 |
+
|
| 67 |
+
# Generation configuration (shared across all models)
|
| 68 |
+
# To enable sampling: set use_sampling=True and adjust temperature/top_p/top_k
|
| 69 |
+
# Current mode: Greedy decoding (do_sample=False)
|
| 70 |
+
self.use_sampling = False # Set to True to enable sampling
|
| 71 |
+
self.temperature = 0.7 # Used when use_sampling=True
|
| 72 |
+
self.top_p = 0.8 # Used when use_sampling=True
|
| 73 |
+
self.top_k = 20 # Used when use_sampling=True
|
| 74 |
+
|
| 75 |
+
# Initialize models
|
| 76 |
+
self.single_model = None
|
| 77 |
+
self.single_tokenizer = None
|
| 78 |
+
self.t2t_model = None
|
| 79 |
+
self.c2c_model = None
|
| 80 |
+
self.c2c_tokenizer = None
|
| 81 |
+
|
| 82 |
+
# C2C model names (will be loaded from config)
|
| 83 |
+
self.c2c_base_model = None
|
| 84 |
+
self.c2c_teacher_model = None
|
| 85 |
+
|
| 86 |
+
print("=" * 60)
|
| 87 |
+
print("Initializing models... This may take a few minutes.")
|
| 88 |
+
print("=" * 60)
|
| 89 |
+
|
| 90 |
+
self._load_all_models()
|
| 91 |
+
|
| 92 |
+
def _load_single_model(self):
|
| 93 |
+
"""Load single HuggingFace model."""
|
| 94 |
+
print(f"\n[Single] Loading {self.single_model_name}...")
|
| 95 |
+
self.single_model, self.single_tokenizer = load_hf_model(
|
| 96 |
+
self.single_model_name, self.device
|
| 97 |
+
)
|
| 98 |
+
set_default_chat_template(self.single_tokenizer, self.single_model_name)
|
| 99 |
+
print("[Single] ✓ Model loaded")
|
| 100 |
+
|
| 101 |
+
def _load_t2t_model(self):
|
| 102 |
+
"""Load two-stage model."""
|
| 103 |
+
print(f"\n[T2T] Loading two-stage model...")
|
| 104 |
+
print(f" Context: {self.t2t_context_model}")
|
| 105 |
+
print(f" Answer: {self.t2t_answer_model}")
|
| 106 |
+
print(f" Background prompt: {self.t2t_background_prompt}")
|
| 107 |
+
print(f" Answer prompt: {self.t2t_answer_prompt}")
|
| 108 |
+
|
| 109 |
+
self.t2t_model = TwoStageInference(
|
| 110 |
+
context_model_path=self.t2t_context_model,
|
| 111 |
+
answer_model_path=self.t2t_answer_model,
|
| 112 |
+
device=str(self.device),
|
| 113 |
+
background_prompt=self.t2t_background_prompt
|
| 114 |
+
)
|
| 115 |
+
print("[T2T] ✓ Model loaded")
|
| 116 |
+
|
| 117 |
+
def _load_c2c_model(self):
|
| 118 |
+
"""Load Rosetta (C2C) model."""
|
| 119 |
+
print(f"\n[C2C] Loading Rosetta model from {self.c2c_checkpoint_path}...")
|
| 120 |
+
|
| 121 |
+
# Check if checkpoint exists
|
| 122 |
+
if not Path(self.c2c_checkpoint_path).exists():
|
| 123 |
+
raise FileNotFoundError(
|
| 124 |
+
f"C2C checkpoint not found: {self.c2c_checkpoint_path}\n"
|
| 125 |
+
"You can download the checkpoints automatically with:\n"
|
| 126 |
+
" from huggingface_hub import snapshot_download\n"
|
| 127 |
+
" snapshot_download(\n"
|
| 128 |
+
" repo_id='nics-efc/C2C_Fuser',\n"
|
| 129 |
+
" allow_patterns=['qwen3_0.6b+qwen2.5_0.5b_Fuser/*'],\n"
|
| 130 |
+
" local_dir='local/checkpoints'\n"
|
| 131 |
+
" )\n"
|
| 132 |
+
"Or see the project README for more details."
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Load config
|
| 136 |
+
import yaml
|
| 137 |
+
config_path = Path(self.c2c_checkpoint_path) / "config.json"
|
| 138 |
+
if not config_path.exists():
|
| 139 |
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 140 |
+
|
| 141 |
+
with open(config_path, "r") as f:
|
| 142 |
+
config = yaml.safe_load(f)
|
| 143 |
+
|
| 144 |
+
# Store model names from config
|
| 145 |
+
self.c2c_base_model = config["model"]["base_model"]
|
| 146 |
+
self.c2c_teacher_model = config["model"]["teacher_model"]
|
| 147 |
+
|
| 148 |
+
# Load Rosetta model
|
| 149 |
+
subfolder_dir = Path(self.c2c_checkpoint_path) / "final"
|
| 150 |
+
if not subfolder_dir.exists():
|
| 151 |
+
raise FileNotFoundError(f"Final checkpoint directory not found: {subfolder_dir}")
|
| 152 |
+
|
| 153 |
+
model_config = {
|
| 154 |
+
"model_name": "Rosetta",
|
| 155 |
+
"rosetta_config": {
|
| 156 |
+
"checkpoints_dir": str(subfolder_dir),
|
| 157 |
+
"base_model": self.c2c_base_model,
|
| 158 |
+
"teacher_model": self.c2c_teacher_model,
|
| 159 |
+
"is_do_alignment": config["model"].get("is_do_alignment", False),
|
| 160 |
+
"alignment_strategy": config["model"].get("alignment_strategy", "first")
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
eval_config = {"checkpoints_dir": str(subfolder_dir)}
|
| 165 |
+
|
| 166 |
+
self.c2c_model, self.c2c_tokenizer = load_rosetta_model(
|
| 167 |
+
model_config, eval_config, self.device
|
| 168 |
+
)
|
| 169 |
+
print("[C2C] ✓ Model loaded")
|
| 170 |
+
|
| 171 |
+
def _load_all_models(self):
|
| 172 |
+
"""Load all models sequentially."""
|
| 173 |
+
try:
|
| 174 |
+
self._load_single_model()
|
| 175 |
+
self._load_t2t_model()
|
| 176 |
+
self._load_c2c_model()
|
| 177 |
+
print("\n" + "=" * 60)
|
| 178 |
+
print("✓ All models loaded successfully!")
|
| 179 |
+
print("=" * 60 + "\n")
|
| 180 |
+
except Exception as e:
|
| 181 |
+
print(f"\n✗ Error loading models: {e}")
|
| 182 |
+
raise
|
| 183 |
+
|
| 184 |
+
def _get_generation_kwargs(self, max_new_tokens: int) -> dict:
|
| 185 |
+
"""
|
| 186 |
+
Get generation kwargs with consistent settings across all models.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
max_new_tokens: Maximum number of new tokens to generate
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Dictionary of generation parameters
|
| 193 |
+
"""
|
| 194 |
+
kwargs = {
|
| 195 |
+
'max_new_tokens': max_new_tokens,
|
| 196 |
+
'do_sample': self.use_sampling
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
if self.use_sampling:
|
| 200 |
+
kwargs.update({
|
| 201 |
+
'temperature': self.temperature,
|
| 202 |
+
'top_p': self.top_p,
|
| 203 |
+
'top_k': self.top_k
|
| 204 |
+
})
|
| 205 |
+
|
| 206 |
+
return kwargs
|
| 207 |
+
|
| 208 |
+
def generate_single(self, user_input: str) -> Generator[str, None, None]:
|
| 209 |
+
"""Generate response from single model with streaming."""
|
| 210 |
+
messages = [{"role": "system", "content": ""}, {"role": "user", "content": user_input}]
|
| 211 |
+
text = self.single_tokenizer.apply_chat_template(
|
| 212 |
+
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
|
| 213 |
+
)
|
| 214 |
+
inputs = self.single_tokenizer(text, return_tensors="pt").to(self.device)
|
| 215 |
+
|
| 216 |
+
# Setup streamer
|
| 217 |
+
streamer = TextIteratorStreamer(
|
| 218 |
+
self.single_tokenizer,
|
| 219 |
+
skip_prompt=True,
|
| 220 |
+
skip_special_tokens=True
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Generation parameters
|
| 224 |
+
generation_kwargs = {
|
| 225 |
+
'input_ids': inputs.input_ids,
|
| 226 |
+
'attention_mask': inputs.attention_mask,
|
| 227 |
+
'streamer': streamer,
|
| 228 |
+
**self._get_generation_kwargs(max_new_tokens=2048)
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
# Run generation in separate thread
|
| 232 |
+
thread = Thread(target=self.single_model.generate, kwargs=generation_kwargs)
|
| 233 |
+
thread.start()
|
| 234 |
+
|
| 235 |
+
# Stream tokens
|
| 236 |
+
generated_text = ""
|
| 237 |
+
for token in streamer:
|
| 238 |
+
generated_text += token
|
| 239 |
+
yield generated_text
|
| 240 |
+
|
| 241 |
+
def generate_t2t(self, user_input: str) -> Generator[tuple[str, str], None, None]:
|
| 242 |
+
"""Generate response from T2T model with streaming (returns context, answer)."""
|
| 243 |
+
# Stage 1: Context generation
|
| 244 |
+
context_streamer = TextIteratorStreamer(
|
| 245 |
+
self.t2t_model.context_tokenizer,
|
| 246 |
+
skip_prompt=True,
|
| 247 |
+
skip_special_tokens=True
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
prompt = self.t2t_background_prompt.format(question=user_input)
|
| 251 |
+
inputs = self.t2t_model.context_tokenizer.apply_chat_template(
|
| 252 |
+
[{"role": "user", "content": prompt}],
|
| 253 |
+
tokenize=True,
|
| 254 |
+
add_generation_prompt=True,
|
| 255 |
+
return_tensors="pt",
|
| 256 |
+
enable_thinking=False
|
| 257 |
+
).to(self.device)
|
| 258 |
+
|
| 259 |
+
generation_kwargs = {
|
| 260 |
+
'input_ids': inputs,
|
| 261 |
+
'streamer': context_streamer,
|
| 262 |
+
**self._get_generation_kwargs(max_new_tokens=self.t2t_context_max_tokens)
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
# Generate context in thread
|
| 266 |
+
thread = Thread(target=self.t2t_model.context_model.generate, kwargs=generation_kwargs)
|
| 267 |
+
thread.start()
|
| 268 |
+
|
| 269 |
+
# Stream context tokens
|
| 270 |
+
context_text = ""
|
| 271 |
+
for token in context_streamer:
|
| 272 |
+
context_text += token
|
| 273 |
+
yield context_text, ""
|
| 274 |
+
|
| 275 |
+
thread.join()
|
| 276 |
+
|
| 277 |
+
# Decode full context
|
| 278 |
+
with torch.inference_mode():
|
| 279 |
+
outputs = self.t2t_model.context_model.generate(
|
| 280 |
+
inputs, **self._get_generation_kwargs(max_new_tokens=self.t2t_context_max_tokens)
|
| 281 |
+
)
|
| 282 |
+
context = self.t2t_model.context_tokenizer.batch_decode(
|
| 283 |
+
outputs[:, inputs.shape[-1]:], skip_special_tokens=True
|
| 284 |
+
)[0]
|
| 285 |
+
|
| 286 |
+
# Stage 2: Answer generation
|
| 287 |
+
answer_streamer = TextIteratorStreamer(
|
| 288 |
+
self.t2t_model.answer_tokenizer,
|
| 289 |
+
skip_prompt=True,
|
| 290 |
+
skip_special_tokens=True
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Format the second round question
|
| 294 |
+
answer_question = self.t2t_answer_prompt.format(question=user_input)
|
| 295 |
+
|
| 296 |
+
inputs = self.t2t_model.answer_tokenizer.apply_chat_template(
|
| 297 |
+
[
|
| 298 |
+
{"role": "user", "content": prompt},
|
| 299 |
+
{"role": "assistant", "content": context},
|
| 300 |
+
{"role": "user", "content": answer_question}
|
| 301 |
+
],
|
| 302 |
+
tokenize=True,
|
| 303 |
+
add_generation_prompt=True,
|
| 304 |
+
return_tensors="pt",
|
| 305 |
+
enable_thinking=False
|
| 306 |
+
).to(self.device)
|
| 307 |
+
|
| 308 |
+
generation_kwargs = {
|
| 309 |
+
'input_ids': inputs,
|
| 310 |
+
'streamer': answer_streamer,
|
| 311 |
+
**self._get_generation_kwargs(max_new_tokens=self.t2t_answer_max_tokens)
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
# Generate answer in thread
|
| 315 |
+
thread = Thread(target=self.t2t_model.answer_model.generate, kwargs=generation_kwargs)
|
| 316 |
+
thread.start()
|
| 317 |
+
|
| 318 |
+
# Stream answer tokens
|
| 319 |
+
answer_text = ""
|
| 320 |
+
for token in answer_streamer:
|
| 321 |
+
answer_text += token
|
| 322 |
+
yield context_text, answer_text
|
| 323 |
+
|
| 324 |
+
def generate_c2c(self, user_input: str) -> Generator[str, None, None]:
|
| 325 |
+
"""Generate response from C2C model with streaming."""
|
| 326 |
+
messages = [{"role": "system", "content": ""}, {"role": "user", "content": user_input}]
|
| 327 |
+
text = self.c2c_tokenizer.apply_chat_template(
|
| 328 |
+
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
|
| 329 |
+
)
|
| 330 |
+
inputs = self.c2c_tokenizer(text, return_tensors="pt").to(self.device)
|
| 331 |
+
|
| 332 |
+
# Setup streamer
|
| 333 |
+
streamer = TextIteratorStreamer(
|
| 334 |
+
self.c2c_tokenizer,
|
| 335 |
+
skip_prompt=True,
|
| 336 |
+
skip_special_tokens=True
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
# Prepare C2C-specific inputs
|
| 340 |
+
full_length = inputs.input_ids.shape[1]
|
| 341 |
+
instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(
|
| 342 |
+
full_length - 1, 1
|
| 343 |
+
).unsqueeze(0).to(self.device)
|
| 344 |
+
label_index = torch.tensor([-1, 0], dtype=torch.long).repeat(
|
| 345 |
+
1, 1
|
| 346 |
+
).unsqueeze(0).to(self.device)
|
| 347 |
+
position_ids = inputs.attention_mask.long().cumsum(-1) - 1 if inputs.attention_mask is not None else \
|
| 348 |
+
torch.arange(full_length, dtype=torch.long).unsqueeze(0).to(self.device)
|
| 349 |
+
|
| 350 |
+
# Generation parameters
|
| 351 |
+
generation_kwargs = {
|
| 352 |
+
'kv_cache_index': [instruction_index, label_index],
|
| 353 |
+
'input_ids': inputs.input_ids,
|
| 354 |
+
'attention_mask': inputs.attention_mask,
|
| 355 |
+
'position_ids': position_ids,
|
| 356 |
+
'streamer': streamer,
|
| 357 |
+
**self._get_generation_kwargs(max_new_tokens=2048)
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
# Run generation in separate thread
|
| 361 |
+
thread = Thread(target=self.c2c_model.generate, kwargs=generation_kwargs)
|
| 362 |
+
thread.start()
|
| 363 |
+
|
| 364 |
+
# Stream tokens
|
| 365 |
+
generated_text = ""
|
| 366 |
+
for token in streamer:
|
| 367 |
+
generated_text += token
|
| 368 |
+
yield generated_text
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def create_demo(model_manager: ModelManager):
|
| 372 |
+
"""Create Gradio interface."""
|
| 373 |
+
|
| 374 |
+
# Preset example questions
|
| 375 |
+
EXAMPLE_QUESTIONS = {
|
| 376 |
+
"example1": """Instead of asking why the act of destroying the environment might be immoral, Hill wants to ask ...
|
| 377 |
+
|
| 378 |
+
A. Why the act of destroying nature might be immoral.
|
| 379 |
+
B. Why people who destroy the environment might be bad people.
|
| 380 |
+
C. How the decision to preserve the environment benefits the environment.
|
| 381 |
+
D. Whether plants have interests.""",
|
| 382 |
+
"example2": "Which company launched the Gemini 1.5 Pro model in early 2024?"
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
def respond(user_input: str):
|
| 386 |
+
"""Main response function that yields updates for all three models."""
|
| 387 |
+
if not user_input.strip():
|
| 388 |
+
yield "", "", "", ""
|
| 389 |
+
|
| 390 |
+
# Generators for each model
|
| 391 |
+
single_gen = model_manager.generate_single(user_input)
|
| 392 |
+
t2t_gen = model_manager.generate_t2t(user_input)
|
| 393 |
+
c2c_gen = model_manager.generate_c2c(user_input)
|
| 394 |
+
|
| 395 |
+
single_done = False
|
| 396 |
+
t2t_done = False
|
| 397 |
+
c2c_done = False
|
| 398 |
+
|
| 399 |
+
single_text = ""
|
| 400 |
+
t2t_context = ""
|
| 401 |
+
t2t_answer = ""
|
| 402 |
+
c2c_text = ""
|
| 403 |
+
|
| 404 |
+
# Stream from all three models
|
| 405 |
+
while not (single_done and t2t_done and c2c_done):
|
| 406 |
+
# Update single
|
| 407 |
+
if not single_done:
|
| 408 |
+
try:
|
| 409 |
+
single_text = next(single_gen)
|
| 410 |
+
except StopIteration:
|
| 411 |
+
single_done = True
|
| 412 |
+
|
| 413 |
+
# Update T2T
|
| 414 |
+
if not t2t_done:
|
| 415 |
+
try:
|
| 416 |
+
t2t_context, t2t_answer = next(t2t_gen)
|
| 417 |
+
except StopIteration:
|
| 418 |
+
t2t_done = True
|
| 419 |
+
|
| 420 |
+
# Update C2C
|
| 421 |
+
if not c2c_done:
|
| 422 |
+
try:
|
| 423 |
+
c2c_text = next(c2c_gen)
|
| 424 |
+
except StopIteration:
|
| 425 |
+
c2c_done = True
|
| 426 |
+
|
| 427 |
+
# Yield current state
|
| 428 |
+
yield single_text, t2t_context, t2t_answer, c2c_text
|
| 429 |
+
|
| 430 |
+
# Create Gradio interface
|
| 431 |
+
with gr.Blocks(title="C2C Demo", theme=gr.themes.Base()) as demo:
|
| 432 |
+
# Header with logo
|
| 433 |
+
with gr.Row():
|
| 434 |
+
with gr.Column(scale=1, min_width=100):
|
| 435 |
+
gr.Image("resource/logo.png", show_label=False, show_download_button=False, container=False, height=80)
|
| 436 |
+
with gr.Column(scale=5):
|
| 437 |
+
gr.Markdown("# Cache-to-Cache Communication Demo")
|
| 438 |
+
gr.Markdown("Compare three inference modes side-by-side: **Single** | **Text-to-Text Communication** | **Cache-to-Cache Communication**")
|
| 439 |
+
|
| 440 |
+
gr.Markdown("---")
|
| 441 |
+
|
| 442 |
+
# Input section
|
| 443 |
+
gr.Markdown("## Question")
|
| 444 |
+
|
| 445 |
+
# Preset question examples
|
| 446 |
+
gr.Markdown("Example Questions:")
|
| 447 |
+
with gr.Row():
|
| 448 |
+
example1_btn = gr.Button("📝 Example 1: Philosophy", size="sm")
|
| 449 |
+
example2_btn = gr.Button("📝 Example 2: Knowledge Cutoff", size="sm")
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
with gr.Row():
|
| 453 |
+
user_input = gr.Textbox(
|
| 454 |
+
label="",
|
| 455 |
+
placeholder="Type your question here...",
|
| 456 |
+
lines=2,
|
| 457 |
+
scale=4,
|
| 458 |
+
show_label=False
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
with gr.Row():
|
| 462 |
+
submit_btn = gr.Button("🚀 Submit", variant="primary", scale=1)
|
| 463 |
+
clear_btn = gr.Button("🗑️ Clear", scale=1)
|
| 464 |
+
|
| 465 |
+
gr.Markdown("---")
|
| 466 |
+
|
| 467 |
+
# Output section - three columns
|
| 468 |
+
gr.Markdown("## Responses")
|
| 469 |
+
with gr.Row():
|
| 470 |
+
# Single column
|
| 471 |
+
with gr.Column():
|
| 472 |
+
gr.Markdown("### Single Model")
|
| 473 |
+
gr.Markdown(f"*{model_manager.single_model_name}*")
|
| 474 |
+
single_output = gr.Textbox(
|
| 475 |
+
label="",
|
| 476 |
+
lines=18,
|
| 477 |
+
max_lines=30,
|
| 478 |
+
interactive=False,
|
| 479 |
+
show_label=False
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
# T2T column (with two sub-boxes)
|
| 483 |
+
with gr.Column():
|
| 484 |
+
gr.Markdown("### Text-to-Text Communication")
|
| 485 |
+
gr.Markdown(f"*{model_manager.t2t_context_model} → {model_manager.t2t_answer_model}*")
|
| 486 |
+
t2t_context_output = gr.Textbox(
|
| 487 |
+
label="📝 Context",
|
| 488 |
+
lines=6,
|
| 489 |
+
max_lines=12,
|
| 490 |
+
interactive=False
|
| 491 |
+
)
|
| 492 |
+
t2t_answer_output = gr.Textbox(
|
| 493 |
+
label="💬 Answer",
|
| 494 |
+
lines=7,
|
| 495 |
+
max_lines=14,
|
| 496 |
+
interactive=False
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# C2C column
|
| 500 |
+
with gr.Column():
|
| 501 |
+
gr.Markdown("### Cache-to-Cache Communication")
|
| 502 |
+
gr.Markdown(f"*{model_manager.c2c_base_model} → {model_manager.c2c_teacher_model}*")
|
| 503 |
+
c2c_output = gr.Textbox(
|
| 504 |
+
label="",
|
| 505 |
+
lines=18,
|
| 506 |
+
max_lines=30,
|
| 507 |
+
interactive=False,
|
| 508 |
+
show_label=False
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# Event handlers
|
| 512 |
+
submit_btn.click(
|
| 513 |
+
fn=respond,
|
| 514 |
+
inputs=[user_input],
|
| 515 |
+
outputs=[single_output, t2t_context_output, t2t_answer_output, c2c_output]
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
user_input.submit(
|
| 519 |
+
fn=respond,
|
| 520 |
+
inputs=[user_input],
|
| 521 |
+
outputs=[single_output, t2t_context_output, t2t_answer_output, c2c_output]
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
clear_btn.click(
|
| 525 |
+
fn=lambda: ("", "", "", "", ""),
|
| 526 |
+
inputs=None,
|
| 527 |
+
outputs=[user_input, single_output, t2t_context_output, t2t_answer_output, c2c_output]
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
# Example question handlers
|
| 531 |
+
example1_btn.click(
|
| 532 |
+
fn=lambda: EXAMPLE_QUESTIONS["example1"],
|
| 533 |
+
inputs=None,
|
| 534 |
+
outputs=[user_input]
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
example2_btn.click(
|
| 538 |
+
fn=lambda: EXAMPLE_QUESTIONS["example2"],
|
| 539 |
+
inputs=None,
|
| 540 |
+
outputs=[user_input]
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
return demo
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def main():
|
| 547 |
+
"""Main entry point."""
|
| 548 |
+
print("=" * 60)
|
| 549 |
+
print("Model Comparison Demo - Gradio Interface")
|
| 550 |
+
print("=" * 60)
|
| 551 |
+
|
| 552 |
+
# Initialize models
|
| 553 |
+
model_manager = ModelManager()
|
| 554 |
+
|
| 555 |
+
# Create and launch demo
|
| 556 |
+
demo = create_demo(model_manager)
|
| 557 |
+
|
| 558 |
+
print("\n" + "=" * 60)
|
| 559 |
+
print("🚀 Launching Gradio interface...")
|
| 560 |
+
print("=" * 60)
|
| 561 |
+
|
| 562 |
+
demo.launch(
|
| 563 |
+
server_name="0.0.0.0",
|
| 564 |
+
server_port=7860,
|
| 565 |
+
share=False,
|
| 566 |
+
show_error=True
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
if __name__ == "__main__":
|
| 571 |
+
main()
|
| 572 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies (from pyproject.toml)
|
| 2 |
+
torch==2.6.0
|
| 3 |
+
transformers==4.52.4
|
| 4 |
+
|
| 5 |
+
# Gradio web interface
|
| 6 |
+
gradio==5.9.1
|
| 7 |
+
|
| 8 |
+
# HuggingFace Hub for checkpoint downloads
|
| 9 |
+
huggingface-hub>=0.26.0
|
| 10 |
+
|
| 11 |
+
# Configuration file parsing
|
| 12 |
+
pyyaml>=6.0
|
| 13 |
+
|
| 14 |
+
# Additional dependencies for rosetta package
|
| 15 |
+
datasets>=2.0
|
| 16 |
+
accelerate>=0.20
|
| 17 |
+
|
| 18 |
+
# Utilities
|
| 19 |
+
einops>=0.8
|
| 20 |
+
|
rosetta/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (530 Bytes). View file
|
|
|
rosetta/baseline/__pycache__/multi_stage.cpython-310.pyc
ADDED
|
Binary file (24.3 kB). View file
|
|
|
rosetta/baseline/__pycache__/multi_stage.cpython-312.pyc
ADDED
|
Binary file (34.7 kB). View file
|
|
|
rosetta/baseline/__pycache__/two_stage_rosetta.cpython-310.pyc
ADDED
|
Binary file (9.34 kB). View file
|
|
|
rosetta/baseline/__pycache__/two_stage_rosetta.cpython-312.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
rosetta/baseline/multi_stage.py
ADDED
|
@@ -0,0 +1,992 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-stage evaluation utilities for VLM+LLM and LLM+LLM pipelines.
|
| 3 |
+
|
| 4 |
+
This module provides utilities for multi-stage evaluation where:
|
| 5 |
+
1. VLM describes/analyzes images + LLM performs reasoning
|
| 6 |
+
2. LLM provides background context + LLM performs reasoning
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Dict, Optional, Any
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import (
|
| 12 |
+
# Qwen2_5_VLForConditionalGeneration,
|
| 13 |
+
AutoProcessor,
|
| 14 |
+
AutoTokenizer,
|
| 15 |
+
AutoModelForCausalLM,
|
| 16 |
+
)
|
| 17 |
+
from rosetta.utils.evaluate import set_default_chat_template, apply_generation_config
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from qwen_vl_utils import process_vision_info
|
| 21 |
+
from transformers import Qwen2_5_VLForConditionalGeneration
|
| 22 |
+
except ImportError:
|
| 23 |
+
print("Please install qwen-vl-utils to use VLM models")
|
| 24 |
+
|
| 25 |
+
class TwoStageInference:
|
| 26 |
+
"""Two-stage LLM+LLM inference pipeline for question answering."""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
context_model_path: str,
|
| 31 |
+
answer_model_path: str,
|
| 32 |
+
device: str = "cuda",
|
| 33 |
+
max_new_tokens: int = 1024,
|
| 34 |
+
background_prompt: str = "Briefly describe the most useful background to solve the problem:\n\n{question}",
|
| 35 |
+
generation_config: Optional[Dict[str, Any]] = None
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Initialize two-stage LLM pipeline.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
context_model_path: Path to context-providing LLM
|
| 42 |
+
answer_model_path: Path to answer-generating LLM
|
| 43 |
+
device: Device to use
|
| 44 |
+
max_new_tokens: Maximum number of new tokens to generate
|
| 45 |
+
background_prompt: Prompt template for background generation
|
| 46 |
+
generation_config: Optional generation configuration to apply to models
|
| 47 |
+
"""
|
| 48 |
+
self.device = device
|
| 49 |
+
self.max_new_tokens = max_new_tokens
|
| 50 |
+
self.background_prompt = background_prompt
|
| 51 |
+
self.generation_config = generation_config or {}
|
| 52 |
+
self._load_models(context_model_path, answer_model_path)
|
| 53 |
+
|
| 54 |
+
def _load_models(self, context_path: str, answer_path: str):
|
| 55 |
+
"""Load both LLM models."""
|
| 56 |
+
# Load context LLM
|
| 57 |
+
self.context_tokenizer = AutoTokenizer.from_pretrained(context_path)
|
| 58 |
+
# for gemma, set sliding_window=4096
|
| 59 |
+
if context_path == "google/gemma-3-1b-it":
|
| 60 |
+
torch._dynamo.config.cache_size_limit = 64
|
| 61 |
+
self.context_model = AutoModelForCausalLM.from_pretrained(
|
| 62 |
+
context_path, torch_dtype=torch.bfloat16, device_map={"": self.device}, sliding_window=4096
|
| 63 |
+
)
|
| 64 |
+
else:
|
| 65 |
+
self.context_model = AutoModelForCausalLM.from_pretrained(
|
| 66 |
+
context_path, torch_dtype=torch.bfloat16, device_map={"": self.device}
|
| 67 |
+
)
|
| 68 |
+
# Apply generation config to context model
|
| 69 |
+
apply_generation_config(self.context_model, self.generation_config)
|
| 70 |
+
|
| 71 |
+
# Load answer LLM
|
| 72 |
+
self.answer_tokenizer = AutoTokenizer.from_pretrained(answer_path)
|
| 73 |
+
self.answer_model = AutoModelForCausalLM.from_pretrained(
|
| 74 |
+
answer_path, torch_dtype=torch.bfloat16, device_map={"": self.device}
|
| 75 |
+
)
|
| 76 |
+
# Apply generation config to answer model
|
| 77 |
+
apply_generation_config(self.answer_model, self.generation_config)
|
| 78 |
+
|
| 79 |
+
def get_background_context(
|
| 80 |
+
self,
|
| 81 |
+
question: str,
|
| 82 |
+
max_new_tokens: Optional[int] = None
|
| 83 |
+
) -> str:
|
| 84 |
+
"""
|
| 85 |
+
Get background context from the first LLM.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
question: Question text (without options)
|
| 89 |
+
max_new_tokens: Max tokens to generate (uses instance default if None)
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Background context
|
| 93 |
+
"""
|
| 94 |
+
prompt = self.background_prompt.format(question=question)
|
| 95 |
+
messages = [{"role": "user", "content": prompt}]
|
| 96 |
+
|
| 97 |
+
template_kwargs = {'enable_thinking': False}
|
| 98 |
+
|
| 99 |
+
inputs = self.context_tokenizer.apply_chat_template(
|
| 100 |
+
messages,
|
| 101 |
+
tokenize=True,
|
| 102 |
+
add_generation_prompt=True,
|
| 103 |
+
return_tensors="pt",
|
| 104 |
+
**template_kwargs
|
| 105 |
+
)
|
| 106 |
+
inputs = inputs.to(self.device)
|
| 107 |
+
|
| 108 |
+
if max_new_tokens is None:
|
| 109 |
+
max_new_tokens = self.max_new_tokens
|
| 110 |
+
|
| 111 |
+
with torch.inference_mode():
|
| 112 |
+
outputs = self.context_model.generate(
|
| 113 |
+
inputs,
|
| 114 |
+
max_new_tokens=max_new_tokens,
|
| 115 |
+
do_sample=False
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
generated_ids = outputs[:, inputs.shape[-1]:]
|
| 119 |
+
context = self.context_tokenizer.batch_decode(
|
| 120 |
+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 121 |
+
)[0]
|
| 122 |
+
|
| 123 |
+
return context
|
| 124 |
+
|
| 125 |
+
def answer_with_context(
|
| 126 |
+
self,
|
| 127 |
+
question: str,
|
| 128 |
+
context: str,
|
| 129 |
+
max_new_tokens: Optional[int] = None,
|
| 130 |
+
original_question: Optional[str] = None
|
| 131 |
+
) -> str:
|
| 132 |
+
"""
|
| 133 |
+
Answer question using the second LLM with context.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
question: Full question with options and proper template
|
| 137 |
+
context: Background context from first LLM
|
| 138 |
+
max_new_tokens: Max tokens to generate (uses instance default if None)
|
| 139 |
+
original_question: Original question asked to first LLM (for conversation format)
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Generated answer
|
| 143 |
+
"""
|
| 144 |
+
# Use conversation format: user asks for background, assistant provides it, user asks main question
|
| 145 |
+
if original_question:
|
| 146 |
+
messages = [
|
| 147 |
+
{"role": "user", "content": self.background_prompt.format(question=original_question)},
|
| 148 |
+
{"role": "assistant", "content": context},
|
| 149 |
+
{"role": "user", "content": question}
|
| 150 |
+
]
|
| 151 |
+
else:
|
| 152 |
+
# Fallback to simple format
|
| 153 |
+
messages = [{"role": "user", "content": f"Background context: {context}\n\n{question}"}]
|
| 154 |
+
|
| 155 |
+
template_kwargs = {'enable_thinking': False}
|
| 156 |
+
|
| 157 |
+
inputs = self.answer_tokenizer.apply_chat_template(
|
| 158 |
+
messages,
|
| 159 |
+
tokenize=True,
|
| 160 |
+
add_generation_prompt=True,
|
| 161 |
+
return_tensors="pt",
|
| 162 |
+
**template_kwargs
|
| 163 |
+
)
|
| 164 |
+
inputs = inputs.to(self.device)
|
| 165 |
+
|
| 166 |
+
if max_new_tokens is None:
|
| 167 |
+
max_new_tokens = self.max_new_tokens
|
| 168 |
+
|
| 169 |
+
with torch.inference_mode():
|
| 170 |
+
outputs = self.answer_model.generate(
|
| 171 |
+
inputs,
|
| 172 |
+
max_new_tokens=max_new_tokens,
|
| 173 |
+
do_sample=False
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
generated_ids = outputs[:, inputs.shape[-1]:]
|
| 177 |
+
answer = self.answer_tokenizer.batch_decode(
|
| 178 |
+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 179 |
+
)[0]
|
| 180 |
+
|
| 181 |
+
return answer
|
| 182 |
+
|
| 183 |
+
def forward_with_context(
|
| 184 |
+
self,
|
| 185 |
+
question: str,
|
| 186 |
+
context: str,
|
| 187 |
+
original_question: Optional[str] = None,
|
| 188 |
+
response_text: Optional[str] = None,
|
| 189 |
+
**forward_kwargs
|
| 190 |
+
) -> Any:
|
| 191 |
+
"""
|
| 192 |
+
Run a forward pass on the answer model using provided context (logits mode).
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
question: Full question with options and proper template
|
| 196 |
+
context: Background context from first LLM
|
| 197 |
+
original_question: Original question asked to first LLM (for conversation format)
|
| 198 |
+
response_text: Optional text to append after the chat template to steer next-token logits
|
| 199 |
+
**forward_kwargs: Extra kwargs forwarded to the model's forward
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Model outputs from the forward pass (e.g., logits)
|
| 203 |
+
"""
|
| 204 |
+
# Use conversation format: user asks for background, assistant provides it, user asks main question
|
| 205 |
+
if original_question:
|
| 206 |
+
messages = [
|
| 207 |
+
{"role": "user", "content": self.background_prompt.format(question=original_question)},
|
| 208 |
+
{"role": "assistant", "content": context},
|
| 209 |
+
{"role": "user", "content": question}
|
| 210 |
+
]
|
| 211 |
+
else:
|
| 212 |
+
# Fallback to simple format
|
| 213 |
+
messages = [{"role": "user", "content": f"Background context: {context}\n\n{question}"}]
|
| 214 |
+
|
| 215 |
+
template_kwargs = {'enable_thinking': False}
|
| 216 |
+
|
| 217 |
+
# Build model inputs; if response_text is provided, append it to steer next-token prediction
|
| 218 |
+
if response_text is not None:
|
| 219 |
+
# Build raw text then append response_text
|
| 220 |
+
text = self.answer_tokenizer.apply_chat_template(
|
| 221 |
+
messages,
|
| 222 |
+
tokenize=False,
|
| 223 |
+
add_generation_prompt=True,
|
| 224 |
+
**template_kwargs
|
| 225 |
+
)
|
| 226 |
+
text = text + response_text
|
| 227 |
+
tokenized = self.answer_tokenizer(text, return_tensors="pt")
|
| 228 |
+
else:
|
| 229 |
+
# Directly build tensors with generation prompt (predict next assistant token)
|
| 230 |
+
tokenized = self.answer_tokenizer.apply_chat_template(
|
| 231 |
+
messages,
|
| 232 |
+
tokenize=True,
|
| 233 |
+
add_generation_prompt=True,
|
| 234 |
+
return_tensors="pt",
|
| 235 |
+
**template_kwargs
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
inputs = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 239 |
+
|
| 240 |
+
with torch.inference_mode():
|
| 241 |
+
outputs = self.answer_model(**inputs, **forward_kwargs)
|
| 242 |
+
|
| 243 |
+
return outputs
|
| 244 |
+
|
| 245 |
+
def forward(
|
| 246 |
+
self,
|
| 247 |
+
question_without_options: str,
|
| 248 |
+
question_with_options: str,
|
| 249 |
+
response_text: Optional[str] = None,
|
| 250 |
+
**forward_kwargs
|
| 251 |
+
) -> Any:
|
| 252 |
+
"""
|
| 253 |
+
Two-stage forward pass (logits mode):
|
| 254 |
+
1) Generate background context with the context model
|
| 255 |
+
2) Run a forward pass on the answer model conditioned on that context
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
question_without_options: Question text without multiple choice options
|
| 259 |
+
question_with_options: Full question with options and proper template
|
| 260 |
+
response_text: Optional text appended after the chat template to steer next-token logits
|
| 261 |
+
**forward_kwargs: Extra kwargs forwarded to the model's forward
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
Model outputs from the forward pass (e.g., logits)
|
| 265 |
+
"""
|
| 266 |
+
context = self.get_background_context(question_without_options)
|
| 267 |
+
return self.forward_with_context(
|
| 268 |
+
question=question_with_options,
|
| 269 |
+
context=context,
|
| 270 |
+
original_question=question_without_options,
|
| 271 |
+
response_text=response_text,
|
| 272 |
+
**forward_kwargs
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
def logits_with_context(
|
| 276 |
+
self,
|
| 277 |
+
question_without_options: str,
|
| 278 |
+
question_with_options: str,
|
| 279 |
+
response_text: Optional[str] = None,
|
| 280 |
+
**forward_kwargs
|
| 281 |
+
) -> Any:
|
| 282 |
+
"""
|
| 283 |
+
Two-stage logits helper that also returns the generated background context
|
| 284 |
+
for logging as CoT.
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
(outputs, context)
|
| 288 |
+
"""
|
| 289 |
+
context = self.get_background_context(question_without_options)
|
| 290 |
+
outputs = self.forward_with_context(
|
| 291 |
+
question=question_with_options,
|
| 292 |
+
context=context,
|
| 293 |
+
original_question=question_without_options,
|
| 294 |
+
response_text=response_text,
|
| 295 |
+
**forward_kwargs
|
| 296 |
+
)
|
| 297 |
+
return outputs, context
|
| 298 |
+
|
| 299 |
+
def generate(
|
| 300 |
+
self,
|
| 301 |
+
question_without_options: str,
|
| 302 |
+
question_with_options: str,
|
| 303 |
+
communication_max_new_tokens: Optional[int] = None,
|
| 304 |
+
response_max_new_tokens: Optional[int] = None,
|
| 305 |
+
**kwargs
|
| 306 |
+
) -> str:
|
| 307 |
+
"""
|
| 308 |
+
Generate answer using two-stage processing.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
question_without_options: Question text without multiple choice options
|
| 312 |
+
question_with_options: Full question with options and proper template
|
| 313 |
+
communication_max_new_tokens: Maximum tokens to generate for the background context
|
| 314 |
+
response_max_new_tokens: Maximum tokens to generate for the answer
|
| 315 |
+
**kwargs: Additional generation parameters (ignored for compatibility)
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
Generated answer string
|
| 319 |
+
"""
|
| 320 |
+
# Stage 1: Get background context
|
| 321 |
+
context = self.get_background_context(question_without_options, communication_max_new_tokens)
|
| 322 |
+
|
| 323 |
+
# Stage 2: Answer question with context
|
| 324 |
+
answer = self.answer_with_context(
|
| 325 |
+
question=question_with_options,
|
| 326 |
+
context=context,
|
| 327 |
+
original_question=question_without_options,
|
| 328 |
+
max_new_tokens=response_max_new_tokens
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
return answer
|
| 332 |
+
|
| 333 |
+
def process(
|
| 334 |
+
self,
|
| 335 |
+
question_without_options: str,
|
| 336 |
+
question_with_options: str
|
| 337 |
+
) -> Dict[str, str]:
|
| 338 |
+
"""
|
| 339 |
+
Full two-stage processing (legacy method for backward compatibility).
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
question_without_options: Question text without multiple choice options
|
| 343 |
+
question_with_options: Full question with options and proper template
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
Dictionary with context and answer
|
| 347 |
+
"""
|
| 348 |
+
# Stage 1: Get background context
|
| 349 |
+
context = self.get_background_context(question_without_options)
|
| 350 |
+
|
| 351 |
+
# Stage 2: Answer question with context
|
| 352 |
+
answer = self.answer_with_context(
|
| 353 |
+
question=question_with_options,
|
| 354 |
+
context=context,
|
| 355 |
+
original_question=question_without_options
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
return {
|
| 359 |
+
"context": context,
|
| 360 |
+
"answer": answer
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class TwoStageRosetta(TwoStageInference):
|
| 365 |
+
"""Two-stage LLM+Rosetta inference pipeline for question answering."""
|
| 366 |
+
|
| 367 |
+
def __init__(
|
| 368 |
+
self,
|
| 369 |
+
context_model_path: str,
|
| 370 |
+
rosetta_checkpoint_dir: str,
|
| 371 |
+
rosetta_subfolder: str = "final",
|
| 372 |
+
device: str = "cuda",
|
| 373 |
+
max_new_tokens: int = 1024,
|
| 374 |
+
background_prompt: str = "Briefly describe the most useful background to solve the problem:\n\n{question}",
|
| 375 |
+
generation_config: Optional[Dict[str, Any]] = None
|
| 376 |
+
):
|
| 377 |
+
"""
|
| 378 |
+
Initialize two-stage pipeline with Rosetta as second model.
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
context_model_path: Path to context-providing LLM
|
| 382 |
+
rosetta_checkpoint_dir: Path to Rosetta checkpoint directory
|
| 383 |
+
rosetta_subfolder: Subfolder name in checkpoint directory (e.g., 'final', 'checkpoint-1000')
|
| 384 |
+
device: Device to use
|
| 385 |
+
max_new_tokens: Maximum number of new tokens to generate
|
| 386 |
+
background_prompt: Prompt template for background generation
|
| 387 |
+
generation_config: Optional generation configuration to apply to models
|
| 388 |
+
"""
|
| 389 |
+
# Initialize parent class with dummy answer model path
|
| 390 |
+
# We'll override the answer model loading
|
| 391 |
+
super().__init__(
|
| 392 |
+
context_model_path=context_model_path,
|
| 393 |
+
answer_model_path=None, # Will be overridden
|
| 394 |
+
device=device,
|
| 395 |
+
max_new_tokens=max_new_tokens,
|
| 396 |
+
background_prompt=background_prompt,
|
| 397 |
+
generation_config=generation_config
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
self.rosetta_checkpoint_dir = rosetta_checkpoint_dir
|
| 401 |
+
self.rosetta_subfolder = rosetta_subfolder
|
| 402 |
+
self._load_rosetta_model()
|
| 403 |
+
|
| 404 |
+
def _load_models(self, context_path: str, answer_path: str):
|
| 405 |
+
"""
|
| 406 |
+
Override parent class _load_models to prevent loading dummy answer model.
|
| 407 |
+
We only load the context model here, and the Rosetta model is loaded separately.
|
| 408 |
+
"""
|
| 409 |
+
# Only load context LLM (answer model is replaced by Rosetta)
|
| 410 |
+
self.context_tokenizer = AutoTokenizer.from_pretrained(context_path)
|
| 411 |
+
self.context_model = AutoModelForCausalLM.from_pretrained(
|
| 412 |
+
context_path, torch_dtype=torch.bfloat16, device_map={"": self.device}
|
| 413 |
+
)
|
| 414 |
+
# Apply generation config to context model
|
| 415 |
+
apply_generation_config(self.context_model, self.generation_config)
|
| 416 |
+
|
| 417 |
+
# Skip loading answer model - we use Rosetta instead
|
| 418 |
+
print(f"Loaded context model from {context_path}")
|
| 419 |
+
print("Skipping answer model loading - using Rosetta model instead")
|
| 420 |
+
|
| 421 |
+
def _load_rosetta_model(self):
|
| 422 |
+
"""Load Rosetta model and related components following load_model_from_checkpoint pattern."""
|
| 423 |
+
import json
|
| 424 |
+
from pathlib import Path
|
| 425 |
+
from rosetta.utils.evaluate import load_rosetta_model
|
| 426 |
+
|
| 427 |
+
checkpoint_path = Path(self.rosetta_checkpoint_dir)
|
| 428 |
+
|
| 429 |
+
# Load config
|
| 430 |
+
config_path = checkpoint_path / "config.json"
|
| 431 |
+
if not config_path.exists():
|
| 432 |
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 433 |
+
|
| 434 |
+
with open(config_path, 'r') as f:
|
| 435 |
+
config = json.load(f)
|
| 436 |
+
|
| 437 |
+
# Check if this is a Rosetta model (has projectors)
|
| 438 |
+
subfolder_dir = checkpoint_path / self.rosetta_subfolder
|
| 439 |
+
has_projectors = subfolder_dir.exists() and any(
|
| 440 |
+
f.name.startswith("projector_") and f.name.endswith(".pt")
|
| 441 |
+
for f in subfolder_dir.iterdir()
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
if not has_projectors:
|
| 445 |
+
raise ValueError(f"No projectors found in {subfolder_dir}. This doesn't appear to be a Rosetta checkpoint.")
|
| 446 |
+
|
| 447 |
+
# Load Rosetta model (following load_model_from_checkpoint pattern)
|
| 448 |
+
print(f"Loading Rosetta model from {self.rosetta_checkpoint_dir}")
|
| 449 |
+
|
| 450 |
+
# Create model config for Rosetta loading
|
| 451 |
+
model_config = {
|
| 452 |
+
"model_name": "Rosetta",
|
| 453 |
+
"rosetta_config": {
|
| 454 |
+
"checkpoints_dir": str(subfolder_dir),
|
| 455 |
+
"base_model": config["model"]["base_model"],
|
| 456 |
+
"teacher_model": config["model"]["teacher_model"],
|
| 457 |
+
"is_do_alignment": config["model"].get("is_do_alignment", False),
|
| 458 |
+
"alignment_strategy": config["model"].get("alignment_strategy", "first")
|
| 459 |
+
}
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
print(f"Model config: {model_config}")
|
| 463 |
+
|
| 464 |
+
eval_config = {
|
| 465 |
+
"checkpoints_dir": str(subfolder_dir)
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
# Load Rosetta model using the existing utility
|
| 469 |
+
self.rosetta_model, self.rosetta_tokenizer = load_rosetta_model(
|
| 470 |
+
model_config,
|
| 471 |
+
eval_config,
|
| 472 |
+
device=self.device
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# Load LLM tokenizer for alignment if needed
|
| 476 |
+
is_do_alignment = config["model"].get("is_do_alignment", False)
|
| 477 |
+
llm_model_path = config["model"].get("teacher_model")
|
| 478 |
+
self.llm_tokenizer = None
|
| 479 |
+
|
| 480 |
+
if is_do_alignment and llm_model_path:
|
| 481 |
+
try:
|
| 482 |
+
self.llm_tokenizer = AutoTokenizer.from_pretrained(str(llm_model_path))
|
| 483 |
+
if self.llm_tokenizer.pad_token is None:
|
| 484 |
+
self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
|
| 485 |
+
set_default_chat_template(self.llm_tokenizer, llm_model_path)
|
| 486 |
+
except Exception as e:
|
| 487 |
+
print(f"Failed to load LLM tokenizer '{llm_model_path}': {e}")
|
| 488 |
+
self.llm_tokenizer = None
|
| 489 |
+
|
| 490 |
+
print(f"Initialized TwoStageRosetta with Rosetta model on {self.device}")
|
| 491 |
+
|
| 492 |
+
def _prepare_rosetta_inputs(
|
| 493 |
+
self,
|
| 494 |
+
question: str,
|
| 495 |
+
context: str,
|
| 496 |
+
original_question: Optional[str] = None,
|
| 497 |
+
max_new_tokens: Optional[int] = None,
|
| 498 |
+
answer_method: str = "generate",
|
| 499 |
+
response_text: Optional[str] = None
|
| 500 |
+
) -> Dict[str, Any]:
|
| 501 |
+
"""
|
| 502 |
+
Prepare inputs for Rosetta model using the simpler approach from live_chat_example.py.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
question: Question to answer
|
| 506 |
+
context: Background context from first LLM
|
| 507 |
+
original_question: Original question asked to first LLM (for conversation format)
|
| 508 |
+
max_new_tokens: Max tokens to generate (uses instance default if None)
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
Dictionary with prepared inputs for Rosetta model
|
| 512 |
+
"""
|
| 513 |
+
# Use conversation format: user asks for background, assistant provides it, user asks main question
|
| 514 |
+
if original_question:
|
| 515 |
+
messages = [
|
| 516 |
+
{"role": "user", "content": self.background_prompt.format(question=original_question)},
|
| 517 |
+
{"role": "assistant", "content": context},
|
| 518 |
+
{"role": "user", "content": question}
|
| 519 |
+
]
|
| 520 |
+
else:
|
| 521 |
+
# Fallback to simple format
|
| 522 |
+
messages = [{"role": "user", "content": f"Background context: {context}\n\n{question}"}]
|
| 523 |
+
|
| 524 |
+
# Apply chat template (following live_chat_example.py pattern)
|
| 525 |
+
base_text = None
|
| 526 |
+
if hasattr(self.rosetta_tokenizer, 'apply_chat_template'):
|
| 527 |
+
base_text = self.rosetta_tokenizer.apply_chat_template(
|
| 528 |
+
messages,
|
| 529 |
+
tokenize=False,
|
| 530 |
+
add_generation_prompt=True,
|
| 531 |
+
enable_thinking=False
|
| 532 |
+
)
|
| 533 |
+
else:
|
| 534 |
+
base_text = f"### Human: {question}\n### Assistant:"
|
| 535 |
+
|
| 536 |
+
# Optionally append response_text for logits mode to steer next-token logits
|
| 537 |
+
if answer_method == 'logits' and response_text is not None:
|
| 538 |
+
text = base_text + response_text
|
| 539 |
+
else:
|
| 540 |
+
text = base_text
|
| 541 |
+
|
| 542 |
+
# Tokenize input
|
| 543 |
+
inputs = self.rosetta_tokenizer(text, return_tensors="pt").to(self.device)
|
| 544 |
+
|
| 545 |
+
# Create kv_cache_index for Rosetta model
|
| 546 |
+
full_length = inputs.input_ids.shape[1]
|
| 547 |
+
if answer_method == 'logits':
|
| 548 |
+
# Compute response length as the extra tokens appended by response_text
|
| 549 |
+
if response_text is not None:
|
| 550 |
+
base_tok = self.rosetta_tokenizer(base_text, return_tensors="pt")
|
| 551 |
+
response_length = int(inputs.input_ids.shape[1] - base_tok.input_ids.shape[1])
|
| 552 |
+
response_length = max(response_length, 0)
|
| 553 |
+
else:
|
| 554 |
+
response_length = 0
|
| 555 |
+
instr_len = max(full_length - response_length, 0)
|
| 556 |
+
instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(instr_len, 1).unsqueeze(0).to(self.device)
|
| 557 |
+
if response_length > 0:
|
| 558 |
+
response_index = torch.tensor([-1, 0], dtype=torch.long).repeat(response_length, 1).unsqueeze(0).to(self.device)
|
| 559 |
+
kv_cache_list = [instruction_index, response_index]
|
| 560 |
+
else:
|
| 561 |
+
kv_cache_list = [instruction_index]
|
| 562 |
+
else:
|
| 563 |
+
# Generate: treat the last position as response (length 1)
|
| 564 |
+
instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(full_length - 1, 1).unsqueeze(0).to(self.device)
|
| 565 |
+
label_index = torch.tensor([-1, 0], dtype=torch.long).repeat(1, 1).unsqueeze(0).to(self.device)
|
| 566 |
+
kv_cache_list = [instruction_index, label_index]
|
| 567 |
+
|
| 568 |
+
# Add position_ids if needed
|
| 569 |
+
if inputs.attention_mask is None:
|
| 570 |
+
position_ids = torch.arange(inputs.input_ids.shape[-1], dtype=torch.long).unsqueeze(0).to(self.device)
|
| 571 |
+
else:
|
| 572 |
+
position_ids = inputs.attention_mask.long().cumsum(-1) - 1
|
| 573 |
+
|
| 574 |
+
outputs = {
|
| 575 |
+
"inputs": {
|
| 576 |
+
"input_ids": inputs.input_ids,
|
| 577 |
+
"attention_mask": inputs.attention_mask,
|
| 578 |
+
"position_ids": position_ids,
|
| 579 |
+
"kv_cache_index": kv_cache_list
|
| 580 |
+
},
|
| 581 |
+
"printable_text": text
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
return outputs
|
| 585 |
+
|
| 586 |
+
def answer_with_context(
|
| 587 |
+
self,
|
| 588 |
+
question: str,
|
| 589 |
+
context: str,
|
| 590 |
+
max_new_tokens: Optional[int] = None,
|
| 591 |
+
original_question: Optional[str] = None
|
| 592 |
+
) -> str:
|
| 593 |
+
"""
|
| 594 |
+
Answer question using Rosetta model with context.
|
| 595 |
+
Overrides parent class method to use Rosetta instead of regular LLM.
|
| 596 |
+
|
| 597 |
+
Args:
|
| 598 |
+
question: Question to answer
|
| 599 |
+
context: Background context from first LLM
|
| 600 |
+
max_new_tokens: Max tokens to generate (uses instance default if None)
|
| 601 |
+
original_question: Original question asked to first LLM (for conversation format)
|
| 602 |
+
|
| 603 |
+
Returns:
|
| 604 |
+
Generated answer
|
| 605 |
+
"""
|
| 606 |
+
# Prepare inputs for Rosetta model
|
| 607 |
+
prepared = self._prepare_rosetta_inputs(
|
| 608 |
+
question=question,
|
| 609 |
+
context=context,
|
| 610 |
+
original_question=original_question,
|
| 611 |
+
max_new_tokens=max_new_tokens
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
if max_new_tokens is None:
|
| 615 |
+
max_new_tokens = self.max_new_tokens
|
| 616 |
+
|
| 617 |
+
# Generation parameters (following live_chat_example.py pattern)
|
| 618 |
+
sampling_params = {
|
| 619 |
+
'do_sample': False,
|
| 620 |
+
'max_new_tokens': max_new_tokens
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
# Generate using Rosetta model (following live_chat_example.py pattern)
|
| 624 |
+
input_length = prepared['inputs']['input_ids'].shape[1]
|
| 625 |
+
|
| 626 |
+
with torch.inference_mode():
|
| 627 |
+
outputs = self.rosetta_model.generate(
|
| 628 |
+
kv_cache_index=prepared['inputs']['kv_cache_index'],
|
| 629 |
+
input_ids=prepared['inputs']['input_ids'],
|
| 630 |
+
attention_mask=prepared['inputs']['attention_mask'],
|
| 631 |
+
position_ids=prepared['inputs']['position_ids'],
|
| 632 |
+
**sampling_params
|
| 633 |
+
)
|
| 634 |
+
generated_ids = outputs[0]
|
| 635 |
+
|
| 636 |
+
# Decode response
|
| 637 |
+
answer = self.rosetta_tokenizer.decode(generated_ids[input_length:], skip_special_tokens=True).strip()
|
| 638 |
+
|
| 639 |
+
return answer
|
| 640 |
+
|
| 641 |
+
def forward_with_context(
|
| 642 |
+
self,
|
| 643 |
+
question: str,
|
| 644 |
+
context: str,
|
| 645 |
+
original_question: Optional[str] = None,
|
| 646 |
+
response_text: Optional[str] = None,
|
| 647 |
+
**forward_kwargs
|
| 648 |
+
) -> Any:
|
| 649 |
+
"""
|
| 650 |
+
Run a forward pass on the Rosetta model using provided context (logits mode).
|
| 651 |
+
|
| 652 |
+
Args:
|
| 653 |
+
question: Full question with options and proper template
|
| 654 |
+
context: Background context from first LLM
|
| 655 |
+
original_question: Original question asked to first LLM (for conversation format)
|
| 656 |
+
response_text: Optional text appended after the chat template to steer next-token logits
|
| 657 |
+
**forward_kwargs: Extra kwargs forwarded to the model's forward
|
| 658 |
+
|
| 659 |
+
Returns:
|
| 660 |
+
Model outputs from the forward pass (e.g., logits)
|
| 661 |
+
"""
|
| 662 |
+
prepared = self._prepare_rosetta_inputs(
|
| 663 |
+
question=question,
|
| 664 |
+
context=context,
|
| 665 |
+
original_question=original_question,
|
| 666 |
+
answer_method='logits',
|
| 667 |
+
response_text=response_text
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
inputs = prepared['inputs']
|
| 671 |
+
with torch.inference_mode():
|
| 672 |
+
outputs = self.rosetta_model.forward(
|
| 673 |
+
kv_cache_index=inputs['kv_cache_index'],
|
| 674 |
+
input_ids=inputs['input_ids'],
|
| 675 |
+
attention_mask=inputs['attention_mask'],
|
| 676 |
+
position_ids=inputs['position_ids'],
|
| 677 |
+
**forward_kwargs
|
| 678 |
+
)
|
| 679 |
+
return outputs
|
| 680 |
+
|
| 681 |
+
def forward(
|
| 682 |
+
self,
|
| 683 |
+
question_without_options: str,
|
| 684 |
+
question_with_options: str,
|
| 685 |
+
response_text: Optional[str] = None,
|
| 686 |
+
**forward_kwargs
|
| 687 |
+
) -> Any:
|
| 688 |
+
"""
|
| 689 |
+
Two-stage forward pass (logits mode) for Rosetta:
|
| 690 |
+
1) Generate background context with the context model
|
| 691 |
+
2) Run a forward pass on the Rosetta model conditioned on that context
|
| 692 |
+
|
| 693 |
+
Args:
|
| 694 |
+
question_without_options: Question text without multiple choice options
|
| 695 |
+
question_with_options: Full question with options and proper template
|
| 696 |
+
response_text: Optional text appended after the chat template to steer next-token logits
|
| 697 |
+
**forward_kwargs: Extra kwargs forwarded to the model's forward
|
| 698 |
+
|
| 699 |
+
Returns:
|
| 700 |
+
Model outputs from the forward pass (e.g., logits)
|
| 701 |
+
"""
|
| 702 |
+
# Work in progress
|
| 703 |
+
raise NotImplementedError
|
| 704 |
+
context = self.get_background_context(question_without_options)
|
| 705 |
+
return self.forward_with_context(
|
| 706 |
+
question=question_with_options,
|
| 707 |
+
context=context,
|
| 708 |
+
original_question=question_without_options,
|
| 709 |
+
response_text=response_text,
|
| 710 |
+
**forward_kwargs
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
def logits_with_context(
|
| 714 |
+
self,
|
| 715 |
+
question_without_options: str,
|
| 716 |
+
question_with_options: str,
|
| 717 |
+
response_text: Optional[str] = None,
|
| 718 |
+
**forward_kwargs
|
| 719 |
+
) -> Any:
|
| 720 |
+
"""
|
| 721 |
+
Two-stage logits helper that also returns the generated background context
|
| 722 |
+
for logging as CoT (Rosetta backend).
|
| 723 |
+
|
| 724 |
+
Returns:
|
| 725 |
+
(outputs, context)
|
| 726 |
+
"""
|
| 727 |
+
context = self.get_background_context(question_without_options)
|
| 728 |
+
outputs = self.forward_with_context(
|
| 729 |
+
question=question_with_options,
|
| 730 |
+
context=context,
|
| 731 |
+
original_question=question_without_options,
|
| 732 |
+
response_text=response_text,
|
| 733 |
+
**forward_kwargs
|
| 734 |
+
)
|
| 735 |
+
return outputs, context
|
| 736 |
+
|
| 737 |
+
def generate(
|
| 738 |
+
self,
|
| 739 |
+
question_without_options: str,
|
| 740 |
+
question_with_options: str,
|
| 741 |
+
max_new_tokens: Optional[int] = None,
|
| 742 |
+
**kwargs
|
| 743 |
+
) -> str:
|
| 744 |
+
"""
|
| 745 |
+
Generate answer using two-stage processing with Rosetta.
|
| 746 |
+
|
| 747 |
+
Args:
|
| 748 |
+
question_without_options: Question text without multiple choice options
|
| 749 |
+
question_with_options: Full question with options and proper template
|
| 750 |
+
max_new_tokens: Maximum tokens to generate (passed to both stages)
|
| 751 |
+
**kwargs: Additional generation parameters (ignored for compatibility)
|
| 752 |
+
|
| 753 |
+
Returns:
|
| 754 |
+
Generated answer string
|
| 755 |
+
"""
|
| 756 |
+
# Stage 1: Get background context (uses parent class method)
|
| 757 |
+
context = self.get_background_context(question_without_options, max_new_tokens)
|
| 758 |
+
|
| 759 |
+
# Stage 2: Answer question with context using Rosetta
|
| 760 |
+
answer = self.answer_with_context(
|
| 761 |
+
question=question_with_options,
|
| 762 |
+
context=context,
|
| 763 |
+
original_question=question_without_options,
|
| 764 |
+
max_new_tokens=max_new_tokens
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
return answer
|
| 768 |
+
|
| 769 |
+
def process(
|
| 770 |
+
self,
|
| 771 |
+
question_without_options: str,
|
| 772 |
+
question_with_options: str
|
| 773 |
+
) -> Dict[str, str]:
|
| 774 |
+
"""
|
| 775 |
+
Full two-stage processing with Rosetta (legacy method for backward compatibility).
|
| 776 |
+
|
| 777 |
+
Args:
|
| 778 |
+
question_without_options: Question text without multiple choice options
|
| 779 |
+
question_with_options: Full question with options and proper template
|
| 780 |
+
|
| 781 |
+
Returns:
|
| 782 |
+
Dictionary with context and answer
|
| 783 |
+
"""
|
| 784 |
+
# Stage 1: Get background context (uses parent class method)
|
| 785 |
+
context = self.get_background_context(question_without_options)
|
| 786 |
+
|
| 787 |
+
# Stage 2: Answer question with context using Rosetta
|
| 788 |
+
answer = self.answer_with_context(
|
| 789 |
+
question=question_with_options,
|
| 790 |
+
context=context,
|
| 791 |
+
original_question=question_without_options
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
return {
|
| 795 |
+
"context": context,
|
| 796 |
+
"answer": answer
|
| 797 |
+
}
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
class MultiModalInference:
|
| 802 |
+
"""Multi-modal VLM+LLM inference pipeline."""
|
| 803 |
+
|
| 804 |
+
def __init__(
|
| 805 |
+
self,
|
| 806 |
+
vlm_model_path: str,
|
| 807 |
+
llm_model_path: str,
|
| 808 |
+
device: str = "cuda",
|
| 809 |
+
max_new_tokens: int = 1024,
|
| 810 |
+
generation_config: Optional[Dict[str, Any]] = None
|
| 811 |
+
):
|
| 812 |
+
"""
|
| 813 |
+
Initialize multi-modal pipeline.
|
| 814 |
+
|
| 815 |
+
Args:
|
| 816 |
+
vlm_model_path: Path to VLM model
|
| 817 |
+
llm_model_path: Path to LLM model
|
| 818 |
+
device: Device to use
|
| 819 |
+
max_new_tokens: Maximum number of new tokens to generate
|
| 820 |
+
generation_config: Optional generation configuration to apply to models
|
| 821 |
+
"""
|
| 822 |
+
self.device = device
|
| 823 |
+
self.max_new_tokens = max_new_tokens
|
| 824 |
+
self.generation_config = generation_config or {}
|
| 825 |
+
self._load_models(vlm_model_path, llm_model_path)
|
| 826 |
+
|
| 827 |
+
def _load_models(self, vlm_path: str, llm_path: str):
|
| 828 |
+
"""Load VLM and LLM models."""
|
| 829 |
+
# Load VLM
|
| 830 |
+
self.vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 831 |
+
vlm_path,
|
| 832 |
+
torch_dtype=torch.bfloat16,
|
| 833 |
+
device_map={"": self.device},
|
| 834 |
+
)
|
| 835 |
+
# Apply generation config to VLM model
|
| 836 |
+
apply_generation_config(self.vlm_model, self.generation_config)
|
| 837 |
+
self.vlm_processor = AutoProcessor.from_pretrained(vlm_path)
|
| 838 |
+
|
| 839 |
+
# Load LLM
|
| 840 |
+
self.llm_tokenizer = AutoTokenizer.from_pretrained(llm_path)
|
| 841 |
+
self.llm_model = AutoModelForCausalLM.from_pretrained(
|
| 842 |
+
llm_path, torch_dtype=torch.bfloat16, device_map={"": self.device}
|
| 843 |
+
)
|
| 844 |
+
# Apply generation config to LLM model
|
| 845 |
+
apply_generation_config(self.llm_model, self.generation_config)
|
| 846 |
+
|
| 847 |
+
def get_image_description(
|
| 848 |
+
self,
|
| 849 |
+
image_path: str,
|
| 850 |
+
prompt: str = "Describe this image in detail.",
|
| 851 |
+
max_new_tokens: Optional[int] = None
|
| 852 |
+
) -> str:
|
| 853 |
+
"""
|
| 854 |
+
Get image description from VLM.
|
| 855 |
+
|
| 856 |
+
Args:
|
| 857 |
+
image_path: Path to image
|
| 858 |
+
prompt: Description prompt
|
| 859 |
+
max_new_tokens: Max tokens to generate (uses instance default if None)
|
| 860 |
+
|
| 861 |
+
Returns:
|
| 862 |
+
Image description
|
| 863 |
+
"""
|
| 864 |
+
messages = [{
|
| 865 |
+
"role": "user",
|
| 866 |
+
"content": [
|
| 867 |
+
{"type": "image", "image": image_path},
|
| 868 |
+
{"type": "text", "text": prompt}
|
| 869 |
+
]
|
| 870 |
+
}]
|
| 871 |
+
|
| 872 |
+
text = self.vlm_processor.apply_chat_template(
|
| 873 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 874 |
+
)
|
| 875 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 876 |
+
inputs = self.vlm_processor(
|
| 877 |
+
text=[text],
|
| 878 |
+
images=image_inputs,
|
| 879 |
+
videos=video_inputs,
|
| 880 |
+
padding=True,
|
| 881 |
+
return_tensors="pt",
|
| 882 |
+
)
|
| 883 |
+
inputs = inputs.to(self.device)
|
| 884 |
+
|
| 885 |
+
if max_new_tokens is None:
|
| 886 |
+
max_new_tokens = self.max_new_tokens
|
| 887 |
+
|
| 888 |
+
with torch.inference_mode():
|
| 889 |
+
outputs = self.vlm_model.generate(
|
| 890 |
+
**inputs,
|
| 891 |
+
max_new_tokens=max_new_tokens,
|
| 892 |
+
do_sample=False,
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
generated_ids = outputs[:, inputs["input_ids"].shape[-1]:]
|
| 896 |
+
description = self.vlm_processor.batch_decode(
|
| 897 |
+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 898 |
+
)[0]
|
| 899 |
+
|
| 900 |
+
return description
|
| 901 |
+
|
| 902 |
+
def answer_with_context(
|
| 903 |
+
self,
|
| 904 |
+
question: str,
|
| 905 |
+
context: str,
|
| 906 |
+
max_new_tokens: Optional[int] = None,
|
| 907 |
+
original_question: Optional[str] = None
|
| 908 |
+
) -> str:
|
| 909 |
+
"""
|
| 910 |
+
Answer question using LLM with context.
|
| 911 |
+
|
| 912 |
+
Args:
|
| 913 |
+
question: Question to answer
|
| 914 |
+
context: Context (e.g., image description from VLM)
|
| 915 |
+
max_new_tokens: Max tokens to generate (uses instance default if None)
|
| 916 |
+
original_question: Original question asked to VLM (for conversation format)
|
| 917 |
+
|
| 918 |
+
Returns:
|
| 919 |
+
Generated answer
|
| 920 |
+
"""
|
| 921 |
+
# Use conversation format: user asks about image, assistant describes, user asks follow-up
|
| 922 |
+
if original_question:
|
| 923 |
+
messages = [
|
| 924 |
+
{"role": "user", "content": original_question},
|
| 925 |
+
{"role": "assistant", "content": context},
|
| 926 |
+
{"role": "user", "content": question}
|
| 927 |
+
]
|
| 928 |
+
else:
|
| 929 |
+
# Fallback to simple format
|
| 930 |
+
messages = [{"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}]
|
| 931 |
+
|
| 932 |
+
template_kwargs = {'enable_thinking': False}
|
| 933 |
+
|
| 934 |
+
# Some tokenizers may not support enable_thinking parameter
|
| 935 |
+
inputs = self.llm_tokenizer.apply_chat_template(
|
| 936 |
+
messages,
|
| 937 |
+
tokenize=True,
|
| 938 |
+
add_generation_prompt=True,
|
| 939 |
+
return_tensors="pt",
|
| 940 |
+
**template_kwargs
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
inputs = inputs.to(self.llm_model.device)
|
| 944 |
+
|
| 945 |
+
if max_new_tokens is None:
|
| 946 |
+
max_new_tokens = self.max_new_tokens
|
| 947 |
+
|
| 948 |
+
with torch.inference_mode():
|
| 949 |
+
outputs = self.llm_model.generate(
|
| 950 |
+
inputs,
|
| 951 |
+
max_new_tokens=max_new_tokens,
|
| 952 |
+
do_sample=False
|
| 953 |
+
)
|
| 954 |
+
|
| 955 |
+
generated_ids = outputs[:, inputs.shape[-1]:]
|
| 956 |
+
answer = self.llm_tokenizer.batch_decode(
|
| 957 |
+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 958 |
+
)[0]
|
| 959 |
+
|
| 960 |
+
return answer
|
| 961 |
+
|
| 962 |
+
def process(
|
| 963 |
+
self,
|
| 964 |
+
image_path: str,
|
| 965 |
+
question: str,
|
| 966 |
+
description_prompt: str = "Briefly describe this image."
|
| 967 |
+
) -> Dict[str, str]:
|
| 968 |
+
"""
|
| 969 |
+
Full multi-stage processing.
|
| 970 |
+
|
| 971 |
+
Args:
|
| 972 |
+
image_path: Path to image
|
| 973 |
+
question: Question to answer
|
| 974 |
+
description_prompt: Prompt for image description
|
| 975 |
+
|
| 976 |
+
Returns:
|
| 977 |
+
Dictionary with description and answer
|
| 978 |
+
"""
|
| 979 |
+
# Stage 1: Get image description
|
| 980 |
+
description = self.get_image_description(image_path, description_prompt)
|
| 981 |
+
|
| 982 |
+
# Stage 2: Answer question with context (pass original prompt for conversation format)
|
| 983 |
+
answer = self.answer_with_context(
|
| 984 |
+
question=question,
|
| 985 |
+
context=description,
|
| 986 |
+
original_question=description_prompt
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
return {
|
| 990 |
+
"description": description,
|
| 991 |
+
"answer": answer
|
| 992 |
+
}
|
rosetta/model/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (419 Bytes). View file
|
|
|
rosetta/model/__pycache__/aggregator.cpython-310.pyc
ADDED
|
Binary file (5.83 kB). View file
|
|
|
rosetta/model/__pycache__/aggregator.cpython-312.pyc
ADDED
|
Binary file (8.53 kB). View file
|
|
|
rosetta/model/__pycache__/aligner.cpython-310.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
rosetta/model/__pycache__/aligner.cpython-312.pyc
ADDED
|
Binary file (24.2 kB). View file
|
|
|
rosetta/model/__pycache__/all_in_one_projector.cpython-310.pyc
ADDED
|
Binary file (27.8 kB). View file
|
|
|
rosetta/model/__pycache__/all_in_one_projector.cpython-312.pyc
ADDED
|
Binary file (53.5 kB). View file
|
|
|
rosetta/model/__pycache__/cache.cpython-310.pyc
ADDED
|
Binary file (980 Bytes). View file
|
|
|
rosetta/model/__pycache__/oracle.cpython-310.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
rosetta/model/__pycache__/oracle.cpython-312.pyc
ADDED
|
Binary file (23.5 kB). View file
|
|
|
rosetta/model/__pycache__/projector.cpython-310.pyc
ADDED
|
Binary file (30.8 kB). View file
|
|
|
rosetta/model/__pycache__/projector.cpython-312.pyc
ADDED
|
Binary file (56 kB). View file
|
|
|
rosetta/model/__pycache__/sampling.cpython-310.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
rosetta/model/__pycache__/sampling.cpython-312.pyc
ADDED
|
Binary file (3.49 kB). View file
|
|
|
rosetta/model/__pycache__/wrapper.cpython-310.pyc
ADDED
|
Binary file (18.6 kB). View file
|
|
|
rosetta/model/__pycache__/wrapper.cpython-312.pyc
ADDED
|
Binary file (33.8 kB). View file
|
|
|
rosetta/model/ablation_projector.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Ablation Projector: A configurable projector for ablation studies based on C2CProjector.
|
| 3 |
+
Allows gradual removal of components to study their individual contributions.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from typing import Optional, Tuple, Literal
|
| 10 |
+
|
| 11 |
+
from rosetta.utils.registry import register_model, capture_init_args
|
| 12 |
+
from rosetta.model.projector import Projector
|
| 13 |
+
from rosetta.model.projector import RegularMLP
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@register_model
|
| 17 |
+
@capture_init_args
|
| 18 |
+
class AblationProjector(Projector):
|
| 19 |
+
"""
|
| 20 |
+
Ablation study projector based on C2CProjector with configurable component removal.
|
| 21 |
+
|
| 22 |
+
Ablation levels:
|
| 23 |
+
0. Full C2C (baseline)
|
| 24 |
+
1. Remove scalar weights (set to 1.0)
|
| 25 |
+
2. Remove gates (set to 1.0)
|
| 26 |
+
3. Remove target contribution (only use source)
|
| 27 |
+
4. Remove gates only (gates=1.0), keep scalars and target
|
| 28 |
+
|
| 29 |
+
Each level builds on the previous one, allowing gradual degradation study.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
source_dim: int,
|
| 35 |
+
target_dim: int,
|
| 36 |
+
source_num_heads: int = 1,
|
| 37 |
+
target_num_heads: int = 1,
|
| 38 |
+
intermediate_dim: int = 1024,
|
| 39 |
+
hidden_dim: int = 1024,
|
| 40 |
+
num_layers: int = 3,
|
| 41 |
+
dropout: float = 0.1,
|
| 42 |
+
initial_temperature: float = 1.0,
|
| 43 |
+
final_temperature: float = 0.001,
|
| 44 |
+
anneal_steps: int = 1929,
|
| 45 |
+
dtype: torch.dtype = torch.float32,
|
| 46 |
+
|
| 47 |
+
# Ablation configuration
|
| 48 |
+
ablation_level: int = 0, # 0=full, 1=no_scalar, 2=no_gate+no_scalar, 3=no_target, 4=no_gate_only
|
| 49 |
+
use_scalar_weights: bool = True, # Can be overridden by ablation_level
|
| 50 |
+
use_gates: bool = True, # Can be overridden by ablation_level
|
| 51 |
+
use_target: bool = True, # Can be overridden by ablation_level
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
assert 0 <= ablation_level <= 4, "ablation_level must be 0, 1, 2, 3, or 4"
|
| 56 |
+
|
| 57 |
+
# Dimensions
|
| 58 |
+
self.source_dim = source_dim
|
| 59 |
+
self.target_dim = target_dim
|
| 60 |
+
self.source_num_heads = source_num_heads
|
| 61 |
+
self.target_num_heads = target_num_heads
|
| 62 |
+
self.ablation_level = ablation_level
|
| 63 |
+
|
| 64 |
+
# Override component usage based on ablation level
|
| 65 |
+
if ablation_level == 4:
|
| 66 |
+
# Special case: disable gates only, keep scalars and target
|
| 67 |
+
use_scalar_weights = True
|
| 68 |
+
use_gates = False
|
| 69 |
+
use_target = True
|
| 70 |
+
else:
|
| 71 |
+
if ablation_level >= 1:
|
| 72 |
+
use_scalar_weights = False
|
| 73 |
+
if ablation_level >= 2:
|
| 74 |
+
use_gates = False
|
| 75 |
+
if ablation_level >= 3:
|
| 76 |
+
use_target = False
|
| 77 |
+
|
| 78 |
+
self.use_scalar_weights = use_scalar_weights
|
| 79 |
+
self.use_gates = use_gates
|
| 80 |
+
self.use_target = use_target
|
| 81 |
+
|
| 82 |
+
# Sizes
|
| 83 |
+
in_dim = source_dim * source_num_heads
|
| 84 |
+
out_dim = target_dim * target_num_heads
|
| 85 |
+
|
| 86 |
+
# 1) concat(source_X, target_X) then project to hidden_dim
|
| 87 |
+
# If not using target, only use source features
|
| 88 |
+
if self.use_target:
|
| 89 |
+
self.key_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype)
|
| 90 |
+
self.value_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype)
|
| 91 |
+
else:
|
| 92 |
+
# Only use source features
|
| 93 |
+
self.key_in = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype)
|
| 94 |
+
self.value_in = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype)
|
| 95 |
+
|
| 96 |
+
# 2) one-layer common embedding MLP to get intermediate representation (at hidden_dim)
|
| 97 |
+
self.key_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype)
|
| 98 |
+
self.value_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype)
|
| 99 |
+
|
| 100 |
+
# 3a) intermediate representation → (L-2)-layer MLP for weights → project to head dim
|
| 101 |
+
# Only build if using scalar weights
|
| 102 |
+
if self.use_scalar_weights:
|
| 103 |
+
self.key_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype)
|
| 104 |
+
self.value_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype)
|
| 105 |
+
self.key_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype)
|
| 106 |
+
self.value_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype)
|
| 107 |
+
|
| 108 |
+
# 3b) intermediate representation → (L-2)-layer MLP for projected_X → finally project hidden_dim → out_dim
|
| 109 |
+
self.key_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype)
|
| 110 |
+
self.value_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype)
|
| 111 |
+
self.key_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype)
|
| 112 |
+
self.value_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype)
|
| 113 |
+
|
| 114 |
+
# Scalar key/value gate parameters and temperature schedule
|
| 115 |
+
# Only build if using gates
|
| 116 |
+
if self.use_gates:
|
| 117 |
+
self.key_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
|
| 118 |
+
self.value_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
|
| 119 |
+
self.use_gumbel = True
|
| 120 |
+
self.register_buffer("gate_temperature", torch.tensor(initial_temperature, dtype=dtype))
|
| 121 |
+
self.initial_temperature = initial_temperature
|
| 122 |
+
self.final_temperature = final_temperature
|
| 123 |
+
self.anneal_steps = anneal_steps
|
| 124 |
+
|
| 125 |
+
# Temperature for weight normalization
|
| 126 |
+
self.scalar_temperature = 1.0
|
| 127 |
+
|
| 128 |
+
def update_temperature(self, step: int):
|
| 129 |
+
"""Update temperature using exponential annealing schedule for gates."""
|
| 130 |
+
if self.use_gates:
|
| 131 |
+
ratio = min(step / self.anneal_steps, 1.0)
|
| 132 |
+
temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
|
| 133 |
+
self.gate_temperature.fill_(temp)
|
| 134 |
+
|
| 135 |
+
def forward(
|
| 136 |
+
self,
|
| 137 |
+
source_kv: Tuple[Tensor, Tensor],
|
| 138 |
+
target_kv: Tuple[Tensor, Tensor],
|
| 139 |
+
position_ids: Optional[Tensor] = None,
|
| 140 |
+
max_pos: Optional[Tensor] = None,
|
| 141 |
+
) -> Tuple[Tensor, Tensor]:
|
| 142 |
+
source_key, source_value = source_kv
|
| 143 |
+
target_key, target_value = target_kv
|
| 144 |
+
|
| 145 |
+
B, Hs, N, Ds = source_key.shape
|
| 146 |
+
_, Ht, _, Dt = target_key.shape
|
| 147 |
+
|
| 148 |
+
# Flatten heads
|
| 149 |
+
source_key_flat = source_key.transpose(1, 2).contiguous().view(B, N, Hs * Ds)
|
| 150 |
+
source_value_flat = source_value.transpose(1, 2).contiguous().view(B, N, Hs * Ds)
|
| 151 |
+
target_key_flat = target_key.transpose(1, 2).contiguous().view(B, N, Ht * Dt)
|
| 152 |
+
target_value_flat = target_value.transpose(1, 2).contiguous().view(B, N, Ht * Dt)
|
| 153 |
+
|
| 154 |
+
# 1) Prepare input features based on ablation level
|
| 155 |
+
if self.use_target:
|
| 156 |
+
# Full C2C: concat source and target features
|
| 157 |
+
key_cat = torch.cat([source_key_flat, target_key_flat], dim=-1)
|
| 158 |
+
value_cat = torch.cat([source_value_flat, target_value_flat], dim=-1)
|
| 159 |
+
else:
|
| 160 |
+
# Ablation level 3: only use source features
|
| 161 |
+
key_cat = source_key_flat
|
| 162 |
+
value_cat = source_value_flat
|
| 163 |
+
|
| 164 |
+
# 2) project to hidden dim
|
| 165 |
+
key_hidden = self.key_in(key_cat)
|
| 166 |
+
value_hidden = self.value_in(value_cat)
|
| 167 |
+
|
| 168 |
+
# 3) one-layer common embedding MLP to get intermediate representation (at hidden_dim)
|
| 169 |
+
key_hidden = self.key_mlp1(key_hidden)
|
| 170 |
+
value_hidden = self.value_mlp1(value_hidden)
|
| 171 |
+
|
| 172 |
+
# 4b) intermediate representation -> projected feature path
|
| 173 |
+
key_proj_hidden = self.key_proj_out(self.key_proj_mlp2(key_hidden)) # (B, N, Ht * Dt)
|
| 174 |
+
value_proj_hidden = self.value_proj_out(self.value_proj_mlp2(value_hidden)) # (B, N, Ht * Dt)
|
| 175 |
+
projected_key = key_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt)
|
| 176 |
+
projected_value = value_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt)
|
| 177 |
+
|
| 178 |
+
# 4a) intermediate representation -> scalar path (if using scalar weights)
|
| 179 |
+
if self.use_scalar_weights:
|
| 180 |
+
key_scalar = self.key_scalar_head(self.key_scalar_mlp2(key_hidden)) # (B, N, Ht)
|
| 181 |
+
value_scalar = self.value_scalar_head(self.value_scalar_mlp2(value_hidden)) # (B, N, Ht)
|
| 182 |
+
key_scalar = key_scalar.permute(0, 2, 1).unsqueeze(-1) # (B, Ht, N, 1)
|
| 183 |
+
value_scalar = value_scalar.permute(0, 2, 1).unsqueeze(-1) # (B, Ht, N, 1)
|
| 184 |
+
# Normalize scalars
|
| 185 |
+
norm_key_scalar = torch.sigmoid(key_scalar)
|
| 186 |
+
norm_value_scalar = torch.sigmoid(value_scalar)
|
| 187 |
+
else:
|
| 188 |
+
# Ablation level 1+: set scalar weights to 1.0
|
| 189 |
+
norm_key_scalar = torch.ones(B, Ht, N, 1, device=projected_key.device, dtype=projected_key.dtype)
|
| 190 |
+
norm_value_scalar = torch.ones(B, Ht, N, 1, device=projected_value.device, dtype=projected_value.dtype)
|
| 191 |
+
|
| 192 |
+
# Key/value gates (if using gates)
|
| 193 |
+
if self.use_gates:
|
| 194 |
+
key_gate_logit = self.key_gate_logit.view(1, 1, 1, 1)
|
| 195 |
+
value_gate_logit = self.value_gate_logit.view(1, 1, 1, 1)
|
| 196 |
+
if self.training and self.use_gumbel:
|
| 197 |
+
u1 = torch.rand(B, Ht, N, 1, device=key_gate_logit.device, dtype=key_gate_logit.dtype)
|
| 198 |
+
u2 = torch.rand(B, Ht, N, 1, device=value_gate_logit.device, dtype=value_gate_logit.dtype)
|
| 199 |
+
g1 = -torch.log(-torch.log(u1 + 1e-20) + 1e-20)
|
| 200 |
+
g2 = -torch.log(-torch.log(u2 + 1e-20) + 1e-20)
|
| 201 |
+
key_gate = torch.sigmoid((key_gate_logit + g1) / self.gate_temperature)
|
| 202 |
+
value_gate = torch.sigmoid((value_gate_logit + g2) / self.gate_temperature)
|
| 203 |
+
else:
|
| 204 |
+
key_gate = (key_gate_logit > 0).float()
|
| 205 |
+
value_gate = (value_gate_logit > 0).float()
|
| 206 |
+
else:
|
| 207 |
+
# Gates disabled: set gates to 1.0 (always open)
|
| 208 |
+
key_gate = torch.ones(B, Ht, N, 1, device=projected_key.device, dtype=projected_key.dtype)
|
| 209 |
+
value_gate = torch.ones(B, Ht, N, 1, device=projected_value.device, dtype=projected_value.dtype)
|
| 210 |
+
|
| 211 |
+
# Compute projected contribution
|
| 212 |
+
projected_key_term = key_gate * norm_key_scalar * projected_key
|
| 213 |
+
projected_value_term = value_gate * norm_value_scalar * projected_value
|
| 214 |
+
|
| 215 |
+
# Compute target contribution (if using target)
|
| 216 |
+
if self.use_target:
|
| 217 |
+
# Full C2C: add target with projected
|
| 218 |
+
output_key = target_key + projected_key_term
|
| 219 |
+
output_value = target_value + projected_value_term
|
| 220 |
+
else:
|
| 221 |
+
# Ablation level 3: only use projected (no target)
|
| 222 |
+
output_key = projected_key_term
|
| 223 |
+
output_value = projected_value_term
|
| 224 |
+
|
| 225 |
+
return output_key, output_value
|
| 226 |
+
|
| 227 |
+
def get_ablation_info(self) -> dict:
|
| 228 |
+
"""Return information about current ablation configuration."""
|
| 229 |
+
return {
|
| 230 |
+
'ablation_level': self.ablation_level,
|
| 231 |
+
'use_scalar_weights': self.use_scalar_weights,
|
| 232 |
+
'use_gates': self.use_gates,
|
| 233 |
+
'use_target': self.use_target,
|
| 234 |
+
'description': self._get_ablation_description()
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
def _get_ablation_description(self) -> str:
|
| 238 |
+
"""Get human-readable description of current ablation level."""
|
| 239 |
+
descriptions = {
|
| 240 |
+
0: "Full C2C (baseline)",
|
| 241 |
+
1: "No scalar weights (scalars=1.0)",
|
| 242 |
+
2: "No gates (gates=1.0) + No scalar weights",
|
| 243 |
+
3: "No target (source-only) + No gates + No scalar weights",
|
| 244 |
+
4: "No gates (gates=1.0), keep scalars and target"
|
| 245 |
+
}
|
| 246 |
+
return descriptions.get(self.ablation_level, "Unknown ablation level")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# Convenience functions for creating specific ablation levels
|
| 250 |
+
def create_ablation_projector(
|
| 251 |
+
source_dim: int,
|
| 252 |
+
target_dim: int,
|
| 253 |
+
source_num_heads: int = 1,
|
| 254 |
+
target_num_heads: int = 1,
|
| 255 |
+
ablation_level: int = 0,
|
| 256 |
+
**kwargs
|
| 257 |
+
) -> AblationProjector:
|
| 258 |
+
"""Create an AblationProjector with specified ablation level."""
|
| 259 |
+
return AblationProjector(
|
| 260 |
+
source_dim=source_dim,
|
| 261 |
+
target_dim=target_dim,
|
| 262 |
+
source_num_heads=source_num_heads,
|
| 263 |
+
target_num_heads=target_num_heads,
|
| 264 |
+
ablation_level=ablation_level,
|
| 265 |
+
**kwargs
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def create_full_c2c_projector(**kwargs) -> AblationProjector:
|
| 270 |
+
"""Create full C2C projector (ablation level 0)."""
|
| 271 |
+
return create_ablation_projector(ablation_level=0, **kwargs)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def create_no_scalar_projector(**kwargs) -> AblationProjector:
|
| 275 |
+
"""Create projector without scalar weights (ablation level 1)."""
|
| 276 |
+
return create_ablation_projector(ablation_level=1, **kwargs)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def create_no_gate_projector(**kwargs) -> AblationProjector:
|
| 280 |
+
"""Create projector without gates (ablation level 2)."""
|
| 281 |
+
return create_ablation_projector(ablation_level=2, **kwargs)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def create_source_only_projector(**kwargs) -> AblationProjector:
|
| 285 |
+
"""Create source-only projector (ablation level 3)."""
|
| 286 |
+
return create_ablation_projector(ablation_level=3, **kwargs)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def create_no_gate_only_projector(**kwargs) -> AblationProjector:
|
| 290 |
+
"""Create projector without gates but with scalar weights and target (ablation level 4)."""
|
| 291 |
+
return create_ablation_projector(ablation_level=4, **kwargs)
|
rosetta/model/aligner.py
ADDED
|
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Token Aligner for handling different tokenizers between SLM and LLM models.
|
| 3 |
+
|
| 4 |
+
This module provides functionality to align tokens between two different tokenizers,
|
| 5 |
+
handling cases where the same text is tokenized differently.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import List, Tuple, Optional, Dict, Literal, Union
|
| 9 |
+
import torch
|
| 10 |
+
from transformers import PreTrainedTokenizerBase
|
| 11 |
+
from enum import Enum
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AlignmentStrategy(Enum):
|
| 15 |
+
"""Strategies for handling 1-to-many token alignments"""
|
| 16 |
+
FIRST = "first" # Always take the first LLM token
|
| 17 |
+
LONGEST = "longest" # Take the LLM token with the longest string
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TokenAligner:
|
| 21 |
+
"""
|
| 22 |
+
Aligns tokens between SLM (Small Language Model) and LLM (Large Language Model) tokenizers.
|
| 23 |
+
|
| 24 |
+
This class handles the case where the same text sequence is tokenized differently
|
| 25 |
+
by different tokenizers, using the SLM tokenization as the base and finding
|
| 26 |
+
corresponding LLM tokens for each SLM token.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
slm_tokenizer: PreTrainedTokenizerBase,
|
| 32 |
+
llm_tokenizer: PreTrainedTokenizerBase,
|
| 33 |
+
strategy: Union[AlignmentStrategy, str] = AlignmentStrategy.FIRST,
|
| 34 |
+
verbose: bool = False
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Initialize the TokenAligner.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
slm_tokenizer: The tokenizer for the Small Language Model (base)
|
| 41 |
+
llm_tokenizer: The tokenizer for the Large Language Model
|
| 42 |
+
strategy: Strategy for handling 1-to-many token mappings
|
| 43 |
+
Either AlignmentStrategy enum or string ('first' or 'longest')
|
| 44 |
+
verbose: Whether to print debug information during alignment
|
| 45 |
+
"""
|
| 46 |
+
self.slm_tokenizer = slm_tokenizer
|
| 47 |
+
self.llm_tokenizer = llm_tokenizer
|
| 48 |
+
|
| 49 |
+
if self.slm_tokenizer.pad_token is None:
|
| 50 |
+
self.slm_tokenizer.pad_token = self.slm_tokenizer.eos_token
|
| 51 |
+
self.slm_tokenizer.pad_token_id = self.slm_tokenizer.eos_token_id
|
| 52 |
+
if self.llm_tokenizer.pad_token is None:
|
| 53 |
+
self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
|
| 54 |
+
self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id
|
| 55 |
+
|
| 56 |
+
# Handle string strategy input
|
| 57 |
+
if isinstance(strategy, str):
|
| 58 |
+
strategy = AlignmentStrategy(strategy.lower())
|
| 59 |
+
self.strategy = strategy
|
| 60 |
+
self.verbose = verbose
|
| 61 |
+
|
| 62 |
+
# Cache for token mappings to improve performance
|
| 63 |
+
self._alignment_cache: Dict[Tuple[int, ...], List[int]] = {}
|
| 64 |
+
|
| 65 |
+
def align_tokens(
|
| 66 |
+
self,
|
| 67 |
+
slm_token_ids: Union[List[int], torch.Tensor],
|
| 68 |
+
return_mapping: bool = False
|
| 69 |
+
) -> Union[List[int], Tuple[List[int], List[Tuple[int, List[int]]]]]:
|
| 70 |
+
"""
|
| 71 |
+
Align SLM tokens to LLM tokens.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
slm_token_ids: Token IDs from the SLM tokenizer
|
| 75 |
+
return_mapping: If True, also return the detailed mapping
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
If return_mapping is False: List of aligned LLM token IDs
|
| 79 |
+
If return_mapping is True: Tuple of (aligned_llm_token_ids, mapping_details)
|
| 80 |
+
where mapping_details is a list of (slm_token_id, [candidate_llm_token_ids])
|
| 81 |
+
"""
|
| 82 |
+
# Convert to list if tensor
|
| 83 |
+
if isinstance(slm_token_ids, torch.Tensor):
|
| 84 |
+
slm_token_ids = slm_token_ids.tolist()
|
| 85 |
+
|
| 86 |
+
# Check cache
|
| 87 |
+
cache_key = tuple(slm_token_ids)
|
| 88 |
+
if cache_key in self._alignment_cache and not return_mapping:
|
| 89 |
+
return self._alignment_cache[cache_key]
|
| 90 |
+
|
| 91 |
+
aligned_llm_tokens = []
|
| 92 |
+
mapping_details = []
|
| 93 |
+
|
| 94 |
+
for slm_token_id in slm_token_ids:
|
| 95 |
+
# Decode SLM token to string (without special token processing)
|
| 96 |
+
slm_token_str = self.slm_tokenizer.decode(
|
| 97 |
+
[slm_token_id],
|
| 98 |
+
skip_special_tokens=False,
|
| 99 |
+
clean_up_tokenization_spaces=False
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Handle special tokens
|
| 103 |
+
if slm_token_id in self.slm_tokenizer.all_special_ids:
|
| 104 |
+
# Try to find corresponding special token in LLM tokenizer
|
| 105 |
+
llm_token_id = self._map_special_token(slm_token_id, slm_token_str)
|
| 106 |
+
aligned_llm_tokens.append(llm_token_id)
|
| 107 |
+
mapping_details.append((slm_token_id, [llm_token_id]))
|
| 108 |
+
continue
|
| 109 |
+
|
| 110 |
+
# Tokenize the string with LLM tokenizer
|
| 111 |
+
llm_token_ids = self.llm_tokenizer.encode(
|
| 112 |
+
slm_token_str,
|
| 113 |
+
add_special_tokens=False,
|
| 114 |
+
return_tensors=None
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if len(llm_token_ids) == 0:
|
| 118 |
+
# Handle empty tokenization (shouldn't normally happen)
|
| 119 |
+
if self.verbose:
|
| 120 |
+
print(f"Warning: SLM token {slm_token_id} ('{slm_token_str}') "
|
| 121 |
+
f"resulted in empty LLM tokenization")
|
| 122 |
+
# Use unknown token as fallback
|
| 123 |
+
llm_token_id = self.llm_tokenizer.unk_token_id or 0
|
| 124 |
+
aligned_llm_tokens.append(llm_token_id)
|
| 125 |
+
mapping_details.append((slm_token_id, [llm_token_id]))
|
| 126 |
+
|
| 127 |
+
elif len(llm_token_ids) == 1:
|
| 128 |
+
# Perfect 1-to-1 mapping
|
| 129 |
+
aligned_llm_tokens.append(llm_token_ids[0])
|
| 130 |
+
mapping_details.append((slm_token_id, llm_token_ids))
|
| 131 |
+
|
| 132 |
+
else:
|
| 133 |
+
# 1-to-many mapping, apply strategy
|
| 134 |
+
selected_token = self._apply_strategy(llm_token_ids, slm_token_str)
|
| 135 |
+
aligned_llm_tokens.append(selected_token)
|
| 136 |
+
mapping_details.append((slm_token_id, llm_token_ids))
|
| 137 |
+
|
| 138 |
+
if self.verbose:
|
| 139 |
+
selected_str = self.llm_tokenizer.decode(
|
| 140 |
+
[selected_token],
|
| 141 |
+
skip_special_tokens=False,
|
| 142 |
+
clean_up_tokenization_spaces=False
|
| 143 |
+
)
|
| 144 |
+
print(f"SLM token {slm_token_id} ('{slm_token_str}') -> "
|
| 145 |
+
f"LLM tokens {llm_token_ids}, selected {selected_token} ('{selected_str}')")
|
| 146 |
+
|
| 147 |
+
# Cache the result
|
| 148 |
+
self._alignment_cache[cache_key] = aligned_llm_tokens
|
| 149 |
+
|
| 150 |
+
if return_mapping:
|
| 151 |
+
return aligned_llm_tokens, mapping_details
|
| 152 |
+
return aligned_llm_tokens
|
| 153 |
+
|
| 154 |
+
def _map_special_token(self, slm_token_id: int, slm_token_str: str) -> int:
|
| 155 |
+
"""
|
| 156 |
+
Map special tokens between tokenizers.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
slm_token_id: The SLM special token ID
|
| 160 |
+
slm_token_str: The string representation of the special token
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
The corresponding LLM token ID
|
| 164 |
+
"""
|
| 165 |
+
# Common special token mappings
|
| 166 |
+
special_token_map = {
|
| 167 |
+
self.slm_tokenizer.pad_token_id: self.llm_tokenizer.pad_token_id,
|
| 168 |
+
self.slm_tokenizer.eos_token_id: self.llm_tokenizer.eos_token_id,
|
| 169 |
+
self.slm_tokenizer.bos_token_id: self.llm_tokenizer.bos_token_id,
|
| 170 |
+
self.slm_tokenizer.unk_token_id: self.llm_tokenizer.unk_token_id,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
# Direct mapping if available
|
| 174 |
+
if slm_token_id in special_token_map and special_token_map[slm_token_id] is not None:
|
| 175 |
+
return special_token_map[slm_token_id]
|
| 176 |
+
|
| 177 |
+
# Try to find by string representation
|
| 178 |
+
try:
|
| 179 |
+
llm_token_id = self.llm_tokenizer.convert_tokens_to_ids(slm_token_str)
|
| 180 |
+
if llm_token_id != self.llm_tokenizer.unk_token_id:
|
| 181 |
+
return llm_token_id
|
| 182 |
+
except:
|
| 183 |
+
pass
|
| 184 |
+
|
| 185 |
+
# Fallback to unknown token
|
| 186 |
+
return self.llm_tokenizer.unk_token_id or 0
|
| 187 |
+
|
| 188 |
+
def _apply_strategy(self, llm_token_ids: List[int], original_str: str) -> int:
|
| 189 |
+
"""
|
| 190 |
+
Apply the selected strategy to choose one LLM token from multiple candidates.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
llm_token_ids: List of candidate LLM token IDs
|
| 194 |
+
original_str: The original string from SLM token
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
The selected LLM token ID
|
| 198 |
+
"""
|
| 199 |
+
if self.strategy == AlignmentStrategy.FIRST:
|
| 200 |
+
return llm_token_ids[0]
|
| 201 |
+
|
| 202 |
+
elif self.strategy == AlignmentStrategy.LONGEST:
|
| 203 |
+
# Find the token with the longest string representation
|
| 204 |
+
longest_token = llm_token_ids[0]
|
| 205 |
+
longest_length = 0
|
| 206 |
+
|
| 207 |
+
for token_id in llm_token_ids:
|
| 208 |
+
token_str = self.llm_tokenizer.decode(
|
| 209 |
+
[token_id],
|
| 210 |
+
skip_special_tokens=False,
|
| 211 |
+
clean_up_tokenization_spaces=False
|
| 212 |
+
)
|
| 213 |
+
if len(token_str) > longest_length:
|
| 214 |
+
longest_length = len(token_str)
|
| 215 |
+
longest_token = token_id
|
| 216 |
+
|
| 217 |
+
return longest_token
|
| 218 |
+
|
| 219 |
+
else:
|
| 220 |
+
# Default to first token if unknown strategy
|
| 221 |
+
return llm_token_ids[0]
|
| 222 |
+
|
| 223 |
+
def align_sequence(
|
| 224 |
+
self,
|
| 225 |
+
text: str,
|
| 226 |
+
return_details: bool = False
|
| 227 |
+
) -> Union[Tuple[List[int], List[int]], Dict[str, any]]:
|
| 228 |
+
"""
|
| 229 |
+
Tokenize text with both tokenizers and return aligned sequences.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
text: The input text to tokenize and align
|
| 233 |
+
return_details: If True, return detailed alignment information
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
If return_details is False: Tuple of (slm_token_ids, aligned_llm_token_ids)
|
| 237 |
+
If return_details is True: Dictionary with detailed alignment information
|
| 238 |
+
"""
|
| 239 |
+
# Tokenize with SLM
|
| 240 |
+
slm_tokens = self.slm_tokenizer.encode(
|
| 241 |
+
text,
|
| 242 |
+
add_special_tokens=True,
|
| 243 |
+
return_tensors=None
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Get aligned LLM tokens
|
| 247 |
+
if return_details:
|
| 248 |
+
aligned_llm_tokens, mapping = self.align_tokens(slm_tokens, return_mapping=True)
|
| 249 |
+
|
| 250 |
+
# Decode tokens for inspection
|
| 251 |
+
slm_decoded = [
|
| 252 |
+
self.slm_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
| 253 |
+
for tid in slm_tokens
|
| 254 |
+
]
|
| 255 |
+
llm_decoded = [
|
| 256 |
+
self.llm_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
| 257 |
+
for tid in aligned_llm_tokens
|
| 258 |
+
]
|
| 259 |
+
|
| 260 |
+
# Original LLM tokenization for comparison
|
| 261 |
+
original_llm_tokens = self.llm_tokenizer.encode(
|
| 262 |
+
text,
|
| 263 |
+
add_special_tokens=True,
|
| 264 |
+
return_tensors=None
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# One-to-one mapping statistics
|
| 268 |
+
num_tokens = len(slm_tokens)
|
| 269 |
+
one_to_one_count = sum(1 for _slm_id, candidates in mapping if len(candidates) == 1)
|
| 270 |
+
one_to_one_rate = (one_to_one_count / num_tokens) if num_tokens > 0 else 0.0
|
| 271 |
+
|
| 272 |
+
return {
|
| 273 |
+
'text': text,
|
| 274 |
+
'slm_token_ids': slm_tokens,
|
| 275 |
+
'slm_decoded': slm_decoded,
|
| 276 |
+
'aligned_llm_token_ids': aligned_llm_tokens,
|
| 277 |
+
'aligned_llm_decoded': llm_decoded,
|
| 278 |
+
'original_llm_token_ids': original_llm_tokens,
|
| 279 |
+
'mapping': mapping,
|
| 280 |
+
'strategy': self.strategy.value,
|
| 281 |
+
'num_tokens': num_tokens,
|
| 282 |
+
'one_to_one_count': one_to_one_count,
|
| 283 |
+
'one_to_one_rate': one_to_one_rate
|
| 284 |
+
}
|
| 285 |
+
else:
|
| 286 |
+
aligned_llm_tokens = self.align_tokens(slm_tokens)
|
| 287 |
+
return slm_tokens, aligned_llm_tokens
|
| 288 |
+
|
| 289 |
+
def visualize_alignment(self, text: str):
|
| 290 |
+
"""
|
| 291 |
+
Print a visual representation of the token alignment.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
text: The text to analyze
|
| 295 |
+
"""
|
| 296 |
+
details = self.align_sequence(text, return_details=True)
|
| 297 |
+
|
| 298 |
+
print("=" * 80)
|
| 299 |
+
print(f"Text: {text}")
|
| 300 |
+
print(f"Strategy: {details['strategy']}")
|
| 301 |
+
print("=" * 80)
|
| 302 |
+
print(f"SLM tokens ({len(details['slm_token_ids'])}): {details['slm_token_ids']}")
|
| 303 |
+
print(f"Aligned LLM tokens ({len(details['aligned_llm_token_ids'])}): {details['aligned_llm_token_ids']}")
|
| 304 |
+
print(f"Original LLM tokens ({len(details['original_llm_token_ids'])}): {details['original_llm_token_ids']}")
|
| 305 |
+
print("-" * 80)
|
| 306 |
+
print("Token-by-token alignment:")
|
| 307 |
+
|
| 308 |
+
for i, (slm_id, llm_id) in enumerate(zip(details['slm_token_ids'], details['aligned_llm_token_ids'])):
|
| 309 |
+
slm_str = details['slm_decoded'][i]
|
| 310 |
+
llm_str = details['aligned_llm_decoded'][i]
|
| 311 |
+
mapping_info = details['mapping'][i]
|
| 312 |
+
|
| 313 |
+
if len(mapping_info[1]) > 1:
|
| 314 |
+
candidates_str = ', '.join([
|
| 315 |
+
f"{tid}:'{self.llm_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False)}'"
|
| 316 |
+
for tid in mapping_info[1]
|
| 317 |
+
])
|
| 318 |
+
print(f" [{i:3d}] SLM {slm_id:6d} ('{slm_str}') -> "
|
| 319 |
+
f"LLM {llm_id:6d} ('{llm_str}') "
|
| 320 |
+
f"[candidates: {candidates_str}]")
|
| 321 |
+
else:
|
| 322 |
+
print(f" [{i:3d}] SLM {slm_id:6d} ('{slm_str}') -> "
|
| 323 |
+
f"LLM {llm_id:6d} ('{llm_str}')")
|
| 324 |
+
print("=" * 80)
|
| 325 |
+
|
| 326 |
+
def clear_cache(self):
|
| 327 |
+
"""Clear the alignment cache."""
|
| 328 |
+
self._alignment_cache.clear()
|
| 329 |
+
|
| 330 |
+
# ========================
|
| 331 |
+
# Chat messages alignment
|
| 332 |
+
# ========================
|
| 333 |
+
def _apply_chat_template_to_ids(
|
| 334 |
+
self,
|
| 335 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 336 |
+
messages: List[Dict[str, str]],
|
| 337 |
+
add_generation_prompt: bool,
|
| 338 |
+
enable_thinking: bool,
|
| 339 |
+
remove_last_surfix: bool
|
| 340 |
+
) -> Tuple[str, List[int], Optional[List[Tuple[int, int]]]]:
|
| 341 |
+
"""
|
| 342 |
+
Apply chat template (no tokenization) then tokenize to ids with optional offsets.
|
| 343 |
+
If remove_last_surfix is True, remove the last suffix from the LLM text
|
| 344 |
+
Returns (templated_text, input_ids, offsets) where offsets may be None.
|
| 345 |
+
"""
|
| 346 |
+
if remove_last_surfix:
|
| 347 |
+
assert messages[-1]["role"] == "assistant", "Last message must be an assistant message"
|
| 348 |
+
templated_text = tokenizer.apply_chat_template(
|
| 349 |
+
messages[:-1],
|
| 350 |
+
tokenize=False,
|
| 351 |
+
add_generation_prompt=True,
|
| 352 |
+
enable_thinking=enable_thinking
|
| 353 |
+
)
|
| 354 |
+
templated_text += messages[-1]["content"]
|
| 355 |
+
else:
|
| 356 |
+
templated_text = tokenizer.apply_chat_template(
|
| 357 |
+
messages,
|
| 358 |
+
tokenize=False,
|
| 359 |
+
add_generation_prompt=add_generation_prompt,
|
| 360 |
+
enable_thinking=enable_thinking
|
| 361 |
+
)
|
| 362 |
+
encoded = tokenizer(
|
| 363 |
+
templated_text,
|
| 364 |
+
add_special_tokens=False,
|
| 365 |
+
return_offsets_mapping=True
|
| 366 |
+
)
|
| 367 |
+
input_ids: List[int] = encoded["input_ids"]
|
| 368 |
+
offsets = encoded.get("offset_mapping")
|
| 369 |
+
return templated_text, input_ids, offsets
|
| 370 |
+
|
| 371 |
+
@staticmethod
|
| 372 |
+
def _first_non_empty_content(messages: List[Dict[str, str]]) -> Optional[str]:
|
| 373 |
+
for m in messages:
|
| 374 |
+
content = m.get("content")
|
| 375 |
+
if isinstance(content, str) and len(content.strip()) > 0:
|
| 376 |
+
return content
|
| 377 |
+
return None
|
| 378 |
+
|
| 379 |
+
def _find_boundary_token_index(
|
| 380 |
+
self,
|
| 381 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 382 |
+
templated_text: str,
|
| 383 |
+
offsets: Optional[List[Tuple[int, int]]],
|
| 384 |
+
content_text: Optional[str]
|
| 385 |
+
) -> int:
|
| 386 |
+
"""
|
| 387 |
+
Find token index where the first non-empty message content starts.
|
| 388 |
+
Falls back to 0 if not found.
|
| 389 |
+
"""
|
| 390 |
+
if not content_text:
|
| 391 |
+
return 0
|
| 392 |
+
char_idx = templated_text.find(content_text)
|
| 393 |
+
if char_idx < 0:
|
| 394 |
+
# Try a shorter probe to improve chances
|
| 395 |
+
probe = content_text[: min(32, len(content_text))]
|
| 396 |
+
if len(probe) > 0:
|
| 397 |
+
char_idx = templated_text.find(probe)
|
| 398 |
+
if char_idx < 0:
|
| 399 |
+
return 0
|
| 400 |
+
|
| 401 |
+
if offsets:
|
| 402 |
+
for idx, (start, _end) in enumerate(offsets):
|
| 403 |
+
if start >= char_idx:
|
| 404 |
+
return idx
|
| 405 |
+
return len(offsets)
|
| 406 |
+
|
| 407 |
+
# Fallback without offsets: tokenize prefix and count tokens
|
| 408 |
+
prefix = templated_text[:char_idx]
|
| 409 |
+
prefix_ids = tokenizer(prefix, add_special_tokens=False)["input_ids"]
|
| 410 |
+
return len(prefix_ids)
|
| 411 |
+
|
| 412 |
+
@staticmethod
|
| 413 |
+
def _compute_content_spans(templated_text: str, messages: List[Dict[str, str]]) -> List[Tuple[int, int]]:
|
| 414 |
+
"""
|
| 415 |
+
Compute character spans in templated_text that correspond to message contents.
|
| 416 |
+
Searches sequentially to reduce ambiguity when contents repeat.
|
| 417 |
+
Enhanced matching: ensures the found content is followed by '<' (special token start)
|
| 418 |
+
to avoid matching content inside special tokens like <begin_of_text>.
|
| 419 |
+
"""
|
| 420 |
+
spans: List[Tuple[int, int]] = []
|
| 421 |
+
search_from = 0
|
| 422 |
+
for m in messages:
|
| 423 |
+
content = m.get("content")
|
| 424 |
+
if not isinstance(content, str) or len(content) == 0:
|
| 425 |
+
continue
|
| 426 |
+
|
| 427 |
+
# Find all possible matches starting from search_from
|
| 428 |
+
idx = search_from
|
| 429 |
+
found_valid_match = False
|
| 430 |
+
|
| 431 |
+
while idx < len(templated_text):
|
| 432 |
+
idx = templated_text.find(content, idx)
|
| 433 |
+
if idx < 0:
|
| 434 |
+
break
|
| 435 |
+
|
| 436 |
+
# Check if this match is valid (followed by '<' indicating a special token)
|
| 437 |
+
end_pos = idx + len(content)
|
| 438 |
+
if end_pos < len(templated_text) and templated_text[end_pos] == '<':
|
| 439 |
+
# Valid match: content is followed by a special token
|
| 440 |
+
spans.append((idx, end_pos))
|
| 441 |
+
search_from = end_pos
|
| 442 |
+
found_valid_match = True
|
| 443 |
+
break
|
| 444 |
+
else:
|
| 445 |
+
# Check if this is the end of the text (also valid for last message)
|
| 446 |
+
if end_pos == len(templated_text):
|
| 447 |
+
spans.append((idx, end_pos))
|
| 448 |
+
search_from = end_pos
|
| 449 |
+
found_valid_match = True
|
| 450 |
+
break
|
| 451 |
+
|
| 452 |
+
# Invalid match, try next occurrence
|
| 453 |
+
idx += 1
|
| 454 |
+
|
| 455 |
+
# Fallback: if no valid match found with '<' requirement, use the old method
|
| 456 |
+
# but only as a last resort and with additional validation
|
| 457 |
+
if not found_valid_match:
|
| 458 |
+
idx = templated_text.find(content, search_from)
|
| 459 |
+
if idx < 0:
|
| 460 |
+
# Try searching from start as last resort
|
| 461 |
+
idx = templated_text.find(content)
|
| 462 |
+
|
| 463 |
+
if idx >= 0:
|
| 464 |
+
end_pos = idx + len(content)
|
| 465 |
+
# Additional check: avoid matching inside obvious special tokens
|
| 466 |
+
# Check if we're inside a special token (preceded by '<' and not followed by '>')
|
| 467 |
+
start_context = templated_text[max(0, idx-10):idx]
|
| 468 |
+
end_context = templated_text[end_pos:min(len(templated_text), end_pos+10)]
|
| 469 |
+
|
| 470 |
+
# Skip if we're clearly inside a special token
|
| 471 |
+
if ('<' in start_context and '>' not in start_context and
|
| 472 |
+
'begin_of_text' in templated_text[max(0, idx-20):idx+20]):
|
| 473 |
+
# This looks like we're matching inside <begin_of_text> or similar
|
| 474 |
+
continue
|
| 475 |
+
|
| 476 |
+
spans.append((idx, end_pos))
|
| 477 |
+
search_from = end_pos
|
| 478 |
+
|
| 479 |
+
return spans
|
| 480 |
+
|
| 481 |
+
@staticmethod
|
| 482 |
+
def _build_token_mask_from_spans(
|
| 483 |
+
offsets: Optional[List[Tuple[int, int]]],
|
| 484 |
+
num_tokens: int,
|
| 485 |
+
spans: List[Tuple[int, int]]
|
| 486 |
+
) -> List[bool]:
|
| 487 |
+
"""
|
| 488 |
+
Build a boolean mask for tokens whose offset range overlaps any span.
|
| 489 |
+
If offsets are missing, default to all False.
|
| 490 |
+
"""
|
| 491 |
+
if not offsets or len(offsets) != num_tokens:
|
| 492 |
+
return [False] * num_tokens
|
| 493 |
+
mask: List[bool] = []
|
| 494 |
+
for (start, end) in offsets:
|
| 495 |
+
if end <= start:
|
| 496 |
+
mask.append(False)
|
| 497 |
+
continue
|
| 498 |
+
is_msg = False
|
| 499 |
+
for s, e in spans:
|
| 500 |
+
# overlap check
|
| 501 |
+
if start < e and end > s:
|
| 502 |
+
is_msg = True
|
| 503 |
+
break
|
| 504 |
+
mask.append(is_msg)
|
| 505 |
+
return mask
|
| 506 |
+
|
| 507 |
+
@staticmethod
|
| 508 |
+
def _spans_to_token_ranges(
|
| 509 |
+
offsets: List[Tuple[int, int]],
|
| 510 |
+
spans: List[Tuple[int, int]]
|
| 511 |
+
) -> List[Tuple[int, int]]:
|
| 512 |
+
"""
|
| 513 |
+
Convert character spans to token index ranges using offsets.
|
| 514 |
+
start token = first token with end > span_start
|
| 515 |
+
end token = first token with start >= span_end
|
| 516 |
+
"""
|
| 517 |
+
ranges: List[Tuple[int, int]] = []
|
| 518 |
+
n = len(offsets)
|
| 519 |
+
for s, e in spans:
|
| 520 |
+
# find start index
|
| 521 |
+
start_idx = 0
|
| 522 |
+
while start_idx < n and offsets[start_idx][1] <= s:
|
| 523 |
+
start_idx += 1
|
| 524 |
+
# find end index
|
| 525 |
+
end_idx = start_idx
|
| 526 |
+
while end_idx < n and offsets[end_idx][0] < e:
|
| 527 |
+
end_idx += 1
|
| 528 |
+
ranges.append((start_idx, end_idx))
|
| 529 |
+
return ranges
|
| 530 |
+
|
| 531 |
+
def align_chat_messages(
|
| 532 |
+
self,
|
| 533 |
+
messages: List[Dict[str, str]],
|
| 534 |
+
add_generation_prompt: bool = True,
|
| 535 |
+
enable_thinking: bool = False,
|
| 536 |
+
return_details: bool = False,
|
| 537 |
+
remove_last_surfix: bool = False
|
| 538 |
+
) -> Dict[str, any]:
|
| 539 |
+
"""
|
| 540 |
+
Align chat-templated sequences by sections (template/message/template...):
|
| 541 |
+
- Preserve all template tokens (pad the shorter template section)
|
| 542 |
+
- For each message section, map SLM tokens to LLM tokens 1:1 via strategy
|
| 543 |
+
- If remove_last_surfix is True, remove the last suffix from the LLM text
|
| 544 |
+
Returns essentials: slm_ids_padded, llm_ids_padded, message_mask (shared),
|
| 545 |
+
slm_padding_mask, llm_padding_mask (True where token is padding inserted).
|
| 546 |
+
When return_details=True, also returns 'sections' with aligned ranges.
|
| 547 |
+
"""
|
| 548 |
+
assert not (add_generation_prompt and remove_last_surfix), "add_generation_prompt and remove_last_surfix cannot be True at the same time"
|
| 549 |
+
|
| 550 |
+
# Build templated sequences with offsets
|
| 551 |
+
slm_text, slm_ids, slm_offsets = self._apply_chat_template_to_ids(
|
| 552 |
+
self.slm_tokenizer, messages, add_generation_prompt, enable_thinking, remove_last_surfix
|
| 553 |
+
)
|
| 554 |
+
llm_text, llm_ids, llm_offsets = self._apply_chat_template_to_ids(
|
| 555 |
+
self.llm_tokenizer, messages, add_generation_prompt, enable_thinking, remove_last_surfix
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
# Required pad tokens
|
| 559 |
+
assert self.slm_tokenizer.pad_token_id is not None, "SLM pad_token_id required"
|
| 560 |
+
assert self.llm_tokenizer.pad_token_id is not None, "LLM pad_token_id required"
|
| 561 |
+
slm_pad_id = self.slm_tokenizer.pad_token_id
|
| 562 |
+
llm_pad_id = self.llm_tokenizer.pad_token_id
|
| 563 |
+
|
| 564 |
+
# Content spans (char) and token ranges
|
| 565 |
+
content_spans_slm = self._compute_content_spans(slm_text, messages)
|
| 566 |
+
content_spans_llm = self._compute_content_spans(llm_text, messages)
|
| 567 |
+
assert slm_offsets is not None and llm_offsets is not None, "offset_mapping required"
|
| 568 |
+
slm_msg_ranges = self._spans_to_token_ranges(slm_offsets, content_spans_slm)
|
| 569 |
+
llm_msg_ranges = self._spans_to_token_ranges(llm_offsets, content_spans_llm)
|
| 570 |
+
# Build section ranges (template/message alternating)
|
| 571 |
+
def build_sections(total_len: int, msg_ranges: List[Tuple[int,int]]):
|
| 572 |
+
sections: List[Tuple[str,int,int]] = []
|
| 573 |
+
prev = 0
|
| 574 |
+
for (s, e) in msg_ranges:
|
| 575 |
+
if prev < s:
|
| 576 |
+
sections.append(("template", prev, s))
|
| 577 |
+
sections.append(("message", s, e))
|
| 578 |
+
prev = e
|
| 579 |
+
if prev < total_len:
|
| 580 |
+
sections.append(("template", prev, total_len))
|
| 581 |
+
return sections
|
| 582 |
+
slm_sections = build_sections(len(slm_ids), slm_msg_ranges)
|
| 583 |
+
llm_sections = build_sections(len(llm_ids), llm_msg_ranges)
|
| 584 |
+
assert len(slm_sections) == len(llm_sections), "Section count mismatch"
|
| 585 |
+
|
| 586 |
+
slm_out: List[int] = []
|
| 587 |
+
llm_out: List[int] = []
|
| 588 |
+
mask_out: List[bool] = []
|
| 589 |
+
slm_pad_mask_out: List[bool] = []
|
| 590 |
+
llm_pad_mask_out: List[bool] = []
|
| 591 |
+
detailed_sections: List[Dict[str, Union[str, Tuple[int,int]]]] = []
|
| 592 |
+
|
| 593 |
+
for (stype_s, s_s, e_s), (stype_l, s_l, e_l) in zip(slm_sections, llm_sections):
|
| 594 |
+
assert stype_s == stype_l, "Section type mismatch"
|
| 595 |
+
slm_start_out = len(slm_out)
|
| 596 |
+
llm_start_out = len(llm_out)
|
| 597 |
+
if stype_s == "template":
|
| 598 |
+
slm_seg_len = e_s - s_s
|
| 599 |
+
llm_seg_len = e_l - s_l
|
| 600 |
+
target_len = slm_seg_len if slm_seg_len >= llm_seg_len else llm_seg_len
|
| 601 |
+
slm_pad_needed = target_len - slm_seg_len
|
| 602 |
+
llm_pad_needed = target_len - llm_seg_len
|
| 603 |
+
slm_seg = slm_ids[s_s:e_s] + [slm_pad_id] * slm_pad_needed
|
| 604 |
+
llm_seg = llm_ids[s_l:e_l] + [llm_pad_id] * llm_pad_needed
|
| 605 |
+
slm_out.extend(slm_seg)
|
| 606 |
+
llm_out.extend(llm_seg)
|
| 607 |
+
mask_out.extend([False] * target_len)
|
| 608 |
+
slm_pad_mask_out.extend([False] * slm_seg_len + [True] * slm_pad_needed)
|
| 609 |
+
llm_pad_mask_out.extend([False] * llm_seg_len + [True] * llm_pad_needed)
|
| 610 |
+
else: # message
|
| 611 |
+
slm_msg = slm_ids[s_s:e_s]
|
| 612 |
+
llm_msg = self.align_tokens(slm_msg)
|
| 613 |
+
assert len(llm_msg) == len(slm_msg)
|
| 614 |
+
slm_out.extend(slm_msg)
|
| 615 |
+
llm_out.extend(llm_msg)
|
| 616 |
+
mask_out.extend([True] * len(slm_msg))
|
| 617 |
+
# no padding in message sections
|
| 618 |
+
slm_pad_mask_out.extend([False] * len(slm_msg))
|
| 619 |
+
llm_pad_mask_out.extend([False] * len(slm_msg))
|
| 620 |
+
slm_end_out = len(slm_out)
|
| 621 |
+
llm_end_out = len(llm_out)
|
| 622 |
+
detailed_sections.append({
|
| 623 |
+
'type': stype_s,
|
| 624 |
+
'slm_range': (slm_start_out, slm_end_out),
|
| 625 |
+
'llm_range': (llm_start_out, llm_end_out)
|
| 626 |
+
})
|
| 627 |
+
|
| 628 |
+
result_min = {
|
| 629 |
+
'slm_ids_padded': slm_out,
|
| 630 |
+
'llm_ids_padded': llm_out,
|
| 631 |
+
'message_mask': mask_out,
|
| 632 |
+
'slm_padding_mask': slm_pad_mask_out,
|
| 633 |
+
'llm_padding_mask': llm_pad_mask_out
|
| 634 |
+
}
|
| 635 |
+
if return_details:
|
| 636 |
+
result_min['sections'] = detailed_sections
|
| 637 |
+
result_min['slm_text'] = slm_text
|
| 638 |
+
result_min['llm_text'] = llm_text
|
| 639 |
+
return result_min
|
rosetta/model/oracle.py
ADDED
|
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 5 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 6 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
from rosetta.model.projector import Projector
|
| 10 |
+
from rosetta.model.sampling import sample_token
|
| 11 |
+
from transformers.utils import ModelOutput
|
| 12 |
+
try:
|
| 13 |
+
from transformers.generation.utils import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
|
| 14 |
+
except Exception:
|
| 15 |
+
GreedySearchDecoderOnlyOutput = None
|
| 16 |
+
SampleDecoderOnlyOutput = None
|
| 17 |
+
from rosetta.model.wrapper import RosettaModel
|
| 18 |
+
|
| 19 |
+
class OracleRosettaModel(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Drop in replacement for the standard transformers LLM models, like Qwen3ForCausalLM
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, model_list: List[PreTrainedModel], base_model_idx = 0, projector_list: List[Projector] = [], aggregator_list: List[nn.Module] = []):
|
| 24 |
+
super().__init__()
|
| 25 |
+
# model list: a list of model, model 0 by default is the base model
|
| 26 |
+
# projector list: a list of projector
|
| 27 |
+
# standard init with additional model list parameter
|
| 28 |
+
# kv-cache dict: key (source_model_idx, target_model_idx), value (Cache), assume only convert at prefill with one type of model
|
| 29 |
+
# projector dict: key (source_model_idx, target_model_idx) value dict(key (source_model_layer_idx, M_target value )
|
| 30 |
+
|
| 31 |
+
self.base_model_idx = base_model_idx
|
| 32 |
+
self.model_list = nn.ModuleList(model_list)
|
| 33 |
+
|
| 34 |
+
device = model_list[base_model_idx].device
|
| 35 |
+
dtype = model_list[base_model_idx].dtype
|
| 36 |
+
self.projector_list = nn.ModuleList(projector_list).to(device=device, dtype=dtype)
|
| 37 |
+
self.aggregator_list = nn.ModuleList(aggregator_list).to(device=device, dtype=dtype)
|
| 38 |
+
|
| 39 |
+
self.projector_dict = {}
|
| 40 |
+
self.aggregator_dict = {}
|
| 41 |
+
self.kv_cache_dict = {}
|
| 42 |
+
self._generation_hook_handlers = []
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def device(self):
|
| 46 |
+
return self.model_list[self.base_model_idx].device
|
| 47 |
+
|
| 48 |
+
def to(self, device):
|
| 49 |
+
"""
|
| 50 |
+
Move the RosettaModel and all underlying models and projectors to the specified device.
|
| 51 |
+
"""
|
| 52 |
+
super().to(device)
|
| 53 |
+
for model in self.model_list:
|
| 54 |
+
model.to(device)
|
| 55 |
+
for projector in self.projector_list:
|
| 56 |
+
projector.to(device)
|
| 57 |
+
for aggregator in self.aggregator_list:
|
| 58 |
+
aggregator.to(device)
|
| 59 |
+
return self
|
| 60 |
+
|
| 61 |
+
# set projector
|
| 62 |
+
def set_projector_config(self,
|
| 63 |
+
source_model_idx: int,
|
| 64 |
+
source_model_layer_idx: int,
|
| 65 |
+
target_model_idx: int,
|
| 66 |
+
target_model_layer_idx: int,
|
| 67 |
+
projector_idx: int):
|
| 68 |
+
"""
|
| 69 |
+
Set the projector configuration
|
| 70 |
+
Args:
|
| 71 |
+
source_model_idx: int, the index of the source model
|
| 72 |
+
source_model_layer_idx: int, the index of the source model layer
|
| 73 |
+
target_model_idx: int, the index of the target model
|
| 74 |
+
target_model_layer_idx: int, the index of the target model layer
|
| 75 |
+
projector_idx: int, the index of the projector
|
| 76 |
+
|
| 77 |
+
The projector dict structure supports multiple projectors per target layer.
|
| 78 |
+
Structure:
|
| 79 |
+
{
|
| 80 |
+
target_model_idx: {
|
| 81 |
+
source_model_idx: {
|
| 82 |
+
target_model_layer_idx: [(source_model_layer_idx, projector_idx), ...]
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
Repeated calls for the same (target, source, target_layer) append additional pairs.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
if target_model_idx not in self.projector_dict.keys():
|
| 90 |
+
self.projector_dict[target_model_idx] = {}
|
| 91 |
+
if source_model_idx not in self.projector_dict[target_model_idx].keys():
|
| 92 |
+
self.projector_dict[target_model_idx][source_model_idx] = {}
|
| 93 |
+
# Accumulate list of (source_layer, projector_idx) for this target layer
|
| 94 |
+
layer_entry = self.projector_dict[target_model_idx][source_model_idx].get(target_model_layer_idx)
|
| 95 |
+
if layer_entry is None:
|
| 96 |
+
self.projector_dict[target_model_idx][source_model_idx][target_model_layer_idx] = [(source_model_layer_idx, projector_idx)]
|
| 97 |
+
else:
|
| 98 |
+
layer_entry.append((source_model_layer_idx, projector_idx))
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def load_projector(self, projector_list):
|
| 102 |
+
self.projector_list: List[Projector] = projector_list
|
| 103 |
+
|
| 104 |
+
def load_aggregator(self, aggregator_list):
|
| 105 |
+
self.aggregator_list: List[nn.Module] = aggregator_list
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_projector(self,
|
| 109 |
+
source_model_idx,
|
| 110 |
+
source_model_layer_idx,
|
| 111 |
+
target_model_idx,
|
| 112 |
+
target_model_layer_idx):
|
| 113 |
+
pair_list = self.projector_dict[target_model_idx][source_model_idx][target_model_layer_idx]
|
| 114 |
+
if len(pair_list) == 0:
|
| 115 |
+
raise ValueError("No projector configured for the given target layer")
|
| 116 |
+
# Prefer exact source layer match
|
| 117 |
+
for src_layer, projector_id in pair_list:
|
| 118 |
+
if src_layer == source_model_layer_idx:
|
| 119 |
+
return self.projector_list[projector_id]
|
| 120 |
+
# Fallback: return the first projector
|
| 121 |
+
return self.projector_list[pair_list[0][1]]
|
| 122 |
+
|
| 123 |
+
def set_aggregator_idx(self,
|
| 124 |
+
source_model_idx: int,
|
| 125 |
+
target_model_idx: int,
|
| 126 |
+
target_model_layer_idx: int,
|
| 127 |
+
aggregator_idx: int):
|
| 128 |
+
if target_model_idx not in self.aggregator_dict:
|
| 129 |
+
self.aggregator_dict[target_model_idx] = {}
|
| 130 |
+
if source_model_idx not in self.aggregator_dict[target_model_idx]:
|
| 131 |
+
self.aggregator_dict[target_model_idx][source_model_idx] = {}
|
| 132 |
+
self.aggregator_dict[target_model_idx][source_model_idx][target_model_layer_idx] = aggregator_idx
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def load_json(file_name):
|
| 137 |
+
with open(file_name, "r") as f:
|
| 138 |
+
result = json.load(f)
|
| 139 |
+
return result
|
| 140 |
+
|
| 141 |
+
@staticmethod
|
| 142 |
+
def _convert_dict_keys_to_ints(obj):
|
| 143 |
+
"""
|
| 144 |
+
Recursively convert dictionary keys that look like integers back to int.
|
| 145 |
+
This reverses json.dump's coercion of dict keys to strings.
|
| 146 |
+
"""
|
| 147 |
+
if isinstance(obj, dict):
|
| 148 |
+
new_obj = {}
|
| 149 |
+
for key, value in obj.items():
|
| 150 |
+
if isinstance(key, str) and key.lstrip('-').isdigit():
|
| 151 |
+
new_key = int(key)
|
| 152 |
+
else:
|
| 153 |
+
new_key = key
|
| 154 |
+
new_obj[new_key] = RosettaModel._convert_dict_keys_to_ints(value)
|
| 155 |
+
return new_obj
|
| 156 |
+
if isinstance(obj, list):
|
| 157 |
+
return [RosettaModel._convert_dict_keys_to_ints(v) for v in obj]
|
| 158 |
+
return obj
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def save_projector_config(self, file_name):
|
| 162 |
+
with open(file_name, "w") as f:
|
| 163 |
+
json.dump(self.projector_dict, f)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def load_projector_config(self, config_path):
|
| 167 |
+
if config_path.endswith(".json"):
|
| 168 |
+
loaded = RosettaModel.load_json(config_path)
|
| 169 |
+
self.projector_dict = RosettaModel._convert_dict_keys_to_ints(loaded)
|
| 170 |
+
|
| 171 |
+
def save_aggregator_config(self, file_name):
|
| 172 |
+
with open(file_name, "w") as f:
|
| 173 |
+
json.dump(self.aggregator_dict, f)
|
| 174 |
+
|
| 175 |
+
def load_aggregator_config(self, config_path):
|
| 176 |
+
if config_path.endswith(".json"):
|
| 177 |
+
loaded = RosettaModel.load_json(config_path)
|
| 178 |
+
self.aggregator_dict = RosettaModel._convert_dict_keys_to_ints(loaded)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def set_kv_cache_dict(self, source_model_idx, target_model_idx, cache):
|
| 182 |
+
if target_model_idx not in self.kv_cache_dict.keys():
|
| 183 |
+
self.kv_cache_dict[target_model_idx] = {}
|
| 184 |
+
if cache is None:
|
| 185 |
+
# Initialize with a DynamicCache instead of RosettaCache for now
|
| 186 |
+
self.kv_cache_dict[target_model_idx][source_model_idx] = DynamicCache() # noqa, maybe we should use RosettaCache here
|
| 187 |
+
else:
|
| 188 |
+
self.kv_cache_dict[target_model_idx][source_model_idx] = cache
|
| 189 |
+
|
| 190 |
+
def forward(
|
| 191 |
+
self,
|
| 192 |
+
kv_cache_index: Optional[List] = None,
|
| 193 |
+
input_ids: Optional[Union[torch.LongTensor, List[torch.LongTensor]]] = None,
|
| 194 |
+
attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
| 195 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 196 |
+
past_key_values: Optional[Cache] = None,
|
| 197 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 198 |
+
labels: Optional[torch.LongTensor] = None,
|
| 199 |
+
use_cache: Optional[bool] = None,
|
| 200 |
+
output_attentions: Optional[bool] = None,
|
| 201 |
+
output_hidden_states: Optional[bool] = None,
|
| 202 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 203 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 204 |
+
# **kwargs: Unpack[KwargsForCausalLM],
|
| 205 |
+
identifier = -1,
|
| 206 |
+
subject = None,
|
| 207 |
+
*args,
|
| 208 |
+
**kwargs,
|
| 209 |
+
) -> CausalLMOutputWithPast:
|
| 210 |
+
"""
|
| 211 |
+
Forward pass
|
| 212 |
+
KVCache index is a list of tensors with shape (B, sec_seq_len, 2), indicating the source and target kv cache model index
|
| 213 |
+
|
| 214 |
+
If input_ids is LongTensor, default to same input ids for different models
|
| 215 |
+
If input_ids is Tuple, default to different input ids for different models.
|
| 216 |
+
|
| 217 |
+
No Rosetta: (-1, 0)
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
# noqa
|
| 221 |
+
self.kv_cache_dict = dict()
|
| 222 |
+
|
| 223 |
+
# Handle different input formats: if input_ids is a list, use per-model inputs
|
| 224 |
+
if isinstance(input_ids, list):
|
| 225 |
+
# Use list format: different input_ids and attention_mask for each model
|
| 226 |
+
base_input_ids = input_ids[self.base_model_idx] if input_ids is not None else None
|
| 227 |
+
base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
|
| 228 |
+
_, seqlen = base_input_ids.size() if base_input_ids is not None else (0, 0)
|
| 229 |
+
else:
|
| 230 |
+
# Use tensor format: same input_ids and attention_mask for all models (backward compatibility)
|
| 231 |
+
base_input_ids = input_ids
|
| 232 |
+
base_attention_mask = attention_mask
|
| 233 |
+
_, seqlen = input_ids.size() if input_ids is not None else (0, 0)
|
| 234 |
+
|
| 235 |
+
num_sections = len(kv_cache_index) if kv_cache_index is not None else 1
|
| 236 |
+
|
| 237 |
+
section_lengths = [kv_cache_index[i].shape[1] for i in range(num_sections)] if kv_cache_index is not None else [seqlen]
|
| 238 |
+
section_starts = [0]
|
| 239 |
+
for l in section_lengths:
|
| 240 |
+
section_starts.append(section_starts[-1] + l)
|
| 241 |
+
|
| 242 |
+
curr_base_kv_cache = past_key_values
|
| 243 |
+
|
| 244 |
+
if seqlen > 1:
|
| 245 |
+
for i in range(num_sections):
|
| 246 |
+
start = section_starts[i]
|
| 247 |
+
end = section_starts[i + 1]
|
| 248 |
+
prefill_input_ids = base_input_ids[:, start:end] if base_input_ids is not None else None
|
| 249 |
+
prefill_attention_mask = base_attention_mask[:, :end] if base_attention_mask is not None else None
|
| 250 |
+
prefill_position_ids = position_ids[:, start:end] if position_ids is not None else None
|
| 251 |
+
prefill_labels = labels[:, start:end] if labels is not None else None
|
| 252 |
+
|
| 253 |
+
# calculate target model kvcache
|
| 254 |
+
output = self.model_list[self.base_model_idx].forward(
|
| 255 |
+
input_ids=prefill_input_ids,
|
| 256 |
+
attention_mask=prefill_attention_mask,
|
| 257 |
+
position_ids=prefill_position_ids,
|
| 258 |
+
past_key_values=curr_base_kv_cache,
|
| 259 |
+
labels=prefill_labels,
|
| 260 |
+
use_cache=use_cache,
|
| 261 |
+
output_attentions=output_attentions,
|
| 262 |
+
output_hidden_states=output_hidden_states,
|
| 263 |
+
*args,
|
| 264 |
+
**kwargs
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
if self.base_model_idx not in self.kv_cache_dict:
|
| 268 |
+
self.kv_cache_dict[self.base_model_idx] = {}
|
| 269 |
+
if self.base_model_idx not in self.kv_cache_dict[self.base_model_idx]:
|
| 270 |
+
self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = None
|
| 271 |
+
self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = output.past_key_values
|
| 272 |
+
|
| 273 |
+
curr_base_kv_cache: DynamicCache = output.past_key_values
|
| 274 |
+
|
| 275 |
+
# if i != num_sections - 1:
|
| 276 |
+
for source_model_idx in range(1, len(self.model_list)):
|
| 277 |
+
if self.base_model_idx not in self.kv_cache_dict:
|
| 278 |
+
self.kv_cache_dict[self.base_model_idx] = {}
|
| 279 |
+
if source_model_idx not in self.kv_cache_dict[self.base_model_idx]:
|
| 280 |
+
self.kv_cache_dict[self.base_model_idx][source_model_idx] = None
|
| 281 |
+
|
| 282 |
+
# Get model-specific input_ids and attention_mask
|
| 283 |
+
if isinstance(input_ids, list):
|
| 284 |
+
source_input_ids = input_ids[source_model_idx]
|
| 285 |
+
source_attention_mask = attention_mask[source_model_idx] if attention_mask is not None else None
|
| 286 |
+
source_prefill_input_ids = source_input_ids[:, start:end] if source_input_ids is not None else None
|
| 287 |
+
source_prefill_attention_mask = source_attention_mask[:, :end] if source_attention_mask is not None else None
|
| 288 |
+
else:
|
| 289 |
+
# Backward compatibility: use same input for all models
|
| 290 |
+
source_prefill_input_ids = prefill_input_ids
|
| 291 |
+
source_prefill_attention_mask = prefill_attention_mask
|
| 292 |
+
|
| 293 |
+
curr_source_kv_cache = self.model_list[source_model_idx].forward(
|
| 294 |
+
input_ids=source_prefill_input_ids,
|
| 295 |
+
attention_mask=source_prefill_attention_mask,
|
| 296 |
+
position_ids=prefill_position_ids,
|
| 297 |
+
past_key_values=self.kv_cache_dict[self.base_model_idx][source_model_idx],
|
| 298 |
+
use_cache=use_cache,
|
| 299 |
+
output_attentions=output_attentions,
|
| 300 |
+
output_hidden_states=output_hidden_states,
|
| 301 |
+
*args,
|
| 302 |
+
**kwargs
|
| 303 |
+
).past_key_values
|
| 304 |
+
self.kv_cache_dict[self.base_model_idx][source_model_idx] = curr_source_kv_cache
|
| 305 |
+
|
| 306 |
+
# calculate source model kvcache and apply projections
|
| 307 |
+
if self.base_model_idx in self.projector_dict:
|
| 308 |
+
source_model_idx = kv_cache_index[i][0][0][0].item() # Get the source model index from the kv_cache_index
|
| 309 |
+
if source_model_idx != -1:
|
| 310 |
+
for target_layer_idx, entry in self.projector_dict[self.base_model_idx][source_model_idx].items():
|
| 311 |
+
base_key_cache, base_value_cache = curr_base_kv_cache[target_layer_idx]
|
| 312 |
+
new_base_key_cache = base_key_cache[:, :, start:end, :]
|
| 313 |
+
new_base_value_cache = base_value_cache[:, :, start:end, :]
|
| 314 |
+
new_base_kv_cache = (new_base_key_cache, new_base_value_cache)
|
| 315 |
+
|
| 316 |
+
pair_list = entry
|
| 317 |
+
|
| 318 |
+
projected_kv_list = []
|
| 319 |
+
source_kv_list = []
|
| 320 |
+
for source_model_layer_idx, projector_idx in pair_list:
|
| 321 |
+
source_key_cache, source_value_cache = self.kv_cache_dict[self.base_model_idx][source_model_idx][source_model_layer_idx]
|
| 322 |
+
new_source_key_cache = source_key_cache[:, :, start:end, :]
|
| 323 |
+
new_source_value_cache = source_value_cache[:, :, start:end, :]
|
| 324 |
+
new_source_kv_cache = (new_source_key_cache, new_source_value_cache)
|
| 325 |
+
projected_key, projected_value = self.projector_list[projector_idx].forward(
|
| 326 |
+
new_source_kv_cache, # tuple of (key, value), each of shape (B, N, H, D)
|
| 327 |
+
new_base_kv_cache
|
| 328 |
+
)
|
| 329 |
+
projected_kv_list.append((projected_key, projected_value))
|
| 330 |
+
|
| 331 |
+
# --------------
|
| 332 |
+
# save base and projected kv cache
|
| 333 |
+
torch.save((projected_key, projected_value), f"oracle/projected_kv/{subject}_{identifier}_{i}.pt")
|
| 334 |
+
torch.save(new_base_kv_cache, f"oracle/target_kv/{subject}_{identifier}_{i}.pt")
|
| 335 |
+
# --------------
|
| 336 |
+
source_kv_list.append(new_source_kv_cache)
|
| 337 |
+
|
| 338 |
+
# Aggregate (fallback to first projector if no aggregator is available)
|
| 339 |
+
use_aggregator = (
|
| 340 |
+
len(projected_kv_list) > 1 and
|
| 341 |
+
len(self.aggregator_list) > 0 and
|
| 342 |
+
self.base_model_idx in self.aggregator_dict and
|
| 343 |
+
source_model_idx in self.aggregator_dict[self.base_model_idx] and
|
| 344 |
+
target_layer_idx in self.aggregator_dict[self.base_model_idx][source_model_idx]
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
if use_aggregator:
|
| 348 |
+
aggregator_idx = self.aggregator_dict[self.base_model_idx][source_model_idx][target_layer_idx]
|
| 349 |
+
agg_key, agg_value = self.aggregator_list[aggregator_idx].forward(
|
| 350 |
+
source_kv_list,
|
| 351 |
+
new_base_kv_cache,
|
| 352 |
+
projected_kv_list
|
| 353 |
+
)
|
| 354 |
+
else:
|
| 355 |
+
# Fallback to first projector result when no aggregator is available
|
| 356 |
+
agg_key, agg_value = projected_kv_list[0]
|
| 357 |
+
|
| 358 |
+
# Update cache with aggregated result
|
| 359 |
+
curr_base_kv_cache.key_cache[target_layer_idx][:, :, start:end, :] = agg_key
|
| 360 |
+
curr_base_kv_cache.value_cache[target_layer_idx][:, :, start:end, :] = agg_value
|
| 361 |
+
|
| 362 |
+
output.past_key_values = curr_base_kv_cache
|
| 363 |
+
|
| 364 |
+
# use base model for decode phase
|
| 365 |
+
else:
|
| 366 |
+
# Handle list input format for decode phase as well
|
| 367 |
+
decode_input_ids = input_ids[self.base_model_idx] if isinstance(input_ids, list) else input_ids
|
| 368 |
+
decode_attention_mask = attention_mask[self.base_model_idx] if isinstance(attention_mask, list) else attention_mask
|
| 369 |
+
|
| 370 |
+
output = self.model_list[self.base_model_idx].forward(
|
| 371 |
+
input_ids=decode_input_ids,
|
| 372 |
+
attention_mask=decode_attention_mask,
|
| 373 |
+
position_ids=position_ids,
|
| 374 |
+
past_key_values=curr_base_kv_cache,
|
| 375 |
+
inputs_embeds=inputs_embeds,
|
| 376 |
+
labels=labels,
|
| 377 |
+
use_cache=use_cache,
|
| 378 |
+
output_attentions=output_attentions,
|
| 379 |
+
output_hidden_states=output_hidden_states,
|
| 380 |
+
cache_position=cache_position,
|
| 381 |
+
*args,
|
| 382 |
+
**kwargs
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
return output
|
| 386 |
+
|
| 387 |
+
@torch.no_grad()
|
| 388 |
+
def generate(
|
| 389 |
+
self,
|
| 390 |
+
kv_cache_index,
|
| 391 |
+
input_ids,
|
| 392 |
+
max_new_tokens: Optional[int] = None,
|
| 393 |
+
past_key_values: Optional[Cache] = None,
|
| 394 |
+
attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
| 395 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 396 |
+
eos_token_id: Optional[Union[int, List[int]]] = None,
|
| 397 |
+
pad_token_id: Optional[int] = None,
|
| 398 |
+
temperature: float = 1.0,
|
| 399 |
+
top_p: float = 1.0,
|
| 400 |
+
top_k: int = -1,
|
| 401 |
+
do_sample: Optional[bool] = None,
|
| 402 |
+
return_dict_in_generate: Optional[bool] = None,
|
| 403 |
+
output_scores: Optional[bool] = None,
|
| 404 |
+
max_length: Optional[int] = None,
|
| 405 |
+
use_cache: bool = True,
|
| 406 |
+
*args,
|
| 407 |
+
**kwargs,
|
| 408 |
+
):
|
| 409 |
+
"""
|
| 410 |
+
New generation loop without using the base model's generate.
|
| 411 |
+
- Uses this module's forward for prefill and per-token decode.
|
| 412 |
+
- Samples tokens via rosetta.model.sampling.sample_token.
|
| 413 |
+
Returns a tensor of shape [batch, prompt_len + generated_len] for the base model stream.
|
| 414 |
+
"""
|
| 415 |
+
# Derive number of tokens to generate
|
| 416 |
+
# If max_new_tokens not provided, infer from max_length
|
| 417 |
+
if isinstance(input_ids, list):
|
| 418 |
+
base_input_ids_for_len = input_ids[self.base_model_idx]
|
| 419 |
+
else:
|
| 420 |
+
base_input_ids_for_len = input_ids
|
| 421 |
+
prompt_len = base_input_ids_for_len.size(1)
|
| 422 |
+
|
| 423 |
+
# Default eos/pad from base model tokenizer/config if not provided
|
| 424 |
+
base_model = self.model_list[self.base_model_idx]
|
| 425 |
+
gen_cfg = getattr(base_model, "generation_config", None)
|
| 426 |
+
cfg_obj = gen_cfg if gen_cfg is not None else getattr(base_model, "config", None)
|
| 427 |
+
if eos_token_id is None and cfg_obj is not None:
|
| 428 |
+
eos_token_id = getattr(cfg_obj, "eos_token_id", None)
|
| 429 |
+
if pad_token_id is None and cfg_obj is not None:
|
| 430 |
+
pad_token_id = getattr(cfg_obj, "pad_token_id", None)
|
| 431 |
+
if pad_token_id is None and eos_token_id is not None:
|
| 432 |
+
pad_token_id = eos_token_id if isinstance(eos_token_id, int) else eos_token_id[0]
|
| 433 |
+
|
| 434 |
+
if max_new_tokens is None:
|
| 435 |
+
if max_length is not None:
|
| 436 |
+
if max_length <= prompt_len:
|
| 437 |
+
max_new_tokens = 0
|
| 438 |
+
else:
|
| 439 |
+
max_new_tokens = max_length - prompt_len
|
| 440 |
+
else:
|
| 441 |
+
raise ValueError("Provide max_new_tokens or max_length")
|
| 442 |
+
if max_new_tokens < 0:
|
| 443 |
+
raise ValueError("max_new_tokens must be non-negative")
|
| 444 |
+
|
| 445 |
+
# Resolve base inputs
|
| 446 |
+
if isinstance(input_ids, list):
|
| 447 |
+
base_input_ids = input_ids[self.base_model_idx]
|
| 448 |
+
base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
|
| 449 |
+
else:
|
| 450 |
+
base_input_ids = input_ids
|
| 451 |
+
base_attention_mask = attention_mask
|
| 452 |
+
|
| 453 |
+
if base_attention_mask is None:
|
| 454 |
+
base_attention_mask = torch.ones_like(base_input_ids, dtype=torch.long, device=base_input_ids.device)
|
| 455 |
+
|
| 456 |
+
batch_size = base_input_ids.size(0)
|
| 457 |
+
|
| 458 |
+
# Prefill to build caches and obtain initial logits
|
| 459 |
+
prefill_output = self.forward(
|
| 460 |
+
kv_cache_index=kv_cache_index,
|
| 461 |
+
input_ids=input_ids,
|
| 462 |
+
attention_mask=attention_mask,
|
| 463 |
+
position_ids=position_ids,
|
| 464 |
+
past_key_values=past_key_values,
|
| 465 |
+
use_cache=use_cache,
|
| 466 |
+
*args,
|
| 467 |
+
**kwargs,
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
current_past = prefill_output.past_key_values
|
| 471 |
+
all_input_ids = base_input_ids
|
| 472 |
+
current_attention_mask = base_attention_mask
|
| 473 |
+
|
| 474 |
+
# EOS handling setup
|
| 475 |
+
eos_set = None
|
| 476 |
+
if eos_token_id is not None:
|
| 477 |
+
eos_set = set(eos_token_id if isinstance(eos_token_id, list) else [eos_token_id])
|
| 478 |
+
finished = torch.zeros(batch_size, dtype=torch.bool, device=all_input_ids.device)
|
| 479 |
+
|
| 480 |
+
# Start from last prefill logits
|
| 481 |
+
last_logits = prefill_output.logits[:, -1, :]
|
| 482 |
+
|
| 483 |
+
# Determine sampling mode
|
| 484 |
+
if do_sample is None:
|
| 485 |
+
do_sample = False
|
| 486 |
+
effective_temperature = temperature if do_sample else 0.0
|
| 487 |
+
|
| 488 |
+
# Optional scores collection
|
| 489 |
+
collect_scores = bool(return_dict_in_generate) and bool(output_scores)
|
| 490 |
+
scores = []
|
| 491 |
+
|
| 492 |
+
for _ in range(max_new_tokens):
|
| 493 |
+
if collect_scores:
|
| 494 |
+
scores.append(last_logits)
|
| 495 |
+
# Sample next token
|
| 496 |
+
next_token = sample_token(last_logits, temperature=effective_temperature, top_p=top_p, top_k=top_k)
|
| 497 |
+
if not isinstance(next_token, torch.Tensor):
|
| 498 |
+
next_token = torch.tensor([next_token], device=all_input_ids.device, dtype=torch.long).repeat(batch_size)
|
| 499 |
+
|
| 500 |
+
# Apply EOS logic
|
| 501 |
+
if eos_set is not None:
|
| 502 |
+
just_finished = torch.zeros_like(finished)
|
| 503 |
+
for eid in eos_set:
|
| 504 |
+
just_finished |= (next_token == eid)
|
| 505 |
+
finished = finished | just_finished
|
| 506 |
+
if pad_token_id is not None:
|
| 507 |
+
next_token = torch.where(
|
| 508 |
+
finished,
|
| 509 |
+
torch.tensor(pad_token_id, device=next_token.device, dtype=next_token.dtype),
|
| 510 |
+
next_token,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# Append sampled token
|
| 514 |
+
next_token_unsqueezed = next_token.unsqueeze(1)
|
| 515 |
+
all_input_ids = torch.cat([all_input_ids, next_token_unsqueezed], dim=1)
|
| 516 |
+
current_attention_mask = torch.cat(
|
| 517 |
+
[
|
| 518 |
+
current_attention_mask,
|
| 519 |
+
torch.ones((batch_size, 1), device=current_attention_mask.device, dtype=current_attention_mask.dtype),
|
| 520 |
+
],
|
| 521 |
+
dim=1,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
# Early stop if all sequences finished
|
| 525 |
+
if eos_set is not None and torch.all(finished):
|
| 526 |
+
break
|
| 527 |
+
|
| 528 |
+
# Decode one step using cached states; pass base-stream tensors
|
| 529 |
+
kv_cache_index = [torch.tensor([-1, 0], dtype=torch.long).repeat(1, 1).unsqueeze(0).to(all_input_ids.device)]
|
| 530 |
+
|
| 531 |
+
decode_output = self.forward(
|
| 532 |
+
kv_cache_index=kv_cache_index,
|
| 533 |
+
input_ids=next_token_unsqueezed,
|
| 534 |
+
attention_mask=current_attention_mask,
|
| 535 |
+
position_ids=None,
|
| 536 |
+
past_key_values=current_past,
|
| 537 |
+
use_cache=True,
|
| 538 |
+
*args,
|
| 539 |
+
**kwargs,
|
| 540 |
+
)
|
| 541 |
+
current_past = decode_output.past_key_values
|
| 542 |
+
last_logits = decode_output.logits[:, -1, :]
|
| 543 |
+
|
| 544 |
+
# Return style compatible with HF generate
|
| 545 |
+
if return_dict_in_generate:
|
| 546 |
+
if GreedySearchDecoderOnlyOutput is not None and SampleDecoderOnlyOutput is not None:
|
| 547 |
+
if do_sample:
|
| 548 |
+
return SampleDecoderOnlyOutput(
|
| 549 |
+
sequences=all_input_ids,
|
| 550 |
+
scores=scores if collect_scores else None,
|
| 551 |
+
)
|
| 552 |
+
else:
|
| 553 |
+
return GreedySearchDecoderOnlyOutput(
|
| 554 |
+
sequences=all_input_ids,
|
| 555 |
+
scores=scores if collect_scores else None,
|
| 556 |
+
)
|
| 557 |
+
# Fallback to generic ModelOutput
|
| 558 |
+
result = {"sequences": all_input_ids}
|
| 559 |
+
if collect_scores:
|
| 560 |
+
result["scores"] = scores
|
| 561 |
+
return ModelOutput(**result)
|
| 562 |
+
return all_input_ids
|
rosetta/model/projector.py
ADDED
|
@@ -0,0 +1,1194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Projector nn module for the unified memory
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from transformers import Cache, DynamicCache
|
| 9 |
+
from typing import Optional, Tuple, Literal, Union
|
| 10 |
+
import copy
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
from rosetta.utils.registry import register_model, get_projector_class, PROJECTOR_REGISTRY, capture_init_args, save_object, load_object
|
| 14 |
+
|
| 15 |
+
class Projector(nn.Module):
|
| 16 |
+
"""Base projector class for unified memory"""
|
| 17 |
+
|
| 18 |
+
def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
|
| 19 |
+
"""
|
| 20 |
+
Project and combine the source key-value tensors to the target key-value tensors
|
| 21 |
+
Args:
|
| 22 |
+
source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
|
| 23 |
+
target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
|
| 24 |
+
Returns:
|
| 25 |
+
Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
|
| 26 |
+
"""
|
| 27 |
+
raise NotImplementedError("Subclasses must implement forward method")
|
| 28 |
+
|
| 29 |
+
def cache_project(self, source_kv_cache: Cache, target_kv_cache: Cache) -> Cache:
|
| 30 |
+
"""
|
| 31 |
+
Project the source kv cache to the target kv cache
|
| 32 |
+
"""
|
| 33 |
+
if not isinstance(source_kv_cache, DynamicCache) or not isinstance(target_kv_cache, DynamicCache):
|
| 34 |
+
raise ValueError("Only DynamicCache is supported")
|
| 35 |
+
|
| 36 |
+
projected_cache = DynamicCache()
|
| 37 |
+
|
| 38 |
+
# Process each layer
|
| 39 |
+
for layer_idx in range(len(source_kv_cache.key_cache)):
|
| 40 |
+
source_key = source_kv_cache.key_cache[layer_idx] # (B, H, N, D_s)
|
| 41 |
+
source_value = source_kv_cache.value_cache[layer_idx] # (B, H, N, D_s)
|
| 42 |
+
|
| 43 |
+
# Get corresponding target tensors (for reference/combination)
|
| 44 |
+
if layer_idx < len(target_kv_cache.key_cache):
|
| 45 |
+
target_key = target_kv_cache.key_cache[layer_idx] # (B, H, N, D_t)
|
| 46 |
+
target_value = target_kv_cache.value_cache[layer_idx] # (B, H, N, D_t)
|
| 47 |
+
else:
|
| 48 |
+
# If target cache doesn't have this layer, create dummy tensors
|
| 49 |
+
B, H, N, D_s = source_key.shape
|
| 50 |
+
D_t = source_key.shape[-1] # Assume same dimension for simplicity
|
| 51 |
+
target_key = torch.zeros(B, H, N, D_t, device=source_key.device, dtype=source_key.dtype)
|
| 52 |
+
target_value = torch.zeros(B, H, N, D_t, device=source_value.device, dtype=source_value.dtype)
|
| 53 |
+
|
| 54 |
+
# Reshape for forward pass: DynamicCache format (B, H, N, D) -> projector format (B, N, H, D)
|
| 55 |
+
source_key_reshaped = source_key.transpose(1, 2)
|
| 56 |
+
source_value_reshaped = source_value.transpose(1, 2)
|
| 57 |
+
target_key_reshaped = target_key.transpose(1, 2)
|
| 58 |
+
target_value_reshaped = target_value.transpose(1, 2)
|
| 59 |
+
|
| 60 |
+
# Project using forward method with tuple input/output
|
| 61 |
+
source_kv = (source_key_reshaped, source_value_reshaped)
|
| 62 |
+
target_kv = (target_key_reshaped, target_value_reshaped)
|
| 63 |
+
projected_key, projected_value = self.forward(source_kv, target_kv)
|
| 64 |
+
|
| 65 |
+
# Reshape back: projector format (B, N, H, D) -> DynamicCache format (B, H, N, D)
|
| 66 |
+
projected_key = projected_key.transpose(1, 2)
|
| 67 |
+
projected_value = projected_value.transpose(1, 2)
|
| 68 |
+
|
| 69 |
+
# Update cache
|
| 70 |
+
projected_cache.update(projected_key, projected_value, layer_idx)
|
| 71 |
+
|
| 72 |
+
return projected_cache
|
| 73 |
+
|
| 74 |
+
@register_model
|
| 75 |
+
@capture_init_args
|
| 76 |
+
class TrivialProjector(Projector):
|
| 77 |
+
"""
|
| 78 |
+
Trivial projector that directly outputs the target key-value pairs without any modification.
|
| 79 |
+
This is useful as a baseline or when you want to effectively disable projection.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(self, **kwargs):
|
| 83 |
+
"""
|
| 84 |
+
Initialize the trivial projector.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
source_dim: Source dimension (ignored, kept for compatibility)
|
| 88 |
+
target_dim: Target dimension (ignored, kept for compatibility)
|
| 89 |
+
**kwargs: Additional arguments (ignored, kept for compatibility)
|
| 90 |
+
"""
|
| 91 |
+
super().__init__()
|
| 92 |
+
|
| 93 |
+
def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
|
| 94 |
+
"""
|
| 95 |
+
Return the target key-value pairs unchanged, ignoring the source.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
source_kv: Tuple of (key, value) tensors (ignored)
|
| 99 |
+
target_kv: Tuple of (key, value) tensors to return unchanged
|
| 100 |
+
Returns:
|
| 101 |
+
The target key-value pairs unchanged
|
| 102 |
+
"""
|
| 103 |
+
return target_kv
|
| 104 |
+
|
| 105 |
+
@register_model
|
| 106 |
+
@capture_init_args
|
| 107 |
+
class ReplaceProjector(Projector):
|
| 108 |
+
"""
|
| 109 |
+
Replacement projector that projects source key-value tensors to target dimension using MLP,
|
| 110 |
+
then replace target tensors using learnable weights.
|
| 111 |
+
"""
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
source_dim: int,
|
| 115 |
+
target_dim: int,
|
| 116 |
+
hidden_dim: int = 512,
|
| 117 |
+
num_layers: int = 2,
|
| 118 |
+
dropout: float = 0.1,
|
| 119 |
+
activation: str = "gelu",
|
| 120 |
+
use_layer_norm: bool = True,
|
| 121 |
+
init_weight: float = 0.1,
|
| 122 |
+
anneal_steps: int = 1360,
|
| 123 |
+
initial_temperature: float = 1.0,
|
| 124 |
+
final_temperature: float = 0.01,
|
| 125 |
+
scalar_temperature: float = 0.005,
|
| 126 |
+
# shared_key_projection: nn.Module = None,
|
| 127 |
+
# shared_value_projection: nn.Module = None,
|
| 128 |
+
dtype: torch.dtype = torch.float32
|
| 129 |
+
):
|
| 130 |
+
super().__init__()
|
| 131 |
+
|
| 132 |
+
self.source_dim = source_dim
|
| 133 |
+
self.target_dim = target_dim
|
| 134 |
+
self.hidden_dim = hidden_dim
|
| 135 |
+
self.num_layers = num_layers
|
| 136 |
+
|
| 137 |
+
# Activation function
|
| 138 |
+
if activation.lower() == "gelu":
|
| 139 |
+
self.activation = nn.GELU()
|
| 140 |
+
elif activation.lower() == "relu":
|
| 141 |
+
self.activation = nn.ReLU()
|
| 142 |
+
elif activation.lower() == "silu":
|
| 143 |
+
self.activation = nn.SiLU()
|
| 144 |
+
else:
|
| 145 |
+
raise ValueError(f"Unsupported activation: {activation}")
|
| 146 |
+
|
| 147 |
+
# Build separate MLP layers for key and value projection
|
| 148 |
+
self.key_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
|
| 149 |
+
self.value_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _build_mlp(self, source_dim: int, hidden_dim: int, target_dim: int, num_layers: int,
|
| 153 |
+
use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
|
| 154 |
+
"""Build a single MLP projection module"""
|
| 155 |
+
layers = []
|
| 156 |
+
|
| 157 |
+
# Input projection
|
| 158 |
+
layers.append(nn.Linear(source_dim, hidden_dim, dtype=dtype))
|
| 159 |
+
if use_layer_norm:
|
| 160 |
+
layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
|
| 161 |
+
layers.append(copy.deepcopy(self.activation))
|
| 162 |
+
layers.append(nn.Dropout(dropout))
|
| 163 |
+
|
| 164 |
+
# Hidden layers
|
| 165 |
+
for _ in range(num_layers - 2):
|
| 166 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
|
| 167 |
+
if use_layer_norm:
|
| 168 |
+
layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
|
| 169 |
+
layers.append(copy.deepcopy(self.activation))
|
| 170 |
+
layers.append(nn.Dropout(dropout))
|
| 171 |
+
|
| 172 |
+
# Output projection
|
| 173 |
+
if num_layers > 1:
|
| 174 |
+
layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
|
| 175 |
+
else:
|
| 176 |
+
# Single layer case
|
| 177 |
+
layers = [nn.Linear(source_dim, target_dim, dtype=dtype)]
|
| 178 |
+
|
| 179 |
+
return nn.Sequential(*layers)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
|
| 183 |
+
"""
|
| 184 |
+
Project source key-value tensors to target dimension and add to target tensors with learnable weights
|
| 185 |
+
Args:
|
| 186 |
+
source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
|
| 187 |
+
target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
|
| 188 |
+
Returns:
|
| 189 |
+
Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
source_key, source_value = source_kv
|
| 193 |
+
target_key, target_value = target_kv
|
| 194 |
+
|
| 195 |
+
# Get shapes - assuming format is (B, H, N, D) where H is num_heads, N is seq_len, D is head_dim
|
| 196 |
+
source_shape = source_key.shape # (B, H_s, N, D_s)
|
| 197 |
+
target_shape = target_key.shape # (B, H_t, N, D_t)
|
| 198 |
+
|
| 199 |
+
# Extract dimensions
|
| 200 |
+
batch_size, source_num_heads, seq_len, source_head_dim = source_shape
|
| 201 |
+
_, target_num_heads, _, target_head_dim = target_shape
|
| 202 |
+
|
| 203 |
+
# Reshape source: merge num_heads and head_dim for projection
|
| 204 |
+
# (B, H_s, N, D_s) -> (B, N, H_s * D_s)
|
| 205 |
+
source_key_reshaped = source_key.transpose(1, 2) # (B, N, H_s, D_s)
|
| 206 |
+
source_value_reshaped = source_value.transpose(1, 2) # (B, N, H_s, D_s)
|
| 207 |
+
|
| 208 |
+
source_key_flat = source_key_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
|
| 209 |
+
source_value_flat = source_value_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
|
| 210 |
+
|
| 211 |
+
# Project source tensors from (H_s * D_s) to (H_t * D_t)
|
| 212 |
+
projected_key_flat = self.key_projection(source_key_flat) # (B, N, H_t * D_t)
|
| 213 |
+
projected_value_flat = self.value_projection(source_value_flat) # (B, N, H_t * D_t)
|
| 214 |
+
|
| 215 |
+
# Reshape projected tensors back to target format
|
| 216 |
+
# (B, N, H_t * D_t) -> (B, N, H_t, D_t) -> (B, H_t, N, D_t)
|
| 217 |
+
projected_key_reshaped = projected_key_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
|
| 218 |
+
projected_value_reshaped = projected_value_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
|
| 219 |
+
|
| 220 |
+
projected_key = projected_key_reshaped.transpose(1, 2) # (B, H_t, N, D_t)
|
| 221 |
+
projected_value = projected_value_reshaped.transpose(1, 2) # (B, H_t, N, D_t)
|
| 222 |
+
|
| 223 |
+
return (projected_key, projected_value)
|
| 224 |
+
|
| 225 |
+
class ModernMLP(nn.Module):
|
| 226 |
+
"""
|
| 227 |
+
Modern MLP with residual connections, layer normalization, and configurable architecture.
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
def __init__(
|
| 231 |
+
self,
|
| 232 |
+
input_dim: int,
|
| 233 |
+
output_dim: int,
|
| 234 |
+
hidden_dim: int = 512,
|
| 235 |
+
num_layers: int = 2,
|
| 236 |
+
activation: str = "gelu",
|
| 237 |
+
use_layer_norm: bool = True,
|
| 238 |
+
use_residual: bool = True,
|
| 239 |
+
dropout: float = 0.1,
|
| 240 |
+
use_swiglu: bool = False,
|
| 241 |
+
dtype: torch.dtype = torch.float32
|
| 242 |
+
):
|
| 243 |
+
super().__init__()
|
| 244 |
+
|
| 245 |
+
self.input_dim = input_dim
|
| 246 |
+
self.output_dim = output_dim
|
| 247 |
+
self.hidden_dim = hidden_dim
|
| 248 |
+
self.num_layers = num_layers
|
| 249 |
+
self.use_residual = use_residual and (input_dim == output_dim)
|
| 250 |
+
self.use_swiglu = use_swiglu
|
| 251 |
+
|
| 252 |
+
# Activation function
|
| 253 |
+
if activation.lower() == "gelu":
|
| 254 |
+
self.activation = nn.GELU()
|
| 255 |
+
elif activation.lower() == "relu":
|
| 256 |
+
self.activation = nn.ReLU()
|
| 257 |
+
elif activation.lower() == "silu":
|
| 258 |
+
self.activation = nn.SiLU()
|
| 259 |
+
else:
|
| 260 |
+
raise ValueError(f"Unsupported activation: {activation}")
|
| 261 |
+
|
| 262 |
+
# Build layers
|
| 263 |
+
self.layers = nn.ModuleList()
|
| 264 |
+
|
| 265 |
+
for i in range(num_layers):
|
| 266 |
+
layer_input_dim = input_dim if i == 0 else hidden_dim
|
| 267 |
+
layer_output_dim = output_dim if i == num_layers - 1 else hidden_dim
|
| 268 |
+
|
| 269 |
+
if self.use_swiglu and i < num_layers - 1: # Don't use SwiGLU on output layer
|
| 270 |
+
layer = SwiGLUBlock(layer_input_dim, layer_output_dim, dtype=dtype)
|
| 271 |
+
else:
|
| 272 |
+
layer = nn.Linear(layer_input_dim, layer_output_dim, dtype=dtype)
|
| 273 |
+
|
| 274 |
+
self.layers.append(layer)
|
| 275 |
+
|
| 276 |
+
# Add layer norm after each layer except the last one
|
| 277 |
+
if use_layer_norm and i < num_layers - 1:
|
| 278 |
+
self.layers.append(nn.LayerNorm(layer_output_dim, dtype=dtype))
|
| 279 |
+
|
| 280 |
+
# Add activation after each layer except the last one
|
| 281 |
+
if i < num_layers - 1 and not self.use_swiglu:
|
| 282 |
+
self.layers.append(copy.deepcopy(self.activation))
|
| 283 |
+
|
| 284 |
+
# Add dropout after activation
|
| 285 |
+
if dropout > 0 and i < num_layers - 1:
|
| 286 |
+
self.layers.append(nn.Dropout(dropout))
|
| 287 |
+
|
| 288 |
+
# Residual projection if dimensions don't match
|
| 289 |
+
if self.use_residual and input_dim != output_dim:
|
| 290 |
+
self.residual_proj = nn.Linear(input_dim, output_dim, dtype=dtype)
|
| 291 |
+
else:
|
| 292 |
+
self.residual_proj = None
|
| 293 |
+
|
| 294 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 295 |
+
"""Forward pass with optional residual connection."""
|
| 296 |
+
residual = x
|
| 297 |
+
|
| 298 |
+
for layer in self.layers:
|
| 299 |
+
x = layer(x)
|
| 300 |
+
|
| 301 |
+
# Add residual connection
|
| 302 |
+
if self.use_residual:
|
| 303 |
+
if self.residual_proj is not None:
|
| 304 |
+
residual = self.residual_proj(residual)
|
| 305 |
+
x = x + residual
|
| 306 |
+
|
| 307 |
+
return x
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class SwiGLUBlock(nn.Module):
|
| 311 |
+
"""SwiGLU activation block for modern transformer architectures."""
|
| 312 |
+
|
| 313 |
+
def __init__(self, input_dim: int, output_dim: int, dtype: torch.dtype = torch.float32):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.gate_proj = nn.Linear(input_dim, output_dim, dtype=dtype)
|
| 316 |
+
self.up_proj = nn.Linear(input_dim, output_dim, dtype=dtype)
|
| 317 |
+
self.activation = nn.SiLU()
|
| 318 |
+
|
| 319 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 320 |
+
gate = self.activation(self.gate_proj(x))
|
| 321 |
+
up = self.up_proj(x)
|
| 322 |
+
return gate * up
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
@register_model
|
| 326 |
+
@capture_init_args
|
| 327 |
+
class AllInOneProjector(Projector):
|
| 328 |
+
"""
|
| 329 |
+
Unified projector that consolidates all projection functionalities with modern patterns.
|
| 330 |
+
|
| 331 |
+
Features:
|
| 332 |
+
1. Gate logit granularity: scalar, token-wise, head-wise, or value-wise
|
| 333 |
+
2. (DISABLED) Selector logit granularity: scalar, token-wise, head-wise, or value-wise
|
| 334 |
+
3. Key/Value weight granularity: scalar, token-wise, head-wise, or value-wise
|
| 335 |
+
4. Input-dependent gates and weights via MLP or parameters
|
| 336 |
+
5. Optional concatenation with combiner networks
|
| 337 |
+
6. Modern MLP architecture with residual connections and SwiGLU
|
| 338 |
+
7. Configurable target preservation: choose between traditional blending or simplified projection
|
| 339 |
+
8. Optional adding of target (self) signal to outputs via add_self
|
| 340 |
+
|
| 341 |
+
Target Preservation Modes:
|
| 342 |
+
- preserve_target_weight=True (default): output = (1-weight)*target + gate*selector*weight*projected
|
| 343 |
+
- preserve_target_weight=False: output = target + gate*selector*weight*projected (no weight coefficient on target)
|
| 344 |
+
|
| 345 |
+
Note: Selector functionality has been disabled/commented out.
|
| 346 |
+
"""
|
| 347 |
+
|
| 348 |
+
def __init__(
|
| 349 |
+
self,
|
| 350 |
+
source_dim: int,
|
| 351 |
+
target_dim: int,
|
| 352 |
+
source_num_heads: int = 1,
|
| 353 |
+
target_num_heads: int = 1,
|
| 354 |
+
hidden_dim: int = 512,
|
| 355 |
+
num_layers: int = 2,
|
| 356 |
+
dropout: float = 0.1,
|
| 357 |
+
activation: str = "gelu",
|
| 358 |
+
use_layer_norm: bool = True,
|
| 359 |
+
use_residual: bool = True,
|
| 360 |
+
use_swiglu: bool = False,
|
| 361 |
+
|
| 362 |
+
# Gate configuration
|
| 363 |
+
gate_granularity: Literal["scalar", "token", "head", "head_merged", "value"] = "scalar",
|
| 364 |
+
gate_depends_on_input: bool = False,
|
| 365 |
+
gate_input_features: Optional[str] = "target_key", # "target_key", "target_value", "both", "target_projected_key", "target_projected_value", "target_projected_both"
|
| 366 |
+
gate_init_value: float = 0.0,
|
| 367 |
+
|
| 368 |
+
# Weight configuration
|
| 369 |
+
weight_granularity: Literal["scalar", "token", "head", "head_merged", "value"] = "scalar",
|
| 370 |
+
weight_depends_on_input: bool = False,
|
| 371 |
+
weight_input_features: Optional[str] = "target_key", # "target_key", "target_value", "both", "target_projected_key", "target_projected_value", "target_projected_both"
|
| 372 |
+
weight_init_value: float = 0.0,
|
| 373 |
+
|
| 374 |
+
# Target preservation configuration
|
| 375 |
+
preserve_target_weight: bool = True, # If False, target won't be multiplied by (1 - normalized_weight)
|
| 376 |
+
add_self: bool = True, # If False, target (self) won't be added to outputs
|
| 377 |
+
|
| 378 |
+
# Concat configuration
|
| 379 |
+
use_concat: bool = False,
|
| 380 |
+
# combiner_hidden_dim: int = 128,
|
| 381 |
+
weight_hidden_dim: int = 1024,
|
| 382 |
+
|
| 383 |
+
# Temperature and gumbel
|
| 384 |
+
use_gumbel: bool = True,
|
| 385 |
+
initial_temperature: float = 1.0,
|
| 386 |
+
final_temperature: float = 0.01,
|
| 387 |
+
anneal_steps: int = 1360,
|
| 388 |
+
scalar_temperature: float = 0.005,
|
| 389 |
+
|
| 390 |
+
# Sequence length configuration
|
| 391 |
+
max_sequence_length: int = 8192, # Maximum sequence length for token-level parameters
|
| 392 |
+
|
| 393 |
+
pos_emb: bool = False,
|
| 394 |
+
|
| 395 |
+
dtype: torch.dtype = torch.float32
|
| 396 |
+
):
|
| 397 |
+
super().__init__()
|
| 398 |
+
|
| 399 |
+
self.source_dim = source_dim
|
| 400 |
+
self.target_dim = target_dim
|
| 401 |
+
self.source_num_heads = source_num_heads
|
| 402 |
+
self.target_num_heads = target_num_heads
|
| 403 |
+
self.hidden_dim = hidden_dim
|
| 404 |
+
self.weight_hidden_dim = weight_hidden_dim
|
| 405 |
+
self.max_sequence_length = max_sequence_length
|
| 406 |
+
|
| 407 |
+
# Configuration
|
| 408 |
+
self.gate_granularity = gate_granularity
|
| 409 |
+
self.gate_depends_on_input = gate_depends_on_input
|
| 410 |
+
self.gate_input_features = gate_input_features
|
| 411 |
+
self.weight_granularity = weight_granularity
|
| 412 |
+
self.weight_depends_on_input = weight_depends_on_input
|
| 413 |
+
self.weight_input_features = weight_input_features
|
| 414 |
+
self.preserve_target_weight = preserve_target_weight
|
| 415 |
+
self.add_self = add_self
|
| 416 |
+
self.use_concat = use_concat
|
| 417 |
+
self.use_gumbel = use_gumbel
|
| 418 |
+
self.scalar_temperature = scalar_temperature
|
| 419 |
+
|
| 420 |
+
# Temperature annealing for gate only (removed selector temperature)
|
| 421 |
+
self.register_buffer("gate_temperature", torch.tensor(initial_temperature, dtype=dtype))
|
| 422 |
+
self.initial_temperature = initial_temperature
|
| 423 |
+
self.final_temperature = final_temperature
|
| 424 |
+
self.anneal_steps = anneal_steps
|
| 425 |
+
|
| 426 |
+
# Build projection networks
|
| 427 |
+
self.key_projection = self._build_projection_mlp(
|
| 428 |
+
source_dim * source_num_heads,
|
| 429 |
+
target_dim * target_num_heads,
|
| 430 |
+
hidden_dim, num_layers, activation, use_layer_norm,
|
| 431 |
+
use_residual, dropout, use_swiglu, dtype
|
| 432 |
+
)
|
| 433 |
+
self.value_projection = self._build_projection_mlp(
|
| 434 |
+
source_dim * source_num_heads,
|
| 435 |
+
target_dim * target_num_heads,
|
| 436 |
+
hidden_dim, num_layers, activation, use_layer_norm,
|
| 437 |
+
use_residual, dropout, use_swiglu, dtype
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Build gate components
|
| 441 |
+
self._build_gate_components(dtype)
|
| 442 |
+
|
| 443 |
+
# Build weight components
|
| 444 |
+
self._build_weight_components(weight_init_value, dtype)
|
| 445 |
+
|
| 446 |
+
# Build concat components if needed
|
| 447 |
+
if self.use_concat:
|
| 448 |
+
in_dim = target_dim * target_num_heads * 2
|
| 449 |
+
out_dim = target_dim * target_num_heads
|
| 450 |
+
self.key_combiner = nn.Linear(in_dim, out_dim, dtype=dtype)
|
| 451 |
+
self.value_combiner = nn.Linear(in_dim, out_dim, dtype=dtype)
|
| 452 |
+
|
| 453 |
+
def _build_projection_mlp(
|
| 454 |
+
self, input_dim: int, output_dim: int, hidden_dim: int,
|
| 455 |
+
num_layers: int, activation: str, use_layer_norm: bool,
|
| 456 |
+
use_residual: bool, dropout: float, use_swiglu: bool, dtype: torch.dtype
|
| 457 |
+
) -> ModernMLP:
|
| 458 |
+
"""Build modern MLP for projection."""
|
| 459 |
+
return ModernMLP(
|
| 460 |
+
input_dim=input_dim,
|
| 461 |
+
output_dim=output_dim,
|
| 462 |
+
hidden_dim=hidden_dim,
|
| 463 |
+
num_layers=num_layers,
|
| 464 |
+
activation=activation,
|
| 465 |
+
use_layer_norm=use_layer_norm,
|
| 466 |
+
use_residual=use_residual,
|
| 467 |
+
dropout=dropout,
|
| 468 |
+
use_swiglu=use_swiglu,
|
| 469 |
+
dtype=dtype
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
def _build_gate_components(self, dtype: torch.dtype):
|
| 473 |
+
"""Build gate logit components based on configuration."""
|
| 474 |
+
if not self.gate_depends_on_input:
|
| 475 |
+
# Parameter-based gate
|
| 476 |
+
gate_shape = self._get_parameter_shape(self.gate_granularity)
|
| 477 |
+
self.gate_logit = nn.Parameter(torch.zeros(gate_shape, dtype=dtype))
|
| 478 |
+
else:
|
| 479 |
+
# Input-dependent gate via MLP
|
| 480 |
+
input_dim = self._get_gate_input_dim()
|
| 481 |
+
output_dim = self._get_gate_output_dim()
|
| 482 |
+
|
| 483 |
+
self.gate_generator = ModernMLP(
|
| 484 |
+
input_dim=input_dim,
|
| 485 |
+
output_dim=output_dim,
|
| 486 |
+
hidden_dim=self.hidden_dim,
|
| 487 |
+
num_layers=2,
|
| 488 |
+
activation="gelu",
|
| 489 |
+
use_layer_norm=True,
|
| 490 |
+
use_residual=False,
|
| 491 |
+
dropout=0.1,
|
| 492 |
+
dtype=dtype
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
def _build_weight_components(self, weight_init_value: float, dtype: torch.dtype):
|
| 496 |
+
"""Build weight components based on configuration."""
|
| 497 |
+
if not self.weight_depends_on_input:
|
| 498 |
+
# Parameter-based weights
|
| 499 |
+
weight_shape = self._get_parameter_shape(self.weight_granularity)
|
| 500 |
+
self.key_weight = nn.Parameter(torch.full(weight_shape, weight_init_value, dtype=dtype))
|
| 501 |
+
self.value_weight = nn.Parameter(torch.full(weight_shape, weight_init_value, dtype=dtype))
|
| 502 |
+
else:
|
| 503 |
+
# Input-dependent weights via MLP
|
| 504 |
+
input_dim = self._get_weight_input_dim()
|
| 505 |
+
output_dim = self._get_weight_output_dim()
|
| 506 |
+
|
| 507 |
+
# Shared hidden layer for efficiency
|
| 508 |
+
self.weight_hidden = ModernMLP(
|
| 509 |
+
input_dim=input_dim,
|
| 510 |
+
output_dim=self.weight_hidden_dim,
|
| 511 |
+
hidden_dim=self.weight_hidden_dim,
|
| 512 |
+
num_layers=2,
|
| 513 |
+
activation="gelu",
|
| 514 |
+
use_layer_norm=True,
|
| 515 |
+
use_residual=False,
|
| 516 |
+
dropout=0.1,
|
| 517 |
+
dtype=dtype
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
# Separate heads for key and value weights
|
| 521 |
+
self.key_weight_head = nn.Linear(self.weight_hidden_dim, output_dim, dtype=dtype)
|
| 522 |
+
self.value_weight_head = nn.Linear(self.weight_hidden_dim, output_dim, dtype=dtype)
|
| 523 |
+
|
| 524 |
+
def _get_parameter_shape(self, granularity: str) -> tuple:
|
| 525 |
+
"""Get parameter shape based on granularity."""
|
| 526 |
+
if granularity == "scalar":
|
| 527 |
+
return () # Scalar
|
| 528 |
+
elif granularity == "token":
|
| 529 |
+
return (self.max_sequence_length,) # Token-level parameters with max sequence length
|
| 530 |
+
elif granularity == "head":
|
| 531 |
+
return (self.max_sequence_length, self.target_num_heads) # Token and head level parameters
|
| 532 |
+
elif granularity == "head_merged":
|
| 533 |
+
return (self.max_sequence_length, self.target_num_heads) # Token and head level parameters
|
| 534 |
+
elif granularity == "value":
|
| 535 |
+
return (self.max_sequence_length, self.target_num_heads, self.target_dim) # Token, head and value level parameters
|
| 536 |
+
else:
|
| 537 |
+
raise ValueError(f"Invalid granularity: {granularity}")
|
| 538 |
+
|
| 539 |
+
def _get_gate_input_dim(self) -> int:
|
| 540 |
+
"""Get input dimension for gate generator."""
|
| 541 |
+
base_dim = 0
|
| 542 |
+
if self.gate_input_features == "target_key":
|
| 543 |
+
base_dim = self.target_dim
|
| 544 |
+
elif self.gate_input_features == "target_value":
|
| 545 |
+
base_dim = self.target_dim
|
| 546 |
+
elif self.gate_input_features == "both":
|
| 547 |
+
base_dim = self.target_dim * 2
|
| 548 |
+
elif self.gate_input_features == "target_projected_key":
|
| 549 |
+
base_dim = self.target_dim * 2 # target_key + projected_key
|
| 550 |
+
elif self.gate_input_features == "target_projected_value":
|
| 551 |
+
base_dim = self.target_dim * 2 # target_value + projected_value
|
| 552 |
+
elif self.gate_input_features == "target_projected_both":
|
| 553 |
+
base_dim = self.target_dim * 4 # target_key + target_value + projected_key + projected_value
|
| 554 |
+
else:
|
| 555 |
+
raise ValueError(f"Invalid gate input features: {self.gate_input_features}")
|
| 556 |
+
|
| 557 |
+
# Adjust for granularity processing strategy
|
| 558 |
+
if self.gate_granularity == "scalar":
|
| 559 |
+
# Scalar: process aggregated features across all heads
|
| 560 |
+
return base_dim # Use pooled features
|
| 561 |
+
elif self.gate_granularity == "token":
|
| 562 |
+
# Token: process merged head dimensions
|
| 563 |
+
return base_dim * self.target_num_heads # Flatten (H, D) to (H*D)
|
| 564 |
+
elif self.gate_granularity == "head_merged":
|
| 565 |
+
# Head-merged: similar to token granularity, merge H and D
|
| 566 |
+
return base_dim * self.target_num_heads # (B, N, H*D)
|
| 567 |
+
elif self.gate_granularity == "head":
|
| 568 |
+
# Head-local: per head processing, do not merge heads
|
| 569 |
+
return base_dim # (B, H, N, D)
|
| 570 |
+
else: # value
|
| 571 |
+
# Value: process per-head features
|
| 572 |
+
return base_dim # Keep per-head processing (B, H, N, D)
|
| 573 |
+
|
| 574 |
+
def _get_gate_output_dim(self) -> int:
|
| 575 |
+
"""Get output dimension for gate generator."""
|
| 576 |
+
if self.gate_granularity == "scalar":
|
| 577 |
+
return 1
|
| 578 |
+
elif self.gate_granularity == "token":
|
| 579 |
+
return 1 # Per token
|
| 580 |
+
elif self.gate_granularity == "head_merged":
|
| 581 |
+
# Per token per head after merge: output one value per head
|
| 582 |
+
return self.target_num_heads
|
| 583 |
+
elif self.gate_granularity == "head":
|
| 584 |
+
# Per token per head: scalar per head
|
| 585 |
+
return 1
|
| 586 |
+
elif self.gate_granularity == "value":
|
| 587 |
+
return self.target_dim # Per token per head per value (but processed per-head, so output D per head)
|
| 588 |
+
else:
|
| 589 |
+
raise ValueError(f"Invalid gate granularity: {self.gate_granularity}")
|
| 590 |
+
|
| 591 |
+
def _get_weight_input_dim(self) -> int:
|
| 592 |
+
"""Get input dimension for weight generator."""
|
| 593 |
+
base_dim = 0
|
| 594 |
+
if self.weight_input_features == "target_key":
|
| 595 |
+
base_dim = self.target_dim
|
| 596 |
+
elif self.weight_input_features == "target_value":
|
| 597 |
+
base_dim = self.target_dim
|
| 598 |
+
elif self.weight_input_features == "both":
|
| 599 |
+
base_dim = self.target_dim * 2
|
| 600 |
+
elif self.weight_input_features == "target_projected_key":
|
| 601 |
+
base_dim = self.target_dim * 2 # target_key + projected_key
|
| 602 |
+
elif self.weight_input_features == "target_projected_value":
|
| 603 |
+
base_dim = self.target_dim * 2 # target_value + projected_value
|
| 604 |
+
elif self.weight_input_features == "target_projected_both":
|
| 605 |
+
base_dim = self.target_dim * 4 # target_key + target_value + projected_key + projected_value
|
| 606 |
+
else:
|
| 607 |
+
raise ValueError(f"Invalid weight input features: {self.weight_input_features}")
|
| 608 |
+
|
| 609 |
+
# Adjust for granularity processing strategy
|
| 610 |
+
if self.weight_granularity == "scalar":
|
| 611 |
+
# Scalar: process aggregated features across all heads
|
| 612 |
+
return base_dim # Use pooled features
|
| 613 |
+
elif self.weight_granularity == "token":
|
| 614 |
+
# Token: process merged head dimensions
|
| 615 |
+
return base_dim * self.target_num_heads # Flatten (H, D) to (H*D)
|
| 616 |
+
elif self.weight_granularity == "head_merged":
|
| 617 |
+
# Head-merged: similar to token granularity, merge H and D
|
| 618 |
+
return base_dim * self.target_num_heads # (B, N, H*D)
|
| 619 |
+
elif self.weight_granularity == "head":
|
| 620 |
+
# Head-local: per head processing, do not merge heads
|
| 621 |
+
return base_dim # (B, H, N, D)
|
| 622 |
+
else: # value
|
| 623 |
+
# Value: process per-head features
|
| 624 |
+
return base_dim # Keep per-head processing (B, H, N, D)
|
| 625 |
+
|
| 626 |
+
def _get_weight_output_dim(self) -> int:
|
| 627 |
+
"""Get output dimension for weight generator."""
|
| 628 |
+
if self.weight_granularity == "scalar":
|
| 629 |
+
return 1
|
| 630 |
+
elif self.weight_granularity == "token":
|
| 631 |
+
return 1 # Per token
|
| 632 |
+
elif self.weight_granularity == "head_merged":
|
| 633 |
+
# Per token per head after merge: output one value per head
|
| 634 |
+
return self.target_num_heads
|
| 635 |
+
elif self.weight_granularity == "head":
|
| 636 |
+
# Per token per head: scalar per head
|
| 637 |
+
return 1
|
| 638 |
+
elif self.weight_granularity == "value":
|
| 639 |
+
return self.target_dim # Per token per head per value (but processed per-head, so output D per head)
|
| 640 |
+
else:
|
| 641 |
+
raise ValueError(f"Invalid weight granularity: {self.weight_granularity}")
|
| 642 |
+
|
| 643 |
+
def _generate_gates(self, target_key: Tensor, target_value: Tensor, projected_key: Tensor = None, projected_value: Tensor = None) -> Tensor:
|
| 644 |
+
"""Generate gate logits based on configuration."""
|
| 645 |
+
if not self.gate_depends_on_input:
|
| 646 |
+
# Use parameter-based gate
|
| 647 |
+
return self.gate_logit
|
| 648 |
+
else:
|
| 649 |
+
# Generate input-dependent gate
|
| 650 |
+
# First, prepare the base input features
|
| 651 |
+
if self.gate_input_features == "target_key":
|
| 652 |
+
base_input = target_key
|
| 653 |
+
elif self.gate_input_features == "target_value":
|
| 654 |
+
base_input = target_value
|
| 655 |
+
elif self.gate_input_features == "both":
|
| 656 |
+
base_input = torch.cat([target_key, target_value], dim=-1)
|
| 657 |
+
elif self.gate_input_features == "target_projected_key":
|
| 658 |
+
if projected_key is None:
|
| 659 |
+
raise ValueError("projected_key is required for target_projected_key input features")
|
| 660 |
+
base_input = torch.cat([target_key, projected_key], dim=-1)
|
| 661 |
+
elif self.gate_input_features == "target_projected_value":
|
| 662 |
+
if projected_value is None:
|
| 663 |
+
raise ValueError("projected_value is required for target_projected_value input features")
|
| 664 |
+
base_input = torch.cat([target_value, projected_value], dim=-1)
|
| 665 |
+
elif self.gate_input_features == "target_projected_both":
|
| 666 |
+
if projected_key is None or projected_value is None:
|
| 667 |
+
raise ValueError("Both projected_key and projected_value are required for target_projected_both input features")
|
| 668 |
+
base_input = torch.cat([target_key, target_value, projected_key, projected_value], dim=-1)
|
| 669 |
+
|
| 670 |
+
# Now process based on granularity
|
| 671 |
+
# base_input shape: (B, H, N, D_input)
|
| 672 |
+
B, H, N, D_input = base_input.shape
|
| 673 |
+
|
| 674 |
+
if self.gate_granularity == "scalar":
|
| 675 |
+
# For scalar granularity, aggregate all dimensions: (B, H, N, D_input) -> (B, D_input)
|
| 676 |
+
gate_input = base_input.mean(dim=(1, 2)) # Average over heads and tokens
|
| 677 |
+
elif self.gate_granularity == "token":
|
| 678 |
+
# For token granularity, merge H and D_input dimensions: (B, H, N, D_input) -> (B, N, H*D_input)
|
| 679 |
+
gate_input = base_input.transpose(1, 2).contiguous().view(B, N, H * D_input)
|
| 680 |
+
elif self.gate_granularity == "head_merged":
|
| 681 |
+
# For head granularity, merge H and D like token: (B, H, N, D_in) -> (B, N, H*D_in)
|
| 682 |
+
gate_input = base_input.transpose(1, 2).contiguous().view(B, N, H * D_input)
|
| 683 |
+
elif self.gate_granularity == "head":
|
| 684 |
+
# For head granularity, keep per-head processing: (B, H, N, D_input)
|
| 685 |
+
gate_input = base_input
|
| 686 |
+
elif self.gate_granularity == "value":
|
| 687 |
+
# For value granularity, keep per-head processing: (B, H, N, D_input)
|
| 688 |
+
gate_input = base_input
|
| 689 |
+
|
| 690 |
+
return self.gate_generator(gate_input)
|
| 691 |
+
|
| 692 |
+
def _generate_weights(self, target_key: Tensor, target_value: Tensor, projected_key: Tensor = None, projected_value: Tensor = None) -> Tuple[Tensor, Tensor]:
|
| 693 |
+
"""Generate weights based on configuration."""
|
| 694 |
+
if not self.weight_depends_on_input:
|
| 695 |
+
# Use parameter-based weights
|
| 696 |
+
return self.key_weight, self.value_weight
|
| 697 |
+
else:
|
| 698 |
+
# Generate input-dependent weights
|
| 699 |
+
# First, prepare the base input features
|
| 700 |
+
if self.weight_input_features == "target_key":
|
| 701 |
+
base_input = target_key
|
| 702 |
+
elif self.weight_input_features == "target_value":
|
| 703 |
+
base_input = target_value
|
| 704 |
+
elif self.weight_input_features == "both":
|
| 705 |
+
base_input = torch.cat([target_key, target_value], dim=-1)
|
| 706 |
+
elif self.weight_input_features == "target_projected_key":
|
| 707 |
+
if projected_key is None:
|
| 708 |
+
raise ValueError("projected_key is required for target_projected_key input features")
|
| 709 |
+
base_input = torch.cat([target_key, projected_key], dim=-1)
|
| 710 |
+
elif self.weight_input_features == "target_projected_value":
|
| 711 |
+
if projected_value is None:
|
| 712 |
+
raise ValueError("projected_value is required for target_projected_value input features")
|
| 713 |
+
base_input = torch.cat([target_value, projected_value], dim=-1)
|
| 714 |
+
elif self.weight_input_features == "target_projected_both":
|
| 715 |
+
if projected_key is None or projected_value is None:
|
| 716 |
+
raise ValueError("Both projected_key and projected_value are required for target_projected_both input features")
|
| 717 |
+
base_input = torch.cat([target_key, target_value, projected_key, projected_value], dim=-1)
|
| 718 |
+
|
| 719 |
+
# Now process based on granularity
|
| 720 |
+
# base_input shape: (B, H, N, D_input)
|
| 721 |
+
B, H, N, D_input = base_input.shape
|
| 722 |
+
|
| 723 |
+
if self.weight_granularity == "scalar":
|
| 724 |
+
# For scalar granularity, aggregate all dimensions: (B, H, N, D_input) -> (B, D_input)
|
| 725 |
+
weight_input = base_input.mean(dim=(1, 2)) # Average over heads and tokens
|
| 726 |
+
elif self.weight_granularity == "token":
|
| 727 |
+
# For token granularity, merge H and D_input dimensions: (B, H, N, D_input) -> (B, N, H*D_input)
|
| 728 |
+
weight_input = base_input.transpose(1, 2).contiguous().view(B, N, H * D_input)
|
| 729 |
+
elif self.weight_granularity == "head_merged":
|
| 730 |
+
# For head granularity, merge H and D like token: (B, H, N, D_in) -> (B, N, H*D_in)
|
| 731 |
+
weight_input = base_input.transpose(1, 2).contiguous().view(B, N, H * D_input)
|
| 732 |
+
elif self.weight_granularity == "head":
|
| 733 |
+
# For head granularity, keep per-head processing: (B, H, N, D_input)
|
| 734 |
+
weight_input = base_input
|
| 735 |
+
elif self.weight_granularity == "value":
|
| 736 |
+
# For value granularity, keep per-head processing: (B, H, N, D_input)
|
| 737 |
+
weight_input = base_input
|
| 738 |
+
|
| 739 |
+
weight_hidden = self.weight_hidden(weight_input)
|
| 740 |
+
key_weight = self.key_weight_head(weight_hidden)
|
| 741 |
+
value_weight = self.value_weight_head(weight_hidden)
|
| 742 |
+
|
| 743 |
+
return key_weight, value_weight
|
| 744 |
+
|
| 745 |
+
def _apply_gumbel_sigmoid(self, gate_logit: Tensor) -> Tensor:
|
| 746 |
+
"""Apply Gumbel sigmoid trick for training."""
|
| 747 |
+
if self.training and self.use_gumbel:
|
| 748 |
+
gumbel_noise = self._sample_gumbel(gate_logit.shape, gate_logit.device, gate_logit.dtype)
|
| 749 |
+
return torch.sigmoid((gate_logit + gumbel_noise) / self.gate_temperature)
|
| 750 |
+
else:
|
| 751 |
+
return (gate_logit > 0).float()
|
| 752 |
+
|
| 753 |
+
@staticmethod
|
| 754 |
+
def _sample_gumbel(shape: tuple, device: torch.device, dtype: torch.dtype, eps: float = 1e-20) -> Tensor:
|
| 755 |
+
"""Sample from Gumbel distribution."""
|
| 756 |
+
u = torch.rand(shape, device=device, dtype=dtype)
|
| 757 |
+
return -torch.log(-torch.log(u + eps) + eps)
|
| 758 |
+
|
| 759 |
+
def _reshape_for_granularity(self, tensor: Tensor, granularity: str, target_shape: tuple) -> Tensor:
|
| 760 |
+
"""Reshape tensor to match target shape based on granularity."""
|
| 761 |
+
B, H, N, D = target_shape
|
| 762 |
+
|
| 763 |
+
if granularity == "scalar":
|
| 764 |
+
# Scalar -> (B, H, N, D)
|
| 765 |
+
return tensor.view(1, 1, 1, 1).expand(B, H, N, D)
|
| 766 |
+
elif granularity == "token":
|
| 767 |
+
# (max_seq_len,) -> (B, H, N, D) - slice to actual sequence length
|
| 768 |
+
token_params = tensor[:N] # Take first N tokens
|
| 769 |
+
return token_params.view(1, 1, N, 1).expand(B, H, N, D)
|
| 770 |
+
elif granularity == "head":
|
| 771 |
+
# (max_seq_len, H) -> (B, H, N, D) - slice to actual sequence length, each token each head independent
|
| 772 |
+
head_params = tensor[:N, :] # Take first N tokens, all heads: (N, H)
|
| 773 |
+
return head_params.view(1, N, H, 1).transpose(1, 2).expand(B, H, N, D) # (1, N, H, 1) -> (1, H, N, 1) -> (B, H, N, D)
|
| 774 |
+
elif granularity == "head_merged":
|
| 775 |
+
raise NotImplementedError
|
| 776 |
+
elif granularity == "value":
|
| 777 |
+
# (max_seq_len, H, D) -> (B, H, N, D) - slice to actual sequence length, each token each head each value independent
|
| 778 |
+
value_params = tensor[:N, :, :] # Take first N tokens: (N, H, D)
|
| 779 |
+
return value_params.view(1, N, H, D).transpose(1, 2).expand(B, H, N, D) # (1, N, H, D) -> (1, H, N, D) -> (B, H, N, D)
|
| 780 |
+
else:
|
| 781 |
+
raise ValueError(f"Invalid granularity: {granularity}")
|
| 782 |
+
|
| 783 |
+
def update_temperature(self, step: int):
|
| 784 |
+
"""Update temperature using exponential annealing schedule for gate only."""
|
| 785 |
+
# Update gate temperature
|
| 786 |
+
gate_ratio = min(step / self.anneal_steps, 1.0)
|
| 787 |
+
gate_temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** gate_ratio
|
| 788 |
+
self.gate_temperature.fill_(gate_temp)
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor], position_ids: Optional[Tensor] = None, max_pos: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
|
| 792 |
+
"""
|
| 793 |
+
Forward pass with unified projection logic.
|
| 794 |
+
|
| 795 |
+
Args:
|
| 796 |
+
source_kv: Tuple of (key, value) tensors, each (B, H_s, N, D_s)
|
| 797 |
+
target_kv: Tuple of (key, value) tensors, each (B, H_t, N, D_t)
|
| 798 |
+
position_ids: Position ids tensor (B, N), optional, required if pos_emb=True
|
| 799 |
+
Returns:
|
| 800 |
+
Tuple of (key, value) tensors, each (B, H_t, N, D_t)
|
| 801 |
+
"""
|
| 802 |
+
source_key, source_value = source_kv
|
| 803 |
+
target_key, target_value = target_kv
|
| 804 |
+
|
| 805 |
+
# Get shapes
|
| 806 |
+
B, H_s, N, D_s = source_key.shape
|
| 807 |
+
_, H_t, _, D_t = target_key.shape
|
| 808 |
+
|
| 809 |
+
# Reshape for projection: (B, H, N, D) -> (B, N, H*D)
|
| 810 |
+
source_key_flat = source_key.transpose(1, 2).contiguous().view(B, N, H_s * D_s)
|
| 811 |
+
source_value_flat = source_value.transpose(1, 2).contiguous().view(B, N, H_s * D_s)
|
| 812 |
+
|
| 813 |
+
# Project source to target dimension
|
| 814 |
+
projected_key_flat = self.key_projection(source_key_flat) # (B, N, H_t * D_t)
|
| 815 |
+
projected_value_flat = self.value_projection(source_value_flat) # (B, N, H_t * D_t)
|
| 816 |
+
|
| 817 |
+
# Handle concatenation if enabled
|
| 818 |
+
if self.use_concat:
|
| 819 |
+
target_key_flat = target_key.transpose(1, 2).contiguous().view(B, N, H_t * D_t)
|
| 820 |
+
target_value_flat = target_value.transpose(1, 2).contiguous().view(B, N, H_t * D_t)
|
| 821 |
+
|
| 822 |
+
# Concatenate and combine
|
| 823 |
+
combined_key = torch.cat([projected_key_flat, target_key_flat], dim=-1)
|
| 824 |
+
combined_value = torch.cat([projected_value_flat, target_value_flat], dim=-1)
|
| 825 |
+
|
| 826 |
+
final_projected_key_flat = self.key_combiner(combined_key)
|
| 827 |
+
final_projected_value_flat = self.value_combiner(combined_value)
|
| 828 |
+
else:
|
| 829 |
+
final_projected_key_flat = projected_key_flat
|
| 830 |
+
final_projected_value_flat = projected_value_flat
|
| 831 |
+
|
| 832 |
+
# Reshape back: (B, N, H_t * D_t) -> (B, H_t, N, D_t)
|
| 833 |
+
projected_key = final_projected_key_flat.view(B, N, H_t, D_t).transpose(1, 2)
|
| 834 |
+
projected_value = final_projected_value_flat.view(B, N, H_t, D_t).transpose(1, 2)
|
| 835 |
+
|
| 836 |
+
# Generate gates, selectors and weights (may need projected tensors for input features)
|
| 837 |
+
needs_projected_for_gate = self.gate_depends_on_input and self.gate_input_features in [
|
| 838 |
+
"target_projected_key", "target_projected_value", "target_projected_both"
|
| 839 |
+
]
|
| 840 |
+
needs_projected_for_weight = self.weight_depends_on_input and self.weight_input_features in [
|
| 841 |
+
"target_projected_key", "target_projected_value", "target_projected_both"
|
| 842 |
+
]
|
| 843 |
+
|
| 844 |
+
if needs_projected_for_gate or needs_projected_for_weight:
|
| 845 |
+
gate_logit = self._generate_gates(target_key, target_value, projected_key, projected_value)
|
| 846 |
+
key_weight, value_weight = self._generate_weights(target_key, target_value, projected_key, projected_value)
|
| 847 |
+
else:
|
| 848 |
+
gate_logit = self._generate_gates(target_key, target_value)
|
| 849 |
+
key_weight, value_weight = self._generate_weights(target_key, target_value)
|
| 850 |
+
|
| 851 |
+
# Reshape gates and weights to match target shape
|
| 852 |
+
target_shape = (B, H_t, N, D_t)
|
| 853 |
+
if self.gate_depends_on_input:
|
| 854 |
+
# Reshape based on gate granularity - all preserve token dimension N
|
| 855 |
+
if self.gate_granularity == "scalar":
|
| 856 |
+
# For scalar, gate_logit is already (B, 1) from MLP, just expand
|
| 857 |
+
gate_logit = gate_logit.view(B, 1, 1, 1).expand(target_shape)
|
| 858 |
+
elif self.gate_granularity == "token":
|
| 859 |
+
gate_logit = gate_logit.unsqueeze(1).unsqueeze(-1).expand(target_shape) # (B, N, 1) -> (B, H, N, D)
|
| 860 |
+
elif self.gate_granularity == "head_merged":
|
| 861 |
+
# (B, N, H) -> (B, H, N, D) - per token per head, broadcast over D
|
| 862 |
+
gate_logit = gate_logit.permute(0, 2, 1).unsqueeze(-1).expand(B, H_t, N, D_t)
|
| 863 |
+
elif self.gate_granularity == "head":
|
| 864 |
+
# (B, H, N, 1) -> (B, H, N, D) - per token per head scalar, broadcast over D
|
| 865 |
+
gate_logit = gate_logit.expand(B, H_t, N, D_t)
|
| 866 |
+
elif self.gate_granularity == "value":
|
| 867 |
+
# (B, H, N, D) -> (B, H, N, D) - each token each head each value has one value
|
| 868 |
+
pass # Already in correct shape
|
| 869 |
+
else:
|
| 870 |
+
gate_logit = self._reshape_for_granularity(gate_logit, self.gate_granularity, target_shape)
|
| 871 |
+
|
| 872 |
+
if self.weight_depends_on_input:
|
| 873 |
+
# Reshape weights based on granularity - all preserve token dimension N
|
| 874 |
+
if self.weight_granularity == "scalar":
|
| 875 |
+
# For scalar, weights are already (B, 1) from MLP, just expand
|
| 876 |
+
key_weight = key_weight.view(B, 1, 1, 1).expand(target_shape)
|
| 877 |
+
value_weight = value_weight.view(B, 1, 1, 1).expand(target_shape)
|
| 878 |
+
elif self.weight_granularity == "token":
|
| 879 |
+
key_weight = key_weight.unsqueeze(1).expand(target_shape) # (B, N, 1) -> (B, H, N, D)
|
| 880 |
+
value_weight = value_weight.unsqueeze(1).expand(target_shape)
|
| 881 |
+
elif self.weight_granularity == "head_merged":
|
| 882 |
+
# (B, N, H) -> (B, H, N, D) - per token per head, broadcast over D
|
| 883 |
+
key_weight = key_weight.permute(0, 2, 1).unsqueeze(-1).expand(B, H_t, N, D_t)
|
| 884 |
+
value_weight = value_weight.permute(0, 2, 1).unsqueeze(-1).expand(B, H_t, N, D_t)
|
| 885 |
+
elif self.weight_granularity == "head":
|
| 886 |
+
# (B, H, N, 1) -> (B, H, N, D) - per token per head scalar, broadcast over D
|
| 887 |
+
key_weight = key_weight.expand(B, H_t, N, D_t)
|
| 888 |
+
value_weight = value_weight.expand(B, H_t, N, D_t)
|
| 889 |
+
elif self.weight_granularity == "value":
|
| 890 |
+
# (B, H, N, D) -> (B, H, N, D) - each token each head each value has one value
|
| 891 |
+
pass # Already in correct shape
|
| 892 |
+
else:
|
| 893 |
+
key_weight = self._reshape_for_granularity(key_weight, self.weight_granularity, target_shape)
|
| 894 |
+
value_weight = self._reshape_for_granularity(value_weight, self.weight_granularity, target_shape)
|
| 895 |
+
|
| 896 |
+
# Apply gating and selection
|
| 897 |
+
gate = self._apply_gumbel_sigmoid(gate_logit)
|
| 898 |
+
|
| 899 |
+
# Normalize weights using dynamic temperature
|
| 900 |
+
normalized_key_weight = torch.sigmoid(key_weight / self.scalar_temperature)
|
| 901 |
+
normalized_value_weight = torch.sigmoid(value_weight / self.scalar_temperature)
|
| 902 |
+
|
| 903 |
+
# Final combination
|
| 904 |
+
# Compute projected contribution (always present)
|
| 905 |
+
projected_key_term = gate * normalized_key_weight * projected_key
|
| 906 |
+
projected_value_term = gate * normalized_value_weight * projected_value
|
| 907 |
+
|
| 908 |
+
# Compute target (self) contribution depending on flags
|
| 909 |
+
if self.add_self:
|
| 910 |
+
if self.preserve_target_weight:
|
| 911 |
+
target_key_term = (1 - normalized_key_weight) * target_key
|
| 912 |
+
target_value_term = (1 - normalized_value_weight) * target_value
|
| 913 |
+
else:
|
| 914 |
+
target_key_term = target_key
|
| 915 |
+
target_value_term = target_value
|
| 916 |
+
else:
|
| 917 |
+
target_key_term = torch.zeros_like(target_key)
|
| 918 |
+
target_value_term = torch.zeros_like(target_value)
|
| 919 |
+
|
| 920 |
+
# Final outputs
|
| 921 |
+
output_key = target_key_term + projected_key_term
|
| 922 |
+
output_value = target_value_term + projected_value_term
|
| 923 |
+
|
| 924 |
+
return (output_key, output_value)
|
| 925 |
+
|
| 926 |
+
class QwenStyleLayer(nn.Module):
|
| 927 |
+
"""
|
| 928 |
+
One Qwen3-style MLP sublayer:
|
| 929 |
+
y = x + Dropout( down( SiLU(gate(LN(x))) * up(LN(x)) ) )
|
| 930 |
+
- Pre-norm with RMSNorm
|
| 931 |
+
- Bias-free linears
|
| 932 |
+
"""
|
| 933 |
+
def __init__(self, hidden_size: int, intermediate_size: int, dropout: float = 0.0, dtype: torch.dtype = torch.float32):
|
| 934 |
+
super().__init__()
|
| 935 |
+
self.norm = nn.RMSNorm(hidden_size, eps=1e-6, dtype=dtype)
|
| 936 |
+
self.gate = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype)
|
| 937 |
+
self.up = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype)
|
| 938 |
+
self.down = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype)
|
| 939 |
+
self.act = nn.SiLU()
|
| 940 |
+
self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 941 |
+
|
| 942 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 943 |
+
h = self.norm(x)
|
| 944 |
+
h = self.act(self.gate(h)) * self.up(h) # SwiGLU
|
| 945 |
+
h = self.down(h)
|
| 946 |
+
h = self.drop(h)
|
| 947 |
+
return x + h
|
| 948 |
+
|
| 949 |
+
class StandardFFNLayer(nn.Module):
|
| 950 |
+
"""
|
| 951 |
+
Pre-norm RMSNorm, classic MLP:
|
| 952 |
+
y = x + Dropout( W2( Act( W1( RMSNorm(x) ) ) ) )
|
| 953 |
+
- No SwiGLU: single hidden nonlinearity (GELU/ReLU/SiLU)
|
| 954 |
+
- Bias-free linears (common in modern LLM FFNs)
|
| 955 |
+
"""
|
| 956 |
+
def __init__(
|
| 957 |
+
self,
|
| 958 |
+
hidden_size: int,
|
| 959 |
+
intermediate_size: int,
|
| 960 |
+
dropout: float = 0.0,
|
| 961 |
+
dtype: torch.dtype = torch.float32,
|
| 962 |
+
activation: str = "gelu",
|
| 963 |
+
):
|
| 964 |
+
super().__init__()
|
| 965 |
+
self.norm = nn.RMSNorm(hidden_size, eps=1e-6, dtype=dtype)
|
| 966 |
+
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype)
|
| 967 |
+
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype)
|
| 968 |
+
self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 969 |
+
|
| 970 |
+
act = activation.lower()
|
| 971 |
+
if act == "gelu":
|
| 972 |
+
self.act = nn.GELU()
|
| 973 |
+
elif act == "relu":
|
| 974 |
+
self.act = nn.ReLU()
|
| 975 |
+
elif act == "silu":
|
| 976 |
+
self.act = nn.SiLU()
|
| 977 |
+
else:
|
| 978 |
+
raise ValueError(f"Unsupported activation: {activation}")
|
| 979 |
+
|
| 980 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 981 |
+
h = self.norm(x)
|
| 982 |
+
h = self.act(self.w1(h))
|
| 983 |
+
h = self.w2(h)
|
| 984 |
+
h = self.drop(h)
|
| 985 |
+
return x + h
|
| 986 |
+
|
| 987 |
+
class RegularMLP(nn.Module):
|
| 988 |
+
"""
|
| 989 |
+
Qwen3-style stacked MLP operating at a fixed hidden size.
|
| 990 |
+
- No input/output projections; caller is responsible for projections.
|
| 991 |
+
- num_layers repeats of Qwen-style FFN sublayer (pre-RMSNorm, SwiGLU, bias-free)
|
| 992 |
+
"""
|
| 993 |
+
def __init__(
|
| 994 |
+
self,
|
| 995 |
+
hidden_dim: int = 1024,
|
| 996 |
+
intermediate_dim: int = 3072,
|
| 997 |
+
num_layers: int = 3,
|
| 998 |
+
dropout: float = 0.1,
|
| 999 |
+
dtype: torch.dtype = torch.float32,
|
| 1000 |
+
):
|
| 1001 |
+
super().__init__()
|
| 1002 |
+
assert num_layers >= 1, "num_layers must be >= 1"
|
| 1003 |
+
|
| 1004 |
+
self.blocks = nn.ModuleList([
|
| 1005 |
+
StandardFFNLayer(hidden_size=hidden_dim, intermediate_size=intermediate_dim, dropout=dropout, dtype=dtype)
|
| 1006 |
+
for _ in range(num_layers)
|
| 1007 |
+
])
|
| 1008 |
+
|
| 1009 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1010 |
+
for blk in self.blocks:
|
| 1011 |
+
x = blk(x)
|
| 1012 |
+
return x
|
| 1013 |
+
|
| 1014 |
+
@register_model
|
| 1015 |
+
@capture_init_args
|
| 1016 |
+
class C2CProjector(Projector):
|
| 1017 |
+
"""
|
| 1018 |
+
Concise projector specialized to a fixed C2C configuration using StandardMLP.
|
| 1019 |
+
- Projections: StandardMLP (pre-RMSNorm, SwiGLU, residual per sublayer)
|
| 1020 |
+
- Concat: enabled, followed by linear combiner to target size
|
| 1021 |
+
- Gate: scalar parameter with Gumbel-sigmoid during training
|
| 1022 |
+
- Weights: input-dependent, head_merged granularity using target and projected key
|
| 1023 |
+
- Target preservation: add_self=True, preserve_target_weight=False
|
| 1024 |
+
- Temperatures: annealed gate temperature (1.0 -> 0.001 over 1929 steps), scalar_temperature=1.0
|
| 1025 |
+
"""
|
| 1026 |
+
|
| 1027 |
+
def __init__(
|
| 1028 |
+
self,
|
| 1029 |
+
source_dim: int,
|
| 1030 |
+
target_dim: int,
|
| 1031 |
+
source_num_heads: int = 1,
|
| 1032 |
+
target_num_heads: int = 1,
|
| 1033 |
+
intermediate_dim: int = 1024,
|
| 1034 |
+
hidden_dim: int = 1024,
|
| 1035 |
+
num_layers: int = 3,
|
| 1036 |
+
dropout: float = 0.1,
|
| 1037 |
+
initial_temperature: float = 1.0,
|
| 1038 |
+
final_temperature: float = 0.001,
|
| 1039 |
+
anneal_steps: int = 1929,
|
| 1040 |
+
dtype: torch.dtype = torch.float32,
|
| 1041 |
+
):
|
| 1042 |
+
super().__init__()
|
| 1043 |
+
|
| 1044 |
+
assert num_layers >= 3, "num_layers must be >= 3"
|
| 1045 |
+
|
| 1046 |
+
# Dimensions
|
| 1047 |
+
self.source_dim = source_dim
|
| 1048 |
+
self.target_dim = target_dim
|
| 1049 |
+
self.source_num_heads = source_num_heads
|
| 1050 |
+
self.target_num_heads = target_num_heads
|
| 1051 |
+
|
| 1052 |
+
# Sizes
|
| 1053 |
+
in_dim = source_dim * source_num_heads
|
| 1054 |
+
out_dim = target_dim * target_num_heads
|
| 1055 |
+
|
| 1056 |
+
# 1) concat(source_X, target_X) then project to hidden_dim
|
| 1057 |
+
self.key_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype)
|
| 1058 |
+
self.value_in = nn.Linear(in_dim + out_dim, hidden_dim, bias=True, dtype=dtype)
|
| 1059 |
+
|
| 1060 |
+
# 2) one-layer common embedding MLP to get intermediate representation (at hidden_dim)
|
| 1061 |
+
self.key_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype)
|
| 1062 |
+
self.value_mlp1 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=1, dropout=dropout, dtype=dtype)
|
| 1063 |
+
|
| 1064 |
+
# 3a) intermediate representation → (L-2)-layer MLP for weights → project to head dim
|
| 1065 |
+
self.key_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype)
|
| 1066 |
+
self.value_scalar_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=hidden_dim, num_layers=1, dropout=dropout, dtype=dtype)
|
| 1067 |
+
self.key_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype)
|
| 1068 |
+
self.value_scalar_head = nn.Linear(hidden_dim, target_num_heads, dtype=dtype)
|
| 1069 |
+
|
| 1070 |
+
# 3b) intermediate representation → (L-2)-layer MLP for projected_X → finally project hidden_dim → out_dim
|
| 1071 |
+
self.key_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype)
|
| 1072 |
+
self.value_proj_mlp2 = RegularMLP(hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, num_layers=num_layers-2, dropout=dropout, dtype=dtype)
|
| 1073 |
+
self.key_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype)
|
| 1074 |
+
self.value_proj_out = nn.Linear(hidden_dim, out_dim, bias=True, dtype=dtype)
|
| 1075 |
+
|
| 1076 |
+
# Scalar key/value gate parameters and temperature schedule
|
| 1077 |
+
self.key_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
|
| 1078 |
+
self.value_gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
|
| 1079 |
+
self.use_gumbel = True
|
| 1080 |
+
self.register_buffer("gate_temperature", torch.tensor(initial_temperature, dtype=dtype))
|
| 1081 |
+
self.initial_temperature = initial_temperature
|
| 1082 |
+
self.final_temperature = final_temperature
|
| 1083 |
+
self.anneal_steps = anneal_steps
|
| 1084 |
+
|
| 1085 |
+
# Temperature for weight normalization
|
| 1086 |
+
self.scalar_temperature = 1.0
|
| 1087 |
+
|
| 1088 |
+
def update_temperature(self, step: int):
|
| 1089 |
+
ratio = min(step / self.anneal_steps, 1.0)
|
| 1090 |
+
temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
|
| 1091 |
+
self.gate_temperature.fill_(temp)
|
| 1092 |
+
|
| 1093 |
+
def forward(
|
| 1094 |
+
self,
|
| 1095 |
+
source_kv: Tuple[Tensor, Tensor],
|
| 1096 |
+
target_kv: Tuple[Tensor, Tensor],
|
| 1097 |
+
position_ids: Optional[Tensor] = None,
|
| 1098 |
+
max_pos: Optional[Tensor] = None,
|
| 1099 |
+
) -> Tuple[Tensor, Tensor]:
|
| 1100 |
+
source_key, source_value = source_kv
|
| 1101 |
+
target_key, target_value = target_kv
|
| 1102 |
+
|
| 1103 |
+
B, Hs, N, Ds = source_key.shape
|
| 1104 |
+
_, Ht, _, Dt = target_key.shape
|
| 1105 |
+
|
| 1106 |
+
# Flatten heads
|
| 1107 |
+
source_key_flat = source_key.transpose(1, 2).contiguous().view(B, N, Hs * Ds)
|
| 1108 |
+
source_value_flat = source_value.transpose(1, 2).contiguous().view(B, N, Hs * Ds)
|
| 1109 |
+
target_key_flat = target_key.transpose(1, 2).contiguous().view(B, N, Ht * Dt)
|
| 1110 |
+
target_value_flat = target_value.transpose(1, 2).contiguous().view(B, N, Ht * Dt)
|
| 1111 |
+
|
| 1112 |
+
# 1) concat source and target features along channel
|
| 1113 |
+
key_cat = torch.cat([source_key_flat, target_key_flat], dim=-1)
|
| 1114 |
+
value_cat = torch.cat([source_value_flat, target_value_flat], dim=-1)
|
| 1115 |
+
|
| 1116 |
+
# 2) project to hidden dim
|
| 1117 |
+
key_hidden = self.key_in(key_cat)
|
| 1118 |
+
value_hidden = self.value_in(value_cat)
|
| 1119 |
+
|
| 1120 |
+
# 3) one-layer common embedding MLP to get intermediate representation (at hidden_dim)
|
| 1121 |
+
key_hidden = self.key_mlp1(key_hidden)
|
| 1122 |
+
value_hidden = self.value_mlp1(value_hidden)
|
| 1123 |
+
|
| 1124 |
+
# 4b) intermediate representation -> projected feature path
|
| 1125 |
+
key_proj_hidden = self.key_proj_out(self.key_proj_mlp2(key_hidden)) # (B, N, Ht * Dt)
|
| 1126 |
+
value_proj_hidden = self.value_proj_out(self.value_proj_mlp2(value_hidden)) # (B, N, Ht * Dt)
|
| 1127 |
+
projected_key = key_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt)
|
| 1128 |
+
projected_value = value_proj_hidden.view(B, N, Ht, Dt).transpose(1, 2) # (B, Ht, N, Dt)
|
| 1129 |
+
|
| 1130 |
+
# 4a) intermediate representation -> scalar path
|
| 1131 |
+
key_scalar = self.key_scalar_head(self.key_scalar_mlp2(key_hidden)) # (B, N, Ht)
|
| 1132 |
+
value_scalar = self.value_scalar_head(self.value_scalar_mlp2(value_hidden)) # (B, N, Ht)
|
| 1133 |
+
key_scalar = key_scalar.permute(0, 2, 1).unsqueeze(-1) # (B, Ht, N, 1)
|
| 1134 |
+
value_scalar = value_scalar.permute(0, 2, 1).unsqueeze(-1) # (B, Ht, N, 1)
|
| 1135 |
+
|
| 1136 |
+
# Key/value gates: element-wise Gumbel noise with scalar logits (broadcast over channels)
|
| 1137 |
+
key_gate_logit = self.key_gate_logit.view(1, 1, 1, 1)
|
| 1138 |
+
value_gate_logit = self.value_gate_logit.view(1, 1, 1, 1)
|
| 1139 |
+
if self.training and self.use_gumbel:
|
| 1140 |
+
u1 = torch.rand(B, Ht, N, 1, device=key_gate_logit.device, dtype=key_gate_logit.dtype)
|
| 1141 |
+
u2 = torch.rand(B, Ht, N, 1, device=value_gate_logit.device, dtype=value_gate_logit.dtype)
|
| 1142 |
+
g1 = -torch.log(-torch.log(u1 + 1e-20) + 1e-20)
|
| 1143 |
+
g2 = -torch.log(-torch.log(u2 + 1e-20) + 1e-20)
|
| 1144 |
+
key_gate = torch.sigmoid((key_gate_logit + g1) / self.gate_temperature)
|
| 1145 |
+
value_gate = torch.sigmoid((value_gate_logit + g2) / self.gate_temperature)
|
| 1146 |
+
else:
|
| 1147 |
+
key_gate = (key_gate_logit > 0).float()
|
| 1148 |
+
value_gate = (value_gate_logit > 0).float()
|
| 1149 |
+
|
| 1150 |
+
# Normalize scalars (scalar_temperature=1.0)
|
| 1151 |
+
norm_key_scalar = torch.sigmoid(key_scalar)
|
| 1152 |
+
norm_value_scalar = torch.sigmoid(value_scalar)
|
| 1153 |
+
|
| 1154 |
+
# Combine (preserve_target_weight=False, add_self=True)
|
| 1155 |
+
output_key = target_key + key_gate * norm_key_scalar * projected_key
|
| 1156 |
+
output_value = target_value + value_gate * norm_value_scalar * projected_value
|
| 1157 |
+
|
| 1158 |
+
# Expose capture attributes for downstream analysis scripts
|
| 1159 |
+
try:
|
| 1160 |
+
# Store normalized scalars (detach to avoid autograd, keep device-agnostic via CPU)
|
| 1161 |
+
self.last_norm_key_scalar = norm_key_scalar.detach().cpu()
|
| 1162 |
+
self.last_norm_value_scalar = norm_value_scalar.detach().cpu()
|
| 1163 |
+
# Store gate logits as python floats (parameters are scalar)
|
| 1164 |
+
self.last_key_gate_logit = float(self.key_gate_logit.detach().cpu().item())
|
| 1165 |
+
self.last_value_gate_logit = float(self.value_gate_logit.detach().cpu().item())
|
| 1166 |
+
except Exception:
|
| 1167 |
+
# Best-effort capture; never break forward path
|
| 1168 |
+
pass
|
| 1169 |
+
|
| 1170 |
+
return output_key, output_value
|
| 1171 |
+
|
| 1172 |
+
def save_projector(obj: Projector, file_path: str) -> None:
|
| 1173 |
+
save_object(obj, file_path)
|
| 1174 |
+
|
| 1175 |
+
def load_projector(file_path: str, override_args: Optional[dict] = None) -> Projector:
|
| 1176 |
+
return load_object(file_path, get_projector_class, override_args)
|
| 1177 |
+
|
| 1178 |
+
def create_projector(projector_type: str, **kwargs) -> Projector:
|
| 1179 |
+
"""
|
| 1180 |
+
Factory function to create a projector based on type.
|
| 1181 |
+
|
| 1182 |
+
Args:
|
| 1183 |
+
projector_type: String indicating the type of projector
|
| 1184 |
+
**kwargs: Additional arguments to pass to the projector constructor
|
| 1185 |
+
|
| 1186 |
+
Returns:
|
| 1187 |
+
An instance of the appropriate projector
|
| 1188 |
+
"""
|
| 1189 |
+
# Prefer using the unified registry getter (handles case-insensitive keys)
|
| 1190 |
+
try:
|
| 1191 |
+
cls = get_projector_class(projector_type)
|
| 1192 |
+
except ValueError as e:
|
| 1193 |
+
raise e
|
| 1194 |
+
return cls(**kwargs)
|
rosetta/model/sampling.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
def sample_token(logits: torch.Tensor, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1) -> Union[int, torch.Tensor]:
|
| 6 |
+
"""Sample a token from logits using temperature, top-p, and top-k sampling.
|
| 7 |
+
Args:
|
| 8 |
+
logits: Token logits of shape [vocab_size] or [batch_size, vocab_size]
|
| 9 |
+
temperature: Temperature for sampling (>0). Higher values produce more random samples.
|
| 10 |
+
top_p: Top-p probability threshold for nucleus sampling (0 < top_p ≤ 1)
|
| 11 |
+
top_k: Top-k threshold for sampling (if -1, no top-k filtering is applied)
|
| 12 |
+
Returns:
|
| 13 |
+
Sampled token ID (int for single sample, tensor for batch)
|
| 14 |
+
"""
|
| 15 |
+
if not isinstance(logits, torch.Tensor):
|
| 16 |
+
raise TypeError("logits must be a torch.Tensor")
|
| 17 |
+
|
| 18 |
+
if logits.dim() not in [1, 2]:
|
| 19 |
+
raise ValueError("logits must have shape [vocab_size] or [batch_size, vocab_size]")
|
| 20 |
+
|
| 21 |
+
# Handle single dimension input
|
| 22 |
+
is_single_input = logits.dim() == 1
|
| 23 |
+
if is_single_input:
|
| 24 |
+
logits = logits.unsqueeze(0)
|
| 25 |
+
|
| 26 |
+
batch_size = logits.shape[0]
|
| 27 |
+
|
| 28 |
+
# For greedy sampling (temperature=0), just return argmax
|
| 29 |
+
if temperature == 0 or temperature <= 1e-5:
|
| 30 |
+
tokens = torch.argmax(logits, dim=-1)
|
| 31 |
+
return tokens.item() if is_single_input else tokens
|
| 32 |
+
|
| 33 |
+
# Convert to probabilities
|
| 34 |
+
probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
|
| 35 |
+
|
| 36 |
+
# Apply top-k filtering first (if specified)
|
| 37 |
+
if top_k != -1:
|
| 38 |
+
# Get top-k values and indices
|
| 39 |
+
top_k_values, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]), dim=-1)
|
| 40 |
+
|
| 41 |
+
# Create a mask to zero out non-top-k probabilities
|
| 42 |
+
mask = torch.zeros_like(probs, dtype=torch.bool)
|
| 43 |
+
mask.scatter_(-1, top_k_indices, True)
|
| 44 |
+
|
| 45 |
+
# Zero out non-top-k probabilities
|
| 46 |
+
probs = probs * mask.float()
|
| 47 |
+
|
| 48 |
+
# Renormalize probabilities
|
| 49 |
+
probs = probs / probs.sum(dim=-1, keepdim=True)
|
| 50 |
+
|
| 51 |
+
# Apply top-p (nucleus) sampling
|
| 52 |
+
if top_p < 1.0:
|
| 53 |
+
# Sort probabilities in descending order
|
| 54 |
+
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
|
| 55 |
+
|
| 56 |
+
# Calculate cumulative probabilities
|
| 57 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 58 |
+
|
| 59 |
+
# Create a mask for probabilities to keep
|
| 60 |
+
# Values above top_p threshold are masked out
|
| 61 |
+
mask = cumulative_probs <= top_p
|
| 62 |
+
|
| 63 |
+
# Always keep at least one token
|
| 64 |
+
mask[:, 0] = True
|
| 65 |
+
|
| 66 |
+
# Zero out masked positions to exclude them from sampling
|
| 67 |
+
sorted_probs = sorted_probs * mask.float()
|
| 68 |
+
|
| 69 |
+
# Renormalize probabilities
|
| 70 |
+
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
|
| 71 |
+
|
| 72 |
+
# Sample from the filtered distribution
|
| 73 |
+
sampled_indices = torch.multinomial(sorted_probs, num_samples=1)
|
| 74 |
+
|
| 75 |
+
# Map back to original vocabulary indices
|
| 76 |
+
tokens = torch.gather(sorted_indices, dim=-1, index=sampled_indices)
|
| 77 |
+
tokens = tokens.squeeze(-1) # Remove sample dimension
|
| 78 |
+
else:
|
| 79 |
+
# Direct sampling if no top-p filtering
|
| 80 |
+
tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
| 81 |
+
|
| 82 |
+
return tokens.item() if is_single_input else tokens
|
rosetta/model/wrapper.py
ADDED
|
@@ -0,0 +1,840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The ensemble of multiple standard transformers LLM models, with automatic kv-cache projection. It shares the same interface as the standard transformers LLM models.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import List, Optional, Union
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 9 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 10 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 11 |
+
import json
|
| 12 |
+
|
| 13 |
+
from rosetta.model.projector import Projector
|
| 14 |
+
from rosetta.model.sampling import sample_token
|
| 15 |
+
from transformers.utils import ModelOutput
|
| 16 |
+
try:
|
| 17 |
+
from transformers.generation.utils import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
|
| 18 |
+
except Exception:
|
| 19 |
+
GreedySearchDecoderOnlyOutput = None
|
| 20 |
+
SampleDecoderOnlyOutput = None
|
| 21 |
+
|
| 22 |
+
def clone_kv_cache(kv_cache: DynamicCache) -> DynamicCache:
|
| 23 |
+
new_cache = DynamicCache()
|
| 24 |
+
for k, v in zip(kv_cache.key_cache, kv_cache.value_cache):
|
| 25 |
+
new_cache.key_cache.append(k.clone().detach())
|
| 26 |
+
new_cache.value_cache.append(v.clone().detach())
|
| 27 |
+
return new_cache
|
| 28 |
+
|
| 29 |
+
def hybrid_to_dynamic(hybrid_cache):
|
| 30 |
+
if hybrid_cache is None:
|
| 31 |
+
return None
|
| 32 |
+
if isinstance(hybrid_cache, DynamicCache):
|
| 33 |
+
return hybrid_cache
|
| 34 |
+
|
| 35 |
+
# 手动从 HybridCache 提取
|
| 36 |
+
if hasattr(hybrid_cache, "key_cache") and hasattr(hybrid_cache, "value_cache"):
|
| 37 |
+
keys = hybrid_cache.key_cache
|
| 38 |
+
values = hybrid_cache.value_cache
|
| 39 |
+
assert len(keys) == len(values), "key/value 层数不一致"
|
| 40 |
+
|
| 41 |
+
legacy_cache = [(k, v) for k, v in zip(keys, values)]
|
| 42 |
+
return DynamicCache.from_legacy_cache(legacy_cache)
|
| 43 |
+
|
| 44 |
+
raise TypeError(f"Unsupported cache type: {type(hybrid_cache)}")
|
| 45 |
+
|
| 46 |
+
class RosettaModel(nn.Module):
|
| 47 |
+
"""
|
| 48 |
+
Drop in replacement for the standard transformers LLM models, like Qwen3ForCausalLM
|
| 49 |
+
"""
|
| 50 |
+
def __init__(self, model_list: List[PreTrainedModel], base_model_idx = 0, projector_list: List[Projector] = [], aggregator_list: List[nn.Module] = []):
|
| 51 |
+
super().__init__()
|
| 52 |
+
# model list: a list of model, model 0 by default is the base model
|
| 53 |
+
# projector list: a list of projector
|
| 54 |
+
# standard init with additional model list parameter
|
| 55 |
+
# kv-cache dict: key (source_model_idx, target_model_idx), value (Cache), assume only convert at prefill with one type of model
|
| 56 |
+
# projector dict: key (source_model_idx, target_model_idx) value dict(key (source_model_layer_idx, M_target value )
|
| 57 |
+
|
| 58 |
+
self.base_model_idx = base_model_idx
|
| 59 |
+
self.model_list = nn.ModuleList(model_list)
|
| 60 |
+
|
| 61 |
+
device = model_list[base_model_idx].device
|
| 62 |
+
dtype = model_list[base_model_idx].dtype
|
| 63 |
+
self.projector_list = nn.ModuleList(projector_list).to(device=device, dtype=dtype)
|
| 64 |
+
self.aggregator_list = nn.ModuleList(aggregator_list).to(device=device, dtype=dtype)
|
| 65 |
+
|
| 66 |
+
self.projector_dict = {}
|
| 67 |
+
self.aggregator_dict = {}
|
| 68 |
+
self.kv_cache_dict = {}
|
| 69 |
+
self._generation_hook_handlers = []
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def device(self):
|
| 73 |
+
return self.model_list[self.base_model_idx].device
|
| 74 |
+
|
| 75 |
+
def to(self, device):
|
| 76 |
+
"""
|
| 77 |
+
Move the RosettaModel and all underlying models and projectors to the specified device.
|
| 78 |
+
"""
|
| 79 |
+
super().to(device)
|
| 80 |
+
for model in self.model_list:
|
| 81 |
+
model.to(device)
|
| 82 |
+
for projector in self.projector_list:
|
| 83 |
+
projector.to(device)
|
| 84 |
+
for aggregator in self.aggregator_list:
|
| 85 |
+
aggregator.to(device)
|
| 86 |
+
return self
|
| 87 |
+
|
| 88 |
+
# set projector
|
| 89 |
+
def set_projector_config(self,
|
| 90 |
+
source_model_idx: int,
|
| 91 |
+
source_model_layer_idx: int,
|
| 92 |
+
target_model_idx: int,
|
| 93 |
+
target_model_layer_idx: int,
|
| 94 |
+
projector_idx: int):
|
| 95 |
+
"""
|
| 96 |
+
Set the projector configuration
|
| 97 |
+
Args:
|
| 98 |
+
source_model_idx: int, the index of the source model
|
| 99 |
+
source_model_layer_idx: int, the index of the source model layer
|
| 100 |
+
target_model_idx: int, the index of the target model
|
| 101 |
+
target_model_layer_idx: int, the index of the target model layer
|
| 102 |
+
projector_idx: int, the index of the projector
|
| 103 |
+
|
| 104 |
+
The projector dict structure supports multiple projectors per target layer.
|
| 105 |
+
Structure:
|
| 106 |
+
{
|
| 107 |
+
target_model_idx: {
|
| 108 |
+
source_model_idx: {
|
| 109 |
+
target_model_layer_idx: [(source_model_layer_idx, projector_idx), ...]
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
Repeated calls for the same (target, source, target_layer) append additional pairs.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
if target_model_idx not in self.projector_dict.keys():
|
| 117 |
+
self.projector_dict[target_model_idx] = {}
|
| 118 |
+
if source_model_idx not in self.projector_dict[target_model_idx].keys():
|
| 119 |
+
self.projector_dict[target_model_idx][source_model_idx] = {}
|
| 120 |
+
# Accumulate list of (source_layer, projector_idx) for this target layer
|
| 121 |
+
layer_entry = self.projector_dict[target_model_idx][source_model_idx].get(target_model_layer_idx)
|
| 122 |
+
if layer_entry is None:
|
| 123 |
+
self.projector_dict[target_model_idx][source_model_idx][target_model_layer_idx] = [(source_model_layer_idx, projector_idx)]
|
| 124 |
+
else:
|
| 125 |
+
layer_entry.append((source_model_layer_idx, projector_idx))
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def load_projector(self, projector_list):
|
| 129 |
+
self.projector_list: List[Projector] = projector_list
|
| 130 |
+
|
| 131 |
+
def load_aggregator(self, aggregator_list):
|
| 132 |
+
self.aggregator_list: List[nn.Module] = aggregator_list
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def get_projector(self,
|
| 136 |
+
source_model_idx,
|
| 137 |
+
source_model_layer_idx,
|
| 138 |
+
target_model_idx,
|
| 139 |
+
target_model_layer_idx):
|
| 140 |
+
pair_list = self.projector_dict[target_model_idx][source_model_idx][target_model_layer_idx]
|
| 141 |
+
if len(pair_list) == 0:
|
| 142 |
+
raise ValueError("No projector configured for the given target layer")
|
| 143 |
+
# Prefer exact source layer match
|
| 144 |
+
for src_layer, projector_id in pair_list:
|
| 145 |
+
if src_layer == source_model_layer_idx:
|
| 146 |
+
return self.projector_list[projector_id]
|
| 147 |
+
# Fallback: return the first projector
|
| 148 |
+
return self.projector_list[pair_list[0][1]]
|
| 149 |
+
|
| 150 |
+
def set_aggregator_idx(self,
|
| 151 |
+
source_model_idx: int,
|
| 152 |
+
target_model_idx: int,
|
| 153 |
+
target_model_layer_idx: int,
|
| 154 |
+
aggregator_idx: int):
|
| 155 |
+
if target_model_idx not in self.aggregator_dict:
|
| 156 |
+
self.aggregator_dict[target_model_idx] = {}
|
| 157 |
+
if source_model_idx not in self.aggregator_dict[target_model_idx]:
|
| 158 |
+
self.aggregator_dict[target_model_idx][source_model_idx] = {}
|
| 159 |
+
self.aggregator_dict[target_model_idx][source_model_idx][target_model_layer_idx] = aggregator_idx
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def load_json(file_name):
|
| 164 |
+
with open(file_name, "r") as f:
|
| 165 |
+
result = json.load(f)
|
| 166 |
+
return result
|
| 167 |
+
|
| 168 |
+
@staticmethod
|
| 169 |
+
def _convert_dict_keys_to_ints(obj):
|
| 170 |
+
"""
|
| 171 |
+
Recursively convert dictionary keys that look like integers back to int.
|
| 172 |
+
This reverses json.dump's coercion of dict keys to strings.
|
| 173 |
+
"""
|
| 174 |
+
if isinstance(obj, dict):
|
| 175 |
+
new_obj = {}
|
| 176 |
+
for key, value in obj.items():
|
| 177 |
+
if isinstance(key, str) and key.lstrip('-').isdigit():
|
| 178 |
+
new_key = int(key)
|
| 179 |
+
else:
|
| 180 |
+
new_key = key
|
| 181 |
+
new_obj[new_key] = RosettaModel._convert_dict_keys_to_ints(value)
|
| 182 |
+
return new_obj
|
| 183 |
+
if isinstance(obj, list):
|
| 184 |
+
return [RosettaModel._convert_dict_keys_to_ints(v) for v in obj]
|
| 185 |
+
return obj
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def save_projector_config(self, file_name):
|
| 189 |
+
with open(file_name, "w") as f:
|
| 190 |
+
json.dump(self.projector_dict, f)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def load_projector_config(self, config_path):
|
| 194 |
+
if config_path.endswith(".json"):
|
| 195 |
+
loaded = RosettaModel.load_json(config_path)
|
| 196 |
+
self.projector_dict = RosettaModel._convert_dict_keys_to_ints(loaded)
|
| 197 |
+
|
| 198 |
+
def save_aggregator_config(self, file_name):
|
| 199 |
+
with open(file_name, "w") as f:
|
| 200 |
+
json.dump(self.aggregator_dict, f)
|
| 201 |
+
|
| 202 |
+
def load_aggregator_config(self, config_path):
|
| 203 |
+
if config_path.endswith(".json"):
|
| 204 |
+
loaded = RosettaModel.load_json(config_path)
|
| 205 |
+
self.aggregator_dict = RosettaModel._convert_dict_keys_to_ints(loaded)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def set_kv_cache_dict(self, source_model_idx, target_model_idx, cache):
|
| 209 |
+
if target_model_idx not in self.kv_cache_dict.keys():
|
| 210 |
+
self.kv_cache_dict[target_model_idx] = {}
|
| 211 |
+
if cache is None:
|
| 212 |
+
# Initialize with a DynamicCache instead of RosettaCache for now
|
| 213 |
+
self.kv_cache_dict[target_model_idx][source_model_idx] = DynamicCache() # noqa, maybe we should use RosettaCache here
|
| 214 |
+
else:
|
| 215 |
+
self.kv_cache_dict[target_model_idx][source_model_idx] = cache
|
| 216 |
+
|
| 217 |
+
def forward(
|
| 218 |
+
self,
|
| 219 |
+
kv_cache_index: Optional[List] = None,
|
| 220 |
+
input_ids: Optional[Union[torch.LongTensor, List[torch.LongTensor]]] = None,
|
| 221 |
+
attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
| 222 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 223 |
+
past_key_values: Optional[Cache] = None,
|
| 224 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 225 |
+
labels: Optional[torch.LongTensor] = None,
|
| 226 |
+
use_cache: Optional[bool] = None,
|
| 227 |
+
output_attentions: Optional[bool] = None,
|
| 228 |
+
output_hidden_states: Optional[bool] = None,
|
| 229 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 230 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 231 |
+
# **kwargs: Unpack[KwargsForCausalLM],
|
| 232 |
+
*args,
|
| 233 |
+
**kwargs,
|
| 234 |
+
) -> CausalLMOutputWithPast:
|
| 235 |
+
"""
|
| 236 |
+
Forward pass
|
| 237 |
+
KVCache index is a list of tensors with shape (B, sec_seq_len, 2), indicating the source and target kv cache model index
|
| 238 |
+
|
| 239 |
+
If input_ids is LongTensor, default to same input ids for different models
|
| 240 |
+
If input_ids is Tuple, default to different input ids for different models.
|
| 241 |
+
|
| 242 |
+
No Rosetta: (-1, 0)
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
# noqa
|
| 246 |
+
self.kv_cache_dict = dict()
|
| 247 |
+
|
| 248 |
+
# Handle different input formats: if input_ids is a list, use per-model inputs
|
| 249 |
+
if isinstance(input_ids, list):
|
| 250 |
+
# Use list format: different input_ids and attention_mask for each model
|
| 251 |
+
base_input_ids = input_ids[self.base_model_idx] if input_ids is not None else None
|
| 252 |
+
base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
|
| 253 |
+
_, seqlen = base_input_ids.size() if base_input_ids is not None else (0, 0)
|
| 254 |
+
else:
|
| 255 |
+
# Use tensor format: same input_ids and attention_mask for all models (backward compatibility)
|
| 256 |
+
base_input_ids = input_ids
|
| 257 |
+
base_attention_mask = attention_mask
|
| 258 |
+
_, seqlen = input_ids.size() if input_ids is not None else (0, 0)
|
| 259 |
+
|
| 260 |
+
num_sections = len(kv_cache_index) if kv_cache_index is not None else 1
|
| 261 |
+
|
| 262 |
+
section_lengths = [kv_cache_index[i].shape[1] for i in range(num_sections)] if kv_cache_index is not None else [seqlen]
|
| 263 |
+
section_starts = [0]
|
| 264 |
+
for l in section_lengths:
|
| 265 |
+
section_starts.append(section_starts[-1] + l)
|
| 266 |
+
|
| 267 |
+
curr_base_kv_cache = past_key_values
|
| 268 |
+
|
| 269 |
+
if seqlen >= 1:
|
| 270 |
+
for i in range(num_sections):
|
| 271 |
+
start = section_starts[i]
|
| 272 |
+
end = section_starts[i + 1]
|
| 273 |
+
prefill_input_ids = base_input_ids[:, start:end] if base_input_ids is not None else None
|
| 274 |
+
prefill_attention_mask = base_attention_mask[:, :end] if base_attention_mask is not None else None
|
| 275 |
+
prefill_position_ids = position_ids[:, start:end] if position_ids is not None else None
|
| 276 |
+
prefill_labels = labels[:, start:end] if labels is not None else None
|
| 277 |
+
|
| 278 |
+
output = self.model_list[self.base_model_idx].forward(
|
| 279 |
+
input_ids=prefill_input_ids,
|
| 280 |
+
attention_mask=prefill_attention_mask,
|
| 281 |
+
position_ids=prefill_position_ids,
|
| 282 |
+
past_key_values=curr_base_kv_cache,
|
| 283 |
+
labels=prefill_labels,
|
| 284 |
+
use_cache=use_cache,
|
| 285 |
+
output_attentions=output_attentions,
|
| 286 |
+
output_hidden_states=output_hidden_states,
|
| 287 |
+
*args,
|
| 288 |
+
**kwargs
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
if self.base_model_idx not in self.kv_cache_dict:
|
| 292 |
+
self.kv_cache_dict[self.base_model_idx] = {}
|
| 293 |
+
if self.base_model_idx not in self.kv_cache_dict[self.base_model_idx]:
|
| 294 |
+
self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = None
|
| 295 |
+
self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = output.past_key_values
|
| 296 |
+
|
| 297 |
+
curr_base_kv_cache: DynamicCache = output.past_key_values
|
| 298 |
+
|
| 299 |
+
if i != num_sections - 1:
|
| 300 |
+
for source_model_idx in range(1, len(self.model_list)):
|
| 301 |
+
if self.base_model_idx not in self.kv_cache_dict:
|
| 302 |
+
self.kv_cache_dict[self.base_model_idx] = {}
|
| 303 |
+
if source_model_idx not in self.kv_cache_dict[self.base_model_idx]:
|
| 304 |
+
self.kv_cache_dict[self.base_model_idx][source_model_idx] = None
|
| 305 |
+
|
| 306 |
+
# Get model-specific input_ids and attention_mask
|
| 307 |
+
if isinstance(input_ids, list):
|
| 308 |
+
source_input_ids = input_ids[source_model_idx]
|
| 309 |
+
source_attention_mask = attention_mask[source_model_idx] if attention_mask is not None else None
|
| 310 |
+
source_prefill_input_ids = source_input_ids[:, start:end] if source_input_ids is not None else None
|
| 311 |
+
source_prefill_attention_mask = source_attention_mask[:, :end] if source_attention_mask is not None else None
|
| 312 |
+
else:
|
| 313 |
+
# Backward compatibility: use same input for all models
|
| 314 |
+
source_prefill_input_ids = prefill_input_ids
|
| 315 |
+
source_prefill_attention_mask = prefill_attention_mask
|
| 316 |
+
|
| 317 |
+
model = self.model_list[source_model_idx]
|
| 318 |
+
was_training = model.training
|
| 319 |
+
had_gc = getattr(model, "is_gradient_checkpointing", False)
|
| 320 |
+
|
| 321 |
+
try:
|
| 322 |
+
if was_training:
|
| 323 |
+
model.eval()
|
| 324 |
+
if had_gc:
|
| 325 |
+
model.gradient_checkpointing_disable()
|
| 326 |
+
|
| 327 |
+
with torch.no_grad():
|
| 328 |
+
out = model(
|
| 329 |
+
input_ids=source_prefill_input_ids,
|
| 330 |
+
attention_mask=source_prefill_attention_mask,
|
| 331 |
+
position_ids=prefill_position_ids,
|
| 332 |
+
past_key_values=self.kv_cache_dict[self.base_model_idx][source_model_idx],
|
| 333 |
+
use_cache=True,
|
| 334 |
+
return_dict=True,
|
| 335 |
+
)
|
| 336 |
+
curr_source_kv_cache = out.past_key_values
|
| 337 |
+
finally:
|
| 338 |
+
if had_gc:
|
| 339 |
+
model.gradient_checkpointing_enable()
|
| 340 |
+
if was_training:
|
| 341 |
+
model.train()
|
| 342 |
+
|
| 343 |
+
curr_source_kv_cache = hybrid_to_dynamic(curr_source_kv_cache)
|
| 344 |
+
self.kv_cache_dict[self.base_model_idx][source_model_idx] = curr_source_kv_cache
|
| 345 |
+
|
| 346 |
+
# calculate source model kvcache and apply projections
|
| 347 |
+
if self.base_model_idx in self.projector_dict:
|
| 348 |
+
source_model_idx = kv_cache_index[i][0][0][0].item() # Get the source model index from the kv_cache_index
|
| 349 |
+
if source_model_idx != -1:
|
| 350 |
+
for target_layer_idx, entry in self.projector_dict[self.base_model_idx][source_model_idx].items():
|
| 351 |
+
base_key_cache, base_value_cache = curr_base_kv_cache[target_layer_idx]
|
| 352 |
+
new_base_key_cache = base_key_cache[:, :, start:end, :]
|
| 353 |
+
new_base_value_cache = base_value_cache[:, :, start:end, :]
|
| 354 |
+
new_base_kv_cache = (new_base_key_cache, new_base_value_cache)
|
| 355 |
+
|
| 356 |
+
pair_list = entry
|
| 357 |
+
|
| 358 |
+
projected_kv_list = []
|
| 359 |
+
source_kv_list = []
|
| 360 |
+
for source_model_layer_idx, projector_idx in pair_list:
|
| 361 |
+
source_key_cache, source_value_cache = self.kv_cache_dict[self.base_model_idx][source_model_idx][source_model_layer_idx]
|
| 362 |
+
new_source_key_cache = source_key_cache[:, :, start:end, :]
|
| 363 |
+
new_source_value_cache = source_value_cache[:, :, start:end, :]
|
| 364 |
+
new_source_kv_cache = (new_source_key_cache, new_source_value_cache)
|
| 365 |
+
projected_key, projected_value = self.projector_list[projector_idx].forward(
|
| 366 |
+
new_source_kv_cache, # tuple of (key, value), each of shape (B, N, H, D)
|
| 367 |
+
new_base_kv_cache
|
| 368 |
+
)
|
| 369 |
+
projected_kv_list.append((projected_key, projected_value))
|
| 370 |
+
source_kv_list.append(new_source_kv_cache)
|
| 371 |
+
|
| 372 |
+
# Aggregate (fallback to first projector if no aggregator is available)
|
| 373 |
+
use_aggregator = (
|
| 374 |
+
len(projected_kv_list) > 1 and
|
| 375 |
+
len(self.aggregator_list) > 0 and
|
| 376 |
+
self.base_model_idx in self.aggregator_dict and
|
| 377 |
+
source_model_idx in self.aggregator_dict[self.base_model_idx] and
|
| 378 |
+
target_layer_idx in self.aggregator_dict[self.base_model_idx][source_model_idx]
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
if use_aggregator:
|
| 382 |
+
aggregator_idx = self.aggregator_dict[self.base_model_idx][source_model_idx][target_layer_idx]
|
| 383 |
+
agg_key, agg_value = self.aggregator_list[aggregator_idx].forward(
|
| 384 |
+
source_kv_list,
|
| 385 |
+
new_base_kv_cache,
|
| 386 |
+
projected_kv_list
|
| 387 |
+
)
|
| 388 |
+
else:
|
| 389 |
+
# Fallback to first projector result when no aggregator is available
|
| 390 |
+
agg_key, agg_value = projected_kv_list[0]
|
| 391 |
+
|
| 392 |
+
# Update cache with aggregated result
|
| 393 |
+
curr_base_kv_cache.key_cache[target_layer_idx][:, :, start:end, :] = agg_key
|
| 394 |
+
curr_base_kv_cache.value_cache[target_layer_idx][:, :, start:end, :] = agg_value
|
| 395 |
+
|
| 396 |
+
output.past_key_values = curr_base_kv_cache
|
| 397 |
+
|
| 398 |
+
# use base model for decode phase
|
| 399 |
+
else:
|
| 400 |
+
# Handle list input format for decode phase as well
|
| 401 |
+
decode_input_ids = input_ids[self.base_model_idx] if isinstance(input_ids, list) else input_ids
|
| 402 |
+
decode_attention_mask = attention_mask[self.base_model_idx] if isinstance(attention_mask, list) else attention_mask
|
| 403 |
+
|
| 404 |
+
output = self.model_list[self.base_model_idx].forward(
|
| 405 |
+
input_ids=decode_input_ids,
|
| 406 |
+
attention_mask=decode_attention_mask,
|
| 407 |
+
position_ids=position_ids,
|
| 408 |
+
past_key_values=curr_base_kv_cache,
|
| 409 |
+
inputs_embeds=inputs_embeds,
|
| 410 |
+
labels=labels,
|
| 411 |
+
use_cache=use_cache,
|
| 412 |
+
output_attentions=output_attentions,
|
| 413 |
+
output_hidden_states=output_hidden_states,
|
| 414 |
+
cache_position=cache_position,
|
| 415 |
+
*args,
|
| 416 |
+
**kwargs
|
| 417 |
+
)
|
| 418 |
+
return output
|
| 419 |
+
|
| 420 |
+
def oracle_forward(
|
| 421 |
+
self,
|
| 422 |
+
kv_cache_index: Optional[List] = None,
|
| 423 |
+
input_ids: Optional[Union[torch.LongTensor, List[torch.LongTensor]]] = None,
|
| 424 |
+
attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
| 425 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 426 |
+
past_key_values: Optional[Cache] = None,
|
| 427 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 428 |
+
labels: Optional[torch.LongTensor] = None,
|
| 429 |
+
use_cache: Optional[bool] = None,
|
| 430 |
+
output_attentions: Optional[bool] = None,
|
| 431 |
+
output_hidden_states: Optional[bool] = None,
|
| 432 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 433 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 434 |
+
# **kwargs: Unpack[KwargsForCausalLM],
|
| 435 |
+
*args,
|
| 436 |
+
**kwargs,
|
| 437 |
+
) -> CausalLMOutputWithPast:
|
| 438 |
+
"""
|
| 439 |
+
Forward pass
|
| 440 |
+
KVCache index is a list of tensors with shape (B, sec_seq_len, 2), indicating the source and target kv cache model index
|
| 441 |
+
|
| 442 |
+
If input_ids is LongTensor, default to same input ids for different models
|
| 443 |
+
If input_ids is Tuple, default to different input ids for different models.
|
| 444 |
+
|
| 445 |
+
No Rosetta: (-1, 0)
|
| 446 |
+
"""
|
| 447 |
+
|
| 448 |
+
# noqa
|
| 449 |
+
self.kv_cache_dict = dict()
|
| 450 |
+
|
| 451 |
+
# Handle different input formats: if input_ids is a list, use per-model inputs
|
| 452 |
+
if isinstance(input_ids, list):
|
| 453 |
+
# Use list format: different input_ids and attention_mask for each model
|
| 454 |
+
base_input_ids = input_ids[self.base_model_idx] if input_ids is not None else None
|
| 455 |
+
base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
|
| 456 |
+
_, seqlen = base_input_ids.size() if base_input_ids is not None else (0, 0)
|
| 457 |
+
else:
|
| 458 |
+
# Use tensor format: same input_ids and attention_mask for all models (backward compatibility)
|
| 459 |
+
base_input_ids = input_ids
|
| 460 |
+
base_attention_mask = attention_mask
|
| 461 |
+
_, seqlen = input_ids.size() if input_ids is not None else (0, 0)
|
| 462 |
+
|
| 463 |
+
num_sections = len(kv_cache_index) if kv_cache_index is not None else 1
|
| 464 |
+
|
| 465 |
+
section_lengths = [kv_cache_index[i].shape[1] for i in range(num_sections)] if kv_cache_index is not None else [seqlen]
|
| 466 |
+
section_starts = [0]
|
| 467 |
+
for l in section_lengths:
|
| 468 |
+
section_starts.append(section_starts[-1] + l)
|
| 469 |
+
|
| 470 |
+
curr_base_kv_cache = past_key_values
|
| 471 |
+
|
| 472 |
+
loss = nn.MSELoss()
|
| 473 |
+
loss_output = 0
|
| 474 |
+
if seqlen > 1:
|
| 475 |
+
for i in range(num_sections):
|
| 476 |
+
start = section_starts[i]
|
| 477 |
+
end = section_starts[i + 1]
|
| 478 |
+
prefill_input_ids = base_input_ids[:, start:end] if base_input_ids is not None else None
|
| 479 |
+
prefill_attention_mask = base_attention_mask[:, :end] if base_attention_mask is not None else None
|
| 480 |
+
prefill_position_ids = position_ids[:, start:end] if position_ids is not None else None
|
| 481 |
+
prefill_labels = labels[:, start:end] if labels is not None else None
|
| 482 |
+
|
| 483 |
+
# calculate target model kvcache
|
| 484 |
+
output = self.model_list[self.base_model_idx].forward(
|
| 485 |
+
input_ids=prefill_input_ids,
|
| 486 |
+
attention_mask=prefill_attention_mask,
|
| 487 |
+
position_ids=prefill_position_ids,
|
| 488 |
+
past_key_values=curr_base_kv_cache,
|
| 489 |
+
labels=prefill_labels,
|
| 490 |
+
use_cache=use_cache,
|
| 491 |
+
output_attentions=output_attentions,
|
| 492 |
+
output_hidden_states=output_hidden_states,
|
| 493 |
+
*args,
|
| 494 |
+
**kwargs
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
if self.base_model_idx not in self.kv_cache_dict:
|
| 498 |
+
self.kv_cache_dict[self.base_model_idx] = {}
|
| 499 |
+
if self.base_model_idx not in self.kv_cache_dict[self.base_model_idx]:
|
| 500 |
+
self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = None
|
| 501 |
+
self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = output.past_key_values
|
| 502 |
+
|
| 503 |
+
curr_base_kv_cache: DynamicCache = output.past_key_values
|
| 504 |
+
|
| 505 |
+
if i != num_sections - 1:
|
| 506 |
+
for source_model_idx in range(1, len(self.model_list)):
|
| 507 |
+
if self.base_model_idx not in self.kv_cache_dict:
|
| 508 |
+
self.kv_cache_dict[self.base_model_idx] = {}
|
| 509 |
+
if source_model_idx not in self.kv_cache_dict[self.base_model_idx]:
|
| 510 |
+
self.kv_cache_dict[self.base_model_idx][source_model_idx] = None
|
| 511 |
+
|
| 512 |
+
# Get model-specific input_ids and attention_mask
|
| 513 |
+
if isinstance(input_ids, list):
|
| 514 |
+
source_input_ids = input_ids[source_model_idx]
|
| 515 |
+
source_attention_mask = attention_mask[source_model_idx] if attention_mask is not None else None
|
| 516 |
+
source_prefill_input_ids = source_input_ids[:, start:end] if source_input_ids is not None else None
|
| 517 |
+
source_prefill_attention_mask = source_attention_mask[:, :end] if source_attention_mask is not None else None
|
| 518 |
+
else:
|
| 519 |
+
# Backward compatibility: use same input for all models
|
| 520 |
+
source_prefill_input_ids = prefill_input_ids
|
| 521 |
+
source_prefill_attention_mask = prefill_attention_mask
|
| 522 |
+
|
| 523 |
+
curr_source_kv_cache = self.model_list[source_model_idx].forward(
|
| 524 |
+
input_ids=source_prefill_input_ids,
|
| 525 |
+
attention_mask=source_prefill_attention_mask,
|
| 526 |
+
position_ids=prefill_position_ids,
|
| 527 |
+
past_key_values=self.kv_cache_dict[self.base_model_idx][source_model_idx],
|
| 528 |
+
use_cache=use_cache,
|
| 529 |
+
output_attentions=output_attentions,
|
| 530 |
+
output_hidden_states=output_hidden_states,
|
| 531 |
+
*args,
|
| 532 |
+
**kwargs
|
| 533 |
+
).past_key_values
|
| 534 |
+
self.kv_cache_dict[self.base_model_idx][source_model_idx] = curr_source_kv_cache
|
| 535 |
+
|
| 536 |
+
# calculate source model kvcache and apply projections
|
| 537 |
+
if self.base_model_idx in self.projector_dict:
|
| 538 |
+
source_model_idx = kv_cache_index[i][0][0][0].item() # Get the source model index from the kv_cache_index
|
| 539 |
+
if source_model_idx != -1:
|
| 540 |
+
for target_layer_idx, entry in self.projector_dict[self.base_model_idx][source_model_idx].items():
|
| 541 |
+
base_key_cache, base_value_cache = curr_base_kv_cache[target_layer_idx]
|
| 542 |
+
new_base_key_cache = base_key_cache[:, :, start:end, :]
|
| 543 |
+
new_base_value_cache = base_value_cache[:, :, start:end, :]
|
| 544 |
+
new_base_kv_cache = (new_base_key_cache, new_base_value_cache)
|
| 545 |
+
|
| 546 |
+
pair_list = entry
|
| 547 |
+
|
| 548 |
+
projected_kv_list = []
|
| 549 |
+
source_kv_list = []
|
| 550 |
+
for source_model_layer_idx, projector_idx in pair_list:
|
| 551 |
+
source_key_cache, source_value_cache = self.kv_cache_dict[self.base_model_idx][source_model_idx][source_model_layer_idx]
|
| 552 |
+
new_source_key_cache = source_key_cache[:, :, start:end, :]
|
| 553 |
+
new_source_value_cache = source_value_cache[:, :, start:end, :]
|
| 554 |
+
new_source_kv_cache = (new_source_key_cache, new_source_value_cache)
|
| 555 |
+
projected_key, projected_value = self.projector_list[projector_idx].forward(
|
| 556 |
+
new_source_kv_cache, # tuple of (key, value), each of shape (B, N, H, D)
|
| 557 |
+
new_base_kv_cache
|
| 558 |
+
)
|
| 559 |
+
loss_output = loss_output + loss(torch.dstack([projected_key, projected_value]),
|
| 560 |
+
torch.dstack([new_source_key_cache, new_source_value_cache]))
|
| 561 |
+
projected_kv_list.append((projected_key, projected_value))
|
| 562 |
+
source_kv_list.append(new_source_kv_cache)
|
| 563 |
+
|
| 564 |
+
# Aggregate (fallback to first projector if no aggregator is available)
|
| 565 |
+
use_aggregator = (
|
| 566 |
+
len(projected_kv_list) > 1 and
|
| 567 |
+
len(self.aggregator_list) > 0 and
|
| 568 |
+
self.base_model_idx in self.aggregator_dict and
|
| 569 |
+
source_model_idx in self.aggregator_dict[self.base_model_idx] and
|
| 570 |
+
target_layer_idx in self.aggregator_dict[self.base_model_idx][source_model_idx]
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
if use_aggregator:
|
| 574 |
+
aggregator_idx = self.aggregator_dict[self.base_model_idx][source_model_idx][target_layer_idx]
|
| 575 |
+
agg_key, agg_value = self.aggregator_list[aggregator_idx].forward(
|
| 576 |
+
source_kv_list,
|
| 577 |
+
new_base_kv_cache,
|
| 578 |
+
projected_kv_list
|
| 579 |
+
)
|
| 580 |
+
else:
|
| 581 |
+
# Fallback to first projector result when no aggregator is available
|
| 582 |
+
agg_key, agg_value = projected_kv_list[0]
|
| 583 |
+
|
| 584 |
+
# Update cache with aggregated result
|
| 585 |
+
curr_base_kv_cache.key_cache[target_layer_idx][:, :, start:end, :] = agg_key
|
| 586 |
+
curr_base_kv_cache.value_cache[target_layer_idx][:, :, start:end, :] = agg_value
|
| 587 |
+
|
| 588 |
+
output.past_key_values = curr_base_kv_cache
|
| 589 |
+
|
| 590 |
+
# use base model for decode phase
|
| 591 |
+
else:
|
| 592 |
+
# Handle list input format for decode phase as well
|
| 593 |
+
decode_input_ids = input_ids[self.base_model_idx] if isinstance(input_ids, list) else input_ids
|
| 594 |
+
decode_attention_mask = attention_mask[self.base_model_idx] if isinstance(attention_mask, list) else attention_mask
|
| 595 |
+
|
| 596 |
+
output = self.model_list[self.base_model_idx].forward(
|
| 597 |
+
input_ids=decode_input_ids,
|
| 598 |
+
attention_mask=decode_attention_mask,
|
| 599 |
+
position_ids=position_ids,
|
| 600 |
+
past_key_values=curr_base_kv_cache,
|
| 601 |
+
inputs_embeds=inputs_embeds,
|
| 602 |
+
labels=labels,
|
| 603 |
+
use_cache=use_cache,
|
| 604 |
+
output_attentions=output_attentions,
|
| 605 |
+
output_hidden_states=output_hidden_states,
|
| 606 |
+
cache_position=cache_position,
|
| 607 |
+
*args,
|
| 608 |
+
**kwargs
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
return output, loss_output
|
| 612 |
+
|
| 613 |
+
@torch.no_grad()
|
| 614 |
+
def generate(
|
| 615 |
+
self,
|
| 616 |
+
kv_cache_index,
|
| 617 |
+
input_ids,
|
| 618 |
+
max_new_tokens: Optional[int] = None,
|
| 619 |
+
past_key_values: Optional[Cache] = None,
|
| 620 |
+
attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
| 621 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 622 |
+
eos_token_id: Optional[Union[int, List[int]]] = None,
|
| 623 |
+
pad_token_id: Optional[int] = None,
|
| 624 |
+
temperature: float = 1.0,
|
| 625 |
+
top_p: float = 1.0,
|
| 626 |
+
top_k: int = -1,
|
| 627 |
+
repetition_penalty: float = 1.0,
|
| 628 |
+
presence_penalty: float = 0.0,
|
| 629 |
+
frequency_penalty: float = 0.0,
|
| 630 |
+
do_sample: Optional[bool] = None,
|
| 631 |
+
return_dict_in_generate: Optional[bool] = None,
|
| 632 |
+
output_scores: Optional[bool] = None,
|
| 633 |
+
max_length: Optional[int] = None,
|
| 634 |
+
use_cache: bool = True,
|
| 635 |
+
streamer = None,
|
| 636 |
+
*args,
|
| 637 |
+
**kwargs,
|
| 638 |
+
):
|
| 639 |
+
"""
|
| 640 |
+
New generation loop without using the base model's generate.
|
| 641 |
+
- Uses this module's forward for prefill and per-token decode.
|
| 642 |
+
- Samples tokens via rosetta.model.sampling.sample_token.
|
| 643 |
+
Returns a tensor of shape [batch, prompt_len + generated_len] for the base model stream.
|
| 644 |
+
"""
|
| 645 |
+
# Derive number of tokens to generate
|
| 646 |
+
# If max_new_tokens not provided, infer from max_length
|
| 647 |
+
if isinstance(input_ids, list):
|
| 648 |
+
base_input_ids_for_len = input_ids[self.base_model_idx]
|
| 649 |
+
else:
|
| 650 |
+
base_input_ids_for_len = input_ids
|
| 651 |
+
prompt_len = base_input_ids_for_len.size(1)
|
| 652 |
+
|
| 653 |
+
# Default eos/pad from base model tokenizer/config if not provided
|
| 654 |
+
base_model = self.model_list[self.base_model_idx]
|
| 655 |
+
gen_cfg = getattr(base_model, "generation_config", None)
|
| 656 |
+
cfg_obj = gen_cfg if gen_cfg is not None else getattr(base_model, "config", None)
|
| 657 |
+
if eos_token_id is None and cfg_obj is not None:
|
| 658 |
+
eos_token_id = getattr(cfg_obj, "eos_token_id", None)
|
| 659 |
+
if pad_token_id is None and cfg_obj is not None:
|
| 660 |
+
pad_token_id = getattr(cfg_obj, "pad_token_id", None)
|
| 661 |
+
if pad_token_id is None and eos_token_id is not None:
|
| 662 |
+
pad_token_id = eos_token_id if isinstance(eos_token_id, int) else eos_token_id[0]
|
| 663 |
+
|
| 664 |
+
if max_new_tokens is None:
|
| 665 |
+
if max_length is not None:
|
| 666 |
+
if max_length <= prompt_len:
|
| 667 |
+
max_new_tokens = 0
|
| 668 |
+
else:
|
| 669 |
+
max_new_tokens = max_length - prompt_len
|
| 670 |
+
else:
|
| 671 |
+
raise ValueError("Provide max_new_tokens or max_length")
|
| 672 |
+
if max_new_tokens < 0:
|
| 673 |
+
raise ValueError("max_new_tokens must be non-negative")
|
| 674 |
+
|
| 675 |
+
# Resolve base inputs
|
| 676 |
+
if isinstance(input_ids, list):
|
| 677 |
+
base_input_ids = input_ids[self.base_model_idx]
|
| 678 |
+
base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
|
| 679 |
+
else:
|
| 680 |
+
base_input_ids = input_ids
|
| 681 |
+
base_attention_mask = attention_mask
|
| 682 |
+
|
| 683 |
+
if base_attention_mask is None:
|
| 684 |
+
base_attention_mask = torch.ones_like(base_input_ids, dtype=torch.long, device=base_input_ids.device)
|
| 685 |
+
|
| 686 |
+
batch_size = base_input_ids.size(0)
|
| 687 |
+
|
| 688 |
+
# Prefill to build caches and obtain initial logits
|
| 689 |
+
prefill_output = self.forward(
|
| 690 |
+
kv_cache_index=kv_cache_index,
|
| 691 |
+
input_ids=input_ids,
|
| 692 |
+
attention_mask=attention_mask,
|
| 693 |
+
position_ids=position_ids,
|
| 694 |
+
past_key_values=past_key_values,
|
| 695 |
+
use_cache=use_cache,
|
| 696 |
+
*args,
|
| 697 |
+
**kwargs,
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
current_past = prefill_output.past_key_values
|
| 701 |
+
all_input_ids = base_input_ids
|
| 702 |
+
current_attention_mask = base_attention_mask
|
| 703 |
+
|
| 704 |
+
# Initialize streamer with prompt if provided
|
| 705 |
+
if streamer is not None:
|
| 706 |
+
streamer.put(base_input_ids)
|
| 707 |
+
|
| 708 |
+
# EOS handling setup
|
| 709 |
+
eos_set = None
|
| 710 |
+
if eos_token_id is not None:
|
| 711 |
+
eos_set = set(eos_token_id if isinstance(eos_token_id, list) else [eos_token_id])
|
| 712 |
+
finished = torch.zeros(batch_size, dtype=torch.bool, device=all_input_ids.device)
|
| 713 |
+
|
| 714 |
+
# Start from last prefill logits
|
| 715 |
+
last_logits = prefill_output.logits[:, -1, :]
|
| 716 |
+
|
| 717 |
+
# Determine sampling mode
|
| 718 |
+
if do_sample is None:
|
| 719 |
+
do_sample = False
|
| 720 |
+
effective_temperature = temperature if do_sample else 0.0
|
| 721 |
+
|
| 722 |
+
# Optional scores collection
|
| 723 |
+
collect_scores = bool(return_dict_in_generate) and bool(output_scores)
|
| 724 |
+
scores = []
|
| 725 |
+
|
| 726 |
+
for _ in range(max_new_tokens):
|
| 727 |
+
if collect_scores:
|
| 728 |
+
scores.append(last_logits)
|
| 729 |
+
# Apply repetition/presence/frequency penalties to logits before sampling
|
| 730 |
+
adjusted_logits = last_logits
|
| 731 |
+
if (
|
| 732 |
+
(repetition_penalty is not None and repetition_penalty != 1.0) or
|
| 733 |
+
(presence_penalty is not None and presence_penalty != 0.0) or
|
| 734 |
+
(frequency_penalty is not None and frequency_penalty != 0.0)
|
| 735 |
+
):
|
| 736 |
+
adjusted_logits = last_logits.clone()
|
| 737 |
+
vocab_size = adjusted_logits.size(-1)
|
| 738 |
+
# Per-batch penalty application for clarity and correctness
|
| 739 |
+
for b in range(batch_size):
|
| 740 |
+
seq_tokens = all_input_ids[b]
|
| 741 |
+
if seq_tokens.numel() == 0:
|
| 742 |
+
continue
|
| 743 |
+
counts = torch.bincount(seq_tokens, minlength=vocab_size)
|
| 744 |
+
if counts.dtype != torch.float32 and counts.dtype != torch.float64:
|
| 745 |
+
counts = counts.to(adjusted_logits.dtype)
|
| 746 |
+
# Presence penalty: penalize any token that has appeared
|
| 747 |
+
if presence_penalty and presence_penalty != 0.0:
|
| 748 |
+
presence_mask = counts > 0
|
| 749 |
+
if presence_mask.any():
|
| 750 |
+
adjusted_logits[b, presence_mask] = adjusted_logits[b, presence_mask] - presence_penalty
|
| 751 |
+
# Frequency penalty: penalize proportionally to frequency
|
| 752 |
+
if frequency_penalty and frequency_penalty != 0.0:
|
| 753 |
+
adjusted_logits[b] = adjusted_logits[b] - frequency_penalty * counts
|
| 754 |
+
# Repetition penalty (HF-style): divide positive logits, multiply negative logits
|
| 755 |
+
if repetition_penalty and repetition_penalty != 1.0:
|
| 756 |
+
rep_mask = counts > 0
|
| 757 |
+
if rep_mask.any():
|
| 758 |
+
pos_mask = rep_mask & (adjusted_logits[b] > 0)
|
| 759 |
+
neg_mask = rep_mask & ~pos_mask
|
| 760 |
+
if pos_mask.any():
|
| 761 |
+
adjusted_logits[b, pos_mask] = adjusted_logits[b, pos_mask] / repetition_penalty
|
| 762 |
+
if neg_mask.any():
|
| 763 |
+
adjusted_logits[b, neg_mask] = adjusted_logits[b, neg_mask] * repetition_penalty
|
| 764 |
+
|
| 765 |
+
# Sample next token
|
| 766 |
+
next_token = sample_token(adjusted_logits, temperature=effective_temperature, top_p=top_p, top_k=top_k)
|
| 767 |
+
if not isinstance(next_token, torch.Tensor):
|
| 768 |
+
next_token = torch.tensor([next_token], device=all_input_ids.device, dtype=torch.long).repeat(batch_size)
|
| 769 |
+
|
| 770 |
+
# Apply EOS logic
|
| 771 |
+
if eos_set is not None:
|
| 772 |
+
just_finished = torch.zeros_like(finished)
|
| 773 |
+
for eid in eos_set:
|
| 774 |
+
just_finished |= (next_token == eid)
|
| 775 |
+
finished = finished | just_finished
|
| 776 |
+
if pad_token_id is not None:
|
| 777 |
+
next_token = torch.where(
|
| 778 |
+
finished,
|
| 779 |
+
torch.tensor(pad_token_id, device=next_token.device, dtype=next_token.dtype),
|
| 780 |
+
next_token,
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
# Append sampled token
|
| 784 |
+
next_token_unsqueezed = next_token.unsqueeze(1)
|
| 785 |
+
all_input_ids = torch.cat([all_input_ids, next_token_unsqueezed], dim=1)
|
| 786 |
+
current_attention_mask = torch.cat(
|
| 787 |
+
[
|
| 788 |
+
current_attention_mask,
|
| 789 |
+
torch.ones((batch_size, 1), device=current_attention_mask.device, dtype=current_attention_mask.dtype),
|
| 790 |
+
],
|
| 791 |
+
dim=1,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
# Stream the new token if streamer provided
|
| 795 |
+
if streamer is not None:
|
| 796 |
+
streamer.put(next_token_unsqueezed)
|
| 797 |
+
|
| 798 |
+
# Early stop if all sequences finished
|
| 799 |
+
if eos_set is not None and torch.all(finished):
|
| 800 |
+
break
|
| 801 |
+
|
| 802 |
+
# Decode one step using cached states; pass base-stream tensors
|
| 803 |
+
kv_cache_index = [torch.tensor([-1, 0], dtype=torch.long).repeat(1, 1).unsqueeze(0).to(all_input_ids.device)]
|
| 804 |
+
|
| 805 |
+
decode_output = self.forward(
|
| 806 |
+
kv_cache_index=kv_cache_index,
|
| 807 |
+
input_ids=next_token_unsqueezed,
|
| 808 |
+
attention_mask=current_attention_mask,
|
| 809 |
+
position_ids=None,
|
| 810 |
+
past_key_values=current_past,
|
| 811 |
+
use_cache=True,
|
| 812 |
+
*args,
|
| 813 |
+
**kwargs,
|
| 814 |
+
)
|
| 815 |
+
current_past = decode_output.past_key_values
|
| 816 |
+
last_logits = decode_output.logits[:, -1, :]
|
| 817 |
+
|
| 818 |
+
# End streaming if streamer provided
|
| 819 |
+
if streamer is not None:
|
| 820 |
+
streamer.end()
|
| 821 |
+
|
| 822 |
+
# Return style compatible with HF generate
|
| 823 |
+
if return_dict_in_generate:
|
| 824 |
+
if GreedySearchDecoderOnlyOutput is not None and SampleDecoderOnlyOutput is not None:
|
| 825 |
+
if do_sample:
|
| 826 |
+
return SampleDecoderOnlyOutput(
|
| 827 |
+
sequences=all_input_ids,
|
| 828 |
+
scores=scores if collect_scores else None,
|
| 829 |
+
)
|
| 830 |
+
else:
|
| 831 |
+
return GreedySearchDecoderOnlyOutput(
|
| 832 |
+
sequences=all_input_ids,
|
| 833 |
+
scores=scores if collect_scores else None,
|
| 834 |
+
)
|
| 835 |
+
# Fallback to generic ModelOutput
|
| 836 |
+
result = {"sequences": all_input_ids}
|
| 837 |
+
if collect_scores:
|
| 838 |
+
result["scores"] = scores
|
| 839 |
+
return ModelOutput(**result)
|
| 840 |
+
return all_input_ids
|
rosetta/train/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training utilities for RosettaModel
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .dataset_adapters import (
|
| 6 |
+
ChatDataset,
|
| 7 |
+
RosettaDataCollator,
|
| 8 |
+
)
|
| 9 |
+
from .model_utils import setup_models
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"RosettaTrainer",
|
| 13 |
+
"ProjectorSaveCallback",
|
| 14 |
+
"freeze_model_components",
|
| 15 |
+
"InstructCoderChatDataset",
|
| 16 |
+
"ChatDataset",
|
| 17 |
+
"RosettaDataCollator",
|
| 18 |
+
"create_instructcoder_dataset",
|
| 19 |
+
"setup_models"
|
| 20 |
+
]
|
rosetta/train/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (505 Bytes). View file
|
|
|
rosetta/train/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (616 Bytes). View file
|
|
|
rosetta/train/__pycache__/dataset_adapters.cpython-310.pyc
ADDED
|
Binary file (50 kB). View file
|
|
|
rosetta/train/__pycache__/dataset_adapters.cpython-312.pyc
ADDED
|
Binary file (73.3 kB). View file
|
|
|
rosetta/train/__pycache__/model_utils.cpython-310.pyc
ADDED
|
Binary file (3.99 kB). View file
|
|
|
rosetta/train/__pycache__/model_utils.cpython-312.pyc
ADDED
|
Binary file (5.58 kB). View file
|
|
|
rosetta/train/__pycache__/trainer.cpython-310.pyc
ADDED
|
Binary file (2.56 kB). View file
|
|
|
rosetta/train/__pycache__/trainer.cpython-312.pyc
ADDED
|
Binary file (3.74 kB). View file
|
|
|
rosetta/train/dataset_adapters.py
ADDED
|
@@ -0,0 +1,1648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple dataset adapter for converting InstructCoder to chat format
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import List, Dict, Any, Optional, Union, Callable
|
| 6 |
+
from datasets import load_dataset, load_from_disk
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
import inspect
|
| 11 |
+
import os
|
| 12 |
+
import hashlib
|
| 13 |
+
# Dataset Registry System
|
| 14 |
+
DATASET_REGISTRY = {}
|
| 15 |
+
|
| 16 |
+
def register_dataset(cls=None, name=None):
|
| 17 |
+
"""
|
| 18 |
+
Register a dataset class in the global registry.
|
| 19 |
+
Can be used as a decorator with or without arguments.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
cls: The class to register
|
| 23 |
+
name: Optional name to register the class under. If None, uses the class name.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
The registered class
|
| 27 |
+
"""
|
| 28 |
+
def _register(cls):
|
| 29 |
+
dataset_name = name if name is not None else cls.__name__
|
| 30 |
+
DATASET_REGISTRY[dataset_name] = cls
|
| 31 |
+
# Also register with lowercase name for case-insensitive lookup
|
| 32 |
+
DATASET_REGISTRY[dataset_name.lower()] = cls
|
| 33 |
+
return cls
|
| 34 |
+
|
| 35 |
+
# Called as @register_dataset
|
| 36 |
+
if cls is not None:
|
| 37 |
+
return _register(cls)
|
| 38 |
+
|
| 39 |
+
# Called as @register_dataset() or @register_dataset(name="DatasetName")
|
| 40 |
+
return _register
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def capture_init_args(cls):
|
| 44 |
+
"""
|
| 45 |
+
Decorator to capture initialization arguments of a dataset class.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
cls: The class to decorate
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
The decorated class with automatic init args capture
|
| 52 |
+
"""
|
| 53 |
+
original_init = cls.__init__
|
| 54 |
+
|
| 55 |
+
def new_init(self, *args, **kwargs):
|
| 56 |
+
# Store all initialization arguments
|
| 57 |
+
self._init_args = {}
|
| 58 |
+
|
| 59 |
+
# Get parameter names from the original __init__ method
|
| 60 |
+
sig = inspect.signature(original_init)
|
| 61 |
+
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
|
| 62 |
+
|
| 63 |
+
# Map positional args to parameter names
|
| 64 |
+
for i, arg in enumerate(args):
|
| 65 |
+
if i < len(param_names):
|
| 66 |
+
self._init_args[param_names[i]] = arg
|
| 67 |
+
|
| 68 |
+
# Add keyword args
|
| 69 |
+
self._init_args.update(kwargs)
|
| 70 |
+
|
| 71 |
+
# Call the original __init__
|
| 72 |
+
original_init(self, *args, **kwargs)
|
| 73 |
+
|
| 74 |
+
cls.__init__ = new_init
|
| 75 |
+
return cls
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Unified batch filtering functions
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def create_text_length_filter(
|
| 82 |
+
max_length: int,
|
| 83 |
+
text_extractor: Callable[[Dict[str, Any]], str],
|
| 84 |
+
tokenizer: Optional[Any] = None,
|
| 85 |
+
use_tokens: bool = False
|
| 86 |
+
):
|
| 87 |
+
"""
|
| 88 |
+
Unified text length filter that can handle both word count and token count filtering.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
max_length: Maximum allowed length (words or tokens)
|
| 92 |
+
text_extractor: Function that extracts text from a single sample
|
| 93 |
+
tokenizer: Tokenizer for token counting (required if use_tokens=True)
|
| 94 |
+
use_tokens: If True, count tokens; if False, count words
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Filter function that can be used with dataset.filter(batched=True)
|
| 98 |
+
"""
|
| 99 |
+
if use_tokens and tokenizer is None:
|
| 100 |
+
raise ValueError("Tokenizer must be provided when use_tokens=True")
|
| 101 |
+
|
| 102 |
+
def _text_length_filter_batch(batch):
|
| 103 |
+
batch_size = len(next(iter(batch.values())))
|
| 104 |
+
samples = [{key: values[i] for key, values in batch.items()} for i in range(batch_size)]
|
| 105 |
+
try:
|
| 106 |
+
texts = [text_extractor(sample) for sample in samples]
|
| 107 |
+
if use_tokens:
|
| 108 |
+
if hasattr(tokenizer, 'apply_chat_template') and any(isinstance(t, list) for t in texts):
|
| 109 |
+
rendered = []
|
| 110 |
+
for t in texts:
|
| 111 |
+
if isinstance(t, list):
|
| 112 |
+
rendered.append(tokenizer.apply_chat_template(t, tokenize=False, add_generation_prompt=False))
|
| 113 |
+
else:
|
| 114 |
+
rendered.append(str(t))
|
| 115 |
+
tokenized = tokenizer(rendered, add_special_tokens=False)
|
| 116 |
+
else:
|
| 117 |
+
tokenized = tokenizer([str(t) for t in texts], add_special_tokens=False)
|
| 118 |
+
lengths = [len(ids) for ids in tokenized["input_ids"]]
|
| 119 |
+
else:
|
| 120 |
+
lengths = [len(str(t).split()) for t in texts]
|
| 121 |
+
return [length <= max_length for length in lengths]
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"Error in text length filter: {e}")
|
| 124 |
+
return [False] * batch_size
|
| 125 |
+
|
| 126 |
+
return _text_length_filter_batch
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def create_field_value_filter(target_value: Any, field_name: str, comparison: str = 'equal'):
|
| 130 |
+
"""
|
| 131 |
+
Unified field value filter for exact matching, language filtering, etc.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
target_value: Value to compare against
|
| 135 |
+
field_name: Field name to check
|
| 136 |
+
comparison: Type of comparison ('equal', 'not_equal', 'in', 'not_in')
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Filter function that can be used with dataset.filter(batched=True)
|
| 140 |
+
"""
|
| 141 |
+
def _field_value_filter_batch(batch):
|
| 142 |
+
field_values = batch.get(field_name, [])
|
| 143 |
+
|
| 144 |
+
if comparison == 'equal':
|
| 145 |
+
return [value == target_value for value in field_values]
|
| 146 |
+
elif comparison == 'not_equal':
|
| 147 |
+
return [value != target_value for value in field_values]
|
| 148 |
+
elif comparison == 'in':
|
| 149 |
+
return [value in target_value for value in field_values]
|
| 150 |
+
elif comparison == 'not_in':
|
| 151 |
+
return [value not in target_value for value in field_values]
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError(f"Unsupported comparison: {comparison}")
|
| 154 |
+
|
| 155 |
+
return _field_value_filter_batch
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def create_modulo_filter(mod_base: int, exclude_values: Union[int, List[int]], field_name: str = '_id'):
|
| 159 |
+
"""
|
| 160 |
+
Unified modulo filter for ID-based filtering.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
mod_base: Modulo base
|
| 164 |
+
exclude_values: Value(s) to exclude (can be single int or list)
|
| 165 |
+
field_name: Field name containing the ID
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Filter function that can be used with dataset.filter(batched=True)
|
| 169 |
+
"""
|
| 170 |
+
if isinstance(exclude_values, int):
|
| 171 |
+
exclude_values = [exclude_values]
|
| 172 |
+
|
| 173 |
+
def _modulo_filter_batch(batch):
|
| 174 |
+
ids = batch.get(field_name, [])
|
| 175 |
+
results = []
|
| 176 |
+
|
| 177 |
+
for _id in ids:
|
| 178 |
+
try:
|
| 179 |
+
# Try numeric conversion first
|
| 180 |
+
id_num = int(_id)
|
| 181 |
+
mod_result = id_num % mod_base
|
| 182 |
+
except (ValueError, TypeError):
|
| 183 |
+
# Use hash for non-numeric IDs
|
| 184 |
+
id_hash = hash(str(_id))
|
| 185 |
+
mod_result = id_hash % mod_base
|
| 186 |
+
|
| 187 |
+
results.append(mod_result not in exclude_values)
|
| 188 |
+
|
| 189 |
+
return results
|
| 190 |
+
|
| 191 |
+
return _modulo_filter_batch
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def create_conversation_length_filter(min_messages: int, text_field: str = 'conversations'):
|
| 195 |
+
"""
|
| 196 |
+
Unified conversation length filter for OpenHermes-style datasets.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
min_messages: Minimum number of messages required (excluding system messages)
|
| 200 |
+
text_field: Field name containing the conversation
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Filter function that can be used with dataset.filter(batched=True)
|
| 204 |
+
"""
|
| 205 |
+
def _conversation_length_filter_batch(batch):
|
| 206 |
+
conversations_list = batch.get(text_field, [])
|
| 207 |
+
results = []
|
| 208 |
+
|
| 209 |
+
for conversations in conversations_list:
|
| 210 |
+
try:
|
| 211 |
+
# Extract messages (excluding system)
|
| 212 |
+
message_count = 0
|
| 213 |
+
for msg in conversations:
|
| 214 |
+
role = msg.get('from') or msg.get('role')
|
| 215 |
+
if role in ('human', 'user', 'gpt', 'assistant'):
|
| 216 |
+
message_count += 1
|
| 217 |
+
|
| 218 |
+
results.append(message_count > min_messages)
|
| 219 |
+
except Exception:
|
| 220 |
+
results.append(False)
|
| 221 |
+
|
| 222 |
+
return results
|
| 223 |
+
|
| 224 |
+
return _conversation_length_filter_batch
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# Text extraction functions for common dataset patterns
|
| 228 |
+
def extract_mmlu_text(sample: Dict[str, Any], question_field: str = 'question', choices_field: str = 'choices') -> str:
|
| 229 |
+
"""Extract text from MMLU-style samples"""
|
| 230 |
+
question = sample.get(question_field, '')
|
| 231 |
+
choices = sample.get(choices_field, [])
|
| 232 |
+
|
| 233 |
+
# Handle both list and dict formats for choices
|
| 234 |
+
if isinstance(choices, dict):
|
| 235 |
+
choices_text = choices.get('text', [])
|
| 236 |
+
else:
|
| 237 |
+
choices_text = choices
|
| 238 |
+
|
| 239 |
+
return (str(question) + " " + " ".join(map(str, choices_text))).strip()
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def extract_chat_text(sample: Dict[str, Any], input_field: str = 'input',
|
| 243 |
+
context_field: str = 'context', answers_field: str = 'answers') -> List[Dict[str, str]]:
|
| 244 |
+
"""Extract chat messages from LongBench-style samples"""
|
| 245 |
+
input_text = str(sample.get(input_field, ''))
|
| 246 |
+
context = str(sample.get(context_field, ''))
|
| 247 |
+
answers = sample.get(answers_field, [])
|
| 248 |
+
|
| 249 |
+
assistant_message = answers[0] if answers and len(answers) > 0 else "No answer provided"
|
| 250 |
+
|
| 251 |
+
# Build complete chat format
|
| 252 |
+
if context:
|
| 253 |
+
human_message = f"Context: {context}\n\nInstruction: {input_text}"
|
| 254 |
+
else:
|
| 255 |
+
human_message = f"Instruction: {input_text}"
|
| 256 |
+
|
| 257 |
+
return [
|
| 258 |
+
{"role": "user", "content": human_message.strip()},
|
| 259 |
+
{"role": "assistant", "content": assistant_message.strip()}
|
| 260 |
+
]
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def extract_conversation_text(sample: Dict[str, Any], text_field: str = 'conversations') -> str:
|
| 264 |
+
"""Extract text from OpenHermes-style conversation samples"""
|
| 265 |
+
conversations = sample.get(text_field, [])
|
| 266 |
+
|
| 267 |
+
if conversations and len(conversations) > 0:
|
| 268 |
+
return conversations[0].get('value', '')
|
| 269 |
+
return ''
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def extract_first_user_message(sample: Dict[str, Any], text_field: str = 'conversations') -> str:
|
| 273 |
+
"""Extract the first human/user message from conversation-style samples."""
|
| 274 |
+
conversations = sample.get(text_field, [])
|
| 275 |
+
for msg in conversations:
|
| 276 |
+
role = msg.get('from') or msg.get('role')
|
| 277 |
+
if role in ('human', 'user'):
|
| 278 |
+
return str(msg.get('value', ''))
|
| 279 |
+
# Fallback to first message if role tags are missing
|
| 280 |
+
if conversations:
|
| 281 |
+
return str(conversations[0].get('value', ''))
|
| 282 |
+
return ''
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def extract_first_assistant_message(sample: Dict[str, Any], text_field: str = 'conversations') -> str:
|
| 286 |
+
"""Extract the first gpt/assistant message from conversation-style samples."""
|
| 287 |
+
conversations = sample.get(text_field, [])
|
| 288 |
+
for msg in conversations:
|
| 289 |
+
role = msg.get('from') or msg.get('role')
|
| 290 |
+
if role in ('gpt', 'assistant'):
|
| 291 |
+
return str(msg.get('value', ''))
|
| 292 |
+
# Fallback to second message if present
|
| 293 |
+
if len(conversations) > 1:
|
| 294 |
+
return str(conversations[1].get('value', ''))
|
| 295 |
+
return ''
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def extract_openhermes_messages(sample: Dict[str, Any], text_field: str = 'conversations') -> List[Dict[str, str]]:
|
| 299 |
+
"""Build chat messages excluding system; include all human/user and gpt/assistant in order."""
|
| 300 |
+
conversation = sample.get(text_field, [])
|
| 301 |
+
messages: List[Dict[str, str]] = []
|
| 302 |
+
for msg in conversation:
|
| 303 |
+
role = msg.get('from') or msg.get('role')
|
| 304 |
+
if role == 'system':
|
| 305 |
+
continue
|
| 306 |
+
if role in ('human', 'user'):
|
| 307 |
+
messages.append({"role": "user", "content": str(msg.get('value', '')).strip()})
|
| 308 |
+
elif role in ('gpt', 'assistant'):
|
| 309 |
+
messages.append({"role": "assistant", "content": str(msg.get('value', ''))})
|
| 310 |
+
return messages
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def extract_instruction_text(sample: Dict[str, Any], instruction_field: str = 'instruction',
|
| 314 |
+
inputs_field: str = 'inputs') -> str:
|
| 315 |
+
"""Extract text from Inkuba-style instruction samples"""
|
| 316 |
+
instruction = sample.get(instruction_field)
|
| 317 |
+
inputs = sample.get(inputs_field, '')
|
| 318 |
+
|
| 319 |
+
if instruction is not None:
|
| 320 |
+
return str(instruction) + "\n\n" + str(inputs)
|
| 321 |
+
else:
|
| 322 |
+
return str(inputs)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def extract_chat_pair_text(sample: Dict[str, Any], user_field: str = 'inputs',
|
| 326 |
+
assistant_field: str = 'targets') -> List[Dict[str, str]]:
|
| 327 |
+
"""Extract chat messages from Aya-style samples"""
|
| 328 |
+
user_text = str(sample.get(user_field, ''))
|
| 329 |
+
assistant_text = str(sample.get(assistant_field, ''))
|
| 330 |
+
|
| 331 |
+
return [
|
| 332 |
+
{"role": "user", "content": user_text.strip()},
|
| 333 |
+
{"role": "assistant", "content": assistant_text.strip()}
|
| 334 |
+
]
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def extract_dolly_chat_messages(sample: Dict[str, Any]) -> List[Dict[str, str]]:
|
| 339 |
+
"""Extract chat messages from Dolly-style samples.
|
| 340 |
+
|
| 341 |
+
Fields:
|
| 342 |
+
- instruction: str
|
| 343 |
+
- context: str (may be empty)
|
| 344 |
+
- response: str
|
| 345 |
+
- category: optional, may be empty/missing
|
| 346 |
+
"""
|
| 347 |
+
instruction = str(sample.get('instruction', '')).strip()
|
| 348 |
+
context = str(sample.get('context', '') or '').strip()
|
| 349 |
+
response = str(sample.get('response', '')).strip()
|
| 350 |
+
|
| 351 |
+
if context:
|
| 352 |
+
user_message = f"{context}\n\n{instruction}"
|
| 353 |
+
else:
|
| 354 |
+
user_message = f"{instruction}"
|
| 355 |
+
|
| 356 |
+
return [
|
| 357 |
+
{"role": "user", "content": user_message.strip()},
|
| 358 |
+
{"role": "assistant", "content": response}
|
| 359 |
+
]
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def extract_mmmlu_chat_messages(sample: Dict[str, Any]) -> List[Dict[str, str]]:
|
| 363 |
+
"""Extract chat messages from MMMLU-style samples (OpenAI/MMMLU)."""
|
| 364 |
+
choice_labels = ['A', 'B', 'C', 'D']
|
| 365 |
+
|
| 366 |
+
template = (
|
| 367 |
+
"Jibu kwa usahihi swali lifuatalo:\n\n"
|
| 368 |
+
"{{question}}\n\n"
|
| 369 |
+
"Chaguo:\n"
|
| 370 |
+
"{{choices}}\n\n"
|
| 371 |
+
"Maelekezo:\n"
|
| 372 |
+
"- Soma swali na chaguo zote kwa makini.\n"
|
| 373 |
+
"- Chagua jibu sahihi zaidi kati ya yaliyotolewa.\n"
|
| 374 |
+
"- Jibu TU kwa herufi (A, B, C, D) inayolingana na jibu sahihi.\n"
|
| 375 |
+
"- Usijumuishe maelezo, maandishi ya ziada, au alama yoyote ya uakifishaji.\n\n"
|
| 376 |
+
"Jibu lako:"
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
choices_text = ""
|
| 380 |
+
for label in choice_labels:
|
| 381 |
+
content = sample.get(label, '')
|
| 382 |
+
choices_text += f"{label}. {content}\n"
|
| 383 |
+
|
| 384 |
+
user_prompt = template.replace("{{choices}}", choices_text).replace("{{question}}", str(sample.get('Question', '')))
|
| 385 |
+
|
| 386 |
+
correct_label = sample.get('Answer', '')
|
| 387 |
+
correct_content = sample.get(correct_label, '')
|
| 388 |
+
assistant_response = f"**Jibu lako: {correct_label}. {correct_content}.**"
|
| 389 |
+
|
| 390 |
+
return [
|
| 391 |
+
{"role": "user", "content": user_prompt.strip()},
|
| 392 |
+
{"role": "assistant", "content": assistant_response}
|
| 393 |
+
]
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def apply_batch_filters(dataset, filters: list, filter_descriptions: list = None,
|
| 399 |
+
batch_size: int = 4096, combine_filters: bool = True,
|
| 400 |
+
num_proc: Optional[int] = None):
|
| 401 |
+
"""
|
| 402 |
+
Apply multiple filters using native batched filtering for maximum performance.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
dataset: Dataset to filter
|
| 406 |
+
filters: List of batched filter functions
|
| 407 |
+
filter_descriptions: Optional list of descriptions for logging
|
| 408 |
+
batch_size: Batch size for filtering operations
|
| 409 |
+
combine_filters: If True, combine all filters into a single batched operation
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
Filtered dataset and original length
|
| 413 |
+
"""
|
| 414 |
+
if not filters:
|
| 415 |
+
return dataset, len(dataset)
|
| 416 |
+
|
| 417 |
+
original_len = len(dataset)
|
| 418 |
+
|
| 419 |
+
if combine_filters and len(filters) > 1:
|
| 420 |
+
# Combine all filters into a single batched operation for maximum efficiency
|
| 421 |
+
def _combined_batch_filter(batch):
|
| 422 |
+
# Get results from all filters
|
| 423 |
+
filter_results = []
|
| 424 |
+
for filter_func in filters:
|
| 425 |
+
filter_results.append(filter_func(batch))
|
| 426 |
+
|
| 427 |
+
# Combine results with AND logic
|
| 428 |
+
combined_results = []
|
| 429 |
+
batch_size = len(filter_results[0]) if filter_results else 0
|
| 430 |
+
|
| 431 |
+
for i in range(batch_size):
|
| 432 |
+
combined_results.append(all(result[i] for result in filter_results))
|
| 433 |
+
|
| 434 |
+
return combined_results
|
| 435 |
+
|
| 436 |
+
# Apply combined filter in a single pass
|
| 437 |
+
filtered_dataset = dataset.filter(
|
| 438 |
+
_combined_batch_filter,
|
| 439 |
+
batched=True,
|
| 440 |
+
batch_size=batch_size,
|
| 441 |
+
num_proc=num_proc if num_proc and (num_proc or 0) > 1 else None,
|
| 442 |
+
desc="Combined batch filtering"
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# Print filtering results
|
| 446 |
+
final_len = len(filtered_dataset)
|
| 447 |
+
if original_len != final_len:
|
| 448 |
+
print(f"Applied combined batch filtering: {original_len} -> {final_len} samples")
|
| 449 |
+
if filter_descriptions:
|
| 450 |
+
for desc in filter_descriptions:
|
| 451 |
+
print(f" - {desc}")
|
| 452 |
+
|
| 453 |
+
else:
|
| 454 |
+
# Apply each filter sequentially with batched processing
|
| 455 |
+
current_dataset = dataset
|
| 456 |
+
|
| 457 |
+
for i, (filter_func, desc) in enumerate(zip(filters, filter_descriptions or [''] * len(filters))):
|
| 458 |
+
pre_filter_len = len(current_dataset)
|
| 459 |
+
|
| 460 |
+
current_dataset = current_dataset.filter(
|
| 461 |
+
filter_func,
|
| 462 |
+
batched=True,
|
| 463 |
+
batch_size=batch_size,
|
| 464 |
+
num_proc=num_proc if num_proc and (num_proc or 0) > 1 else None,
|
| 465 |
+
desc=f"Filtering: {desc}" if desc else f"Filter {i+1}"
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
post_filter_len = len(current_dataset)
|
| 469 |
+
if desc and pre_filter_len != post_filter_len:
|
| 470 |
+
print(f" - {desc}: {pre_filter_len} -> {post_filter_len} samples")
|
| 471 |
+
|
| 472 |
+
filtered_dataset = current_dataset
|
| 473 |
+
final_len = len(filtered_dataset)
|
| 474 |
+
if original_len != final_len:
|
| 475 |
+
print(f"Applied sequential batch filtering: {original_len} -> {final_len} samples")
|
| 476 |
+
|
| 477 |
+
return filtered_dataset, original_len
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def generate_kv_cache_index(instruction_length: int, full_length: int) -> torch.tensor:
|
| 481 |
+
"""
|
| 482 |
+
Generate KV cache index for the input sequence.
|
| 483 |
+
|
| 484 |
+
Args:
|
| 485 |
+
instruction_length: Length of the instruction tokens
|
| 486 |
+
full_length: Total length of the full conversation tokens
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
Tensor with KV cache index
|
| 490 |
+
"""
|
| 491 |
+
assert instruction_length <= full_length
|
| 492 |
+
|
| 493 |
+
instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(instruction_length - 1, 1)
|
| 494 |
+
label_index = torch.tensor([-1, 0], dtype=torch.long).repeat(full_length - instruction_length + 1, 1)
|
| 495 |
+
|
| 496 |
+
kv_cache_index = torch.cat([instruction_index, label_index], dim=0) # shape: (seq_len, 2)
|
| 497 |
+
|
| 498 |
+
return kv_cache_index
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
"""
|
| 502 |
+
Instruction dataset
|
| 503 |
+
|
| 504 |
+
Convert any form of inputs to standard message format
|
| 505 |
+
"""
|
| 506 |
+
|
| 507 |
+
@register_dataset
|
| 508 |
+
@capture_init_args
|
| 509 |
+
class LongBenchChatDataset(Dataset):
|
| 510 |
+
"""LongBench数据集转换为LongBench原始格式"""
|
| 511 |
+
|
| 512 |
+
def __init__(self, split: str = "test", num_samples: Optional[int] = None,
|
| 513 |
+
dataset_name: Optional[str] = None, language: Optional[str] = None,
|
| 514 |
+
max_word_count: Optional[int] = None, max_length: Optional[int] = 14000,
|
| 515 |
+
use_longbench_e: bool = True, filter_mod4: bool = True):
|
| 516 |
+
"""
|
| 517 |
+
初始化LongBench数据集
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
split: 数据集分割 ("test" - LongBench主要使用test分割)
|
| 521 |
+
num_samples: 使用的样本数量 (None表示全部)
|
| 522 |
+
dataset_name: 特定数据集名称 (None表示所有数据集)
|
| 523 |
+
language: 语言过滤 ("en" 或 "zh")
|
| 524 |
+
max_word_count: 最大词数限制(用于英文文本)
|
| 525 |
+
max_length: 最大字符长度限制
|
| 526 |
+
use_longbench_e: 是否使用LongBench-E版本
|
| 527 |
+
filter_mod4: 是否过滤_id mod4余1的样本
|
| 528 |
+
"""
|
| 529 |
+
print(f"Loading LongBench{' -E' if use_longbench_e else ''} dataset (split: {split}, dataset: {dataset_name})...")
|
| 530 |
+
|
| 531 |
+
# LongBench包含的数据集列表
|
| 532 |
+
longbench_datasets = [
|
| 533 |
+
"narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa",
|
| 534 |
+
"2wikimqa", "musique", "dureader", "gov_report", "qmsum", "multi_news",
|
| 535 |
+
"vcsum", "trec", "triviaqa", "samsum", "lsht", "passage_count",
|
| 536 |
+
"passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"
|
| 537 |
+
]
|
| 538 |
+
|
| 539 |
+
longbench_e_datasets = [
|
| 540 |
+
"qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report",
|
| 541 |
+
"multi_news", "trec", "triviaqa", "samsum", "passage_count",
|
| 542 |
+
"passage_retrieval_en", "lcc", "repobench-p"
|
| 543 |
+
]
|
| 544 |
+
|
| 545 |
+
target_datasets = longbench_e_datasets if use_longbench_e else longbench_datasets
|
| 546 |
+
|
| 547 |
+
# 定义LongBench提示模板
|
| 548 |
+
self.dataset_prompt_formats = {
|
| 549 |
+
"narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
|
| 550 |
+
"qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
|
| 551 |
+
"multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
|
| 552 |
+
"multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
|
| 553 |
+
"hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
|
| 554 |
+
"2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
|
| 555 |
+
"musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
|
| 556 |
+
"dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
|
| 557 |
+
"gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:",
|
| 558 |
+
"qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
|
| 559 |
+
"multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:",
|
| 560 |
+
"vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:",
|
| 561 |
+
"trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}",
|
| 562 |
+
"triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}",
|
| 563 |
+
"samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}",
|
| 564 |
+
"lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}",
|
| 565 |
+
"passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
|
| 566 |
+
"passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ",
|
| 567 |
+
"passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:",
|
| 568 |
+
"lcc": "Please complete the code given below. \n{context}Next line of code:\n",
|
| 569 |
+
"repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n"
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
# 定义不使用聊天模板的任务
|
| 573 |
+
#self.no_chat_template_tasks = ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]
|
| 574 |
+
self.no_chat_template_tasks=['']
|
| 575 |
+
self.use_longbench_e = use_longbench_e
|
| 576 |
+
self.max_length = max_length
|
| 577 |
+
|
| 578 |
+
if dataset_name:
|
| 579 |
+
if dataset_name not in target_datasets:
|
| 580 |
+
raise ValueError(f"Dataset {dataset_name} not found in LongBench{' -E' if use_longbench_e else ''}")
|
| 581 |
+
target_datasets = [dataset_name]
|
| 582 |
+
self.current_evaluating_subject = dataset_name
|
| 583 |
+
else:
|
| 584 |
+
self.current_evaluating_subject = None
|
| 585 |
+
|
| 586 |
+
# 加载所有选定的数据集
|
| 587 |
+
all_data = []
|
| 588 |
+
for dataset in target_datasets:
|
| 589 |
+
try:
|
| 590 |
+
dataset_suffix = f"{dataset}_e" if use_longbench_e else dataset
|
| 591 |
+
data = load_dataset('THUDM/LongBench', dataset_suffix, split=split)
|
| 592 |
+
print(f" Loaded {len(data)} samples from {dataset}")
|
| 593 |
+
|
| 594 |
+
# 添加数据集名称标识
|
| 595 |
+
data = data.map(lambda x: {"dataset_source": dataset})
|
| 596 |
+
all_data.append(data)
|
| 597 |
+
except Exception as e:
|
| 598 |
+
print(f"Warning: Failed to load {dataset}: {e}")
|
| 599 |
+
continue
|
| 600 |
+
|
| 601 |
+
if not all_data:
|
| 602 |
+
raise ValueError("No datasets were successfully loaded")
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
from datasets import concatenate_datasets
|
| 606 |
+
self.dataset = concatenate_datasets(all_data)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
# mod4!=1
|
| 613 |
+
if filter_mod4:
|
| 614 |
+
original_len = len(self.dataset)
|
| 615 |
+
|
| 616 |
+
def _mod4_not_1(example):
|
| 617 |
+
_id = example.get('_id', '')
|
| 618 |
+
id_hash = int(hashlib.sha256(str(_id).encode('utf-8')).hexdigest(), 16)
|
| 619 |
+
|
| 620 |
+
return id_hash % 4 != 1
|
| 621 |
+
|
| 622 |
+
self.dataset = self.dataset.filter(_mod4_not_1)
|
| 623 |
+
print(f"Filtered by _id mod4 != 1: {original_len} -> {len(self.dataset)} samples")
|
| 624 |
+
|
| 625 |
+
# 限制样本数量
|
| 626 |
+
if num_samples and num_samples < len(self.dataset):
|
| 627 |
+
self.dataset = self.dataset.select(range(num_samples))
|
| 628 |
+
|
| 629 |
+
print(f"Loaded total {len(self.dataset)} samples from LongBench{' -E' if use_longbench_e else ''}")
|
| 630 |
+
def __len__(self):
|
| 631 |
+
return len(self.dataset)
|
| 632 |
+
|
| 633 |
+
def _format_longbench_example(self, example: Dict[str, Any], tokenizer: AutoTokenizer) -> str:
|
| 634 |
+
|
| 635 |
+
# 1. 确定任务类型
|
| 636 |
+
dataset_source = example.get('dataset_source', '')
|
| 637 |
+
if self.current_evaluating_subject:
|
| 638 |
+
current_subject = self.current_evaluating_subject
|
| 639 |
+
else:
|
| 640 |
+
current_subject = dataset_source
|
| 641 |
+
|
| 642 |
+
# 仅当字符串以"_e"结尾时才替换
|
| 643 |
+
import re
|
| 644 |
+
subject = re.sub(r"_e$", "", current_subject) if self.use_longbench_e else current_subject
|
| 645 |
+
|
| 646 |
+
# 2. 获取提示模板
|
| 647 |
+
if subject not in self.dataset_prompt_formats:
|
| 648 |
+
subject = "narrativeqa" # 默认模板
|
| 649 |
+
prompt_format = self.dataset_prompt_formats[subject]
|
| 650 |
+
|
| 651 |
+
# 3. 直接使用**example展开所有字段
|
| 652 |
+
raw_prompt = prompt_format.format(**example)
|
| 653 |
+
|
| 654 |
+
# 4. 超长截断逻辑
|
| 655 |
+
tokenized_raw = tokenizer(raw_prompt, truncation=False, return_tensors="pt").input_ids[0]
|
| 656 |
+
if len(tokenized_raw) > self.max_length:
|
| 657 |
+
half_len = int(self.max_length / 2)
|
| 658 |
+
raw_prompt = tokenizer.decode(tokenized_raw[:half_len], skip_special_tokens=True) + \
|
| 659 |
+
tokenizer.decode(tokenized_raw[-half_len:], skip_special_tokens=True)
|
| 660 |
+
|
| 661 |
+
# 5. 应用Chat Template
|
| 662 |
+
|
| 663 |
+
final_prompt = raw_prompt
|
| 664 |
+
print(len(tokenized_raw))
|
| 665 |
+
return final_prompt
|
| 666 |
+
|
| 667 |
+
def __getitem__(self, idx):
|
| 668 |
+
|
| 669 |
+
sample = self.dataset[idx]
|
| 670 |
+
|
| 671 |
+
# 格式化样本
|
| 672 |
+
formatted_prompt = self._format_longbench_example(sample, self.tokenizer)
|
| 673 |
+
|
| 674 |
+
# 提取答案
|
| 675 |
+
answers = sample.get('answers', [])
|
| 676 |
+
assistant_message = answers[0] if answers and len(answers) > 0 else "No answer provided"
|
| 677 |
+
|
| 678 |
+
return [
|
| 679 |
+
{
|
| 680 |
+
"role": "user",
|
| 681 |
+
"content": formatted_prompt.strip()
|
| 682 |
+
},
|
| 683 |
+
{
|
| 684 |
+
"role": "assistant",
|
| 685 |
+
"content": assistant_message.strip()
|
| 686 |
+
}
|
| 687 |
+
]
|
| 688 |
+
|
| 689 |
+
@register_dataset
|
| 690 |
+
@capture_init_args
|
| 691 |
+
class MMLUChatDataset(Dataset):
|
| 692 |
+
"""Simple MMLU dataset converted to chat format"""
|
| 693 |
+
|
| 694 |
+
def __init__(self, split: str = "train", num_samples: Optional[int] = None, max_word_count: Optional[int] = None):
|
| 695 |
+
"""
|
| 696 |
+
Initialize the dataset
|
| 697 |
+
|
| 698 |
+
Args:
|
| 699 |
+
split: Dataset split
|
| 700 |
+
num_samples: Number of samples to use (None for all)
|
| 701 |
+
max_word_count: If set, drop samples whose question + all choices exceed this word count
|
| 702 |
+
"""
|
| 703 |
+
print(f"Loading MMLU dataset (split: {split})...")
|
| 704 |
+
# Load dataset
|
| 705 |
+
dataset = load_dataset("cais/mmlu", "all")
|
| 706 |
+
dataset = dataset[split]
|
| 707 |
+
|
| 708 |
+
# Ensure we have a proper Dataset object
|
| 709 |
+
if hasattr(dataset, 'select'):
|
| 710 |
+
self.dataset = dataset
|
| 711 |
+
else:
|
| 712 |
+
raise ValueError(f"Unexpected dataset type: {type(dataset)}")
|
| 713 |
+
|
| 714 |
+
# Limit samples if specified
|
| 715 |
+
if num_samples and num_samples < len(self.dataset):
|
| 716 |
+
self.dataset = self.dataset.select(range(num_samples))
|
| 717 |
+
|
| 718 |
+
# Apply total token length filtering on full chat (user + assistant)
|
| 719 |
+
if max_word_count is not None:
|
| 720 |
+
# Use a small tokenizer for speed; total token length = chat(user+assistant)
|
| 721 |
+
self._mmlu_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
| 722 |
+
extractor = lambda sample: self._build_chat_messages(sample)
|
| 723 |
+
filters = [create_text_length_filter(max_word_count, extractor, self._mmlu_tokenizer, use_tokens=True)]
|
| 724 |
+
filter_descriptions = [f"Token count filter (full chat): max {max_word_count}"]
|
| 725 |
+
self.dataset, _ = apply_batch_filters(self.dataset, filters, filter_descriptions)
|
| 726 |
+
|
| 727 |
+
print(f"Loaded {len(self.dataset)} samples")
|
| 728 |
+
|
| 729 |
+
def __len__(self):
|
| 730 |
+
return len(self.dataset)
|
| 731 |
+
|
| 732 |
+
def __getitem__(self, idx):
|
| 733 |
+
sample = self.dataset[idx]
|
| 734 |
+
return self._build_chat_messages(sample)
|
| 735 |
+
|
| 736 |
+
def _build_chat_messages(self, sample: Dict[str, Any]) -> List[Dict[str, str]]:
|
| 737 |
+
choice_labels = ['A', 'B', 'C', 'D']
|
| 738 |
+
question = sample.get('question', '')
|
| 739 |
+
choices_list = sample.get('choices', [])
|
| 740 |
+
user_prompt = f"Question: {question}\n\nChoices:\n"
|
| 741 |
+
for i, choice in enumerate(choices_list):
|
| 742 |
+
label = choice_labels[i] if i < len(choice_labels) else chr(65 + i)
|
| 743 |
+
user_prompt += f"{label}. {choice}\n"
|
| 744 |
+
ans_idx = sample.get('answer', 0)
|
| 745 |
+
if isinstance(ans_idx, str) and ans_idx.isdigit():
|
| 746 |
+
ans_idx = int(ans_idx)
|
| 747 |
+
ans_label = choice_labels[ans_idx] if 0 <= int(ans_idx) < len(choice_labels) else chr(65 + int(ans_idx))
|
| 748 |
+
assistant_text = f"The correct answer is {ans_label}."
|
| 749 |
+
return [
|
| 750 |
+
{"role": "user", "content": user_prompt.strip()},
|
| 751 |
+
{"role": "assistant", "content": assistant_text.strip()},
|
| 752 |
+
]
|
| 753 |
+
|
| 754 |
+
@register_dataset
|
| 755 |
+
@capture_init_args
|
| 756 |
+
class MMLUCotChatDataset(Dataset):
|
| 757 |
+
"""Simple MMLUCot dataset converted to chat format"""
|
| 758 |
+
|
| 759 |
+
def __init__(self, split: str = "train", num_samples: Optional[int] = None):
|
| 760 |
+
"""
|
| 761 |
+
Initialize the dataset
|
| 762 |
+
|
| 763 |
+
Args:
|
| 764 |
+
split: Dataset split
|
| 765 |
+
num_samples: Number of samples to use (None for all)
|
| 766 |
+
"""
|
| 767 |
+
print(f"Loading MMLUCot dataset (split: {split})...")
|
| 768 |
+
# Load dataset
|
| 769 |
+
dataset = load_dataset("Brench/MMLU-Pro-CoT-Train-43K")
|
| 770 |
+
dataset = dataset[split]
|
| 771 |
+
|
| 772 |
+
# Ensure we have a proper Dataset object
|
| 773 |
+
if hasattr(dataset, 'select'):
|
| 774 |
+
self.dataset = dataset
|
| 775 |
+
else:
|
| 776 |
+
raise ValueError(f"Unexpected dataset type: {type(dataset)}")
|
| 777 |
+
|
| 778 |
+
# Limit samples if specified
|
| 779 |
+
if num_samples and num_samples < len(self.dataset):
|
| 780 |
+
self.dataset = self.dataset.select(range(num_samples))
|
| 781 |
+
|
| 782 |
+
print(f"Loaded {len(self.dataset)} samples")
|
| 783 |
+
|
| 784 |
+
def __len__(self):
|
| 785 |
+
return len(self.dataset)
|
| 786 |
+
|
| 787 |
+
def __getitem__(self, idx):
|
| 788 |
+
sample = self.dataset[idx]
|
| 789 |
+
|
| 790 |
+
user_prompt = sample['question'] + "\n"
|
| 791 |
+
|
| 792 |
+
assistant_response = sample['chain_of_thoughts']
|
| 793 |
+
|
| 794 |
+
return [
|
| 795 |
+
{
|
| 796 |
+
"role": "user",
|
| 797 |
+
"content": user_prompt.strip()
|
| 798 |
+
},
|
| 799 |
+
{
|
| 800 |
+
"role": "assistant",
|
| 801 |
+
"content": assistant_response
|
| 802 |
+
}
|
| 803 |
+
]
|
| 804 |
+
|
| 805 |
+
@register_dataset
|
| 806 |
+
@capture_init_args
|
| 807 |
+
class LLMGeneratedChatDataset(Dataset):
|
| 808 |
+
"""Simple LLM Generated dataset converted to chat format"""
|
| 809 |
+
|
| 810 |
+
def __init__(self, split: str = "train", num_samples: Optional[int] = None, data_path: str = "./teacher_datasets/output/dataset_finished", max_word_count: Optional[int] = None):
|
| 811 |
+
"""
|
| 812 |
+
Initialize the dataset
|
| 813 |
+
|
| 814 |
+
Args:
|
| 815 |
+
split: Dataset split
|
| 816 |
+
num_samples: Number of samples to use (None for all)
|
| 817 |
+
"""
|
| 818 |
+
print(f"Loading LLMGeneratedCot dataset (split: {split})...")
|
| 819 |
+
# Load dataset
|
| 820 |
+
dataset = load_from_disk(data_path)
|
| 821 |
+
|
| 822 |
+
# Ensure we have a proper Dataset object
|
| 823 |
+
if hasattr(dataset, 'select'):
|
| 824 |
+
self.dataset = dataset
|
| 825 |
+
else:
|
| 826 |
+
raise ValueError(f"Unexpected dataset type: {type(dataset)}")
|
| 827 |
+
|
| 828 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
| 829 |
+
|
| 830 |
+
if max_word_count is not None:
|
| 831 |
+
original_len = len(self.dataset)
|
| 832 |
+
half = max_word_count // 2
|
| 833 |
+
def _under_token_limit(batch):
|
| 834 |
+
q = tokenizer(batch["input_text"], add_special_tokens=False, padding=False, truncation=False)
|
| 835 |
+
a = tokenizer(batch["model_response"], add_special_tokens=False, padding=False, truncation=False)
|
| 836 |
+
return [
|
| 837 |
+
(len(q_ids) <= half) and (len(q_ids) + len(a_ids) <= max_word_count)
|
| 838 |
+
for q_ids, a_ids in zip(q["input_ids"], a["input_ids"])
|
| 839 |
+
]
|
| 840 |
+
|
| 841 |
+
self.dataset = self.dataset.filter(
|
| 842 |
+
_under_token_limit,
|
| 843 |
+
batched=True,
|
| 844 |
+
batch_size=2048, # 视显存/内存调大
|
| 845 |
+
num_proc=min(8, os.cpu_count() or 1),
|
| 846 |
+
load_from_cache_file=True,
|
| 847 |
+
desc=f"Filter max_word_count={max_word_count}",
|
| 848 |
+
)
|
| 849 |
+
print(f"Filtered by max_word_count={max_word_count}: {original_len} -> {len(self.dataset)} samples")
|
| 850 |
+
|
| 851 |
+
# Limit samples if specified
|
| 852 |
+
if num_samples and num_samples < len(self.dataset):
|
| 853 |
+
self.dataset = self.dataset.select(range(num_samples))
|
| 854 |
+
|
| 855 |
+
print(f"Loaded {len(self.dataset)} samples")
|
| 856 |
+
|
| 857 |
+
def __len__(self):
|
| 858 |
+
return len(self.dataset)
|
| 859 |
+
|
| 860 |
+
def __getitem__(self, idx):
|
| 861 |
+
sample = self.dataset[idx]
|
| 862 |
+
|
| 863 |
+
input_text = sample.get('input_text', '') or ''
|
| 864 |
+
|
| 865 |
+
# Parse question and choices from input_text, which is expected to contain a
|
| 866 |
+
# "Choices:" section followed by labeled options like "A. ..."
|
| 867 |
+
def _parse_question_and_choices(text: str):
|
| 868 |
+
lines = (text or '').splitlines()
|
| 869 |
+
# Find the line index for "Choices:" (case-insensitive, ignoring spaces)
|
| 870 |
+
choices_idx = -1
|
| 871 |
+
for i, line in enumerate(lines):
|
| 872 |
+
if line.strip().lower().startswith('choices'):
|
| 873 |
+
choices_idx = i
|
| 874 |
+
break
|
| 875 |
+
|
| 876 |
+
if choices_idx == -1:
|
| 877 |
+
# Fallback: no explicit Choices header found
|
| 878 |
+
question_part = text.strip()
|
| 879 |
+
return question_part, ''
|
| 880 |
+
|
| 881 |
+
question_part = '\n'.join(lines[:choices_idx]).strip()
|
| 882 |
+
|
| 883 |
+
# Collect labeled choices until blank line or instruction-like line
|
| 884 |
+
collected = []
|
| 885 |
+
for raw in lines[choices_idx + 1:]:
|
| 886 |
+
s = raw.strip()
|
| 887 |
+
if not s:
|
| 888 |
+
# Stop on first blank after having collected at least one choice
|
| 889 |
+
if collected:
|
| 890 |
+
break
|
| 891 |
+
else:
|
| 892 |
+
continue
|
| 893 |
+
lower = s.lower()
|
| 894 |
+
# Stop when hitting instruction section common in prompts
|
| 895 |
+
if lower.startswith('instructions:') or lower.startswith("let's ") or lower.startswith('you must'):
|
| 896 |
+
break
|
| 897 |
+
# Accept formats like "A. ..." or "A) ..."
|
| 898 |
+
if len(s) >= 3 and s[0] in 'ABCDEFGHIJ' and s[1] in ').' and s[2] == ' ':
|
| 899 |
+
collected.append(s)
|
| 900 |
+
else:
|
| 901 |
+
# If we've started collecting and this line doesn't look like a choice, stop
|
| 902 |
+
if collected:
|
| 903 |
+
break
|
| 904 |
+
# Otherwise ignore preamble noise
|
| 905 |
+
continue
|
| 906 |
+
|
| 907 |
+
choices_block = '\n'.join(collected).strip()
|
| 908 |
+
return question_part, choices_block
|
| 909 |
+
|
| 910 |
+
question, choices_block = _parse_question_and_choices(input_text)
|
| 911 |
+
|
| 912 |
+
# Rebuild user prompt using the evaluation CoT template
|
| 913 |
+
template = """Accurately answer the following question:
|
| 914 |
+
|
| 915 |
+
{{question}}
|
| 916 |
+
|
| 917 |
+
Choices:
|
| 918 |
+
{{choices}}
|
| 919 |
+
|
| 920 |
+
Instructions:
|
| 921 |
+
- Carefully read the question and all options.
|
| 922 |
+
- Let's think step by step and you must explain your reasoning briefly.
|
| 923 |
+
- Then give the final answer.
|
| 924 |
+
- Keep your response within 150 words."""
|
| 925 |
+
|
| 926 |
+
filled_prompt = (
|
| 927 |
+
template
|
| 928 |
+
.replace("{{question}}", question or '')
|
| 929 |
+
.replace("{{choices}}", choices_block or '')
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
user_prompt = filled_prompt.strip() + "\n"
|
| 933 |
+
|
| 934 |
+
assistant_response = sample['model_response']
|
| 935 |
+
|
| 936 |
+
return [
|
| 937 |
+
{
|
| 938 |
+
"role": "user",
|
| 939 |
+
"content": user_prompt.strip()
|
| 940 |
+
},
|
| 941 |
+
{
|
| 942 |
+
"role": "assistant",
|
| 943 |
+
"content": assistant_response
|
| 944 |
+
}
|
| 945 |
+
]
|
| 946 |
+
|
| 947 |
+
@register_dataset
|
| 948 |
+
@capture_init_args
|
| 949 |
+
class OpenBookChatDataset(Dataset):
|
| 950 |
+
"""Simple OpenBook dataset converted to chat format"""
|
| 951 |
+
|
| 952 |
+
def __init__(self, split: str = "train", num_samples: Optional[int] = None):
|
| 953 |
+
"""
|
| 954 |
+
Initialize the dataset
|
| 955 |
+
|
| 956 |
+
Args:
|
| 957 |
+
split: Dataset split
|
| 958 |
+
num_samples: Number of samples to use (None for all)
|
| 959 |
+
"""
|
| 960 |
+
print(f"Loading OpenBook dataset (split: {split})...")
|
| 961 |
+
# Load dataset
|
| 962 |
+
dataset = load_dataset("allenai/openbookqa", "main")
|
| 963 |
+
dataset = dataset[split]
|
| 964 |
+
|
| 965 |
+
# Ensure we have a proper Dataset object
|
| 966 |
+
if hasattr(dataset, 'select'):
|
| 967 |
+
self.dataset = dataset
|
| 968 |
+
else:
|
| 969 |
+
raise ValueError(f"Unexpected dataset type: {type(dataset)}")
|
| 970 |
+
|
| 971 |
+
# Limit samples if specified
|
| 972 |
+
if num_samples and num_samples < len(self.dataset):
|
| 973 |
+
self.dataset = self.dataset.select(range(num_samples))
|
| 974 |
+
|
| 975 |
+
print(f"Loaded {len(self.dataset)} samples")
|
| 976 |
+
|
| 977 |
+
def __len__(self):
|
| 978 |
+
return len(self.dataset)
|
| 979 |
+
|
| 980 |
+
def __getitem__(self, idx):
|
| 981 |
+
sample = self.dataset[idx]
|
| 982 |
+
choice_labels = ['A', 'B', 'C', 'D']
|
| 983 |
+
|
| 984 |
+
user_prompt = (
|
| 985 |
+
f"Question: {sample['question_stem']}\n\n"
|
| 986 |
+
f"Choices:\n"
|
| 987 |
+
)
|
| 988 |
+
for idx, choice in enumerate(sample['choices']['text']):
|
| 989 |
+
label = choice_labels[idx]
|
| 990 |
+
user_prompt += f"{label}. {choice}\n"
|
| 991 |
+
|
| 992 |
+
correct_label = sample["answerKey"]
|
| 993 |
+
assistant_response = f"The correct answer is {correct_label}."
|
| 994 |
+
|
| 995 |
+
return [
|
| 996 |
+
{
|
| 997 |
+
"role": "user",
|
| 998 |
+
"content": user_prompt.strip()
|
| 999 |
+
},
|
| 1000 |
+
{
|
| 1001 |
+
"role": "assistant",
|
| 1002 |
+
"content": assistant_response
|
| 1003 |
+
}
|
| 1004 |
+
]
|
| 1005 |
+
|
| 1006 |
+
@register_dataset
|
| 1007 |
+
@capture_init_args
|
| 1008 |
+
class OpenHermesChatDataset(Dataset):
|
| 1009 |
+
"""Simple general dataset converted to chat format"""
|
| 1010 |
+
|
| 1011 |
+
def __init__(self, split: str = "train", num_samples: Optional[int] = None, max_word_count: Optional[int] = None, min_conversation_turns: int = 0):
|
| 1012 |
+
"""
|
| 1013 |
+
Initialize the dataset
|
| 1014 |
+
|
| 1015 |
+
Args:
|
| 1016 |
+
split: Dataset split
|
| 1017 |
+
num_samples: Number of samples to use (None for all)
|
| 1018 |
+
max_word_count: Maximum token count for filtering
|
| 1019 |
+
min_conversation_turns: Minimum number of conversation turns (default 3 for multi-turn conversations)
|
| 1020 |
+
"""
|
| 1021 |
+
print(f"Loading OpenHermes dataset (split: {split})...")
|
| 1022 |
+
# Load dataset
|
| 1023 |
+
dataset = load_dataset("teknium/OpenHermes-2.5")
|
| 1024 |
+
dataset = dataset[split]
|
| 1025 |
+
|
| 1026 |
+
# Ensure we have a proper Dataset object
|
| 1027 |
+
if hasattr(dataset, 'select'):
|
| 1028 |
+
self.dataset = dataset
|
| 1029 |
+
else:
|
| 1030 |
+
raise ValueError(f"Unexpected dataset type: {type(dataset)}")
|
| 1031 |
+
|
| 1032 |
+
# Limit samples if specified
|
| 1033 |
+
if num_samples and num_samples < len(self.dataset):
|
| 1034 |
+
self.dataset = self.dataset.select(range(num_samples))
|
| 1035 |
+
|
| 1036 |
+
# Apply filters
|
| 1037 |
+
filters = []
|
| 1038 |
+
filter_descriptions = []
|
| 1039 |
+
|
| 1040 |
+
# Filter by minimum conversation length (exclude conversations with <= 2 messages)
|
| 1041 |
+
if min_conversation_turns > 0:
|
| 1042 |
+
filters.append(create_conversation_length_filter(min_conversation_turns - 1, 'conversations'))
|
| 1043 |
+
filter_descriptions.append(f"Conversation length filter: min {min_conversation_turns} messages (multi-turn only)")
|
| 1044 |
+
|
| 1045 |
+
# Apply conversation-level token count filtering (all messages combined <= max_word_count)
|
| 1046 |
+
if max_word_count is not None:
|
| 1047 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
| 1048 |
+
extractor = lambda sample: extract_openhermes_messages(sample, 'conversations')
|
| 1049 |
+
filters.append(create_text_length_filter(max_word_count, extractor, tokenizer, use_tokens=True))
|
| 1050 |
+
filter_descriptions.append(f"Token count filter: max {max_word_count}")
|
| 1051 |
+
|
| 1052 |
+
# Apply all filters
|
| 1053 |
+
if filters:
|
| 1054 |
+
self.dataset, _ = apply_batch_filters(self.dataset, filters, filter_descriptions, num_proc=8)
|
| 1055 |
+
|
| 1056 |
+
print(f"Loaded {len(self.dataset)} samples")
|
| 1057 |
+
|
| 1058 |
+
def __len__(self):
|
| 1059 |
+
return len(self.dataset)
|
| 1060 |
+
|
| 1061 |
+
def __getitem__(self, idx):
|
| 1062 |
+
sample = self.dataset[idx]
|
| 1063 |
+
return extract_openhermes_messages(sample, 'conversations')
|
| 1064 |
+
|
| 1065 |
+
"""
|
| 1066 |
+
Chat dataset
|
| 1067 |
+
|
| 1068 |
+
Convert standard message format to input_ids and labels
|
| 1069 |
+
"""
|
| 1070 |
+
class ChatDataset(Dataset):
|
| 1071 |
+
"""Dataset for chat format training with HuggingFace Trainer compatibility"""
|
| 1072 |
+
|
| 1073 |
+
def __init__(self, chat_dataset, tokenizer: AutoTokenizer, max_length: int = 32768):
|
| 1074 |
+
self.chat_dataset = chat_dataset
|
| 1075 |
+
self.tokenizer = tokenizer
|
| 1076 |
+
self.max_length = max_length
|
| 1077 |
+
|
| 1078 |
+
def __len__(self):
|
| 1079 |
+
return len(self.chat_dataset)
|
| 1080 |
+
|
| 1081 |
+
def __getitem__(self, idx) -> Dict[str, Any]:
|
| 1082 |
+
messages = self.chat_dataset[idx]
|
| 1083 |
+
|
| 1084 |
+
# Get instruction (first message)
|
| 1085 |
+
instruction = self.tokenizer.apply_chat_template(
|
| 1086 |
+
messages[:-1],
|
| 1087 |
+
tokenize=False,
|
| 1088 |
+
add_generation_prompt=True,
|
| 1089 |
+
enable_thinking=False,
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
# Get full conversation
|
| 1093 |
+
full_text = self.tokenizer.apply_chat_template(
|
| 1094 |
+
messages,
|
| 1095 |
+
tokenize=False,
|
| 1096 |
+
add_generation_prompt=False,
|
| 1097 |
+
enable_thinking=False,
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
# Tokenize instruction and full text
|
| 1101 |
+
instruction_tokens = self.tokenizer(instruction, add_special_tokens=False)["input_ids"]
|
| 1102 |
+
full_tokens = self.tokenizer(full_text, add_special_tokens=False)["input_ids"]
|
| 1103 |
+
|
| 1104 |
+
# Truncate if necessary
|
| 1105 |
+
if len(full_tokens) > self.max_length:
|
| 1106 |
+
full_tokens = full_tokens[:self.max_length]
|
| 1107 |
+
|
| 1108 |
+
# Create labels (-100 for instruction tokens, actual tokens for response)
|
| 1109 |
+
labels = [-100] * len(instruction_tokens) + full_tokens[len(instruction_tokens):]
|
| 1110 |
+
# labels = [-100] * (len(full_tokens) - 4) + full_tokens[-4:]
|
| 1111 |
+
if len(labels) > self.max_length:
|
| 1112 |
+
labels = labels[:self.max_length]
|
| 1113 |
+
|
| 1114 |
+
kv_cache_index = generate_kv_cache_index(len(instruction_tokens), len(full_tokens))
|
| 1115 |
+
# kv_cache_index = generate_kv_cache_index(len(full_tokens)-4, len(full_tokens))
|
| 1116 |
+
# kv_cache_index = generate_kv_cache_index(len(full_tokens) + 1, len(full_tokens))
|
| 1117 |
+
|
| 1118 |
+
return {
|
| 1119 |
+
"input_ids": full_tokens,
|
| 1120 |
+
"labels": labels,
|
| 1121 |
+
"kv_cache_index": kv_cache_index
|
| 1122 |
+
}
|
| 1123 |
+
|
| 1124 |
+
|
| 1125 |
+
class AlignedChatDataset(Dataset):
|
| 1126 |
+
"""Dataset that precomputes aligned inputs for SLM/LLM using a TokenAligner"""
|
| 1127 |
+
|
| 1128 |
+
def __init__(self, instruct_dataset: Dataset, aligner: Any, max_length: int = 32768):
|
| 1129 |
+
self.dataset = instruct_dataset
|
| 1130 |
+
self.aligner = aligner
|
| 1131 |
+
self.max_length = max_length
|
| 1132 |
+
|
| 1133 |
+
def __len__(self):
|
| 1134 |
+
return len(self.dataset)
|
| 1135 |
+
|
| 1136 |
+
def __getitem__(self, idx):
|
| 1137 |
+
messages = self.dataset[idx]
|
| 1138 |
+
|
| 1139 |
+
# Build aligned sequences and section map
|
| 1140 |
+
details = self.aligner.align_chat_messages(messages, add_generation_prompt=False, return_details=True)
|
| 1141 |
+
slm_ids: List[int] = details['slm_ids_padded']
|
| 1142 |
+
llm_ids: List[int] = details['llm_ids_padded']
|
| 1143 |
+
sections = details['sections']
|
| 1144 |
+
|
| 1145 |
+
slm_pad_mask = torch.tensor(details['slm_padding_mask'])
|
| 1146 |
+
llm_pad_mask = torch.tensor(details['llm_padding_mask'])
|
| 1147 |
+
message_mask = torch.tensor(details['message_mask'])
|
| 1148 |
+
|
| 1149 |
+
# Determine instruction boundary as start of the last message section
|
| 1150 |
+
instr_end = 0
|
| 1151 |
+
for sec_idx in range(len(sections) - 1, -1, -1):
|
| 1152 |
+
sec = sections[sec_idx]
|
| 1153 |
+
if sec['type'] == 'message':
|
| 1154 |
+
instr_end = sec['slm_range'][0]
|
| 1155 |
+
break
|
| 1156 |
+
|
| 1157 |
+
# Labels: follow ChatDataset policy (-100 for instruction-only, supervise the rest)
|
| 1158 |
+
labels = [-100] * instr_end + slm_ids[instr_end:]
|
| 1159 |
+
if len(labels) > self.max_length:
|
| 1160 |
+
labels = labels[:self.max_length]
|
| 1161 |
+
|
| 1162 |
+
# Truncate inputs if needed
|
| 1163 |
+
if len(slm_ids) > self.max_length:
|
| 1164 |
+
slm_ids = slm_ids[:self.max_length]
|
| 1165 |
+
# Truncate padding mask accordingly
|
| 1166 |
+
slm_pad_mask = slm_pad_mask[:self.max_length]
|
| 1167 |
+
if len(llm_ids) > self.max_length:
|
| 1168 |
+
llm_ids = llm_ids[:self.max_length]
|
| 1169 |
+
llm_pad_mask = llm_pad_mask[:self.max_length]
|
| 1170 |
+
|
| 1171 |
+
# KV cache index based on instruction length
|
| 1172 |
+
kv_cache_index = generate_kv_cache_index(instr_end, len(slm_ids))
|
| 1173 |
+
# Addtionally mask non-message parts
|
| 1174 |
+
kv_cache_index[~message_mask] = torch.tensor([[-1,0]])
|
| 1175 |
+
|
| 1176 |
+
return {
|
| 1177 |
+
"input_ids": [slm_ids, llm_ids],
|
| 1178 |
+
"labels": labels,
|
| 1179 |
+
"kv_cache_index": kv_cache_index,
|
| 1180 |
+
"messages": messages,
|
| 1181 |
+
# Per-model aligned inputs (per-sample, pre-batch)
|
| 1182 |
+
"model_padding_mask": [slm_pad_mask, llm_pad_mask],
|
| 1183 |
+
}
|
| 1184 |
+
|
| 1185 |
+
|
| 1186 |
+
class BaselineChatDataset(Dataset):
|
| 1187 |
+
"""Simple dataset for baseline model training without Rosetta-specific features"""
|
| 1188 |
+
|
| 1189 |
+
def __init__(self, chat_dataset, tokenizer: AutoTokenizer, max_length: int = 2048):
|
| 1190 |
+
self.chat_dataset = chat_dataset
|
| 1191 |
+
self.tokenizer = tokenizer
|
| 1192 |
+
self.max_length = max_length
|
| 1193 |
+
|
| 1194 |
+
def __len__(self):
|
| 1195 |
+
return len(self.chat_dataset)
|
| 1196 |
+
|
| 1197 |
+
def __getitem__(self, idx):
|
| 1198 |
+
messages = self.chat_dataset[idx]
|
| 1199 |
+
|
| 1200 |
+
# Get instruction (first message)
|
| 1201 |
+
instruction = self.tokenizer.apply_chat_template(
|
| 1202 |
+
messages[:1],
|
| 1203 |
+
tokenize=False,
|
| 1204 |
+
add_generation_prompt=True,
|
| 1205 |
+
enable_thinking=False,
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
+
# Get full conversation
|
| 1209 |
+
full_text = self.tokenizer.apply_chat_template(
|
| 1210 |
+
messages,
|
| 1211 |
+
tokenize=False,
|
| 1212 |
+
add_generation_prompt=False,
|
| 1213 |
+
enable_thinking=False,
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
# Tokenize instruction and full text
|
| 1217 |
+
instruction_tokens = self.tokenizer(instruction, add_special_tokens=False)["input_ids"]
|
| 1218 |
+
full_tokens = self.tokenizer(full_text, add_special_tokens=False)["input_ids"]
|
| 1219 |
+
|
| 1220 |
+
# Truncate if necessary
|
| 1221 |
+
if len(full_tokens) > self.max_length:
|
| 1222 |
+
full_tokens = full_tokens[:self.max_length]
|
| 1223 |
+
|
| 1224 |
+
# Create labels (-100 for instruction tokens, actual tokens for response)
|
| 1225 |
+
labels = [-100] * len(instruction_tokens) + full_tokens[len(instruction_tokens):]
|
| 1226 |
+
if len(labels) > self.max_length:
|
| 1227 |
+
labels = labels[:self.max_length]
|
| 1228 |
+
|
| 1229 |
+
return {
|
| 1230 |
+
"input_ids": full_tokens,
|
| 1231 |
+
"labels": labels,
|
| 1232 |
+
}
|
| 1233 |
+
|
| 1234 |
+
"""
|
| 1235 |
+
Data collator
|
| 1236 |
+
|
| 1237 |
+
Batch chat data to model input
|
| 1238 |
+
"""
|
| 1239 |
+
|
| 1240 |
+
class RosettaDataCollator:
|
| 1241 |
+
"""Improved data collator for RosettaModel training with cleaner logic"""
|
| 1242 |
+
|
| 1243 |
+
def __init__(self, slm_tokenizer: AutoTokenizer, llm_tokenizer: AutoTokenizer = None,
|
| 1244 |
+
pad_to_multiple_of: Optional[int] = None, max_length: Optional[int] = None,
|
| 1245 |
+
aligner: Optional[Any] = None, do_alignment: bool = False):
|
| 1246 |
+
"""
|
| 1247 |
+
Initialize the collator.
|
| 1248 |
+
|
| 1249 |
+
Args:
|
| 1250 |
+
slm_tokenizer: Small language model tokenizer
|
| 1251 |
+
llm_tokenizer: Large language model tokenizer (optional)
|
| 1252 |
+
pad_to_multiple_of: Pad sequence length to multiple of this value
|
| 1253 |
+
max_length: Maximum sequence length
|
| 1254 |
+
aligner: Alignment module (if needed)
|
| 1255 |
+
do_alignment: Whether to perform alignment
|
| 1256 |
+
"""
|
| 1257 |
+
self.slm_tokenizer = slm_tokenizer
|
| 1258 |
+
self.llm_tokenizer = llm_tokenizer
|
| 1259 |
+
self.pad_to_multiple_of = pad_to_multiple_of
|
| 1260 |
+
self.max_length = max_length
|
| 1261 |
+
self.aligner = aligner
|
| 1262 |
+
self.do_alignment = do_alignment
|
| 1263 |
+
|
| 1264 |
+
if self.do_alignment:
|
| 1265 |
+
assert self.aligner is not None, "Aligner must be provided if do_alignment is True"
|
| 1266 |
+
|
| 1267 |
+
# Store padding token IDs for different models
|
| 1268 |
+
self.slm_pad_token_id = self.slm_tokenizer.pad_token_id
|
| 1269 |
+
self.llm_pad_token_id = self.llm_tokenizer.pad_token_id if self.llm_tokenizer else self.slm_pad_token_id
|
| 1270 |
+
|
| 1271 |
+
def _normalize_input_format(self, feature: Dict[str, Any]) -> Dict[str, Any]:
|
| 1272 |
+
"""
|
| 1273 |
+
Normalize input format to handle both single and dual model inputs.
|
| 1274 |
+
|
| 1275 |
+
Args:
|
| 1276 |
+
feature: Input feature dictionary
|
| 1277 |
+
|
| 1278 |
+
Returns:
|
| 1279 |
+
Normalized feature with consistent format
|
| 1280 |
+
"""
|
| 1281 |
+
# Normalize input_ids: ensure it's always a list of tensors
|
| 1282 |
+
input_ids = feature['input_ids']
|
| 1283 |
+
if isinstance(input_ids, list) and len(input_ids) > 0:
|
| 1284 |
+
if isinstance(input_ids[0], list):
|
| 1285 |
+
# Case: [[ids1], [ids2]] -> convert to list of tensors
|
| 1286 |
+
input_ids_tensors = [torch.tensor(ids, dtype=torch.long) for ids in input_ids]
|
| 1287 |
+
else:
|
| 1288 |
+
# Case: [id1, id2, ...] -> single model case
|
| 1289 |
+
input_ids_tensors = [torch.tensor(input_ids, dtype=torch.long)]
|
| 1290 |
+
else:
|
| 1291 |
+
# Fallback: assume single model
|
| 1292 |
+
input_ids_tensors = [torch.tensor(input_ids, dtype=torch.long)]
|
| 1293 |
+
|
| 1294 |
+
# Normalize attention_mask
|
| 1295 |
+
attention_masks = []
|
| 1296 |
+
if "model_padding_mask" in feature:
|
| 1297 |
+
# Use model-specific padding masks
|
| 1298 |
+
for model_padding_mask in feature["model_padding_mask"]:
|
| 1299 |
+
attention_masks.append((~model_padding_mask).float())
|
| 1300 |
+
else:
|
| 1301 |
+
# Generate default attention masks
|
| 1302 |
+
for input_tensor in input_ids_tensors:
|
| 1303 |
+
attention_masks.append(torch.ones(len(input_tensor), dtype=torch.float))
|
| 1304 |
+
|
| 1305 |
+
return {
|
| 1306 |
+
'input_ids': input_ids_tensors,
|
| 1307 |
+
'attention_mask': attention_masks,
|
| 1308 |
+
'labels': torch.tensor(feature['labels'], dtype=torch.long),
|
| 1309 |
+
'kv_cache_index': feature['kv_cache_index'],
|
| 1310 |
+
'position_ids': torch.arange(len(feature['labels']), dtype=torch.long)
|
| 1311 |
+
}
|
| 1312 |
+
|
| 1313 |
+
def _split_into_sections(self, normalized_feature: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 1314 |
+
"""
|
| 1315 |
+
Split sequence into sections based on kv_cache_index changes.
|
| 1316 |
+
|
| 1317 |
+
Args:
|
| 1318 |
+
normalized_feature: Normalized feature dictionary
|
| 1319 |
+
|
| 1320 |
+
Returns:
|
| 1321 |
+
List of sections
|
| 1322 |
+
"""
|
| 1323 |
+
kv_idx = normalized_feature['kv_cache_index']
|
| 1324 |
+
|
| 1325 |
+
# Find change points in kv_cache_index
|
| 1326 |
+
change_points = [0]
|
| 1327 |
+
for i in range(1, kv_idx.size(0)):
|
| 1328 |
+
if not torch.equal(kv_idx[i], kv_idx[i - 1]):
|
| 1329 |
+
change_points.append(i)
|
| 1330 |
+
change_points.append(kv_idx.size(0))
|
| 1331 |
+
|
| 1332 |
+
# Create sections
|
| 1333 |
+
sections = []
|
| 1334 |
+
for i in range(len(change_points) - 1):
|
| 1335 |
+
start, end = change_points[i], change_points[i + 1]
|
| 1336 |
+
section = {
|
| 1337 |
+
'input_ids': [ids[start:end] for ids in normalized_feature['input_ids']],
|
| 1338 |
+
'attention_mask': [mask[start:end] for mask in normalized_feature['attention_mask']],
|
| 1339 |
+
'labels': normalized_feature['labels'][start:end],
|
| 1340 |
+
'kv_cache_index': normalized_feature['kv_cache_index'][start:end],
|
| 1341 |
+
'position_ids': normalized_feature['position_ids'][start:end]
|
| 1342 |
+
}
|
| 1343 |
+
sections.append(section)
|
| 1344 |
+
|
| 1345 |
+
return sections
|
| 1346 |
+
|
| 1347 |
+
def _pad_sections(self, all_sections: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
|
| 1348 |
+
"""
|
| 1349 |
+
Pad sections to ensure uniform structure across batch.
|
| 1350 |
+
|
| 1351 |
+
Args:
|
| 1352 |
+
all_sections: List of section lists for each sample
|
| 1353 |
+
|
| 1354 |
+
Returns:
|
| 1355 |
+
Padded batch dictionary
|
| 1356 |
+
"""
|
| 1357 |
+
max_sections = max(len(sections) for sections in all_sections)
|
| 1358 |
+
num_models = len(all_sections[0][0]['input_ids']) if all_sections else 1
|
| 1359 |
+
|
| 1360 |
+
# Initialize output structure - keep models separate throughout
|
| 1361 |
+
padded_output = {
|
| 1362 |
+
'input_ids_per_model': [[] for _ in range(num_models)], # One list per model
|
| 1363 |
+
'attention_mask_per_model': [[] for _ in range(num_models)], # One list per model
|
| 1364 |
+
'labels': [],
|
| 1365 |
+
'kv_cache_index': [],
|
| 1366 |
+
'position_ids': []
|
| 1367 |
+
}
|
| 1368 |
+
|
| 1369 |
+
# Process each section index
|
| 1370 |
+
for sec_idx in range(max_sections):
|
| 1371 |
+
section_data = self._collect_section_data(all_sections, sec_idx, num_models)
|
| 1372 |
+
padded_section = self._pad_single_section(section_data, num_models)
|
| 1373 |
+
|
| 1374 |
+
# Add to output - keep models separate
|
| 1375 |
+
for model_idx in range(num_models):
|
| 1376 |
+
padded_output['input_ids_per_model'][model_idx].append(
|
| 1377 |
+
padded_section['input_ids_per_model'][model_idx])
|
| 1378 |
+
padded_output['attention_mask_per_model'][model_idx].append(
|
| 1379 |
+
padded_section['attention_mask_per_model'][model_idx])
|
| 1380 |
+
|
| 1381 |
+
padded_output['labels'].append(padded_section['labels'])
|
| 1382 |
+
padded_output['kv_cache_index'].append(padded_section['kv_cache_index'])
|
| 1383 |
+
padded_output['position_ids'].append(padded_section['position_ids'])
|
| 1384 |
+
|
| 1385 |
+
# Concatenate sections and finalize
|
| 1386 |
+
return self._finalize_output(padded_output, num_models, len(all_sections))
|
| 1387 |
+
|
| 1388 |
+
def _collect_section_data(self, all_sections: List[List[Dict[str, Any]]],
|
| 1389 |
+
sec_idx: int, num_models: int) -> Dict[str, List]:
|
| 1390 |
+
"""Collect data for a specific section across all samples."""
|
| 1391 |
+
# Separate collections for each model to avoid confusion
|
| 1392 |
+
section_data = {
|
| 1393 |
+
'input_ids_per_model': [[] for _ in range(num_models)], # [[slm_seqs], [llm_seqs]]
|
| 1394 |
+
'attention_mask_per_model': [[] for _ in range(num_models)],
|
| 1395 |
+
'labels': [],
|
| 1396 |
+
'kv_cache_index': [],
|
| 1397 |
+
'position_ids': []
|
| 1398 |
+
}
|
| 1399 |
+
|
| 1400 |
+
for sample_sections in all_sections:
|
| 1401 |
+
# Some samples may have fewer sections; create default empty tensors when missing
|
| 1402 |
+
if sec_idx < len(sample_sections):
|
| 1403 |
+
sec = sample_sections[sec_idx]
|
| 1404 |
+
for model_idx in range(num_models):
|
| 1405 |
+
section_data['input_ids_per_model'][model_idx].append(sec['input_ids'][model_idx])
|
| 1406 |
+
section_data['attention_mask_per_model'][model_idx].append(sec['attention_mask'][model_idx])
|
| 1407 |
+
section_data['labels'].append(sec['labels'])
|
| 1408 |
+
section_data['kv_cache_index'].append(sec['kv_cache_index'])
|
| 1409 |
+
section_data['position_ids'].append(sec['position_ids'])
|
| 1410 |
+
else:
|
| 1411 |
+
# Default empty tensors; downstream pad_sequence will pad appropriately
|
| 1412 |
+
for model_idx in range(num_models):
|
| 1413 |
+
section_data['input_ids_per_model'][model_idx].append(torch.tensor([], dtype=torch.long))
|
| 1414 |
+
section_data['attention_mask_per_model'][model_idx].append(torch.tensor([], dtype=torch.float))
|
| 1415 |
+
section_data['labels'].append(torch.tensor([], dtype=torch.long))
|
| 1416 |
+
section_data['kv_cache_index'].append(torch.empty((0, 2), dtype=torch.long))
|
| 1417 |
+
section_data['position_ids'].append(torch.tensor([], dtype=torch.long))
|
| 1418 |
+
|
| 1419 |
+
return section_data
|
| 1420 |
+
|
| 1421 |
+
def _pad_single_section(self, section_data: Dict[str, List], num_models: int) -> Dict[str, Any]:
|
| 1422 |
+
"""Pad tensors within a single section."""
|
| 1423 |
+
# Pad input_ids separately for each model with their respective pad tokens
|
| 1424 |
+
padded_input_ids_per_model = []
|
| 1425 |
+
padded_attention_mask_per_model = []
|
| 1426 |
+
|
| 1427 |
+
for model_idx in range(num_models):
|
| 1428 |
+
pad_token_id = self.slm_pad_token_id if model_idx == 0 else self.llm_pad_token_id
|
| 1429 |
+
|
| 1430 |
+
# Pad input_ids for this model
|
| 1431 |
+
padded_input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 1432 |
+
section_data['input_ids_per_model'][model_idx],
|
| 1433 |
+
batch_first=True,
|
| 1434 |
+
padding_value=pad_token_id
|
| 1435 |
+
)
|
| 1436 |
+
padded_input_ids_per_model.append(padded_input_ids)
|
| 1437 |
+
|
| 1438 |
+
# Pad attention_mask for this model
|
| 1439 |
+
padded_attention_mask = torch.nn.utils.rnn.pad_sequence(
|
| 1440 |
+
section_data['attention_mask_per_model'][model_idx],
|
| 1441 |
+
batch_first=True,
|
| 1442 |
+
padding_value=0
|
| 1443 |
+
)
|
| 1444 |
+
padded_attention_mask_per_model.append(padded_attention_mask)
|
| 1445 |
+
|
| 1446 |
+
# Standard padding for other tensors
|
| 1447 |
+
padded_labels = torch.nn.utils.rnn.pad_sequence(
|
| 1448 |
+
section_data['labels'], batch_first=True, padding_value=-100)
|
| 1449 |
+
padded_kv_cache = torch.nn.utils.rnn.pad_sequence(
|
| 1450 |
+
section_data['kv_cache_index'], batch_first=True, padding_value=-1)
|
| 1451 |
+
padded_position_ids = torch.nn.utils.rnn.pad_sequence(
|
| 1452 |
+
section_data['position_ids'], batch_first=True, padding_value=0)
|
| 1453 |
+
|
| 1454 |
+
return {
|
| 1455 |
+
'input_ids_per_model': padded_input_ids_per_model, # Keep separate per model
|
| 1456 |
+
'attention_mask_per_model': padded_attention_mask_per_model, # Keep separate per model
|
| 1457 |
+
'labels': padded_labels,
|
| 1458 |
+
'kv_cache_index': padded_kv_cache,
|
| 1459 |
+
'position_ids': padded_position_ids,
|
| 1460 |
+
'num_models': num_models
|
| 1461 |
+
}
|
| 1462 |
+
|
| 1463 |
+
def _finalize_output(self, padded_output: Dict[str, List],
|
| 1464 |
+
num_models: int, batch_size: int) -> Dict[str, Any]:
|
| 1465 |
+
"""Finalize the output by concatenating sections - keep models separate throughout."""
|
| 1466 |
+
final_output = {}
|
| 1467 |
+
|
| 1468 |
+
# Handle input_ids and attention_mask - keep separate per model
|
| 1469 |
+
if num_models == 1:
|
| 1470 |
+
# Single model case: concatenate sections for the single model
|
| 1471 |
+
final_output['input_ids'] = torch.cat(padded_output['input_ids_per_model'][0], dim=1)
|
| 1472 |
+
final_output['attention_mask'] = torch.cat(padded_output['attention_mask_per_model'][0], dim=1)
|
| 1473 |
+
else:
|
| 1474 |
+
# Multi-model case: keep as list of tensors, one per model
|
| 1475 |
+
final_output['input_ids'] = [
|
| 1476 |
+
torch.cat(padded_output['input_ids_per_model'][model_idx], dim=1)
|
| 1477 |
+
for model_idx in range(num_models)
|
| 1478 |
+
]
|
| 1479 |
+
final_output['attention_mask'] = [
|
| 1480 |
+
torch.cat(padded_output['attention_mask_per_model'][model_idx], dim=1)
|
| 1481 |
+
for model_idx in range(num_models)
|
| 1482 |
+
]
|
| 1483 |
+
|
| 1484 |
+
# Concatenate other tensors normally
|
| 1485 |
+
final_output['labels'] = torch.cat(padded_output['labels'], dim=1)
|
| 1486 |
+
final_output['position_ids'] = torch.cat(padded_output['position_ids'], dim=1)
|
| 1487 |
+
final_output['kv_cache_index'] = padded_output['kv_cache_index'] # Keep as list of sections
|
| 1488 |
+
|
| 1489 |
+
return final_output
|
| 1490 |
+
|
| 1491 |
+
def _apply_length_constraints(self, output: Dict[str, Any]) -> Dict[str, Any]:
|
| 1492 |
+
"""Apply max_length truncation if specified."""
|
| 1493 |
+
if self.max_length is None:
|
| 1494 |
+
return output
|
| 1495 |
+
|
| 1496 |
+
# Determine current sequence length
|
| 1497 |
+
if isinstance(output['input_ids'], list):
|
| 1498 |
+
seq_length = output['input_ids'][0].size(1)
|
| 1499 |
+
else:
|
| 1500 |
+
seq_length = output['input_ids'].size(1)
|
| 1501 |
+
|
| 1502 |
+
if seq_length <= self.max_length:
|
| 1503 |
+
return output
|
| 1504 |
+
|
| 1505 |
+
# Truncate sequences
|
| 1506 |
+
if isinstance(output['input_ids'], list):
|
| 1507 |
+
output['input_ids'] = [ids[:, :self.max_length] for ids in output['input_ids']]
|
| 1508 |
+
output['attention_mask'] = [mask[:, :self.max_length] for mask in output['attention_mask']]
|
| 1509 |
+
else:
|
| 1510 |
+
output['input_ids'] = output['input_ids'][:, :self.max_length]
|
| 1511 |
+
output['attention_mask'] = output['attention_mask'][:, :self.max_length]
|
| 1512 |
+
|
| 1513 |
+
output['labels'] = output['labels'][:, :self.max_length]
|
| 1514 |
+
output['position_ids'] = output['position_ids'][:, :self.max_length]
|
| 1515 |
+
|
| 1516 |
+
# Truncate kv_cache_index sections appropriately
|
| 1517 |
+
output['kv_cache_index'] = self._truncate_kv_cache_sections(
|
| 1518 |
+
output['kv_cache_index'], self.max_length)
|
| 1519 |
+
|
| 1520 |
+
return output
|
| 1521 |
+
|
| 1522 |
+
def _truncate_kv_cache_sections(self, kv_cache_sections: List[torch.Tensor],
|
| 1523 |
+
max_length: int) -> List[torch.Tensor]:
|
| 1524 |
+
"""Truncate kv_cache sections to fit within max_length."""
|
| 1525 |
+
truncated_sections = []
|
| 1526 |
+
current_pos = 0
|
| 1527 |
+
|
| 1528 |
+
for section in kv_cache_sections:
|
| 1529 |
+
section_length = section.size(1)
|
| 1530 |
+
remaining_length = max_length - current_pos
|
| 1531 |
+
|
| 1532 |
+
if remaining_length <= 0:
|
| 1533 |
+
break
|
| 1534 |
+
elif remaining_length >= section_length:
|
| 1535 |
+
truncated_sections.append(section)
|
| 1536 |
+
current_pos += section_length
|
| 1537 |
+
else:
|
| 1538 |
+
truncated_section = section[:, :remaining_length]
|
| 1539 |
+
truncated_sections.append(truncated_section)
|
| 1540 |
+
break
|
| 1541 |
+
|
| 1542 |
+
return truncated_sections
|
| 1543 |
+
|
| 1544 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 1545 |
+
"""
|
| 1546 |
+
Main collation function with improved logic.
|
| 1547 |
+
|
| 1548 |
+
Args:
|
| 1549 |
+
features: List of feature dictionaries from dataset
|
| 1550 |
+
|
| 1551 |
+
Returns:
|
| 1552 |
+
Batched and padded output dictionary
|
| 1553 |
+
"""
|
| 1554 |
+
if not features:
|
| 1555 |
+
return {}
|
| 1556 |
+
|
| 1557 |
+
# Step 1: Normalize input format for all features
|
| 1558 |
+
normalized_features = [self._normalize_input_format(feat) for feat in features]
|
| 1559 |
+
|
| 1560 |
+
# Step 2: Split each feature into sections
|
| 1561 |
+
all_sections = [self._split_into_sections(feat) for feat in normalized_features]
|
| 1562 |
+
|
| 1563 |
+
# Step 3: Pad sections to create uniform batch structure
|
| 1564 |
+
output = self._pad_sections(all_sections)
|
| 1565 |
+
|
| 1566 |
+
# Step 4: Apply length constraints if needed
|
| 1567 |
+
output = self._apply_length_constraints(output)
|
| 1568 |
+
|
| 1569 |
+
return output
|
| 1570 |
+
|
| 1571 |
+
|
| 1572 |
+
class BaselineDataCollator:
|
| 1573 |
+
"""Custom data collator for baseline model training"""
|
| 1574 |
+
|
| 1575 |
+
def __init__(self, tokenizer: AutoTokenizer, pad_to_multiple_of: Optional[int] = None):
|
| 1576 |
+
self.tokenizer = tokenizer
|
| 1577 |
+
self.pad_to_multiple_of = pad_to_multiple_of
|
| 1578 |
+
|
| 1579 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
| 1580 |
+
# Extract input_ids and labels
|
| 1581 |
+
input_ids = [f["input_ids"] for f in features]
|
| 1582 |
+
labels = [f["labels"] for f in features]
|
| 1583 |
+
|
| 1584 |
+
# Find max length in batch
|
| 1585 |
+
max_length = max(len(ids) for ids in input_ids)
|
| 1586 |
+
|
| 1587 |
+
# Apply pad_to_multiple_of if specified
|
| 1588 |
+
if self.pad_to_multiple_of is not None:
|
| 1589 |
+
max_length = ((max_length + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of) * self.pad_to_multiple_of
|
| 1590 |
+
|
| 1591 |
+
# Pad sequences
|
| 1592 |
+
batch_input_ids = []
|
| 1593 |
+
batch_labels = []
|
| 1594 |
+
batch_attention_mask = []
|
| 1595 |
+
|
| 1596 |
+
for ids, lbls in zip(input_ids, labels):
|
| 1597 |
+
# Pad input_ids
|
| 1598 |
+
padded_ids = ids + [self.tokenizer.pad_token_id] * (max_length - len(ids))
|
| 1599 |
+
batch_input_ids.append(padded_ids)
|
| 1600 |
+
|
| 1601 |
+
# Pad labels (use -100 for padding)
|
| 1602 |
+
padded_labels = lbls + [-100] * (max_length - len(lbls))
|
| 1603 |
+
batch_labels.append(padded_labels)
|
| 1604 |
+
|
| 1605 |
+
# Create attention mask
|
| 1606 |
+
attention_mask = [1] * len(ids) + [0] * (max_length - len(ids))
|
| 1607 |
+
batch_attention_mask.append(attention_mask)
|
| 1608 |
+
|
| 1609 |
+
return {
|
| 1610 |
+
"input_ids": torch.tensor(batch_input_ids, dtype=torch.long),
|
| 1611 |
+
"labels": torch.tensor(batch_labels, dtype=torch.long),
|
| 1612 |
+
"attention_mask": torch.tensor(batch_attention_mask, dtype=torch.long),
|
| 1613 |
+
}
|
| 1614 |
+
|
| 1615 |
+
|
| 1616 |
+
|
| 1617 |
+
"""
|
| 1618 |
+
Helper functions
|
| 1619 |
+
"""
|
| 1620 |
+
|
| 1621 |
+
|
| 1622 |
+
def create_dataset(dataset_type: str, **kwargs) -> Dataset:
|
| 1623 |
+
"""
|
| 1624 |
+
Factory function to create a dataset based on type.
|
| 1625 |
+
|
| 1626 |
+
Args:
|
| 1627 |
+
dataset_type: String indicating the type of dataset
|
| 1628 |
+
**kwargs: Additional arguments to pass to the dataset constructor
|
| 1629 |
+
|
| 1630 |
+
Returns:
|
| 1631 |
+
An instance of the appropriate dataset
|
| 1632 |
+
"""
|
| 1633 |
+
# First, check if dataset_type is directly in the registry (exact match)
|
| 1634 |
+
if dataset_type in DATASET_REGISTRY:
|
| 1635 |
+
return DATASET_REGISTRY[dataset_type](**kwargs)
|
| 1636 |
+
|
| 1637 |
+
# Then check for case-insensitive match
|
| 1638 |
+
dataset_type_lower = dataset_type.lower()
|
| 1639 |
+
if dataset_type_lower in DATASET_REGISTRY:
|
| 1640 |
+
return DATASET_REGISTRY[dataset_type_lower](**kwargs)
|
| 1641 |
+
|
| 1642 |
+
# If not found in registry, raise an error with valid options
|
| 1643 |
+
valid_options = list(
|
| 1644 |
+
set([name for name, cls in DATASET_REGISTRY.items() if name == cls.__name__])
|
| 1645 |
+
) # Only include actual class names
|
| 1646 |
+
raise ValueError(
|
| 1647 |
+
f"Unknown dataset type: {dataset_type}. Valid options are: {valid_options}"
|
| 1648 |
+
)
|
rosetta/train/model_utils.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model setup utilities for RosettaModel training/evaluation
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from typing import Dict, Any, List
|
| 7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
+
|
| 9 |
+
from rosetta.model.wrapper import RosettaModel
|
| 10 |
+
from rosetta.model.projector import create_projector
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Mapping strategies
|
| 14 |
+
"""
|
| 15 |
+
def k_nearest_sources(num_target_layers: int, num_source_layers: int, k: int) -> Dict[int, List[int]]:
|
| 16 |
+
"""
|
| 17 |
+
Compute a per-target mapping to K nearest source layers.
|
| 18 |
+
|
| 19 |
+
Returns: Dict[target_idx, List[source_idx]] only for targets we map.
|
| 20 |
+
Distances are computed by placing target and source layers uniformly in [0, 1]
|
| 21 |
+
and sorting by absolute distance.
|
| 22 |
+
"""
|
| 23 |
+
if num_target_layers <= 1:
|
| 24 |
+
target_positions = [0.0]
|
| 25 |
+
else:
|
| 26 |
+
target_positions = [i / (num_target_layers - 1) for i in range(num_target_layers)]
|
| 27 |
+
if num_source_layers <= 1:
|
| 28 |
+
source_positions = [0.0]
|
| 29 |
+
else:
|
| 30 |
+
source_positions = [j / (num_source_layers - 1) for j in range(num_source_layers)]
|
| 31 |
+
|
| 32 |
+
mapping: Dict[int, List[int]] = {}
|
| 33 |
+
for t_idx, t_pos in enumerate(target_positions):
|
| 34 |
+
sorted_src = sorted(range(num_source_layers), key=lambda j: abs(source_positions[j] - t_pos))
|
| 35 |
+
chosen = sorted_src[:max(0, k)]
|
| 36 |
+
if len(chosen) > 0:
|
| 37 |
+
mapping[t_idx] = chosen
|
| 38 |
+
return mapping
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def last_aligned_sources(num_target_layers: int, num_source_layers: int, k: int = 1) -> Dict[int, List[int]]:
|
| 42 |
+
"""
|
| 43 |
+
Return a per-target mapping that aligns the last target layer to the last
|
| 44 |
+
source layer and walks toward the front.
|
| 45 |
+
|
| 46 |
+
Returns: Dict[target_idx, List[source_idx]] only for targets we map. For each
|
| 47 |
+
target t, we choose up to K sources anchored at the aligned index, preferring
|
| 48 |
+
backward indices first then forward to satisfy K.
|
| 49 |
+
|
| 50 |
+
Example (T=11, S=33): target 10 -> [32, 31, ...], target 9 -> [31, 30, ...]
|
| 51 |
+
"""
|
| 52 |
+
mapping: Dict[int, List[int]] = {}
|
| 53 |
+
if num_target_layers <= 0 or num_source_layers <= 0:
|
| 54 |
+
return mapping
|
| 55 |
+
|
| 56 |
+
# Align ends; offset >= 0 means extra source layers at the front
|
| 57 |
+
offset = num_source_layers - num_target_layers
|
| 58 |
+
|
| 59 |
+
def take_k_from(s0: int) -> List[int]:
|
| 60 |
+
result: List[int] = []
|
| 61 |
+
# Prefer moving backward from the anchor (last-to-front)
|
| 62 |
+
for back in range(k):
|
| 63 |
+
idx = s0 - back
|
| 64 |
+
if 0 <= idx < num_source_layers:
|
| 65 |
+
result.append(idx)
|
| 66 |
+
# If not enough due to boundary, extend forward
|
| 67 |
+
next_idx = s0 + 1
|
| 68 |
+
while len(result) < k and next_idx < num_source_layers:
|
| 69 |
+
result.append(next_idx)
|
| 70 |
+
next_idx += 1
|
| 71 |
+
return result
|
| 72 |
+
|
| 73 |
+
for t in range(num_target_layers):
|
| 74 |
+
s0 = offset + t
|
| 75 |
+
# Clamp to valid range for edge cases (e.g., fewer source layers)
|
| 76 |
+
if s0 < 0:
|
| 77 |
+
s0 = 0
|
| 78 |
+
elif s0 > num_source_layers - 1:
|
| 79 |
+
s0 = num_source_layers - 1
|
| 80 |
+
chosen = take_k_from(s0)
|
| 81 |
+
if len(chosen) > 0:
|
| 82 |
+
mapping[t] = chosen
|
| 83 |
+
|
| 84 |
+
return mapping
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def setup_models(model_config: Dict[str, Any], device: str = "cuda", dtype: torch.dtype = torch.bfloat16):
|
| 88 |
+
"""Setup RosettaModel with base model, teacher model, and projectors"""
|
| 89 |
+
|
| 90 |
+
# Load tokenizer
|
| 91 |
+
tokenizer = AutoTokenizer.from_pretrained(model_config["base_model"])
|
| 92 |
+
if tokenizer.pad_token is None:
|
| 93 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 94 |
+
|
| 95 |
+
# Load models
|
| 96 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 97 |
+
model_config["base_model"],
|
| 98 |
+
torch_dtype=dtype,
|
| 99 |
+
device_map=device
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
teacher_model = AutoModelForCausalLM.from_pretrained(
|
| 103 |
+
model_config["teacher_model"],
|
| 104 |
+
torch_dtype=dtype,
|
| 105 |
+
device_map=device
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Create projector
|
| 109 |
+
projector_config = model_config["projector"]
|
| 110 |
+
projector_params = projector_config["params"].copy()
|
| 111 |
+
projector_params["dtype"] = dtype
|
| 112 |
+
|
| 113 |
+
projector = create_projector(
|
| 114 |
+
projector_config["type"],
|
| 115 |
+
source_dim=teacher_model.config.head_dim,
|
| 116 |
+
target_dim=base_model.config.head_dim,
|
| 117 |
+
**projector_params
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Setup RosettaModel
|
| 121 |
+
rosetta_model = RosettaModel(
|
| 122 |
+
model_list=[base_model, teacher_model],
|
| 123 |
+
base_model_idx=0,
|
| 124 |
+
projector_list=[projector]
|
| 125 |
+
).to(device)
|
| 126 |
+
|
| 127 |
+
# Configure projector mappings
|
| 128 |
+
num_layers_to_map = min(
|
| 129 |
+
base_model.config.num_hidden_layers,
|
| 130 |
+
teacher_model.config.num_hidden_layers
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
for layer_idx in range(num_layers_to_map):
|
| 134 |
+
rosetta_model.set_projector_config(
|
| 135 |
+
source_model_idx=1, # Teacher
|
| 136 |
+
source_model_layer_idx=layer_idx,
|
| 137 |
+
target_model_idx=0, # Base
|
| 138 |
+
target_model_layer_idx=layer_idx,
|
| 139 |
+
projector_idx=0
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return rosetta_model, tokenizer
|
rosetta/utils/__pycache__/evaluate.cpython-310.pyc
ADDED
|
Binary file (15.8 kB). View file
|
|
|
rosetta/utils/__pycache__/evaluate.cpython-312.pyc
ADDED
|
Binary file (24 kB). View file
|
|
|
rosetta/utils/__pycache__/multi_stage.cpython-310.pyc
ADDED
|
Binary file (5.22 kB). View file
|
|
|
rosetta/utils/__pycache__/multi_stage.cpython-312.pyc
ADDED
|
Binary file (7.49 kB). View file
|
|
|
rosetta/utils/__pycache__/registry.cpython-310.pyc
ADDED
|
Binary file (7.43 kB). View file
|
|
|
rosetta/utils/__pycache__/registry.cpython-312.pyc
ADDED
|
Binary file (9.98 kB). View file
|
|
|
rosetta/utils/evaluate.py
ADDED
|
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Common evaluation utilities for benchmark tasks.
|
| 3 |
+
|
| 4 |
+
This module provides shared functions for model evaluation across different benchmarks
|
| 5 |
+
like MMLU-Redux and MMMLU.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
import os
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import numpy as np
|
| 13 |
+
from typing import Dict, Any, List, Tuple, Optional
|
| 14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 15 |
+
|
| 16 |
+
from rosetta.model.projector import load_projector
|
| 17 |
+
from rosetta.model.wrapper import RosettaModel
|
| 18 |
+
from rosetta.model.oracle import OracleRosettaModel
|
| 19 |
+
|
| 20 |
+
def build_prompt(dataset: str, locale: str, question: str, choices: str, use_cot: bool, use_template: bool = True) -> str:
|
| 21 |
+
"""
|
| 22 |
+
Build a localized prompt for a given dataset and locale.
|
| 23 |
+
|
| 24 |
+
Currently supports:
|
| 25 |
+
- dataset: "mmmlu"
|
| 26 |
+
- locale: "SW_KE" (Swahili). Other locales fall back to English.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
dataset: Dataset identifier (e.g., "mmmlu")
|
| 30 |
+
locale: Locale/subject code (e.g., "SW_KE")
|
| 31 |
+
question: Question text
|
| 32 |
+
choices: Formatted choices string
|
| 33 |
+
use_cot: Whether to include CoT instruction
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Localized prompt string
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
# Unified default English templates (shared by MMLU and MMMLU)
|
| 40 |
+
if not use_cot:
|
| 41 |
+
template = """Accurately answer the following question:
|
| 42 |
+
|
| 43 |
+
{{question}}
|
| 44 |
+
|
| 45 |
+
Choices:
|
| 46 |
+
{{choices}}
|
| 47 |
+
|
| 48 |
+
Instructions:
|
| 49 |
+
- Carefully read the question and all options.
|
| 50 |
+
- Select the single most correct answer.
|
| 51 |
+
- Respond ONLY in the following format: "The correct answer is A/B/C/D".
|
| 52 |
+
- Do not include any explanations, additional text, or punctuation besides the answer.
|
| 53 |
+
|
| 54 |
+
The correct answer is"""
|
| 55 |
+
|
| 56 |
+
else:
|
| 57 |
+
template = """Accurately answer the following question:
|
| 58 |
+
|
| 59 |
+
{{question}}
|
| 60 |
+
|
| 61 |
+
Choices:
|
| 62 |
+
{{choices}}
|
| 63 |
+
|
| 64 |
+
Instructions:
|
| 65 |
+
- Carefully read the question and all options.
|
| 66 |
+
- Let's think step by step and explain your reasoning briefly.
|
| 67 |
+
- Then give the final answer starting with The correct answer is"""
|
| 68 |
+
|
| 69 |
+
prompt = template.replace("{{question}}", question)
|
| 70 |
+
prompt = prompt.replace("{{choices}}", choices)
|
| 71 |
+
|
| 72 |
+
if not use_template:
|
| 73 |
+
prompt = question + "\n\nChoices:\n" + choices
|
| 74 |
+
|
| 75 |
+
return prompt
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def parse_answer(answer_str: str) -> List[str]:
|
| 79 |
+
"""
|
| 80 |
+
Parse answer string to extract valid answer options.
|
| 81 |
+
Converts digits 0/1/2/3 to letters A/B/C/D.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
answer_str: String containing answer digits
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
List of parsed answer letters
|
| 88 |
+
"""
|
| 89 |
+
if not isinstance(answer_str, str):
|
| 90 |
+
return []
|
| 91 |
+
valid_digits = [c for c in answer_str if c in {'0','1','2','3'}]
|
| 92 |
+
return sorted(list({
|
| 93 |
+
chr(65 + int(d)) # 0->A, 1->B, 2->C, 3->D
|
| 94 |
+
for d in valid_digits
|
| 95 |
+
}))
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def extract_answer_from_content(text: str) -> Optional[str]:
|
| 99 |
+
"""
|
| 100 |
+
Extract answer from model output with robust multi-pattern matching.
|
| 101 |
+
Supports multiple languages and response formats.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
text: Model output text
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Extracted answer letter or None
|
| 108 |
+
"""
|
| 109 |
+
text = text.strip()
|
| 110 |
+
if not text:
|
| 111 |
+
return None
|
| 112 |
+
|
| 113 |
+
# Define multiple answer patterns for different languages and formats
|
| 114 |
+
answer_patterns = [
|
| 115 |
+
# English patterns
|
| 116 |
+
r'Answer:\s*(.*)',
|
| 117 |
+
r'answer:\s*(.*)',
|
| 118 |
+
r'ANSWER:\s*(.*)',
|
| 119 |
+
r'Your answer:\s*(.*)',
|
| 120 |
+
r'your answer:\s*(.*)',
|
| 121 |
+
r'YOUR ANSWER:\s*(.*)',
|
| 122 |
+
r'The answer is\s*(.*)',
|
| 123 |
+
r'the answer is\s*(.*)',
|
| 124 |
+
r'THE ANSWER IS\s*(.*)',
|
| 125 |
+
r'Correct answer is\s*(.*)',
|
| 126 |
+
r'correct answer is\s*(.*)',
|
| 127 |
+
r'Correct answer is:\s*(.*)',
|
| 128 |
+
r'correct answer is:\s*(.*)',
|
| 129 |
+
r'Correct answer:\s*(.*)',
|
| 130 |
+
r'correct answer:\s*(.*)',
|
| 131 |
+
r'CORRECT ANSWER:\s*(.*)',
|
| 132 |
+
|
| 133 |
+
# Swahili patterns
|
| 134 |
+
r'Jibu lako:\s*(.*)',
|
| 135 |
+
r'jibu lako:\s*(.*)',
|
| 136 |
+
r'JIBU LAKO:\s*(.*)',
|
| 137 |
+
r'Jibu:\s*(.*)',
|
| 138 |
+
r'jibu:\s*(.*)',
|
| 139 |
+
r'JIBU:\s*(.*)',
|
| 140 |
+
r'Jibu sahihi:\s*(.*)',
|
| 141 |
+
r'jibu sahihi:\s*(.*)',
|
| 142 |
+
r'JIBU SAHIHI:\s*(.*)',
|
| 143 |
+
|
| 144 |
+
# Other common patterns
|
| 145 |
+
r'Response:\s*(.*)',
|
| 146 |
+
r'response:\s*(.*)',
|
| 147 |
+
r'RESPONSE:\s*(.*)',
|
| 148 |
+
r'Choice:\s*(.*)',
|
| 149 |
+
r'choice:\s*(.*)',
|
| 150 |
+
r'CHOICE:\s*(.*)',
|
| 151 |
+
r'Option:\s*(.*)',
|
| 152 |
+
r'option:\s*(.*)',
|
| 153 |
+
r'OPTION:\s*(.*)',
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
# 1. Try to match any of the answer patterns
|
| 157 |
+
for pattern in answer_patterns:
|
| 158 |
+
match = re.search(pattern, text, re.IGNORECASE)
|
| 159 |
+
if match:
|
| 160 |
+
answer_part = match.group(1).strip()
|
| 161 |
+
# Search for first A-D letter in the matched part
|
| 162 |
+
for char in answer_part:
|
| 163 |
+
if char in {'A', 'B', 'C', 'D'}:
|
| 164 |
+
return char
|
| 165 |
+
|
| 166 |
+
# 2. Look for standalone A-D letters that are likely answers
|
| 167 |
+
# Prioritize letters at the end of text or with clear answer-like context
|
| 168 |
+
standalone_patterns = [
|
| 169 |
+
r'\b([A-D])(?:\s*[.,!?:)]?\s*$)', # A-D at end of text with optional punctuation
|
| 170 |
+
r'\b([A-D])(?:\s*[.,!?:)]\s)', # A-D followed by punctuation and space
|
| 171 |
+
r'(?:^|\s)([A-D])(?:\s*$)', # A-D at start or with word boundary at end
|
| 172 |
+
]
|
| 173 |
+
|
| 174 |
+
for pattern in standalone_patterns:
|
| 175 |
+
matches = re.findall(pattern, text, re.IGNORECASE)
|
| 176 |
+
if matches:
|
| 177 |
+
# Check if this looks like mathematical expressions rather than answers
|
| 178 |
+
math_indicators = ['+', '-', '*', '/', '=', '^', 'x^', 'y^', 'z^', 'mod', 'sqrt', 'sin', 'cos', 'tan']
|
| 179 |
+
has_math = any(indicator in text for indicator in math_indicators)
|
| 180 |
+
has_answer_indicators = any(phrase in text.lower() for phrase in ['jibu', 'answer', 'choice', 'option', 'response', 'correct', 'sahihi'])
|
| 181 |
+
|
| 182 |
+
# If it has math indicators but no answer indicators, it's likely mathematical notation
|
| 183 |
+
if has_math and not has_answer_indicators:
|
| 184 |
+
continue # Skip this match, try next pattern
|
| 185 |
+
|
| 186 |
+
return matches[-1].upper()
|
| 187 |
+
|
| 188 |
+
# 3. Fallback: find all A-D letters but be more selective
|
| 189 |
+
all_letters = re.findall(r'\b([A-D])\b', text, re.IGNORECASE)
|
| 190 |
+
if all_letters:
|
| 191 |
+
# Check if this looks like mathematical expressions rather than answers
|
| 192 |
+
math_indicators = ['+', '-', '*', '/', '=', '^', 'x^', 'y^', 'z^', 'mod', 'sqrt', 'sin', 'cos', 'tan']
|
| 193 |
+
has_math = any(indicator in text for indicator in math_indicators)
|
| 194 |
+
has_answer_indicators = any(phrase in text.lower() for phrase in ['jibu', 'answer', 'choice', 'option', 'response', 'correct', 'sahihi'])
|
| 195 |
+
|
| 196 |
+
# If it has math indicators but no answer indicators, it's likely mathematical notation
|
| 197 |
+
if has_math and not has_answer_indicators:
|
| 198 |
+
return None
|
| 199 |
+
|
| 200 |
+
# Otherwise, return the last letter found
|
| 201 |
+
return all_letters[-1].upper()
|
| 202 |
+
|
| 203 |
+
# 3. Search backwards for any A-D letter as fallback
|
| 204 |
+
for char in reversed(text):
|
| 205 |
+
if char in {'A', 'B', 'C', 'D'}:
|
| 206 |
+
return char
|
| 207 |
+
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def apply_generation_config(model: Any, generation_config: Optional[Dict[str, Any]] = None) -> None:
|
| 212 |
+
"""
|
| 213 |
+
Apply generation configuration to a model and handle sampling parameters.
|
| 214 |
+
|
| 215 |
+
This function applies the provided generation config to the model and removes
|
| 216 |
+
sampling parameters (temperature, top_p, top_k, min_p) when do_sample=False
|
| 217 |
+
to avoid warnings from the transformers library. If no config is provided,
|
| 218 |
+
it defaults to greedy decoding with cleaned sampling parameters.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
model: Model object with generation_config attribute
|
| 222 |
+
generation_config: Optional generation configuration dictionary.
|
| 223 |
+
If None, defaults to greedy decoding (do_sample=False).
|
| 224 |
+
"""
|
| 225 |
+
if not hasattr(model, 'generation_config'):
|
| 226 |
+
return
|
| 227 |
+
|
| 228 |
+
# If no config provided, default to greedy decoding
|
| 229 |
+
if not generation_config:
|
| 230 |
+
generation_config = {'do_sample': False}
|
| 231 |
+
|
| 232 |
+
# Apply all configuration parameters
|
| 233 |
+
for key, value in generation_config.items():
|
| 234 |
+
setattr(model.generation_config, key, value)
|
| 235 |
+
|
| 236 |
+
# Disable sampling parameters if do_sample=False to avoid warnings
|
| 237 |
+
# We set them to None instead of deleting, since some model code may
|
| 238 |
+
# access these attributes unconditionally.
|
| 239 |
+
if not generation_config.get('do_sample', True):
|
| 240 |
+
sampling_params = ['temperature', 'top_p', 'top_k', 'min_p', 'repetition_penalty']
|
| 241 |
+
for param in sampling_params:
|
| 242 |
+
try:
|
| 243 |
+
setattr(model.generation_config, param, None)
|
| 244 |
+
except Exception:
|
| 245 |
+
# If the backend does not allow setting, ignore silently
|
| 246 |
+
pass
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def set_default_chat_template(tokenizer, model_name: str):
|
| 250 |
+
"""
|
| 251 |
+
Set default chat template for models without one.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
tokenizer: Tokenizer object
|
| 255 |
+
model_name: Name of the model
|
| 256 |
+
"""
|
| 257 |
+
if tokenizer.chat_template is None:
|
| 258 |
+
if "UlizaLlama3".lower() in model_name.lower():
|
| 259 |
+
tokenizer.chat_template = (
|
| 260 |
+
"{%- for message in messages %}"
|
| 261 |
+
"{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}"
|
| 262 |
+
"{{- message['content'] }}"
|
| 263 |
+
"{{- '<|eot_id|>' }}"
|
| 264 |
+
"{%- endfor %}"
|
| 265 |
+
"{%- if add_generation_prompt %}"
|
| 266 |
+
"{{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}"
|
| 267 |
+
"{%- endif %}"
|
| 268 |
+
)
|
| 269 |
+
else:
|
| 270 |
+
print(f"Model {model_name} has no chat template, setting default template...")
|
| 271 |
+
default_template = """{% for message in messages %}{% if message['role'] == 'user' %}### Human: {{ message['content'] }}{% elif message['role'] == 'assistant' %}### Assistant: {{ message['content'] }}{% endif %}{% if not loop.last %}
|
| 272 |
+
{% endif %}{% endfor %}{% if add_generation_prompt %}
|
| 273 |
+
### Assistant:{% endif %}"""
|
| 274 |
+
tokenizer.chat_template = default_template
|
| 275 |
+
print("Default chat template has been set.")
|
| 276 |
+
else:
|
| 277 |
+
print(f"Model {model_name} already has a chat template.")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def load_hf_model(model_name: str, device: torch.device, generation_config: Optional[Dict[str, Any]] = None) -> Tuple[Any, Any]:
|
| 281 |
+
"""
|
| 282 |
+
Load Hugging Face model and tokenizer.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
model_name: Model name or path
|
| 286 |
+
device: Device to load model on
|
| 287 |
+
generation_config: Optional generation configuration to apply
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
Tuple of (model, tokenizer)
|
| 291 |
+
"""
|
| 292 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 293 |
+
str(model_name),
|
| 294 |
+
trust_remote_code=True,
|
| 295 |
+
padding_side='left'
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if tokenizer.pad_token is None:
|
| 299 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 300 |
+
|
| 301 |
+
# Check and set chat template
|
| 302 |
+
set_default_chat_template(tokenizer, model_name)
|
| 303 |
+
|
| 304 |
+
if model_name == "google/gemma-3-1b-it":
|
| 305 |
+
torch._dynamo.config.cache_size_limit = 64
|
| 306 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 307 |
+
str(model_name),
|
| 308 |
+
torch_dtype=torch.bfloat16,
|
| 309 |
+
device_map={"": device},
|
| 310 |
+
sliding_window=4096
|
| 311 |
+
).eval()
|
| 312 |
+
else:
|
| 313 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 314 |
+
str(model_name),
|
| 315 |
+
torch_dtype=torch.bfloat16,
|
| 316 |
+
device_map={"": device}
|
| 317 |
+
).eval()
|
| 318 |
+
|
| 319 |
+
# Apply generation config
|
| 320 |
+
apply_generation_config(model, generation_config)
|
| 321 |
+
|
| 322 |
+
return model, tokenizer
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def load_rosetta_model(model_config: Dict[str, Any], eval_config: Dict[str, Any],
|
| 326 |
+
device: torch.device, generation_config: Optional[Dict[str, Any]] = None) -> Tuple[Any, Any]:
|
| 327 |
+
"""
|
| 328 |
+
Load Rosetta model with projectors and aggregators.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
model_config: Model configuration dict
|
| 332 |
+
eval_config: Evaluation configuration dict
|
| 333 |
+
device: Device to load model on
|
| 334 |
+
generation_config: Optional generation configuration to apply
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
Tuple of (rosetta_model, tokenizer)
|
| 338 |
+
"""
|
| 339 |
+
# Prefer checkpoints_dir under model.rosetta_config; fall back to eval config for backward compatibility
|
| 340 |
+
rosetta_config = model_config["rosetta_config"]
|
| 341 |
+
checkpoint_dir = rosetta_config.get("checkpoints_dir", eval_config.get("checkpoints_dir"))
|
| 342 |
+
if checkpoint_dir is None:
|
| 343 |
+
raise KeyError("checkpoints_dir must be provided under model.rosetta_config (preferred) or eval config (legacy)")
|
| 344 |
+
slm_model_path = rosetta_config["base_model"]
|
| 345 |
+
llm_model_path = rosetta_config["teacher_model"]
|
| 346 |
+
|
| 347 |
+
# Load tokenizer
|
| 348 |
+
slm_tokenizer = AutoTokenizer.from_pretrained(str(slm_model_path))
|
| 349 |
+
set_default_chat_template(slm_tokenizer, slm_model_path)
|
| 350 |
+
|
| 351 |
+
# Load models
|
| 352 |
+
slm_model = AutoModelForCausalLM.from_pretrained(
|
| 353 |
+
str(slm_model_path),
|
| 354 |
+
torch_dtype=torch.bfloat16,
|
| 355 |
+
device_map={"": device}
|
| 356 |
+
).eval()
|
| 357 |
+
|
| 358 |
+
# Apply generation config to SLM
|
| 359 |
+
apply_generation_config(slm_model, generation_config)
|
| 360 |
+
|
| 361 |
+
if llm_model_path == "google/gemma-3-1b-it":
|
| 362 |
+
llm_model = AutoModelForCausalLM.from_pretrained(
|
| 363 |
+
str(llm_model_path),
|
| 364 |
+
torch_dtype=torch.bfloat16,
|
| 365 |
+
device_map={"": device},
|
| 366 |
+
sliding_window=4096
|
| 367 |
+
).eval()
|
| 368 |
+
else:
|
| 369 |
+
llm_model = AutoModelForCausalLM.from_pretrained(
|
| 370 |
+
str(llm_model_path),
|
| 371 |
+
torch_dtype=torch.bfloat16,
|
| 372 |
+
device_map={"": device}
|
| 373 |
+
).eval()
|
| 374 |
+
|
| 375 |
+
# Apply generation config to LLM
|
| 376 |
+
apply_generation_config(llm_model, generation_config)
|
| 377 |
+
|
| 378 |
+
# Load projectors
|
| 379 |
+
num_projectors = len([f for f in os.listdir(checkpoint_dir) if re.match(r"projector_\d+\.pt", f)])
|
| 380 |
+
projector_list = []
|
| 381 |
+
for t in range(num_projectors):
|
| 382 |
+
json_cfg = os.path.join(checkpoint_dir, f"projector_{t}.json")
|
| 383 |
+
proj = load_projector(json_cfg)
|
| 384 |
+
proj = proj.to(device)
|
| 385 |
+
pt_path = os.path.join(checkpoint_dir, f"projector_{t}.pt")
|
| 386 |
+
if os.path.exists(pt_path):
|
| 387 |
+
state_dict = torch.load(pt_path, map_location=device)
|
| 388 |
+
proj.load_state_dict(state_dict, strict=False)
|
| 389 |
+
projector_list.append(proj)
|
| 390 |
+
|
| 391 |
+
aggregator_list = []
|
| 392 |
+
|
| 393 |
+
# Initialize Rosetta model
|
| 394 |
+
rosetta_model = RosettaModel(
|
| 395 |
+
model_list=[slm_model, llm_model],
|
| 396 |
+
base_model_idx=0,
|
| 397 |
+
projector_list=projector_list,
|
| 398 |
+
aggregator_list=aggregator_list,
|
| 399 |
+
).to(device).eval()
|
| 400 |
+
|
| 401 |
+
# Load projector/aggregator mapping configs
|
| 402 |
+
proj_cfg_path = os.path.join(checkpoint_dir, "projector_config.json")
|
| 403 |
+
agg_cfg_path = os.path.join(checkpoint_dir, "aggregator_config.json")
|
| 404 |
+
rosetta_model.load_projector_config(proj_cfg_path)
|
| 405 |
+
rosetta_model.load_aggregator_config(agg_cfg_path)
|
| 406 |
+
|
| 407 |
+
return rosetta_model, slm_tokenizer
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def load_oracle_rosetta_model(model_config: Dict[str, Any], eval_config: Dict[str, Any],
|
| 411 |
+
device: torch.device) -> Tuple[Any, Any]:
|
| 412 |
+
"""
|
| 413 |
+
Load Rosetta model with projectors and aggregators.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
model_config: Model configuration dict
|
| 417 |
+
eval_config: Evaluation configuration dict
|
| 418 |
+
device: Device to load model on
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
Tuple of (rosetta_model, tokenizer)
|
| 422 |
+
"""
|
| 423 |
+
# Prefer checkpoints_dir under model.rosetta_config; fall back to eval config for backward compatibility
|
| 424 |
+
rosetta_config = model_config["rosetta_config"]
|
| 425 |
+
checkpoint_dir = rosetta_config.get("checkpoints_dir", eval_config.get("checkpoints_dir"))
|
| 426 |
+
if checkpoint_dir is None:
|
| 427 |
+
raise KeyError("checkpoints_dir must be provided under model.rosetta_config (preferred) or eval config (legacy)")
|
| 428 |
+
slm_model_path = rosetta_config["base_model"]
|
| 429 |
+
llm_model_path = rosetta_config["teacher_model"]
|
| 430 |
+
|
| 431 |
+
# Load tokenizer
|
| 432 |
+
slm_tokenizer = AutoTokenizer.from_pretrained(str(slm_model_path))
|
| 433 |
+
set_default_chat_template(slm_tokenizer, slm_model_path)
|
| 434 |
+
|
| 435 |
+
# Load models
|
| 436 |
+
slm_model = AutoModelForCausalLM.from_pretrained(
|
| 437 |
+
str(slm_model_path),
|
| 438 |
+
torch_dtype=torch.bfloat16,
|
| 439 |
+
device_map={"": device}
|
| 440 |
+
).eval()
|
| 441 |
+
|
| 442 |
+
llm_model = AutoModelForCausalLM.from_pretrained(
|
| 443 |
+
str(llm_model_path),
|
| 444 |
+
torch_dtype=torch.bfloat16,
|
| 445 |
+
device_map={"": device}
|
| 446 |
+
).eval()
|
| 447 |
+
|
| 448 |
+
# Load projectors
|
| 449 |
+
num_projectors = len([f for f in os.listdir(checkpoint_dir) if re.match(r"projector_\d+\.pt", f)])
|
| 450 |
+
projector_list = []
|
| 451 |
+
for t in range(num_projectors):
|
| 452 |
+
json_cfg = os.path.join(checkpoint_dir, f"projector_{t}.json")
|
| 453 |
+
proj = load_projector(json_cfg)
|
| 454 |
+
proj = proj.to(device)
|
| 455 |
+
pt_path = os.path.join(checkpoint_dir, f"projector_{t}.pt")
|
| 456 |
+
if os.path.exists(pt_path):
|
| 457 |
+
state_dict = torch.load(pt_path, map_location=device)
|
| 458 |
+
proj.load_state_dict(state_dict, strict=False)
|
| 459 |
+
projector_list.append(proj)
|
| 460 |
+
|
| 461 |
+
aggregator_list = []
|
| 462 |
+
|
| 463 |
+
# Initialize Rosetta model
|
| 464 |
+
rosetta_model = OracleRosettaModel(
|
| 465 |
+
model_list=[slm_model, llm_model],
|
| 466 |
+
base_model_idx=0,
|
| 467 |
+
projector_list=projector_list,
|
| 468 |
+
aggregator_list=aggregator_list,
|
| 469 |
+
).to(device).eval()
|
| 470 |
+
|
| 471 |
+
# Load projector/aggregator mapping configs
|
| 472 |
+
proj_cfg_path = os.path.join(checkpoint_dir, "projector_config.json")
|
| 473 |
+
agg_cfg_path = os.path.join(checkpoint_dir, "aggregator_config.json")
|
| 474 |
+
rosetta_model.load_projector_config(proj_cfg_path)
|
| 475 |
+
rosetta_model.load_aggregator_config(agg_cfg_path)
|
| 476 |
+
|
| 477 |
+
return rosetta_model, slm_tokenizer
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def get_option_token_ids(tokenizer, num_options: int = 4) -> List[int]:
|
| 481 |
+
"""
|
| 482 |
+
Get token IDs for options A, B, C, D (or more up to J).
|
| 483 |
+
|
| 484 |
+
Args:
|
| 485 |
+
tokenizer: Tokenizer object
|
| 486 |
+
num_options: Number of options to get (default 4 for A-D, max 10 for A-J)
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
List of token IDs for options
|
| 490 |
+
"""
|
| 491 |
+
# Limit to maximum of 10 options (A-J)
|
| 492 |
+
num_options = min(num_options, 10)
|
| 493 |
+
option_ids = []
|
| 494 |
+
for i in range(num_options):
|
| 495 |
+
letter = chr(65 + i) # A=65, B=66, etc.
|
| 496 |
+
ids = tokenizer.encode(" " + letter, add_special_tokens=False)
|
| 497 |
+
option_ids.append(ids[0] if ids else tokenizer.eos_token_id)
|
| 498 |
+
return option_ids
|
| 499 |
+
|
| 500 |
+
"""
|
| 501 |
+
Deprecated
|
| 502 |
+
"""
|
| 503 |
+
|
| 504 |
+
@torch.no_grad()
|
| 505 |
+
def generate_answer_with_logits(model, tokenizer, prompt: str, option_ids: List[int],
|
| 506 |
+
device: torch.device, model_type: str = "hf") -> Tuple[str, np.ndarray]:
|
| 507 |
+
"""
|
| 508 |
+
Generate answer using logits method.
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
model: Model object
|
| 512 |
+
tokenizer: Tokenizer object
|
| 513 |
+
prompt: Input prompt
|
| 514 |
+
option_ids: Token IDs for options A, B, C, D
|
| 515 |
+
device: Device to run on
|
| 516 |
+
model_type: Type of model ("rosetta", "qwen", or "hf")
|
| 517 |
+
|
| 518 |
+
Returns:
|
| 519 |
+
Tuple of (predicted_answer, probabilities)
|
| 520 |
+
"""
|
| 521 |
+
messages = [{
|
| 522 |
+
"role": "user",
|
| 523 |
+
"content": prompt
|
| 524 |
+
}]
|
| 525 |
+
|
| 526 |
+
# Try to apply chat template
|
| 527 |
+
try:
|
| 528 |
+
text = tokenizer.apply_chat_template(
|
| 529 |
+
messages,
|
| 530 |
+
tokenize=False,
|
| 531 |
+
add_generation_prompt=True,
|
| 532 |
+
enable_thinking=False if model_type == "qwen" else None
|
| 533 |
+
)
|
| 534 |
+
except Exception as e:
|
| 535 |
+
print(f"Failed to apply chat template for {model_type} model: {e}")
|
| 536 |
+
text = f"### Human: {prompt}\n### Assistant:"
|
| 537 |
+
|
| 538 |
+
text += "The correct answer is"
|
| 539 |
+
input_ids = tokenizer(text, return_tensors="pt").to(device)['input_ids']
|
| 540 |
+
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(device)
|
| 541 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 542 |
+
|
| 543 |
+
if model_type == "rosetta":
|
| 544 |
+
instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(input_ids.shape[1]-1, 1).unsqueeze(0).to(device)
|
| 545 |
+
response_index = torch.tensor([[-1, 0]], dtype=torch.long).unsqueeze(0)
|
| 546 |
+
outputs = model.forward(
|
| 547 |
+
input_ids=input_ids,
|
| 548 |
+
attention_mask=attention_mask,
|
| 549 |
+
position_ids=position_ids,
|
| 550 |
+
kv_cache_index=[instruction_index, response_index]
|
| 551 |
+
)
|
| 552 |
+
else:
|
| 553 |
+
outputs = model(input_ids)
|
| 554 |
+
|
| 555 |
+
logits = outputs.logits[0, -1]
|
| 556 |
+
option_logits = torch.tensor([
|
| 557 |
+
logits[option_ids[0]].item(),
|
| 558 |
+
logits[option_ids[1]].item(),
|
| 559 |
+
logits[option_ids[2]].item(),
|
| 560 |
+
logits[option_ids[3]].item()
|
| 561 |
+
])
|
| 562 |
+
|
| 563 |
+
probs = torch.nn.functional.softmax(option_logits, dim=0).numpy()
|
| 564 |
+
pred = chr(65 + np.argmax(probs))
|
| 565 |
+
return pred, probs
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
@torch.no_grad()
|
| 569 |
+
def generate_answer_with_generate(model, tokenizer, prompt: str, device: torch.device,
|
| 570 |
+
model_type: str = "hf") -> Tuple[str, np.ndarray, int, int, str]:
|
| 571 |
+
"""
|
| 572 |
+
Generate answer using text generation method.
|
| 573 |
+
|
| 574 |
+
Args:
|
| 575 |
+
model: Model object
|
| 576 |
+
tokenizer: Tokenizer object
|
| 577 |
+
prompt: Input prompt
|
| 578 |
+
device: Device to run on
|
| 579 |
+
model_type: Type of model ("rosetta" or "hf")
|
| 580 |
+
|
| 581 |
+
Returns:
|
| 582 |
+
Tuple of (predicted_answer, probabilities, input_length, generation_length, generated_text)
|
| 583 |
+
"""
|
| 584 |
+
messages = [{
|
| 585 |
+
"role": "user",
|
| 586 |
+
"content": prompt
|
| 587 |
+
}]
|
| 588 |
+
|
| 589 |
+
# Apply chat template
|
| 590 |
+
try:
|
| 591 |
+
text = tokenizer.apply_chat_template(
|
| 592 |
+
messages,
|
| 593 |
+
tokenize=False,
|
| 594 |
+
add_generation_prompt=True,
|
| 595 |
+
enable_thinking=False
|
| 596 |
+
)
|
| 597 |
+
except Exception as e:
|
| 598 |
+
print(f"Failed to apply chat template: {e}")
|
| 599 |
+
text = f"### Human: {prompt}\n### Assistant:"
|
| 600 |
+
|
| 601 |
+
# Prepare model input
|
| 602 |
+
inputs = tokenizer(text, return_tensors="pt").to(device)
|
| 603 |
+
|
| 604 |
+
# Generation parameters
|
| 605 |
+
sampling_params = {
|
| 606 |
+
'do_sample': True,
|
| 607 |
+
'temperature': 0.7,
|
| 608 |
+
'top_p': 0.8,
|
| 609 |
+
'top_k': 20,
|
| 610 |
+
'min_p': 0.0,
|
| 611 |
+
'repetition_penalty': 1.2,
|
| 612 |
+
'max_new_tokens': 1024
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
# Generate text
|
| 616 |
+
outputs = model.generate(
|
| 617 |
+
input_ids=inputs.input_ids,
|
| 618 |
+
attention_mask=inputs.attention_mask,
|
| 619 |
+
**sampling_params
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# Parse output
|
| 623 |
+
if isinstance(model, RosettaModel):
|
| 624 |
+
generated_ids = outputs[0]
|
| 625 |
+
else:
|
| 626 |
+
generated_ids = outputs[0][inputs.input_ids.shape[1]:]
|
| 627 |
+
content = tokenizer.decode(generated_ids, skip_special_tokens=True).strip("\n")
|
| 628 |
+
|
| 629 |
+
# Extract answer
|
| 630 |
+
pred = extract_answer_from_content(content)
|
| 631 |
+
|
| 632 |
+
# Return uniform distribution for generate method
|
| 633 |
+
probs = np.array([0.25, 0.25, 0.25, 0.25])
|
| 634 |
+
|
| 635 |
+
input_length = inputs.input_ids.shape[1]
|
| 636 |
+
gen_length = generated_ids.shape[0]
|
| 637 |
+
|
| 638 |
+
return pred, probs, input_length, gen_length, content
|
| 639 |
+
|
rosetta/utils/matheval.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
import jsonlines
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
from math_verify import parse, verify, LatexExtractionConfig, ExprExtractionConfig, StringExtractionConfig
|
| 9 |
+
from latex2sympy2_extended import NormalizationConfig
|
| 10 |
+
from transformers import AutoTokenizer
|
| 11 |
+
import requests
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
class MathEvaluator:
|
| 15 |
+
|
| 16 |
+
def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
|
| 17 |
+
raise NotImplementedError
|
| 18 |
+
|
| 19 |
+
def extract_after_think(self, text: str, truncate_length: int = 1000, finish_generation: bool = True) -> str:
|
| 20 |
+
pattern = r"</think>(.*)"
|
| 21 |
+
match = re.search(pattern, text, re.DOTALL)
|
| 22 |
+
return match.group(1).strip() if (match and finish_generation) else text[-truncate_length:]
|
| 23 |
+
|
| 24 |
+
def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extracted_answer: str = "", finish_generation: bool = True) -> str:
|
| 25 |
+
raise NotImplementedError
|
| 26 |
+
|
| 27 |
+
def get_llm_judge_prompt_not_finished(self, solution_str: str, ground_truth: str, extracted_answer: str = "", finish_generation: bool = True) -> str:
|
| 28 |
+
return f"""Please determine whether the final answer in the model-generated response was already correctly derived early in the reasoning process, and that the subsequent content consists mainly of unnecessary verification, overthinking, or repetitive reasoning. If correct is derived early, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
|
| 29 |
+
Reference answer: {ground_truth}
|
| 30 |
+
Model-generated response: {solution_str}
|
| 31 |
+
""".strip()
|
| 32 |
+
|
| 33 |
+
def llm_judge(self, solution_str: str, ground_truth: str, extracted_answer: str = "", finish_generation: bool = True) -> bool:
|
| 34 |
+
global OPENAI_CLIENT, MODEL_NAME
|
| 35 |
+
def get_inputs(scene_description):
|
| 36 |
+
body = [
|
| 37 |
+
{"role": "user", "content": scene_description},
|
| 38 |
+
]
|
| 39 |
+
return body
|
| 40 |
+
|
| 41 |
+
def run_api(inputs):
|
| 42 |
+
completion = OPENAI_CLIENT.chat.completions.create(
|
| 43 |
+
model=MODEL_NAME,
|
| 44 |
+
messages=inputs
|
| 45 |
+
)
|
| 46 |
+
return completion.choices[0].message.content.strip()
|
| 47 |
+
if finish_generation:
|
| 48 |
+
scene_description = self.get_llm_judge_prompt(solution_str, ground_truth, extracted_answer, finish_generation)
|
| 49 |
+
else:
|
| 50 |
+
scene_description = self.get_llm_judge_prompt_not_finished(solution_str, ground_truth, extracted_answer, finish_generation)
|
| 51 |
+
inputs = get_inputs(scene_description)
|
| 52 |
+
response = run_api(inputs)
|
| 53 |
+
|
| 54 |
+
return "YES" in response
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class AIMEEvaluator(MathEvaluator):
|
| 58 |
+
def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
|
| 59 |
+
# if not ground_truth.startswith("$"):
|
| 60 |
+
# ground_truth = f"${ground_truth}$"
|
| 61 |
+
gold = parse(
|
| 62 |
+
ground_truth,
|
| 63 |
+
extraction_config=[ExprExtractionConfig()],
|
| 64 |
+
)
|
| 65 |
+
answer = parse(
|
| 66 |
+
solution_str,
|
| 67 |
+
extraction_config=[
|
| 68 |
+
LatexExtractionConfig(
|
| 69 |
+
normalization_config=NormalizationConfig(
|
| 70 |
+
nits=False,
|
| 71 |
+
malformed_operators=False,
|
| 72 |
+
basic_latex=True,
|
| 73 |
+
boxed="all",
|
| 74 |
+
units=True,
|
| 75 |
+
),
|
| 76 |
+
boxed_match_priority=0,
|
| 77 |
+
try_extract_without_anchor=False,
|
| 78 |
+
),
|
| 79 |
+
ExprExtractionConfig(),
|
| 80 |
+
],
|
| 81 |
+
extraction_mode="first_match",
|
| 82 |
+
)
|
| 83 |
+
if len(answer) == 0:
|
| 84 |
+
return False, "No extracted answer"
|
| 85 |
+
else:
|
| 86 |
+
return verify(gold, answer), str(answer)
|
| 87 |
+
|
| 88 |
+
def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str:
|
| 89 |
+
solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation)
|
| 90 |
+
return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a math question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
|
| 91 |
+
Model-generated answer: {solution_str}
|
| 92 |
+
Reference answer: {ground_truth}""".strip()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class GSM8KEvaluator(MathEvaluator):
|
| 96 |
+
def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
|
| 97 |
+
# if not ground_truth.startswith("$"):
|
| 98 |
+
# ground_truth = f"${ground_truth}$"
|
| 99 |
+
gold = parse(
|
| 100 |
+
ground_truth,
|
| 101 |
+
extraction_config=[ExprExtractionConfig()],
|
| 102 |
+
)
|
| 103 |
+
answer = parse(
|
| 104 |
+
solution_str,
|
| 105 |
+
extraction_config=[
|
| 106 |
+
LatexExtractionConfig(
|
| 107 |
+
normalization_config=NormalizationConfig(
|
| 108 |
+
nits=False,
|
| 109 |
+
malformed_operators=False,
|
| 110 |
+
basic_latex=True,
|
| 111 |
+
boxed="all",
|
| 112 |
+
units=True,
|
| 113 |
+
),
|
| 114 |
+
boxed_match_priority=0,
|
| 115 |
+
try_extract_without_anchor=False,
|
| 116 |
+
),
|
| 117 |
+
ExprExtractionConfig(),
|
| 118 |
+
],
|
| 119 |
+
extraction_mode="first_match",
|
| 120 |
+
)
|
| 121 |
+
if len(answer) == 0:
|
| 122 |
+
return False, "No extracted answer"
|
| 123 |
+
else:
|
| 124 |
+
return verify(gold, answer), str(answer)
|
| 125 |
+
|
| 126 |
+
def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str:
|
| 127 |
+
solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation)
|
| 128 |
+
return f"""Please determine whether the final answer provided in the model-generated response with rule-based extracted answer is equivalent to the reference answer from a math question. The final answer may either be enclosed in the \\boxed{{}} or appear after the "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
|
| 129 |
+
|
| 130 |
+
1. The reference answer does not include percentage signs, units or time formats (e.g., am, pm), but the Model-generated answer may include them.
|
| 131 |
+
For example, 1 is equivalent to 1 %, 1 kg, 1 am, 1 pm, 1:00 am, 1:00 pm, etc.
|
| 132 |
+
Model-generated answer: 1%
|
| 133 |
+
Reference answer: 1
|
| 134 |
+
Your output: YES
|
| 135 |
+
|
| 136 |
+
Model-generated answer: 1 kg
|
| 137 |
+
Reference answer: 1
|
| 138 |
+
Your output: YES
|
| 139 |
+
|
| 140 |
+
Model-generated answer: 1:00 pm
|
| 141 |
+
Reference answer: 1
|
| 142 |
+
Your output: YES
|
| 143 |
+
|
| 144 |
+
2. The reference answer only includes one single number, but the Model-generated answer may include multiple numbers.
|
| 145 |
+
For example, 10 is equivalent to \\boxed{{(4, 6)}}, etc.
|
| 146 |
+
Model-generated answer: 5, 5
|
| 147 |
+
Reference answer: 10
|
| 148 |
+
Your output: YES
|
| 149 |
+
|
| 150 |
+
Model-generated answer: 4, 6
|
| 151 |
+
Reference answer: 10
|
| 152 |
+
Your output: YES
|
| 153 |
+
|
| 154 |
+
Model-generated answer: 86, 42
|
| 155 |
+
Reference answer: 128
|
| 156 |
+
Your output: YES
|
| 157 |
+
|
| 158 |
+
Now let's try a real example.
|
| 159 |
+
Model-generated answer: {solution_str}
|
| 160 |
+
Reference answer: {ground_truth}
|
| 161 |
+
""".strip()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class MATH500Evaluator(MathEvaluator):
|
| 165 |
+
def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
|
| 166 |
+
if not ground_truth.startswith("$"):
|
| 167 |
+
ground_truth = f"${ground_truth}$"
|
| 168 |
+
gold = parse(
|
| 169 |
+
ground_truth,
|
| 170 |
+
extraction_config=[LatexExtractionConfig()],
|
| 171 |
+
)
|
| 172 |
+
answer = parse(
|
| 173 |
+
solution_str,
|
| 174 |
+
extraction_config=[
|
| 175 |
+
LatexExtractionConfig(
|
| 176 |
+
normalization_config=NormalizationConfig(
|
| 177 |
+
nits=False,
|
| 178 |
+
malformed_operators=False,
|
| 179 |
+
basic_latex=True,
|
| 180 |
+
boxed="all",
|
| 181 |
+
units=True,
|
| 182 |
+
),
|
| 183 |
+
boxed_match_priority=0,
|
| 184 |
+
try_extract_without_anchor=False,
|
| 185 |
+
),
|
| 186 |
+
ExprExtractionConfig(),
|
| 187 |
+
],
|
| 188 |
+
extraction_mode="first_match",
|
| 189 |
+
)
|
| 190 |
+
if len(answer) == 0:
|
| 191 |
+
return False, "No extracted answer"
|
| 192 |
+
else:
|
| 193 |
+
return verify(gold, answer), str(answer)
|
| 194 |
+
def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str:
|
| 195 |
+
solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation)
|
| 196 |
+
return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a math question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
|
| 197 |
+
Model-generated answer: {solution_str}
|
| 198 |
+
Reference answer: {ground_truth}""".strip()
|
| 199 |
+
|
| 200 |
+
class AMCEvaluator(MathEvaluator):
|
| 201 |
+
def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
|
| 202 |
+
if not ground_truth.startswith("$"):
|
| 203 |
+
ground_truth = f"${ground_truth}$"
|
| 204 |
+
gold = parse(
|
| 205 |
+
ground_truth,
|
| 206 |
+
extraction_config=[LatexExtractionConfig()],
|
| 207 |
+
)
|
| 208 |
+
answer = parse(
|
| 209 |
+
solution_str,
|
| 210 |
+
extraction_config=[
|
| 211 |
+
LatexExtractionConfig(
|
| 212 |
+
normalization_config=NormalizationConfig(
|
| 213 |
+
nits=False,
|
| 214 |
+
malformed_operators=False,
|
| 215 |
+
basic_latex=True,
|
| 216 |
+
boxed="all",
|
| 217 |
+
units=True,
|
| 218 |
+
),
|
| 219 |
+
boxed_match_priority=0,
|
| 220 |
+
try_extract_without_anchor=False,
|
| 221 |
+
),
|
| 222 |
+
ExprExtractionConfig(),
|
| 223 |
+
],
|
| 224 |
+
extraction_mode="first_match",
|
| 225 |
+
)
|
| 226 |
+
if len(answer) == 0:
|
| 227 |
+
return False, "No extracted answer"
|
| 228 |
+
else:
|
| 229 |
+
return verify(gold, answer), str(answer)
|
| 230 |
+
def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str:
|
| 231 |
+
solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation)
|
| 232 |
+
return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a math question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
|
| 233 |
+
Model-generated answer: {solution_str}
|
| 234 |
+
Reference answer: {ground_truth}""".strip()
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class GPQAEvaluator(MathEvaluator):
|
| 238 |
+
def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
|
| 239 |
+
# if not ground_truth.startswith("$"):
|
| 240 |
+
# ground_truth = f"${ground_truth}$"
|
| 241 |
+
gold = parse(
|
| 242 |
+
ground_truth,
|
| 243 |
+
extraction_config=[StringExtractionConfig()],
|
| 244 |
+
)
|
| 245 |
+
answer = parse(
|
| 246 |
+
solution_str,
|
| 247 |
+
extraction_config=[
|
| 248 |
+
StringExtractionConfig(),
|
| 249 |
+
]
|
| 250 |
+
)
|
| 251 |
+
if len(answer) == 0:
|
| 252 |
+
return False, "No extracted answer"
|
| 253 |
+
else:
|
| 254 |
+
return verify(gold, answer), str(answer)
|
| 255 |
+
|
| 256 |
+
def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str:
|
| 257 |
+
solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation)
|
| 258 |
+
return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a multiple choice question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
|
| 259 |
+
Model-generated answer: {solution_str}
|
| 260 |
+
Reference answer: {ground_truth}""".strip()
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
# class MBPPEvaluator(Evaluator):
|
| 264 |
+
# def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
|
| 265 |
+
# return True, "No extracted answer"
|
| 266 |
+
|
| 267 |
+
# def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str:
|
| 268 |
+
# solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation)
|
| 269 |
+
# return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a multiple choice question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
|
| 270 |
+
# Model-generated answer: {solution_str}
|
| 271 |
+
# Reference answer: {ground_truth}""".strip()
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# class HUMANEVALEvaluator(Evaluator):
|
| 275 |
+
# def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool:
|
| 276 |
+
# return True, "No extracted answer"
|
| 277 |
+
|
| 278 |
+
# def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str:
|
| 279 |
+
# solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation)
|
| 280 |
+
# return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a multiple choice question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content.
|
| 281 |
+
# Model-generated answer: {solution_str}
|
| 282 |
+
# Reference answer: {ground_truth}""".strip()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
evaluator_map = {
|
| 286 |
+
"aime2024": AIMEEvaluator(),
|
| 287 |
+
"aime2025": AIMEEvaluator(),
|
| 288 |
+
"gsm8k": GSM8KEvaluator(),
|
| 289 |
+
"math500": MATH500Evaluator(),
|
| 290 |
+
"gpqa_diamond": GPQAEvaluator(),
|
| 291 |
+
"amc23": AMCEvaluator(),
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
API_BASE = None
|
| 295 |
+
DEPLOYMENT_NAME = None
|
| 296 |
+
API_VERSION = None
|
| 297 |
+
CONSTRUCTED_URL = None
|
| 298 |
+
API_KEY = None
|
| 299 |
+
HEADERS = None
|
| 300 |
+
OPENAI_CLIENT = None
|
| 301 |
+
MODEL_NAME = None
|
| 302 |
+
|
| 303 |
+
def set_client(api_base=None, deployment_name=None, api_version=None, api_key=None, model_name="gpt-4.1-2025-04-14"):
|
| 304 |
+
global API_BASE, DEPLOYMENT_NAME, API_VERSION, CONSTRUCTED_URL, API_KEY, HEADERS, MODEL_NAME, OPENAI_CLIENT
|
| 305 |
+
|
| 306 |
+
API_BASE = api_base
|
| 307 |
+
DEPLOYMENT_NAME = deployment_name
|
| 308 |
+
API_VERSION = api_version
|
| 309 |
+
CONSTRUCTED_URL = f"{api_base}/openai/deployments/{deployment_name}/chat/completions?api-version={api_version}"
|
| 310 |
+
API_KEY = api_key or os.getenv("OPENAI_API_KEY", "")
|
| 311 |
+
MODEL_NAME = model_name
|
| 312 |
+
HEADERS = {
|
| 313 |
+
"Content-Type": "application/json",
|
| 314 |
+
"api-key": api_key,
|
| 315 |
+
}
|
| 316 |
+
if API_KEY:
|
| 317 |
+
print(f"Using API key: {API_KEY}")
|
| 318 |
+
OPENAI_CLIENT = OpenAI(api_key=API_KEY)
|
| 319 |
+
else:
|
| 320 |
+
OPENAI_CLIENT = None
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# def call_llm_judge(message: list, args: argparse.Namespace) -> str:
|
| 326 |
+
# """
|
| 327 |
+
# Call the Qwen API with the given message.
|
| 328 |
+
|
| 329 |
+
# Args:
|
| 330 |
+
# message (list): Message list for the API.
|
| 331 |
+
# args (argparse.Namespace): Parsed arguments.
|
| 332 |
+
|
| 333 |
+
# Returns:
|
| 334 |
+
# str: The content of the completion response.
|
| 335 |
+
|
| 336 |
+
# Raises:
|
| 337 |
+
# Exception: When the API call fails.
|
| 338 |
+
# """
|
| 339 |
+
|
| 340 |
+
# try:
|
| 341 |
+
# completion = client.chat.completions.create(
|
| 342 |
+
# model=args.model_name,
|
| 343 |
+
# messages=message,
|
| 344 |
+
# temperature=args.temperature,
|
| 345 |
+
# top_p=args.top_p,
|
| 346 |
+
# )
|
| 347 |
+
# return completion.choices[0].message.content.strip()
|
| 348 |
+
# except Exception as e:
|
| 349 |
+
# print(f"API call failed: {str(e)}")
|
| 350 |
+
# raise
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# def rule_judge(completin, ground_truth, dataset_name: str) -> Dict[str, Any]:
|
| 354 |
+
# rule_judge_result = None
|
| 355 |
+
|
| 356 |
+
# rule_judge_result = evaluator_map[dataset_name].rule_judge(completin, ground_truth)
|
| 357 |
+
|
| 358 |
+
# if not rule_judge_result:
|
| 359 |
+
# print(f"No valid answer detected | LLM judge")
|
| 360 |
+
# call_llm_judge
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# return example
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# def process_example(example: Dict[str, Any], args: argparse.Namespace) -> Dict[str, Any]:
|
| 368 |
+
# """
|
| 369 |
+
# Process an individual example with token counting.
|
| 370 |
+
# """
|
| 371 |
+
# try:
|
| 372 |
+
|
| 373 |
+
# completion_text = example.get('completion', '')
|
| 374 |
+
# before_truncation_text = completion_text.split("</think>")[0]
|
| 375 |
+
# before_truncation_token_count = count_tokens(before_truncation_text, args)
|
| 376 |
+
# after_truncation_token_count = example.get('generated_tokens', 0) - before_truncation_token_count
|
| 377 |
+
|
| 378 |
+
# example['before_truncation_token_count'] = before_truncation_token_count
|
| 379 |
+
# example['after_truncation_token_count'] = after_truncation_token_count
|
| 380 |
+
|
| 381 |
+
# example = post_process(example)
|
| 382 |
+
# if not example.get('rule_judge_result', False):
|
| 383 |
+
# think_truncation = extract_after_think(completion_text) or ''
|
| 384 |
+
|
| 385 |
+
# prompt = generate_prompt(example, think_truncation)
|
| 386 |
+
# messages = format_prompt(prompt)
|
| 387 |
+
# llm_judge_response = call_qwen(messages, args)
|
| 388 |
+
# print(llm_judge_response)
|
| 389 |
+
# example['llm_judge_response'] = llm_judge_response
|
| 390 |
+
# if llm_judge_response == "YES":
|
| 391 |
+
# example['llm_judge_result'] = True
|
| 392 |
+
# example['final_judge_result'] = True
|
| 393 |
+
# elif llm_judge_response == "NO":
|
| 394 |
+
# example['llm_judge_result'] = False
|
| 395 |
+
# example['final_judge_result'] = False
|
| 396 |
+
# else:
|
| 397 |
+
# example['llm_judge_result'] = None
|
| 398 |
+
# example['final_judge_result'] = None
|
| 399 |
+
|
| 400 |
+
# return example
|
| 401 |
+
# except Exception as e:
|
| 402 |
+
# print(f"Failed to process example {example.get('idx', 'unknown')}: {str(e)}")
|
| 403 |
+
# example['final_judge_result'] = None
|
| 404 |
+
# return example
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
# def main():
|
| 409 |
+
# args = parse_arguments()
|
| 410 |
+
# try:
|
| 411 |
+
# final_results = []
|
| 412 |
+
# with jsonlines.open(args.result_save_name, mode='w') as writer:
|
| 413 |
+
# with open(args.data_load_name, 'r', encoding='utf-8') as f:
|
| 414 |
+
# data = json.load(f)
|
| 415 |
+
|
| 416 |
+
# for idx, example in enumerate(data, start=1):
|
| 417 |
+
# print(f"Processed example ID: {idx}")
|
| 418 |
+
# new_example = process_example(example, args)
|
| 419 |
+
# final_results.append(new_example)
|
| 420 |
+
|
| 421 |
+
# writer.write_all(final_results)
|
| 422 |
+
# print(f"Processing complete! Valid results saved to: {args.result_save_name}")
|
| 423 |
+
|
| 424 |
+
# grouped_data = defaultdict(list)
|
| 425 |
+
# for entry in final_results:
|
| 426 |
+
# grouped_data[entry['idx']].append(entry)
|
| 427 |
+
|
| 428 |
+
# pass_at_1_per_idx = {}
|
| 429 |
+
# for idx, entries in grouped_data.items():
|
| 430 |
+
# correct_count = sum(entry['final_judge_result'] for entry in entries if entry['final_judge_result'] is True)
|
| 431 |
+
# total_count = len(entries)
|
| 432 |
+
# pass_at_1_per_idx[idx] = correct_count / total_count if total_count else 0
|
| 433 |
+
|
| 434 |
+
# overall_pass_at_1 = sum(pass_at_1_per_idx.values()) / len(pass_at_1_per_idx)
|
| 435 |
+
# average_generated_tokens = sum(entry['generated_tokens'] for entry in final_results) / len(final_results)
|
| 436 |
+
# average_reasoning_tokens = sum(entry['before_truncation_token_count'] for entry in final_results) / len(final_results)
|
| 437 |
+
# average_answer_tokens = sum(entry['after_truncation_token_count'] for entry in final_results) / len(final_results)
|
| 438 |
+
|
| 439 |
+
# print(f'\nPass@1 per idx:\n{pass_at_1_per_idx}\n')
|
| 440 |
+
# print(f'Overall pass@1: {overall_pass_at_1:.4f}')
|
| 441 |
+
# print(f'Average generated_tokens: {average_generated_tokens:.2f}')
|
| 442 |
+
# print(f'Average reasoning_tokens: {average_reasoning_tokens:.2f}')
|
| 443 |
+
# print(f'Average answer_tokens: {average_answer_tokens:.2f}')
|
| 444 |
+
|
| 445 |
+
# except json.JSONDecodeError as e:
|
| 446 |
+
# print(f"Data loading failed: Invalid JSON format - {str(e)}")
|
| 447 |
+
# raise
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
if __name__ == '__main__':
|
| 451 |
+
api_key = os.getenv("OPENAI_API_KEY", "")
|
| 452 |
+
|
| 453 |
+
set_client(api_base="", deployment_name="", api_version="", api_key=api_key)
|
| 454 |
+
response = "Okay, so I need to find the number of triples (a, b, c) of nonnegative integers such that a + b + c = 300 and the expression a\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b equals 6,000,000. Hmm, that seems a bit complicated, but maybe I can break it down step by step.\n\nFirst, let me note down the given equations:\n\n1. a + b + c = 300\n2. a\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b = 6,000,000\n\nI need to find all nonnegative integer triples (a, b, c) that satisfy both equations. Since the problem involves symmetric expressions in a, b, c, maybe there's a way to exploit symmetry here.\n\nLet me first if I can rewrite the second equation in a more manageable form. The expression a\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b looks symmetric. Maybe I can factor it somehow?\n\nLet me see: Each term is of the form (variable squared times another variable). So, for example, a\u00b2b + a\u00b2c = a\u00b2(b + c), similarly for the other terms. Let me try that:\n\na\u00b2(b + c) + b\u00b2(a + c) + c\u00b2(a + b)\n\nYes, that's the same as the original expression. Now, since a + b + c = 300, maybe I can substitute (b + c) with (300 - a), and similarly for the others. Let's try that substitution.\n\nSo, substituting:\n\na\u00b2(300 - a) + b\u00b2(300 - b) + c\u00b2(300 - c)\n\nWait, hold on. Let me check that again. The original expression is't exactly that. Wait, the's see:\n\nWait, the original expression is a\u00b2(b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b. So I group terms as a\u00b2(b + c) + b\u00b2(a + c) + c\u00b2(a + b), then yes, each term is variable squared times the sum of the other two variables. Since the sum of the other two variables is 300 minus the variable. So, for example, (b + c) = 300 - a, so substituting that in, we get:\n\na\u00b2(300 - a) + b\u00b2(300 - b) + c\u00b2(300 - c)\n\nTherefore, the second equation can be rewritten as:\n\na\u00b2(300 - a) + b\u00b2(300 - b) + c\u00b2(300 - c) = 6,000,000\n\nHmm, that might be a useful form. Let me write that down:\n\nEquation 2: a\u00b2(300 - a) + b\u00b2(300 - b) + c\u00b2(300 - c) = 6,000,000\n\nNow, perhaps I can consider expanding expression in terms of the variables. Let me think if there's a way to relate this to the first equation. Since a + b + c = 300, maybe we can express some relationship between these variables here Let't immediately yet.\n\nAlternatively, maybe I can consider expanding the expression or finding another way to express it. Let me see ifAlternatively, perhaps I can consider the expression a\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b. Let me see if this can be written in terms of (a + b + c)^3 or something similar.\n\nWait, let me recall that (a + b + c)^3 = a^3 + b^3 + c^3 + 3(a + b)(a + c)(b + c). Hmm, not sure if that helps. Alternatively, maybe expanding (a + b + c)(ab + bc + ca). Let me see:\n\n(a + b + c)(ab + bc + ca) = a^2b + a^2c + ab^2 + b^2c + a c^2 + b c^2 + 3abc\n\nWait, that's exactly our expression in the second equation plus 3abc. So:\n\n(a + b + c)(ab + bc + ca) = (a\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b) + 3abc\n\nTherefore, the second equation can be rewritten as:\n\n(a + b + c)(ab + bc + ca) - 3abc = 6,000,000\n\nSince a know that a + b + c = 300, substitute that in:\n\n300(abab + bc + ca) - 3abc = 6,000,000\n\nHmm, that might be a better way to handle it equation. Let me write that down:\n\nEqu00(ab + bc + ca) - 3abc = 6,000,000\n\nDiv can divide both sides by 3 to simplify:\n\n100(ab + bc + ca) - abc = 2,000,000\n\nSo now, the equation becomes:\n\n100(ab + bc + ca) - abc = 2,000,000\n\nHmm, that seems a bit more manageable. Let me write that as:\n\nabc - 100(ab + bc + ca) = -2,000,000\n\nAlternatively, maybe rearranged as:\n\nabc = 100(ab + bc + ca) - 2,000,000\n\nHmm, not sure yet. Let me see if I can find another relation here.\n\n Since we have a + b + c = 300, perhaps we can denote variables in terms of two variables. say, let me set variables as a, b, and then c = 300 - a - b. Then substitute expressions in terms of a and b. But that might get complicated, but perhaps manageable.\n\nAlternatively, maybe there's a substitution or symmetry here. Let me think.\n\n the problem is symmetric in a, b, c, so perhaps the solutions are symmetric in some way. Maybe two variables are equal, or all three are equal?\n\nLet me first check if all three variables are equal. If a = b = c, then since a + b + c = 300, each would be 100. Let me check if that satisfies the second equation.\n\nIf a = b = c = 100, then compute the second expression:\n\na\u00b2b + a\u00b2c + ...\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\n = 6*(100)^2*100 = 6*10003 = 6,000,000. Wait, that's exactly the required given! So (100, 100, 100) is a solution.\n\n So that's one solution. But the problem asks for all triples, so maybe might be more?\n\nHmm, so maybe there are other triples where variables are not all equal but still satisfy the equations. Let me see.\n\nAlternatively, maybe two variables are equal and the third is different. Let me suppose that two variables are equal, say a = b, and c is different. Let,2a + c = 300, so c = 300 - 2a.\n\nThen, let's plug into the second equation.\n\nFirst, let me compute ab + bc + ca. when a = b:\n\nab + bc + ca = a\u00b2 + a c + a c = a\u00b2 + 2a c\n\nSimilarly, abc = a\u00b2 c\n\nSo substitugging into the equation 100(ab + bc + ca) - abc = 100(a\u00b2 + 2a c) - a\u00b2 c = 100a\u00b2 + 200a c - a\u00b2 c\n\nBut since c = 300 - 2a, substitute that in:\n\n100a\u00b2 + 200a(300 - 2a) - a\u00b2(300 - 2a)\n\nLet me compute each term:\n\nFirst term: 100a\u00b2\n\nSecond term: 200a*(300 - 2a) = 200a*300 - 200a*2a = 60,000a - 400a\u00b2\n\nThird term: -a\u00b2*(300 - 2a) = -300a\u00b2 + 2a\u00b3\n\nSo adding all three terms together:\n\n100a\u00b2 + (60,000a - 400a\u00b2) + (-300a\u00b2 + 2a\u00b3)\n\nCombine like terms:\n\nFor00a\u00b2 - 400a\u00b2 - 300a\u00b2 + 2a\u00b3 + 60,000a\n\nThat's:\n\n(100 - 400 - 300)a\u00b2 + 2a\u00b3 + 60,000a\n\n= (-600a\u00b2) + 2a\u00b3 + 60,000a\n\nSo the entire expression is 2a\u00b3 - 600a\u00b2 + 60,000a\n\nSet this equal to 2,000,000 ( (since the equation was 100(ab + bc + ca) - abc = 2,000,000)\n\nWait:\n\n2a\u00b3 - 600a\u00b2 + 60,000a = 2,000,000\n\nDivide both sides by 2:\n\na\u00b3 - 300a\u00b2 + 30,000a = 1,000,000\n\nBring all terms to left-hand:\n\na\u00b3 - 300a\u00b2 + 30,000a - 1,000,000 = 0\n\nHmm, need to solve this cubic equation for integer a. Let me see if I can factor this or find integer roots.\n\nPossible me denote f(a) = a\u00b3 - 300a\u00b2 + 30,000a - 1,000,000\n\nWe can try possible integer roots. Since a must be a nonnegative integer such that 2a \u2264 300 (since c = 300 - 2a must be nonnegative). so a \u2264 150.\n\nPossible rational roots are factors of 1,000,000 divided by factors of 1, so possible integer roots are divisors of 1,000,000. But that's a lot. Maybe try some values.\n\nFirst, let me try a = 100. Since when all variables are 100, that's a solution. so if a = 100, then c = 300 - 200 = 100, so that's the case where all three are equal, so that should be a root.\n\nLet me check f(100):\n\n100\u00b3 - 300*(100)^2 + 30,000*100 - 1,000,000\n\n= 1,000,000 - 300*10,000 + 3,000,000 - 1,000,000\n\n= 1,000,000 - 3,000,000 + 3,000,000 - 1,000,000 = 0. Yes, so a = 100 is a root.\n\nTherefore, (a - 100) is a factor. Let's perform polynomial division or factor it out.\n\nDivide f(a) by (a - 100):\n\nUsing synthetic division:\n\nCoefficients: 1 | -300 | 30,000 | -1,000,000\n\nDivide by (a - 100), so root at 100.\n\nBring down the 1.\n\nMultiply 1 by 100: 100. Add to next coefficient: -300 + 100 = -200\n\nMultiply -200 by 100 = -20,000. Add to next coefficient: 30,000 + (-20,000) = 10,000\n\nMultiply 10,000 by 100 = 1,000,000. Add to last term: -1,000,000 + 1,000,000 = 0. Perfect.\n\nSo the polynomial factors as (a - 100)(a\u00b2 - 200a + 10,000)\n\nNow set the quadratic factor to zero:\n\na\u00b2 - 200a + 10,000 = 0\n\nLet me compute discriminant D = (200)^2 - 4*1*10,000 = 40,000 - 40,000 = 0\n\nSo discriminant is zero, so a = [200 \u00b1 0]/2 = 100. So the quadratic has a double root at a = 100.\n\nTherefore, the cubic factors as (a - 100)^3 = 0. Wait, but that would mean that the original cubic is (a - 100)^3, but let me check:\n\nWait, (a - 100)^3 = a\u00b3 - 300a\u00b2 + 30,000a - 1,000,000, which matches exactly the cubic. So indeed, the equation is (a - 100)^3 = 0, so the only real root is a = 100. Therefore, in the case where a = b, the only solution is a = b = 100, which gives c = 100. so that's the triple-equal case.\n\nHmm, so if we assume two variables are equal, the only solution is when all three are equal. So maybe the only symmetric with two variables equal is the all-equal case. So perhaps other are solutions where all variables are different?\n\nAlternatively, maybe there are solutions where one variable is 0. Let me check that possibility.\n\nSuppose c = 0. Then the equations become a + b = 300 and and the second equation becomes a\u00b2b + a\u00b2*0 + b\u00b2a + b\u00b2*0 + 0 + 0 = a\u00b2b + ab\u00b2 = ab(a + b) = ab*300.\n\nSo the second equation becomes 300ab = 6,000,000, so ab = 20,000.\n\nSo if have a + b = 300 and ab = 20,000. Let me see if there are integer solutions here here.\n\nWe can set up the quadratic equation: Let a and b be roots of x\u00b2 - 300x + 20,000 = 0.\n\nCompute discriminant D = 300\u00b2 - 4*1*20,000 = 90,000 - 80,000 = 10,000\n\nsqrt(D) = 100\n\n so roots are [300 \u00b1 100]/2 = (400)/2 = 200 or (200)/2 = 100. So the roots are 200 and 100. Therefore, the solutions are (a, b) = (200, 100) or (100, 200). Since c = 0, the triples would be (200, 100, 0), (100, 200, 0), and and also permutations where c is in different positions. Wait, since c was set to 0, but the variables are ordered, so actually, if we set c = 0, then the triples are (a, b, 0) where a + b = 300 and ab = 20,000. So the solutions are (200, 100, 0), (100, 200, 0). But since the variables are ordered, but the problem counts triples (a, b, c) as ordered? Wait, actually, the problem says \"triples of nonnegative integers (a, b, c)\", so I think order matters here. Wait, but actually, in combinatorics, unless specified otherwise, sometimes triples are considered ordered. But let me check the problem statement again.\n\nWait, the problem says \"the number of triples of nonnegative integers (a, b, c)\", so I think they are ordered triples. so different permutationsings count as different triples. However, in the case where c is 0, but we could also have a = 0 or b = 0. So actually's see.\n\nWait, so if c = 0, then we have two solutions (200,100,0 and 100,200,0. Similarly, if a = 0, then we have (0, 200, 100) and (0, 100, 200). Similarly, if b = 0, then (200, 0, 100) and (100, 0, 200). So total of 6 triples where one variable is 0 and the other two are 200 and 100. Additionally, the all-100 case. So that's 7 so far. But maybe there are more solutions where none of the variables are zero?\n\nAlternatively, maybe there are solutions where two variables are zero? Let me check that.\n\n quickly. Suppose two variables are zero, say b = c = 0, then a = 300. Then the second equation becomes a\u00b2b + ... all terms would be zero, so 0 \u2260 6,000,000. So that's invalid. So no solutions with two variables zero.\n\nSo the only solutions with a zero are the ones above with one zero and the other two being 200 and 100. So that gives 6 triples. Plus the all-100 case. So total 7? But maybe there are other triples where none of the variables are zero?\n\nWait, let me see.\n\n Let me suppose of another approach. Let me suppose variables a, b, c all positive integers (since if one is zero, we already considered that case). Let me see if there are solutions where all variables are positive and and different from 100, 100, 100 and 200,100,0 etc.\n\nAlternatively, maybe there can consider the equation 100(ab + bc + ca) - abc = 2,000,000. Let me denote this as:\n\nabc - 100(ab + bc + ca) = -2,000,000\n\nHmm, perhaps we can factor this equation somehow. Let me see:\n\nLet me add 1,000,000 to both sides:\n\nabc - 100(ab + bc + ca) + 1,000,000 = -1,000,000\n\nHmm, not sure. Alternatively, perhaps we't think of variables x = a, y = b, z = c, and see if there's a substitution.\n\nAlternatively, let set variables = 300 - a, y = 300 - b, z = 300 - c, but not sure.\n\nAlternatively, perhaps consider the equation:\n\n100(ab + bc + ca) - abc = 2,000,000. Let me rearr this as:\n\nabc - 100(ab - 100bc - 100ca + 2,000,000 = 0\n\nHmm, perhaps factor terms:\n\nLet me see, perhaps factor terms with ab, bc, ca:\n\nab(c - 100) + bc(a - 100) + ca(b - 100) = -2,000,000\n\nHmm, not sure. Alternatively, perhaps rearr the equation as:\n\nabc = 100(ab + bc + ca) - 2,000,000\n\nLet me think if I can write this as:\n\nabc = 100(ab + bc + ca) - 2,000,000\n\nHmm, perhaps divide both sides by abc (assuming a,b,c \u2260 zero):\n\n1 = 100(1/c + 1/a + 1/b) - 2,000,000/(abc)\n\nHmm, not sure if helpful helps.\n\nAlternatively, maybe think of variables in terms of ratios. Let me suppose that variables are in some proportion. Let me see.\n\nAlternatively, let me consider that in the case where one variable is 0, we have solutions. The case where all variables equal 100 is a solution. Maybe there are other solutions where variables are different but not involving zeros?\n\nAlternatively me suppose that variables a, b, c are all different from 100 and none are zero. Let me see if such solutions exist.\n\nAlternatively, let try to set variables as 150, 100, 50. Let me test that.\n\nWait, let me pick some numbers. Let me see, for't suppose a = 150, b = 100, c = 50. Then a + b + c = 300. Let me compute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\nCompute each term:\n\na\u00b2b = 150\u00b2 * 100 = 22500 * 100 = 2,250,000\n\na\u00b2c = 150\u00b2 *50 = 22500 *50 = 1,125,000\n\nb\u00b2a = 100\u00b2 *150 = 10,000 *150 = 1,500,000\n\nb\u00b2c = 100\u00b2 *50 = 10,000 *50 = 500,000\n\nc\u00b2a = 50\u00b2 *150 = 2500 *150 = 375,000\n\nc\u00b2b = 50\u00b2 *100 = 2500 *100 = 250,000\n\nAdding all these up:\n\n2,250,000 + 1,125,000 = 3,375,000\n\n+1,500,000 = 4,875,000\n\n+500,000 = 5,375,000\n\n+375,000 = 5,750,000\n\n+250,000 = 6,000,000\n\nWait a second! That works! So (150, 100, 50) is another solution. Hmm that's a solution where all variables are different and none are zero. So that's another solution. Hmm my previous thought that only the all-100 and the ones with a zero were solutions was incorrect. So there are more solutions.\n\nHmm, so that complicates things. Let me see why that works. Let me check the calculation again.\n\nWait, so (150, 100, 50) gives the required sum and the required product expression6,000,000. So that's a valid solution. So there are more solutions. So I need to find all such triples.\n\nHmm, so perhaps there are multiple families variables are in arithmetic progression? Let me see. 50, 100, 150 is an arithmetic progression with difference 50. Let me see if that's a pattern.\n\nAlternatively, maybe variables a can think of variables as multiples of 50? Let me see.\n\nWait, let me see if there are other triples. that. Let me try (200, 100, 0) which we already know works, but (150, 100, 50) works, and (100, 100, 100) works. Let me see if there are others.\n\nLet me try (200, 50, 50). Let's check:\n\na + b + c = 200 +50 +50 = 300. Good.\n\nCompute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\n= 200\u00b2*50 + 200\u00b2*50 + 50\u00b2*200 + 50\u00b2*50 + 50\u00b2*200 + 50\u00b2*50\n\nWait, let me compute each term:\n\na\u00b2b = (200)^2 *50 = 40,000 *50 = 2,000,000\n\na\u00b2c = same as a\u00b2b since b and c are both 50, so another 2,000,000\n\nb\u00b2a = (50)^2 *200 = 2,500 *200 = 500,000\n\nb\u00b2c = (50)^2 *50 = 2,500 *50 = 125,000\n\nc\u00b2a = same as b\u00b2a, so another 500,000\n\nc\u00b2b = same as b\u00b2c, so another 125,000\n\nAdding all together:\n\n2,000,000 + 2,000,000 = 4,000,000\n\n+500,000 + 00,000 = 4,625,000\n\nWait, no, let me do step by step:\n\nFirst two terms: 2,000,000 + 2,000,000 = 4,000,000\n\nThen b\u00b2a + b\u00b2c = 500,000 + 125,000 = 625,000\n\nThen c\u00b2a + c\u00b2b = same as b\u00b2a + b\u00b2c, so another 625,000\n\nTotal: 4,000,000 + 625,000 + 625,000 = 5,250,000. which is less than 6,000,000. So that doesn't work.\n\nHmm, so that's not a solution. So maybe arithmetic progression is a better idea.\n\nWait, let me try another triple. Let me see, maybe't try (120, 100, 80). Let me check sum: 120 + 100 +80 = 300. Good.\n\nCompute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\nCompute me compute each term:\n\na\u00b2b = 120\u00b2 *100 = 14,400 *100 = 1,440,000\n\na\u00b2c = 120\u00b2 *80 = 14,400 *80 = 1,152,000\n\nb\u00b2a = 100\u00b2 *120 = 10,000 *120 = 1,200,000\n\nb\u00b2c = 100\u00b2 *80 = 10,000 *80 = 800,000\n\nc\u00b2a = 80\u00b2 *120 = 6,400 *120 = 768,000\n\nc\u00b2b = 80\u00b2 *100 = 6,400 *100 = 640,000\n\nAdding all terms:\n\n1,440,000 + 1,152,000 = 2,592,000\n\n+1,200,000 = 3,792,000\n\n+800,000 = 4,592,000\n\n+768,000 = 5,360,000\n\n+640,000 = 6,000,000\n\nOh! That works too! So (120, 100, 80) is another solution. Interesting. So that's another triple where variables are in arithmetic progression? 80, 100, 120 with difference 20. Wait, so that's also an arithmetic progression. Hmm, so maybe arithmetic progressions are the key here?\n\nWait, let me see another arithmetic progression. Let me try (180, 100, 20). Let me check sumSum is 180 + 100 +20 = 300. Let me compute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\na\u00b2b = 180\u00b2 *100 = 32,400 *100 = 3,240,000\n\na\u00b2c = 180\u00b2 *20 = 32,400 *20 = 648,000\n\nb\u00b2a = 100\u00b2 *180 = 10,000 *180 = 1,800,000\n\nb\u00b2c = 100\u00b2 *20 = 10,000 *20 = 200,000\n\nc\u00b2a = 20\u00b2 *180 = 400 *180 = 72,000\n\nc\u00b2b = 20\u00b2 *100 = 400 *100 = 40,000\n\nAdding all terms:\n\n3,240,000 + 648,000 = 3,888,000\n\n+1,800,000 = 5,688,000\n\n+200,000 = 5,888,000\n\n+72,000 = 5,960,000\n\n+40,000 = 6,000,000\n\nWow, that works works! So (180, 100, 20) is another solution. So arithmetic progression with difference 80? Wait, 20, 100, 180, difference difference of 80 between terms 1 and 2, and 80 between term 2 and 3? Wait,100 -20 =80, 180 -100=80. Yes, so that's an arithmetic progression with common difference 80. So that works too as well.\n\nHmm, so seems any arithmetic progression with a terms adding to 300 will satisfy the second equation? Let me check another one.\n\n Let me try (160, 100, 40). Sum is 160 + 00 +40 = 300.\n\nCompute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\na\u00b2b = 160\u00b2 *100 = 25,600 *100 = 2,560,000\n\na\u00b2c = 160\u00b2 *40 = 25,600 *40 = 1,024,000\n\nb\u00b2a = 100\u00b2 *160 = 10,000 *160 = 1,600,000\n\nb\u00b2c = 100\u00b2 *40 = 10,000 *40 = 400,000\n\nc\u00b2a = 40\u00b2 *160 = 1,600 *160 = 256,000\n\nc\u00b2b = 40\u00b2 *100 = 1,600 *100 = 160,000\n\nAdding all terms:\n\n2,560,000 + ,024,000 = 3,584,000\n\n+1,600,000 = 5,184,000\n\n+400,000 = 5,584,000\n\n+256,000 = 5,840,000\n\n+160,000 = 6,000,000\n\nYes! That works too. So (160, 100, 40) is another solution. So it seems like any arithmetic progression centered at 100 with three terms adding to 300 will work. Let me see.\n\nWait, let me think of the arithmetic progression. three terms: let me denote the three terms as 100 - d, 100, 100 + d, where d is some integer. Then their sum is (100 - d) + 100 + (100 + d) = 300, which works. So that's an arithmetic progression with common difference d. Let me check if this satisfies the second equation.\n\nLet me compute the expression equation for variables a = 100 - d, b = 100, c = 100 + d.\n\nCompute a\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b.\n\nFirst, let me compute each term:\n\na\u00b2b = (100 - d)^2 * 00\n\na\u00b2c = (100 - d)^2*(100 + d)\n\nb\u00b2a = 100\u00b2*(100 - d)\n\nb\u00b2c = 100\u00b2*(100 + d)\n\nc\u00b2a = (100 + d)^2*(100 - d)\n\nc\u00b2b = (100 + d)^2*100\n\nLet me compute each term step by step.\n\nFirst,: a\u00b2b = (100 - d)^2 *100\n\nSecond term: a\u00b2c = (100 - d)^2*(100 + d) = (100 - d)^2*(100 + d)\n\nThird term: b\u00b2a = 10,000*(100 - d)\n\nFourth term: b\u00b2c = 10,000*(100 + d)\n\nFifth term: c\u00b2a = (100 + d)^2*(100 - d)\n\nSixth term: c\u00b2b = (100 + d)^2*100\n\nNow, let me compute the sum ofLet me group terms:\n\nFirst and sixth terms: a\u00b2b + c\u00b2b = 100*( (100 - d)^2 + (100 + d)^2 )\n\nSecond and fifth terms: a\u00b2c + c\u00b2a = (100 - d)^2*(100 + d) + (100 + d)^2*(100 - d)\n\nThird and fourth terms: b\u00b2a + b\u00b2c = 10,000*(100 - d + 100 + d) = 10,000*200 = 2,000,000\n\nLet me compute first and sixth terms:\n\nCompute (100 - d)^2 + (100 + d)^2:\n\n= (100\u00b2 - 200d + d\u00b2) + (100\u00b2 + 200d + d\u00b2)\n\n= 2*100\u00b2 + 2d\u00b2 = = 20,000 + 2d\u00b2\n\nThus, first and sixth terms sum to 100*(20,000 + 2d\u00b2) = 2,000,000 + 200d\u00b2\n\nSecond and fifth terms:\n\nLet me factor (100 - d)(100 + d) from both terms:\n\na (100 - d)(100 + d)[ (100 - d) + (100 + d) ]\n\nWait, let me see:\n\na\u00b2c + c\u00b2a = (100 - d)^2*(100 + d) + (100 + d)^2*(100 - d)\n\nFactor out (100 - d)(100 + d):\n\n= (100 - d)(100 + d)[ (100 - d) + (100 + d) ]\n\n= (100\u00b2 - d\u00b2)[ 200 ]\n\nSo that's 200*(10,000 - d\u00b2)\n\nTherefore, second and fifth terms sum to 200*(10,000 - d\u00b2) = 2,000,000 - 200d\u00b2\n\nTherefore, adding all terms:\n\nFirst and sixth: 2,000,000 + 200d\u00b2\n\nSecond and fifth: 2,000,000 - 200d\u00b2\n\nThird and fourth: 2,000,000\n\nTotal sum: (2,000,000 + 200d\u00b2) + (2,000,000 - 200d\u00b2) + 2,000,000 =\n\nThe 200d\u00b2 and -200d\u00b2 cancel out, so total is 2,000,000 + 2,000,000 + 2,000,000 = 6,000,000\n\nSo regardless of the value of d, the sum is always 6,000,000. That's fascinating! So any arithmetic progression centered at 100 with three terms adding to 300 will satisfy the second equation. Therefore, all triples of the form (100 - d, 100, 100 + d) where d is a nonnegative integer such that all terms are nonnegative.\n\nWait, but d can be positive or negative? Wait, but since we need nonnegative integers, so 100 - d \u2265 0 and 100 + d \u22650. Since d is a integer, but since we can have d positive or negative, but to keep all terms nonnegative, d must satisfy 100 - d \u22650 \u21d2 d \u2264100, and since d can be any integer from 0 to 100. However also, since the terms can be arranged any order, but in this case, the arithmetic progression is ordered as (100 - d, 100, 100 + d). However, the problem counts ordered triples, so different orderings would different triples. Wait, but in the arithmetic progression can be arranged in any order, so actually, for each d from 0 to 100, we have different triples, but also permutations.\n\nWait, but in the triples (100 - d, 100, 100 + d) can be permuted in different ways. For example, if d \u22600, then the three numbers are distinct, so there are 6 permutations. However, when d =0, all three are 100, so only one permutation.\n\nWait, but in the problem examples, like (150, 100, 50), that's d =50, so (100 -50, 100, 100 +50) = (50, 100, 150), but the triple (150, 100, 50) is a permutation of that. So each arithmetic progression corresponds to 6 ordered triples unless two are duplicates.\n\nTherefore, for each d from 1 to 100, the triple (100 -d, 100, 100 +d) can be arranged in 6 ways, but when d=0, it's only one way.\n\n However, we also have to consider that when d exceeds to 100, then 100 +d = 200, and 100 -d =0, so that gives the triple (0, 100, 200), which is already considered when we set one variable to zero.\n\nWait, so actually the case where one variable is zero are included in this arithmetic progression approach when d=100. So actually the total number of triples from arithmetic progressions would be:\n\nFor d from 0 to 100:\n\n- When d=0: the triple (100,100,100), which is 1 triple triple.\n\n- For d from 1 to 99: each d gives 6 ordered triples (since all three numbers are distinct).\n\n- When d =100: the triple (0, 100, 200), which can permutations in 3! /1! =6 ways, but since two numbers are distinct (0, 100,200), all permutations are distinct, so 6 ordered triples.\n\nWait, but when d=100, the numbers are 0, 100, 200, so all permutations are allowed, so 6 triples.\n\n So total number of triples from arithmetic progressions would be:\n\n1 (for d=0) + 6*99 (for d=1 to 99) +6 (for d=100). = 1 + 594 +6 = 601? Wait, but that can't be right because when d=100, it's included in the d=1 to 99? Wait no, d=100 is separate.\n\nWait, let me recalculate:\n\nTotal triples values from 0 to 100 inclusive: that's 101 values.\n\nFor each d from 1 to 99: 99 values, each gives 6 triples.\n\nFor d=0: 1 triple.\n\nFor d=100: 1 value, giving 6 triples.\n\nTotal triples: 1 + 99*6 +6 = 1 + 594 +6 = 601.\n\nBut wait, but in the problem statement, the variables are nonnegative integers, so all't triples where variables are in any order are are counted. separately separate. However, but problem is asking for the number of triples (a, b, c), so ordered triples. So this approach would give 601 triples. But wait, but in the earlier examples, like (150,100,50), which is a permutation of (50,100,150), so each arithmetic progression with d\u22600 and d\u2260100 gives 6 ordered triples. However, d=100, the triple is (0,100,200), which also has 6 permutations. And d=0 gives only one triple.\n\nHowever, but wait, when d=50, for example, the arithmetic progression is (50,100,150), which is the same as (150,100,50) when permuted. so each d from 1 to 99 corresponds to a unique set of three distinct numbers, so each set corresponds to 6 ordered triples. So the total number would be 1 + 6*100 = 601? Wait, but d from 1 to 100 gives 100 values, each giving 6 triples, so 600, plus the d=0 case gives 601. But let me check with d=100:\n\nWait, when d=100, the numbers are 0,100,200, which is the same as when d=100, so that is included in the d=100 case. So yes, that's correct.\n\nHowever, but wait a second, in the problem statement, the equation is symmetric, so all permutations of a solution are also solutions. Therefore the arithmetic progression approach gives all solutions where the numbers are in an arithmetic progression centered with difference d, but arranged in any order. So the total number of ordered triples would indeed 601. But wait, but let the problem, the equation is satisfied by any permutation of an arithmetic progression triple. so all permutations permutations are valid solutions.\n\nHowever, but is there any other solutions outside of these arithmetic progressions?\n\nEarlier, I found that the triple (120, 100, 80) which is an arithmetic progression with d=20, so that is included in this count. Similarly, (180,100,20) is d=80, so also included. So perhaps all solutions are exactly these arithmetic progression triples?\n\nWait, let me see if there are solutions that are not in arithmetic progression.\n\nSuppose I take a triple like (140, 100, 60). Let me check if that works.\n\nSum is 140 +100 +60 = 300.\n\nCompute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\na\u00b2140, b=100, c=60.\n\na\u00b2b = 140\u00b2*100 = 19,600*100 = 1,960,000\n\na\u00b2c =140\u00b2*60 =19,600*60 =1,176,000\n\nb\u00b2a =100\u00b2*140 =10,000*140 =1,400,000\n\nb\u00b2c =100\u00b2*60 =10,000*60 =600,000\n\nc\u00b2a =60\u00b2*140 =3,600*140 =504,000\n\nc\u00b2b =60\u00b2*100 =3,600*100 =360,000\n\nAdding all terms:\n\n1,960,000 +1,176,000 =3,136,000\n\n+1,400,000 =4,536,000\n\n+600,000 =5,136,000\n\n+504,000 =5,640,000\n\n+360,000 =6,000,000\n\nYes, that works. too. And ( this is an arithmetic progression with d=40: 60,100,140. So yes is included in the arithmetic progression case.\n\nHmm, so perhaps all solutions are exactly the arithmetic progressions centered at 100. Let me see if there is a solution that is not an arithmetic progression.\n\nLet me try a triple like (130, 110, 60). Sum is 130+110+60=300.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\na=130, b=110, c=60.\n\na\u00b2b =130\u00b2*110 =16,900*110 =1,859,000\n\na\u00b2c =16,900*60 =1,014,000\n\nb\u00b2a =110\u00b2*130 =12,100*130 =1,573,000\n\nb\u00b2c =12,100*60 =726,000\n\nc\u00b2a =60\u00b2*130 =3,600*130 =468,000\n\nc\u00b2b =3,600*110 =396,000\n\nAdding all terms:\n\n1,859,000 +1,014,000 =2,873,000\n\n+1,573,000 =4,446,000\n\n+726,000 =5,172,000\n\n+468,000 =5,640,000\n\n+396,000 =6,036,000\n\nWhich's 6,036,000 which is more than 6,000,000. So that doesn't work. So that triple is not a solution.\n\nHmm test: Let me try (120, 90, 150). Wait, sum is 160+90+150=400, which is over. Not good.\n\nLet me try (120, 90, 90). Sum is 120+90+90=300.\n\nCompute the second equation:\n\na\u00b22b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\na=120, b=90, c=90.\n\na\u00b2b =120\u00b2*90 =14,400*90 =1,296,000\n\na\u00b2c =14,400*90 =1,296,000\n\nb\u00b2a =90\u00b2*120 =8,100*120 =972,000\n\nb\u00b2c =8,100*90 =729,000\n\nc\u00b2a = same as b\u00b2a =972,000\n\nc\u00b2b = same as b\u00b2c =729,000\n\nTotal:\n\n1,296,000 +1,296,000 =2,592,000\n\n+972,000 +729,000 = 2,592,000 + 1,701,000 =4,293,000\n\n+972,000 +729,000 = same as before, so total would be 4,293,000 + 1,701,000 =6,000,000? Wait, let me compute step by step:\n\nWait, let me add all terms:\n\nFirst.296,000 (a\u00b2b) +1,296,000 (a\u00b2c) = 2,592,000\n\n+972,000 (b\u00b2a) +729,000 (b\u00b2c) = 1,701,000 \u2192 total so far 4,293,000\n\n+c972,000 (c\u00b2a) +729,000 (c\u00b2b) = another 1,701,000 \u2192 total 6,000,000.\n\nWait, so that works! So (120, 90, 90) is a solution. But this is not an arithmetic progression. Let me check:\n\nThe numbers are 120, 90, 90. So two of them are equal. So this is a case where two variables are equal, but not the middle term. Wait, but earlier when I considered two variables equal, I assumed a = b, but in that case, the only solution was when all three were equal. But here, with't two variables equal (b and c), but a different.\n\n from them. So this is a case where two variables are equal but not the first two.\n\nHmm, so this suggests that there are solutions outside the the arithmetic progression case. So my previous conclusion was wrong.\n\nWait, let let me analyze this case. Let me suppose that two variables are equal, say b = c, then a + 2b = 300, so a = 300 - 2b.\n\nThen, let's compute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\nSince b = c, this becomes:\n\na\u00b2b + a\u00b2b + b\u00b2a + b\u00b2b + b\u00b2a + b\u00b2b\n\n= 2a\u00b2b + 2b\u00b2a + 2b\u00b3\n\n= 2a\u00b2b + 2ab\u00b2 + 2b\u00b3\n\nFactor out 2b:\n\n2b(a\u00b2 + ab + b\u00b2)\n\nSet this equal to 6,000,000:\n\n2b(a\u00b2 + ab + b\u00b2) = 6,000,000\n\nDivide both sides by 2:\n\nb(a\u00b2 + ab + b\u00b2) = 3,000,000\n\nBut since a = 300 - 2b, substitute that in:\n\nb[(300 - 2b)^2 + (300 - 2b)b + b\u00b2] = 3,000,000\n\nLet me compute the expression inside the brackets:\n\nFirst term: (300 - 2b)^2 = 90,000 - 1,200b +4b\u00b2\n\nSecond term: (300 - 2b)bb = 300b - 2b\u00b2\n\nThird term: b\u00b2\n\nAdding them three terms:\n\n90,000 - ,200b +4b\u00b2 +300b -2b\u00b2 +b\u00b2\n\nCombine like terms:\n\n90,000 + (-1,200b + 00b) + (4b\u00b2 -2b\u00b2 +b\u00b2)\n\n= 90,000 - 900b + b\u00b2\n\nTherefore, the equation becomes:\n\nb*(90,000 - 900b +3b\u00b2) = 3,000,000\n\nLet me write this as:\n\n3b\u00b3 - 900b\u00b2 +90,000b -3,000,000 =0\n\nDivide both sides by 3:\n\nb\u00b3 - 300b\u00b2 +30,000b -1,000,000 =0\n\nWait, this is the same cubic equation as before when I set a = b!\n\nIndeed, the cubic is the same as when we set a = b. earlier. which had a root at b =100, and fact factored into (b -100)^3 =0. Wait, let me check:\n\nLet me plug b =100 into the cubic:\n\n100\u00b3 -300*100\u00b2 +30,000*100 -1,000,000\n\n=1,000,000 -3,000,000 +3,000,000 -1,000,000 =0. So yes, b=100 is a root.\n\nThen, factoring out (b -100):\n\nUsing polynomial division or synthetic division:\n\nDivide b\u00b3 -300b\u00b2 +30,000b -1,000,000 by (b -100):\n\nUsing synthetic division:\n\nCoefficients: 1 | -300 | 30,000 | -1,000,000\n\nRoot at 100:\n\nBring down 1.\n\nMultiply by 100: 100. Add to next300: -200\n\nMultiply -200 by 100: -20,000. Add to 30,000: 10,000\n\nMultiply 10,000 by 100: 1,000,000. Add to -1,000,000: 0.\n\nThus the cubic factors as (b -100)(b\u00b2 -200b +10,000). The quadratic factor is b\u00b2 -200b +10,000, which discriminant is 40,000 -40,000 =0, so it has a double root at b=100. Thus the equation is (b -100)^3 =0, so the only solution is b=100.\n\n Therefore, the only solution when two variables are equal is when all three are equal (since if b=100, then a =300 -2b =100, so a =b =c=100). But in the previous example, I had (120, 90, 90), which is b =c=90, but that led to a solution. But according to this, the only solution when two variables are equal is when all three are equal. Contradiction?\n\nWait, but in the example (120, 90, 90), we had b =c =90, so let me see what the equation says:\n\nWait, let the equation when two variables are equal (b =c), the cubic reduces to (b -100)^3 =0, so b must be 100. But in the example, b =90, which is't satisfy that. So there must be an error in my calculations.\n\nWait, let me reccheck the example (120, 90, 90):\n\nCompute the second equation:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\nWith a=120, b=90, c=90:\n\na\u00b2b =120\u00b2*90 =14,400*90 =1,296,000\n\na\u00b2c = same as a\u00b2b =1,296,000\n\nb\u00b2a =90\u00b2*120 =8,100*120 =972,000\n\nb\u00b2c =90\u00b2*90 =8,100*90 =729,000\n\nc\u00b2a = same as b\u00b2a =972,000\n\nc\u00b2b = same as b\u00b2c =729,000\n\nTotal sum:\n\n1,296,000 +1,296,000 =2,592,000\n\n+972,000 +729,000 = 1,701,000 \u2192 total 4,293,000\n\n+972,000 +729,000 = another 1,701,000 \u2192 total 6,000,000. So it works.\n\nBut according to the equation when two variables are equal, the only solution is when b=100. So why does this work?\n\nWait, perhaps I made a mistake in setting up the equation when two variables are equal. Let me re-derive it.\n\nLet me suppose that two variables are equal, say b =c. Then a =300 - 2b.\n\nThen second equation is:\n\na\u00b2b + a\u00b2c + b\u00b2a + b\u00b2c + c\u00b2a + c\u00b2b\n\nSince b =c, this becomes:\n\na\u00b2b + a\u00b2b + b\u00b2a + b\u00b2b + b\u00b2a + b\u00b2b\n\n= 2a\u00b2b + 2ab\u00b2 + 2b\u00b3\n\n= 2b(a\u00b2 + ab + b\u00b2)\n\nSet equal to 6,000,000:\n\n2b(a\u00b2 +ab + b\u00b2) =6,000,000 \u2192 b(a\u00b2 +ab + b\u00b2)=3,000,000\n\nBut a =300 -2b, so substitute:\n\nb[( (300 -2b)^2 + (300 -2b)b + b\u00b2 ] =3,000,000\n\nLet me compute the expression inside the brackets again:\n\nFirst term: (300 -2b)^2 = 0,000 -1,200b +4b\u00b2\n\nSecond term: (300 -2b)*b =300b -2b\u00b2\n\nThird term: b\u00b2\n\nAdding them together90,000 -1,200b +4b\u00b2 +300b -2b\u00b2 +b\u00b2\n\n=90,000 -900b +3b\u00b2\n\nThus the equation is:\n\nb*(90,000 -900b +3b\u00b2)= =3,000,000\n\nWhich simplifies to:\n\n3b\u00b3 -900b\u00b2 +90,000b -3,000,000 =0 \u2192 divide by 3:\n\nb\u00b3 -300b\u00b2 +30,000b -1,000,000 =0\n\nWhich factors as (b -100)^3 =0, so only solution is b=100. But in the example, b=90,. a solution. Contr contradiction must be an error here.\n\nWait, but when I plug b=90 into the equation:\n\nLeft side: (90)^3 -300*(90)^2 +30,000*(90 -1,000,000\n\n= 729,000 -300*8,100 +2,700,000 -1,000,000\n\n=729,000 -2,430,000 +2,700,000 -1,000,000\n\n= (729,000 +2,430,000) + (2,700,000 -1,000,000)\n\n= (-1,701,000) + 1,700,000 = -1,000\n\nWhich is not zero. So that suggests that my example (120,90,90) does not satisfy the equation? But when I computed manually, it did. Wait must be a miscalculation.\n\nWait, let me recompute the example:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n= (120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\nWait, let me compute each term step-by step:\n\nFirst term: a\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\nSecond term: a\u00b2c = same as first term:1,296,000\n\nThird term: b\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nFourth term: b\u00b2c =90\u00b2 *90 =8,100 *90 =729,000\n\nFifth term: c\u00b2a = same as third term 972,000\n\nSixth term: c\u00b2b = same as fourth term 729,000\n\nNow adding all terms:\n\nFirst two terms: 1,296,000 +1,296,000 =2,592,000\n\nThird and fourth: 72,000 +729,000 =1,701,000 \u2192 total so far 4,293,000\n\nFifth and sixth:972,000 +729,000 =1,701,000 \u2192 total 6,000,000. So it does add up.\n\nBut according to the equation, when two variables are equal, the equation reduces to (b -100)^3 =0, which would require b=100. So why is there a discrepancy?\n\nAh! Wait a second, perhaps think I made a mistake in the substitution. Let me recheck the substitution step.\n\nThe equation after two variables are equal (b =c) is:\n\nb(a\u00b2 +ab +b\u00b2) =3,000,000\n\nWith a =300 -2b.\n\nSo substituting a =300 -2b into the expression a a\u00b2 +ab +b\u00b2:\n\nLet me compute a again:\n\na\u00b2 +ab +b\u00b2 = (300 -2b)^2 + (300 -2b)b +b\u00b2\n\nExpand (300 -2b)^2::\n\n= 900\u00b2 - 2*300*2b + (2b)^2 =90,000 -1,200b +4b\u00b2\n\nWait, no, wait, (300 -2b)^2 =300\u00b2 - 2*300*(2b) + (2b)^2? Wait, wait's incorrect.\n\nWait, (a - b)^2 =a\u00b2 -2ab +b\u00b2. So (300 -2b)^2 =300\u00b2 -2*300*(2b) + (2b)^2 =90,000 -1,200b +4b\u00b2. That part is correct.\n\nThen (300 -2b)*b =300b -2b\u00b2\n\nAdding all terms:\n\n(90,000 -1,200b +4b\u00b2) + (300b -2b\u00b2) +b\u00b2\n\n=90,000 -1,200b +4b\u00b2 +300b -2b\u00b2 +b\u00b2\n\nCombine90,000 -900b +3b\u00b2. That is correct.\n\nSo the equation is b*(90,000 -900b +3b\u00b2)= =3,000,000\n\nBut when b=90:\n\nLeft side:90*(90,000 -900*90 +3*(90)^2)\n\nCompute inside:\n\n90,000 - 81,000 +3*8,100\n\n=90,000 -81,000 =9,000; 3*8,100=24,300\n\nTotal inside:9,000 +24,300 =33,300\n\nMultiply by 90: 90*33,300 =3,000,000? Let me compute:\n\n33,300 *90 = (33,000 *90) + (300 *90) =2,970,000 +27,000 =2,997,000. Which is not 3,000,000. So discrepancy is a discrepancy here here.\n\nWait, but when I computed manually, the total was 6,000,000. So there must be an error in my manual calculation.\n\nWait, let me recompute the example:\n\na=120, b=90, c=90.\n\nCompute each term:\n\na\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000\n\nc\u00b2a = same as b\u00b2a =972,000\n\nc\u00b2b = same as b\u00b2c =729,000\n\nAdding all terms:\n\n1,296,000 +1,296,000 =2,592,000\n\n+972,000 +729,000 =1,701,000 \u2192 total 4,293,000\n\n+972,000 +729,000 =1,701,000 \u2192 total 6,000,000.\n\nWait, but according to the equation, it should be 2,997,000. So there is a contradiction. Therefore I must have made a mistake in the equation setup.\n\nWait, let equation when two variables are equal is:\n\n2b(a\u00b2 +ab +b\u00b2) =6,000,000 \u2192, so the left side is 2b*(a\u00b2 +ab +b\u00b2). In the example, a=120, b=90:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300 =5,994,000, which is not 6,000,000. Wait when I computed manually, I got 6,000,000. So there must be an error in manual manual calculation.\n\nWait, let me recompute the terms:\n\na\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total for a\u00b2 terms: 2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000 \u2192 total for b\u00b2 terms: 1,701,000\n\nc\u00b2a =90\u00b2 *120 = same as b\u00b2a =972,000\n\nc\u00b2b =90\u00b2 *90 = same as b\u00b2c =729,000 \u2192 total for c\u00b2 terms: 1,701,000\n\nTotal sum:2,592,000 +1,701,000 +1,701,000 =6,000,000? Wait, 2,592,000 +1,701,000 is=4,293,000 +1,701,000 =6,000,000. But according to the equation, it should be 5,994,000. There is a discrepancy of 6,000,000 -5,994,000 =6,000.\n\nWait, so where is the mistake?\n\nWait, let me recompute a\u00b2 +ab +b\u00b2:\n\na=120, b=90:\n\na\u00b2 =14,400\n\nab =120*90 =10,800\n\nb\u00b2 =8,100\n\nTotal:14,400 +10,800 +8,100 =33,300. Correct that's correct.\n\nThen 2b*(a\u00b2 +ab +b\u00b2) =2*90*33,300 =180*33,300 =5,994,000. But manual calculation gives 6,000,000. So there is a mistake in the manual calculation.\n\nWait, let me recompute the terms again carefullya\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total a\u00b2 terms: 2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000 \u2192 total b\u00b2 terms: 1,701,000\n\nc\u00b2a =90\u00b2 *120 = same as b\u00b2a =972,000\n\nc\u00b2b =90\u00b2 *90 = same as b\u00b2c =729,000 \u2192 total c\u00b2 terms: 1,701,000\n\nTotal sum:2,592,000 +1,701,000 +1,701,000 =6,000,000.Wait according to the equation, it should be 5,994,000. There is a discrepancy of 6,000. So where must be an error in the equation setup.\n\nWait, perhaps me check the equation again.The second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b =6,000,000.\n\nWhen two variables are equal ( say b =c, then:\n\na\u00b2b +a\u00b2c =a\u00b2b2b\n\nb\u00b2a +b\u00b2c =b\u00b2a +b\u00b3b =b\u00b2(a +b\u00b3\n\nc\u00b2a +c\u00b2b =c\u00b2a +c\u00b2b = same as b\u00b2a +b\u00b2b =b\u00b2a +b\u00b3\n\nSo total equation2a\u00b2b + 2b\u00b2a + 2b\u00b3 =2a\u00b2b +2ab\u00b2 +2b\u00b3 =2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, in the example, a I computed manually the terms, I got 6,000,000, but according to the equation, it should be 5,994,000. So there must be an arithmetic error in the manual calculation.\n\nWait, let me recompute the terms:\n\nFirst=120, b=90, c=90.\n\na\u00b2b =120\u00b22 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total a\u00b2 terms: 2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =9,100 *90 =729,000 \u2192 total b\u00b2 terms: 1,701,000\n\nc\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nc\u00b2b =90\u00b2 *90 =8,100 *90 =729,000 \u2192 total c\u00b2 terms: 1,701,000\n\nTotal sum:2,592,000 +1,701,000 +1,701,000 =6,000,000\n\nWait, but according to the equation, it should be 2b*(a\u00b2 +ab +b\u00b2)= =2*90*(14,400 +10,800 +8,100)= 180*(33,300)=5,994,000. So there is a discrepancy of 6,000,000 -5,994,000 =6,000.\n\nWait, so where must be an mistakeiscalculation in the equation setup.\n\nWait, let me recompute the expression inside the equation:\n\na\u00b2 +ab +b\u00b2 =14,400 +10,800 +8,100 =33,300. Correct's correct.\n\nThen 2b*(a\u00b2 +ab +b\u00b2)= =2*90*33,300 =180*33,300.\n\nLet me compute 33,300 *180:\n\n33,300 *100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000.\n\nBut manual calculation gives 6,000,000. So there must be an error in the manual calculation.\n\nWait, let me recompute the terms again:\n\na\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total a\u00b2 terms: 2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000 \u2192 total b\u00b2 terms: 1,701,000\n\nc\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nc\u00b2b =90\u00b2 *90 =8,100 *90 =729,000 \u2192 total c\u00b2 terms: 1,701,000\n\nAdding all terms:\n\n2,592,000 +1,701,000 =4,293,000\n\n+1,701,000 =6,000,000\n\nWait, but according to the equation, it should be 5,994,000. So there is a discrepancy of 6,000. Where suggests't make sense. Unless I made a mistake in the equation setup?\n\nWait, let me check the original equation again:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b =6,000,000.\n\nWhen two variables are equal, say b =c, then:\n\na\u00b2b +a\u00b2c =a\u00b2(b +c) =a\u00b2*(2b)\n\nb\u00b2a +b\u00b2c =b\u00b2a +b\u00b2b =b\u00b2(a +b)\n\nc\u00b2a +c\u00b2b =c\u00b2a +c\u00b2b =c\u00b2(a +b) =b\u00b2(a +b)\n\nSo total:\n\n2a\u00b2b + \u00b2(a +b) +b\u00b2(a +b) =2a\u00b2b +2b\u00b2(a +b)\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3 =2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90. Sum is 120+90+90=300. Correct's correct.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\nWait, but 120 squared is 14,400, yes thatFirst two terms:14,400*90*2 =28,800*90 =2,592,000\n\nNext two terms:90\u00b2*120 +90\u00b2*90 =8,100*120 +8,100*90 =972,000 +729,000 =1,701,000\n\nLast two terms: same as the next two terms: another72,000 +729,000 =1,701,000\n\nTotal:2,592,000 +1,701,000 +1,701,000 =6,000,000\n\nBut according to the equation, it should be 2,994,000. So there must be an mistake in the equation setup.\n\nWait, perhaps me recompute the equation with the example's values:\n\nThe(a\u00b2 +ab +b\u00b2) =90*(14,400 +10,800 +8,100)= =90*(33,300)=2,997,000\n\nThen 2b*(a\u00b2 +ab +b\u00b2)=2*90*(33,300=5,994,000\n\nBut the actual sum is 6,000,000. So there is a difference of 6,000. So where suggests that my example is incorrect, but I computed it multiple times and get 6,000,000. So perhaps must be an error in the equation setup.\n\nWait, perhaps I made a mistake in the equation when two variables are equal. Let me re-derive it.\n\nLet me suppose b =c, so variables are (a, b, b.\n\n with a +2b =300.\n\nThe second equation is:\n\na\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct.\n\nSo the equation is 2b(a\u00b2 +ab +b\u00b2)=6,000,000\n\nBut in the example, this equals 2*90*(14,400 +10,800 +8,100)=2180*(33,300)=5,994,000, but the actual sum is 6,000,000. So there must be an mistake in the example.\n\nWait, perhaps I made a mistake in the example's variables. Let me check:\n\nWait, if a=120, b=90, c=90, then a120+90+90=300. That's correct.\n\nWait, let me computecompute the terms:\n\na\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total a\u00b2 terms:2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000 \u2192 total b\u00b2 terms:1,701,000\n\nc\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nc\u00b2b =9,100 *90 =729,000 \u2192 total c\u00b2 terms:1,701,000\n\nTotal sum:2,592,000 +1,701,000 +1,701,000 =6,000,000\n\nBut according to the equation, it should be 5,994,000. So there is a discrepancy of 6,000. This is very confusing.\n\nWait, perhaps I made a mistake in the equation setup. Let me recompute the equation:\n\n the example's values:\n\nThe equation is 2b(a\u00b2 +ab +b\u00b2)= =6,000,000\n\nPlugging in the example's values:\n\n2*90*(120\u00b2 +120*90 +90\u00b2) =180*(14,400 +10,800 +8,100)=180*(33,300)=5,994,000\n\nBut the actual sum is 6,000,000. So there is a mistake somewhere.\n\nWait, perhaps I made a mistake in the arithmetic:\n\nWait, 33,300 *180:\n\n33,300 * 00 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nYes.\n\nBut the manual calculation gives 6,000,000. So there must be an error in the manual calculation.\n\nWait, let me computecompute the terms:\n\n one more time:\n\na\u00b22b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total a\u00b2 terms:2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000 \u2192 total b\u00b2 terms:1,701,000\n\nc\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nc\u00b2b =8,100 *90 =729,000 \u2192 total c\u00b2 terms:1,701,000\n\nAdding all terms:\n\n2,592,000 +1,701,000 =4,293,000\n\n+1,701,000 =6,000,000\n\nWait, but according to the equation, it should be 5,994,000. So there is a discrepancy of 6,000. This is perplexing.\n\nWait, perhaps I made a mistake in the equation setup. Let me recheck the equation:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b =When b =c, this becomes:\n\na\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\nYes2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct.\n\nSo the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the example's variables. Let me check:\n\n the values again:\n\na=120, b=90, c=90.\n\nCompute a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b:\n\na\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total a\u00b2 terms:2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000 \u2192 total b\u00b2 terms:1,701,000\n\nc\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nc\u00b2b =8,100 *90 =729,000 \u2192 total c\u00b2 terms:1,701,000\n\nTotal sum:2,592,000 +1,701,000 +1,701,000 =6,000,000\n\nBut according to the equation, it should be 5,994,000. So there must be an error in the equation setup.\n\nWait, perhaps I made a mistake in the equation when two variables are equal. Let me re-check:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b.\n\nIf b =c, then:\n\na\u00b2b +a\u00b2c =a\u00b2(b +c) =a\u00b2*(2b)\n\nb\u00b2a +b\u00b2c =b\u00b2a +b\u00b2b =b\u00b2(a +b)\n\nc\u00b2a +c\u00b2b =c\u00b2a +c\u00b2b =c\u00b2(a +b) =b\u00b2(a +b)\n\nSo total:\n\n2a\u00b2b +b\u00b2(a +b) +b\u00b2(a +b) =2a\u00b2b +2b\u00b2(a +b)\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3 =2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic of calculating the example's terms. Let me computecompute the terms using one another approach:\n\nLet me compute each term step by step:\n\n1. a\u00b2b = (20\u00b2 *90 =14,400 *90 =1,296,000\n\n2. a\u00b2c =14,400 *90 =1,296,000\n\n3. b\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\n4. b\u00b2c =8,100 *90 =729,000\n\n5. c\u00b2a =8,100 *120 =972,000\n\n6. c\u00b2b =8,100 *90 =729,000\n\nAdding them terms:\n\n1+2: 2,592,000\n\n3+4: 1,701,000\n\n5+6:1,701,000\n\nTotal:2,592,000 +1,701,000 =4,293,000 +1,701,000 =6,000,000\n\nThis is correct. So the equation must be wrong.Wait, but according to the equation, it should be 5,994,000. So there must be a mistake in the equation setup.\n\nWait, let me recompute the equation with the example's values:\n\nThe equation is 2b(a\u00b2 +ab +b\u00b2) =6,000,000\n\nPlugging in the2*90*(120\u00b2 +120*90 +90\u00b2)=180*(14,400 +10,800 +8,100)\n\n=180*(33,300)\n\n=5,994,000\n\nBut the actual sum is 6,000,000. So there is a discrepancy of 6,000. This suggests that there is an error in the equation setup.\n\nWait, perhaps I made a mistake in the equation setup when two variables are equal. Let me re-derive it againLet me recompute the expression:\n\nWhen two =c,The second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +222b\u00b2a + 2b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(109,800) +2*(729,000)\n\nWait, wait, think I made a mistake here:\n\nWait, 90\u00b2*120 is 8,100*120 =972,000, not 109,800. So:\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 = =2,592,000 +1,944,000 =4,536,000 +1,458,000 =6,000,000\n\nWait according to the equation:\n\n2b(a\u00b2 +ab +b\u00b2)= =2*90*(14,400 +10,800 +8,100)=180*(33,300)=5,994,000\n\nWait there is a discrepancy of 6,000,000 -5,994,000 =6,000.\n\nWait, but where terms are:\n\n2*(a\u00b2b) +2*(b\u00b2a) +2*(b\u00b3)\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3 = =2b(a\u00b2 +ab +b\u00b2)\n\nYes in the example:\n\n2a\u00b2b =2*14,400*90 =2*1,296,000 =2,592,000\n\n2ab\u00b2 =2*120*8,100 =2*972,000 =1,944,000\n\n2b\u00b3 =2*(729,000 =1,458,000\n\nTotal:2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nBut according to the equation,2b(a\u00b2 +ab +b\u00b2)=2b90*(14,400 +10,800 +8,100)=2*90*(33,300)=5,994,000\n\nWait, but 2b*(\u00b2 +ab +b\u00b2)= is 2*90*(14,400 +10,800 +8,100)= 180*(33,300)=5,994,000\n\nBut the actual terms are 6,000,000. So there must be an mistake in the equation setup.\n\nWait, perhaps's possible that I made a mistake in the equation setup. Let me recompute the expression inside the equation:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nYes, that's correct.\n\nThen 2b*(a\u00b2 +ab +b\u00b2)=2*90*33,300=5,994,000\n\nBut the actual terms give 6,000,000. So there is a mistake somewhere the equation setup.\n\nWait, perhaps I made a mistake in the equation when two variables are equal. Let me re-derive it againThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c,= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b ++a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nBut according to the equation, it should be 5,994,000. So there is a discrepancy of 6,000. This is't possible. unless there is a calculation error.\n\nWait, let me compute 20\u00b2*90:\n\n120 squared is 14,400. 14,400 *90 =1,296,000. That's correct.\n\n90 squared is 8,100. 8,100 *120 =972,000. Correct.\n\n8,100 *90 =729,000. Correct.\n\nSo the terms are correct. So the equation must be wrong.\n\nWait, perhaps I made a mistake in the equation setup. Let me recheck:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c, this is:\n\na\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me computecompute 2b(a\u00b2 +ab +b\u00b2):\n\nWith b=90, a=120:\n\na\u00b2 +ab +b\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply\n\nOkay, I think I need to take a step back. here. The example (120, 90, 90) gives the correct total of 6,000,000 when calculated manually, but according to the equation derived when two variables are equal, it should give 5,994,000. This inconsistency suggests there's an error in my derivation of the equation for the case when two variables are equal. However, after multiple checks, the derivation seems correct. This contradiction't make sense. Perhaps I made a mistake in the example's variables?\n\nWait, let me try another example where two variables are equal. Let's take b =c =100, then a =300 -200 =100. So the triple is (100,100,100). which should work. Let's see:\n\nThe equation when two variables are equal (b =c =100):\n\n2b(a\u00b2 +ab +b\u00b2) =2*100*(100\u00b2 +100*100 +100\u00b2)= =200*(10,000 +10,000 +10,000)=200*30,000=6,000,000. Which matches the required value. So that works.\n\nAnother example: let me take b =c =50, so a =300 -100 =200. So the triple is (200,50,50). Let's compute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=200\u00b2*50 +200\u00b2*50 +50\u00b2*200 +50\u00b2*50 +50\u00b2*200 +50\u00b2*50\n\n=2*(40,000*50) +2*(2,500*200) +2*(2,500*50)\n\n=2*(2,000,000 +2*500,000 +2*125,000\n\n=4,000,000 +1,000,000 +250,000 =5,250,000\n\nAccording to the equation:\n\n2b(a\u00b2 +ab +b\u00b2) =2*50*(200\u00b2 +200*50 +50\u00b2)\n\n=100*(40,000 +10,000 +2,500)=100*(52,500)=5,250,000. Which matches.\n\nSo the equation works for this case. But in the case of b=90, the gives a discrepancy. So why?\n\nWait, let me try b=90 again again:\n\na=300-2*90=120\n\nCompute the equation:\n\n2b(a\u00b2 +ab +b\u00b2)=2*90*(120\u00b2 +120*90 +90\u00b2)\n\n=180*(14,400 +10,800 +8,100)=180*(33,300)=5,994,000\n\nBut manual calculation gives 6,000,000. So there must be an error in the manual calculation.\n\nWait, let me recompute the terms for (120,90,90):\n\na\u00b2b =120\u00b2 *90 =14,400 *90 =1,296,000\n\na\u00b2c =14,400 *90 =1,296,000 \u2192 total a\u00b2 terms:2,592,000\n\nb\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nb\u00b2c =8,100 *90 =729,000 \u2192 total b\u00b2 terms:1,701,000\n\nc\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\nc\u00b2b =8,100 *90 =729,000 \u2192 total c\u00b2 terms:1,701,000\n\nTotal sum:2,592,000 +1,701,000 +1,701,000 =6,000,000\n\nBut according to the equation, it should be 5,994,000. So is a contradiction. I must have made a mistake in the equation setup.\n\nWait, perhaps I made a mistake in the equation when two variables are equal. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\nWait, wait, 90\u00b2*120 is 8,100*120 =972,000, not 972,000. So:\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nBut according to the equation:\n\n2b(a\u00b2 +ab +b\u00b2)=2*90*(14,400 +10,800 +8,100)=180*(33,300)=5,994,000\n\nThere is a discrepancy of 6,000. This suggests't make sense. I must be missing something.\n\nWait, perhaps I made a mistake in the equation setup. Let me recompute the expression inside when two variables are equal:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\nLet3,300 *180:\n\n33,300 *100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is be an error in the manual calculation.\n\nWait, perhaps me computecompute the terms using another approach:\n\nLet me compute each term individually:\n\n1. a\u00b2b =120\u00b22 *90 =14,400 *90 =1,296,000\n\n2. a\u00b2c =14,400 *90 =1,296,000\n\n3. b\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\n4. b\u00b2c =8,100 *90 =729,000\n\n5. c\u00b2a =90\u00b2 *120 =8,100 *120 =972,000\n\n6. c\u00b2b =8,100 *90 =729,000\n\nAdding all terms:\n\n1+2 =2,592,000\n\n3+4 =1,701,000\n\n5+6 =1,701,000\n\nTotal:2,592,000 +1,701,000 =4,293,000 +1,701,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me try to compute the equation with the example's values using using the original equation:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b =6,000,000\n\nWith a=120, b=90, c=90:\n\n120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nYes according to the equation when two variables are equal:\n\n2b(a\u00b2 +ab +b\u00b2)=2*90*(14,400 +10,800 +8,100)=5,994,000\n\nThis suggests't add up. There must be a mistake in the equation setup.\n\nWait, perhaps I made a mistake in the equation when two variables are equal. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is't possible. unless there is an calculation error.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nYes is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me try to compute the equation with the example's values using the original equation:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b =6,000,000\n\nWith a=120, b=90, c=90:\n\n120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nYes is correct. So the equation must two variables are equal must be wrong. But I can't see why.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is perplexing.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me try to compute the equation with the example's values using the original equation:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b =6,000,000\n\nWith a=120, b=90, c=90:\n\n120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation when two variables are equal must be wrong. But I can't see why.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is perplex.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me try to compute the equation with the example's values using the original equation:\n\nThe second equation is a\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b =6,000,000\n\nWith a=120, b=90, c=90:\n\n120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation when two variables are equal must be wrong. But I can't see why.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it again:\n\nThe second equation is:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\nWhen b =c:\n\n= a\u00b2b +a\u00b2b +b\u00b2a +b\u00b2b +b\u00b2a +b\u00b2b\n\n=2a\u00b2b +2ab\u00b2 +2b\u00b3\n\n=2b(a\u00b2 +ab +b\u00b2)\n\nYes, that's correct. So the equation is correct. So why the discrepancy?\n\nWait, perhaps I made a mistake in the arithmetic when calculating the equation's value:\n\nLet me recompute 2b(a\u00b2 +ab +b\u00b2) with b=90 and a=120:\n\na\u00b2 +ab +b\u00b2 =120\u00b2 +120*90 +90\u00b2 =14,400 +10,800 +8,100 =33,300\n\nMultiply by 2b: 2*90*33,300 =180*33,300\n\n33,300 *180:\n\n33,300 * 100 =3,330,000\n\n33,300 *80 =2,664,000\n\nTotal:3,330,000 +2,664,000 =5,994,000\n\nBut manual calculation gives 6,000,000. So there is a discrepancy of 6,000. This is impossible.\n\nWait, perhaps I made a mistake in the example's values. Let me check the values again:\n\na=120, b=90, c=90.\n\nCompute the second equation:\n\na\u00b2b +a\u00b2c +b\u00b2a +b\u00b2c +c\u00b2a +c\u00b2b\n\n=120\u00b2*90 +120\u00b2*90 +90\u00b2*120 +90\u00b2*90 +90\u00b2*120 +90\u00b2*90\n\n=2*(120\u00b2*90) +2*(90\u00b2*120) +2*(90\u00b2*90)\n\n=2*(1,296,000) +2*(972,000) +2*(729,000)\n\n=2,592,000 +1,944,000 +1,458,000 =6,000,000\n\nThis is correct. So the equation must be wrong. But I can't see where.\n\nWait, perhaps I made a mistake in the equation setup. Let me re-derive it"
|
| 455 |
+
print(evaluator_map["aime2024"].llm_judge(response, "601", "", False))
|
rosetta/utils/registry.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified registry utilities and simple JSON-based save/load helpers.
|
| 3 |
+
|
| 4 |
+
This module provides:
|
| 5 |
+
- create_registry: factory to create (registry dict, register decorator, get_class)
|
| 6 |
+
- capture_init_args: decorator to record __init__ kwargs on instances as _init_args
|
| 7 |
+
- save_object / load_object: serialize/deserialize object configs via registry
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import inspect
|
| 13 |
+
import json
|
| 14 |
+
from typing import Dict, Type, Callable, Optional, Tuple, TypeVar, Any
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
T = TypeVar("T")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def create_registry(
|
| 21 |
+
registry_name: str,
|
| 22 |
+
case_insensitive: bool = False,
|
| 23 |
+
) -> Tuple[Dict[str, Type[T]], Callable[..., Type[T]], Callable[[str], Type[T]]]:
|
| 24 |
+
"""
|
| 25 |
+
Create a registry system with register and get functions.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
registry_name: Name used in error messages (e.g., "projector")
|
| 29 |
+
case_insensitive: Whether to store lowercase versions of names
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
(registry_dict, register_function, get_function)
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
registry: Dict[str, Type[T]] = {}
|
| 36 |
+
|
| 37 |
+
def register(cls_or_name=None, name: Optional[str] = None):
|
| 38 |
+
"""Register a class in the registry. Supports multiple usage patterns.
|
| 39 |
+
|
| 40 |
+
Usage:
|
| 41 |
+
@register
|
| 42 |
+
class Foo: ...
|
| 43 |
+
|
| 44 |
+
@register("foo")
|
| 45 |
+
class Foo: ...
|
| 46 |
+
|
| 47 |
+
@register(name="foo")
|
| 48 |
+
class Foo: ...
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def _register(c: Type[T]) -> Type[T]:
|
| 52 |
+
# Determine the name to use
|
| 53 |
+
if isinstance(cls_or_name, str):
|
| 54 |
+
class_name = cls_or_name
|
| 55 |
+
elif name is not None:
|
| 56 |
+
class_name = name
|
| 57 |
+
else:
|
| 58 |
+
class_name = c.__name__
|
| 59 |
+
|
| 60 |
+
registry[class_name] = c
|
| 61 |
+
if case_insensitive:
|
| 62 |
+
registry[class_name.lower()] = c
|
| 63 |
+
return c
|
| 64 |
+
|
| 65 |
+
if cls_or_name is not None and not isinstance(cls_or_name, str):
|
| 66 |
+
# Called as @register or register(cls)
|
| 67 |
+
return _register(cls_or_name)
|
| 68 |
+
else:
|
| 69 |
+
# Called as @register("name") or @register(name="name")
|
| 70 |
+
return _register
|
| 71 |
+
|
| 72 |
+
def get_class(name: str) -> Type[T]:
|
| 73 |
+
"""Get class by name from registry."""
|
| 74 |
+
if name not in registry:
|
| 75 |
+
# Build readable available list without duplicates when case_insensitive
|
| 76 |
+
seen = set()
|
| 77 |
+
available = []
|
| 78 |
+
for k in registry.keys():
|
| 79 |
+
if k.lower() in seen:
|
| 80 |
+
continue
|
| 81 |
+
seen.add(k.lower())
|
| 82 |
+
available.append(k)
|
| 83 |
+
raise ValueError(
|
| 84 |
+
f"Unknown {registry_name} class: {name}. Available: {available}"
|
| 85 |
+
)
|
| 86 |
+
return registry[name]
|
| 87 |
+
|
| 88 |
+
return registry, register, get_class
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def capture_init_args(cls):
|
| 92 |
+
"""
|
| 93 |
+
Decorator to capture initialization arguments of a class.
|
| 94 |
+
|
| 95 |
+
Stores the mapping of the constructor's parameters to the values supplied
|
| 96 |
+
at instantiation time into `self._init_args` for later serialization.
|
| 97 |
+
"""
|
| 98 |
+
original_init = cls.__init__
|
| 99 |
+
|
| 100 |
+
def new_init(self, *args, **kwargs):
|
| 101 |
+
# Store all initialization arguments
|
| 102 |
+
init_args: Dict[str, Any] = {}
|
| 103 |
+
|
| 104 |
+
# Get parameter names from the original __init__ method
|
| 105 |
+
sig = inspect.signature(original_init)
|
| 106 |
+
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
|
| 107 |
+
|
| 108 |
+
# Map positional args to parameter names
|
| 109 |
+
for i, arg in enumerate(args):
|
| 110 |
+
if i < len(param_names):
|
| 111 |
+
init_args[param_names[i]] = arg
|
| 112 |
+
|
| 113 |
+
# Add keyword args
|
| 114 |
+
init_args.update(kwargs)
|
| 115 |
+
|
| 116 |
+
self._init_args = init_args
|
| 117 |
+
|
| 118 |
+
# Call the original __init__
|
| 119 |
+
original_init(self, *args, **kwargs)
|
| 120 |
+
|
| 121 |
+
cls.__init__ = new_init
|
| 122 |
+
return cls
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# -------------------------
|
| 126 |
+
# Serialization utilities
|
| 127 |
+
# -------------------------
|
| 128 |
+
|
| 129 |
+
def _encode_value(value: Any) -> Any:
|
| 130 |
+
"""Best-effort JSON encoding for common ML types."""
|
| 131 |
+
# Primitives and None
|
| 132 |
+
if value is None or isinstance(value, (bool, int, float, str)):
|
| 133 |
+
return value
|
| 134 |
+
|
| 135 |
+
# Tuples -> lists
|
| 136 |
+
if isinstance(value, tuple):
|
| 137 |
+
return [
|
| 138 |
+
_encode_value(v) for v in value
|
| 139 |
+
]
|
| 140 |
+
|
| 141 |
+
# Lists
|
| 142 |
+
if isinstance(value, list):
|
| 143 |
+
return [
|
| 144 |
+
_encode_value(v) for v in value
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
# Dicts
|
| 148 |
+
if isinstance(value, dict):
|
| 149 |
+
return {k: _encode_value(v) for k, v in value.items()}
|
| 150 |
+
|
| 151 |
+
# torch-specific types
|
| 152 |
+
if torch is not None:
|
| 153 |
+
# torch.dtype
|
| 154 |
+
if isinstance(value, type(getattr(torch, "float32", object))):
|
| 155 |
+
# Guard: torch.dtype is not a class; rely on str(value) format
|
| 156 |
+
s = str(value)
|
| 157 |
+
if s.startswith("torch."):
|
| 158 |
+
return {"__type__": "torch.dtype", "value": s.split(".")[-1]}
|
| 159 |
+
|
| 160 |
+
# torch.device
|
| 161 |
+
if isinstance(value, getattr(torch, "device", ())):
|
| 162 |
+
return {"__type__": "torch.device", "value": str(value)}
|
| 163 |
+
|
| 164 |
+
# Fallback to string representation
|
| 165 |
+
return {"__type__": "str", "value": str(value)}
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _decode_value(value: Any) -> Any:
|
| 169 |
+
"""Decode values produced by _encode_value, recursively for containers."""
|
| 170 |
+
# Lists: decode each element
|
| 171 |
+
if isinstance(value, list):
|
| 172 |
+
return [_decode_value(v) for v in value]
|
| 173 |
+
|
| 174 |
+
# Dicts: either a typed-marker dict or a regular mapping that needs recursive decoding
|
| 175 |
+
if isinstance(value, dict):
|
| 176 |
+
if "__type__" in value:
|
| 177 |
+
t = value.get("__type__")
|
| 178 |
+
v = value.get("value")
|
| 179 |
+
|
| 180 |
+
if t == "torch.dtype" and torch is not None:
|
| 181 |
+
dtype = getattr(torch, str(v), None)
|
| 182 |
+
if dtype is None:
|
| 183 |
+
raise ValueError(f"Unknown torch.dtype: {v}")
|
| 184 |
+
return dtype
|
| 185 |
+
|
| 186 |
+
if t == "torch.device" and torch is not None:
|
| 187 |
+
return torch.device(v)
|
| 188 |
+
|
| 189 |
+
if t == "str":
|
| 190 |
+
return str(v)
|
| 191 |
+
|
| 192 |
+
# Unknown type marker; return raw as-is
|
| 193 |
+
return value
|
| 194 |
+
|
| 195 |
+
# Regular dict: decode values recursively
|
| 196 |
+
return {k: _decode_value(v) for k, v in value.items()}
|
| 197 |
+
|
| 198 |
+
# Primitives and anything else: return as-is
|
| 199 |
+
return value
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def save_object(obj: Any, file_path: str) -> None:
|
| 203 |
+
"""
|
| 204 |
+
Save an object's construction config to a JSON file.
|
| 205 |
+
|
| 206 |
+
The object is expected to have been decorated with capture_init_args,
|
| 207 |
+
so that `obj._init_args` exists.
|
| 208 |
+
"""
|
| 209 |
+
class_name = obj.__class__.__name__
|
| 210 |
+
init_args = getattr(obj, "_init_args", {})
|
| 211 |
+
|
| 212 |
+
serializable_args = _encode_value(init_args)
|
| 213 |
+
payload = {
|
| 214 |
+
"class": class_name,
|
| 215 |
+
"init_args": serializable_args,
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
| 219 |
+
json.dump(payload, f, indent=2)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def load_object(
|
| 223 |
+
file_path: str,
|
| 224 |
+
get_class_fn: Callable[[str], Type[T]],
|
| 225 |
+
override_args: Optional[Dict[str, Any]] = None,
|
| 226 |
+
) -> T:
|
| 227 |
+
"""
|
| 228 |
+
Load an object from a JSON config file previously saved by save_object.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
file_path: Path to JSON file
|
| 232 |
+
get_class_fn: Function to resolve class names from registry
|
| 233 |
+
override_args: Optional dict to override stored init args
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
Instantiated object of type T
|
| 237 |
+
"""
|
| 238 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 239 |
+
payload = json.load(f)
|
| 240 |
+
|
| 241 |
+
class_name = payload["class"]
|
| 242 |
+
encoded_args = payload.get("init_args", {})
|
| 243 |
+
init_args = _decode_value(encoded_args)
|
| 244 |
+
|
| 245 |
+
if override_args:
|
| 246 |
+
init_args.update(override_args)
|
| 247 |
+
|
| 248 |
+
cls = get_class_fn(class_name)
|
| 249 |
+
return cls(**init_args)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def dumps_object_config(obj: Any) -> str:
|
| 253 |
+
"""Return a JSON string with the object's class and init args."""
|
| 254 |
+
class_name = obj.__class__.__name__
|
| 255 |
+
init_args = getattr(obj, "_init_args", {})
|
| 256 |
+
serializable_args = _encode_value(init_args)
|
| 257 |
+
return json.dumps({"class": class_name, "init_args": serializable_args}, indent=2)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def loads_object_config(
|
| 261 |
+
s: str,
|
| 262 |
+
get_class_fn: Callable[[str], Type[T]],
|
| 263 |
+
override_args: Optional[Dict[str, Any]] = None,
|
| 264 |
+
) -> T:
|
| 265 |
+
"""Instantiate an object from a JSON string produced by dumps_object_config."""
|
| 266 |
+
payload = json.loads(s)
|
| 267 |
+
class_name = payload["class"]
|
| 268 |
+
encoded_args = payload.get("init_args", {})
|
| 269 |
+
init_args = _decode_value(encoded_args)
|
| 270 |
+
if override_args:
|
| 271 |
+
init_args.update(override_args)
|
| 272 |
+
cls = get_class_fn(class_name)
|
| 273 |
+
return cls(**init_args)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# Model Registry System (case-insensitive for backward compatibility)
|
| 277 |
+
PROJECTOR_REGISTRY, register_model, get_projector_class = create_registry(
|
| 278 |
+
"projector", case_insensitive=True
|
| 279 |
+
)
|