Neu256 commited on
Commit
c196c4f
·
1 Parent(s): c113930

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -121
utils.py DELETED
@@ -1,121 +0,0 @@
1
- import os
2
- import torch
3
- from datetime import datetime
4
-
5
- # hyperparameters
6
- BATCH_SIZE = 32 # how many independent sequences will we process in parallel?
7
- BLOCK_SIZE = 64 # what is the maximum context length for predictions?
8
- MAX_ITER = 500 # number of training iterations
9
- EVAL_INTER = 1
10
- LEARNING_RATE = 3e-4
11
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
- NUM_HEAD = 6
13
- NUM_EMBED = NUM_HEAD * 128
14
- NUM_LAYER = 6
15
- DROPOUT = 0.2
16
-
17
- def encode(text_seq: str, tokenizer: any) -> torch.Tensor:
18
- """
19
- Function to encode input text using a pre-trained tokenizer and vectorized lookups
20
- """
21
- # tokenize the input text
22
- tokens = tokenizer.tokenize(text_seq)
23
- # convert the tokens to their corresponding ids
24
- token_indices = tokenizer.convert_tokens_to_ids(tokens)
25
- token_indices = torch.tensor(token_indices, dtype=torch.long)
26
- return token_indices
27
-
28
-
29
- def decode(enc_sec: torch.Tensor, tokenizer: any) -> str:
30
- """
31
- Function to decode a sequence of token indices back to a string
32
- """
33
- # convert the indices to a list
34
- enc_sec = enc_sec.tolist()
35
- # decode the indices to a string
36
- text = tokenizer.decode(enc_sec)
37
- return text
38
-
39
-
40
- def get_batch(data: list[str], block_size: int, batch_size: int):
41
- """
42
- This is a simple function to create batches of data.
43
- GPUs allow for parallel processing we can feed multiple chunks at once
44
- so that's why we would need batches - how many independant sequences
45
- will we process in parallel.
46
-
47
- Parameters:
48
- data: list[str]: data to take batch from
49
- block_size (int): size of the text that is proccessed at once
50
- batch_size (int): number of sequences to process in parallel
51
-
52
- Returns:
53
- x, y: a tuple with token sequence and token target
54
- """
55
- ix = torch.randint(len(data) - block_size, (batch_size,))
56
- # we stack batch_size rows of sentences
57
- # so x and y are the matrices with rows_num=batch_size
58
- # and col_num=block_size
59
- x = torch.stack([data[i : i + block_size] for i in ix])
60
- # y is x shifted one position right - because we predict
61
- # word in y having all the previous words as context
62
- y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
63
- x, y = x.to(DEVICE), y.to(DEVICE)
64
- return x, y
65
-
66
-
67
- @torch.no_grad()
68
- def estimate_loss(
69
- data: list[str],
70
- model: torch.nn.Module,
71
- block_size: int,
72
- batch_size: int,
73
- eval_iters: int = 10,
74
- ):
75
- out = {}
76
- model.eval()
77
- losses = torch.zeros(eval_iters)
78
- for k in range(eval_iters):
79
- X, Y = get_batch(data=data, block_size=block_size, batch_size=batch_size)
80
- logits, loss = model.forward(X, Y)
81
- losses[k] = loss.item()
82
- out = losses.mean()
83
- model.train()
84
- return out
85
-
86
-
87
- def load_model_from_checkpoint(
88
- model_class: torch.nn.Module,
89
- path_to_checkpoint: str = "checkpoints/state_dict_model.pt",
90
- **kwargs: dict,
91
- ) -> torch.nn.Module:
92
- try:
93
- state_dict = torch.load(path_to_checkpoint)
94
- print("Successfully loaded model from the checkpoint")
95
- except Exception as e:
96
- print(f"Error loading the model from the checkpoint. {e}")
97
-
98
- model = model_class(**kwargs)
99
- # load the state_dict into the model
100
- model.load_state_dict(state_dict)
101
- return model
102
-
103
-
104
- def save_model_to_chekpoint(
105
- model: torch.nn.Module, path_to_checkpoint: str = "checkpoints", epoch: int = 0
106
- ):
107
- # check if path exists, otherwise create it
108
- if not os.path.exists(path_to_checkpoint):
109
- os.makedirs(path_to_checkpoint)
110
-
111
- # datetime object containing current date and time
112
- now = datetime.now()
113
- # dd/mm/YY H:M:S
114
- dt_string = now.strftime("%d.%m.%Y_%H:%M:%S")
115
- checkpoint_name = "checkpoint_epoch-" + str(epoch) + "_" + dt_string + ".pt"
116
- full_path = os.path.join(path_to_checkpoint, checkpoint_name)
117
- try:
118
- torch.save(model.state_dict(), full_path)
119
- print("Successfully saved the model to {}".format(full_path))
120
- except Exception as e:
121
- print(f"Error saving the model to checkpoint. {e}")