tinyllm_canonical / models /model_shell.py
stegsoph's picture
Upload folder using huggingface_hub
5d2c747 verified
"""
The standard Model Shell. It combines the embedding model,
core model and LM head.
"""
import torch
from models import core_models, embedding_models, model_heads
class ModelShell(torch.nn.Module):
"""
Unify the embedding model, core model and LM head
into a single object; initializes the weights
and prints basic model statistics.
"""
def __init__(
self,
embedding_model: embedding_models.EmbedderInterface,
core_model: core_models.GenericTransformer,
model_head: model_heads.AutoregressiveLMHead,
weight_init_func=None,
):
super().__init__()
self.embedding_model = embedding_model
self.core_model = core_model
self.model_head = model_head
# Expose tokenizer for evaluators expecting model.tokenizer
self.tokenizer = getattr(embedding_model, "tokenizer", None)
# initialize model weights
if weight_init_func is not None:
self.apply(weight_init_func)
# override to device to set the attribute
def to(self, *args, **kwargs):
self.device = args[0]
return super().to(*args, **kwargs)
def forward(self, token_ids, attention_mask=None, **kwargs):
"""
The default forward pass is used for trianing and
accepts the token_ids as input.
"""
if attention_mask is not None and attention_mask.dtype not in (
torch.bool,
torch.float16,
torch.float32,
torch.float64,
):
attention_mask = attention_mask.to(dtype=torch.bool)
# pass the token_ids through the embedding model
# to get B, S, H (with pos encoding if necessary)
x = self.embedding_model(token_ids)
# pass the embeddings through the core model
x = self.core_model(x, attention_mask=attention_mask)
# pass the core model output through the model head
x = self.model_head(x)
return x
@torch.no_grad()
def loglikelihood(self, prefixes, continuations):
"""
Compute the loglikelihood of continuation
tokens given a prefix.
Args:
prefixes: list[str]
continuations: list[str]
Returns:
ll: torch.tensor(B)
"""
total_strings = [f"{prefix} {cont}" for prefix, cont in zip(prefixes, continuations)]
input_tokens = [self.embedding_model.tokenize_input(string, truncate=True) for string in total_strings]
padded_batch, mask = self.embedding_model.pad_batch(input_tokens, direction="right")
input_tensor = padded_batch.detach().clone().to(device=self.device, dtype=torch.long)
decoded_text = [self.embedding_model.tokenizer.decode([ids]) for ids in input_tensor[0]]
decoded_text = "-".join(decoded_text)
print(f"Decoded input text: {decoded_text}...") # Debugging line to check decoded
logits, _ = self.forward(input_tensor)
logits = logits[:, :-1].reshape(-1, logits.size(-1))
target_tensor = input_tensor[:, 1:].reshape(-1)
ll = torch.nn.functional.cross_entropy(logits, target_tensor, reduction="none")
mask = mask[:, 1:].reshape(-1).to(ll.device)
ll = ll * mask
ll = ll.view(input_tensor.size(0), -1).sum(dim=1)
return -ll
@torch.no_grad()
def loglikelihood_ids(self, prefix_ids_list, continuation_ids_list):
"""
Compute log-likelihood using pre-tokenized inputs.
Args:
prefix_ids_list: list[list[int]] — tokenized prefixes
continuation_ids_list: list[list[int]] — tokenized continuations
Returns:
torch.tensor(B): log-likelihoods for each input
"""
input_ids = [prefix + continuation for prefix, continuation in zip(prefix_ids_list, continuation_ids_list)]
padded_inputs, mask = self.embedding_model.pad_batch(input_ids, direction="right")
input_tensor = padded_inputs.to(self.device)
decoded_text = [self.embedding_model.tokenizer.decode([ids]) for ids in input_tensor[0]]
decoded_text = "-".join(decoded_text)
# print(f"Decoded input text: {decoded_text}...") # Debugging line to check decoded text
# import code; code.interact(local=locals())
logits, _ = self.forward(input_tensor)
logits = logits[:, :-1].reshape(-1, logits.size(-1))
target_tensor = input_tensor[:, 1:].reshape(-1)
ll = torch.nn.functional.cross_entropy(logits, target_tensor, reduction="none")
mask = mask[:, 1:].reshape(-1).to(ll.device)
ll = ll * mask
ll = ll.view(input_tensor.size(0), -1).sum(dim=1)
return -ll
@torch.no_grad()
def generate(
self,
input_ids,
max_new_tokens=20,
temperature=1.0,
top_k=0,
top_p=1.0,
do_sample=True,
pad_token_id=None,
eos_token_id=None,
output_scores=False,
return_dict_in_generate=False,
logits_processor=None,
use_cache=True,
**kwargs,
):
from types import SimpleNamespace
import torch.nn.functional as F
input_ids = input_ids.to(self.device)
generated = input_ids.clone()
scores_out = []
for _ in range(max_new_tokens):
outputs = self.forward(generated)
logits = outputs[0] if isinstance(outputs, tuple) else outputs
logits = logits[:, -1, :] # (B, V)
if temperature != 1.0:
logits = logits / temperature
if logits_processor is not None:
for proc in logits_processor:
logits = proc(generated, logits)
probs = F.softmax(logits, dim=-1)
if do_sample:
next_tokens = torch.multinomial(probs, num_samples=1)
else:
next_tokens = torch.argmax(probs, dim=-1, keepdim=True)
if output_scores:
scores_out.append(logits)
generated = torch.cat([generated, next_tokens], dim=1)
if eos_token_id is not None and (next_tokens == eos_token_id).all():
break
if return_dict_in_generate:
return SimpleNamespace(
sequences=generated,
scores=scores_out,
)
else:
return generated