from typing import List, Optional, Union import torch from torch import nn from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast import json from rosetta.model.projector import Projector from rosetta.model.sampling import sample_token from transformers.utils import ModelOutput try: from transformers.generation.utils import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput except Exception: GreedySearchDecoderOnlyOutput = None SampleDecoderOnlyOutput = None from rosetta.model.wrapper import RosettaModel class OracleRosettaModel(nn.Module): """ Drop in replacement for the standard transformers LLM models, like Qwen3ForCausalLM """ def __init__(self, model_list: List[PreTrainedModel], base_model_idx = 0, projector_list: List[Projector] = [], aggregator_list: List[nn.Module] = []): super().__init__() # model list: a list of model, model 0 by default is the base model # projector list: a list of projector # standard init with additional model list parameter # kv-cache dict: key (source_model_idx, target_model_idx), value (Cache), assume only convert at prefill with one type of model # projector dict: key (source_model_idx, target_model_idx) value dict(key (source_model_layer_idx, M_target value ) self.base_model_idx = base_model_idx self.model_list = nn.ModuleList(model_list) device = model_list[base_model_idx].device dtype = model_list[base_model_idx].dtype self.projector_list = nn.ModuleList(projector_list).to(device=device, dtype=dtype) self.aggregator_list = nn.ModuleList(aggregator_list).to(device=device, dtype=dtype) self.projector_dict = {} self.aggregator_dict = {} self.kv_cache_dict = {} self._generation_hook_handlers = [] @property def device(self): return self.model_list[self.base_model_idx].device def to(self, device): """ Move the RosettaModel and all underlying models and projectors to the specified device. """ super().to(device) for model in self.model_list: model.to(device) for projector in self.projector_list: projector.to(device) for aggregator in self.aggregator_list: aggregator.to(device) return self # set projector def set_projector_config(self, source_model_idx: int, source_model_layer_idx: int, target_model_idx: int, target_model_layer_idx: int, projector_idx: int): """ Set the projector configuration Args: source_model_idx: int, the index of the source model source_model_layer_idx: int, the index of the source model layer target_model_idx: int, the index of the target model target_model_layer_idx: int, the index of the target model layer projector_idx: int, the index of the projector The projector dict structure supports multiple projectors per target layer. Structure: { target_model_idx: { source_model_idx: { target_model_layer_idx: [(source_model_layer_idx, projector_idx), ...] } } } Repeated calls for the same (target, source, target_layer) append additional pairs. """ if target_model_idx not in self.projector_dict.keys(): self.projector_dict[target_model_idx] = {} if source_model_idx not in self.projector_dict[target_model_idx].keys(): self.projector_dict[target_model_idx][source_model_idx] = {} # Accumulate list of (source_layer, projector_idx) for this target layer layer_entry = self.projector_dict[target_model_idx][source_model_idx].get(target_model_layer_idx) if layer_entry is None: self.projector_dict[target_model_idx][source_model_idx][target_model_layer_idx] = [(source_model_layer_idx, projector_idx)] else: layer_entry.append((source_model_layer_idx, projector_idx)) def load_projector(self, projector_list): self.projector_list: List[Projector] = projector_list def load_aggregator(self, aggregator_list): self.aggregator_list: List[nn.Module] = aggregator_list def get_projector(self, source_model_idx, source_model_layer_idx, target_model_idx, target_model_layer_idx): pair_list = self.projector_dict[target_model_idx][source_model_idx][target_model_layer_idx] if len(pair_list) == 0: raise ValueError("No projector configured for the given target layer") # Prefer exact source layer match for src_layer, projector_id in pair_list: if src_layer == source_model_layer_idx: return self.projector_list[projector_id] # Fallback: return the first projector return self.projector_list[pair_list[0][1]] def set_aggregator_idx(self, source_model_idx: int, target_model_idx: int, target_model_layer_idx: int, aggregator_idx: int): if target_model_idx not in self.aggregator_dict: self.aggregator_dict[target_model_idx] = {} if source_model_idx not in self.aggregator_dict[target_model_idx]: self.aggregator_dict[target_model_idx][source_model_idx] = {} self.aggregator_dict[target_model_idx][source_model_idx][target_model_layer_idx] = aggregator_idx @staticmethod def load_json(file_name): with open(file_name, "r") as f: result = json.load(f) return result @staticmethod def _convert_dict_keys_to_ints(obj): """ Recursively convert dictionary keys that look like integers back to int. This reverses json.dump's coercion of dict keys to strings. """ if isinstance(obj, dict): new_obj = {} for key, value in obj.items(): if isinstance(key, str) and key.lstrip('-').isdigit(): new_key = int(key) else: new_key = key new_obj[new_key] = RosettaModel._convert_dict_keys_to_ints(value) return new_obj if isinstance(obj, list): return [RosettaModel._convert_dict_keys_to_ints(v) for v in obj] return obj def save_projector_config(self, file_name): with open(file_name, "w") as f: json.dump(self.projector_dict, f) def load_projector_config(self, config_path): if config_path.endswith(".json"): loaded = RosettaModel.load_json(config_path) self.projector_dict = RosettaModel._convert_dict_keys_to_ints(loaded) def save_aggregator_config(self, file_name): with open(file_name, "w") as f: json.dump(self.aggregator_dict, f) def load_aggregator_config(self, config_path): if config_path.endswith(".json"): loaded = RosettaModel.load_json(config_path) self.aggregator_dict = RosettaModel._convert_dict_keys_to_ints(loaded) def set_kv_cache_dict(self, source_model_idx, target_model_idx, cache): if target_model_idx not in self.kv_cache_dict.keys(): self.kv_cache_dict[target_model_idx] = {} if cache is None: # Initialize with a DynamicCache instead of RosettaCache for now self.kv_cache_dict[target_model_idx][source_model_idx] = DynamicCache() # noqa, maybe we should use RosettaCache here else: self.kv_cache_dict[target_model_idx][source_model_idx] = cache def forward( self, kv_cache_index: Optional[List] = None, input_ids: Optional[Union[torch.LongTensor, List[torch.LongTensor]]] = None, attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, # **kwargs: Unpack[KwargsForCausalLM], identifier = -1, subject = None, *args, **kwargs, ) -> CausalLMOutputWithPast: """ Forward pass KVCache index is a list of tensors with shape (B, sec_seq_len, 2), indicating the source and target kv cache model index If input_ids is LongTensor, default to same input ids for different models If input_ids is Tuple, default to different input ids for different models. No Rosetta: (-1, 0) """ # noqa self.kv_cache_dict = dict() # Handle different input formats: if input_ids is a list, use per-model inputs if isinstance(input_ids, list): # Use list format: different input_ids and attention_mask for each model base_input_ids = input_ids[self.base_model_idx] if input_ids is not None else None base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None _, seqlen = base_input_ids.size() if base_input_ids is not None else (0, 0) else: # Use tensor format: same input_ids and attention_mask for all models (backward compatibility) base_input_ids = input_ids base_attention_mask = attention_mask _, seqlen = input_ids.size() if input_ids is not None else (0, 0) num_sections = len(kv_cache_index) if kv_cache_index is not None else 1 section_lengths = [kv_cache_index[i].shape[1] for i in range(num_sections)] if kv_cache_index is not None else [seqlen] section_starts = [0] for l in section_lengths: section_starts.append(section_starts[-1] + l) curr_base_kv_cache = past_key_values if seqlen > 1: for i in range(num_sections): start = section_starts[i] end = section_starts[i + 1] prefill_input_ids = base_input_ids[:, start:end] if base_input_ids is not None else None prefill_attention_mask = base_attention_mask[:, :end] if base_attention_mask is not None else None prefill_position_ids = position_ids[:, start:end] if position_ids is not None else None prefill_labels = labels[:, start:end] if labels is not None else None # calculate target model kvcache output = self.model_list[self.base_model_idx].forward( input_ids=prefill_input_ids, attention_mask=prefill_attention_mask, position_ids=prefill_position_ids, past_key_values=curr_base_kv_cache, labels=prefill_labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, *args, **kwargs ) if self.base_model_idx not in self.kv_cache_dict: self.kv_cache_dict[self.base_model_idx] = {} if self.base_model_idx not in self.kv_cache_dict[self.base_model_idx]: self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = None self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = output.past_key_values curr_base_kv_cache: DynamicCache = output.past_key_values # if i != num_sections - 1: for source_model_idx in range(1, len(self.model_list)): if self.base_model_idx not in self.kv_cache_dict: self.kv_cache_dict[self.base_model_idx] = {} if source_model_idx not in self.kv_cache_dict[self.base_model_idx]: self.kv_cache_dict[self.base_model_idx][source_model_idx] = None # Get model-specific input_ids and attention_mask if isinstance(input_ids, list): source_input_ids = input_ids[source_model_idx] source_attention_mask = attention_mask[source_model_idx] if attention_mask is not None else None source_prefill_input_ids = source_input_ids[:, start:end] if source_input_ids is not None else None source_prefill_attention_mask = source_attention_mask[:, :end] if source_attention_mask is not None else None else: # Backward compatibility: use same input for all models source_prefill_input_ids = prefill_input_ids source_prefill_attention_mask = prefill_attention_mask curr_source_kv_cache = self.model_list[source_model_idx].forward( input_ids=source_prefill_input_ids, attention_mask=source_prefill_attention_mask, position_ids=prefill_position_ids, past_key_values=self.kv_cache_dict[self.base_model_idx][source_model_idx], use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, *args, **kwargs ).past_key_values self.kv_cache_dict[self.base_model_idx][source_model_idx] = curr_source_kv_cache # calculate source model kvcache and apply projections if self.base_model_idx in self.projector_dict: source_model_idx = kv_cache_index[i][0][0][0].item() # Get the source model index from the kv_cache_index if source_model_idx != -1: for target_layer_idx, entry in self.projector_dict[self.base_model_idx][source_model_idx].items(): base_key_cache, base_value_cache = curr_base_kv_cache[target_layer_idx] new_base_key_cache = base_key_cache[:, :, start:end, :] new_base_value_cache = base_value_cache[:, :, start:end, :] new_base_kv_cache = (new_base_key_cache, new_base_value_cache) pair_list = entry projected_kv_list = [] source_kv_list = [] for source_model_layer_idx, projector_idx in pair_list: source_key_cache, source_value_cache = self.kv_cache_dict[self.base_model_idx][source_model_idx][source_model_layer_idx] new_source_key_cache = source_key_cache[:, :, start:end, :] new_source_value_cache = source_value_cache[:, :, start:end, :] new_source_kv_cache = (new_source_key_cache, new_source_value_cache) projected_key, projected_value = self.projector_list[projector_idx].forward( new_source_kv_cache, # tuple of (key, value), each of shape (B, N, H, D) new_base_kv_cache ) projected_kv_list.append((projected_key, projected_value)) # -------------- # save base and projected kv cache torch.save((projected_key, projected_value), f"oracle/projected_kv/{subject}_{identifier}_{i}.pt") torch.save(new_base_kv_cache, f"oracle/target_kv/{subject}_{identifier}_{i}.pt") # -------------- source_kv_list.append(new_source_kv_cache) # Aggregate (fallback to first projector if no aggregator is available) use_aggregator = ( len(projected_kv_list) > 1 and len(self.aggregator_list) > 0 and self.base_model_idx in self.aggregator_dict and source_model_idx in self.aggregator_dict[self.base_model_idx] and target_layer_idx in self.aggregator_dict[self.base_model_idx][source_model_idx] ) if use_aggregator: aggregator_idx = self.aggregator_dict[self.base_model_idx][source_model_idx][target_layer_idx] agg_key, agg_value = self.aggregator_list[aggregator_idx].forward( source_kv_list, new_base_kv_cache, projected_kv_list ) else: # Fallback to first projector result when no aggregator is available agg_key, agg_value = projected_kv_list[0] # Update cache with aggregated result curr_base_kv_cache.key_cache[target_layer_idx][:, :, start:end, :] = agg_key curr_base_kv_cache.value_cache[target_layer_idx][:, :, start:end, :] = agg_value output.past_key_values = curr_base_kv_cache # use base model for decode phase else: # Handle list input format for decode phase as well decode_input_ids = input_ids[self.base_model_idx] if isinstance(input_ids, list) else input_ids decode_attention_mask = attention_mask[self.base_model_idx] if isinstance(attention_mask, list) else attention_mask output = self.model_list[self.base_model_idx].forward( input_ids=decode_input_ids, attention_mask=decode_attention_mask, position_ids=position_ids, past_key_values=curr_base_kv_cache, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, *args, **kwargs ) return output @torch.no_grad() def generate( self, kv_cache_index, input_ids, max_new_tokens: Optional[int] = None, past_key_values: Optional[Cache] = None, attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, position_ids: Optional[torch.LongTensor] = None, eos_token_id: Optional[Union[int, List[int]]] = None, pad_token_id: Optional[int] = None, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, do_sample: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, output_scores: Optional[bool] = None, max_length: Optional[int] = None, use_cache: bool = True, *args, **kwargs, ): """ New generation loop without using the base model's generate. - Uses this module's forward for prefill and per-token decode. - Samples tokens via rosetta.model.sampling.sample_token. Returns a tensor of shape [batch, prompt_len + generated_len] for the base model stream. """ # Derive number of tokens to generate # If max_new_tokens not provided, infer from max_length if isinstance(input_ids, list): base_input_ids_for_len = input_ids[self.base_model_idx] else: base_input_ids_for_len = input_ids prompt_len = base_input_ids_for_len.size(1) # Default eos/pad from base model tokenizer/config if not provided base_model = self.model_list[self.base_model_idx] gen_cfg = getattr(base_model, "generation_config", None) cfg_obj = gen_cfg if gen_cfg is not None else getattr(base_model, "config", None) if eos_token_id is None and cfg_obj is not None: eos_token_id = getattr(cfg_obj, "eos_token_id", None) if pad_token_id is None and cfg_obj is not None: pad_token_id = getattr(cfg_obj, "pad_token_id", None) if pad_token_id is None and eos_token_id is not None: pad_token_id = eos_token_id if isinstance(eos_token_id, int) else eos_token_id[0] if max_new_tokens is None: if max_length is not None: if max_length <= prompt_len: max_new_tokens = 0 else: max_new_tokens = max_length - prompt_len else: raise ValueError("Provide max_new_tokens or max_length") if max_new_tokens < 0: raise ValueError("max_new_tokens must be non-negative") # Resolve base inputs if isinstance(input_ids, list): base_input_ids = input_ids[self.base_model_idx] base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None else: base_input_ids = input_ids base_attention_mask = attention_mask if base_attention_mask is None: base_attention_mask = torch.ones_like(base_input_ids, dtype=torch.long, device=base_input_ids.device) batch_size = base_input_ids.size(0) # Prefill to build caches and obtain initial logits prefill_output = self.forward( kv_cache_index=kv_cache_index, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, *args, **kwargs, ) current_past = prefill_output.past_key_values all_input_ids = base_input_ids current_attention_mask = base_attention_mask # EOS handling setup eos_set = None if eos_token_id is not None: eos_set = set(eos_token_id if isinstance(eos_token_id, list) else [eos_token_id]) finished = torch.zeros(batch_size, dtype=torch.bool, device=all_input_ids.device) # Start from last prefill logits last_logits = prefill_output.logits[:, -1, :] # Determine sampling mode if do_sample is None: do_sample = False effective_temperature = temperature if do_sample else 0.0 # Optional scores collection collect_scores = bool(return_dict_in_generate) and bool(output_scores) scores = [] for _ in range(max_new_tokens): if collect_scores: scores.append(last_logits) # Sample next token next_token = sample_token(last_logits, temperature=effective_temperature, top_p=top_p, top_k=top_k) if not isinstance(next_token, torch.Tensor): next_token = torch.tensor([next_token], device=all_input_ids.device, dtype=torch.long).repeat(batch_size) # Apply EOS logic if eos_set is not None: just_finished = torch.zeros_like(finished) for eid in eos_set: just_finished |= (next_token == eid) finished = finished | just_finished if pad_token_id is not None: next_token = torch.where( finished, torch.tensor(pad_token_id, device=next_token.device, dtype=next_token.dtype), next_token, ) # Append sampled token next_token_unsqueezed = next_token.unsqueeze(1) all_input_ids = torch.cat([all_input_ids, next_token_unsqueezed], dim=1) current_attention_mask = torch.cat( [ current_attention_mask, torch.ones((batch_size, 1), device=current_attention_mask.device, dtype=current_attention_mask.dtype), ], dim=1, ) # Early stop if all sequences finished if eos_set is not None and torch.all(finished): break # Decode one step using cached states; pass base-stream tensors kv_cache_index = [torch.tensor([-1, 0], dtype=torch.long).repeat(1, 1).unsqueeze(0).to(all_input_ids.device)] decode_output = self.forward( kv_cache_index=kv_cache_index, input_ids=next_token_unsqueezed, attention_mask=current_attention_mask, position_ids=None, past_key_values=current_past, use_cache=True, *args, **kwargs, ) current_past = decode_output.past_key_values last_logits = decode_output.logits[:, -1, :] # Return style compatible with HF generate if return_dict_in_generate: if GreedySearchDecoderOnlyOutput is not None and SampleDecoderOnlyOutput is not None: if do_sample: return SampleDecoderOnlyOutput( sequences=all_input_ids, scores=scores if collect_scores else None, ) else: return GreedySearchDecoderOnlyOutput( sequences=all_input_ids, scores=scores if collect_scores else None, ) # Fallback to generic ModelOutput result = {"sequences": all_input_ids} if collect_scores: result["scores"] = scores return ModelOutput(**result) return all_input_ids