""" The standard Model Shell. It combines the embedding model, core model and LM head. """ import torch from models import core_models, embedding_models, model_heads class ModelShell(torch.nn.Module): """ Unify the embedding model, core model and LM head into a single object; initializes the weights and prints basic model statistics. """ def __init__( self, embedding_model: embedding_models.EmbedderInterface, core_model: core_models.GenericTransformer, model_head: model_heads.AutoregressiveLMHead, weight_init_func=None, ): super().__init__() self.embedding_model = embedding_model self.core_model = core_model self.model_head = model_head # Expose tokenizer for evaluators expecting model.tokenizer self.tokenizer = getattr(embedding_model, "tokenizer", None) # initialize model weights if weight_init_func is not None: self.apply(weight_init_func) # override to device to set the attribute def to(self, *args, **kwargs): self.device = args[0] return super().to(*args, **kwargs) def forward(self, token_ids, attention_mask=None, **kwargs): """ The default forward pass is used for trianing and accepts the token_ids as input. """ if attention_mask is not None and attention_mask.dtype not in ( torch.bool, torch.float16, torch.float32, torch.float64, ): attention_mask = attention_mask.to(dtype=torch.bool) # pass the token_ids through the embedding model # to get B, S, H (with pos encoding if necessary) x = self.embedding_model(token_ids) # pass the embeddings through the core model x = self.core_model(x, attention_mask=attention_mask) # pass the core model output through the model head x = self.model_head(x) return x @torch.no_grad() def loglikelihood(self, prefixes, continuations): """ Compute the loglikelihood of continuation tokens given a prefix. Args: prefixes: list[str] continuations: list[str] Returns: ll: torch.tensor(B) """ total_strings = [f"{prefix} {cont}" for prefix, cont in zip(prefixes, continuations)] input_tokens = [self.embedding_model.tokenize_input(string, truncate=True) for string in total_strings] padded_batch, mask = self.embedding_model.pad_batch(input_tokens, direction="right") input_tensor = padded_batch.detach().clone().to(device=self.device, dtype=torch.long) decoded_text = [self.embedding_model.tokenizer.decode([ids]) for ids in input_tensor[0]] decoded_text = "-".join(decoded_text) print(f"Decoded input text: {decoded_text}...") # Debugging line to check decoded logits, _ = self.forward(input_tensor) logits = logits[:, :-1].reshape(-1, logits.size(-1)) target_tensor = input_tensor[:, 1:].reshape(-1) ll = torch.nn.functional.cross_entropy(logits, target_tensor, reduction="none") mask = mask[:, 1:].reshape(-1).to(ll.device) ll = ll * mask ll = ll.view(input_tensor.size(0), -1).sum(dim=1) return -ll @torch.no_grad() def loglikelihood_ids(self, prefix_ids_list, continuation_ids_list): """ Compute log-likelihood using pre-tokenized inputs. Args: prefix_ids_list: list[list[int]] — tokenized prefixes continuation_ids_list: list[list[int]] — tokenized continuations Returns: torch.tensor(B): log-likelihoods for each input """ input_ids = [prefix + continuation for prefix, continuation in zip(prefix_ids_list, continuation_ids_list)] padded_inputs, mask = self.embedding_model.pad_batch(input_ids, direction="right") input_tensor = padded_inputs.to(self.device) decoded_text = [self.embedding_model.tokenizer.decode([ids]) for ids in input_tensor[0]] decoded_text = "-".join(decoded_text) # print(f"Decoded input text: {decoded_text}...") # Debugging line to check decoded text # import code; code.interact(local=locals()) logits, _ = self.forward(input_tensor) logits = logits[:, :-1].reshape(-1, logits.size(-1)) target_tensor = input_tensor[:, 1:].reshape(-1) ll = torch.nn.functional.cross_entropy(logits, target_tensor, reduction="none") mask = mask[:, 1:].reshape(-1).to(ll.device) ll = ll * mask ll = ll.view(input_tensor.size(0), -1).sum(dim=1) return -ll @torch.no_grad() def generate( self, input_ids, max_new_tokens=20, temperature=1.0, top_k=0, top_p=1.0, do_sample=True, pad_token_id=None, eos_token_id=None, output_scores=False, return_dict_in_generate=False, logits_processor=None, use_cache=True, **kwargs, ): from types import SimpleNamespace import torch.nn.functional as F input_ids = input_ids.to(self.device) generated = input_ids.clone() scores_out = [] for _ in range(max_new_tokens): outputs = self.forward(generated) logits = outputs[0] if isinstance(outputs, tuple) else outputs logits = logits[:, -1, :] # (B, V) if temperature != 1.0: logits = logits / temperature if logits_processor is not None: for proc in logits_processor: logits = proc(generated, logits) probs = F.softmax(logits, dim=-1) if do_sample: next_tokens = torch.multinomial(probs, num_samples=1) else: next_tokens = torch.argmax(probs, dim=-1, keepdim=True) if output_scores: scores_out.append(logits) generated = torch.cat([generated, next_tokens], dim=1) if eos_token_id is not None and (next_tokens == eos_token_id).all(): break if return_dict_in_generate: return SimpleNamespace( sequences=generated, scores=scores_out, ) else: return generated