|
|
""" |
|
|
The standard Model Shell. It combines the embedding model, |
|
|
core model and LM head. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
|
|
|
from models import core_models, embedding_models, model_heads |
|
|
|
|
|
|
|
|
|
|
|
class ModelShell(torch.nn.Module): |
|
|
""" |
|
|
Unify the embedding model, core model and LM head |
|
|
into a single object; initializes the weights |
|
|
and prints basic model statistics. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embedding_model: embedding_models.EmbedderInterface, |
|
|
core_model: core_models.GenericTransformer, |
|
|
model_head: model_heads.AutoregressiveLMHead, |
|
|
weight_init_func=None, |
|
|
): |
|
|
super().__init__() |
|
|
self.embedding_model = embedding_model |
|
|
self.core_model = core_model |
|
|
self.model_head = model_head |
|
|
|
|
|
self.tokenizer = getattr(embedding_model, "tokenizer", None) |
|
|
|
|
|
|
|
|
|
|
|
if weight_init_func is not None: |
|
|
self.apply(weight_init_func) |
|
|
|
|
|
|
|
|
def to(self, *args, **kwargs): |
|
|
self.device = args[0] |
|
|
return super().to(*args, **kwargs) |
|
|
|
|
|
def forward(self, token_ids, attention_mask=None, **kwargs): |
|
|
""" |
|
|
The default forward pass is used for trianing and |
|
|
accepts the token_ids as input. |
|
|
""" |
|
|
if attention_mask is not None and attention_mask.dtype not in ( |
|
|
torch.bool, |
|
|
torch.float16, |
|
|
torch.float32, |
|
|
torch.float64, |
|
|
): |
|
|
attention_mask = attention_mask.to(dtype=torch.bool) |
|
|
|
|
|
|
|
|
|
|
|
x = self.embedding_model(token_ids) |
|
|
|
|
|
|
|
|
x = self.core_model(x, attention_mask=attention_mask) |
|
|
|
|
|
|
|
|
x = self.model_head(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def loglikelihood(self, prefixes, continuations): |
|
|
""" |
|
|
Compute the loglikelihood of continuation |
|
|
tokens given a prefix. |
|
|
Args: |
|
|
prefixes: list[str] |
|
|
continuations: list[str] |
|
|
Returns: |
|
|
ll: torch.tensor(B) |
|
|
""" |
|
|
total_strings = [f"{prefix} {cont}" for prefix, cont in zip(prefixes, continuations)] |
|
|
input_tokens = [self.embedding_model.tokenize_input(string, truncate=True) for string in total_strings] |
|
|
padded_batch, mask = self.embedding_model.pad_batch(input_tokens, direction="right") |
|
|
input_tensor = padded_batch.detach().clone().to(device=self.device, dtype=torch.long) |
|
|
|
|
|
decoded_text = [self.embedding_model.tokenizer.decode([ids]) for ids in input_tensor[0]] |
|
|
decoded_text = "-".join(decoded_text) |
|
|
print(f"Decoded input text: {decoded_text}...") |
|
|
|
|
|
logits, _ = self.forward(input_tensor) |
|
|
logits = logits[:, :-1].reshape(-1, logits.size(-1)) |
|
|
target_tensor = input_tensor[:, 1:].reshape(-1) |
|
|
ll = torch.nn.functional.cross_entropy(logits, target_tensor, reduction="none") |
|
|
mask = mask[:, 1:].reshape(-1).to(ll.device) |
|
|
ll = ll * mask |
|
|
ll = ll.view(input_tensor.size(0), -1).sum(dim=1) |
|
|
return -ll |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def loglikelihood_ids(self, prefix_ids_list, continuation_ids_list): |
|
|
""" |
|
|
Compute log-likelihood using pre-tokenized inputs. |
|
|
Args: |
|
|
prefix_ids_list: list[list[int]] — tokenized prefixes |
|
|
continuation_ids_list: list[list[int]] — tokenized continuations |
|
|
Returns: |
|
|
torch.tensor(B): log-likelihoods for each input |
|
|
""" |
|
|
input_ids = [prefix + continuation for prefix, continuation in zip(prefix_ids_list, continuation_ids_list)] |
|
|
|
|
|
padded_inputs, mask = self.embedding_model.pad_batch(input_ids, direction="right") |
|
|
input_tensor = padded_inputs.to(self.device) |
|
|
|
|
|
decoded_text = [self.embedding_model.tokenizer.decode([ids]) for ids in input_tensor[0]] |
|
|
decoded_text = "-".join(decoded_text) |
|
|
|
|
|
|
|
|
logits, _ = self.forward(input_tensor) |
|
|
logits = logits[:, :-1].reshape(-1, logits.size(-1)) |
|
|
target_tensor = input_tensor[:, 1:].reshape(-1) |
|
|
ll = torch.nn.functional.cross_entropy(logits, target_tensor, reduction="none") |
|
|
mask = mask[:, 1:].reshape(-1).to(ll.device) |
|
|
ll = ll * mask |
|
|
ll = ll.view(input_tensor.size(0), -1).sum(dim=1) |
|
|
|
|
|
return -ll |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
input_ids, |
|
|
max_new_tokens=20, |
|
|
temperature=1.0, |
|
|
top_k=0, |
|
|
top_p=1.0, |
|
|
do_sample=True, |
|
|
pad_token_id=None, |
|
|
eos_token_id=None, |
|
|
output_scores=False, |
|
|
return_dict_in_generate=False, |
|
|
logits_processor=None, |
|
|
use_cache=True, |
|
|
**kwargs, |
|
|
): |
|
|
from types import SimpleNamespace |
|
|
import torch.nn.functional as F |
|
|
|
|
|
input_ids = input_ids.to(self.device) |
|
|
generated = input_ids.clone() |
|
|
scores_out = [] |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
outputs = self.forward(generated) |
|
|
logits = outputs[0] if isinstance(outputs, tuple) else outputs |
|
|
logits = logits[:, -1, :] |
|
|
|
|
|
if temperature != 1.0: |
|
|
logits = logits / temperature |
|
|
|
|
|
if logits_processor is not None: |
|
|
for proc in logits_processor: |
|
|
logits = proc(generated, logits) |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
if do_sample: |
|
|
next_tokens = torch.multinomial(probs, num_samples=1) |
|
|
else: |
|
|
next_tokens = torch.argmax(probs, dim=-1, keepdim=True) |
|
|
|
|
|
if output_scores: |
|
|
scores_out.append(logits) |
|
|
|
|
|
generated = torch.cat([generated, next_tokens], dim=1) |
|
|
|
|
|
if eos_token_id is not None and (next_tokens == eos_token_id).all(): |
|
|
break |
|
|
|
|
|
if return_dict_in_generate: |
|
|
return SimpleNamespace( |
|
|
sequences=generated, |
|
|
scores=scores_out, |
|
|
) |
|
|
else: |
|
|
return generated |
|
|
|