|
|
""" |
|
|
Full definition of a GPT Language Model, all of it in this single file. |
|
|
References: |
|
|
1) the official GPT-2 TensorFlow implementation released by OpenAI: |
|
|
https://github.com/openai/gpt-2/blob/master/src/model.py |
|
|
2) huggingface/transformers PyTorch implementation: |
|
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py |
|
|
""" |
|
|
|
|
|
from datetime import datetime |
|
|
import math |
|
|
import inspect |
|
|
import os |
|
|
import uuid |
|
|
|
|
|
import pandas as pd |
|
|
from pydantic import BaseModel, ConfigDict |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
from transformers import PreTrainedTokenizerFast |
|
|
from typing import Callable |
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module): |
|
|
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" |
|
|
|
|
|
def __init__(self, ndim, bias): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(ndim)) |
|
|
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
|
|
|
|
|
def forward(self, input): |
|
|
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) |
|
|
|
|
|
|
|
|
class CausalSelfAttention(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
assert config.n_embd % config.n_head == 0 |
|
|
|
|
|
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) |
|
|
|
|
|
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
|
|
|
|
|
self.attn_dropout = nn.Dropout(config.dropout) |
|
|
self.resid_dropout = nn.Dropout(config.dropout) |
|
|
self.n_head = config.n_head |
|
|
self.n_embd = config.n_embd |
|
|
self.dropout = config.dropout |
|
|
|
|
|
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") |
|
|
if not self.flash: |
|
|
print( |
|
|
"WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0" |
|
|
) |
|
|
|
|
|
self.register_buffer( |
|
|
"bias", |
|
|
torch.tril(torch.ones(config.block_size, config.block_size)).view( |
|
|
1, 1, config.block_size, config.block_size |
|
|
), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
( |
|
|
B, |
|
|
T, |
|
|
C, |
|
|
) = x.size() |
|
|
|
|
|
|
|
|
q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
|
|
k = k.view(B, T, self.n_head, C // self.n_head).transpose( |
|
|
1, 2 |
|
|
) |
|
|
q = q.view(B, T, self.n_head, C // self.n_head).transpose( |
|
|
1, 2 |
|
|
) |
|
|
v = v.view(B, T, self.n_head, C // self.n_head).transpose( |
|
|
1, 2 |
|
|
) |
|
|
|
|
|
|
|
|
if self.flash: |
|
|
|
|
|
y = torch.nn.functional.scaled_dot_product_attention( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
attn_mask=None, |
|
|
dropout_p=self.dropout if self.training else 0, |
|
|
is_causal=True, |
|
|
) |
|
|
else: |
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
|
|
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) |
|
|
att = F.softmax(att, dim=-1) |
|
|
att = self.attn_dropout(att) |
|
|
y = att @ v |
|
|
y = ( |
|
|
y.transpose(1, 2).contiguous().view(B, T, C) |
|
|
) |
|
|
|
|
|
|
|
|
y = self.resid_dropout(self.c_proj(y)) |
|
|
return y |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) |
|
|
self.gelu = nn.GELU() |
|
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) |
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.c_fc(x) |
|
|
x = self.gelu(x) |
|
|
x = self.c_proj(x) |
|
|
x = self.dropout(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) |
|
|
self.attn = CausalSelfAttention(config) |
|
|
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) |
|
|
self.mlp = MLP(config) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.attn(self.ln_1(x)) |
|
|
x = x + self.mlp(self.ln_2(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class GPTConfig(BaseModel): |
|
|
block_size: int = 1024 |
|
|
vocab_size: int = 50304 |
|
|
n_layer: int = 12 |
|
|
n_head: int = 12 |
|
|
n_embd: int = 768 |
|
|
dropout: float = 0.0 |
|
|
bias: bool = True |
|
|
tokenizer_file: str = 'resources/tokenizer.json' |
|
|
|
|
|
model_config = ConfigDict(extra='ignore') |
|
|
|
|
|
|
|
|
class GPT(nn.Module): |
|
|
def __init__(self, config: GPTConfig): |
|
|
super().__init__() |
|
|
assert config.vocab_size is not None |
|
|
assert config.block_size is not None |
|
|
self.config = config |
|
|
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=config.tokenizer_file) |
|
|
self.end_token = self.tokenizer('[END]')['input_ids'][0] |
|
|
self.comma_token = self.tokenizer(',')['input_ids'][0] |
|
|
|
|
|
self.transformer = nn.ModuleDict( |
|
|
dict( |
|
|
wte=nn.Embedding(config.vocab_size, config.n_embd), |
|
|
wpe=nn.Embedding(config.block_size, config.n_embd), |
|
|
drop=nn.Dropout(config.dropout), |
|
|
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
|
|
ln_f=LayerNorm(config.n_embd, bias=config.bias), |
|
|
) |
|
|
) |
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.transformer.wte.weight = ( |
|
|
self.lm_head.weight |
|
|
) |
|
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
for pn, p in self.named_parameters(): |
|
|
if pn.endswith("c_proj.weight"): |
|
|
torch.nn.init.normal_( |
|
|
p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_num_params(self, non_embedding=True): |
|
|
""" |
|
|
Return the number of parameters in the model. |
|
|
For non-embedding count (default), the position embeddings get subtracted. |
|
|
The token embeddings would too, except due to the parameter sharing these |
|
|
params are actually used as weights in the final layer, so we include them. |
|
|
""" |
|
|
n_params = sum(p.numel() for p in self.parameters()) |
|
|
if non_embedding: |
|
|
n_params -= self.transformer.wpe.weight.numel() |
|
|
return n_params |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
|
|
def forward(self, idx, targets=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = idx.device |
|
|
b, t = idx.size() |
|
|
assert ( |
|
|
t <= self.config.block_size |
|
|
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
pos = torch.arange(0, t, dtype=torch.long, device=device) |
|
|
|
|
|
|
|
|
tok_emb = self.transformer.wte(idx) |
|
|
pos_emb = self.transformer.wpe(pos) |
|
|
x = self.transformer.drop(tok_emb + pos_emb) |
|
|
for block in self.transformer.h: |
|
|
x = block(x) |
|
|
x = self.transformer.ln_f(x) |
|
|
|
|
|
if targets is not None: |
|
|
|
|
|
logits = self.lm_head(x) |
|
|
loss = F.cross_entropy( |
|
|
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 |
|
|
) |
|
|
else: |
|
|
|
|
|
logits = self.lm_head( |
|
|
x[:, [-1], :] |
|
|
) |
|
|
loss = None |
|
|
|
|
|
return logits, loss |
|
|
|
|
|
def crop_block_size(self, block_size): |
|
|
|
|
|
|
|
|
|
|
|
assert block_size <= self.config.block_size |
|
|
self.config.block_size = block_size |
|
|
self.transformer.wpe.weight = nn.Parameter( |
|
|
self.transformer.wpe.weight[:block_size] |
|
|
) |
|
|
for block in self.transformer.h: |
|
|
if hasattr(block.attn, "bias"): |
|
|
block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, model_type, override_args=None): |
|
|
assert model_type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"} |
|
|
override_args = override_args or {} |
|
|
|
|
|
assert all(k == "dropout" for k in override_args) |
|
|
from transformers import GPT2LMHeadModel |
|
|
|
|
|
print("loading weights from pretrained gpt: %s" % model_type) |
|
|
|
|
|
|
|
|
config_args = { |
|
|
"gpt2": dict(n_layer=12, n_head=12, n_embd=768), |
|
|
"gpt2-medium": dict(n_layer=24, n_head=16, n_embd=1024), |
|
|
"gpt2-large": dict(n_layer=36, n_head=20, n_embd=1280), |
|
|
"gpt2-xl": dict(n_layer=48, n_head=25, n_embd=1600), |
|
|
}[model_type] |
|
|
print("forcing vocab_size=50257, block_size=1024, bias=True") |
|
|
config_args["vocab_size"] = 50257 |
|
|
config_args["block_size"] = 1024 |
|
|
config_args["bias"] = True |
|
|
|
|
|
if "dropout" in override_args: |
|
|
print(f"overriding dropout rate to {override_args['dropout']}") |
|
|
config_args["dropout"] = override_args["dropout"] |
|
|
|
|
|
config = GPTConfig(**config_args) |
|
|
model = GPT(config) |
|
|
sd = model.state_dict() |
|
|
sd_keys = sd.keys() |
|
|
sd_keys = [ |
|
|
k for k in sd_keys if not k.endswith(".attn.bias") |
|
|
] |
|
|
|
|
|
|
|
|
model_hf = GPT2LMHeadModel.from_pretrained(model_type) |
|
|
sd_hf = model_hf.state_dict() |
|
|
|
|
|
|
|
|
sd_keys_hf = sd_hf.keys() |
|
|
sd_keys_hf = [ |
|
|
k for k in sd_keys_hf if not k.endswith(".attn.masked_bias") |
|
|
] |
|
|
sd_keys_hf = [ |
|
|
k for k in sd_keys_hf if not k.endswith(".attn.bias") |
|
|
] |
|
|
transposed = [ |
|
|
"attn.c_attn.weight", |
|
|
"attn.c_proj.weight", |
|
|
"mlp.c_fc.weight", |
|
|
"mlp.c_proj.weight", |
|
|
] |
|
|
|
|
|
|
|
|
assert len(sd_keys_hf) == len( |
|
|
sd_keys |
|
|
), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" |
|
|
for k in sd_keys_hf: |
|
|
if any(k.endswith(w) for w in transposed): |
|
|
|
|
|
assert sd_hf[k].shape[::-1] == sd[k].shape |
|
|
with torch.no_grad(): |
|
|
sd[k].copy_(sd_hf[k].t()) |
|
|
else: |
|
|
|
|
|
assert sd_hf[k].shape == sd[k].shape |
|
|
with torch.no_grad(): |
|
|
sd[k].copy_(sd_hf[k]) |
|
|
|
|
|
return model |
|
|
|
|
|
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): |
|
|
|
|
|
param_dict = {pn: p for pn, p in self.named_parameters()} |
|
|
|
|
|
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} |
|
|
|
|
|
|
|
|
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] |
|
|
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] |
|
|
optim_groups = [ |
|
|
{"params": decay_params, "weight_decay": weight_decay}, |
|
|
{"params": nodecay_params, "weight_decay": 0.0}, |
|
|
] |
|
|
num_decay_params = sum(p.numel() for p in decay_params) |
|
|
num_nodecay_params = sum(p.numel() for p in nodecay_params) |
|
|
print( |
|
|
f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters" |
|
|
) |
|
|
print( |
|
|
f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters" |
|
|
) |
|
|
|
|
|
fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters |
|
|
use_fused = fused_available and device_type == "cuda" |
|
|
extra_args = dict(fused=True) if use_fused else dict() |
|
|
optimizer = torch.optim.AdamW( |
|
|
optim_groups, lr=learning_rate, betas=betas, **extra_args |
|
|
) |
|
|
print(f"using fused AdamW: {use_fused}") |
|
|
|
|
|
return optimizer |
|
|
|
|
|
def estimate_mfu(self, fwdbwd_per_iter, dt): |
|
|
"""estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS""" |
|
|
|
|
|
|
|
|
N = self.get_num_params() |
|
|
cfg = self.config |
|
|
L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd // cfg.n_head, cfg.block_size |
|
|
flops_per_token = 6 * N + 12 * L * H * Q * T |
|
|
flops_per_fwdbwd = flops_per_token * T |
|
|
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter |
|
|
|
|
|
flops_achieved = flops_per_iter * (1.0 / dt) |
|
|
flops_promised = 312e12 |
|
|
mfu = flops_achieved / flops_promised |
|
|
return mfu |
|
|
|
|
|
@property |
|
|
def device(self) -> str: |
|
|
|
|
|
return next(self.lm_head.parameters()).device.type |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
idx: torch.Tensor, |
|
|
max_new_tokens: int = 12, |
|
|
temperature: float = 0.0, |
|
|
topn: int = 100, |
|
|
pruning_ratio: float = 4, |
|
|
pruning_offset: float = 5, |
|
|
log_file: str | None = None, |
|
|
on_iteration: Callable = None, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
if topn <= 0: |
|
|
raise ValueError('topn should be greater than 0') |
|
|
|
|
|
if not 0 < max_new_tokens <= 20: |
|
|
raise ValueError('max_new_tokens should be in (0, 20]') |
|
|
|
|
|
run_uuid = uuid.uuid4() |
|
|
|
|
|
idx = idx.to(self.device) |
|
|
sequences = idx.unsqueeze(0) |
|
|
|
|
|
probabilities = torch.tensor([1.], device=self.device) |
|
|
|
|
|
finished_sequences = torch.tensor([], device=self.device) |
|
|
finished_probs = torch.tensor([], device=self.device) |
|
|
|
|
|
|
|
|
sequences_per_iter = round(pruning_offset + topn / pruning_ratio) |
|
|
|
|
|
for i in range(max_new_tokens): |
|
|
if on_iteration is not None: |
|
|
on_iteration() |
|
|
|
|
|
|
|
|
sequences = sequences[:, -self.config.block_size:] |
|
|
|
|
|
|
|
|
logits, _ = self(sequences) |
|
|
logits = logits.squeeze(1) |
|
|
|
|
|
|
|
|
output_probs = F.softmax(logits, dim=-1) |
|
|
new_sequence_probs = output_probs * probabilities.unsqueeze(1) |
|
|
|
|
|
|
|
|
if i > 0: |
|
|
|
|
|
comma_token_probs = new_sequence_probs[:, self.comma_token] |
|
|
end_token_probs = new_sequence_probs[:, self.end_token] |
|
|
_finish_probs = end_token_probs + comma_token_probs |
|
|
|
|
|
finished_sequences = torch.cat((finished_sequences, sequences)) |
|
|
finished_probs = torch.cat((finished_probs, _finish_probs), dim=-1) |
|
|
|
|
|
|
|
|
if len(finished_sequences) > topn: |
|
|
|
|
|
lowest_viable_probability = torch.topk(finished_probs, topn).values[-1] |
|
|
viable_sequences = probabilities > lowest_viable_probability |
|
|
|
|
|
if viable_sequences.sum() == 0: |
|
|
break |
|
|
|
|
|
|
|
|
sequences = sequences[viable_sequences] |
|
|
probabilities = probabilities[viable_sequences] |
|
|
logits = logits[viable_sequences] |
|
|
new_sequence_probs = new_sequence_probs[viable_sequences] |
|
|
|
|
|
|
|
|
token_mask = new_sequence_probs < lowest_viable_probability |
|
|
if token_mask.sum() == 0: |
|
|
break |
|
|
|
|
|
new_sequence_probs[token_mask] = 0 |
|
|
logits[token_mask] = 0 |
|
|
|
|
|
|
|
|
new_sequence_probs[:, self.end_token] = 0 |
|
|
new_sequence_probs[:, self.comma_token] = 0 |
|
|
|
|
|
|
|
|
num_nonzero_probs = torch.count_nonzero(new_sequence_probs).item() |
|
|
num_seqs_next_iter = min(sequences_per_iter, num_nonzero_probs) |
|
|
|
|
|
if num_seqs_next_iter == 0: |
|
|
break |
|
|
|
|
|
if temperature == 0: |
|
|
new_sequence_probs = new_sequence_probs.flatten() |
|
|
_, idx_next = torch.topk(new_sequence_probs, num_seqs_next_iter) |
|
|
|
|
|
else: |
|
|
|
|
|
scaled_logits = logits / (temperature+1e-1) |
|
|
probs_with_temp = F.softmax(scaled_logits, dim=-1) |
|
|
probs_with_temp = probs_with_temp * probabilities.unsqueeze(1) |
|
|
|
|
|
probs_with_temp[:, self.end_token] = 0 |
|
|
probs_with_temp[:, self.comma_token] = 0 |
|
|
|
|
|
|
|
|
probs_with_temp = probs_with_temp.flatten() |
|
|
probs_with_temp[probs_with_temp < 0] = 0 |
|
|
idx_next = torch.multinomial(probs_with_temp, num_seqs_next_iter) |
|
|
|
|
|
|
|
|
sequence_idx = idx_next // self.config.vocab_size |
|
|
token_values = idx_next % self.config.vocab_size |
|
|
|
|
|
sequences = sequences[sequence_idx] |
|
|
sequences = torch.cat([sequences, token_values.unsqueeze(1)], dim=-1) |
|
|
probabilities = new_sequence_probs.flatten()[idx_next] |
|
|
|
|
|
if log_file is not None: |
|
|
_, current_best_idx = torch.topk(finished_probs, min(topn, len(finished_probs))) |
|
|
current_best = finished_sequences[current_best_idx] |
|
|
self.log_generation_data( |
|
|
log_file=log_file, |
|
|
run_id=run_uuid, |
|
|
topn=topn, |
|
|
x=idx, |
|
|
iteration=i, |
|
|
probabilities=probabilities, |
|
|
current_preds=current_best, |
|
|
finished_probs=finished_probs, |
|
|
) |
|
|
|
|
|
|
|
|
_, final_indices = torch.topk(finished_probs, topn) |
|
|
final_sequences = finished_sequences[final_indices] |
|
|
|
|
|
return final_sequences |
|
|
|
|
|
def log_generation_data( |
|
|
self, |
|
|
log_file: str, |
|
|
run_id: uuid.UUID, |
|
|
iteration: int, |
|
|
topn: int, |
|
|
x: torch.Tensor, |
|
|
probabilities: torch.Tensor, |
|
|
current_preds: torch.Tensor, |
|
|
finished_probs: torch.Tensor, |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(finished_probs) > topn: |
|
|
topnth_finished_prob = torch.topk(finished_probs, topn).values[-1].item() |
|
|
else: |
|
|
topnth_finished_prob = 0 |
|
|
|
|
|
largest_prob = probabilities.max().item() |
|
|
|
|
|
new_row = [{ |
|
|
'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f'), |
|
|
'run_id': str(run_id), |
|
|
'topn': topn, |
|
|
'iteration': iteration, |
|
|
'largest_prob': largest_prob, |
|
|
'topnth_finished_prob': topnth_finished_prob, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}] |
|
|
df_new_row = pd.DataFrame(new_row) |
|
|
|
|
|
if os.path.exists(log_file): |
|
|
df = pd.read_csv(log_file, index_col=0) |
|
|
df = pd.concat([df, df_new_row], ignore_index=True) |
|
|
else: |
|
|
df = df_new_row |
|
|
|
|
|
df.to_csv(log_file) |
|
|
|
|
|
def save_checkpoint( |
|
|
self, path, optimizer=None, iter_num=None, best_val_loss=None, config=None |
|
|
): |
|
|
optimizer = {} if not optimizer else optimizer.state_dict() |
|
|
iter_num = {} if not iter_num else {"iter_num": iter_num} |
|
|
best_val_loss = {} if not best_val_loss else {"best_val_loss": best_val_loss} |
|
|
config = {} if not config else {"config": config} |
|
|
checkpoint = { |
|
|
"model": self.state_dict(), |
|
|
"model_args": dict(self.config), |
|
|
**optimizer, |
|
|
**iter_num, |
|
|
**best_val_loss, |
|
|
**config, |
|
|
} |
|
|
torch.save(checkpoint, path) |
|
|
|
|
|
@staticmethod |
|
|
def from_checkpoint( |
|
|
path: str, |
|
|
return_train_params: bool = False, |
|
|
device: str = 'cpu', |
|
|
tokenizer_path: str | None = None, |
|
|
): |
|
|
checkpoint = torch.load(path, map_location=device, weights_only=True) |
|
|
|
|
|
config = GPTConfig(**checkpoint["model_args"]) |
|
|
if tokenizer_path: |
|
|
config.tokenizer_file = tokenizer_path |
|
|
model = GPT(config) |
|
|
state_dict = checkpoint["model"] |
|
|
|
|
|
|
|
|
|
|
|
unwanted_prefix = "_orig_mod." |
|
|
for k, v in list(state_dict.items()): |
|
|
if k.startswith(unwanted_prefix): |
|
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
|
model.load_state_dict(state_dict) |
|
|
model.to(device) |
|
|
|
|
|
if not return_train_params: |
|
|
return model |
|
|
|
|
|
iter_num = checkpoint["iter_num"] |
|
|
best_val_loss = checkpoint["best_val_loss"] |
|
|
optim_state = checkpoint["optimizer"] |
|
|
|
|
|
assert isinstance(iter_num, int) |
|
|
assert isinstance(best_val_loss, torch.Tensor) |
|
|
assert isinstance(optim_state, dict) |
|
|
|
|
|
return model, iter_num, best_val_loss, optim_state |
|
|
|