File size: 6,465 Bytes
5d2c747 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
"""
The standard Model Shell. It combines the embedding model,
core model and LM head.
"""
import torch
from models import core_models, embedding_models, model_heads
class ModelShell(torch.nn.Module):
"""
Unify the embedding model, core model and LM head
into a single object; initializes the weights
and prints basic model statistics.
"""
def __init__(
self,
embedding_model: embedding_models.EmbedderInterface,
core_model: core_models.GenericTransformer,
model_head: model_heads.AutoregressiveLMHead,
weight_init_func=None,
):
super().__init__()
self.embedding_model = embedding_model
self.core_model = core_model
self.model_head = model_head
# Expose tokenizer for evaluators expecting model.tokenizer
self.tokenizer = getattr(embedding_model, "tokenizer", None)
# initialize model weights
if weight_init_func is not None:
self.apply(weight_init_func)
# override to device to set the attribute
def to(self, *args, **kwargs):
self.device = args[0]
return super().to(*args, **kwargs)
def forward(self, token_ids, attention_mask=None, **kwargs):
"""
The default forward pass is used for trianing and
accepts the token_ids as input.
"""
if attention_mask is not None and attention_mask.dtype not in (
torch.bool,
torch.float16,
torch.float32,
torch.float64,
):
attention_mask = attention_mask.to(dtype=torch.bool)
# pass the token_ids through the embedding model
# to get B, S, H (with pos encoding if necessary)
x = self.embedding_model(token_ids)
# pass the embeddings through the core model
x = self.core_model(x, attention_mask=attention_mask)
# pass the core model output through the model head
x = self.model_head(x)
return x
@torch.no_grad()
def loglikelihood(self, prefixes, continuations):
"""
Compute the loglikelihood of continuation
tokens given a prefix.
Args:
prefixes: list[str]
continuations: list[str]
Returns:
ll: torch.tensor(B)
"""
total_strings = [f"{prefix} {cont}" for prefix, cont in zip(prefixes, continuations)]
input_tokens = [self.embedding_model.tokenize_input(string, truncate=True) for string in total_strings]
padded_batch, mask = self.embedding_model.pad_batch(input_tokens, direction="right")
input_tensor = padded_batch.detach().clone().to(device=self.device, dtype=torch.long)
decoded_text = [self.embedding_model.tokenizer.decode([ids]) for ids in input_tensor[0]]
decoded_text = "-".join(decoded_text)
print(f"Decoded input text: {decoded_text}...") # Debugging line to check decoded
logits, _ = self.forward(input_tensor)
logits = logits[:, :-1].reshape(-1, logits.size(-1))
target_tensor = input_tensor[:, 1:].reshape(-1)
ll = torch.nn.functional.cross_entropy(logits, target_tensor, reduction="none")
mask = mask[:, 1:].reshape(-1).to(ll.device)
ll = ll * mask
ll = ll.view(input_tensor.size(0), -1).sum(dim=1)
return -ll
@torch.no_grad()
def loglikelihood_ids(self, prefix_ids_list, continuation_ids_list):
"""
Compute log-likelihood using pre-tokenized inputs.
Args:
prefix_ids_list: list[list[int]] — tokenized prefixes
continuation_ids_list: list[list[int]] — tokenized continuations
Returns:
torch.tensor(B): log-likelihoods for each input
"""
input_ids = [prefix + continuation for prefix, continuation in zip(prefix_ids_list, continuation_ids_list)]
padded_inputs, mask = self.embedding_model.pad_batch(input_ids, direction="right")
input_tensor = padded_inputs.to(self.device)
decoded_text = [self.embedding_model.tokenizer.decode([ids]) for ids in input_tensor[0]]
decoded_text = "-".join(decoded_text)
# print(f"Decoded input text: {decoded_text}...") # Debugging line to check decoded text
# import code; code.interact(local=locals())
logits, _ = self.forward(input_tensor)
logits = logits[:, :-1].reshape(-1, logits.size(-1))
target_tensor = input_tensor[:, 1:].reshape(-1)
ll = torch.nn.functional.cross_entropy(logits, target_tensor, reduction="none")
mask = mask[:, 1:].reshape(-1).to(ll.device)
ll = ll * mask
ll = ll.view(input_tensor.size(0), -1).sum(dim=1)
return -ll
@torch.no_grad()
def generate(
self,
input_ids,
max_new_tokens=20,
temperature=1.0,
top_k=0,
top_p=1.0,
do_sample=True,
pad_token_id=None,
eos_token_id=None,
output_scores=False,
return_dict_in_generate=False,
logits_processor=None,
use_cache=True,
**kwargs,
):
from types import SimpleNamespace
import torch.nn.functional as F
input_ids = input_ids.to(self.device)
generated = input_ids.clone()
scores_out = []
for _ in range(max_new_tokens):
outputs = self.forward(generated)
logits = outputs[0] if isinstance(outputs, tuple) else outputs
logits = logits[:, -1, :] # (B, V)
if temperature != 1.0:
logits = logits / temperature
if logits_processor is not None:
for proc in logits_processor:
logits = proc(generated, logits)
probs = F.softmax(logits, dim=-1)
if do_sample:
next_tokens = torch.multinomial(probs, num_samples=1)
else:
next_tokens = torch.argmax(probs, dim=-1, keepdim=True)
if output_scores:
scores_out.append(logits)
generated = torch.cat([generated, next_tokens], dim=1)
if eos_token_id is not None and (next_tokens == eos_token_id).all():
break
if return_dict_in_generate:
return SimpleNamespace(
sequences=generated,
scores=scores_out,
)
else:
return generated
|